policyengine 3.0.0__py3-none-any.whl → 3.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- policyengine/__pycache__/__init__.cpython-313.pyc +0 -0
- policyengine/core/__init__.py +22 -0
- policyengine/core/dataset.py +260 -0
- policyengine/core/dataset_version.py +16 -0
- policyengine/core/dynamic.py +43 -0
- policyengine/core/output.py +26 -0
- policyengine/{models → core}/parameter.py +4 -2
- policyengine/{models → core}/parameter_value.py +1 -1
- policyengine/core/policy.py +43 -0
- policyengine/{models → core}/simulation.py +10 -14
- policyengine/core/tax_benefit_model.py +11 -0
- policyengine/core/tax_benefit_model_version.py +34 -0
- policyengine/core/variable.py +15 -0
- policyengine/outputs/__init__.py +21 -0
- policyengine/outputs/aggregate.py +124 -0
- policyengine/outputs/change_aggregate.py +184 -0
- policyengine/outputs/decile_impact.py +140 -0
- policyengine/tax_benefit_models/uk/__init__.py +26 -0
- policyengine/tax_benefit_models/uk/analysis.py +97 -0
- policyengine/tax_benefit_models/uk/datasets.py +176 -0
- policyengine/tax_benefit_models/uk/model.py +268 -0
- policyengine/tax_benefit_models/uk/outputs.py +108 -0
- policyengine/tax_benefit_models/uk.py +33 -0
- policyengine/tax_benefit_models/us/__init__.py +36 -0
- policyengine/tax_benefit_models/us/analysis.py +99 -0
- policyengine/tax_benefit_models/us/datasets.py +307 -0
- policyengine/tax_benefit_models/us/model.py +447 -0
- policyengine/tax_benefit_models/us/outputs.py +108 -0
- policyengine/tax_benefit_models/us.py +32 -0
- policyengine/utils/__init__.py +3 -0
- policyengine/utils/dates.py +40 -0
- policyengine/utils/parametric_reforms.py +39 -0
- policyengine/utils/plotting.py +179 -0
- {policyengine-3.0.0.dist-info → policyengine-3.1.1.dist-info}/METADATA +185 -20
- policyengine-3.1.1.dist-info/RECORD +39 -0
- policyengine/database/__init__.py +0 -56
- policyengine/database/aggregate.py +0 -33
- policyengine/database/baseline_parameter_value_table.py +0 -66
- policyengine/database/baseline_variable_table.py +0 -40
- policyengine/database/database.py +0 -251
- policyengine/database/dataset_table.py +0 -41
- policyengine/database/dynamic_table.py +0 -34
- policyengine/database/link.py +0 -82
- policyengine/database/model_table.py +0 -27
- policyengine/database/model_version_table.py +0 -28
- policyengine/database/parameter_table.py +0 -31
- policyengine/database/parameter_value_table.py +0 -62
- policyengine/database/policy_table.py +0 -34
- policyengine/database/report_element_table.py +0 -48
- policyengine/database/report_table.py +0 -24
- policyengine/database/simulation_table.py +0 -50
- policyengine/database/user_table.py +0 -28
- policyengine/database/versioned_dataset_table.py +0 -28
- policyengine/models/__init__.py +0 -30
- policyengine/models/aggregate.py +0 -92
- policyengine/models/baseline_parameter_value.py +0 -14
- policyengine/models/baseline_variable.py +0 -12
- policyengine/models/dataset.py +0 -18
- policyengine/models/dynamic.py +0 -15
- policyengine/models/model.py +0 -124
- policyengine/models/model_version.py +0 -14
- policyengine/models/policy.py +0 -17
- policyengine/models/policyengine_uk.py +0 -114
- policyengine/models/policyengine_us.py +0 -115
- policyengine/models/report.py +0 -10
- policyengine/models/report_element.py +0 -36
- policyengine/models/user.py +0 -14
- policyengine/models/versioned_dataset.py +0 -12
- policyengine/utils/charts.py +0 -286
- policyengine/utils/compress.py +0 -20
- policyengine/utils/datasets.py +0 -71
- policyengine-3.0.0.dist-info/RECORD +0 -47
- policyengine-3.0.0.dist-info/entry_points.txt +0 -2
- {policyengine-3.0.0.dist-info → policyengine-3.1.1.dist-info}/WHEEL +0 -0
- {policyengine-3.0.0.dist-info → policyengine-3.1.1.dist-info}/licenses/LICENSE +0 -0
- {policyengine-3.0.0.dist-info → policyengine-3.1.1.dist-info}/top_level.txt +0 -0
|
@@ -1,251 +0,0 @@
|
|
|
1
|
-
from typing import Any
|
|
2
|
-
|
|
3
|
-
from sqlmodel import Session, SQLModel
|
|
4
|
-
|
|
5
|
-
from .aggregate import aggregate_table_link
|
|
6
|
-
from .baseline_parameter_value_table import baseline_parameter_value_table_link
|
|
7
|
-
from .baseline_variable_table import baseline_variable_table_link
|
|
8
|
-
from .dataset_table import dataset_table_link
|
|
9
|
-
from .dynamic_table import dynamic_table_link
|
|
10
|
-
from .link import TableLink
|
|
11
|
-
|
|
12
|
-
# Import all table links
|
|
13
|
-
from .model_table import model_table_link
|
|
14
|
-
from .model_version_table import model_version_table_link
|
|
15
|
-
from .parameter_table import parameter_table_link
|
|
16
|
-
from .parameter_value_table import parameter_value_table_link
|
|
17
|
-
from .policy_table import policy_table_link
|
|
18
|
-
from .report_element_table import report_element_table_link
|
|
19
|
-
from .report_table import report_table_link
|
|
20
|
-
from .simulation_table import simulation_table_link
|
|
21
|
-
from .user_table import user_table_link
|
|
22
|
-
from .versioned_dataset_table import versioned_dataset_table_link
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class Database:
|
|
26
|
-
url: str
|
|
27
|
-
|
|
28
|
-
_model_table_links: list[TableLink] = []
|
|
29
|
-
|
|
30
|
-
def __init__(self, url: str):
|
|
31
|
-
self.url = url
|
|
32
|
-
self.engine = self._create_engine()
|
|
33
|
-
self.session = Session(self.engine)
|
|
34
|
-
|
|
35
|
-
for link in [
|
|
36
|
-
model_table_link,
|
|
37
|
-
model_version_table_link,
|
|
38
|
-
dataset_table_link,
|
|
39
|
-
versioned_dataset_table_link,
|
|
40
|
-
policy_table_link,
|
|
41
|
-
dynamic_table_link,
|
|
42
|
-
parameter_table_link,
|
|
43
|
-
parameter_value_table_link,
|
|
44
|
-
baseline_parameter_value_table_link,
|
|
45
|
-
baseline_variable_table_link,
|
|
46
|
-
simulation_table_link,
|
|
47
|
-
aggregate_table_link,
|
|
48
|
-
user_table_link,
|
|
49
|
-
report_table_link,
|
|
50
|
-
report_element_table_link,
|
|
51
|
-
]:
|
|
52
|
-
self.register_table(link)
|
|
53
|
-
|
|
54
|
-
def _create_engine(self):
|
|
55
|
-
from sqlmodel import create_engine
|
|
56
|
-
|
|
57
|
-
return create_engine(self.url, echo=False)
|
|
58
|
-
|
|
59
|
-
def create_tables(self):
|
|
60
|
-
"""Create all database tables."""
|
|
61
|
-
SQLModel.metadata.create_all(self.engine)
|
|
62
|
-
|
|
63
|
-
def drop_tables(self):
|
|
64
|
-
"""Drop all database tables."""
|
|
65
|
-
SQLModel.metadata.drop_all(self.engine)
|
|
66
|
-
|
|
67
|
-
def reset(self):
|
|
68
|
-
"""Drop and recreate all tables."""
|
|
69
|
-
self.drop_tables()
|
|
70
|
-
self.create_tables()
|
|
71
|
-
|
|
72
|
-
def __enter__(self):
|
|
73
|
-
"""Context manager entry - creates a session."""
|
|
74
|
-
self.session = Session(self.engine)
|
|
75
|
-
return self.session
|
|
76
|
-
|
|
77
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
78
|
-
"""Context manager exit - closes the session."""
|
|
79
|
-
if exc_type:
|
|
80
|
-
self.session.rollback()
|
|
81
|
-
else:
|
|
82
|
-
self.session.commit()
|
|
83
|
-
self.session.close()
|
|
84
|
-
|
|
85
|
-
def register_table(self, link: TableLink):
|
|
86
|
-
self._model_table_links.append(link)
|
|
87
|
-
# Create the table if not exists
|
|
88
|
-
link.table_cls.metadata.create_all(self.engine)
|
|
89
|
-
|
|
90
|
-
def get(self, model_cls: type, **kwargs):
|
|
91
|
-
table_link = next(
|
|
92
|
-
(
|
|
93
|
-
link
|
|
94
|
-
for link in self._model_table_links
|
|
95
|
-
if link.model_cls == model_cls
|
|
96
|
-
),
|
|
97
|
-
None,
|
|
98
|
-
)
|
|
99
|
-
if table_link is not None:
|
|
100
|
-
return table_link.get(self, **kwargs)
|
|
101
|
-
|
|
102
|
-
def set(self, object: Any, commit: bool = True):
|
|
103
|
-
table_link = next(
|
|
104
|
-
(
|
|
105
|
-
link
|
|
106
|
-
for link in self._model_table_links
|
|
107
|
-
if link.model_cls is type(object)
|
|
108
|
-
),
|
|
109
|
-
None,
|
|
110
|
-
)
|
|
111
|
-
if table_link is not None:
|
|
112
|
-
table_link.set(self, object, commit=commit)
|
|
113
|
-
|
|
114
|
-
def register_model_version(self, model_version):
|
|
115
|
-
"""Register a model version with its model and seed objects.
|
|
116
|
-
This replaces all existing parameters, baseline parameter values,
|
|
117
|
-
and baseline variables for this model version."""
|
|
118
|
-
# Add or update the model directly to avoid conflicts
|
|
119
|
-
from policyengine.utils.compress import compress_data
|
|
120
|
-
|
|
121
|
-
from .baseline_parameter_value_table import BaselineParameterValueTable
|
|
122
|
-
from .baseline_variable_table import BaselineVariableTable
|
|
123
|
-
from .model_table import ModelTable
|
|
124
|
-
from .model_version_table import ModelVersionTable
|
|
125
|
-
from .parameter_table import ParameterTable
|
|
126
|
-
|
|
127
|
-
existing_model = (
|
|
128
|
-
self.session.query(ModelTable)
|
|
129
|
-
.filter(ModelTable.id == model_version.model.id)
|
|
130
|
-
.first()
|
|
131
|
-
)
|
|
132
|
-
if not existing_model:
|
|
133
|
-
model_table = ModelTable(
|
|
134
|
-
id=model_version.model.id,
|
|
135
|
-
name=model_version.model.name,
|
|
136
|
-
description=model_version.model.description,
|
|
137
|
-
simulation_function=(
|
|
138
|
-
lambda m: compress_data(m.simulation_function)
|
|
139
|
-
)(model_version.model),
|
|
140
|
-
)
|
|
141
|
-
self.session.add(model_table)
|
|
142
|
-
self.session.flush()
|
|
143
|
-
|
|
144
|
-
# Add or update the model version
|
|
145
|
-
existing_version = (
|
|
146
|
-
self.session.query(ModelVersionTable)
|
|
147
|
-
.filter(ModelVersionTable.id == model_version.id)
|
|
148
|
-
.first()
|
|
149
|
-
)
|
|
150
|
-
if not existing_version:
|
|
151
|
-
version_table = ModelVersionTable(
|
|
152
|
-
id=model_version.id,
|
|
153
|
-
model_id=model_version.model.id,
|
|
154
|
-
version=model_version.version,
|
|
155
|
-
description=model_version.description,
|
|
156
|
-
created_at=model_version.created_at,
|
|
157
|
-
)
|
|
158
|
-
self.session.add(version_table)
|
|
159
|
-
self.session.flush()
|
|
160
|
-
|
|
161
|
-
# Get seed objects from the model
|
|
162
|
-
seed_objects = model_version.model.create_seed_objects(model_version)
|
|
163
|
-
|
|
164
|
-
# Delete ALL existing seed data for this model (not just this version)
|
|
165
|
-
# This ensures we start fresh with the new version's data
|
|
166
|
-
# Order matters due to foreign key constraints
|
|
167
|
-
|
|
168
|
-
# First delete baseline parameter values (they reference parameters)
|
|
169
|
-
self.session.query(BaselineParameterValueTable).filter(
|
|
170
|
-
BaselineParameterValueTable.model_id == model_version.model.id
|
|
171
|
-
).delete()
|
|
172
|
-
|
|
173
|
-
# Then delete baseline variables for this model
|
|
174
|
-
self.session.query(BaselineVariableTable).filter(
|
|
175
|
-
BaselineVariableTable.model_id == model_version.model.id
|
|
176
|
-
).delete()
|
|
177
|
-
|
|
178
|
-
# Finally delete all parameters for this model
|
|
179
|
-
self.session.query(ParameterTable).filter(
|
|
180
|
-
ParameterTable.model_id == model_version.model.id
|
|
181
|
-
).delete()
|
|
182
|
-
|
|
183
|
-
self.session.commit()
|
|
184
|
-
|
|
185
|
-
# Add all parameters first
|
|
186
|
-
for parameter in seed_objects.parameters:
|
|
187
|
-
# We need to add directly to session to avoid the autoflush issue
|
|
188
|
-
from .parameter_table import ParameterTable
|
|
189
|
-
|
|
190
|
-
param_table = ParameterTable(
|
|
191
|
-
id=parameter.id,
|
|
192
|
-
model_id=parameter.model.id, # Now required as part of composite key
|
|
193
|
-
description=parameter.description,
|
|
194
|
-
data_type=parameter.data_type.__name__
|
|
195
|
-
if parameter.data_type
|
|
196
|
-
else None,
|
|
197
|
-
)
|
|
198
|
-
self.session.add(param_table)
|
|
199
|
-
|
|
200
|
-
# Flush parameters to database so they exist for foreign key constraints
|
|
201
|
-
self.session.flush()
|
|
202
|
-
|
|
203
|
-
# Add all baseline parameter values
|
|
204
|
-
for baseline_param_value in seed_objects.baseline_parameter_values:
|
|
205
|
-
import math
|
|
206
|
-
from uuid import uuid4
|
|
207
|
-
|
|
208
|
-
from .baseline_parameter_value_table import (
|
|
209
|
-
BaselineParameterValueTable,
|
|
210
|
-
)
|
|
211
|
-
|
|
212
|
-
# Handle special float values that JSON doesn't support
|
|
213
|
-
value = baseline_param_value.value
|
|
214
|
-
if isinstance(value, float):
|
|
215
|
-
if math.isinf(value):
|
|
216
|
-
value = "Infinity" if value > 0 else "-Infinity"
|
|
217
|
-
elif math.isnan(value):
|
|
218
|
-
value = "NaN"
|
|
219
|
-
|
|
220
|
-
bpv_table = BaselineParameterValueTable(
|
|
221
|
-
id=str(uuid4()),
|
|
222
|
-
parameter_id=baseline_param_value.parameter.id,
|
|
223
|
-
model_id=baseline_param_value.parameter.model.id, # Add model_id
|
|
224
|
-
model_version_id=baseline_param_value.model_version.id,
|
|
225
|
-
value=value,
|
|
226
|
-
start_date=baseline_param_value.start_date,
|
|
227
|
-
end_date=baseline_param_value.end_date,
|
|
228
|
-
)
|
|
229
|
-
self.session.add(bpv_table)
|
|
230
|
-
|
|
231
|
-
# Add all baseline variables
|
|
232
|
-
for baseline_variable in seed_objects.baseline_variables:
|
|
233
|
-
from .baseline_variable_table import BaselineVariableTable
|
|
234
|
-
|
|
235
|
-
bv_table = BaselineVariableTable(
|
|
236
|
-
id=baseline_variable.id,
|
|
237
|
-
model_id=baseline_variable.model_version.model.id, # Add model_id
|
|
238
|
-
model_version_id=baseline_variable.model_version.id,
|
|
239
|
-
entity=baseline_variable.entity,
|
|
240
|
-
label=baseline_variable.label,
|
|
241
|
-
description=baseline_variable.description,
|
|
242
|
-
data_type=(lambda bv: compress_data(bv.data_type))(
|
|
243
|
-
baseline_variable
|
|
244
|
-
)
|
|
245
|
-
if baseline_variable.data_type
|
|
246
|
-
else None,
|
|
247
|
-
)
|
|
248
|
-
self.session.add(bv_table)
|
|
249
|
-
|
|
250
|
-
# Commit everything at once
|
|
251
|
-
self.session.commit()
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
from uuid import uuid4
|
|
2
|
-
|
|
3
|
-
from sqlmodel import Field, SQLModel
|
|
4
|
-
|
|
5
|
-
from policyengine.models import Dataset
|
|
6
|
-
from policyengine.utils.compress import compress_data, decompress_data
|
|
7
|
-
|
|
8
|
-
from .link import TableLink
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class DatasetTable(SQLModel, table=True):
|
|
12
|
-
__tablename__ = "datasets"
|
|
13
|
-
|
|
14
|
-
id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
|
|
15
|
-
name: str = Field(nullable=False)
|
|
16
|
-
description: str | None = Field(default=None)
|
|
17
|
-
version: str | None = Field(default=None)
|
|
18
|
-
versioned_dataset_id: str | None = Field(
|
|
19
|
-
default=None, foreign_key="versioned_datasets.id", ondelete="SET NULL"
|
|
20
|
-
)
|
|
21
|
-
year: int | None = Field(default=None)
|
|
22
|
-
data: bytes | None = Field(default=None)
|
|
23
|
-
model_id: str | None = Field(
|
|
24
|
-
default=None, foreign_key="models.id", ondelete="SET NULL"
|
|
25
|
-
)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
dataset_table_link = TableLink(
|
|
29
|
-
model_cls=Dataset,
|
|
30
|
-
table_cls=DatasetTable,
|
|
31
|
-
model_to_table_custom_transforms=dict(
|
|
32
|
-
versioned_dataset_id=lambda d: d.versioned_dataset.id
|
|
33
|
-
if d.versioned_dataset
|
|
34
|
-
else None,
|
|
35
|
-
model_id=lambda d: d.model.id if d.model else None,
|
|
36
|
-
data=lambda d: compress_data(d.data) if d.data else None,
|
|
37
|
-
),
|
|
38
|
-
table_to_model_custom_transforms=dict(
|
|
39
|
-
data=lambda b: decompress_data(b) if b else None,
|
|
40
|
-
),
|
|
41
|
-
)
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
from datetime import datetime
|
|
2
|
-
from uuid import uuid4
|
|
3
|
-
|
|
4
|
-
from sqlmodel import Field, SQLModel
|
|
5
|
-
|
|
6
|
-
from policyengine.models import Dynamic
|
|
7
|
-
from policyengine.utils.compress import compress_data, decompress_data
|
|
8
|
-
|
|
9
|
-
from .link import TableLink
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class DynamicTable(SQLModel, table=True):
|
|
13
|
-
__tablename__ = "dynamics"
|
|
14
|
-
|
|
15
|
-
id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
|
|
16
|
-
name: str = Field(nullable=False)
|
|
17
|
-
description: str | None = Field(default=None)
|
|
18
|
-
simulation_modifier: bytes | None = Field(default=None)
|
|
19
|
-
created_at: datetime = Field(default_factory=datetime.now)
|
|
20
|
-
updated_at: datetime = Field(default_factory=datetime.now)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
dynamic_table_link = TableLink(
|
|
24
|
-
model_cls=Dynamic,
|
|
25
|
-
table_cls=DynamicTable,
|
|
26
|
-
model_to_table_custom_transforms=dict(
|
|
27
|
-
simulation_modifier=lambda d: compress_data(d.simulation_modifier)
|
|
28
|
-
if d.simulation_modifier
|
|
29
|
-
else None,
|
|
30
|
-
),
|
|
31
|
-
table_to_model_custom_transforms=dict(
|
|
32
|
-
simulation_modifier=lambda b: decompress_data(b) if b else None,
|
|
33
|
-
),
|
|
34
|
-
)
|
policyengine/database/link.py
DELETED
|
@@ -1,82 +0,0 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
2
|
-
from typing import TYPE_CHECKING
|
|
3
|
-
|
|
4
|
-
from pydantic import BaseModel
|
|
5
|
-
from sqlmodel import SQLModel, select
|
|
6
|
-
|
|
7
|
-
if TYPE_CHECKING:
|
|
8
|
-
from .database import Database
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class TableLink(BaseModel):
|
|
12
|
-
model_cls: type[BaseModel]
|
|
13
|
-
table_cls: type[SQLModel]
|
|
14
|
-
model_to_table_custom_transforms: dict[str, Callable] | None = None
|
|
15
|
-
table_to_model_custom_transforms: dict[str, Callable] | None = None
|
|
16
|
-
primary_key: str | tuple[str, ...] = (
|
|
17
|
-
"id" # Allow multiple strings in tuple
|
|
18
|
-
)
|
|
19
|
-
|
|
20
|
-
def get(self, database: "Database", **kwargs):
|
|
21
|
-
statement = select(self.table_cls).filter_by(**kwargs)
|
|
22
|
-
result = database.session.exec(statement).first()
|
|
23
|
-
if result is None:
|
|
24
|
-
return None
|
|
25
|
-
model_data = result.model_dump()
|
|
26
|
-
if self.table_to_model_custom_transforms:
|
|
27
|
-
for (
|
|
28
|
-
field,
|
|
29
|
-
transform,
|
|
30
|
-
) in self.table_to_model_custom_transforms.items():
|
|
31
|
-
model_data[field] = transform(getattr(result, field))
|
|
32
|
-
|
|
33
|
-
# Only include fields that exist in the model class
|
|
34
|
-
valid_fields = {
|
|
35
|
-
field_name for field_name in self.model_cls.__annotations__.keys()
|
|
36
|
-
}
|
|
37
|
-
filtered_model_data = {
|
|
38
|
-
k: v for k, v in model_data.items() if k in valid_fields
|
|
39
|
-
}
|
|
40
|
-
return self.model_cls(**filtered_model_data)
|
|
41
|
-
|
|
42
|
-
def set(self, database: "Database", obj: BaseModel, commit: bool = True):
|
|
43
|
-
model_data = obj.model_dump()
|
|
44
|
-
if self.model_to_table_custom_transforms:
|
|
45
|
-
for (
|
|
46
|
-
field,
|
|
47
|
-
transform,
|
|
48
|
-
) in self.model_to_table_custom_transforms.items():
|
|
49
|
-
model_data[field] = transform(obj)
|
|
50
|
-
# Only include fields that exist in the table class
|
|
51
|
-
valid_fields = {
|
|
52
|
-
field_name for field_name in self.table_cls.__annotations__.keys()
|
|
53
|
-
}
|
|
54
|
-
filtered_model_data = {
|
|
55
|
-
k: v for k, v in model_data.items() if k in valid_fields
|
|
56
|
-
}
|
|
57
|
-
table_obj = self.table_cls(**filtered_model_data)
|
|
58
|
-
|
|
59
|
-
# Check if already exists using primary key
|
|
60
|
-
query = select(self.table_cls)
|
|
61
|
-
if isinstance(self.primary_key, tuple):
|
|
62
|
-
for key in self.primary_key:
|
|
63
|
-
query = query.where(
|
|
64
|
-
getattr(self.table_cls, key) == getattr(table_obj, key)
|
|
65
|
-
)
|
|
66
|
-
else:
|
|
67
|
-
query = query.where(
|
|
68
|
-
getattr(self.table_cls, self.primary_key)
|
|
69
|
-
== getattr(table_obj, self.primary_key)
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
existing = database.session.exec(query).first()
|
|
73
|
-
if existing:
|
|
74
|
-
# Update existing record
|
|
75
|
-
for key, value in filtered_model_data.items():
|
|
76
|
-
setattr(existing, key, value)
|
|
77
|
-
database.session.add(existing)
|
|
78
|
-
else:
|
|
79
|
-
database.session.add(table_obj)
|
|
80
|
-
|
|
81
|
-
if commit:
|
|
82
|
-
database.session.commit()
|
|
@@ -1,27 +0,0 @@
|
|
|
1
|
-
from sqlmodel import Field, SQLModel
|
|
2
|
-
|
|
3
|
-
from policyengine.models import Model
|
|
4
|
-
from policyengine.utils.compress import compress_data, decompress_data
|
|
5
|
-
|
|
6
|
-
from .link import TableLink
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class ModelTable(SQLModel, table=True, extend_existing=True):
|
|
10
|
-
__tablename__ = "models"
|
|
11
|
-
|
|
12
|
-
id: str = Field(primary_key=True)
|
|
13
|
-
name: str = Field(nullable=False)
|
|
14
|
-
description: str | None = Field(default=None)
|
|
15
|
-
simulation_function: bytes
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
model_table_link = TableLink(
|
|
19
|
-
model_cls=Model,
|
|
20
|
-
table_cls=ModelTable,
|
|
21
|
-
model_to_table_custom_transforms=dict(
|
|
22
|
-
simulation_function=lambda m: compress_data(m.simulation_function),
|
|
23
|
-
),
|
|
24
|
-
table_to_model_custom_transforms=dict(
|
|
25
|
-
simulation_function=lambda b: decompress_data(b),
|
|
26
|
-
),
|
|
27
|
-
)
|
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
from datetime import datetime
|
|
2
|
-
from uuid import uuid4
|
|
3
|
-
|
|
4
|
-
from sqlmodel import Field, SQLModel
|
|
5
|
-
|
|
6
|
-
from policyengine.models import ModelVersion
|
|
7
|
-
|
|
8
|
-
from .link import TableLink
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class ModelVersionTable(SQLModel, table=True):
|
|
12
|
-
__tablename__ = "model_versions"
|
|
13
|
-
|
|
14
|
-
id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
|
|
15
|
-
model_id: str = Field(foreign_key="models.id", ondelete="CASCADE")
|
|
16
|
-
version: str = Field(nullable=False)
|
|
17
|
-
description: str | None = Field(default=None)
|
|
18
|
-
created_at: datetime = Field(default_factory=datetime.now)
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
model_version_table_link = TableLink(
|
|
22
|
-
model_cls=ModelVersion,
|
|
23
|
-
table_cls=ModelVersionTable,
|
|
24
|
-
model_to_table_custom_transforms=dict(
|
|
25
|
-
model_id=lambda model_version: model_version.model.id,
|
|
26
|
-
),
|
|
27
|
-
table_to_model_custom_transforms={},
|
|
28
|
-
)
|
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
from sqlmodel import Field, SQLModel
|
|
2
|
-
|
|
3
|
-
from policyengine.models import Parameter
|
|
4
|
-
|
|
5
|
-
from .link import TableLink
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class ParameterTable(SQLModel, table=True):
|
|
9
|
-
__tablename__ = "parameters"
|
|
10
|
-
__table_args__ = ({"extend_existing": True},)
|
|
11
|
-
|
|
12
|
-
id: str = Field(primary_key=True) # Parameter name
|
|
13
|
-
model_id: str = Field(
|
|
14
|
-
primary_key=True, foreign_key="models.id"
|
|
15
|
-
) # Part of composite key
|
|
16
|
-
description: str | None = Field(default=None)
|
|
17
|
-
data_type: str | None = Field(nullable=True) # Data type name
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
parameter_table_link = TableLink(
|
|
21
|
-
model_cls=Parameter,
|
|
22
|
-
table_cls=ParameterTable,
|
|
23
|
-
primary_key=("id", "model_id"), # Composite primary key
|
|
24
|
-
model_to_table_custom_transforms=dict(
|
|
25
|
-
data_type=lambda p: p.data_type.__name__ if p.data_type else None,
|
|
26
|
-
model_id=lambda p: p.model.id if p.model else None,
|
|
27
|
-
),
|
|
28
|
-
table_to_model_custom_transforms=dict(
|
|
29
|
-
data_type=lambda t: eval(t.data_type) if t.data_type else None
|
|
30
|
-
),
|
|
31
|
-
)
|
|
@@ -1,62 +0,0 @@
|
|
|
1
|
-
from datetime import datetime
|
|
2
|
-
from typing import Any
|
|
3
|
-
from uuid import uuid4
|
|
4
|
-
|
|
5
|
-
from sqlmodel import JSON, Column, Field, SQLModel
|
|
6
|
-
|
|
7
|
-
from policyengine.models import ParameterValue
|
|
8
|
-
|
|
9
|
-
from .link import TableLink
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class ParameterValueTable(SQLModel, table=True):
|
|
13
|
-
__tablename__ = "parameter_values"
|
|
14
|
-
__table_args__ = ({"extend_existing": True},)
|
|
15
|
-
|
|
16
|
-
id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
|
|
17
|
-
parameter_id: str = Field(nullable=False) # Part of composite foreign key
|
|
18
|
-
model_id: str = Field(nullable=False) # Part of composite foreign key
|
|
19
|
-
value: Any | None = Field(
|
|
20
|
-
default=None, sa_column=Column(JSON)
|
|
21
|
-
) # JSON field for any type
|
|
22
|
-
start_date: datetime = Field(nullable=False)
|
|
23
|
-
end_date: datetime | None = Field(default=None)
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def transform_value_to_table(pv):
|
|
27
|
-
"""Transform value for storage, handling special float values."""
|
|
28
|
-
import math
|
|
29
|
-
|
|
30
|
-
value = pv.value
|
|
31
|
-
if isinstance(value, float):
|
|
32
|
-
if math.isinf(value):
|
|
33
|
-
return "Infinity" if value > 0 else "-Infinity"
|
|
34
|
-
elif math.isnan(value):
|
|
35
|
-
return "NaN"
|
|
36
|
-
return value
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def transform_value_from_table(table_row):
|
|
40
|
-
"""Transform value from storage, converting special strings back to floats."""
|
|
41
|
-
value = table_row.value
|
|
42
|
-
if value == "Infinity":
|
|
43
|
-
return float("inf")
|
|
44
|
-
elif value == "-Infinity":
|
|
45
|
-
return float("-inf")
|
|
46
|
-
elif value == "NaN":
|
|
47
|
-
return float("nan")
|
|
48
|
-
return value
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
parameter_value_table_link = TableLink(
|
|
52
|
-
model_cls=ParameterValue,
|
|
53
|
-
table_cls=ParameterValueTable,
|
|
54
|
-
model_to_table_custom_transforms=dict(
|
|
55
|
-
parameter_id=lambda pv: pv.parameter.id,
|
|
56
|
-
model_id=lambda pv: pv.parameter.model.id, # Add model_id from parameter
|
|
57
|
-
value=transform_value_to_table,
|
|
58
|
-
),
|
|
59
|
-
table_to_model_custom_transforms=dict(
|
|
60
|
-
value=transform_value_from_table,
|
|
61
|
-
),
|
|
62
|
-
)
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
from datetime import datetime
|
|
2
|
-
from uuid import uuid4
|
|
3
|
-
|
|
4
|
-
from sqlmodel import Field, SQLModel
|
|
5
|
-
|
|
6
|
-
from policyengine.models import Policy
|
|
7
|
-
from policyengine.utils.compress import compress_data, decompress_data
|
|
8
|
-
|
|
9
|
-
from .link import TableLink
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class PolicyTable(SQLModel, table=True):
|
|
13
|
-
__tablename__ = "policies"
|
|
14
|
-
|
|
15
|
-
id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
|
|
16
|
-
name: str = Field(nullable=False)
|
|
17
|
-
description: str | None = Field(default=None)
|
|
18
|
-
simulation_modifier: bytes | None = Field(default=None)
|
|
19
|
-
created_at: datetime = Field(default_factory=datetime.now)
|
|
20
|
-
updated_at: datetime = Field(default_factory=datetime.now)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
policy_table_link = TableLink(
|
|
24
|
-
model_cls=Policy,
|
|
25
|
-
table_cls=PolicyTable,
|
|
26
|
-
model_to_table_custom_transforms=dict(
|
|
27
|
-
simulation_modifier=lambda p: compress_data(p.simulation_modifier)
|
|
28
|
-
if p.simulation_modifier
|
|
29
|
-
else None,
|
|
30
|
-
),
|
|
31
|
-
table_to_model_custom_transforms=dict(
|
|
32
|
-
simulation_modifier=lambda b: decompress_data(b) if b else None,
|
|
33
|
-
),
|
|
34
|
-
)
|
|
@@ -1,48 +0,0 @@
|
|
|
1
|
-
import uuid
|
|
2
|
-
from datetime import datetime
|
|
3
|
-
|
|
4
|
-
from sqlmodel import Field, SQLModel
|
|
5
|
-
|
|
6
|
-
from policyengine.models.report_element import ReportElement
|
|
7
|
-
|
|
8
|
-
from .link import TableLink
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class ReportElementTable(SQLModel, table=True, extend_existing=True):
|
|
12
|
-
__tablename__ = "report_elements"
|
|
13
|
-
|
|
14
|
-
id: str = Field(
|
|
15
|
-
primary_key=True, default_factory=lambda: str(uuid.uuid4())
|
|
16
|
-
)
|
|
17
|
-
label: str = Field(nullable=False)
|
|
18
|
-
type: str = Field(nullable=False) # "chart" or "markdown"
|
|
19
|
-
|
|
20
|
-
# Data source
|
|
21
|
-
data_table: str | None = Field(default=None) # "aggregates"
|
|
22
|
-
|
|
23
|
-
# Chart configuration
|
|
24
|
-
chart_type: str | None = Field(
|
|
25
|
-
default=None
|
|
26
|
-
) # "bar", "line", "scatter", "area", "pie"
|
|
27
|
-
x_axis_variable: str | None = Field(default=None)
|
|
28
|
-
y_axis_variable: str | None = Field(default=None)
|
|
29
|
-
group_by: str | None = Field(default=None)
|
|
30
|
-
color_by: str | None = Field(default=None)
|
|
31
|
-
size_by: str | None = Field(default=None)
|
|
32
|
-
|
|
33
|
-
# Markdown specific
|
|
34
|
-
markdown_content: str | None = Field(default=None)
|
|
35
|
-
|
|
36
|
-
# Metadata
|
|
37
|
-
report_id: str | None = Field(default=None, foreign_key="reports.id")
|
|
38
|
-
user_id: str | None = Field(default=None, foreign_key="users.id")
|
|
39
|
-
position: int | None = Field(default=None)
|
|
40
|
-
visible: bool | None = Field(default=True)
|
|
41
|
-
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
42
|
-
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
report_element_table_link = TableLink(
|
|
46
|
-
model_cls=ReportElement,
|
|
47
|
-
table_cls=ReportElementTable,
|
|
48
|
-
)
|
|
@@ -1,24 +0,0 @@
|
|
|
1
|
-
import uuid
|
|
2
|
-
from datetime import datetime
|
|
3
|
-
|
|
4
|
-
from sqlmodel import Field, SQLModel
|
|
5
|
-
|
|
6
|
-
from policyengine.models.report import Report
|
|
7
|
-
|
|
8
|
-
from .link import TableLink
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class ReportTable(SQLModel, table=True, extend_existing=True):
|
|
12
|
-
__tablename__ = "reports"
|
|
13
|
-
|
|
14
|
-
id: str = Field(
|
|
15
|
-
primary_key=True, default_factory=lambda: str(uuid.uuid4())
|
|
16
|
-
)
|
|
17
|
-
label: str = Field(nullable=False)
|
|
18
|
-
created_at: datetime = Field(default_factory=datetime.utcnow)
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
report_table_link = TableLink(
|
|
22
|
-
model_cls=Report,
|
|
23
|
-
table_cls=ReportTable,
|
|
24
|
-
)
|