policyengine 3.0.0__py3-none-any.whl → 3.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. policyengine/__pycache__/__init__.cpython-313.pyc +0 -0
  2. policyengine/core/__init__.py +22 -0
  3. policyengine/core/dataset.py +260 -0
  4. policyengine/core/dataset_version.py +16 -0
  5. policyengine/core/dynamic.py +43 -0
  6. policyengine/core/output.py +26 -0
  7. policyengine/{models → core}/parameter.py +4 -2
  8. policyengine/{models → core}/parameter_value.py +1 -1
  9. policyengine/core/policy.py +43 -0
  10. policyengine/{models → core}/simulation.py +10 -14
  11. policyengine/core/tax_benefit_model.py +11 -0
  12. policyengine/core/tax_benefit_model_version.py +34 -0
  13. policyengine/core/variable.py +15 -0
  14. policyengine/outputs/__init__.py +21 -0
  15. policyengine/outputs/aggregate.py +124 -0
  16. policyengine/outputs/change_aggregate.py +184 -0
  17. policyengine/outputs/decile_impact.py +140 -0
  18. policyengine/tax_benefit_models/uk/__init__.py +26 -0
  19. policyengine/tax_benefit_models/uk/analysis.py +97 -0
  20. policyengine/tax_benefit_models/uk/datasets.py +176 -0
  21. policyengine/tax_benefit_models/uk/model.py +268 -0
  22. policyengine/tax_benefit_models/uk/outputs.py +108 -0
  23. policyengine/tax_benefit_models/uk.py +33 -0
  24. policyengine/tax_benefit_models/us/__init__.py +36 -0
  25. policyengine/tax_benefit_models/us/analysis.py +99 -0
  26. policyengine/tax_benefit_models/us/datasets.py +307 -0
  27. policyengine/tax_benefit_models/us/model.py +447 -0
  28. policyengine/tax_benefit_models/us/outputs.py +108 -0
  29. policyengine/tax_benefit_models/us.py +32 -0
  30. policyengine/utils/__init__.py +3 -0
  31. policyengine/utils/dates.py +40 -0
  32. policyengine/utils/parametric_reforms.py +39 -0
  33. policyengine/utils/plotting.py +179 -0
  34. {policyengine-3.0.0.dist-info → policyengine-3.1.1.dist-info}/METADATA +185 -20
  35. policyengine-3.1.1.dist-info/RECORD +39 -0
  36. policyengine/database/__init__.py +0 -56
  37. policyengine/database/aggregate.py +0 -33
  38. policyengine/database/baseline_parameter_value_table.py +0 -66
  39. policyengine/database/baseline_variable_table.py +0 -40
  40. policyengine/database/database.py +0 -251
  41. policyengine/database/dataset_table.py +0 -41
  42. policyengine/database/dynamic_table.py +0 -34
  43. policyengine/database/link.py +0 -82
  44. policyengine/database/model_table.py +0 -27
  45. policyengine/database/model_version_table.py +0 -28
  46. policyengine/database/parameter_table.py +0 -31
  47. policyengine/database/parameter_value_table.py +0 -62
  48. policyengine/database/policy_table.py +0 -34
  49. policyengine/database/report_element_table.py +0 -48
  50. policyengine/database/report_table.py +0 -24
  51. policyengine/database/simulation_table.py +0 -50
  52. policyengine/database/user_table.py +0 -28
  53. policyengine/database/versioned_dataset_table.py +0 -28
  54. policyengine/models/__init__.py +0 -30
  55. policyengine/models/aggregate.py +0 -92
  56. policyengine/models/baseline_parameter_value.py +0 -14
  57. policyengine/models/baseline_variable.py +0 -12
  58. policyengine/models/dataset.py +0 -18
  59. policyengine/models/dynamic.py +0 -15
  60. policyengine/models/model.py +0 -124
  61. policyengine/models/model_version.py +0 -14
  62. policyengine/models/policy.py +0 -17
  63. policyengine/models/policyengine_uk.py +0 -114
  64. policyengine/models/policyengine_us.py +0 -115
  65. policyengine/models/report.py +0 -10
  66. policyengine/models/report_element.py +0 -36
  67. policyengine/models/user.py +0 -14
  68. policyengine/models/versioned_dataset.py +0 -12
  69. policyengine/utils/charts.py +0 -286
  70. policyengine/utils/compress.py +0 -20
  71. policyengine/utils/datasets.py +0 -71
  72. policyengine-3.0.0.dist-info/RECORD +0 -47
  73. policyengine-3.0.0.dist-info/entry_points.txt +0 -2
  74. {policyengine-3.0.0.dist-info → policyengine-3.1.1.dist-info}/WHEEL +0 -0
  75. {policyengine-3.0.0.dist-info → policyengine-3.1.1.dist-info}/licenses/LICENSE +0 -0
  76. {policyengine-3.0.0.dist-info → policyengine-3.1.1.dist-info}/top_level.txt +0 -0
@@ -1,50 +0,0 @@
1
- from datetime import datetime
2
- from uuid import uuid4
3
-
4
- from sqlmodel import Field, SQLModel
5
-
6
- from policyengine.models import Simulation
7
- from policyengine.utils.compress import compress_data, decompress_data
8
-
9
- from .link import TableLink
10
-
11
-
12
- class SimulationTable(SQLModel, table=True):
13
- __tablename__ = "simulations"
14
-
15
- id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
16
- created_at: datetime = Field(default_factory=datetime.now)
17
- updated_at: datetime = Field(default_factory=datetime.now)
18
-
19
- policy_id: str | None = Field(
20
- default=None, foreign_key="policies.id", ondelete="SET NULL"
21
- )
22
- dynamic_id: str | None = Field(
23
- default=None, foreign_key="dynamics.id", ondelete="SET NULL"
24
- )
25
- dataset_id: str = Field(foreign_key="datasets.id", ondelete="CASCADE")
26
- model_id: str = Field(foreign_key="models.id", ondelete="CASCADE")
27
- model_version_id: str | None = Field(
28
- default=None, foreign_key="model_versions.id", ondelete="SET NULL"
29
- )
30
-
31
- result: bytes | None = Field(default=None)
32
-
33
-
34
- simulation_table_link = TableLink(
35
- model_cls=Simulation,
36
- table_cls=SimulationTable,
37
- model_to_table_custom_transforms=dict(
38
- policy_id=lambda s: s.policy.id if s.policy else None,
39
- dynamic_id=lambda s: s.dynamic.id if s.dynamic else None,
40
- dataset_id=lambda s: s.dataset.id,
41
- model_id=lambda s: s.model.id,
42
- model_version_id=lambda s: s.model_version.id
43
- if s.model_version
44
- else None,
45
- result=lambda s: compress_data(s.result) if s.result else None,
46
- ),
47
- table_to_model_custom_transforms=dict(
48
- result=lambda b: decompress_data(b) if b else None,
49
- ),
50
- )
@@ -1,28 +0,0 @@
1
- import uuid
2
- from datetime import datetime
3
-
4
- from sqlmodel import Field, SQLModel
5
-
6
- from policyengine.models.user import User
7
-
8
- from .link import TableLink
9
-
10
-
11
- class UserTable(SQLModel, table=True, extend_existing=True):
12
- __tablename__ = "users"
13
-
14
- id: str = Field(
15
- primary_key=True, default_factory=lambda: str(uuid.uuid4())
16
- )
17
- username: str = Field(nullable=False, unique=True)
18
- first_name: str | None = Field(default=None)
19
- last_name: str | None = Field(default=None)
20
- email: str | None = Field(default=None)
21
- created_at: datetime = Field(default_factory=datetime.utcnow)
22
- updated_at: datetime = Field(default_factory=datetime.utcnow)
23
-
24
-
25
- user_table_link = TableLink(
26
- model_cls=User,
27
- table_cls=UserTable,
28
- )
@@ -1,28 +0,0 @@
1
- from uuid import uuid4
2
-
3
- from sqlmodel import Field, SQLModel
4
-
5
- from policyengine.models import VersionedDataset
6
-
7
- from .link import TableLink
8
-
9
-
10
- class VersionedDatasetTable(SQLModel, table=True):
11
- __tablename__ = "versioned_datasets"
12
-
13
- id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
14
- name: str = Field(nullable=False)
15
- description: str = Field(nullable=False)
16
- model_id: str | None = Field(
17
- default=None, foreign_key="models.id", ondelete="SET NULL"
18
- )
19
-
20
-
21
- versioned_dataset_table_link = TableLink(
22
- model_cls=VersionedDataset,
23
- table_cls=VersionedDatasetTable,
24
- model_to_table_custom_transforms=dict(
25
- model_id=lambda vd: vd.model.id if vd.model else None,
26
- ),
27
- table_to_model_custom_transforms={},
28
- )
@@ -1,30 +0,0 @@
1
- from .aggregate import Aggregate as Aggregate
2
- from .aggregate import AggregateType as AggregateType
3
- from .baseline_parameter_value import (
4
- BaselineParameterValue as BaselineParameterValue,
5
- )
6
- from .baseline_variable import BaselineVariable as BaselineVariable
7
- from .dataset import Dataset as Dataset
8
- from .dynamic import Dynamic as Dynamic
9
- from .model import Model as Model
10
- from .model_version import ModelVersion as ModelVersion
11
- from .parameter import Parameter as Parameter
12
- from .parameter_value import ParameterValue as ParameterValue
13
- from .policy import Policy as Policy
14
- from .policyengine_uk import (
15
- policyengine_uk_latest_version as policyengine_uk_latest_version,
16
- )
17
- from .policyengine_uk import (
18
- policyengine_uk_model as policyengine_uk_model,
19
- )
20
- from .policyengine_us import (
21
- policyengine_us_latest_version as policyengine_us_latest_version,
22
- )
23
- from .policyengine_us import (
24
- policyengine_us_model as policyengine_us_model,
25
- )
26
- from .report import Report as Report
27
- from .report_element import ReportElement as ReportElement
28
- from .simulation import Simulation as Simulation
29
- from .user import User as User
30
- from .versioned_dataset import VersionedDataset as VersionedDataset
@@ -1,92 +0,0 @@
1
- from enum import Enum
2
- from typing import TYPE_CHECKING, Literal
3
-
4
- import pandas as pd
5
- from microdf import MicroDataFrame
6
- from pydantic import BaseModel
7
-
8
- if TYPE_CHECKING:
9
- from policyengine.models import Simulation
10
-
11
-
12
- class AggregateType(str, Enum):
13
- SUM = "sum"
14
- MEAN = "mean"
15
- COUNT = "count"
16
-
17
-
18
- class Aggregate(BaseModel):
19
- simulation: "Simulation"
20
- entity: str
21
- variable_name: str
22
- year: int | None = None
23
- filter_variable_name: str | None = None
24
- filter_variable_value: str | None = None
25
- filter_variable_leq: float | None = None
26
- filter_variable_geq: float | None = None
27
- aggregate_function: Literal[
28
- AggregateType.SUM, AggregateType.MEAN, AggregateType.COUNT
29
- ]
30
-
31
- value: float | None = None
32
-
33
- @staticmethod
34
- def run(aggregates: list["Aggregate"]) -> list["Aggregate"]:
35
- # Assumes that all aggregates are from the same simulation
36
- results = []
37
-
38
- tables = aggregates[0].simulation.result
39
- for table in tables:
40
- tables[table] = pd.DataFrame(tables[table])
41
- weight_col = f"{table}_weight"
42
- if weight_col in tables[table].columns:
43
- tables[table] = MicroDataFrame(
44
- tables[table], weights=weight_col
45
- )
46
-
47
- for agg in aggregates:
48
- if agg.entity not in tables:
49
- raise ValueError(
50
- f"Entity {agg.entity} not found in simulation results"
51
- )
52
- table = tables[agg.entity]
53
-
54
- if agg.variable_name not in table.columns:
55
- raise ValueError(
56
- f"Variable {agg.variable_name} not found in entity {agg.entity}"
57
- )
58
-
59
- df = table
60
-
61
- if agg.year is None:
62
- agg.year = aggregates[0].simulation.dataset.year
63
-
64
- if agg.filter_variable_name is not None:
65
- if agg.filter_variable_name not in df.columns:
66
- raise ValueError(
67
- f"Filter variable {agg.filter_variable_name} not found in entity {agg.entity}"
68
- )
69
- if agg.filter_variable_value is not None:
70
- df = df[
71
- df[agg.filter_variable_name]
72
- == agg.filter_variable_value
73
- ]
74
- if agg.filter_variable_leq is not None:
75
- df = df[
76
- df[agg.filter_variable_name] <= agg.filter_variable_leq
77
- ]
78
- if agg.filter_variable_geq is not None:
79
- df = df[
80
- df[agg.filter_variable_name] >= agg.filter_variable_geq
81
- ]
82
-
83
- if agg.aggregate_function == AggregateType.SUM:
84
- agg.value = float(df[agg.variable_name].sum())
85
- elif agg.aggregate_function == AggregateType.MEAN:
86
- agg.value = float(df[agg.variable_name].mean())
87
- elif agg.aggregate_function == AggregateType.COUNT:
88
- agg.value = float((df[agg.variable_name] > 0).sum())
89
-
90
- results.append(agg)
91
-
92
- return results
@@ -1,14 +0,0 @@
1
- from datetime import datetime
2
-
3
- from pydantic import BaseModel
4
-
5
- from .model_version import ModelVersion
6
- from .parameter import Parameter
7
-
8
-
9
- class BaselineParameterValue(BaseModel):
10
- parameter: Parameter
11
- model_version: ModelVersion
12
- value: float | int | str | bool | list | None = None
13
- start_date: datetime
14
- end_date: datetime | None = None
@@ -1,12 +0,0 @@
1
- from pydantic import BaseModel
2
-
3
- from .model_version import ModelVersion
4
-
5
-
6
- class BaselineVariable(BaseModel):
7
- id: str
8
- model_version: ModelVersion
9
- entity: str
10
- label: str | None = None
11
- description: str | None = None
12
- data_type: type | None = None
@@ -1,18 +0,0 @@
1
- from typing import Any
2
- from uuid import uuid4
3
-
4
- from pydantic import BaseModel, Field
5
-
6
- from .model import Model
7
- from .versioned_dataset import VersionedDataset
8
-
9
-
10
- class Dataset(BaseModel):
11
- id: str = Field(default_factory=lambda: str(uuid4()))
12
- name: str
13
- description: str | None = None
14
- version: str | None = None
15
- versioned_dataset: VersionedDataset | None = None
16
- year: int | None = None
17
- data: Any | None = None
18
- model: Model | None = None
@@ -1,15 +0,0 @@
1
- from collections.abc import Callable
2
- from datetime import datetime
3
- from uuid import uuid4
4
-
5
- from pydantic import BaseModel, Field
6
-
7
-
8
- class Dynamic(BaseModel):
9
- id: str = Field(default_factory=lambda: str(uuid4()))
10
- name: str
11
- description: str | None = None
12
- parameter_values: list[str] = []
13
- simulation_modifier: Callable | None = None
14
- created_at: datetime = Field(default_factory=datetime.now)
15
- updated_at: datetime = Field(default_factory=datetime.now)
@@ -1,124 +0,0 @@
1
- from collections.abc import Callable
2
- from datetime import datetime
3
- from typing import TYPE_CHECKING
4
-
5
- from pydantic import BaseModel
6
-
7
- if TYPE_CHECKING:
8
- from .baseline_parameter_value import BaselineParameterValue
9
- from .baseline_variable import BaselineVariable
10
- from .parameter import Parameter
11
-
12
-
13
- class Model(BaseModel):
14
- id: str
15
- name: str
16
- description: str | None = None
17
- simulation_function: Callable
18
-
19
- def create_seed_objects(self, model_version):
20
- from policyengine_core.parameters import Parameter as CoreParameter
21
-
22
- from .baseline_parameter_value import BaselineParameterValue
23
- from .baseline_variable import BaselineVariable
24
- from .parameter import Parameter
25
-
26
- if self.id == "policyengine_uk":
27
- from policyengine_uk.tax_benefit_system import system
28
- elif self.id == "policyengine_us":
29
- from policyengine_us.system import system
30
- else:
31
- raise ValueError("Unsupported model.")
32
-
33
- parameters = []
34
- baseline_parameter_values = []
35
- baseline_variables = []
36
- seen_parameter_ids = set()
37
-
38
- for parameter in system.parameters.get_descendants():
39
- # Skip if we've already processed this parameter ID
40
- if parameter.name in seen_parameter_ids:
41
- continue
42
- seen_parameter_ids.add(parameter.name)
43
- param = Parameter(
44
- id=parameter.name,
45
- description=parameter.description,
46
- data_type=None,
47
- model=self,
48
- )
49
- parameters.append(param)
50
- if isinstance(parameter, CoreParameter):
51
- values = parameter.values_list[::-1]
52
- param.data_type = type(values[-1].value)
53
- for i in range(len(values)):
54
- value_at_instant = values[i]
55
- instant_str = safe_parse_instant_str(
56
- value_at_instant.instant_str
57
- )
58
- if i + 1 < len(values):
59
- next_instant_str = safe_parse_instant_str(
60
- values[i + 1].instant_str
61
- )
62
- else:
63
- next_instant_str = None
64
- baseline_param_value = BaselineParameterValue(
65
- parameter=param,
66
- model_version=model_version,
67
- value=value_at_instant.value,
68
- start_date=instant_str,
69
- end_date=next_instant_str,
70
- )
71
- baseline_parameter_values.append(baseline_param_value)
72
-
73
- for variable in system.variables.values():
74
- baseline_variable = BaselineVariable(
75
- id=variable.name,
76
- model_version=model_version,
77
- entity=variable.entity.key,
78
- label=variable.label,
79
- description=variable.documentation,
80
- data_type=variable.value_type,
81
- )
82
- baseline_variables.append(baseline_variable)
83
-
84
- return SeedObjects(
85
- parameters=parameters,
86
- baseline_parameter_values=baseline_parameter_values,
87
- baseline_variables=baseline_variables,
88
- )
89
-
90
-
91
- def safe_parse_instant_str(instant_str: str) -> datetime:
92
- if instant_str == "0000-01-01":
93
- return datetime(1, 1, 1)
94
- else:
95
- try:
96
- return datetime.strptime(instant_str, "%Y-%m-%d")
97
- except ValueError:
98
- # Handle invalid dates like 2021-06-31
99
- # Try to parse year and month, then use last valid day
100
- parts = instant_str.split("-")
101
- if len(parts) == 3:
102
- year = int(parts[0])
103
- month = int(parts[1])
104
- day = int(parts[2])
105
-
106
- # Find the last valid day of the month
107
- import calendar
108
-
109
- last_day = calendar.monthrange(year, month)[1]
110
- if day > last_day:
111
- print(
112
- f"Warning: Invalid date {instant_str}, using {year}-{month:02d}-{last_day:02d}"
113
- )
114
- return datetime(year, month, last_day)
115
-
116
- # If we can't parse it at all, print and raise
117
- print(f"Error: Cannot parse date {instant_str}")
118
- raise
119
-
120
-
121
- class SeedObjects(BaseModel):
122
- parameters: list["Parameter"]
123
- baseline_parameter_values: list["BaselineParameterValue"]
124
- baseline_variables: list["BaselineVariable"]
@@ -1,14 +0,0 @@
1
- from datetime import datetime
2
- from uuid import uuid4
3
-
4
- from pydantic import BaseModel, Field
5
-
6
- from .model import Model
7
-
8
-
9
- class ModelVersion(BaseModel):
10
- id: str = Field(default_factory=lambda: str(uuid4()))
11
- model: Model
12
- version: str
13
- description: str | None = None
14
- created_at: datetime = Field(default_factory=datetime.now)
@@ -1,17 +0,0 @@
1
- from collections.abc import Callable
2
- from datetime import datetime
3
- from uuid import uuid4
4
-
5
- from pydantic import BaseModel, Field
6
-
7
- from .parameter_value import ParameterValue
8
-
9
-
10
- class Policy(BaseModel):
11
- id: str = Field(default_factory=lambda: str(uuid4()))
12
- name: str
13
- description: str | None = None
14
- parameter_values: list[ParameterValue] = []
15
- simulation_modifier: Callable | None = None
16
- created_at: datetime = Field(default_factory=datetime.now)
17
- updated_at: datetime = Field(default_factory=datetime.now)
@@ -1,114 +0,0 @@
1
- import importlib.metadata
2
-
3
- import pandas as pd
4
-
5
- from ..models import Dataset, Dynamic, Model, ModelVersion, Policy
6
-
7
-
8
- def run_policyengine_uk(
9
- dataset: "Dataset",
10
- policy: "Policy | None" = None,
11
- dynamic: "Dynamic | None" = None,
12
- ) -> dict[str, "pd.DataFrame"]:
13
- data: dict[str, pd.DataFrame] = dataset.data
14
-
15
- from policyengine_uk import Microsimulation
16
- from policyengine_uk.data import UKSingleYearDataset
17
-
18
- pe_input_data = UKSingleYearDataset(
19
- person=data["person"],
20
- benunit=data["benunit"],
21
- household=data["household"],
22
- fiscal_year=dataset.year,
23
- )
24
-
25
- sim = Microsimulation(dataset=pe_input_data)
26
- sim.default_calculation_period = dataset.year
27
-
28
- def simulation_modifier(sim: Microsimulation):
29
- if policy is not None and len(policy.parameter_values) > 0:
30
- for parameter_value in policy.parameter_values:
31
- sim.tax_benefit_system.parameters.get_child(
32
- parameter_value.parameter.id
33
- ).update(
34
- value=parameter_value.value,
35
- start=parameter_value.start_date.strftime("%Y-%m-%d"),
36
- stop=parameter_value.end_date.strftime("%Y-%m-%d")
37
- if parameter_value.end_date
38
- else None,
39
- )
40
-
41
- if dynamic is not None and len(dynamic.parameter_values) > 0:
42
- for parameter_value in dynamic.parameter_values:
43
- sim.tax_benefit_system.parameters.get_child(
44
- parameter_value.parameter.id
45
- ).update(
46
- value=parameter_value.value,
47
- start=parameter_value.start_date.strftime("%Y-%m-%d"),
48
- stop=parameter_value.end_date.strftime("%Y-%m-%d")
49
- if parameter_value.end_date
50
- else None,
51
- )
52
-
53
- if dynamic is not None and dynamic.simulation_modifier is not None:
54
- dynamic.simulation_modifier(sim)
55
- if policy is not None and policy.simulation_modifier is not None:
56
- policy.simulation_modifier(sim)
57
-
58
- simulation_modifier(sim)
59
-
60
- # Skip reforms for now
61
-
62
- output_data = {}
63
-
64
- variable_blacklist = [ # TEMPORARY: we need to fix policyengine-uk to make these only take a long time with non-default parameters set to true.
65
- "is_uc_entitled_baseline",
66
- "income_elasticity_lsr",
67
- "child_benefit_opts_out",
68
- "housing_benefit_baseline_entitlement",
69
- "baseline_ctc_entitlement",
70
- "pre_budget_change_household_tax",
71
- "pre_budget_change_household_net_income",
72
- "is_on_cliff",
73
- "marginal_tax_rate_on_capital_gains",
74
- "relative_capital_gains_mtr_change",
75
- "pre_budget_change_ons_equivalised_income_decile",
76
- "substitution_elasticity",
77
- "marginal_tax_rate",
78
- "cliff_evaluated",
79
- "cliff_gap",
80
- "substitution_elasticity_lsr",
81
- "relative_wage_change",
82
- "relative_income_change",
83
- "pre_budget_change_household_benefits",
84
- ]
85
-
86
- for entity in ["person", "benunit", "household"]:
87
- output_data[entity] = pd.DataFrame()
88
- for variable in sim.tax_benefit_system.variables.values():
89
- correct_entity = variable.entity.key == entity
90
- if variable.name in variable_blacklist:
91
- continue
92
- if variable.definition_period != "year":
93
- continue
94
- if correct_entity:
95
- output_data[entity][variable.name] = sim.calculate(
96
- variable.name
97
- )
98
-
99
- return output_data
100
-
101
-
102
- policyengine_uk_model = Model(
103
- id="policyengine_uk",
104
- name="PolicyEngine UK",
105
- description="PolicyEngine's open-source tax-benefit microsimulation model.",
106
- simulation_function=run_policyengine_uk,
107
- )
108
-
109
- # Get policyengine-uk version
110
-
111
- policyengine_uk_latest_version = ModelVersion(
112
- model=policyengine_uk_model,
113
- version=importlib.metadata.distribution("policyengine_uk").version,
114
- )
@@ -1,115 +0,0 @@
1
- import importlib.metadata
2
-
3
- import pandas as pd
4
-
5
- from ..models import Dataset, Dynamic, Model, ModelVersion, Policy
6
-
7
-
8
- def run_policyengine_us(
9
- dataset: "Dataset",
10
- policy: "Policy | None" = None,
11
- dynamic: "Dynamic | None" = None,
12
- ) -> dict[str, "pd.DataFrame"]:
13
- data: dict[str, pd.DataFrame] = dataset.data
14
-
15
- person_df = pd.DataFrame()
16
-
17
- for table_name, table in data.items():
18
- if table_name == "person":
19
- for col in table.columns:
20
- person_df[f"{col}__{dataset.year}"] = table[col].values
21
- else:
22
- foreign_key = data["person"][f"person_{table_name}_id"]
23
- primary_key = data[table_name][f"{table_name}_id"]
24
-
25
- projected = table.set_index(primary_key).loc[foreign_key]
26
-
27
- for col in projected.columns:
28
- person_df[f"{col}__{dataset.year}"] = projected[col].values
29
-
30
- from policyengine_us import Microsimulation
31
-
32
- sim = Microsimulation(dataset=person_df)
33
- sim.default_calculation_period = dataset.year
34
-
35
- def simulation_modifier(sim: Microsimulation):
36
- if policy is not None and len(policy.parameter_values) > 0:
37
- for parameter_value in policy.parameter_values:
38
- sim.tax_benefit_system.parameters.get_child(
39
- parameter_value.parameter.id
40
- ).update(
41
- parameter_value.value,
42
- start=parameter_value.start_date.strftime("%Y-%m-%d"),
43
- stop=parameter_value.end_date.strftime("%Y-%m-%d")
44
- if parameter_value.end_date
45
- else None,
46
- )
47
-
48
- if dynamic is not None and len(dynamic.parameter_values) > 0:
49
- for parameter_value in dynamic.parameter_values:
50
- sim.tax_benefit_system.parameters.get_child(
51
- parameter_value.parameter.id
52
- ).update(
53
- parameter_value.value,
54
- start=parameter_value.start_date.strftime("%Y-%m-%d"),
55
- stop=parameter_value.end_date.strftime("%Y-%m-%d")
56
- if parameter_value.end_date
57
- else None,
58
- )
59
-
60
- if dynamic is not None and dynamic.simulation_modifier is not None:
61
- dynamic.simulation_modifier(sim)
62
- if policy is not None and policy.simulation_modifier is not None:
63
- policy.simulation_modifier(sim)
64
-
65
- simulation_modifier(sim)
66
-
67
- # Skip reforms for now
68
-
69
- output_data = {}
70
-
71
- variable_whitelist = [
72
- "household_net_income",
73
- ]
74
-
75
- for variable in variable_whitelist:
76
- sim.calculate(variable)
77
-
78
- for entity in [
79
- "person",
80
- "marital_unit",
81
- "family",
82
- "tax_unit",
83
- "spm_unit",
84
- "household",
85
- ]:
86
- output_data[entity] = pd.DataFrame()
87
- for variable in sim.tax_benefit_system.variables.values():
88
- correct_entity = variable.entity.key == entity
89
- if str(dataset.year) not in list(
90
- map(str, sim.get_holder(variable.name).get_known_periods())
91
- ):
92
- continue
93
- if variable.definition_period != "year":
94
- continue
95
- if not correct_entity:
96
- continue
97
- output_data[entity][variable.name] = sim.calculate(variable.name)
98
-
99
- return output_data
100
-
101
-
102
- policyengine_us_model = Model(
103
- id="policyengine_us",
104
- name="PolicyEngine US",
105
- description="PolicyEngine's open-source tax-benefit microsimulation model.",
106
- simulation_function=run_policyengine_us,
107
- )
108
-
109
- # Get policyengine-uk version
110
-
111
-
112
- policyengine_us_latest_version = ModelVersion(
113
- model=policyengine_us_model,
114
- version=importlib.metadata.distribution("policyengine_us").version,
115
- )