dsgrid-toolkit 0.3.3__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- build_backend.py +93 -0
- dsgrid/__init__.py +22 -0
- dsgrid/api/__init__.py +0 -0
- dsgrid/api/api_manager.py +179 -0
- dsgrid/api/app.py +419 -0
- dsgrid/api/models.py +60 -0
- dsgrid/api/response_models.py +116 -0
- dsgrid/apps/__init__.py +0 -0
- dsgrid/apps/project_viewer/app.py +216 -0
- dsgrid/apps/registration_gui.py +444 -0
- dsgrid/chronify.py +32 -0
- dsgrid/cli/__init__.py +0 -0
- dsgrid/cli/common.py +120 -0
- dsgrid/cli/config.py +176 -0
- dsgrid/cli/download.py +13 -0
- dsgrid/cli/dsgrid.py +157 -0
- dsgrid/cli/dsgrid_admin.py +92 -0
- dsgrid/cli/install_notebooks.py +62 -0
- dsgrid/cli/query.py +729 -0
- dsgrid/cli/registry.py +1862 -0
- dsgrid/cloud/__init__.py +0 -0
- dsgrid/cloud/cloud_storage_interface.py +140 -0
- dsgrid/cloud/factory.py +31 -0
- dsgrid/cloud/fake_storage_interface.py +37 -0
- dsgrid/cloud/s3_storage_interface.py +156 -0
- dsgrid/common.py +36 -0
- dsgrid/config/__init__.py +0 -0
- dsgrid/config/annual_time_dimension_config.py +194 -0
- dsgrid/config/common.py +142 -0
- dsgrid/config/config_base.py +148 -0
- dsgrid/config/dataset_config.py +907 -0
- dsgrid/config/dataset_schema_handler_factory.py +46 -0
- dsgrid/config/date_time_dimension_config.py +136 -0
- dsgrid/config/dimension_config.py +54 -0
- dsgrid/config/dimension_config_factory.py +65 -0
- dsgrid/config/dimension_mapping_base.py +350 -0
- dsgrid/config/dimension_mappings_config.py +48 -0
- dsgrid/config/dimensions.py +1025 -0
- dsgrid/config/dimensions_config.py +71 -0
- dsgrid/config/file_schema.py +190 -0
- dsgrid/config/index_time_dimension_config.py +80 -0
- dsgrid/config/input_dataset_requirements.py +31 -0
- dsgrid/config/mapping_tables.py +209 -0
- dsgrid/config/noop_time_dimension_config.py +42 -0
- dsgrid/config/project_config.py +1462 -0
- dsgrid/config/registration_models.py +188 -0
- dsgrid/config/representative_period_time_dimension_config.py +194 -0
- dsgrid/config/simple_models.py +49 -0
- dsgrid/config/supplemental_dimension.py +29 -0
- dsgrid/config/time_dimension_base_config.py +192 -0
- dsgrid/data_models.py +155 -0
- dsgrid/dataset/__init__.py +0 -0
- dsgrid/dataset/dataset.py +123 -0
- dsgrid/dataset/dataset_expression_handler.py +86 -0
- dsgrid/dataset/dataset_mapping_manager.py +121 -0
- dsgrid/dataset/dataset_schema_handler_base.py +945 -0
- dsgrid/dataset/dataset_schema_handler_one_table.py +209 -0
- dsgrid/dataset/dataset_schema_handler_two_table.py +322 -0
- dsgrid/dataset/growth_rates.py +162 -0
- dsgrid/dataset/models.py +51 -0
- dsgrid/dataset/table_format_handler_base.py +257 -0
- dsgrid/dataset/table_format_handler_factory.py +17 -0
- dsgrid/dataset/unpivoted_table.py +121 -0
- dsgrid/dimension/__init__.py +0 -0
- dsgrid/dimension/base_models.py +230 -0
- dsgrid/dimension/dimension_filters.py +308 -0
- dsgrid/dimension/standard.py +252 -0
- dsgrid/dimension/time.py +352 -0
- dsgrid/dimension/time_utils.py +103 -0
- dsgrid/dsgrid_rc.py +88 -0
- dsgrid/exceptions.py +105 -0
- dsgrid/filesystem/__init__.py +0 -0
- dsgrid/filesystem/cloud_filesystem.py +32 -0
- dsgrid/filesystem/factory.py +32 -0
- dsgrid/filesystem/filesystem_interface.py +136 -0
- dsgrid/filesystem/local_filesystem.py +74 -0
- dsgrid/filesystem/s3_filesystem.py +118 -0
- dsgrid/loggers.py +132 -0
- dsgrid/minimal_patterns.cp313-win_amd64.pyd +0 -0
- dsgrid/notebooks/connect_to_dsgrid_registry.ipynb +949 -0
- dsgrid/notebooks/registration.ipynb +48 -0
- dsgrid/notebooks/start_notebook.sh +11 -0
- dsgrid/project.py +451 -0
- dsgrid/query/__init__.py +0 -0
- dsgrid/query/dataset_mapping_plan.py +142 -0
- dsgrid/query/derived_dataset.py +388 -0
- dsgrid/query/models.py +728 -0
- dsgrid/query/query_context.py +287 -0
- dsgrid/query/query_submitter.py +994 -0
- dsgrid/query/report_factory.py +19 -0
- dsgrid/query/report_peak_load.py +70 -0
- dsgrid/query/reports_base.py +20 -0
- dsgrid/registry/__init__.py +0 -0
- dsgrid/registry/bulk_register.py +165 -0
- dsgrid/registry/common.py +287 -0
- dsgrid/registry/config_update_checker_base.py +63 -0
- dsgrid/registry/data_store_factory.py +34 -0
- dsgrid/registry/data_store_interface.py +74 -0
- dsgrid/registry/dataset_config_generator.py +158 -0
- dsgrid/registry/dataset_registry_manager.py +950 -0
- dsgrid/registry/dataset_update_checker.py +16 -0
- dsgrid/registry/dimension_mapping_registry_manager.py +575 -0
- dsgrid/registry/dimension_mapping_update_checker.py +16 -0
- dsgrid/registry/dimension_registry_manager.py +413 -0
- dsgrid/registry/dimension_update_checker.py +16 -0
- dsgrid/registry/duckdb_data_store.py +207 -0
- dsgrid/registry/filesystem_data_store.py +150 -0
- dsgrid/registry/filter_registry_manager.py +123 -0
- dsgrid/registry/project_config_generator.py +57 -0
- dsgrid/registry/project_registry_manager.py +1623 -0
- dsgrid/registry/project_update_checker.py +48 -0
- dsgrid/registry/registration_context.py +223 -0
- dsgrid/registry/registry_auto_updater.py +316 -0
- dsgrid/registry/registry_database.py +667 -0
- dsgrid/registry/registry_interface.py +446 -0
- dsgrid/registry/registry_manager.py +558 -0
- dsgrid/registry/registry_manager_base.py +367 -0
- dsgrid/registry/versioning.py +92 -0
- dsgrid/rust_ext/__init__.py +14 -0
- dsgrid/rust_ext/find_minimal_patterns.py +129 -0
- dsgrid/spark/__init__.py +0 -0
- dsgrid/spark/functions.py +589 -0
- dsgrid/spark/types.py +110 -0
- dsgrid/tests/__init__.py +0 -0
- dsgrid/tests/common.py +140 -0
- dsgrid/tests/make_us_data_registry.py +265 -0
- dsgrid/tests/register_derived_datasets.py +103 -0
- dsgrid/tests/utils.py +25 -0
- dsgrid/time/__init__.py +0 -0
- dsgrid/time/time_conversions.py +80 -0
- dsgrid/time/types.py +67 -0
- dsgrid/units/__init__.py +0 -0
- dsgrid/units/constants.py +113 -0
- dsgrid/units/convert.py +71 -0
- dsgrid/units/energy.py +145 -0
- dsgrid/units/power.py +87 -0
- dsgrid/utils/__init__.py +0 -0
- dsgrid/utils/dataset.py +830 -0
- dsgrid/utils/files.py +179 -0
- dsgrid/utils/filters.py +125 -0
- dsgrid/utils/id_remappings.py +100 -0
- dsgrid/utils/py_expression_eval/LICENSE +19 -0
- dsgrid/utils/py_expression_eval/README.md +8 -0
- dsgrid/utils/py_expression_eval/__init__.py +847 -0
- dsgrid/utils/py_expression_eval/tests.py +283 -0
- dsgrid/utils/run_command.py +70 -0
- dsgrid/utils/scratch_dir_context.py +65 -0
- dsgrid/utils/spark.py +918 -0
- dsgrid/utils/spark_partition.py +98 -0
- dsgrid/utils/timing.py +239 -0
- dsgrid/utils/utilities.py +221 -0
- dsgrid/utils/versioning.py +36 -0
- dsgrid_toolkit-0.3.3.dist-info/METADATA +193 -0
- dsgrid_toolkit-0.3.3.dist-info/RECORD +157 -0
- dsgrid_toolkit-0.3.3.dist-info/WHEEL +4 -0
- dsgrid_toolkit-0.3.3.dist-info/entry_points.txt +4 -0
- dsgrid_toolkit-0.3.3.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,667 @@
|
|
|
1
|
+
import getpass
|
|
2
|
+
import logging
|
|
3
|
+
import sqlite3
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any, Generator
|
|
7
|
+
|
|
8
|
+
from chronify.utils.path_utils import check_overwrite
|
|
9
|
+
from sqlalchemy import (
|
|
10
|
+
Column,
|
|
11
|
+
Connection,
|
|
12
|
+
Engine,
|
|
13
|
+
Integer,
|
|
14
|
+
ForeignKey,
|
|
15
|
+
JSON,
|
|
16
|
+
MetaData,
|
|
17
|
+
String,
|
|
18
|
+
Table,
|
|
19
|
+
UniqueConstraint,
|
|
20
|
+
create_engine,
|
|
21
|
+
delete,
|
|
22
|
+
insert,
|
|
23
|
+
select,
|
|
24
|
+
update,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
from dsgrid.exceptions import (
|
|
28
|
+
DSGValueNotRegistered,
|
|
29
|
+
DSGInvalidOperation,
|
|
30
|
+
DSGValueNotStored,
|
|
31
|
+
DSGDuplicateValueRegistered,
|
|
32
|
+
)
|
|
33
|
+
from dsgrid.registry.common import (
|
|
34
|
+
DataStoreType,
|
|
35
|
+
DatabaseConnection,
|
|
36
|
+
RegistrationModel,
|
|
37
|
+
RegistryTables,
|
|
38
|
+
RegistryType,
|
|
39
|
+
MODEL_TYPE_TO_ID_FIELD_MAPPING,
|
|
40
|
+
)
|
|
41
|
+
from dsgrid.registry.data_store_interface import DataStoreInterface
|
|
42
|
+
from dsgrid.registry.data_store_factory import make_data_store
|
|
43
|
+
from dsgrid.utils.files import dump_data
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
logger = logging.getLogger(__name__)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class RegistryDatabase:
|
|
50
|
+
"""Database containing a dsgrid registry"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, engine: Engine, data_store: DataStoreInterface | None = None) -> None:
|
|
53
|
+
"""Construct the database."""
|
|
54
|
+
self._metadata = MetaData()
|
|
55
|
+
self._engine = engine
|
|
56
|
+
self._data_store = data_store
|
|
57
|
+
|
|
58
|
+
def dispose(self) -> None:
|
|
59
|
+
"""Dispose the database engine and release all connections."""
|
|
60
|
+
self._engine.dispose()
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def create(
|
|
64
|
+
cls,
|
|
65
|
+
conn: DatabaseConnection,
|
|
66
|
+
data_path: Path,
|
|
67
|
+
data_store_type: DataStoreType = DataStoreType.FILESYSTEM,
|
|
68
|
+
overwrite: bool = False,
|
|
69
|
+
**connect_kwargs: Any,
|
|
70
|
+
) -> "RegistryDatabase":
|
|
71
|
+
"""Create a new registry database."""
|
|
72
|
+
filename = conn.get_filename()
|
|
73
|
+
for path in (filename, data_path):
|
|
74
|
+
check_overwrite(path, overwrite=overwrite)
|
|
75
|
+
data_path.mkdir()
|
|
76
|
+
data_store = make_data_store(data_path, data_store_type, initialize=True)
|
|
77
|
+
db = cls(create_engine(conn.url, **connect_kwargs), data_store)
|
|
78
|
+
db.initialize_db(data_path, data_store_type)
|
|
79
|
+
return db
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def create_with_existing_data(
|
|
83
|
+
cls,
|
|
84
|
+
conn: DatabaseConnection,
|
|
85
|
+
data_path: Path,
|
|
86
|
+
data_store_type: DataStoreType = DataStoreType.FILESYSTEM,
|
|
87
|
+
overwrite: bool = False,
|
|
88
|
+
**connect_kwargs: Any,
|
|
89
|
+
) -> "RegistryDatabase":
|
|
90
|
+
"""Create a new registry database with existing registry data."""
|
|
91
|
+
filename = conn.get_filename()
|
|
92
|
+
check_overwrite(filename, overwrite=overwrite)
|
|
93
|
+
store = make_data_store(data_path, data_store_type, initialize=False)
|
|
94
|
+
db = RegistryDatabase(create_engine(conn.url, **connect_kwargs), store)
|
|
95
|
+
db.initialize_db(data_path, data_store_type)
|
|
96
|
+
return db
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def connect(
|
|
100
|
+
cls,
|
|
101
|
+
conn: DatabaseConnection,
|
|
102
|
+
**connect_kwargs: Any,
|
|
103
|
+
) -> "RegistryDatabase":
|
|
104
|
+
"""Load an existing registry database."""
|
|
105
|
+
# This tests the connection.
|
|
106
|
+
conn.get_filename()
|
|
107
|
+
engine = create_engine(conn.url, **connect_kwargs)
|
|
108
|
+
db = RegistryDatabase(engine)
|
|
109
|
+
db.update_sqlalchemy_metadata()
|
|
110
|
+
base_path = db.get_data_path()
|
|
111
|
+
data_store_type = db.get_data_store_type()
|
|
112
|
+
db.data_store = make_data_store(base_path, data_store_type, initialize=False)
|
|
113
|
+
return db
|
|
114
|
+
|
|
115
|
+
def update_sqlalchemy_metadata(self) -> None:
|
|
116
|
+
"""Update the sqlalchemy metadata for table schema. Call this method if you add tables
|
|
117
|
+
in the sqlalchemy engine outside of this class.
|
|
118
|
+
"""
|
|
119
|
+
self._metadata.reflect(self._engine, views=True)
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def engine(self) -> Engine:
|
|
123
|
+
"""Return the sqlalchemy engine."""
|
|
124
|
+
return self._engine
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def data_store(self) -> DataStoreInterface:
|
|
128
|
+
"""Return the data store."""
|
|
129
|
+
if self._data_store is None:
|
|
130
|
+
msg = "Data store is not initialized. Use create() or connect() to initialize."
|
|
131
|
+
raise DSGInvalidOperation(msg)
|
|
132
|
+
return self._data_store
|
|
133
|
+
|
|
134
|
+
@data_store.setter
|
|
135
|
+
def data_store(self, store: DataStoreInterface) -> None:
|
|
136
|
+
"""Set the data store."""
|
|
137
|
+
self._data_store = store
|
|
138
|
+
|
|
139
|
+
def initialize_db(self, data_path: Path, data_store_type: DataStoreType) -> None:
|
|
140
|
+
"""Initialize the database to store a dsgrid registry."""
|
|
141
|
+
create_tables(self._engine, self._metadata)
|
|
142
|
+
created_by = getpass.getuser()
|
|
143
|
+
created_on = str(datetime.now())
|
|
144
|
+
registry_data_path = str(data_path)
|
|
145
|
+
with self._engine.begin() as conn:
|
|
146
|
+
kv_table = self.get_table(RegistryTables.KEY_VALUE)
|
|
147
|
+
conn.execute(insert(kv_table).values(key="created_by", value=created_by))
|
|
148
|
+
conn.execute(insert(kv_table).values(key="created_on", value=created_on))
|
|
149
|
+
conn.execute(insert(kv_table).values(key="data_path", value=registry_data_path))
|
|
150
|
+
conn.execute(
|
|
151
|
+
insert(kv_table).values(key="data_store_type", value=data_store_type.value)
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
record = {
|
|
155
|
+
"created_by": created_by,
|
|
156
|
+
"created_on": created_on,
|
|
157
|
+
"data_path": registry_data_path,
|
|
158
|
+
"data_store_type": data_store_type.value,
|
|
159
|
+
}
|
|
160
|
+
dump_data(record, data_path / "registry.json5")
|
|
161
|
+
|
|
162
|
+
def get_table(self, name: RegistryTables) -> Table:
|
|
163
|
+
"""Return the sqlalchemy Table object."""
|
|
164
|
+
if not self.has_table(name):
|
|
165
|
+
msg = f"{name=}"
|
|
166
|
+
raise DSGValueNotStored(msg)
|
|
167
|
+
return Table(name.value, self._metadata)
|
|
168
|
+
|
|
169
|
+
def has_table(self, name: RegistryTables) -> bool:
|
|
170
|
+
"""Return True if the database has a table with the given name."""
|
|
171
|
+
return name in self._metadata.tables
|
|
172
|
+
|
|
173
|
+
def try_get_table(self, name: RegistryTables) -> Table | None:
|
|
174
|
+
"""Return the sqlalchemy Table object or None if it is not stored."""
|
|
175
|
+
if not self.has_table(name):
|
|
176
|
+
return None
|
|
177
|
+
return Table(name.value, self._metadata)
|
|
178
|
+
|
|
179
|
+
def list_tables(self) -> list[str]:
|
|
180
|
+
"""Return a list of tables in the database."""
|
|
181
|
+
return [RegistryTables(x) for x in self._metadata.tables]
|
|
182
|
+
|
|
183
|
+
@classmethod
|
|
184
|
+
def copy(
|
|
185
|
+
cls,
|
|
186
|
+
src_conn: DatabaseConnection,
|
|
187
|
+
dst_conn: DatabaseConnection,
|
|
188
|
+
dst_data_path: Path,
|
|
189
|
+
) -> "RegistryDatabase":
|
|
190
|
+
"""Copy the contents of a source database to a destination and return the destination.
|
|
191
|
+
Currently, only supports SQLite backends.
|
|
192
|
+
"""
|
|
193
|
+
sqlite_base = "sqlite:///"
|
|
194
|
+
if sqlite_base not in src_conn.url or sqlite_base not in dst_conn.url:
|
|
195
|
+
# If/when we need postgres, we can use postgres tools or copy the tables through
|
|
196
|
+
# sqlalchemy.
|
|
197
|
+
msg = "Both src and destination databases must be sqlite. {src_conn=} {dst_conn=}"
|
|
198
|
+
raise NotImplementedError(msg)
|
|
199
|
+
|
|
200
|
+
cls.delete(dst_conn)
|
|
201
|
+
dst = cls.create(dst_conn, dst_data_path)
|
|
202
|
+
with sqlite3.connect(src_conn.url.replace("sqlite:///", "")) as src:
|
|
203
|
+
with dst.engine.begin() as dst_conn_:
|
|
204
|
+
# The backup below will overwrite the data_path value.
|
|
205
|
+
table = dst.get_table(RegistryTables.KEY_VALUE)
|
|
206
|
+
stmt = select(table.c.value).where(table.c.key == "data_path")
|
|
207
|
+
row = dst_conn_.execute(stmt).fetchone()
|
|
208
|
+
assert row is not None
|
|
209
|
+
orig_data_path = row.value
|
|
210
|
+
assert dst_conn_._dbapi_connection is not None
|
|
211
|
+
assert isinstance(
|
|
212
|
+
dst_conn_._dbapi_connection.driver_connection, sqlite3.Connection
|
|
213
|
+
)
|
|
214
|
+
src.backup(dst_conn_._dbapi_connection.driver_connection)
|
|
215
|
+
stmt = update(table).where(table.c.key == "data_path").values(value=orig_data_path)
|
|
216
|
+
dst_conn_.execute(stmt)
|
|
217
|
+
|
|
218
|
+
logger.info("Copied database %s to %s", src_conn.url, dst_conn.url)
|
|
219
|
+
dst.dispose()
|
|
220
|
+
return dst
|
|
221
|
+
|
|
222
|
+
@staticmethod
|
|
223
|
+
def delete(conn: DatabaseConnection) -> None:
|
|
224
|
+
"""Delete the dsgrid database."""
|
|
225
|
+
filename = conn.get_filename()
|
|
226
|
+
if filename.exists():
|
|
227
|
+
filename.unlink()
|
|
228
|
+
|
|
229
|
+
@staticmethod
|
|
230
|
+
def has_database(conn: DatabaseConnection) -> bool:
|
|
231
|
+
"""Return True if the database exists."""
|
|
232
|
+
filename = conn.get_filename()
|
|
233
|
+
return filename.exists()
|
|
234
|
+
|
|
235
|
+
def _get_model_id(self, model_type: RegistryType, model: dict[str, Any]) -> str:
|
|
236
|
+
return model[self._get_model_id_field(model_type)]
|
|
237
|
+
|
|
238
|
+
@staticmethod
|
|
239
|
+
def _get_model_id_field(model_type: RegistryType) -> str:
|
|
240
|
+
return MODEL_TYPE_TO_ID_FIELD_MAPPING[model_type]
|
|
241
|
+
|
|
242
|
+
def insert_model(
|
|
243
|
+
self,
|
|
244
|
+
conn: Connection,
|
|
245
|
+
model_type: RegistryType,
|
|
246
|
+
model: dict[str, Any],
|
|
247
|
+
registration: RegistrationModel,
|
|
248
|
+
) -> dict[str, Any]:
|
|
249
|
+
"""Add a model to the database. Sets the id field of the model with the database value."""
|
|
250
|
+
table = self.get_table(RegistryTables.MODELS)
|
|
251
|
+
model_id = self._get_model_id(model_type, model)
|
|
252
|
+
if self.has(conn, model_type, model_id):
|
|
253
|
+
msg = f"{model_type=} {model_id}"
|
|
254
|
+
raise DSGDuplicateValueRegistered(msg)
|
|
255
|
+
|
|
256
|
+
res = conn.execute(
|
|
257
|
+
insert(table).values(
|
|
258
|
+
registration_id=registration.id,
|
|
259
|
+
model_type=model_type.value,
|
|
260
|
+
model_id=model_id,
|
|
261
|
+
version=model["version"],
|
|
262
|
+
model=model,
|
|
263
|
+
)
|
|
264
|
+
)
|
|
265
|
+
db_id = res.lastrowid
|
|
266
|
+
new_model = self._add_database_id(conn, table, db_id)
|
|
267
|
+
self._set_current_version(conn, model_type, model_id, db_id)
|
|
268
|
+
logger.debug("Inserted model_type=%s model_id=%s id=%s", model_type, model_id, db_id)
|
|
269
|
+
return new_model
|
|
270
|
+
|
|
271
|
+
def update_model(
|
|
272
|
+
self,
|
|
273
|
+
conn: Connection,
|
|
274
|
+
model_type: RegistryType,
|
|
275
|
+
model: dict[str, Any],
|
|
276
|
+
registration: RegistrationModel,
|
|
277
|
+
) -> dict[str, Any]:
|
|
278
|
+
"""Add a model to the database. Sets the id field of the model with the database value."""
|
|
279
|
+
table = self.get_table(RegistryTables.MODELS)
|
|
280
|
+
model_id = self._get_model_id(model_type, model)
|
|
281
|
+
res = conn.execute(
|
|
282
|
+
insert(table).values(
|
|
283
|
+
registration_id=registration.id,
|
|
284
|
+
model_type=model_type.value,
|
|
285
|
+
model_id=model_id,
|
|
286
|
+
version=model["version"],
|
|
287
|
+
model=model,
|
|
288
|
+
)
|
|
289
|
+
)
|
|
290
|
+
db_id = res.lastrowid
|
|
291
|
+
new_model = self._add_database_id(conn, table, db_id)
|
|
292
|
+
self._update_current_version(conn, model_type, model_id, db_id)
|
|
293
|
+
logger.error("Updated model_type=%s model_id=%s id=%s", model_type, model_id, db_id)
|
|
294
|
+
return new_model
|
|
295
|
+
|
|
296
|
+
def _add_database_id(self, conn: Connection, table: Table, db_id: int) -> dict[str, Any]:
|
|
297
|
+
"""Add the newly-generated ID to the model's JSON blob and update the db."""
|
|
298
|
+
stmt = select(table.c.model).where(table.c.id == db_id)
|
|
299
|
+
row = conn.execute(stmt).fetchone()
|
|
300
|
+
assert row
|
|
301
|
+
data = row.model
|
|
302
|
+
data["id"] = db_id
|
|
303
|
+
conn.execute(update(table).where(table.c.id == db_id).values(model=data))
|
|
304
|
+
return data
|
|
305
|
+
|
|
306
|
+
def _set_current_version(
|
|
307
|
+
self, conn: Connection, model_type: RegistryType, model_id: str, db_id: int
|
|
308
|
+
) -> None:
|
|
309
|
+
table = self.get_table(RegistryTables.CURRENT_VERSIONS)
|
|
310
|
+
conn.execute(
|
|
311
|
+
insert(table).values(
|
|
312
|
+
model_type=model_type.value,
|
|
313
|
+
model_id=model_id,
|
|
314
|
+
current_id=db_id,
|
|
315
|
+
update_timestamp=str(datetime.now()),
|
|
316
|
+
)
|
|
317
|
+
)
|
|
318
|
+
logger.debug("Set the current version of %s %s to %s", model_type, model_id, db_id)
|
|
319
|
+
|
|
320
|
+
def _update_current_version(
|
|
321
|
+
self, conn: Connection, model_type: RegistryType, model_id: str, db_id: int
|
|
322
|
+
) -> None:
|
|
323
|
+
table = self.get_table(RegistryTables.CURRENT_VERSIONS)
|
|
324
|
+
stmt = (
|
|
325
|
+
update(table)
|
|
326
|
+
.where(table.c.model_type == model_type)
|
|
327
|
+
.where(table.c.model_id == model_id)
|
|
328
|
+
.values(current_id=db_id, update_timestamp=str(datetime.now()))
|
|
329
|
+
)
|
|
330
|
+
conn.execute(stmt)
|
|
331
|
+
logger.debug("Set the current version of %s %s to %s", model_type, model_id, db_id)
|
|
332
|
+
|
|
333
|
+
def get_containing_models_by_db_id(
|
|
334
|
+
self,
|
|
335
|
+
conn: Connection,
|
|
336
|
+
db_id: int,
|
|
337
|
+
parent_model_type: RegistryType | None = None,
|
|
338
|
+
) -> list[tuple[RegistryType, dict[str, Any]]]:
|
|
339
|
+
table1 = self.get_table(RegistryTables.CONTAINS)
|
|
340
|
+
table2 = self.get_table(RegistryTables.MODELS)
|
|
341
|
+
table3 = self.get_table(RegistryTables.CURRENT_VERSIONS)
|
|
342
|
+
stmt = (
|
|
343
|
+
select(table2.c.model_type, table2.c.model)
|
|
344
|
+
.join(table2, table1.c.parent_id == table2.c.id)
|
|
345
|
+
.join(table3, table2.c.id == table3.c.current_id)
|
|
346
|
+
.where(table1.c.child_id == db_id)
|
|
347
|
+
)
|
|
348
|
+
if parent_model_type is not None:
|
|
349
|
+
stmt = stmt.where(table2.c.model_type == parent_model_type)
|
|
350
|
+
return [(RegistryType(x.model_type), x.model) for x in conn.execute(stmt).fetchall()]
|
|
351
|
+
|
|
352
|
+
def get_containing_models(
|
|
353
|
+
self,
|
|
354
|
+
conn: Connection,
|
|
355
|
+
child_model_type: RegistryType,
|
|
356
|
+
model_id: str,
|
|
357
|
+
version: str,
|
|
358
|
+
parent_model_type: RegistryType | None = None,
|
|
359
|
+
):
|
|
360
|
+
db_id = self._get_db_id(conn, child_model_type, model_id, version)
|
|
361
|
+
return self.get_containing_models_by_db_id(
|
|
362
|
+
conn, db_id, parent_model_type=parent_model_type
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
def _get_db_id(
|
|
366
|
+
self, conn: Connection, model_type: RegistryType, model_id: str, version: str
|
|
367
|
+
) -> int:
|
|
368
|
+
table = self.get_table(RegistryTables.MODELS)
|
|
369
|
+
stmt = (
|
|
370
|
+
select(table.c.id)
|
|
371
|
+
.where(table.c.model_type == model_type)
|
|
372
|
+
.where(table.c.model_id == model_id)
|
|
373
|
+
.where(table.c.version == version)
|
|
374
|
+
)
|
|
375
|
+
row = conn.execute(stmt).fetchone()
|
|
376
|
+
assert row
|
|
377
|
+
return row.id
|
|
378
|
+
|
|
379
|
+
def delete_models(self, conn: Connection, model_type: RegistryType, model_id: str) -> None:
|
|
380
|
+
"""Delete all documents of model_type with the model_id."""
|
|
381
|
+
for table in (RegistryTables.MODELS, RegistryTables.CURRENT_VERSIONS):
|
|
382
|
+
table = self._get_table(table)
|
|
383
|
+
stmt = (
|
|
384
|
+
delete(table)
|
|
385
|
+
.where(table.c.model_type == model_type.value)
|
|
386
|
+
.where(table.c.model_id == model_id)
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
conn.execute(stmt)
|
|
390
|
+
|
|
391
|
+
logger.info("Deleted all documents with model_id=%s", model_id)
|
|
392
|
+
|
|
393
|
+
def get_data_path(self) -> Path:
|
|
394
|
+
"""Return the path where dataset data is stored."""
|
|
395
|
+
table = self._get_table(RegistryTables.KEY_VALUE)
|
|
396
|
+
with self._engine.connect() as conn:
|
|
397
|
+
row = conn.execute(select(table.c.value).where(table.c.key == "data_path")).fetchone()
|
|
398
|
+
if row is None:
|
|
399
|
+
msg = "Bug: received no result in query for data_path"
|
|
400
|
+
raise Exception(msg)
|
|
401
|
+
return Path(row.value)
|
|
402
|
+
|
|
403
|
+
def get_data_store_type(self) -> DataStoreType:
|
|
404
|
+
"""Return the path where dataset data is stored."""
|
|
405
|
+
table = self._get_table(RegistryTables.KEY_VALUE)
|
|
406
|
+
with self._engine.connect() as conn:
|
|
407
|
+
row = conn.execute(
|
|
408
|
+
select(table.c.value).where(table.c.key == "data_store_type")
|
|
409
|
+
).fetchone()
|
|
410
|
+
if row is None:
|
|
411
|
+
# Allow legacy registries to keep working.
|
|
412
|
+
return DataStoreType.FILESYSTEM
|
|
413
|
+
return DataStoreType(row.value)
|
|
414
|
+
|
|
415
|
+
def _get_table(self, table_type: RegistryTables) -> Table:
|
|
416
|
+
return Table(table_type.value, self._metadata)
|
|
417
|
+
|
|
418
|
+
def get_latest(self, conn: Connection, model_type: RegistryType, model_id: str):
|
|
419
|
+
return self._get_latest_column(conn, model_type, model_id, "model")
|
|
420
|
+
|
|
421
|
+
def get_latest_version(self, conn: Connection, model_type: RegistryType, model_id: str) -> str:
|
|
422
|
+
return self._get_latest_column(conn, model_type, model_id, "version")
|
|
423
|
+
|
|
424
|
+
def _get_latest_column(
|
|
425
|
+
self, conn: Connection, model_type: RegistryType, model_id: str, column: str
|
|
426
|
+
) -> Any:
|
|
427
|
+
table1 = self._get_table(RegistryTables.MODELS)
|
|
428
|
+
table2 = self._get_table(RegistryTables.CURRENT_VERSIONS)
|
|
429
|
+
stmt = (
|
|
430
|
+
select(
|
|
431
|
+
table1.c[column],
|
|
432
|
+
)
|
|
433
|
+
.join_from(table1, table2, table1.c.id == table2.c.current_id)
|
|
434
|
+
.where(table2.c.model_type == model_type)
|
|
435
|
+
.where(table2.c.model_id == model_id)
|
|
436
|
+
)
|
|
437
|
+
rows = conn.execute(stmt).fetchall()
|
|
438
|
+
if not rows:
|
|
439
|
+
msg = f"{model_type=} {model_id=} is not registered"
|
|
440
|
+
raise DSGValueNotRegistered(msg)
|
|
441
|
+
if len(rows) != 1:
|
|
442
|
+
msg = "Bug: received more than one model set to latest: {rows}"
|
|
443
|
+
raise Exception(msg)
|
|
444
|
+
return getattr(rows[0], column)
|
|
445
|
+
|
|
446
|
+
def list_model_ids(self, conn: Connection, model_type: RegistryType) -> list[str]:
|
|
447
|
+
table = self.get_table(RegistryTables.MODELS)
|
|
448
|
+
stmt = select(table.c.model_id).where(table.c.model_type == model_type).distinct()
|
|
449
|
+
return [x.model_id for x in conn.execute(stmt).fetchall()]
|
|
450
|
+
|
|
451
|
+
def iter_models(
|
|
452
|
+
self, conn: Connection, model_type: RegistryType, all_versions: bool = False
|
|
453
|
+
) -> Generator[dict[str, Any], None, None]:
|
|
454
|
+
table = self.get_table(RegistryTables.MODELS)
|
|
455
|
+
if all_versions:
|
|
456
|
+
stmt = select(table.c.model).where(table.c.model_type == model_type)
|
|
457
|
+
else:
|
|
458
|
+
table2 = self._get_table(RegistryTables.CURRENT_VERSIONS)
|
|
459
|
+
stmt = (
|
|
460
|
+
select(
|
|
461
|
+
table.c.model,
|
|
462
|
+
)
|
|
463
|
+
.join_from(table, table2, table.c.id == table2.c.current_id)
|
|
464
|
+
.where(table2.c.model_type == model_type)
|
|
465
|
+
)
|
|
466
|
+
for item in conn.execute(stmt).fetchall():
|
|
467
|
+
yield item.model
|
|
468
|
+
|
|
469
|
+
def _get_by_version(
|
|
470
|
+
self, conn: Connection, model_type: RegistryType, model_id: str, version: str
|
|
471
|
+
):
|
|
472
|
+
table = self.get_table(RegistryTables.MODELS)
|
|
473
|
+
stmt = (
|
|
474
|
+
select(table.c.model)
|
|
475
|
+
.where(table.c.model_type == model_type)
|
|
476
|
+
.where(table.c.model_id == model_id)
|
|
477
|
+
.where(table.c.version == str(version))
|
|
478
|
+
)
|
|
479
|
+
rows = conn.execute(stmt).fetchall()
|
|
480
|
+
if not rows:
|
|
481
|
+
msg = f"{model_type=} {model_id}"
|
|
482
|
+
raise DSGValueNotRegistered(msg)
|
|
483
|
+
if len(rows) > 1:
|
|
484
|
+
msg = f"Bug: found more than one row. {model_type=} {model_id=} {version=}"
|
|
485
|
+
raise Exception(msg)
|
|
486
|
+
return rows[0].model
|
|
487
|
+
|
|
488
|
+
def insert_registration(
|
|
489
|
+
self,
|
|
490
|
+
conn: Connection,
|
|
491
|
+
registration: RegistrationModel,
|
|
492
|
+
) -> RegistrationModel:
|
|
493
|
+
"""Insert a registration entry to the database.
|
|
494
|
+
|
|
495
|
+
Returns
|
|
496
|
+
-------
|
|
497
|
+
RegistrationModel
|
|
498
|
+
Will be identical to the input except that id will be assigned from the database.
|
|
499
|
+
"""
|
|
500
|
+
table = self.get_table(RegistryTables.REGISTRATIONS)
|
|
501
|
+
res = conn.execute(
|
|
502
|
+
insert(table).values(
|
|
503
|
+
timestamp=str(registration.timestamp),
|
|
504
|
+
submitter=registration.submitter,
|
|
505
|
+
update_type=registration.update_type,
|
|
506
|
+
log_message=registration.log_message,
|
|
507
|
+
)
|
|
508
|
+
)
|
|
509
|
+
data = registration.model_dump(mode="json")
|
|
510
|
+
data["id"] = res.lastrowid
|
|
511
|
+
return RegistrationModel(**data)
|
|
512
|
+
|
|
513
|
+
def get_registration(self, conn: Connection, db_id: int) -> RegistrationModel:
|
|
514
|
+
"""Return the registration information for the database ID."""
|
|
515
|
+
table1 = self.get_table(RegistryTables.MODELS)
|
|
516
|
+
table2 = self.get_table(RegistryTables.REGISTRATIONS)
|
|
517
|
+
stmt = (
|
|
518
|
+
select(
|
|
519
|
+
table2.c.id,
|
|
520
|
+
table2.c.timestamp,
|
|
521
|
+
table2.c.submitter,
|
|
522
|
+
table2.c.update_type,
|
|
523
|
+
table2.c.log_message,
|
|
524
|
+
)
|
|
525
|
+
.join_from(table1, table2, table1.c.registration_id == table2.c.id)
|
|
526
|
+
.where(table1.c.id == db_id)
|
|
527
|
+
)
|
|
528
|
+
rows = conn.execute(stmt).fetchall()
|
|
529
|
+
if not rows:
|
|
530
|
+
msg = f"{db_id=}"
|
|
531
|
+
raise DSGValueNotRegistered(msg)
|
|
532
|
+
if len(rows) > 1:
|
|
533
|
+
msg = f"Bug: found more than one row matching {db_id=}"
|
|
534
|
+
raise Exception(msg)
|
|
535
|
+
row = rows[0]
|
|
536
|
+
return RegistrationModel(
|
|
537
|
+
id=row.id,
|
|
538
|
+
timestamp=row.timestamp,
|
|
539
|
+
submitter=row.submitter,
|
|
540
|
+
update_type=row.update_type,
|
|
541
|
+
log_message=row.log_message,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
def get_initial_registration(
|
|
545
|
+
self, conn: Connection, model_type: RegistryType, model_id: str
|
|
546
|
+
) -> RegistrationModel:
|
|
547
|
+
"""Return the initial registration information for the ID."""
|
|
548
|
+
table1 = self.get_table(RegistryTables.MODELS)
|
|
549
|
+
table2 = self.get_table(RegistryTables.REGISTRATIONS)
|
|
550
|
+
stmt = (
|
|
551
|
+
select(
|
|
552
|
+
table2.c.id,
|
|
553
|
+
table2.c.timestamp,
|
|
554
|
+
table2.c.submitter,
|
|
555
|
+
table2.c.update_type,
|
|
556
|
+
table2.c.log_message,
|
|
557
|
+
)
|
|
558
|
+
.join_from(table1, table2, table1.c.registration_id == table2.c.id)
|
|
559
|
+
.where(table1.c.model_type == model_type.value)
|
|
560
|
+
.where(table1.c.model_id == model_id)
|
|
561
|
+
.order_by(table1.c.id)
|
|
562
|
+
.limit(1)
|
|
563
|
+
)
|
|
564
|
+
row = conn.execute(stmt).fetchone()
|
|
565
|
+
if not row:
|
|
566
|
+
msg = f"{model_type=} {model_id=}"
|
|
567
|
+
raise DSGValueNotRegistered(msg)
|
|
568
|
+
assert row
|
|
569
|
+
return RegistrationModel(
|
|
570
|
+
id=row.id,
|
|
571
|
+
timestamp=row.timestamp,
|
|
572
|
+
submitter=row.submitter,
|
|
573
|
+
update_type=row.update_type,
|
|
574
|
+
log_message=row.log_message,
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
def has(
|
|
578
|
+
self,
|
|
579
|
+
conn: Connection,
|
|
580
|
+
model_type: RegistryType,
|
|
581
|
+
model_id: str,
|
|
582
|
+
version: str | None = None,
|
|
583
|
+
) -> bool:
|
|
584
|
+
"""Return True if the database has a document matching the inputs."""
|
|
585
|
+
table = self.get_table(RegistryTables.MODELS)
|
|
586
|
+
stmt = (
|
|
587
|
+
select(table.c.id)
|
|
588
|
+
.where(table.c.model_type == model_type)
|
|
589
|
+
.where(table.c.model_id == model_id)
|
|
590
|
+
)
|
|
591
|
+
if version is not None:
|
|
592
|
+
stmt = stmt.where(table.c.version == version)
|
|
593
|
+
stmt = stmt.limit(1)
|
|
594
|
+
res = conn.execute(stmt).fetchone()
|
|
595
|
+
return bool(res)
|
|
596
|
+
|
|
597
|
+
def insert_contains_edge(
|
|
598
|
+
self,
|
|
599
|
+
conn: Connection,
|
|
600
|
+
parent_id: int,
|
|
601
|
+
child_id: int,
|
|
602
|
+
) -> None:
|
|
603
|
+
table = self.get_table(RegistryTables.CONTAINS)
|
|
604
|
+
conn.execute(insert(table).values(parent_id=parent_id, child_id=child_id))
|
|
605
|
+
|
|
606
|
+
def replace_model(self, conn: Connection, model: dict[str, Any]):
|
|
607
|
+
"""Replace the model in the database."""
|
|
608
|
+
table = self.get_table(RegistryTables.MODELS)
|
|
609
|
+
conn.execute(update(table).where(table.c.id == model["id"]).values(model=model))
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def create_tables(engine: Engine, metadata: MetaData) -> None:
|
|
613
|
+
"""Create the registry tables in the database."""
|
|
614
|
+
# Note to devs: Please update dev/registry_database.md if you change the schema.
|
|
615
|
+
Table(
|
|
616
|
+
RegistryTables.KEY_VALUE.value,
|
|
617
|
+
metadata,
|
|
618
|
+
Column("key", String(), unique=True, primary_key=True),
|
|
619
|
+
Column("value", String(), nullable=False),
|
|
620
|
+
)
|
|
621
|
+
reg_table = Table(
|
|
622
|
+
RegistryTables.REGISTRATIONS.value,
|
|
623
|
+
metadata,
|
|
624
|
+
Column("id", Integer, unique=True, primary_key=True),
|
|
625
|
+
Column("timestamp", String(), nullable=False),
|
|
626
|
+
Column("submitter", String(), nullable=False),
|
|
627
|
+
Column("update_type", String(), nullable=False),
|
|
628
|
+
Column("log_message", String(), nullable=False),
|
|
629
|
+
)
|
|
630
|
+
models_table = Table(
|
|
631
|
+
RegistryTables.MODELS.value,
|
|
632
|
+
metadata,
|
|
633
|
+
Column("id", Integer(), unique=True, primary_key=True),
|
|
634
|
+
Column("registration_id", Integer(), ForeignKey(reg_table.c.id), nullable=False),
|
|
635
|
+
# project, dataset, dimension, dimension_mapping
|
|
636
|
+
Column("model_type", String(), nullable=False),
|
|
637
|
+
# project_id, dataset_id, dimension_id, mapping_id
|
|
638
|
+
Column("model_id", String(), nullable=False),
|
|
639
|
+
Column("version", String(), nullable=False),
|
|
640
|
+
# This is project_config, dataset_config, etc.
|
|
641
|
+
Column("model", JSON(), nullable=False),
|
|
642
|
+
UniqueConstraint("model_type", "model_id", "version"),
|
|
643
|
+
)
|
|
644
|
+
Table(
|
|
645
|
+
RegistryTables.CURRENT_VERSIONS.value,
|
|
646
|
+
metadata,
|
|
647
|
+
Column("id", Integer(), unique=True, primary_key=True),
|
|
648
|
+
Column("model_type", String(), nullable=False),
|
|
649
|
+
Column("model_id", String(), nullable=False),
|
|
650
|
+
Column("current_id", Integer(), ForeignKey(models_table.c.id), nullable=False),
|
|
651
|
+
Column("update_timestamp", String(), nullable=False),
|
|
652
|
+
UniqueConstraint("model_type", "model_id"),
|
|
653
|
+
)
|
|
654
|
+
# This table manages associations between
|
|
655
|
+
# projects and datasets,
|
|
656
|
+
# projects and dimensions,
|
|
657
|
+
# datasets and dimensions,
|
|
658
|
+
# dimension mappings and dimensions,
|
|
659
|
+
# and possibly derived datasets and datasets.
|
|
660
|
+
Table(
|
|
661
|
+
RegistryTables.CONTAINS,
|
|
662
|
+
metadata,
|
|
663
|
+
Column("id", Integer, primary_key=True, unique=True),
|
|
664
|
+
Column("parent_id", Integer(), ForeignKey(models_table.c.id), nullable=False),
|
|
665
|
+
Column("child_id", Integer(), ForeignKey(models_table.c.id), nullable=False),
|
|
666
|
+
)
|
|
667
|
+
metadata.create_all(engine)
|