terrax 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- terrax/__init__.py +17 -0
- terrax/atmosphere/__init__.py +13 -0
- terrax/atmosphere/diagnostics.py +656 -0
- terrax/atmosphere/diagnostics_test.py +247 -0
- terrax/atmosphere/equations.py +489 -0
- terrax/atmosphere/equations_test.py +121 -0
- terrax/atmosphere/fixers.py +248 -0
- terrax/atmosphere/fixers_test.py +118 -0
- terrax/atmosphere/idealized_states.py +252 -0
- terrax/atmosphere/idealized_states_test.py +154 -0
- terrax/atmosphere/interpolators.py +400 -0
- terrax/atmosphere/interpolators_test.py +398 -0
- terrax/atmosphere/observation_operators.py +96 -0
- terrax/atmosphere/observation_operators_test.py +101 -0
- terrax/atmosphere/parameterizations.py +94 -0
- terrax/atmosphere/state_conversion.py +208 -0
- terrax/atmosphere/transforms.py +265 -0
- terrax/atmosphere/transforms_test.py +179 -0
- terrax/auxiliary/__init__.py +13 -0
- terrax/auxiliary/models.py +400 -0
- terrax/auxiliary/models_test.py +384 -0
- terrax/core/__init__.py +13 -0
- terrax/core/api.py +1089 -0
- terrax/core/api_test.py +735 -0
- terrax/core/array_transforms.py +78 -0
- terrax/core/boundaries.py +148 -0
- terrax/core/boundaries_test.py +114 -0
- terrax/core/checkpointing.py +106 -0
- terrax/core/checkpointing_test.py +88 -0
- terrax/core/coordinates.py +1875 -0
- terrax/core/coordinates_test.py +607 -0
- terrax/core/data_specs.py +514 -0
- terrax/core/data_specs_test.py +423 -0
- terrax/core/diagnostics.py +581 -0
- terrax/core/diagnostics_test.py +734 -0
- terrax/core/dynamic_io.py +169 -0
- terrax/core/equations.py +112 -0
- terrax/core/feature_transforms.py +218 -0
- terrax/core/feature_transforms_test.py +244 -0
- terrax/core/fiddle_tags.py +21 -0
- terrax/core/field_utils.py +400 -0
- terrax/core/field_utils_test.py +637 -0
- terrax/core/interpolators.py +181 -0
- terrax/core/interpolators_test.py +63 -0
- terrax/core/learned_transforms.py +503 -0
- terrax/core/learned_transforms_test.py +436 -0
- terrax/core/module_utils.py +685 -0
- terrax/core/module_utils_test.py +569 -0
- terrax/core/normalizations.py +233 -0
- terrax/core/normalizations_test.py +346 -0
- terrax/core/observation_operators.py +227 -0
- terrax/core/observation_operators_test.py +359 -0
- terrax/core/orographies.py +166 -0
- terrax/core/parallelism.py +577 -0
- terrax/core/parallelism_test.py +323 -0
- terrax/core/pytree_utils.py +254 -0
- terrax/core/pytree_utils_test.py +311 -0
- terrax/core/random_processes.py +1151 -0
- terrax/core/random_processes_test.py +1068 -0
- terrax/core/scales.py +227 -0
- terrax/core/scan_utils.py +284 -0
- terrax/core/scan_utils_test.py +529 -0
- terrax/core/spatial_filters.py +132 -0
- terrax/core/spatial_filters_test.py +76 -0
- terrax/core/spherical_harmonics.py +549 -0
- terrax/core/spherical_harmonics_test.py +444 -0
- terrax/core/standard_layers.py +673 -0
- terrax/core/standard_layers_test.py +569 -0
- terrax/core/step_filters.py +70 -0
- terrax/core/step_filters_test.py +41 -0
- terrax/core/time_integrators.py +147 -0
- terrax/core/towers.py +389 -0
- terrax/core/towers_test.py +521 -0
- terrax/core/transformer_layers.py +1414 -0
- terrax/core/transformer_layers_test.py +291 -0
- terrax/core/transforms.py +2536 -0
- terrax/core/transforms_test.py +1908 -0
- terrax/core/typing.py +207 -0
- terrax/core/units.py +246 -0
- terrax/core/xarray_utils.py +369 -0
- terrax/core/xarray_utils_test.py +352 -0
- terrax/couplers/__init__.py +13 -0
- terrax/couplers/generic.py +132 -0
- terrax/couplers/generic_test.py +83 -0
- terrax/inference/__init__.py +13 -0
- terrax/inference/dynamic_inputs.py +377 -0
- terrax/inference/dynamic_inputs_test.py +260 -0
- terrax/inference/runner.py +732 -0
- terrax/inference/runner_test.py +393 -0
- terrax/inference/streaming.py +73 -0
- terrax/inference/streaming_test.py +68 -0
- terrax/inference/timing.py +107 -0
- terrax/inference/timing_test.py +53 -0
- terrax/init_test.py +27 -0
- terrax/metrics/__init__.py +24 -0
- terrax/metrics/aggregation.py +368 -0
- terrax/metrics/aggregation_test.py +94 -0
- terrax/metrics/base.py +454 -0
- terrax/metrics/binning.py +90 -0
- terrax/metrics/deterministic_losses.py +51 -0
- terrax/metrics/deterministic_metrics.py +234 -0
- terrax/metrics/deterministic_test.py +283 -0
- terrax/metrics/evaluators.py +446 -0
- terrax/metrics/evaluators_test.py +501 -0
- terrax/metrics/probabilistic_losses.py +53 -0
- terrax/metrics/probabilistic_metrics.py +353 -0
- terrax/metrics/probabilistic_test.py +185 -0
- terrax/metrics/scaling.py +554 -0
- terrax/metrics/scaling_test.py +193 -0
- terrax/metrics/weighting.py +205 -0
- terrax/metrics/weighting_test.py +128 -0
- terrax/toy_model_examples/__init__.py +13 -0
- terrax/toy_model_examples/data_model.py +156 -0
- terrax/toy_model_examples/data_model_test.py +98 -0
- terrax/toy_model_examples/lorenz96.py +296 -0
- terrax/toy_model_examples/lorenz96_test.py +152 -0
- terrax/training/__init__.py +13 -0
- terrax/training/checkpointing.py +189 -0
- terrax/training/data_loading.py +1233 -0
- terrax/training/data_loading_test.py +692 -0
- terrax/training/model_calibrators.py +247 -0
- terrax/training/model_calibrators_test.py +247 -0
- terrax/training/train_utils.py +299 -0
- terrax/training/trainer.py +2135 -0
- terrax/training/trainer_test.py +580 -0
- terrax/xreader/__init__.py +26 -0
- terrax/xreader/iterators.py +422 -0
- terrax/xreader/iterators_test.py +268 -0
- terrax/xreader/stencils.py +269 -0
- terrax/xreader/stencils_test.py +349 -0
- terrax-0.0.5.dist-info/METADATA +53 -0
- terrax-0.0.5.dist-info/RECORD +135 -0
- terrax-0.0.5.dist-info/WHEEL +5 -0
- terrax-0.0.5.dist-info/licenses/LICENSE +202 -0
- terrax-0.0.5.dist-info/top_level.txt +1 -0
terrax/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Terrax: Library for AI-first Earth System Modeling."""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
__version__ = "0.0.5" # keep in sync with pyproject.toml
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2024 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,656 @@
|
|
|
1
|
+
# Copyright 2024 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Module-based API for calculating diagnostics of NeuralGCM models."""
|
|
16
|
+
|
|
17
|
+
from typing import Literal, Protocol
|
|
18
|
+
|
|
19
|
+
import coordax as cx
|
|
20
|
+
from flax import nnx
|
|
21
|
+
import jax.numpy as jnp
|
|
22
|
+
from terrax.core import coordinates
|
|
23
|
+
from terrax.core import observation_operators
|
|
24
|
+
from terrax.core import orographies
|
|
25
|
+
from terrax.core import spherical_harmonics
|
|
26
|
+
from terrax.core import transforms
|
|
27
|
+
from terrax.core import typing
|
|
28
|
+
from terrax.core import units
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class EnergyBalanceModule(Protocol):
|
|
32
|
+
"""Protocol for energy balance modules that adjust tendencies."""
|
|
33
|
+
|
|
34
|
+
def __call__(
|
|
35
|
+
self,
|
|
36
|
+
imbalance: cx.Field,
|
|
37
|
+
tendencies: dict[str, cx.Field],
|
|
38
|
+
*args,
|
|
39
|
+
**kwargs,
|
|
40
|
+
) -> dict[str, cx.Field]:
|
|
41
|
+
"""Adjusts tendencies based on energy imbalance."""
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@nnx.dataclass
|
|
45
|
+
class ExtractPrecipitationPlusEvaporation(nnx.Module):
|
|
46
|
+
"""Diagnoses precipitation plus evaporation rate from physics tendencies.
|
|
47
|
+
|
|
48
|
+
The computation of P + E is based on the integration of non-dynamical moisture
|
|
49
|
+
tendency over the vertical column. We define precipitation and evaporation
|
|
50
|
+
rates as the rate of change of non-atmospheric moisture, i.e. resulting in
|
|
51
|
+
positive values for precipitation and negative values for evaporation. This
|
|
52
|
+
is in line with how these quantities are often defined in datasets like ERA5
|
|
53
|
+
or IMERG. This is also in line with the convention of having "downward"
|
|
54
|
+
fluxes as positive and "upward" fluxes as negative.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
ylm_map: spherical_harmonics.FixedYlmMapping
|
|
58
|
+
levels: coordinates.SigmaLevels | coordinates.HybridLevels
|
|
59
|
+
sim_units: units.SimUnits
|
|
60
|
+
moisture_species: tuple[str, ...] = (
|
|
61
|
+
'specific_humidity',
|
|
62
|
+
'specific_cloud_ice_water_content',
|
|
63
|
+
'specific_cloud_liquid_water_content',
|
|
64
|
+
)
|
|
65
|
+
prognostics_arg_key: str | int = 'prognostics'
|
|
66
|
+
|
|
67
|
+
def _compute_p_plus_e_rate(
|
|
68
|
+
self,
|
|
69
|
+
tendencies: dict[str, cx.Field],
|
|
70
|
+
prognostics: dict[str, cx.Field],
|
|
71
|
+
) -> dict[str, cx.Field]:
|
|
72
|
+
to_nodal = self.ylm_map.to_nodal
|
|
73
|
+
p_surface = cx.cmap(jnp.exp)(to_nodal(prognostics['log_surface_pressure']))
|
|
74
|
+
scale = 1 / self.sim_units.gravity_acceleration
|
|
75
|
+
moisture_tendencies_nodal = [
|
|
76
|
+
to_nodal(v) for k, v in tendencies.items() if k in self.moisture_species
|
|
77
|
+
]
|
|
78
|
+
moisture_tendencies_sum = sum(moisture_tendencies_nodal)
|
|
79
|
+
assert isinstance(moisture_tendencies_sum, cx.Field)
|
|
80
|
+
p_plus_e = -scale * self.levels.integrate_over_pressure(
|
|
81
|
+
moisture_tendencies_sum,
|
|
82
|
+
p_surface,
|
|
83
|
+
self.sim_units,
|
|
84
|
+
)
|
|
85
|
+
return p_plus_e
|
|
86
|
+
|
|
87
|
+
def __call__(self, inputs, *args, **kwargs) -> dict[str, cx.Field]:
|
|
88
|
+
tendencies = inputs
|
|
89
|
+
if isinstance(self.prognostics_arg_key, int):
|
|
90
|
+
prognostics = args[self.prognostics_arg_key]
|
|
91
|
+
else:
|
|
92
|
+
prognostics = kwargs.get(self.prognostics_arg_key)
|
|
93
|
+
p_plus_e_rate = self._compute_p_plus_e_rate(tendencies, prognostics)
|
|
94
|
+
return {'precipitation_plus_evaporation_rate': p_plus_e_rate}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
PrecipitationScales = Literal['rate', 'cumulative', 'mass_rate']
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@nnx.dataclass
|
|
101
|
+
class ExtractPrecipitationAndEvaporation(nnx.Module):
|
|
102
|
+
"""Extracts balanced precipitation and evaporation values.
|
|
103
|
+
|
|
104
|
+
This module can be attached in diagnostics that have access to both
|
|
105
|
+
parameterization tendencies and model state to infer balanced precipitation
|
|
106
|
+
and evaporation. We use `observation_operator` to predict on of the
|
|
107
|
+
two (either `precipitation` or `evaporation`) and infer the other from the
|
|
108
|
+
precipitation_plus_evaporation calculation. The mode is defined by the
|
|
109
|
+
provided operator, query and inference variable indicating which variable
|
|
110
|
+
will be computed from the balance equations.
|
|
111
|
+
|
|
112
|
+
Attributes:
|
|
113
|
+
observation_operator: Observation operator used to predict one of the two
|
|
114
|
+
variables from the balance equations.
|
|
115
|
+
operator_query: Query used for the observation operator.
|
|
116
|
+
extract_p_plus_e: Module that extracts precipitation plus evaporation from
|
|
117
|
+
tendencies and prognostics.
|
|
118
|
+
prognostics_arg_key: Key or index of the prognostics argument in the call
|
|
119
|
+
signature.
|
|
120
|
+
precipitation_scaling: Scaling strategy for the precipitation field. Must be
|
|
121
|
+
one of `rate`, `mass_rate` or `cumulative`. If using `cumulative` scaling,
|
|
122
|
+
`dt` must be set.
|
|
123
|
+
evaporation_scaling: Scaling strategy for the evaporation field. Must be one
|
|
124
|
+
of `rate`, `mass_rate` or `cumulative`. If using `cumulative` scaling,
|
|
125
|
+
`dt` must be set.
|
|
126
|
+
dt: Timestep by which the precipitation is scaled (only used when
|
|
127
|
+
`precipitation_scaling` is set to `cumulative`).
|
|
128
|
+
sim_units: Object defining nondimensionalization and physical constants.
|
|
129
|
+
precipitation_key: Key under which the precipitation field is stored in the
|
|
130
|
+
output.
|
|
131
|
+
evaporation_key: Key under which the evaporation field is stored in the
|
|
132
|
+
output.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
observation_operator: typing.ObservationOperator = nnx.data()
|
|
136
|
+
operator_query: dict[str, cx.Coordinate]
|
|
137
|
+
extract_p_plus_e: ExtractPrecipitationPlusEvaporation
|
|
138
|
+
prognostics_arg_key: str | int = 'prognostics'
|
|
139
|
+
precipitation_scaling: PrecipitationScales = 'rate'
|
|
140
|
+
evaporation_scaling: PrecipitationScales = 'rate'
|
|
141
|
+
dt: float | None = None
|
|
142
|
+
precipitation_key: str = 'precipitation'
|
|
143
|
+
evaporation_key: str = 'evaporation'
|
|
144
|
+
sim_units: units.SimUnits = nnx.static(kw_only=True)
|
|
145
|
+
|
|
146
|
+
def __post_init__(self):
|
|
147
|
+
valid_keys = set([self.precipitation_key, self.evaporation_key])
|
|
148
|
+
query_keys = set(self.operator_query.keys())
|
|
149
|
+
if len(query_keys.intersection(valid_keys)) != 1:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
f'{self.operator_query=} should contain exactly on of {valid_keys=}.'
|
|
152
|
+
)
|
|
153
|
+
[self.observe_key] = valid_keys.intersection(query_keys)
|
|
154
|
+
[self.diagnosed_key] = valid_keys.difference(query_keys)
|
|
155
|
+
|
|
156
|
+
def _extract_prognostics(self, *args, **kwargs):
|
|
157
|
+
if isinstance(self.prognostics_arg_key, int):
|
|
158
|
+
prognostics = args[self.prognostics_arg_key]
|
|
159
|
+
else:
|
|
160
|
+
prognostics = kwargs.get(self.prognostics_arg_key)
|
|
161
|
+
if not isinstance(prognostics, dict):
|
|
162
|
+
raise ValueError(
|
|
163
|
+
f'Prognostics must be a dictionary, got {type(prognostics)=} instead.'
|
|
164
|
+
)
|
|
165
|
+
return prognostics
|
|
166
|
+
|
|
167
|
+
def _apply_scaling(self, precipitation_and_evaporation):
|
|
168
|
+
water_density = self.sim_units.water_density
|
|
169
|
+
for key, scaling in zip(
|
|
170
|
+
[self.precipitation_key, self.evaporation_key],
|
|
171
|
+
[self.precipitation_scaling, self.evaporation_scaling],
|
|
172
|
+
):
|
|
173
|
+
if scaling == 'cumulative':
|
|
174
|
+
if self.dt is None:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
'dt must be provided when using cumulative precipitation scaling.'
|
|
177
|
+
)
|
|
178
|
+
precipitation_and_evaporation[key] *= self.dt / water_density
|
|
179
|
+
elif scaling == 'rate':
|
|
180
|
+
precipitation_and_evaporation[key] *= 1 / water_density
|
|
181
|
+
elif scaling == 'mass_rate':
|
|
182
|
+
continue
|
|
183
|
+
else:
|
|
184
|
+
raise ValueError(
|
|
185
|
+
f'{scaling=} should be one of rate, mass_rate or cumulative.'
|
|
186
|
+
)
|
|
187
|
+
return precipitation_and_evaporation
|
|
188
|
+
|
|
189
|
+
def __call__(self, result, *args, **kwargs):
|
|
190
|
+
tendencies = result
|
|
191
|
+
[p_plus_e] = self.extract_p_plus_e(tendencies, *args, **kwargs).values()
|
|
192
|
+
prognostics = self._extract_prognostics(*args, **kwargs)
|
|
193
|
+
observation = self.observation_operator.observe(
|
|
194
|
+
prognostics, query=self.operator_query
|
|
195
|
+
)
|
|
196
|
+
observation = observation[self.observe_key]
|
|
197
|
+
precipitation_and_evaporation = {
|
|
198
|
+
self.diagnosed_key: p_plus_e - observation,
|
|
199
|
+
self.observe_key: observation,
|
|
200
|
+
}
|
|
201
|
+
return self._apply_scaling(precipitation_and_evaporation)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@nnx.dataclass
|
|
205
|
+
class ExtractPrecipitationAndEvaporationWithConstraints(
|
|
206
|
+
ExtractPrecipitationAndEvaporation
|
|
207
|
+
):
|
|
208
|
+
"""Extracts balanced precipitation and evaporation values.
|
|
209
|
+
|
|
210
|
+
This module can be attached in diagnostics that have access to both
|
|
211
|
+
parameterization tendencies and model state to infer balanced precipitation
|
|
212
|
+
and evaporation. We use `observation_operator` to predict on of the
|
|
213
|
+
two (either `precipitation` or `evaporation`) and infer the other from the
|
|
214
|
+
precipitation_plus_evaporation calculation. The mode is defined by the
|
|
215
|
+
provided operator, query and inference variable indicating which variable
|
|
216
|
+
will be computed from the balance equations. Evaporation is constrained to
|
|
217
|
+
be non-positive and precipitation is constrained to be non-negative.
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
def __call__(self, result, *args, **kwargs):
|
|
221
|
+
tendencies = result
|
|
222
|
+
[p_plus_e] = self.extract_p_plus_e(tendencies, *args, **kwargs).values()
|
|
223
|
+
prognostics = self._extract_prognostics(*args, **kwargs)
|
|
224
|
+
observation = self.observation_operator.observe(
|
|
225
|
+
prognostics, query=self.operator_query
|
|
226
|
+
)
|
|
227
|
+
observation = observation[self.observe_key]
|
|
228
|
+
if self.observe_key == self.precipitation_key:
|
|
229
|
+
constrained_observation = cx.cmap(
|
|
230
|
+
lambda x, a, b: jnp.maximum(x, jnp.maximum(a, b))
|
|
231
|
+
)(observation, p_plus_e, 0)
|
|
232
|
+
precipitation_and_evaporation = {
|
|
233
|
+
self.observe_key: constrained_observation,
|
|
234
|
+
self.diagnosed_key: p_plus_e - constrained_observation,
|
|
235
|
+
}
|
|
236
|
+
elif self.observe_key == self.evaporation_key:
|
|
237
|
+
constrained_observation = cx.cmap(
|
|
238
|
+
lambda x, a, b: jnp.minimum(x, jnp.minimum(a, b))
|
|
239
|
+
)(observation, p_plus_e, 0)
|
|
240
|
+
precipitation_and_evaporation = {
|
|
241
|
+
self.observe_key: constrained_observation,
|
|
242
|
+
self.diagnosed_key: p_plus_e - constrained_observation,
|
|
243
|
+
}
|
|
244
|
+
else:
|
|
245
|
+
raise ValueError(
|
|
246
|
+
f'{self.observe_key=} should be either {self.precipitation_key=} or'
|
|
247
|
+
f' {self.evaporation_key=}.'
|
|
248
|
+
)
|
|
249
|
+
return self._apply_scaling(precipitation_and_evaporation)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
@nnx.dataclass
|
|
253
|
+
class ExtractColumnDryAirMass(nnx.Module):
|
|
254
|
+
"""Extracts column dry air mass from prognostics."""
|
|
255
|
+
|
|
256
|
+
ylm_map: spherical_harmonics.FixedYlmMapping
|
|
257
|
+
levels: coordinates.SigmaLevels | coordinates.HybridLevels
|
|
258
|
+
sim_units: units.SimUnits
|
|
259
|
+
moisture_species: tuple[str, ...] = (
|
|
260
|
+
'specific_humidity',
|
|
261
|
+
'specific_cloud_ice_water_content',
|
|
262
|
+
'specific_cloud_liquid_water_content',
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
def __call__(
|
|
266
|
+
self, prognostics: dict[str, cx.Field], *args, **kwargs
|
|
267
|
+
) -> dict[str, cx.Field]:
|
|
268
|
+
"""Computes column dry air mass."""
|
|
269
|
+
del args, kwargs # Unused.
|
|
270
|
+
to_nodal = self.ylm_map.to_nodal
|
|
271
|
+
p_surface_field = cx.cmap(jnp.exp)(
|
|
272
|
+
to_nodal(prognostics['log_surface_pressure'])
|
|
273
|
+
)
|
|
274
|
+
g = self.sim_units.gravity_acceleration
|
|
275
|
+
|
|
276
|
+
missing_keys = [k for k in self.moisture_species if k not in prognostics]
|
|
277
|
+
if missing_keys:
|
|
278
|
+
raise KeyError(
|
|
279
|
+
f'Moisture species {missing_keys} not found in prognostics.'
|
|
280
|
+
)
|
|
281
|
+
q_fields = [to_nodal(prognostics[k]) for k in self.moisture_species]
|
|
282
|
+
if q_fields:
|
|
283
|
+
q_total = sum(q_fields)
|
|
284
|
+
else:
|
|
285
|
+
q_total = 0.0
|
|
286
|
+
assert isinstance(q_total, (cx.Field, float, int))
|
|
287
|
+
|
|
288
|
+
column_dry_air_mass = (1 / g) * self.levels.integrate_over_pressure(
|
|
289
|
+
1.0 - q_total, p_surface_field, self.sim_units
|
|
290
|
+
)
|
|
291
|
+
return {'column_dry_air_mass': column_dry_air_mass}
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
@nnx.dataclass
|
|
295
|
+
class ExtractEnergyResiduals(nnx.Module):
|
|
296
|
+
"""Computes column energy imbalance based on moist enthalpy formulation.
|
|
297
|
+
|
|
298
|
+
This module calculates the imbalance between surface and TOA fluxes (RT - FS)
|
|
299
|
+
and the column energy tendency due to parameterizations dE/dt|_NN based on
|
|
300
|
+
E = phi_s*p_s/g + p_s/g * integral(Cp*T + Lv*q - Lf*qi + k)dsigma
|
|
301
|
+
(See Durran's book section 8.6.4 where it is shown that this is equivalent to
|
|
302
|
+
E = p_s/g * integral(Cv*T + Lv*q + phi - Lf*qi + k)dsigma
|
|
303
|
+
albeit it does not have moisture species there):
|
|
304
|
+
The tendency dE/dt|_NN is computed as:
|
|
305
|
+
dE/dt|_NN = p_s/g * [
|
|
306
|
+
(phi_s + integral(Cp*T + Lv*q - Lf*qi + k)dsigma) * d(log p_s)/dt|_NN +
|
|
307
|
+
integral(Cp*dT/dt|_NN + Lv*dq/dt|_NN - Lf*dqi/dt|_NN + dk/dt|_NN)dsigma
|
|
308
|
+
]
|
|
309
|
+
The module returns the imbalance: (RT - FS) - dE/dt|_NN
|
|
310
|
+
where RT and FS are TOA and surface fluxes obtained from
|
|
311
|
+
observation_operator.
|
|
312
|
+
If use_evaporation_for_latent_heat is True, FS uses latent heat flux
|
|
313
|
+
derived from mean_evaporation_rate (by multiplying by Lv which is inaccurate
|
|
314
|
+
in ice covered regions), otherwise it uses surface_latent_heat_flux
|
|
315
|
+
from net_energy_terms.
|
|
316
|
+
If use_liquid_ice_moist_static_energy is True, qi is included in the
|
|
317
|
+
column energy integral and we need to predict also snowfall to close budget,
|
|
318
|
+
otherwise it is excluded.
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
ylm_map: spherical_harmonics.FixedYlmMapping
|
|
322
|
+
levels: coordinates.SigmaLevels | coordinates.HybridLevels
|
|
323
|
+
sim_units: units.SimUnits
|
|
324
|
+
model_orography: orographies.ModalOrography
|
|
325
|
+
observation_operator: typing.ObservationOperator = nnx.data()
|
|
326
|
+
energy_fluxes_query: dict[str, cx.Coordinate]
|
|
327
|
+
prognostics_arg_key: str | int = 'prognostics'
|
|
328
|
+
use_evaporation_for_latent_heat: bool = False
|
|
329
|
+
use_liquid_ice_moist_static_energy: bool = False
|
|
330
|
+
|
|
331
|
+
def __post_init__(self):
|
|
332
|
+
self.rt_keys = ['top_net_thermal_radiation', 'top_net_solar_radiation']
|
|
333
|
+
self.fs_keys = [
|
|
334
|
+
'surface_sensible_heat_flux',
|
|
335
|
+
'surface_net_solar_radiation',
|
|
336
|
+
'surface_net_thermal_radiation',
|
|
337
|
+
]
|
|
338
|
+
if self.use_evaporation_for_latent_heat:
|
|
339
|
+
required_keys = self.rt_keys + self.fs_keys + ['mean_evaporation_rate']
|
|
340
|
+
else:
|
|
341
|
+
required_keys = self.rt_keys + self.fs_keys + ['surface_latent_heat_flux']
|
|
342
|
+
|
|
343
|
+
if self.use_liquid_ice_moist_static_energy:
|
|
344
|
+
required_keys.append('snowfall')
|
|
345
|
+
missing_keys = [
|
|
346
|
+
k for k in required_keys if k not in self.energy_fluxes_query
|
|
347
|
+
]
|
|
348
|
+
if missing_keys:
|
|
349
|
+
raise ValueError(
|
|
350
|
+
f'Missing energy terms in energy_fluxes_query: {missing_keys}'
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
def _compute_ke_and_tendency(
|
|
354
|
+
self,
|
|
355
|
+
tendencies: dict[str, cx.Field],
|
|
356
|
+
prognostics: dict[str, cx.Field],
|
|
357
|
+
) -> tuple[cx.Field, cx.Field]:
|
|
358
|
+
"""Computes nodal kinetic energy and its tendency."""
|
|
359
|
+
velocity_from_div_curl = transforms.VelocityFromDivCurl(self.ylm_map)
|
|
360
|
+
winds = velocity_from_div_curl({
|
|
361
|
+
'vorticity': prognostics['vorticity'],
|
|
362
|
+
'divergence': prognostics['divergence'],
|
|
363
|
+
})
|
|
364
|
+
u_nodal = winds['u_component_of_wind']
|
|
365
|
+
v_nodal = winds['v_component_of_wind']
|
|
366
|
+
k_nodal = 0.5 * (u_nodal**2 + v_nodal**2)
|
|
367
|
+
wind_tends = velocity_from_div_curl({
|
|
368
|
+
'vorticity': tendencies['vorticity'],
|
|
369
|
+
'divergence': tendencies['divergence'],
|
|
370
|
+
})
|
|
371
|
+
du_dt_nodal = wind_tends['u_component_of_wind']
|
|
372
|
+
dv_dt_nodal = wind_tends['v_component_of_wind']
|
|
373
|
+
dk_dt_nodal = u_nodal * du_dt_nodal + v_nodal * dv_dt_nodal
|
|
374
|
+
return k_nodal, dk_dt_nodal
|
|
375
|
+
|
|
376
|
+
def _compute_vertically_integrated_tendency(
|
|
377
|
+
self,
|
|
378
|
+
tendencies: dict[str, cx.Field],
|
|
379
|
+
prognostics: dict[str, cx.Field],
|
|
380
|
+
) -> cx.Field:
|
|
381
|
+
"""Computes column energy tendency due to parameterization."""
|
|
382
|
+
to_nodal = self.ylm_map.to_nodal
|
|
383
|
+
p_surface_field = cx.cmap(jnp.exp)(
|
|
384
|
+
to_nodal(prognostics['log_surface_pressure'])
|
|
385
|
+
)
|
|
386
|
+
cp = self.sim_units.Cp
|
|
387
|
+
lv = self.sim_units.Lv
|
|
388
|
+
g = self.sim_units.gravity_acceleration
|
|
389
|
+
lf = self.sim_units.Lf
|
|
390
|
+
|
|
391
|
+
t_nodal_field = to_nodal(prognostics['temperature'])
|
|
392
|
+
q_nodal_field = to_nodal(prognostics['specific_humidity'])
|
|
393
|
+
|
|
394
|
+
if self.use_liquid_ice_moist_static_energy:
|
|
395
|
+
qi_nodal_field = to_nodal(prognostics['specific_cloud_ice_water_content'])
|
|
396
|
+
dqi_dt_nodal_field = to_nodal(
|
|
397
|
+
tendencies['specific_cloud_ice_water_content']
|
|
398
|
+
)
|
|
399
|
+
else:
|
|
400
|
+
qi_nodal_field = 0.0
|
|
401
|
+
dqi_dt_nodal_field = 0.0
|
|
402
|
+
|
|
403
|
+
k_nodal_field, dk_dt_nodal_field = self._compute_ke_and_tendency(
|
|
404
|
+
tendencies, prognostics
|
|
405
|
+
)
|
|
406
|
+
temp_tend_nodal_field = to_nodal(tendencies['temperature'])
|
|
407
|
+
hum_tend_nodal_field = to_nodal(tendencies['specific_humidity'])
|
|
408
|
+
|
|
409
|
+
phi_s = self.model_orography.nodal_orography * g
|
|
410
|
+
log_sp_tend = tendencies.get('log_surface_pressure')
|
|
411
|
+
if log_sp_tend is not None:
|
|
412
|
+
log_sp_tend_nodal_field = to_nodal(log_sp_tend)
|
|
413
|
+
else:
|
|
414
|
+
log_sp_tend_nodal_field = p_surface_field * 0
|
|
415
|
+
|
|
416
|
+
integrand1 = (
|
|
417
|
+
cp * t_nodal_field
|
|
418
|
+
+ lv * q_nodal_field
|
|
419
|
+
- lf * qi_nodal_field
|
|
420
|
+
+ k_nodal_field
|
|
421
|
+
)
|
|
422
|
+
i1 = self.levels.integrate_over_pressure(
|
|
423
|
+
integrand1, p_surface_field, self.sim_units
|
|
424
|
+
)
|
|
425
|
+
integrand2 = (
|
|
426
|
+
cp * temp_tend_nodal_field
|
|
427
|
+
+ lv * hum_tend_nodal_field
|
|
428
|
+
- lf * dqi_dt_nodal_field
|
|
429
|
+
+ dk_dt_nodal_field
|
|
430
|
+
)
|
|
431
|
+
i2 = self.levels.integrate_over_pressure(
|
|
432
|
+
integrand2, p_surface_field, self.sim_units
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
energy_tendency_data = (1 / g) * (
|
|
436
|
+
(p_surface_field * phi_s + i1) * log_sp_tend_nodal_field + i2
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
return energy_tendency_data
|
|
440
|
+
|
|
441
|
+
def __call__(
|
|
442
|
+
self,
|
|
443
|
+
inputs: dict[str, cx.Field],
|
|
444
|
+
*args,
|
|
445
|
+
**kwargs,
|
|
446
|
+
) -> dict[str, cx.Field]:
|
|
447
|
+
"""Computes temperature tendency adjustment to conserve energy."""
|
|
448
|
+
tendencies = inputs
|
|
449
|
+
if isinstance(self.prognostics_arg_key, int):
|
|
450
|
+
prognostics = args[self.prognostics_arg_key]
|
|
451
|
+
else:
|
|
452
|
+
prognostics = kwargs.get(self.prognostics_arg_key)
|
|
453
|
+
if not isinstance(prognostics, dict):
|
|
454
|
+
raise ValueError(
|
|
455
|
+
f'Prognostics must be a dictionary, got {type(prognostics)=} instead.'
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
e_tendency_nn = self._compute_vertically_integrated_tendency(
|
|
459
|
+
tendencies, prognostics
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
net_energy_terms = self.observation_operator.observe(
|
|
463
|
+
prognostics, query=self.energy_fluxes_query
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
# Assuming observation_operator returns RT and FS fluxes in J/m^2
|
|
467
|
+
# accumulated over an hour time and need to be converted to W/m^2.
|
|
468
|
+
# RT is TOA flux into atm, and FS is surface flux from atm.
|
|
469
|
+
# The user-provided formula is dE/dt = RT - FS.
|
|
470
|
+
sec_in_hour_inv = 1 / self.sim_units.nondimensionalize(
|
|
471
|
+
3600 * typing.units.seconds
|
|
472
|
+
)
|
|
473
|
+
rt = sum(net_energy_terms[k] for k in self.rt_keys) * sec_in_hour_inv
|
|
474
|
+
fs = sum(net_energy_terms[k] for k in self.fs_keys) * sec_in_hour_inv
|
|
475
|
+
|
|
476
|
+
if self.use_evaporation_for_latent_heat:
|
|
477
|
+
# mean_evaporation_rate is mass rate per second, in SI: (kg/m^2/s).
|
|
478
|
+
# Multiplying by Lv gives nondim equivalent of W/m^2.
|
|
479
|
+
fs += net_energy_terms['mean_evaporation_rate'] * self.sim_units.Lv
|
|
480
|
+
else:
|
|
481
|
+
fs += net_energy_terms['surface_latent_heat_flux'] * sec_in_hour_inv
|
|
482
|
+
if self.use_liquid_ice_moist_static_energy:
|
|
483
|
+
# snowfall is in m of water equivalent (accumulated), so we multiply by
|
|
484
|
+
# density and Lf to get J/m^2 and then divide by time to get W/m^2.
|
|
485
|
+
fs += (
|
|
486
|
+
net_energy_terms['snowfall']
|
|
487
|
+
* self.sim_units.water_density
|
|
488
|
+
* self.sim_units.Lf
|
|
489
|
+
* sec_in_hour_inv
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
# Energy imbalance: difference between required tendency (rt - fs) and
|
|
493
|
+
# tendency from NN (e_tendency_nn).
|
|
494
|
+
imbalance = (rt - fs) - e_tendency_nn
|
|
495
|
+
return {'imbalance': imbalance}
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
@nnx.dataclass
|
|
499
|
+
class ExtractColumnEnergyBudget(nnx.Module):
|
|
500
|
+
"""Computes column energy and adds TOA and surface fluxes.
|
|
501
|
+
|
|
502
|
+
This module calculates column energy based on
|
|
503
|
+
E_col = phi_s*p_s/g + p_s/g * integral(Cp*T + Lv*q - Lf*qi + k)dsigma
|
|
504
|
+
and adds TOA and surface fluxes (RT - FS) obtained from
|
|
505
|
+
observation_operator to compute budget = E_col + RT - FS, representing
|
|
506
|
+
the column energy based on fluxes.
|
|
507
|
+
The result is integrated horizontally to obtain a global energy budget:
|
|
508
|
+
E = horizontal_integral(E_col + RT - FS).
|
|
509
|
+
If use_evaporation_for_latent_heat is True, FS uses latent heat flux
|
|
510
|
+
derived from mean_evaporation_rate (by multiplying by Lv which is inaccurate
|
|
511
|
+
in ice covered regions), otherwise it uses surface_latent_heat_flux
|
|
512
|
+
from net_energy_terms.
|
|
513
|
+
If use_liquid_ice_moist_static_energy is True, qi is included in the
|
|
514
|
+
column energy integral and we need to predict also snowfall to close budget,
|
|
515
|
+
otherwise it is excluded.
|
|
516
|
+
"""
|
|
517
|
+
|
|
518
|
+
ylm_map: spherical_harmonics.FixedYlmMapping
|
|
519
|
+
levels: coordinates.SigmaLevels | coordinates.HybridLevels
|
|
520
|
+
sim_units: units.SimUnits
|
|
521
|
+
model_orography: orographies.ModalOrography
|
|
522
|
+
observation_operator: typing.ObservationOperator | None = nnx.data(
|
|
523
|
+
default=None
|
|
524
|
+
)
|
|
525
|
+
energy_fluxes_query: dict[str, cx.Coordinate] | None = None
|
|
526
|
+
dt: float | None = None
|
|
527
|
+
use_evaporation_for_latent_heat: bool = False
|
|
528
|
+
use_liquid_ice_moist_static_energy: bool = False
|
|
529
|
+
prognostics_arg_key: str | int | None = None
|
|
530
|
+
|
|
531
|
+
def __post_init__(self):
|
|
532
|
+
if self.energy_fluxes_query is not None:
|
|
533
|
+
self.rt_keys = ['top_net_thermal_radiation', 'top_net_solar_radiation']
|
|
534
|
+
self.fs_keys = [
|
|
535
|
+
'surface_sensible_heat_flux',
|
|
536
|
+
'surface_net_solar_radiation',
|
|
537
|
+
'surface_net_thermal_radiation',
|
|
538
|
+
]
|
|
539
|
+
if self.use_evaporation_for_latent_heat:
|
|
540
|
+
required_keys = self.rt_keys + self.fs_keys + ['mean_evaporation_rate']
|
|
541
|
+
else:
|
|
542
|
+
required_keys = (
|
|
543
|
+
self.rt_keys + self.fs_keys + ['surface_latent_heat_flux']
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
if self.use_liquid_ice_moist_static_energy:
|
|
547
|
+
required_keys.append('snowfall')
|
|
548
|
+
missing_keys = [
|
|
549
|
+
k for k in required_keys if k not in self.energy_fluxes_query
|
|
550
|
+
]
|
|
551
|
+
if missing_keys:
|
|
552
|
+
raise ValueError(
|
|
553
|
+
f'Missing energy terms in energy_fluxes_query: {missing_keys}'
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
def _compute_column_energy(
|
|
557
|
+
self, prognostics: dict[str, cx.Field]
|
|
558
|
+
) -> cx.Field:
|
|
559
|
+
to_nodal = self.ylm_map.to_nodal
|
|
560
|
+
p_surface_field = cx.cmap(jnp.exp)(
|
|
561
|
+
to_nodal(prognostics['log_surface_pressure'])
|
|
562
|
+
)
|
|
563
|
+
cp = self.sim_units.Cp
|
|
564
|
+
lv = self.sim_units.Lv
|
|
565
|
+
g = self.sim_units.gravity_acceleration
|
|
566
|
+
lf = self.sim_units.Lf
|
|
567
|
+
|
|
568
|
+
t_nodal_field = to_nodal(prognostics['temperature'])
|
|
569
|
+
q_nodal_field = to_nodal(prognostics['specific_humidity'])
|
|
570
|
+
|
|
571
|
+
if self.use_liquid_ice_moist_static_energy:
|
|
572
|
+
qi_nodal_field = to_nodal(prognostics['specific_cloud_ice_water_content'])
|
|
573
|
+
else:
|
|
574
|
+
qi_nodal_field = 0.0
|
|
575
|
+
|
|
576
|
+
velocity_from_div_curl = transforms.VelocityFromDivCurl(self.ylm_map)
|
|
577
|
+
winds = velocity_from_div_curl({
|
|
578
|
+
'vorticity': prognostics['vorticity'],
|
|
579
|
+
'divergence': prognostics['divergence'],
|
|
580
|
+
})
|
|
581
|
+
u_nodal = winds['u_component_of_wind']
|
|
582
|
+
v_nodal = winds['v_component_of_wind']
|
|
583
|
+
k_nodal_field = 0.5 * (u_nodal**2 + v_nodal**2)
|
|
584
|
+
|
|
585
|
+
phi_s = self.model_orography.nodal_orography * g
|
|
586
|
+
|
|
587
|
+
integrand = (
|
|
588
|
+
cp * t_nodal_field
|
|
589
|
+
+ lv * q_nodal_field
|
|
590
|
+
- lf * qi_nodal_field
|
|
591
|
+
+ k_nodal_field
|
|
592
|
+
)
|
|
593
|
+
vert_integrated_energy = self.levels.integrate_over_pressure(
|
|
594
|
+
integrand, p_surface_field, self.sim_units
|
|
595
|
+
)
|
|
596
|
+
column_energy = (1 / g) * (p_surface_field * phi_s + vert_integrated_energy)
|
|
597
|
+
return column_energy
|
|
598
|
+
|
|
599
|
+
def __call__(
|
|
600
|
+
self,
|
|
601
|
+
inputs: dict[str, cx.Field],
|
|
602
|
+
*args,
|
|
603
|
+
**kwargs,
|
|
604
|
+
) -> dict[str, cx.Field]:
|
|
605
|
+
"""Computes total energy budget."""
|
|
606
|
+
if self.prognostics_arg_key is None:
|
|
607
|
+
prognostics = inputs
|
|
608
|
+
elif isinstance(self.prognostics_arg_key, int):
|
|
609
|
+
prognostics = args[self.prognostics_arg_key]
|
|
610
|
+
else:
|
|
611
|
+
prognostics = kwargs[self.prognostics_arg_key]
|
|
612
|
+
|
|
613
|
+
if not isinstance(prognostics, dict):
|
|
614
|
+
raise ValueError(
|
|
615
|
+
f'Prognostics must be a dictionary, got {type(prognostics)=} instead.'
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
column_energy = self._compute_column_energy(prognostics)
|
|
619
|
+
results = {'column_energy': column_energy}
|
|
620
|
+
|
|
621
|
+
if self.observation_operator is not None:
|
|
622
|
+
if self.energy_fluxes_query is None:
|
|
623
|
+
raise ValueError(
|
|
624
|
+
'energy_fluxes_query must be provided if observation_operator is'
|
|
625
|
+
' provided.'
|
|
626
|
+
)
|
|
627
|
+
net_energy_terms = self.observation_operator.observe(
|
|
628
|
+
prognostics, query=self.energy_fluxes_query
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
# Assuming observation_operator returns RT and FS fluxes in J/m^2
|
|
632
|
+
# accumulated over an hour time.
|
|
633
|
+
rt = sum(net_energy_terms[k] for k in self.rt_keys)
|
|
634
|
+
fs = sum(net_energy_terms[k] for k in self.fs_keys)
|
|
635
|
+
|
|
636
|
+
if self.use_evaporation_for_latent_heat:
|
|
637
|
+
# mean_evaporation_rate is a rate (kg/m^2/s), so we multiply by Lv and
|
|
638
|
+
# dt to get J/m^2.
|
|
639
|
+
fs += (
|
|
640
|
+
net_energy_terms['mean_evaporation_rate']
|
|
641
|
+
* self.sim_units.Lv
|
|
642
|
+
* self.dt
|
|
643
|
+
)
|
|
644
|
+
else:
|
|
645
|
+
fs += net_energy_terms['surface_latent_heat_flux']
|
|
646
|
+
if self.use_liquid_ice_moist_static_energy:
|
|
647
|
+
# snowfall is in m of water equivalent (accumulated), so we multiply by
|
|
648
|
+
# density and Lf to get J/m^2.
|
|
649
|
+
fs += (
|
|
650
|
+
net_energy_terms['snowfall']
|
|
651
|
+
* self.sim_units.water_density
|
|
652
|
+
* self.sim_units.Lf
|
|
653
|
+
)
|
|
654
|
+
column_budget = column_energy + rt - fs
|
|
655
|
+
results['column_energy_budget'] = column_budget
|
|
656
|
+
return results
|