dcoupler 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dcoupler/__init__.py +52 -0
- dcoupler/components/__init__.py +13 -0
- dcoupler/components/fuse.py +193 -0
- dcoupler/components/routing.py +219 -0
- dcoupler/core/__init__.py +26 -0
- dcoupler/core/bmi.py +76 -0
- dcoupler/core/component.py +158 -0
- dcoupler/core/connection.py +118 -0
- dcoupler/core/conservation.py +55 -0
- dcoupler/core/graph.py +394 -0
- dcoupler/core/temporal.py +31 -0
- dcoupler/diagnostics/__init__.py +3 -0
- dcoupler/diagnostics/gradients.py +125 -0
- dcoupler/losses/__init__.py +29 -0
- dcoupler/losses/hydrological.py +159 -0
- dcoupler/observers/__init__.py +4 -0
- dcoupler/observers/base.py +29 -0
- dcoupler/observers/streamflow.py +34 -0
- dcoupler/optimization/__init__.py +11 -0
- dcoupler/optimization/multi_observation.py +102 -0
- dcoupler/optimization/parameters.py +139 -0
- dcoupler/optimization/trainer.py +196 -0
- dcoupler/py.typed +0 -0
- dcoupler/utils/__init__.py +3 -0
- dcoupler/utils/temporal.py +45 -0
- dcoupler/wrappers/__init__.py +17 -0
- dcoupler/wrappers/enzyme.py +188 -0
- dcoupler/wrappers/jax.py +420 -0
- dcoupler/wrappers/process.py +248 -0
- dcoupler-0.2.0.dist-info/METADATA +101 -0
- dcoupler-0.2.0.dist-info/RECORD +34 -0
- dcoupler-0.2.0.dist-info/WHEEL +5 -0
- dcoupler-0.2.0.dist-info/licenses/LICENSE +191 -0
- dcoupler-0.2.0.dist-info/top_level.txt +1 -0
dcoupler/__init__.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from dcoupler.core import (
|
|
2
|
+
CouplingGraph,
|
|
3
|
+
DifferentiableComponent,
|
|
4
|
+
FluxDirection,
|
|
5
|
+
FluxSpec,
|
|
6
|
+
GradientMethod,
|
|
7
|
+
ParameterSpec,
|
|
8
|
+
FluxConnection,
|
|
9
|
+
SpatialRemapper,
|
|
10
|
+
TemporalOrchestrator,
|
|
11
|
+
ConservationChecker,
|
|
12
|
+
BMIMixin,
|
|
13
|
+
)
|
|
14
|
+
from dcoupler.optimization import (
|
|
15
|
+
ParameterManager,
|
|
16
|
+
MultiObservationLoss,
|
|
17
|
+
LossTerm,
|
|
18
|
+
Trainer,
|
|
19
|
+
TrainingResult,
|
|
20
|
+
)
|
|
21
|
+
from dcoupler import observers as _observers
|
|
22
|
+
from dcoupler import diagnostics as _diagnostics
|
|
23
|
+
from dcoupler import utils as _utils
|
|
24
|
+
from dcoupler.observers import * # noqa: F401,F403
|
|
25
|
+
from dcoupler.diagnostics import * # noqa: F401,F403
|
|
26
|
+
from dcoupler import losses as _losses
|
|
27
|
+
from dcoupler import components as _components
|
|
28
|
+
from dcoupler import wrappers as _wrappers
|
|
29
|
+
from dcoupler.losses import * # noqa: F401,F403
|
|
30
|
+
from dcoupler.components import * # noqa: F401,F403
|
|
31
|
+
from dcoupler.wrappers import * # noqa: F401,F403
|
|
32
|
+
|
|
33
|
+
__version__ = "0.2.0"
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
"CouplingGraph",
|
|
37
|
+
"DifferentiableComponent",
|
|
38
|
+
"FluxDirection",
|
|
39
|
+
"FluxSpec",
|
|
40
|
+
"GradientMethod",
|
|
41
|
+
"ParameterSpec",
|
|
42
|
+
"FluxConnection",
|
|
43
|
+
"SpatialRemapper",
|
|
44
|
+
"TemporalOrchestrator",
|
|
45
|
+
"ConservationChecker",
|
|
46
|
+
"BMIMixin",
|
|
47
|
+
"ParameterManager",
|
|
48
|
+
"MultiObservationLoss",
|
|
49
|
+
"LossTerm",
|
|
50
|
+
"Trainer",
|
|
51
|
+
"TrainingResult",
|
|
52
|
+
] + _losses.__all__ + _components.__all__ + _wrappers.__all__ + _observers.__all__ + _diagnostics.__all__ + _utils.__all__
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
__all__ = []
|
|
2
|
+
|
|
3
|
+
try:
|
|
4
|
+
from .fuse import FUSEComponent
|
|
5
|
+
__all__.append("FUSEComponent")
|
|
6
|
+
except Exception:
|
|
7
|
+
FUSEComponent = None # type: ignore
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from .routing import MuskingumCungeRouting
|
|
11
|
+
__all__.append("MuskingumCungeRouting")
|
|
12
|
+
except Exception:
|
|
13
|
+
MuskingumCungeRouting = None # type: ignore
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List, Optional
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
from dcoupler.core.component import (
|
|
8
|
+
DifferentiableComponent,
|
|
9
|
+
FluxDirection,
|
|
10
|
+
FluxSpec,
|
|
11
|
+
GradientMethod,
|
|
12
|
+
ParameterSpec,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import cfuse
|
|
17
|
+
import cfuse_core
|
|
18
|
+
from cfuse.torch import DifferentiableFUSEBatch
|
|
19
|
+
except ImportError:
|
|
20
|
+
cfuse = None
|
|
21
|
+
cfuse_core = None
|
|
22
|
+
DifferentiableFUSEBatch = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class FUSEComponent(DifferentiableComponent, nn.Module):
|
|
26
|
+
"""FUSE component wrapper using DifferentiableFUSEBatch."""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
name: str,
|
|
31
|
+
fuse_config,
|
|
32
|
+
n_hrus: int,
|
|
33
|
+
dt: float = 86400.0,
|
|
34
|
+
spatial_params: bool = True,
|
|
35
|
+
n_states: Optional[int] = None,
|
|
36
|
+
) -> None:
|
|
37
|
+
super().__init__()
|
|
38
|
+
if DifferentiableFUSEBatch is None or cfuse is None:
|
|
39
|
+
raise RuntimeError("cfuse and DifferentiableFUSEBatch are required for FUSEComponent")
|
|
40
|
+
if fuse_config is None:
|
|
41
|
+
raise ValueError("fuse_config is required")
|
|
42
|
+
|
|
43
|
+
self._name = name
|
|
44
|
+
self.fuse_config = fuse_config
|
|
45
|
+
self.config_dict = fuse_config.to_dict() if hasattr(fuse_config, "to_dict") else fuse_config
|
|
46
|
+
self.n_hrus = n_hrus
|
|
47
|
+
self.dt_seconds = float(dt)
|
|
48
|
+
self.dt_days = self.dt_seconds / 86400.0
|
|
49
|
+
self.spatial_params = spatial_params
|
|
50
|
+
|
|
51
|
+
if n_states is None:
|
|
52
|
+
if cfuse_core is None:
|
|
53
|
+
raise RuntimeError("cfuse_core required to infer state size")
|
|
54
|
+
self.n_states = int(cfuse_core.get_num_active_states(self.config_dict))
|
|
55
|
+
else:
|
|
56
|
+
self.n_states = int(n_states)
|
|
57
|
+
|
|
58
|
+
self.param_names = list(cfuse.PARAM_NAMES)
|
|
59
|
+
self.n_params = len(self.param_names)
|
|
60
|
+
lowers = torch.tensor([cfuse.PARAM_BOUNDS[n][0] for n in self.param_names], dtype=torch.float32)
|
|
61
|
+
uppers = torch.tensor([cfuse.PARAM_BOUNDS[n][1] for n in self.param_names], dtype=torch.float32)
|
|
62
|
+
self.register_buffer("param_lower", lowers)
|
|
63
|
+
self.register_buffer("param_upper", uppers)
|
|
64
|
+
|
|
65
|
+
self._raw_params = nn.ParameterList()
|
|
66
|
+
for _ in range(self.n_params):
|
|
67
|
+
if self.spatial_params:
|
|
68
|
+
init = torch.zeros(self.n_hrus) + torch.randn(self.n_hrus) * 0.2
|
|
69
|
+
else:
|
|
70
|
+
init = torch.zeros(())
|
|
71
|
+
self._raw_params.append(nn.Parameter(init))
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def name(self) -> str:
|
|
75
|
+
return self._name
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def input_fluxes(self) -> List[FluxSpec]:
|
|
79
|
+
return [
|
|
80
|
+
FluxSpec(
|
|
81
|
+
name="forcing",
|
|
82
|
+
units="mixed",
|
|
83
|
+
direction=FluxDirection.INPUT,
|
|
84
|
+
spatial_type="hru",
|
|
85
|
+
temporal_resolution=self.dt_seconds,
|
|
86
|
+
dims=("time", "hru", "var"),
|
|
87
|
+
optional=False,
|
|
88
|
+
)
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def output_fluxes(self) -> List[FluxSpec]:
|
|
93
|
+
return [
|
|
94
|
+
FluxSpec(
|
|
95
|
+
name="runoff",
|
|
96
|
+
units="mm/day",
|
|
97
|
+
direction=FluxDirection.OUTPUT,
|
|
98
|
+
spatial_type="hru",
|
|
99
|
+
temporal_resolution=self.dt_seconds,
|
|
100
|
+
dims=("time", "hru"),
|
|
101
|
+
conserved_quantity="water_mass",
|
|
102
|
+
)
|
|
103
|
+
]
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def parameters(self) -> List[ParameterSpec]:
|
|
107
|
+
specs: List[ParameterSpec] = []
|
|
108
|
+
for i, name in enumerate(self.param_names):
|
|
109
|
+
specs.append(
|
|
110
|
+
ParameterSpec(
|
|
111
|
+
name=name,
|
|
112
|
+
lower_bound=float(self.param_lower[i].item()),
|
|
113
|
+
upper_bound=float(self.param_upper[i].item()),
|
|
114
|
+
spatial=self.spatial_params,
|
|
115
|
+
n_spatial=self.n_hrus if self.spatial_params else None,
|
|
116
|
+
log_transform=False,
|
|
117
|
+
)
|
|
118
|
+
)
|
|
119
|
+
return specs
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def gradient_method(self) -> GradientMethod:
|
|
123
|
+
return GradientMethod.ENZYME
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def state_size(self) -> int:
|
|
127
|
+
return self.n_states
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def requires_batch(self) -> bool:
|
|
131
|
+
return True
|
|
132
|
+
|
|
133
|
+
def get_initial_state(self) -> torch.Tensor:
|
|
134
|
+
state = torch.zeros(self.n_hrus, self.n_states)
|
|
135
|
+
if self.n_states > 0:
|
|
136
|
+
state[:, 0] = 50.0
|
|
137
|
+
if self.n_states > 1:
|
|
138
|
+
state[:, 1] = 20.0
|
|
139
|
+
if self.n_states > 2:
|
|
140
|
+
state[:, 2] = 200.0
|
|
141
|
+
return state
|
|
142
|
+
|
|
143
|
+
def _raw_param_matrix(self) -> torch.Tensor:
|
|
144
|
+
if self.spatial_params:
|
|
145
|
+
return torch.stack(list(self._raw_params), dim=1)
|
|
146
|
+
return torch.stack(list(self._raw_params), dim=0)
|
|
147
|
+
|
|
148
|
+
def get_physical_parameters(self) -> Dict[str, torch.Tensor]:
|
|
149
|
+
raw = self._raw_param_matrix()
|
|
150
|
+
if self.spatial_params:
|
|
151
|
+
phys = self.param_lower + (self.param_upper - self.param_lower) * torch.sigmoid(raw)
|
|
152
|
+
else:
|
|
153
|
+
phys = self.param_lower + (self.param_upper - self.param_lower) * torch.sigmoid(raw)
|
|
154
|
+
return {name: phys[:, i] if self.spatial_params else phys[i] for i, name in enumerate(self.param_names)}
|
|
155
|
+
|
|
156
|
+
def _physical_param_tensor(self) -> torch.Tensor:
|
|
157
|
+
raw = self._raw_param_matrix()
|
|
158
|
+
if self.spatial_params:
|
|
159
|
+
return self.param_lower + (self.param_upper - self.param_lower) * torch.sigmoid(raw)
|
|
160
|
+
return self.param_lower + (self.param_upper - self.param_lower) * torch.sigmoid(raw)
|
|
161
|
+
|
|
162
|
+
def step(
|
|
163
|
+
self,
|
|
164
|
+
inputs: Dict[str, torch.Tensor],
|
|
165
|
+
state: torch.Tensor,
|
|
166
|
+
dt: float,
|
|
167
|
+
):
|
|
168
|
+
raise RuntimeError("FUSEComponent requires batch execution")
|
|
169
|
+
|
|
170
|
+
def run(
|
|
171
|
+
self,
|
|
172
|
+
inputs: Dict[str, torch.Tensor],
|
|
173
|
+
state: torch.Tensor,
|
|
174
|
+
dt: float,
|
|
175
|
+
n_timesteps: int,
|
|
176
|
+
):
|
|
177
|
+
forcing = inputs.get("forcing")
|
|
178
|
+
if forcing is None:
|
|
179
|
+
raise ValueError("Missing required input 'forcing'")
|
|
180
|
+
phys_params = self._physical_param_tensor()
|
|
181
|
+
if state is None:
|
|
182
|
+
state = self.get_initial_state().to(forcing.device)
|
|
183
|
+
runoff = DifferentiableFUSEBatch.apply(
|
|
184
|
+
phys_params,
|
|
185
|
+
state,
|
|
186
|
+
forcing,
|
|
187
|
+
self.config_dict,
|
|
188
|
+
self.dt_days,
|
|
189
|
+
)
|
|
190
|
+
return {"runoff": runoff}, state
|
|
191
|
+
|
|
192
|
+
def get_torch_parameters(self) -> List[torch.nn.Parameter]:
|
|
193
|
+
return list(self._raw_params)
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import xarray as xr
|
|
9
|
+
|
|
10
|
+
import droute as dmc
|
|
11
|
+
|
|
12
|
+
from dcoupler.core.component import (
|
|
13
|
+
DifferentiableComponent,
|
|
14
|
+
FluxDirection,
|
|
15
|
+
FluxSpec,
|
|
16
|
+
GradientMethod,
|
|
17
|
+
ParameterSpec,
|
|
18
|
+
)
|
|
19
|
+
from dcoupler.wrappers.enzyme import DifferentiableRouting
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MuskingumCungeRouting(DifferentiableComponent, nn.Module):
|
|
23
|
+
"""Muskingum-Cunge routing component with Enzyme AD."""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
name: str,
|
|
28
|
+
topology_file: str,
|
|
29
|
+
hru_areas: np.ndarray,
|
|
30
|
+
dt: float = 86400.0,
|
|
31
|
+
outlet_reach_id: Optional[int] = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
super().__init__()
|
|
34
|
+
self._name = name
|
|
35
|
+
self.dt_seconds = float(dt)
|
|
36
|
+
|
|
37
|
+
self.network = self._load_network(topology_file)
|
|
38
|
+
self.n_reaches = self.network.num_reaches()
|
|
39
|
+
self.n_hrus = len(hru_areas)
|
|
40
|
+
|
|
41
|
+
topo_order = list(self.network.topological_order())
|
|
42
|
+
self.reach_ids = topo_order
|
|
43
|
+
self.id_to_idx = {rid: i for i, rid in enumerate(topo_order)}
|
|
44
|
+
|
|
45
|
+
if outlet_reach_id is None:
|
|
46
|
+
outlet_reach_id = int(topo_order[-1])
|
|
47
|
+
self.outlet_reach_id = int(outlet_reach_id)
|
|
48
|
+
|
|
49
|
+
config = dmc.RouterConfig()
|
|
50
|
+
config.dt = self.dt_seconds
|
|
51
|
+
config.num_substeps = 4
|
|
52
|
+
config.enable_gradients = False
|
|
53
|
+
|
|
54
|
+
self.router = dmc.MuskingumCungeRouter(self.network, config)
|
|
55
|
+
|
|
56
|
+
initial_log_n = torch.full((self.n_reaches,), np.log(0.035))
|
|
57
|
+
initial_log_n = initial_log_n + torch.randn(self.n_reaches) * 0.1
|
|
58
|
+
self.log_manning_n = nn.Parameter(initial_log_n)
|
|
59
|
+
|
|
60
|
+
self.register_buffer(
|
|
61
|
+
"mapping_matrix",
|
|
62
|
+
self._build_mapping_matrix(topology_file, hru_areas),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def name(self) -> str:
|
|
67
|
+
return self._name
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def input_fluxes(self) -> List[FluxSpec]:
|
|
71
|
+
return [
|
|
72
|
+
FluxSpec(
|
|
73
|
+
name="lateral_inflow",
|
|
74
|
+
units="m3/s",
|
|
75
|
+
direction=FluxDirection.INPUT,
|
|
76
|
+
spatial_type="reach",
|
|
77
|
+
temporal_resolution=self.dt_seconds,
|
|
78
|
+
dims=("time", "reach"),
|
|
79
|
+
optional=False,
|
|
80
|
+
)
|
|
81
|
+
]
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def output_fluxes(self) -> List[FluxSpec]:
|
|
85
|
+
return [
|
|
86
|
+
FluxSpec(
|
|
87
|
+
name="discharge",
|
|
88
|
+
units="m3/s",
|
|
89
|
+
direction=FluxDirection.OUTPUT,
|
|
90
|
+
spatial_type="point",
|
|
91
|
+
temporal_resolution=self.dt_seconds,
|
|
92
|
+
dims=("time",),
|
|
93
|
+
conserved_quantity="water_mass",
|
|
94
|
+
)
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def parameters(self) -> List[ParameterSpec]:
|
|
99
|
+
return [
|
|
100
|
+
ParameterSpec(
|
|
101
|
+
name="manning_n",
|
|
102
|
+
lower_bound=1e-4,
|
|
103
|
+
upper_bound=1.0,
|
|
104
|
+
spatial=True,
|
|
105
|
+
n_spatial=self.n_reaches,
|
|
106
|
+
log_transform=True,
|
|
107
|
+
)
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def gradient_method(self) -> GradientMethod:
|
|
112
|
+
return GradientMethod.ENZYME
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def state_size(self) -> int:
|
|
116
|
+
return 0
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def requires_batch(self) -> bool:
|
|
120
|
+
return True
|
|
121
|
+
|
|
122
|
+
def get_initial_state(self) -> torch.Tensor:
|
|
123
|
+
return torch.empty(0)
|
|
124
|
+
|
|
125
|
+
def get_physical_parameters(self) -> Dict[str, torch.Tensor]:
|
|
126
|
+
return {"manning_n": torch.exp(self.log_manning_n)}
|
|
127
|
+
|
|
128
|
+
def step(self, inputs: Dict[str, torch.Tensor], state: torch.Tensor, dt: float):
|
|
129
|
+
raise RuntimeError("MuskingumCungeRouting requires batch execution")
|
|
130
|
+
|
|
131
|
+
def run(
|
|
132
|
+
self,
|
|
133
|
+
inputs: Dict[str, torch.Tensor],
|
|
134
|
+
state: torch.Tensor,
|
|
135
|
+
dt: float,
|
|
136
|
+
n_timesteps: int,
|
|
137
|
+
):
|
|
138
|
+
lateral = inputs.get("lateral_inflow")
|
|
139
|
+
if lateral is None:
|
|
140
|
+
raise ValueError("Missing required input 'lateral_inflow'")
|
|
141
|
+
manning_n = torch.exp(self.log_manning_n)
|
|
142
|
+
discharge = DifferentiableRouting.apply(
|
|
143
|
+
lateral,
|
|
144
|
+
manning_n,
|
|
145
|
+
self.router,
|
|
146
|
+
self.network,
|
|
147
|
+
self.outlet_reach_id,
|
|
148
|
+
self.dt_seconds,
|
|
149
|
+
)
|
|
150
|
+
return {"discharge": discharge}, state
|
|
151
|
+
|
|
152
|
+
def get_torch_parameters(self) -> List[torch.nn.Parameter]:
|
|
153
|
+
return [self.log_manning_n]
|
|
154
|
+
|
|
155
|
+
def _load_network(self, topology_file: str) -> dmc.Network:
|
|
156
|
+
ds = xr.open_dataset(topology_file)
|
|
157
|
+
|
|
158
|
+
seg_ids = ds["segId"].values.astype(int)
|
|
159
|
+
down_seg_ids = ds["downSegId"].values.astype(int)
|
|
160
|
+
lengths = ds["length"].values.astype(float)
|
|
161
|
+
slopes = ds["slope"].values.astype(float)
|
|
162
|
+
mann_n = ds["mann_n"].values if "mann_n" in ds else np.full(len(seg_ids), 0.035)
|
|
163
|
+
|
|
164
|
+
network = dmc.Network()
|
|
165
|
+
seg_id_set = set(seg_ids)
|
|
166
|
+
|
|
167
|
+
upstream_map = {int(sid): [] for sid in seg_ids}
|
|
168
|
+
for i, down_id in enumerate(down_seg_ids):
|
|
169
|
+
if int(down_id) in seg_id_set:
|
|
170
|
+
upstream_map[int(down_id)].append(int(seg_ids[i]))
|
|
171
|
+
|
|
172
|
+
for i, sid in enumerate(seg_ids):
|
|
173
|
+
reach = dmc.Reach()
|
|
174
|
+
reach.id = int(sid)
|
|
175
|
+
reach.length = float(lengths[i])
|
|
176
|
+
reach.slope = max(float(slopes[i]), 0.0001)
|
|
177
|
+
reach.manning_n = float(mann_n[i])
|
|
178
|
+
reach.geometry.width_coef = 7.2
|
|
179
|
+
reach.geometry.width_exp = 0.5
|
|
180
|
+
reach.geometry.depth_coef = 0.27
|
|
181
|
+
reach.geometry.depth_exp = 0.3
|
|
182
|
+
reach.upstream_junction_id = int(sid)
|
|
183
|
+
down_id = int(down_seg_ids[i])
|
|
184
|
+
reach.downstream_junction_id = down_id if down_id in seg_id_set else -1
|
|
185
|
+
network.add_reach(reach)
|
|
186
|
+
|
|
187
|
+
for i, sid in enumerate(seg_ids):
|
|
188
|
+
junc = dmc.Junction()
|
|
189
|
+
junc.id = int(sid)
|
|
190
|
+
junc.upstream_reach_ids = upstream_map[int(sid)]
|
|
191
|
+
junc.downstream_reach_ids = [int(sid)]
|
|
192
|
+
network.add_junction(junc)
|
|
193
|
+
|
|
194
|
+
network.build_topology()
|
|
195
|
+
ds.close()
|
|
196
|
+
return network
|
|
197
|
+
|
|
198
|
+
def _build_mapping_matrix(self, topology_file: str, hru_areas: np.ndarray) -> torch.Tensor:
|
|
199
|
+
ds = xr.open_dataset(topology_file)
|
|
200
|
+
hru_to_seg = ds["hruToSegId"].values.astype(int)
|
|
201
|
+
ds.close()
|
|
202
|
+
|
|
203
|
+
topo_order = self.network.topological_order()
|
|
204
|
+
id_to_idx = {rid: i for i, rid in enumerate(topo_order)}
|
|
205
|
+
|
|
206
|
+
n_hrus = len(hru_to_seg)
|
|
207
|
+
n_reaches = len(topo_order)
|
|
208
|
+
mapping = torch.zeros((n_reaches, n_hrus), dtype=torch.float32)
|
|
209
|
+
|
|
210
|
+
for h_idx, seg_id in enumerate(hru_to_seg):
|
|
211
|
+
if seg_id in id_to_idx:
|
|
212
|
+
r_idx = id_to_idx[seg_id]
|
|
213
|
+
conversion = hru_areas[h_idx] / 1000.0 / 86400.0
|
|
214
|
+
mapping[r_idx, h_idx] = conversion
|
|
215
|
+
|
|
216
|
+
indices = mapping.nonzero(as_tuple=False).T
|
|
217
|
+
values = mapping[indices[0], indices[1]]
|
|
218
|
+
sparse = torch.sparse_coo_tensor(indices, values, mapping.shape)
|
|
219
|
+
return sparse
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from .component import (
|
|
2
|
+
DifferentiableComponent,
|
|
3
|
+
FluxDirection,
|
|
4
|
+
FluxSpec,
|
|
5
|
+
GradientMethod,
|
|
6
|
+
ParameterSpec,
|
|
7
|
+
)
|
|
8
|
+
from .connection import FluxConnection, SpatialRemapper
|
|
9
|
+
from .temporal import TemporalOrchestrator
|
|
10
|
+
from .conservation import ConservationChecker
|
|
11
|
+
from .graph import CouplingGraph
|
|
12
|
+
from .bmi import BMIMixin
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"DifferentiableComponent",
|
|
16
|
+
"FluxDirection",
|
|
17
|
+
"FluxSpec",
|
|
18
|
+
"GradientMethod",
|
|
19
|
+
"ParameterSpec",
|
|
20
|
+
"FluxConnection",
|
|
21
|
+
"SpatialRemapper",
|
|
22
|
+
"TemporalOrchestrator",
|
|
23
|
+
"ConservationChecker",
|
|
24
|
+
"CouplingGraph",
|
|
25
|
+
"BMIMixin",
|
|
26
|
+
]
|
dcoupler/core/bmi.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""BMI-aligned lifecycle mixin for dCoupler components.
|
|
2
|
+
|
|
3
|
+
The Basic Model Interface (BMI) is a standardized interface for model
|
|
4
|
+
coupling. This mixin maps BMI lifecycle methods to the DifferentiableComponent
|
|
5
|
+
protocol so that existing components can be used in BMI-based workflows
|
|
6
|
+
without modifying their core implementation.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import abc
|
|
12
|
+
from typing import Any, Dict, List
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BMIMixin(abc.ABC):
|
|
16
|
+
"""BMI-aligned lifecycle for dCoupler components.
|
|
17
|
+
|
|
18
|
+
Subclasses implement these methods to expose a BMI-style interface
|
|
19
|
+
while still participating in the dCoupler CouplingGraph.
|
|
20
|
+
|
|
21
|
+
BMI ↔ DifferentiableComponent mapping:
|
|
22
|
+
bmi_initialize(config) → initialize(config) + get_initial_state()
|
|
23
|
+
bmi_update(inputs, dt) → step(inputs, state, dt)
|
|
24
|
+
bmi_update_batch(…, n) → run(inputs, state, dt, n)
|
|
25
|
+
bmi_get_state() → return internal state tensor
|
|
26
|
+
bmi_set_state(state) → set internal state tensor
|
|
27
|
+
bmi_get_value(name) → index into last output dict
|
|
28
|
+
bmi_finalize() → cleanup hook
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
@abc.abstractmethod
|
|
32
|
+
def bmi_initialize(self, config: dict) -> None:
|
|
33
|
+
"""One-time setup: parse config, allocate state, load data."""
|
|
34
|
+
|
|
35
|
+
@abc.abstractmethod
|
|
36
|
+
def bmi_update(self, inputs: Dict[str, Any], dt: float) -> Dict[str, Any]:
|
|
37
|
+
"""Advance one timestep, return outputs dict."""
|
|
38
|
+
|
|
39
|
+
def bmi_update_batch(
|
|
40
|
+
self, inputs: Dict[str, Any], dt: float, n_timesteps: int
|
|
41
|
+
) -> Dict[str, Any]:
|
|
42
|
+
"""Advance multiple timesteps. Default: loop bmi_update."""
|
|
43
|
+
all_outputs: Dict[str, list] = {}
|
|
44
|
+
for _ in range(n_timesteps):
|
|
45
|
+
outputs = self.bmi_update(inputs, dt)
|
|
46
|
+
for k, v in outputs.items():
|
|
47
|
+
all_outputs.setdefault(k, []).append(v)
|
|
48
|
+
return all_outputs
|
|
49
|
+
|
|
50
|
+
@abc.abstractmethod
|
|
51
|
+
def bmi_finalize(self) -> None:
|
|
52
|
+
"""Release resources."""
|
|
53
|
+
|
|
54
|
+
@abc.abstractmethod
|
|
55
|
+
def bmi_get_state(self) -> Any:
|
|
56
|
+
"""Return the current internal state."""
|
|
57
|
+
|
|
58
|
+
@abc.abstractmethod
|
|
59
|
+
def bmi_set_state(self, state: Any) -> None:
|
|
60
|
+
"""Overwrite the current internal state."""
|
|
61
|
+
|
|
62
|
+
@abc.abstractmethod
|
|
63
|
+
def bmi_get_value(self, name: str) -> Any:
|
|
64
|
+
"""Get a named output/state variable by name."""
|
|
65
|
+
|
|
66
|
+
@abc.abstractmethod
|
|
67
|
+
def bmi_set_value(self, name: str, value: Any) -> None:
|
|
68
|
+
"""Set a named input variable by name."""
|
|
69
|
+
|
|
70
|
+
@abc.abstractmethod
|
|
71
|
+
def bmi_get_output_var_names(self) -> List[str]:
|
|
72
|
+
"""Return list of output variable names."""
|
|
73
|
+
|
|
74
|
+
@abc.abstractmethod
|
|
75
|
+
def bmi_get_input_var_names(self) -> List[str]:
|
|
76
|
+
"""Return list of input variable names."""
|