dcoupler 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dcoupler/__init__.py ADDED
@@ -0,0 +1,52 @@
1
+ from dcoupler.core import (
2
+ CouplingGraph,
3
+ DifferentiableComponent,
4
+ FluxDirection,
5
+ FluxSpec,
6
+ GradientMethod,
7
+ ParameterSpec,
8
+ FluxConnection,
9
+ SpatialRemapper,
10
+ TemporalOrchestrator,
11
+ ConservationChecker,
12
+ BMIMixin,
13
+ )
14
+ from dcoupler.optimization import (
15
+ ParameterManager,
16
+ MultiObservationLoss,
17
+ LossTerm,
18
+ Trainer,
19
+ TrainingResult,
20
+ )
21
+ from dcoupler import observers as _observers
22
+ from dcoupler import diagnostics as _diagnostics
23
+ from dcoupler import utils as _utils
24
+ from dcoupler.observers import * # noqa: F401,F403
25
+ from dcoupler.diagnostics import * # noqa: F401,F403
26
+ from dcoupler import losses as _losses
27
+ from dcoupler import components as _components
28
+ from dcoupler import wrappers as _wrappers
29
+ from dcoupler.losses import * # noqa: F401,F403
30
+ from dcoupler.components import * # noqa: F401,F403
31
+ from dcoupler.wrappers import * # noqa: F401,F403
32
+
33
+ __version__ = "0.2.0"
34
+
35
+ __all__ = [
36
+ "CouplingGraph",
37
+ "DifferentiableComponent",
38
+ "FluxDirection",
39
+ "FluxSpec",
40
+ "GradientMethod",
41
+ "ParameterSpec",
42
+ "FluxConnection",
43
+ "SpatialRemapper",
44
+ "TemporalOrchestrator",
45
+ "ConservationChecker",
46
+ "BMIMixin",
47
+ "ParameterManager",
48
+ "MultiObservationLoss",
49
+ "LossTerm",
50
+ "Trainer",
51
+ "TrainingResult",
52
+ ] + _losses.__all__ + _components.__all__ + _wrappers.__all__ + _observers.__all__ + _diagnostics.__all__ + _utils.__all__
@@ -0,0 +1,13 @@
1
+ __all__ = []
2
+
3
+ try:
4
+ from .fuse import FUSEComponent
5
+ __all__.append("FUSEComponent")
6
+ except Exception:
7
+ FUSEComponent = None # type: ignore
8
+
9
+ try:
10
+ from .routing import MuskingumCungeRouting
11
+ __all__.append("MuskingumCungeRouting")
12
+ except Exception:
13
+ MuskingumCungeRouting = None # type: ignore
@@ -0,0 +1,193 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List, Optional
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from dcoupler.core.component import (
8
+ DifferentiableComponent,
9
+ FluxDirection,
10
+ FluxSpec,
11
+ GradientMethod,
12
+ ParameterSpec,
13
+ )
14
+
15
+ try:
16
+ import cfuse
17
+ import cfuse_core
18
+ from cfuse.torch import DifferentiableFUSEBatch
19
+ except ImportError:
20
+ cfuse = None
21
+ cfuse_core = None
22
+ DifferentiableFUSEBatch = None
23
+
24
+
25
+ class FUSEComponent(DifferentiableComponent, nn.Module):
26
+ """FUSE component wrapper using DifferentiableFUSEBatch."""
27
+
28
+ def __init__(
29
+ self,
30
+ name: str,
31
+ fuse_config,
32
+ n_hrus: int,
33
+ dt: float = 86400.0,
34
+ spatial_params: bool = True,
35
+ n_states: Optional[int] = None,
36
+ ) -> None:
37
+ super().__init__()
38
+ if DifferentiableFUSEBatch is None or cfuse is None:
39
+ raise RuntimeError("cfuse and DifferentiableFUSEBatch are required for FUSEComponent")
40
+ if fuse_config is None:
41
+ raise ValueError("fuse_config is required")
42
+
43
+ self._name = name
44
+ self.fuse_config = fuse_config
45
+ self.config_dict = fuse_config.to_dict() if hasattr(fuse_config, "to_dict") else fuse_config
46
+ self.n_hrus = n_hrus
47
+ self.dt_seconds = float(dt)
48
+ self.dt_days = self.dt_seconds / 86400.0
49
+ self.spatial_params = spatial_params
50
+
51
+ if n_states is None:
52
+ if cfuse_core is None:
53
+ raise RuntimeError("cfuse_core required to infer state size")
54
+ self.n_states = int(cfuse_core.get_num_active_states(self.config_dict))
55
+ else:
56
+ self.n_states = int(n_states)
57
+
58
+ self.param_names = list(cfuse.PARAM_NAMES)
59
+ self.n_params = len(self.param_names)
60
+ lowers = torch.tensor([cfuse.PARAM_BOUNDS[n][0] for n in self.param_names], dtype=torch.float32)
61
+ uppers = torch.tensor([cfuse.PARAM_BOUNDS[n][1] for n in self.param_names], dtype=torch.float32)
62
+ self.register_buffer("param_lower", lowers)
63
+ self.register_buffer("param_upper", uppers)
64
+
65
+ self._raw_params = nn.ParameterList()
66
+ for _ in range(self.n_params):
67
+ if self.spatial_params:
68
+ init = torch.zeros(self.n_hrus) + torch.randn(self.n_hrus) * 0.2
69
+ else:
70
+ init = torch.zeros(())
71
+ self._raw_params.append(nn.Parameter(init))
72
+
73
+ @property
74
+ def name(self) -> str:
75
+ return self._name
76
+
77
+ @property
78
+ def input_fluxes(self) -> List[FluxSpec]:
79
+ return [
80
+ FluxSpec(
81
+ name="forcing",
82
+ units="mixed",
83
+ direction=FluxDirection.INPUT,
84
+ spatial_type="hru",
85
+ temporal_resolution=self.dt_seconds,
86
+ dims=("time", "hru", "var"),
87
+ optional=False,
88
+ )
89
+ ]
90
+
91
+ @property
92
+ def output_fluxes(self) -> List[FluxSpec]:
93
+ return [
94
+ FluxSpec(
95
+ name="runoff",
96
+ units="mm/day",
97
+ direction=FluxDirection.OUTPUT,
98
+ spatial_type="hru",
99
+ temporal_resolution=self.dt_seconds,
100
+ dims=("time", "hru"),
101
+ conserved_quantity="water_mass",
102
+ )
103
+ ]
104
+
105
+ @property
106
+ def parameters(self) -> List[ParameterSpec]:
107
+ specs: List[ParameterSpec] = []
108
+ for i, name in enumerate(self.param_names):
109
+ specs.append(
110
+ ParameterSpec(
111
+ name=name,
112
+ lower_bound=float(self.param_lower[i].item()),
113
+ upper_bound=float(self.param_upper[i].item()),
114
+ spatial=self.spatial_params,
115
+ n_spatial=self.n_hrus if self.spatial_params else None,
116
+ log_transform=False,
117
+ )
118
+ )
119
+ return specs
120
+
121
+ @property
122
+ def gradient_method(self) -> GradientMethod:
123
+ return GradientMethod.ENZYME
124
+
125
+ @property
126
+ def state_size(self) -> int:
127
+ return self.n_states
128
+
129
+ @property
130
+ def requires_batch(self) -> bool:
131
+ return True
132
+
133
+ def get_initial_state(self) -> torch.Tensor:
134
+ state = torch.zeros(self.n_hrus, self.n_states)
135
+ if self.n_states > 0:
136
+ state[:, 0] = 50.0
137
+ if self.n_states > 1:
138
+ state[:, 1] = 20.0
139
+ if self.n_states > 2:
140
+ state[:, 2] = 200.0
141
+ return state
142
+
143
+ def _raw_param_matrix(self) -> torch.Tensor:
144
+ if self.spatial_params:
145
+ return torch.stack(list(self._raw_params), dim=1)
146
+ return torch.stack(list(self._raw_params), dim=0)
147
+
148
+ def get_physical_parameters(self) -> Dict[str, torch.Tensor]:
149
+ raw = self._raw_param_matrix()
150
+ if self.spatial_params:
151
+ phys = self.param_lower + (self.param_upper - self.param_lower) * torch.sigmoid(raw)
152
+ else:
153
+ phys = self.param_lower + (self.param_upper - self.param_lower) * torch.sigmoid(raw)
154
+ return {name: phys[:, i] if self.spatial_params else phys[i] for i, name in enumerate(self.param_names)}
155
+
156
+ def _physical_param_tensor(self) -> torch.Tensor:
157
+ raw = self._raw_param_matrix()
158
+ if self.spatial_params:
159
+ return self.param_lower + (self.param_upper - self.param_lower) * torch.sigmoid(raw)
160
+ return self.param_lower + (self.param_upper - self.param_lower) * torch.sigmoid(raw)
161
+
162
+ def step(
163
+ self,
164
+ inputs: Dict[str, torch.Tensor],
165
+ state: torch.Tensor,
166
+ dt: float,
167
+ ):
168
+ raise RuntimeError("FUSEComponent requires batch execution")
169
+
170
+ def run(
171
+ self,
172
+ inputs: Dict[str, torch.Tensor],
173
+ state: torch.Tensor,
174
+ dt: float,
175
+ n_timesteps: int,
176
+ ):
177
+ forcing = inputs.get("forcing")
178
+ if forcing is None:
179
+ raise ValueError("Missing required input 'forcing'")
180
+ phys_params = self._physical_param_tensor()
181
+ if state is None:
182
+ state = self.get_initial_state().to(forcing.device)
183
+ runoff = DifferentiableFUSEBatch.apply(
184
+ phys_params,
185
+ state,
186
+ forcing,
187
+ self.config_dict,
188
+ self.dt_days,
189
+ )
190
+ return {"runoff": runoff}, state
191
+
192
+ def get_torch_parameters(self) -> List[torch.nn.Parameter]:
193
+ return list(self._raw_params)
@@ -0,0 +1,219 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List, Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import xarray as xr
9
+
10
+ import droute as dmc
11
+
12
+ from dcoupler.core.component import (
13
+ DifferentiableComponent,
14
+ FluxDirection,
15
+ FluxSpec,
16
+ GradientMethod,
17
+ ParameterSpec,
18
+ )
19
+ from dcoupler.wrappers.enzyme import DifferentiableRouting
20
+
21
+
22
+ class MuskingumCungeRouting(DifferentiableComponent, nn.Module):
23
+ """Muskingum-Cunge routing component with Enzyme AD."""
24
+
25
+ def __init__(
26
+ self,
27
+ name: str,
28
+ topology_file: str,
29
+ hru_areas: np.ndarray,
30
+ dt: float = 86400.0,
31
+ outlet_reach_id: Optional[int] = None,
32
+ ) -> None:
33
+ super().__init__()
34
+ self._name = name
35
+ self.dt_seconds = float(dt)
36
+
37
+ self.network = self._load_network(topology_file)
38
+ self.n_reaches = self.network.num_reaches()
39
+ self.n_hrus = len(hru_areas)
40
+
41
+ topo_order = list(self.network.topological_order())
42
+ self.reach_ids = topo_order
43
+ self.id_to_idx = {rid: i for i, rid in enumerate(topo_order)}
44
+
45
+ if outlet_reach_id is None:
46
+ outlet_reach_id = int(topo_order[-1])
47
+ self.outlet_reach_id = int(outlet_reach_id)
48
+
49
+ config = dmc.RouterConfig()
50
+ config.dt = self.dt_seconds
51
+ config.num_substeps = 4
52
+ config.enable_gradients = False
53
+
54
+ self.router = dmc.MuskingumCungeRouter(self.network, config)
55
+
56
+ initial_log_n = torch.full((self.n_reaches,), np.log(0.035))
57
+ initial_log_n = initial_log_n + torch.randn(self.n_reaches) * 0.1
58
+ self.log_manning_n = nn.Parameter(initial_log_n)
59
+
60
+ self.register_buffer(
61
+ "mapping_matrix",
62
+ self._build_mapping_matrix(topology_file, hru_areas),
63
+ )
64
+
65
+ @property
66
+ def name(self) -> str:
67
+ return self._name
68
+
69
+ @property
70
+ def input_fluxes(self) -> List[FluxSpec]:
71
+ return [
72
+ FluxSpec(
73
+ name="lateral_inflow",
74
+ units="m3/s",
75
+ direction=FluxDirection.INPUT,
76
+ spatial_type="reach",
77
+ temporal_resolution=self.dt_seconds,
78
+ dims=("time", "reach"),
79
+ optional=False,
80
+ )
81
+ ]
82
+
83
+ @property
84
+ def output_fluxes(self) -> List[FluxSpec]:
85
+ return [
86
+ FluxSpec(
87
+ name="discharge",
88
+ units="m3/s",
89
+ direction=FluxDirection.OUTPUT,
90
+ spatial_type="point",
91
+ temporal_resolution=self.dt_seconds,
92
+ dims=("time",),
93
+ conserved_quantity="water_mass",
94
+ )
95
+ ]
96
+
97
+ @property
98
+ def parameters(self) -> List[ParameterSpec]:
99
+ return [
100
+ ParameterSpec(
101
+ name="manning_n",
102
+ lower_bound=1e-4,
103
+ upper_bound=1.0,
104
+ spatial=True,
105
+ n_spatial=self.n_reaches,
106
+ log_transform=True,
107
+ )
108
+ ]
109
+
110
+ @property
111
+ def gradient_method(self) -> GradientMethod:
112
+ return GradientMethod.ENZYME
113
+
114
+ @property
115
+ def state_size(self) -> int:
116
+ return 0
117
+
118
+ @property
119
+ def requires_batch(self) -> bool:
120
+ return True
121
+
122
+ def get_initial_state(self) -> torch.Tensor:
123
+ return torch.empty(0)
124
+
125
+ def get_physical_parameters(self) -> Dict[str, torch.Tensor]:
126
+ return {"manning_n": torch.exp(self.log_manning_n)}
127
+
128
+ def step(self, inputs: Dict[str, torch.Tensor], state: torch.Tensor, dt: float):
129
+ raise RuntimeError("MuskingumCungeRouting requires batch execution")
130
+
131
+ def run(
132
+ self,
133
+ inputs: Dict[str, torch.Tensor],
134
+ state: torch.Tensor,
135
+ dt: float,
136
+ n_timesteps: int,
137
+ ):
138
+ lateral = inputs.get("lateral_inflow")
139
+ if lateral is None:
140
+ raise ValueError("Missing required input 'lateral_inflow'")
141
+ manning_n = torch.exp(self.log_manning_n)
142
+ discharge = DifferentiableRouting.apply(
143
+ lateral,
144
+ manning_n,
145
+ self.router,
146
+ self.network,
147
+ self.outlet_reach_id,
148
+ self.dt_seconds,
149
+ )
150
+ return {"discharge": discharge}, state
151
+
152
+ def get_torch_parameters(self) -> List[torch.nn.Parameter]:
153
+ return [self.log_manning_n]
154
+
155
+ def _load_network(self, topology_file: str) -> dmc.Network:
156
+ ds = xr.open_dataset(topology_file)
157
+
158
+ seg_ids = ds["segId"].values.astype(int)
159
+ down_seg_ids = ds["downSegId"].values.astype(int)
160
+ lengths = ds["length"].values.astype(float)
161
+ slopes = ds["slope"].values.astype(float)
162
+ mann_n = ds["mann_n"].values if "mann_n" in ds else np.full(len(seg_ids), 0.035)
163
+
164
+ network = dmc.Network()
165
+ seg_id_set = set(seg_ids)
166
+
167
+ upstream_map = {int(sid): [] for sid in seg_ids}
168
+ for i, down_id in enumerate(down_seg_ids):
169
+ if int(down_id) in seg_id_set:
170
+ upstream_map[int(down_id)].append(int(seg_ids[i]))
171
+
172
+ for i, sid in enumerate(seg_ids):
173
+ reach = dmc.Reach()
174
+ reach.id = int(sid)
175
+ reach.length = float(lengths[i])
176
+ reach.slope = max(float(slopes[i]), 0.0001)
177
+ reach.manning_n = float(mann_n[i])
178
+ reach.geometry.width_coef = 7.2
179
+ reach.geometry.width_exp = 0.5
180
+ reach.geometry.depth_coef = 0.27
181
+ reach.geometry.depth_exp = 0.3
182
+ reach.upstream_junction_id = int(sid)
183
+ down_id = int(down_seg_ids[i])
184
+ reach.downstream_junction_id = down_id if down_id in seg_id_set else -1
185
+ network.add_reach(reach)
186
+
187
+ for i, sid in enumerate(seg_ids):
188
+ junc = dmc.Junction()
189
+ junc.id = int(sid)
190
+ junc.upstream_reach_ids = upstream_map[int(sid)]
191
+ junc.downstream_reach_ids = [int(sid)]
192
+ network.add_junction(junc)
193
+
194
+ network.build_topology()
195
+ ds.close()
196
+ return network
197
+
198
+ def _build_mapping_matrix(self, topology_file: str, hru_areas: np.ndarray) -> torch.Tensor:
199
+ ds = xr.open_dataset(topology_file)
200
+ hru_to_seg = ds["hruToSegId"].values.astype(int)
201
+ ds.close()
202
+
203
+ topo_order = self.network.topological_order()
204
+ id_to_idx = {rid: i for i, rid in enumerate(topo_order)}
205
+
206
+ n_hrus = len(hru_to_seg)
207
+ n_reaches = len(topo_order)
208
+ mapping = torch.zeros((n_reaches, n_hrus), dtype=torch.float32)
209
+
210
+ for h_idx, seg_id in enumerate(hru_to_seg):
211
+ if seg_id in id_to_idx:
212
+ r_idx = id_to_idx[seg_id]
213
+ conversion = hru_areas[h_idx] / 1000.0 / 86400.0
214
+ mapping[r_idx, h_idx] = conversion
215
+
216
+ indices = mapping.nonzero(as_tuple=False).T
217
+ values = mapping[indices[0], indices[1]]
218
+ sparse = torch.sparse_coo_tensor(indices, values, mapping.shape)
219
+ return sparse
@@ -0,0 +1,26 @@
1
+ from .component import (
2
+ DifferentiableComponent,
3
+ FluxDirection,
4
+ FluxSpec,
5
+ GradientMethod,
6
+ ParameterSpec,
7
+ )
8
+ from .connection import FluxConnection, SpatialRemapper
9
+ from .temporal import TemporalOrchestrator
10
+ from .conservation import ConservationChecker
11
+ from .graph import CouplingGraph
12
+ from .bmi import BMIMixin
13
+
14
+ __all__ = [
15
+ "DifferentiableComponent",
16
+ "FluxDirection",
17
+ "FluxSpec",
18
+ "GradientMethod",
19
+ "ParameterSpec",
20
+ "FluxConnection",
21
+ "SpatialRemapper",
22
+ "TemporalOrchestrator",
23
+ "ConservationChecker",
24
+ "CouplingGraph",
25
+ "BMIMixin",
26
+ ]
dcoupler/core/bmi.py ADDED
@@ -0,0 +1,76 @@
1
+ """BMI-aligned lifecycle mixin for dCoupler components.
2
+
3
+ The Basic Model Interface (BMI) is a standardized interface for model
4
+ coupling. This mixin maps BMI lifecycle methods to the DifferentiableComponent
5
+ protocol so that existing components can be used in BMI-based workflows
6
+ without modifying their core implementation.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import abc
12
+ from typing import Any, Dict, List
13
+
14
+
15
+ class BMIMixin(abc.ABC):
16
+ """BMI-aligned lifecycle for dCoupler components.
17
+
18
+ Subclasses implement these methods to expose a BMI-style interface
19
+ while still participating in the dCoupler CouplingGraph.
20
+
21
+ BMI ↔ DifferentiableComponent mapping:
22
+ bmi_initialize(config) → initialize(config) + get_initial_state()
23
+ bmi_update(inputs, dt) → step(inputs, state, dt)
24
+ bmi_update_batch(…, n) → run(inputs, state, dt, n)
25
+ bmi_get_state() → return internal state tensor
26
+ bmi_set_state(state) → set internal state tensor
27
+ bmi_get_value(name) → index into last output dict
28
+ bmi_finalize() → cleanup hook
29
+ """
30
+
31
+ @abc.abstractmethod
32
+ def bmi_initialize(self, config: dict) -> None:
33
+ """One-time setup: parse config, allocate state, load data."""
34
+
35
+ @abc.abstractmethod
36
+ def bmi_update(self, inputs: Dict[str, Any], dt: float) -> Dict[str, Any]:
37
+ """Advance one timestep, return outputs dict."""
38
+
39
+ def bmi_update_batch(
40
+ self, inputs: Dict[str, Any], dt: float, n_timesteps: int
41
+ ) -> Dict[str, Any]:
42
+ """Advance multiple timesteps. Default: loop bmi_update."""
43
+ all_outputs: Dict[str, list] = {}
44
+ for _ in range(n_timesteps):
45
+ outputs = self.bmi_update(inputs, dt)
46
+ for k, v in outputs.items():
47
+ all_outputs.setdefault(k, []).append(v)
48
+ return all_outputs
49
+
50
+ @abc.abstractmethod
51
+ def bmi_finalize(self) -> None:
52
+ """Release resources."""
53
+
54
+ @abc.abstractmethod
55
+ def bmi_get_state(self) -> Any:
56
+ """Return the current internal state."""
57
+
58
+ @abc.abstractmethod
59
+ def bmi_set_state(self, state: Any) -> None:
60
+ """Overwrite the current internal state."""
61
+
62
+ @abc.abstractmethod
63
+ def bmi_get_value(self, name: str) -> Any:
64
+ """Get a named output/state variable by name."""
65
+
66
+ @abc.abstractmethod
67
+ def bmi_set_value(self, name: str, value: Any) -> None:
68
+ """Set a named input variable by name."""
69
+
70
+ @abc.abstractmethod
71
+ def bmi_get_output_var_names(self) -> List[str]:
72
+ """Return list of output variable names."""
73
+
74
+ @abc.abstractmethod
75
+ def bmi_get_input_var_names(self) -> List[str]:
76
+ """Return list of input variable names."""