causalrl 0.99.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- causalrl/__init__.py +274 -0
- causalrl/_backend/__init__.py +15 -0
- causalrl/agents/__init__.py +1 -0
- causalrl/agents/bandits.py +61 -0
- causalrl/agents/base.py +22 -0
- causalrl/agents/baselines.py +59 -0
- causalrl/agents/counterfactual.py +64 -0
- causalrl/agents/deep_deconfounded.py +80 -0
- causalrl/agents/dovi.py +139 -0
- causalrl/agents/offline_online.py +63 -0
- causalrl/agents/primitives.py +25 -0
- causalrl/agents/scbandit.py +100 -0
- causalrl/curriculum.py +117 -0
- causalrl/data/__init__.py +1 -0
- causalrl/data/dataset.py +90 -0
- causalrl/discovery.py +709 -0
- causalrl/envs/__init__.py +1 -0
- causalrl/envs/base.py +49 -0
- causalrl/envs/suite/__init__.py +1 -0
- causalrl/envs/suite/counterfactual_bandit.py +93 -0
- causalrl/envs/suite/curriculum.py +24 -0
- causalrl/envs/suite/discovery.py +59 -0
- causalrl/envs/suite/dtr.py +96 -0
- causalrl/envs/suite/games.py +38 -0
- causalrl/envs/suite/gridworld.py +91 -0
- causalrl/envs/suite/imitation.py +82 -0
- causalrl/envs/suite/mabuc.py +90 -0
- causalrl/envs/suite/scbandit.py +170 -0
- causalrl/envs/suite/seq_dtr.py +103 -0
- causalrl/envs/suite/seq_mabuc.py +96 -0
- causalrl/envs/suite/shaping.py +34 -0
- causalrl/envs/suite/transport.py +58 -0
- causalrl/eval/__init__.py +1 -0
- causalrl/eval/benchmark.py +168 -0
- causalrl/eval/harness.py +29 -0
- causalrl/eval/metrics.py +8 -0
- causalrl/eval/ope.py +20 -0
- causalrl/exceptions.py +25 -0
- causalrl/experimental/__init__.py +1 -0
- causalrl/experimental/ope.py +19 -0
- causalrl/games.py +331 -0
- causalrl/identification/__init__.py +1 -0
- causalrl/identification/_separation.py +50 -0
- causalrl/identification/bounds.py +114 -0
- causalrl/identification/counterfactual.py +131 -0
- causalrl/identification/criteria.py +33 -0
- causalrl/identification/id_algorithm.py +816 -0
- causalrl/identification/intervention_sets.py +152 -0
- causalrl/identification/transport.py +204 -0
- causalrl/imitation.py +124 -0
- causalrl/py.typed +0 -0
- causalrl/scm/__init__.py +1 -0
- causalrl/scm/graph.py +169 -0
- causalrl/scm/mechanisms.py +57 -0
- causalrl/scm/scm.py +131 -0
- causalrl/shaping.py +116 -0
- causalrl-0.99.0.dist-info/METADATA +392 -0
- causalrl-0.99.0.dist-info/RECORD +60 -0
- causalrl-0.99.0.dist-info/WHEEL +4 -0
- causalrl-0.99.0.dist-info/licenses/LICENSE +21 -0
causalrl/__init__.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
# pyright: reportUnsupportedDunderAll=false
|
|
2
|
+
"""causalrl: causal intervention-selection and causal-RL research tools.
|
|
3
|
+
|
|
4
|
+
The stable public API is loaded lazily so graph algorithms and tabular components can be used
|
|
5
|
+
without installing the optional PyTorch-backed SCM and neural functionality.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from importlib import import_module as _import_module
|
|
9
|
+
from importlib.metadata import version as _pkg_version
|
|
10
|
+
from typing import cast as _cast
|
|
11
|
+
|
|
12
|
+
__version__ = _pkg_version("causalrl")
|
|
13
|
+
|
|
14
|
+
_EXPORTS: dict[str, tuple[str, str]] = {
|
|
15
|
+
"DOVI": ("causalrl.agents.dovi", "DOVI"),
|
|
16
|
+
"UCDTR": ("causalrl.agents.offline_online", "UCDTR"),
|
|
17
|
+
"Agent": ("causalrl.agents.base", "Agent"),
|
|
18
|
+
"BruteForceInterventionTS": ("causalrl.agents.scbandit", "BruteForceInterventionTS"),
|
|
19
|
+
"BehavioralCloning": ("causalrl.imitation", "BehavioralCloning"),
|
|
20
|
+
"BenchmarkEstimate": ("causalrl.eval.benchmark", "BenchmarkEstimate"),
|
|
21
|
+
"CPDAG": ("causalrl.discovery", "CPDAG"),
|
|
22
|
+
"PAG": ("causalrl.discovery", "PAG"),
|
|
23
|
+
"CausalEnv": ("causalrl.envs.base", "CausalEnv"),
|
|
24
|
+
"CausalGame": ("causalrl.games", "CausalGame"),
|
|
25
|
+
"CausalGraph": ("causalrl.scm.graph", "CausalGraph"),
|
|
26
|
+
"CausalGraphError": ("causalrl.exceptions", "CausalGraphError"),
|
|
27
|
+
"CausalImitator": ("causalrl.imitation", "CausalImitator"),
|
|
28
|
+
"CausalRLError": ("causalrl.exceptions", "CausalRLError"),
|
|
29
|
+
"CausalThompsonSampling": ("causalrl.agents.bandits", "CausalThompsonSampling"),
|
|
30
|
+
"ConfoundedGridworld": ("causalrl.envs.suite.gridworld", "ConfoundedGridworld"),
|
|
31
|
+
"ConfoundedMDP": ("causalrl.envs.base", "ConfoundedMDP"),
|
|
32
|
+
"ConfoundedTrajectoryDataset": ("causalrl.data.dataset", "ConfoundedTrajectoryDataset"),
|
|
33
|
+
"CounterfactualBanditEnv": (
|
|
34
|
+
"causalrl.envs.suite.counterfactual_bandit",
|
|
35
|
+
"CounterfactualBanditEnv",
|
|
36
|
+
),
|
|
37
|
+
"CounterfactualOptimalPolicy": (
|
|
38
|
+
"causalrl.agents.counterfactual",
|
|
39
|
+
"CounterfactualOptimalPolicy",
|
|
40
|
+
),
|
|
41
|
+
"DTREnv": ("causalrl.envs.suite.dtr", "DTREnv"),
|
|
42
|
+
"DeepDeconfoundedQ": ("causalrl.agents.deep_deconfounded", "DeepDeconfoundedQ"),
|
|
43
|
+
"Domain": ("causalrl.identification.id_algorithm", "Domain"),
|
|
44
|
+
"Estimand": ("causalrl.identification.id_algorithm", "Estimand"),
|
|
45
|
+
"FixedSetThompsonSampling": ("causalrl.agents.scbandit", "FixedSetThompsonSampling"),
|
|
46
|
+
"FunctionalMechanism": ("causalrl.scm.mechanisms", "FunctionalMechanism"),
|
|
47
|
+
"LinearGaussianMechanism": ("causalrl.scm.mechanisms", "LinearGaussianMechanism"),
|
|
48
|
+
"MABUCEnv": ("causalrl.envs.suite.mabuc", "MABUCEnv"),
|
|
49
|
+
"Mechanism": ("causalrl.scm.mechanisms", "Mechanism"),
|
|
50
|
+
"NaiveOffline": ("causalrl.agents.baselines", "NaiveOffline"),
|
|
51
|
+
"NaivePOMISThompsonSampling": ("causalrl.agents.scbandit", "NaivePOMISThompsonSampling"),
|
|
52
|
+
"NaiveThompsonSampling": ("causalrl.agents.bandits", "NaiveThompsonSampling"),
|
|
53
|
+
"NeuralMechanism": ("causalrl.scm.mechanisms", "NeuralMechanism"),
|
|
54
|
+
"NotIdentifiableError": ("causalrl.exceptions", "NotIdentifiableError"),
|
|
55
|
+
"OnlineOnlyUCB": ("causalrl.agents.baselines", "OnlineOnlyUCB"),
|
|
56
|
+
"POMISThompsonSampling": ("causalrl.agents.scbandit", "POMISThompsonSampling"),
|
|
57
|
+
"PrerequisiteLearner": ("causalrl.curriculum", "PrerequisiteLearner"),
|
|
58
|
+
"RealizabilityError": ("causalrl.exceptions", "RealizabilityError"),
|
|
59
|
+
"UnverifiedAssumptionError": ("causalrl.exceptions", "UnverifiedAssumptionError"),
|
|
60
|
+
"SelectionDiagram": ("causalrl.identification.transport", "SelectionDiagram"),
|
|
61
|
+
"SequentialDTREnv": ("causalrl.envs.suite.seq_dtr", "SequentialDTREnv"),
|
|
62
|
+
"SequentialMABUCEnv": ("causalrl.envs.suite.seq_mabuc", "SequentialMABUCEnv"),
|
|
63
|
+
"StructuralCausalBanditEnv": ("causalrl.envs.suite.scbandit", "StructuralCausalBanditEnv"),
|
|
64
|
+
"StructuralCausalModel": ("causalrl.scm.scm", "StructuralCausalModel"),
|
|
65
|
+
"TabularMDP": ("causalrl.shaping", "TabularMDP"),
|
|
66
|
+
"Transition": ("causalrl.data.dataset", "Transition"),
|
|
67
|
+
"TransportFormula": ("causalrl.identification.transport", "TransportFormula"),
|
|
68
|
+
"apply_potential_shaping": ("causalrl.shaping", "apply_potential_shaping"),
|
|
69
|
+
"backdoor_adjustment_set": ("causalrl.identification.criteria", "backdoor_adjustment_set"),
|
|
70
|
+
"best_response": ("causalrl.games", "best_response"),
|
|
71
|
+
"causal_curriculum": ("causalrl.curriculum", "causal_curriculum"),
|
|
72
|
+
"causal_potential": ("causalrl.shaping", "causal_potential"),
|
|
73
|
+
"causal_q_bounds": ("causalrl.identification.bounds", "causal_q_bounds"),
|
|
74
|
+
"ipw_sensitivity_bounds": ("causalrl.identification.bounds", "ipw_sensitivity_bounds"),
|
|
75
|
+
"manski_bounds": ("causalrl.identification.bounds", "manski_bounds"),
|
|
76
|
+
"conditional_mutual_information": ("causalrl.discovery", "conditional_mutual_information"),
|
|
77
|
+
"counterfactual_expectation": (
|
|
78
|
+
"causalrl.identification.counterfactual",
|
|
79
|
+
"counterfactual_expectation",
|
|
80
|
+
),
|
|
81
|
+
"cumulative_regret": ("causalrl.eval.metrics", "cumulative_regret"),
|
|
82
|
+
"curriculum_q_learning": ("causalrl.curriculum", "curriculum_q_learning"),
|
|
83
|
+
"discover": ("causalrl.discovery", "discover"),
|
|
84
|
+
"discover_interventional": ("causalrl.discovery", "discover_interventional"),
|
|
85
|
+
"discover_latent": ("causalrl.discovery", "discover_latent"),
|
|
86
|
+
"effect_of_treatment_on_treated": (
|
|
87
|
+
"causalrl.identification.counterfactual",
|
|
88
|
+
"effect_of_treatment_on_treated",
|
|
89
|
+
),
|
|
90
|
+
"estimate_effect": ("causalrl.identification.id_algorithm", "estimate_effect"),
|
|
91
|
+
"estimate_effect_with_experiments": (
|
|
92
|
+
"causalrl.identification.id_algorithm",
|
|
93
|
+
"estimate_effect_with_experiments",
|
|
94
|
+
),
|
|
95
|
+
"estimate_transport_general": (
|
|
96
|
+
"causalrl.identification.id_algorithm",
|
|
97
|
+
"estimate_transport_general",
|
|
98
|
+
),
|
|
99
|
+
"estimate_transported_effect": (
|
|
100
|
+
"causalrl.identification.id_algorithm",
|
|
101
|
+
"estimate_transported_effect",
|
|
102
|
+
),
|
|
103
|
+
"finite_horizon_regret": ("causalrl.eval.metrics", "finite_horizon_regret"),
|
|
104
|
+
"generate_logs": ("causalrl.data.dataset", "generate_logs"),
|
|
105
|
+
"identify_effect": ("causalrl.identification.id_algorithm", "identify_effect"),
|
|
106
|
+
"identify_effect_with_experiments": (
|
|
107
|
+
"causalrl.identification.id_algorithm",
|
|
108
|
+
"identify_effect_with_experiments",
|
|
109
|
+
),
|
|
110
|
+
"identify_transport": ("causalrl.identification.id_algorithm", "identify_transport"),
|
|
111
|
+
"identify_transport_general": (
|
|
112
|
+
"causalrl.identification.id_algorithm",
|
|
113
|
+
"identify_transport_general",
|
|
114
|
+
),
|
|
115
|
+
"imitation_backdoor_set": ("causalrl.imitation", "imitation_backdoor_set"),
|
|
116
|
+
"ipw_value": ("causalrl.eval.ope", "ipw_value"),
|
|
117
|
+
"is_backdoor_admissible": ("causalrl.identification.transport", "is_backdoor_admissible"),
|
|
118
|
+
"is_gid_identifiable": ("causalrl.identification.id_algorithm", "is_gid_identifiable"),
|
|
119
|
+
"is_identifiable": ("causalrl.identification.criteria", "is_identifiable"),
|
|
120
|
+
"is_identifiable_effect": ("causalrl.identification.id_algorithm", "is_identifiable_effect"),
|
|
121
|
+
"is_imitable": ("causalrl.imitation", "is_imitable"),
|
|
122
|
+
"is_nash_equilibrium": ("causalrl.games", "is_nash_equilibrium"),
|
|
123
|
+
"is_transportable": ("causalrl.identification.transport", "is_transportable"),
|
|
124
|
+
"is_transportable_effect": (
|
|
125
|
+
"causalrl.identification.id_algorithm",
|
|
126
|
+
"is_transportable_effect",
|
|
127
|
+
),
|
|
128
|
+
"is_transportable_general": (
|
|
129
|
+
"causalrl.identification.id_algorithm",
|
|
130
|
+
"is_transportable_general",
|
|
131
|
+
),
|
|
132
|
+
"is_valid_curriculum": ("causalrl.curriculum", "is_valid_curriculum"),
|
|
133
|
+
"minimal_intervention_sets": (
|
|
134
|
+
"causalrl.identification.intervention_sets",
|
|
135
|
+
"minimal_intervention_sets",
|
|
136
|
+
),
|
|
137
|
+
"mixed_nash_equilibria": ("causalrl.games", "mixed_nash_equilibria"),
|
|
138
|
+
"pomis": ("causalrl.identification.intervention_sets", "pomis"),
|
|
139
|
+
"pure_nash_equilibria": ("causalrl.games", "pure_nash_equilibria"),
|
|
140
|
+
"q_learning": ("causalrl.shaping", "q_learning"),
|
|
141
|
+
"report_to_dict": ("causalrl.eval.benchmark", "report_to_dict"),
|
|
142
|
+
"requires_experiment": ("causalrl.identification.intervention_sets", "requires_experiment"),
|
|
143
|
+
"run_episodes": ("causalrl.eval.harness", "run_episodes"),
|
|
144
|
+
"run_confounded_chain_benchmark": ("causalrl.eval.benchmark", "run_confounded_chain_benchmark"),
|
|
145
|
+
"run_frontdoor_benchmark": ("causalrl.eval.benchmark", "run_frontdoor_benchmark"),
|
|
146
|
+
"transport_estimand": ("causalrl.identification.transport", "transport_estimand"),
|
|
147
|
+
"transport_formula": ("causalrl.identification.transport", "transport_formula"),
|
|
148
|
+
"transported_effect": ("causalrl.identification.transport", "transported_effect"),
|
|
149
|
+
"value_iteration": ("causalrl.shaping", "value_iteration"),
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
__all__ = [
|
|
153
|
+
"CPDAG",
|
|
154
|
+
"DOVI",
|
|
155
|
+
"PAG",
|
|
156
|
+
"UCDTR",
|
|
157
|
+
"Agent",
|
|
158
|
+
"BehavioralCloning",
|
|
159
|
+
"BenchmarkEstimate",
|
|
160
|
+
"BruteForceInterventionTS",
|
|
161
|
+
"CausalEnv",
|
|
162
|
+
"CausalGame",
|
|
163
|
+
"CausalGraph",
|
|
164
|
+
"CausalGraphError",
|
|
165
|
+
"CausalImitator",
|
|
166
|
+
"CausalRLError",
|
|
167
|
+
"CausalThompsonSampling",
|
|
168
|
+
"ConfoundedGridworld",
|
|
169
|
+
"ConfoundedMDP",
|
|
170
|
+
"ConfoundedTrajectoryDataset",
|
|
171
|
+
"CounterfactualBanditEnv",
|
|
172
|
+
"CounterfactualOptimalPolicy",
|
|
173
|
+
"DTREnv",
|
|
174
|
+
"DeepDeconfoundedQ",
|
|
175
|
+
"Domain",
|
|
176
|
+
"Estimand",
|
|
177
|
+
"FixedSetThompsonSampling",
|
|
178
|
+
"FunctionalMechanism",
|
|
179
|
+
"LinearGaussianMechanism",
|
|
180
|
+
"MABUCEnv",
|
|
181
|
+
"Mechanism",
|
|
182
|
+
"NaiveOffline",
|
|
183
|
+
"NaivePOMISThompsonSampling",
|
|
184
|
+
"NaiveThompsonSampling",
|
|
185
|
+
"NeuralMechanism",
|
|
186
|
+
"NotIdentifiableError",
|
|
187
|
+
"OnlineOnlyUCB",
|
|
188
|
+
"POMISThompsonSampling",
|
|
189
|
+
"PrerequisiteLearner",
|
|
190
|
+
"RealizabilityError",
|
|
191
|
+
"SelectionDiagram",
|
|
192
|
+
"SequentialDTREnv",
|
|
193
|
+
"SequentialMABUCEnv",
|
|
194
|
+
"StructuralCausalBanditEnv",
|
|
195
|
+
"StructuralCausalModel",
|
|
196
|
+
"TabularMDP",
|
|
197
|
+
"Transition",
|
|
198
|
+
"TransportFormula",
|
|
199
|
+
"UnverifiedAssumptionError",
|
|
200
|
+
"__version__",
|
|
201
|
+
"apply_potential_shaping",
|
|
202
|
+
"backdoor_adjustment_set",
|
|
203
|
+
"best_response",
|
|
204
|
+
"causal_curriculum",
|
|
205
|
+
"causal_potential",
|
|
206
|
+
"causal_q_bounds",
|
|
207
|
+
"conditional_mutual_information",
|
|
208
|
+
"counterfactual_expectation",
|
|
209
|
+
"cumulative_regret",
|
|
210
|
+
"curriculum_q_learning",
|
|
211
|
+
"discover",
|
|
212
|
+
"discover_interventional",
|
|
213
|
+
"discover_latent",
|
|
214
|
+
"effect_of_treatment_on_treated",
|
|
215
|
+
"estimate_effect",
|
|
216
|
+
"estimate_effect_with_experiments",
|
|
217
|
+
"estimate_transport_general",
|
|
218
|
+
"estimate_transported_effect",
|
|
219
|
+
"finite_horizon_regret",
|
|
220
|
+
"generate_logs",
|
|
221
|
+
"identify_effect",
|
|
222
|
+
"identify_effect_with_experiments",
|
|
223
|
+
"identify_transport",
|
|
224
|
+
"identify_transport_general",
|
|
225
|
+
"imitation_backdoor_set",
|
|
226
|
+
"ipw_sensitivity_bounds",
|
|
227
|
+
"ipw_value",
|
|
228
|
+
"is_backdoor_admissible",
|
|
229
|
+
"is_gid_identifiable",
|
|
230
|
+
"is_identifiable",
|
|
231
|
+
"is_identifiable_effect",
|
|
232
|
+
"is_imitable",
|
|
233
|
+
"is_nash_equilibrium",
|
|
234
|
+
"is_transportable",
|
|
235
|
+
"is_transportable_effect",
|
|
236
|
+
"is_transportable_general",
|
|
237
|
+
"is_valid_curriculum",
|
|
238
|
+
"manski_bounds",
|
|
239
|
+
"minimal_intervention_sets",
|
|
240
|
+
"mixed_nash_equilibria",
|
|
241
|
+
"pomis",
|
|
242
|
+
"pure_nash_equilibria",
|
|
243
|
+
"q_learning",
|
|
244
|
+
"report_to_dict",
|
|
245
|
+
"requires_experiment",
|
|
246
|
+
"run_confounded_chain_benchmark",
|
|
247
|
+
"run_episodes",
|
|
248
|
+
"run_frontdoor_benchmark",
|
|
249
|
+
"transport_estimand",
|
|
250
|
+
"transport_formula",
|
|
251
|
+
"transported_effect",
|
|
252
|
+
"value_iteration",
|
|
253
|
+
]
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def __getattr__(name: str) -> object:
|
|
257
|
+
target = _EXPORTS.get(name)
|
|
258
|
+
if target is None:
|
|
259
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
260
|
+
module_name, attribute = target
|
|
261
|
+
try:
|
|
262
|
+
value = _cast(object, getattr(_import_module(module_name), attribute))
|
|
263
|
+
except ModuleNotFoundError as exc:
|
|
264
|
+
if exc.name == "torch":
|
|
265
|
+
raise ImportError(
|
|
266
|
+
f"{name} requires PyTorch support; install the 'causalrl[torch]' extra"
|
|
267
|
+
) from exc
|
|
268
|
+
raise
|
|
269
|
+
globals()[name] = value
|
|
270
|
+
return value
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def __dir__() -> list[str]:
|
|
274
|
+
return sorted(set(globals()) | set(__all__))
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Thin numerics seam. PyTorch today; an alternate backend can re-implement this module."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
# torch.Generator/Tensor are public API but not re-exported in torch's type stubs.
|
|
6
|
+
Tensor = torch.Tensor
|
|
7
|
+
Generator = torch.Generator # type: ignore[reportPrivateImportUsage]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def default_generator(seed: int | None = None) -> Generator:
|
|
11
|
+
"""Return a torch.Generator, optionally seeded, for reproducible sampling."""
|
|
12
|
+
gen = Generator()
|
|
13
|
+
if seed is not None:
|
|
14
|
+
gen.manual_seed(seed)
|
|
15
|
+
return gen
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Causal RL agents."""
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from causalrl.agents.base import Agent
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class NaiveThompsonSampling(Agent):
|
|
9
|
+
"""Beta-Bernoulli Thompson sampling, one posterior per arm. Ignores intuition,
|
|
10
|
+
so it cannot distinguish the arms (their interventional means are equal).
|
|
11
|
+
|
|
12
|
+
Sampling uses a per-instance ``numpy.random.Generator`` seeded from ``seed``, so two
|
|
13
|
+
agents with different seeds draw independent action sequences and a fixed seed is
|
|
14
|
+
reproducible regardless of global RNG state.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, n_arms: int, seed: int | None = None) -> None:
|
|
18
|
+
self.n_arms = n_arms
|
|
19
|
+
self._alpha = np.ones(n_arms)
|
|
20
|
+
self._beta = np.ones(n_arms)
|
|
21
|
+
self._rng = np.random.default_rng(seed)
|
|
22
|
+
|
|
23
|
+
def act(self, observation: dict[str, Any]) -> int:
|
|
24
|
+
samples = np.asarray(self._rng.beta(self._alpha, self._beta), dtype=np.float64)
|
|
25
|
+
return int(np.argmax(samples))
|
|
26
|
+
|
|
27
|
+
def update(self, observation: dict[str, Any], action: int, reward: float) -> None:
|
|
28
|
+
if reward > 0:
|
|
29
|
+
self._alpha[action] += 1.0
|
|
30
|
+
else:
|
|
31
|
+
self._beta[action] += 1.0
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CausalThompsonSampling(Agent):
|
|
35
|
+
"""Thompson sampling with one Beta posterior per (intuition, arm) cell.
|
|
36
|
+
|
|
37
|
+
Conditioning on the observed confounder proxy `intuition` de-confounds the choice,
|
|
38
|
+
letting the agent learn the lucky arm for each intuition value.
|
|
39
|
+
|
|
40
|
+
Sampling uses a per-instance ``numpy.random.Generator`` seeded from ``seed`` (see
|
|
41
|
+
:class:`NaiveThompsonSampling` for the reproducibility contract).
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, n_arms: int, n_contexts: int, seed: int | None = None) -> None:
|
|
45
|
+
self.n_arms = n_arms
|
|
46
|
+
self.n_contexts = n_contexts
|
|
47
|
+
self._alpha = np.ones((n_contexts, n_arms))
|
|
48
|
+
self._beta = np.ones((n_contexts, n_arms))
|
|
49
|
+
self._rng = np.random.default_rng(seed)
|
|
50
|
+
|
|
51
|
+
def act(self, observation: dict[str, Any]) -> int:
|
|
52
|
+
ctx = int(observation["intuition"])
|
|
53
|
+
samples = np.asarray(self._rng.beta(self._alpha[ctx], self._beta[ctx]), dtype=np.float64)
|
|
54
|
+
return int(np.argmax(samples))
|
|
55
|
+
|
|
56
|
+
def update(self, observation: dict[str, Any], action: int, reward: float) -> None:
|
|
57
|
+
ctx = int(observation["intuition"])
|
|
58
|
+
if reward > 0:
|
|
59
|
+
self._alpha[ctx, action] += 1.0
|
|
60
|
+
else:
|
|
61
|
+
self._beta[ctx, action] += 1.0
|
causalrl/agents/base.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class Agent(ABC):
|
|
6
|
+
"""Minimal agent interface: choose an action, then learn from the reward."""
|
|
7
|
+
|
|
8
|
+
@abstractmethod
|
|
9
|
+
def act(self, observation: dict[str, Any]) -> int: ...
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def update(self, observation: dict[str, Any], action: int, reward: float) -> None: ...
|
|
13
|
+
|
|
14
|
+
def observe_transition( # noqa: B027
|
|
15
|
+
self, state: int, action: int, next_state: int, done: bool
|
|
16
|
+
) -> None:
|
|
17
|
+
"""Optional model-learning hook: observe a `(s, a, s', done)` transition.
|
|
18
|
+
|
|
19
|
+
Default no-op. Model-based agents (e.g. DOVI's value iteration) override this to
|
|
20
|
+
build an empirical transition model; reward-only agents ignore it.
|
|
21
|
+
"""
|
|
22
|
+
pass
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from causalrl.agents.base import Agent
|
|
7
|
+
from causalrl.data.dataset import ConfoundedTrajectoryDataset
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OnlineOnlyUCB(Agent):
|
|
11
|
+
"""UCB1 per state, ignoring any offline data. The 'learn from scratch online' baseline."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, n_states: int, n_actions: int, seed: int | None = None) -> None:
|
|
14
|
+
self.n_states = n_states
|
|
15
|
+
self.n_actions = n_actions
|
|
16
|
+
self._counts = np.zeros((n_states, n_actions))
|
|
17
|
+
self._sums = np.zeros((n_states, n_actions))
|
|
18
|
+
self._t = 1
|
|
19
|
+
self._rng = np.random.default_rng(seed)
|
|
20
|
+
|
|
21
|
+
def ingest_offline(self, dataset: ConfoundedTrajectoryDataset) -> None:
|
|
22
|
+
"""No-op: this baseline ignores offline data by design."""
|
|
23
|
+
|
|
24
|
+
def act(self, observation: dict[str, Any]) -> int:
|
|
25
|
+
s = int(observation["state"])
|
|
26
|
+
untried = [a for a in range(self.n_actions) if self._counts[s, a] == 0]
|
|
27
|
+
if untried:
|
|
28
|
+
return int(self._rng.choice(untried))
|
|
29
|
+
means = self._sums[s] / self._counts[s]
|
|
30
|
+
bonus = np.sqrt(2.0 * math.log(self._t) / self._counts[s])
|
|
31
|
+
return int(np.argmax(np.asarray(means + bonus, dtype=np.float64)))
|
|
32
|
+
|
|
33
|
+
def update(self, observation: dict[str, Any], action: int, reward: float) -> None:
|
|
34
|
+
s = int(observation["state"])
|
|
35
|
+
self._counts[s, action] += 1.0
|
|
36
|
+
self._sums[s, action] += reward
|
|
37
|
+
self._t += 1
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class NaiveOffline(Agent):
|
|
41
|
+
"""Fits E[R|s,a] from the logs as if unconfounded and acts greedily. Provably biased
|
|
42
|
+
under confounding — the cautionary baseline."""
|
|
43
|
+
|
|
44
|
+
def __init__(self, n_states: int, n_actions: int) -> None:
|
|
45
|
+
self.n_states = n_states
|
|
46
|
+
self.n_actions = n_actions
|
|
47
|
+
self._mean = np.zeros((n_states, n_actions))
|
|
48
|
+
|
|
49
|
+
def ingest_offline(self, dataset: ConfoundedTrajectoryDataset) -> None:
|
|
50
|
+
for s in range(self.n_states):
|
|
51
|
+
for a in range(self.n_actions):
|
|
52
|
+
self._mean[s, a] = dataset.mean_reward(s, a)
|
|
53
|
+
|
|
54
|
+
def act(self, observation: dict[str, Any]) -> int:
|
|
55
|
+
s = int(observation["state"])
|
|
56
|
+
return int(np.argmax(np.asarray(self._mean[s], dtype=np.float64)))
|
|
57
|
+
|
|
58
|
+
def update(self, observation: dict[str, Any], action: int, reward: float) -> None:
|
|
59
|
+
"""Naive-offline does not learn online (fixed policy from the logs)."""
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Counterfactual decision-making agents (Layer 3).
|
|
2
|
+
|
|
3
|
+
A model-based policy that decides by querying the counterfactual reward ``E[Y_{do(a)} | intent]``
|
|
4
|
+
on a known SCM — the Regret Decision Criterion of Bareinboim, Forney & Pearl (NeurIPS 2015).
|
|
5
|
+
Conditioning the choice on the agent's own intent (a proxy for the unobserved confounder) recovers
|
|
6
|
+
the per-intent optimum that the best fixed intervention cannot see.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from collections.abc import Sequence
|
|
12
|
+
from typing import TYPE_CHECKING, Any
|
|
13
|
+
|
|
14
|
+
from causalrl.agents.base import Agent
|
|
15
|
+
from causalrl.identification.counterfactual import regret_decision_table
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from causalrl.scm.scm import StructuralCausalModel
|
|
19
|
+
|
|
20
|
+
__all__ = ["CounterfactualOptimalPolicy"]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class CounterfactualOptimalPolicy(Agent):
|
|
24
|
+
"""Plays ``argmax_a E[Y_{do(action_node=a)} | intent]`` from a known SCM.
|
|
25
|
+
|
|
26
|
+
The Layer-3 oracle: it precomputes the Regret Decision Criterion table once at construction and
|
|
27
|
+
then acts greedily on the observed intent. ``update`` is a no-op — the model is known, so there
|
|
28
|
+
is nothing to learn online. The computed table is exposed as ``decision_table`` for inspection.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
scm: StructuralCausalModel,
|
|
34
|
+
*,
|
|
35
|
+
outcome: str,
|
|
36
|
+
action_node: str,
|
|
37
|
+
intent_node: str,
|
|
38
|
+
arms: Sequence[int],
|
|
39
|
+
intents: Sequence[int],
|
|
40
|
+
intent_key: str = "intuition",
|
|
41
|
+
n: int = 20_000,
|
|
42
|
+
seed: int | None = None,
|
|
43
|
+
) -> None:
|
|
44
|
+
self._intent_key = intent_key
|
|
45
|
+
self.decision_table = regret_decision_table(
|
|
46
|
+
scm,
|
|
47
|
+
outcome=outcome,
|
|
48
|
+
action_node=action_node,
|
|
49
|
+
intent_node=intent_node,
|
|
50
|
+
arms=arms,
|
|
51
|
+
intents=intents,
|
|
52
|
+
n=n,
|
|
53
|
+
seed=seed,
|
|
54
|
+
)
|
|
55
|
+
self._best_arm: dict[int, int] = {
|
|
56
|
+
intent: max(row, key=lambda arm: row[arm])
|
|
57
|
+
for intent, row in self.decision_table.items()
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
def act(self, observation: dict[str, Any]) -> int:
|
|
61
|
+
return self._best_arm[int(observation[self._intent_key])]
|
|
62
|
+
|
|
63
|
+
def update(self, observation: dict[str, Any], action: int, reward: float) -> None:
|
|
64
|
+
"""No-op: the SCM is known, so there is nothing to learn online."""
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from causalrl.agents.base import Agent
|
|
7
|
+
from causalrl.agents.primitives import bounds_table
|
|
8
|
+
from causalrl.data.dataset import ConfoundedTrajectoryDataset
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class _QNet(torch.nn.Module):
|
|
12
|
+
"""One-hot state -> Q-values per action."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, n_states: int, n_actions: int) -> None:
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.net = torch.nn.Sequential( # type: ignore[reportUnknownMemberType]
|
|
17
|
+
torch.nn.Linear(n_states, 64), # type: ignore[reportPrivateImportUsage]
|
|
18
|
+
torch.nn.ReLU(), # type: ignore[reportPrivateImportUsage]
|
|
19
|
+
torch.nn.Linear(64, n_actions), # type: ignore[reportPrivateImportUsage]
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
23
|
+
return self.net(x) # type: ignore[reportUnknownMemberType,no-any-return]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DeepDeconfoundedQ(Agent):
|
|
27
|
+
"""DQN-style agent whose bootstrap targets are clamped into the Manski causal bounds.
|
|
28
|
+
|
|
29
|
+
Function approximation can extrapolate Q outside the causally-valid envelope; clamping
|
|
30
|
+
each (state, action) target into [lower, upper] keeps the learned values consistent with
|
|
31
|
+
what the confounded offline data can support, while online experience tightens them.
|
|
32
|
+
A small one-step bootstrap is used (sufficient for the toy demo); a full offline-RL
|
|
33
|
+
backbone (d3rlpy) is the designated reuse path at real scale.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, n_states: int, n_actions: int, seed: int | None = None) -> None:
|
|
37
|
+
self.n_states = n_states
|
|
38
|
+
self.n_actions = n_actions
|
|
39
|
+
torch.manual_seed(0 if seed is None else seed) # type: ignore[reportPrivateImportUsage]
|
|
40
|
+
self._rng = np.random.default_rng(seed)
|
|
41
|
+
self._q = _QNet(n_states, n_actions)
|
|
42
|
+
self._opt = torch.optim.Adam(self._q.parameters(), lr=1e-2) # type: ignore[reportPrivateImportUsage]
|
|
43
|
+
self._eps = 0.2
|
|
44
|
+
self._lower = np.zeros((n_states, n_actions))
|
|
45
|
+
self._upper = np.ones((n_states, n_actions))
|
|
46
|
+
|
|
47
|
+
def ingest_offline(self, dataset: ConfoundedTrajectoryDataset) -> None:
|
|
48
|
+
for (s, a), (lo, hi) in bounds_table(dataset).items():
|
|
49
|
+
self._lower[s, a] = lo
|
|
50
|
+
self._upper[s, a] = hi
|
|
51
|
+
|
|
52
|
+
def bound(self, state: int, action: int) -> tuple[float, float]:
|
|
53
|
+
return float(self._lower[state, action]), float(self._upper[state, action])
|
|
54
|
+
|
|
55
|
+
def clamp_target(self, state: int, action: int, target: float) -> float:
|
|
56
|
+
lo, hi = self.bound(state, action)
|
|
57
|
+
return float(min(max(target, lo), hi))
|
|
58
|
+
|
|
59
|
+
def _onehot(self, state: int) -> torch.Tensor:
|
|
60
|
+
x = torch.zeros(self.n_states) # type: ignore[reportPrivateImportUsage]
|
|
61
|
+
x[state] = 1.0
|
|
62
|
+
return x
|
|
63
|
+
|
|
64
|
+
def act(self, observation: dict[str, Any]) -> int:
|
|
65
|
+
s = int(observation["state"])
|
|
66
|
+
if self._rng.random() < self._eps:
|
|
67
|
+
return int(self._rng.integers(0, self.n_actions))
|
|
68
|
+
with torch.no_grad(): # type: ignore[reportPrivateImportUsage]
|
|
69
|
+
q = self._q(self._onehot(s))
|
|
70
|
+
return int(torch.argmax(q).item()) # type: ignore[reportPrivateImportUsage]
|
|
71
|
+
|
|
72
|
+
def update(self, observation: dict[str, Any], action: int, reward: float) -> None:
|
|
73
|
+
s = int(observation["state"])
|
|
74
|
+
target = self.clamp_target(s, action, float(reward))
|
|
75
|
+
q = self._q(self._onehot(s))
|
|
76
|
+
pred = q[action]
|
|
77
|
+
loss = (pred - torch.tensor(target)) ** 2 # type: ignore[reportPrivateImportUsage]
|
|
78
|
+
self._opt.zero_grad()
|
|
79
|
+
loss.backward()
|
|
80
|
+
self._opt.step() # type: ignore[reportUnknownMemberType]
|