terrax 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. terrax/__init__.py +17 -0
  2. terrax/atmosphere/__init__.py +13 -0
  3. terrax/atmosphere/diagnostics.py +656 -0
  4. terrax/atmosphere/diagnostics_test.py +247 -0
  5. terrax/atmosphere/equations.py +489 -0
  6. terrax/atmosphere/equations_test.py +121 -0
  7. terrax/atmosphere/fixers.py +248 -0
  8. terrax/atmosphere/fixers_test.py +118 -0
  9. terrax/atmosphere/idealized_states.py +252 -0
  10. terrax/atmosphere/idealized_states_test.py +154 -0
  11. terrax/atmosphere/interpolators.py +400 -0
  12. terrax/atmosphere/interpolators_test.py +398 -0
  13. terrax/atmosphere/observation_operators.py +96 -0
  14. terrax/atmosphere/observation_operators_test.py +101 -0
  15. terrax/atmosphere/parameterizations.py +94 -0
  16. terrax/atmosphere/state_conversion.py +208 -0
  17. terrax/atmosphere/transforms.py +265 -0
  18. terrax/atmosphere/transforms_test.py +179 -0
  19. terrax/auxiliary/__init__.py +13 -0
  20. terrax/auxiliary/models.py +400 -0
  21. terrax/auxiliary/models_test.py +384 -0
  22. terrax/core/__init__.py +13 -0
  23. terrax/core/api.py +1089 -0
  24. terrax/core/api_test.py +735 -0
  25. terrax/core/array_transforms.py +78 -0
  26. terrax/core/boundaries.py +148 -0
  27. terrax/core/boundaries_test.py +114 -0
  28. terrax/core/checkpointing.py +106 -0
  29. terrax/core/checkpointing_test.py +88 -0
  30. terrax/core/coordinates.py +1875 -0
  31. terrax/core/coordinates_test.py +607 -0
  32. terrax/core/data_specs.py +514 -0
  33. terrax/core/data_specs_test.py +423 -0
  34. terrax/core/diagnostics.py +581 -0
  35. terrax/core/diagnostics_test.py +734 -0
  36. terrax/core/dynamic_io.py +169 -0
  37. terrax/core/equations.py +112 -0
  38. terrax/core/feature_transforms.py +218 -0
  39. terrax/core/feature_transforms_test.py +244 -0
  40. terrax/core/fiddle_tags.py +21 -0
  41. terrax/core/field_utils.py +400 -0
  42. terrax/core/field_utils_test.py +637 -0
  43. terrax/core/interpolators.py +181 -0
  44. terrax/core/interpolators_test.py +63 -0
  45. terrax/core/learned_transforms.py +503 -0
  46. terrax/core/learned_transforms_test.py +436 -0
  47. terrax/core/module_utils.py +685 -0
  48. terrax/core/module_utils_test.py +569 -0
  49. terrax/core/normalizations.py +233 -0
  50. terrax/core/normalizations_test.py +346 -0
  51. terrax/core/observation_operators.py +227 -0
  52. terrax/core/observation_operators_test.py +359 -0
  53. terrax/core/orographies.py +166 -0
  54. terrax/core/parallelism.py +577 -0
  55. terrax/core/parallelism_test.py +323 -0
  56. terrax/core/pytree_utils.py +254 -0
  57. terrax/core/pytree_utils_test.py +311 -0
  58. terrax/core/random_processes.py +1151 -0
  59. terrax/core/random_processes_test.py +1068 -0
  60. terrax/core/scales.py +227 -0
  61. terrax/core/scan_utils.py +284 -0
  62. terrax/core/scan_utils_test.py +529 -0
  63. terrax/core/spatial_filters.py +132 -0
  64. terrax/core/spatial_filters_test.py +76 -0
  65. terrax/core/spherical_harmonics.py +549 -0
  66. terrax/core/spherical_harmonics_test.py +444 -0
  67. terrax/core/standard_layers.py +673 -0
  68. terrax/core/standard_layers_test.py +569 -0
  69. terrax/core/step_filters.py +70 -0
  70. terrax/core/step_filters_test.py +41 -0
  71. terrax/core/time_integrators.py +147 -0
  72. terrax/core/towers.py +389 -0
  73. terrax/core/towers_test.py +521 -0
  74. terrax/core/transformer_layers.py +1414 -0
  75. terrax/core/transformer_layers_test.py +291 -0
  76. terrax/core/transforms.py +2536 -0
  77. terrax/core/transforms_test.py +1908 -0
  78. terrax/core/typing.py +207 -0
  79. terrax/core/units.py +246 -0
  80. terrax/core/xarray_utils.py +369 -0
  81. terrax/core/xarray_utils_test.py +352 -0
  82. terrax/couplers/__init__.py +13 -0
  83. terrax/couplers/generic.py +132 -0
  84. terrax/couplers/generic_test.py +83 -0
  85. terrax/inference/__init__.py +13 -0
  86. terrax/inference/dynamic_inputs.py +377 -0
  87. terrax/inference/dynamic_inputs_test.py +260 -0
  88. terrax/inference/runner.py +732 -0
  89. terrax/inference/runner_test.py +393 -0
  90. terrax/inference/streaming.py +73 -0
  91. terrax/inference/streaming_test.py +68 -0
  92. terrax/inference/timing.py +107 -0
  93. terrax/inference/timing_test.py +53 -0
  94. terrax/init_test.py +27 -0
  95. terrax/metrics/__init__.py +24 -0
  96. terrax/metrics/aggregation.py +368 -0
  97. terrax/metrics/aggregation_test.py +94 -0
  98. terrax/metrics/base.py +454 -0
  99. terrax/metrics/binning.py +90 -0
  100. terrax/metrics/deterministic_losses.py +51 -0
  101. terrax/metrics/deterministic_metrics.py +234 -0
  102. terrax/metrics/deterministic_test.py +283 -0
  103. terrax/metrics/evaluators.py +446 -0
  104. terrax/metrics/evaluators_test.py +501 -0
  105. terrax/metrics/probabilistic_losses.py +53 -0
  106. terrax/metrics/probabilistic_metrics.py +353 -0
  107. terrax/metrics/probabilistic_test.py +185 -0
  108. terrax/metrics/scaling.py +554 -0
  109. terrax/metrics/scaling_test.py +193 -0
  110. terrax/metrics/weighting.py +205 -0
  111. terrax/metrics/weighting_test.py +128 -0
  112. terrax/toy_model_examples/__init__.py +13 -0
  113. terrax/toy_model_examples/data_model.py +156 -0
  114. terrax/toy_model_examples/data_model_test.py +98 -0
  115. terrax/toy_model_examples/lorenz96.py +296 -0
  116. terrax/toy_model_examples/lorenz96_test.py +152 -0
  117. terrax/training/__init__.py +13 -0
  118. terrax/training/checkpointing.py +189 -0
  119. terrax/training/data_loading.py +1233 -0
  120. terrax/training/data_loading_test.py +692 -0
  121. terrax/training/model_calibrators.py +247 -0
  122. terrax/training/model_calibrators_test.py +247 -0
  123. terrax/training/train_utils.py +299 -0
  124. terrax/training/trainer.py +2135 -0
  125. terrax/training/trainer_test.py +580 -0
  126. terrax/xreader/__init__.py +26 -0
  127. terrax/xreader/iterators.py +422 -0
  128. terrax/xreader/iterators_test.py +268 -0
  129. terrax/xreader/stencils.py +269 -0
  130. terrax/xreader/stencils_test.py +349 -0
  131. terrax-0.0.5.dist-info/METADATA +53 -0
  132. terrax-0.0.5.dist-info/RECORD +135 -0
  133. terrax-0.0.5.dist-info/WHEEL +5 -0
  134. terrax-0.0.5.dist-info/licenses/LICENSE +202 -0
  135. terrax-0.0.5.dist-info/top_level.txt +1 -0
terrax/__init__.py ADDED
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Terrax: Library for AI-first Earth System Modeling."""
15
+
16
+
17
+ __version__ = "0.0.5" # keep in sync with pyproject.toml
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,656 @@
1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Module-based API for calculating diagnostics of NeuralGCM models."""
16
+
17
+ from typing import Literal, Protocol
18
+
19
+ import coordax as cx
20
+ from flax import nnx
21
+ import jax.numpy as jnp
22
+ from terrax.core import coordinates
23
+ from terrax.core import observation_operators
24
+ from terrax.core import orographies
25
+ from terrax.core import spherical_harmonics
26
+ from terrax.core import transforms
27
+ from terrax.core import typing
28
+ from terrax.core import units
29
+
30
+
31
+ class EnergyBalanceModule(Protocol):
32
+ """Protocol for energy balance modules that adjust tendencies."""
33
+
34
+ def __call__(
35
+ self,
36
+ imbalance: cx.Field,
37
+ tendencies: dict[str, cx.Field],
38
+ *args,
39
+ **kwargs,
40
+ ) -> dict[str, cx.Field]:
41
+ """Adjusts tendencies based on energy imbalance."""
42
+
43
+
44
+ @nnx.dataclass
45
+ class ExtractPrecipitationPlusEvaporation(nnx.Module):
46
+ """Diagnoses precipitation plus evaporation rate from physics tendencies.
47
+
48
+ The computation of P + E is based on the integration of non-dynamical moisture
49
+ tendency over the vertical column. We define precipitation and evaporation
50
+ rates as the rate of change of non-atmospheric moisture, i.e. resulting in
51
+ positive values for precipitation and negative values for evaporation. This
52
+ is in line with how these quantities are often defined in datasets like ERA5
53
+ or IMERG. This is also in line with the convention of having "downward"
54
+ fluxes as positive and "upward" fluxes as negative.
55
+ """
56
+
57
+ ylm_map: spherical_harmonics.FixedYlmMapping
58
+ levels: coordinates.SigmaLevels | coordinates.HybridLevels
59
+ sim_units: units.SimUnits
60
+ moisture_species: tuple[str, ...] = (
61
+ 'specific_humidity',
62
+ 'specific_cloud_ice_water_content',
63
+ 'specific_cloud_liquid_water_content',
64
+ )
65
+ prognostics_arg_key: str | int = 'prognostics'
66
+
67
+ def _compute_p_plus_e_rate(
68
+ self,
69
+ tendencies: dict[str, cx.Field],
70
+ prognostics: dict[str, cx.Field],
71
+ ) -> dict[str, cx.Field]:
72
+ to_nodal = self.ylm_map.to_nodal
73
+ p_surface = cx.cmap(jnp.exp)(to_nodal(prognostics['log_surface_pressure']))
74
+ scale = 1 / self.sim_units.gravity_acceleration
75
+ moisture_tendencies_nodal = [
76
+ to_nodal(v) for k, v in tendencies.items() if k in self.moisture_species
77
+ ]
78
+ moisture_tendencies_sum = sum(moisture_tendencies_nodal)
79
+ assert isinstance(moisture_tendencies_sum, cx.Field)
80
+ p_plus_e = -scale * self.levels.integrate_over_pressure(
81
+ moisture_tendencies_sum,
82
+ p_surface,
83
+ self.sim_units,
84
+ )
85
+ return p_plus_e
86
+
87
+ def __call__(self, inputs, *args, **kwargs) -> dict[str, cx.Field]:
88
+ tendencies = inputs
89
+ if isinstance(self.prognostics_arg_key, int):
90
+ prognostics = args[self.prognostics_arg_key]
91
+ else:
92
+ prognostics = kwargs.get(self.prognostics_arg_key)
93
+ p_plus_e_rate = self._compute_p_plus_e_rate(tendencies, prognostics)
94
+ return {'precipitation_plus_evaporation_rate': p_plus_e_rate}
95
+
96
+
97
+ PrecipitationScales = Literal['rate', 'cumulative', 'mass_rate']
98
+
99
+
100
+ @nnx.dataclass
101
+ class ExtractPrecipitationAndEvaporation(nnx.Module):
102
+ """Extracts balanced precipitation and evaporation values.
103
+
104
+ This module can be attached in diagnostics that have access to both
105
+ parameterization tendencies and model state to infer balanced precipitation
106
+ and evaporation. We use `observation_operator` to predict on of the
107
+ two (either `precipitation` or `evaporation`) and infer the other from the
108
+ precipitation_plus_evaporation calculation. The mode is defined by the
109
+ provided operator, query and inference variable indicating which variable
110
+ will be computed from the balance equations.
111
+
112
+ Attributes:
113
+ observation_operator: Observation operator used to predict one of the two
114
+ variables from the balance equations.
115
+ operator_query: Query used for the observation operator.
116
+ extract_p_plus_e: Module that extracts precipitation plus evaporation from
117
+ tendencies and prognostics.
118
+ prognostics_arg_key: Key or index of the prognostics argument in the call
119
+ signature.
120
+ precipitation_scaling: Scaling strategy for the precipitation field. Must be
121
+ one of `rate`, `mass_rate` or `cumulative`. If using `cumulative` scaling,
122
+ `dt` must be set.
123
+ evaporation_scaling: Scaling strategy for the evaporation field. Must be one
124
+ of `rate`, `mass_rate` or `cumulative`. If using `cumulative` scaling,
125
+ `dt` must be set.
126
+ dt: Timestep by which the precipitation is scaled (only used when
127
+ `precipitation_scaling` is set to `cumulative`).
128
+ sim_units: Object defining nondimensionalization and physical constants.
129
+ precipitation_key: Key under which the precipitation field is stored in the
130
+ output.
131
+ evaporation_key: Key under which the evaporation field is stored in the
132
+ output.
133
+ """
134
+
135
+ observation_operator: typing.ObservationOperator = nnx.data()
136
+ operator_query: dict[str, cx.Coordinate]
137
+ extract_p_plus_e: ExtractPrecipitationPlusEvaporation
138
+ prognostics_arg_key: str | int = 'prognostics'
139
+ precipitation_scaling: PrecipitationScales = 'rate'
140
+ evaporation_scaling: PrecipitationScales = 'rate'
141
+ dt: float | None = None
142
+ precipitation_key: str = 'precipitation'
143
+ evaporation_key: str = 'evaporation'
144
+ sim_units: units.SimUnits = nnx.static(kw_only=True)
145
+
146
+ def __post_init__(self):
147
+ valid_keys = set([self.precipitation_key, self.evaporation_key])
148
+ query_keys = set(self.operator_query.keys())
149
+ if len(query_keys.intersection(valid_keys)) != 1:
150
+ raise ValueError(
151
+ f'{self.operator_query=} should contain exactly on of {valid_keys=}.'
152
+ )
153
+ [self.observe_key] = valid_keys.intersection(query_keys)
154
+ [self.diagnosed_key] = valid_keys.difference(query_keys)
155
+
156
+ def _extract_prognostics(self, *args, **kwargs):
157
+ if isinstance(self.prognostics_arg_key, int):
158
+ prognostics = args[self.prognostics_arg_key]
159
+ else:
160
+ prognostics = kwargs.get(self.prognostics_arg_key)
161
+ if not isinstance(prognostics, dict):
162
+ raise ValueError(
163
+ f'Prognostics must be a dictionary, got {type(prognostics)=} instead.'
164
+ )
165
+ return prognostics
166
+
167
+ def _apply_scaling(self, precipitation_and_evaporation):
168
+ water_density = self.sim_units.water_density
169
+ for key, scaling in zip(
170
+ [self.precipitation_key, self.evaporation_key],
171
+ [self.precipitation_scaling, self.evaporation_scaling],
172
+ ):
173
+ if scaling == 'cumulative':
174
+ if self.dt is None:
175
+ raise ValueError(
176
+ 'dt must be provided when using cumulative precipitation scaling.'
177
+ )
178
+ precipitation_and_evaporation[key] *= self.dt / water_density
179
+ elif scaling == 'rate':
180
+ precipitation_and_evaporation[key] *= 1 / water_density
181
+ elif scaling == 'mass_rate':
182
+ continue
183
+ else:
184
+ raise ValueError(
185
+ f'{scaling=} should be one of rate, mass_rate or cumulative.'
186
+ )
187
+ return precipitation_and_evaporation
188
+
189
+ def __call__(self, result, *args, **kwargs):
190
+ tendencies = result
191
+ [p_plus_e] = self.extract_p_plus_e(tendencies, *args, **kwargs).values()
192
+ prognostics = self._extract_prognostics(*args, **kwargs)
193
+ observation = self.observation_operator.observe(
194
+ prognostics, query=self.operator_query
195
+ )
196
+ observation = observation[self.observe_key]
197
+ precipitation_and_evaporation = {
198
+ self.diagnosed_key: p_plus_e - observation,
199
+ self.observe_key: observation,
200
+ }
201
+ return self._apply_scaling(precipitation_and_evaporation)
202
+
203
+
204
+ @nnx.dataclass
205
+ class ExtractPrecipitationAndEvaporationWithConstraints(
206
+ ExtractPrecipitationAndEvaporation
207
+ ):
208
+ """Extracts balanced precipitation and evaporation values.
209
+
210
+ This module can be attached in diagnostics that have access to both
211
+ parameterization tendencies and model state to infer balanced precipitation
212
+ and evaporation. We use `observation_operator` to predict on of the
213
+ two (either `precipitation` or `evaporation`) and infer the other from the
214
+ precipitation_plus_evaporation calculation. The mode is defined by the
215
+ provided operator, query and inference variable indicating which variable
216
+ will be computed from the balance equations. Evaporation is constrained to
217
+ be non-positive and precipitation is constrained to be non-negative.
218
+ """
219
+
220
+ def __call__(self, result, *args, **kwargs):
221
+ tendencies = result
222
+ [p_plus_e] = self.extract_p_plus_e(tendencies, *args, **kwargs).values()
223
+ prognostics = self._extract_prognostics(*args, **kwargs)
224
+ observation = self.observation_operator.observe(
225
+ prognostics, query=self.operator_query
226
+ )
227
+ observation = observation[self.observe_key]
228
+ if self.observe_key == self.precipitation_key:
229
+ constrained_observation = cx.cmap(
230
+ lambda x, a, b: jnp.maximum(x, jnp.maximum(a, b))
231
+ )(observation, p_plus_e, 0)
232
+ precipitation_and_evaporation = {
233
+ self.observe_key: constrained_observation,
234
+ self.diagnosed_key: p_plus_e - constrained_observation,
235
+ }
236
+ elif self.observe_key == self.evaporation_key:
237
+ constrained_observation = cx.cmap(
238
+ lambda x, a, b: jnp.minimum(x, jnp.minimum(a, b))
239
+ )(observation, p_plus_e, 0)
240
+ precipitation_and_evaporation = {
241
+ self.observe_key: constrained_observation,
242
+ self.diagnosed_key: p_plus_e - constrained_observation,
243
+ }
244
+ else:
245
+ raise ValueError(
246
+ f'{self.observe_key=} should be either {self.precipitation_key=} or'
247
+ f' {self.evaporation_key=}.'
248
+ )
249
+ return self._apply_scaling(precipitation_and_evaporation)
250
+
251
+
252
+ @nnx.dataclass
253
+ class ExtractColumnDryAirMass(nnx.Module):
254
+ """Extracts column dry air mass from prognostics."""
255
+
256
+ ylm_map: spherical_harmonics.FixedYlmMapping
257
+ levels: coordinates.SigmaLevels | coordinates.HybridLevels
258
+ sim_units: units.SimUnits
259
+ moisture_species: tuple[str, ...] = (
260
+ 'specific_humidity',
261
+ 'specific_cloud_ice_water_content',
262
+ 'specific_cloud_liquid_water_content',
263
+ )
264
+
265
+ def __call__(
266
+ self, prognostics: dict[str, cx.Field], *args, **kwargs
267
+ ) -> dict[str, cx.Field]:
268
+ """Computes column dry air mass."""
269
+ del args, kwargs # Unused.
270
+ to_nodal = self.ylm_map.to_nodal
271
+ p_surface_field = cx.cmap(jnp.exp)(
272
+ to_nodal(prognostics['log_surface_pressure'])
273
+ )
274
+ g = self.sim_units.gravity_acceleration
275
+
276
+ missing_keys = [k for k in self.moisture_species if k not in prognostics]
277
+ if missing_keys:
278
+ raise KeyError(
279
+ f'Moisture species {missing_keys} not found in prognostics.'
280
+ )
281
+ q_fields = [to_nodal(prognostics[k]) for k in self.moisture_species]
282
+ if q_fields:
283
+ q_total = sum(q_fields)
284
+ else:
285
+ q_total = 0.0
286
+ assert isinstance(q_total, (cx.Field, float, int))
287
+
288
+ column_dry_air_mass = (1 / g) * self.levels.integrate_over_pressure(
289
+ 1.0 - q_total, p_surface_field, self.sim_units
290
+ )
291
+ return {'column_dry_air_mass': column_dry_air_mass}
292
+
293
+
294
+ @nnx.dataclass
295
+ class ExtractEnergyResiduals(nnx.Module):
296
+ """Computes column energy imbalance based on moist enthalpy formulation.
297
+
298
+ This module calculates the imbalance between surface and TOA fluxes (RT - FS)
299
+ and the column energy tendency due to parameterizations dE/dt|_NN based on
300
+ E = phi_s*p_s/g + p_s/g * integral(Cp*T + Lv*q - Lf*qi + k)dsigma
301
+ (See Durran's book section 8.6.4 where it is shown that this is equivalent to
302
+ E = p_s/g * integral(Cv*T + Lv*q + phi - Lf*qi + k)dsigma
303
+ albeit it does not have moisture species there):
304
+ The tendency dE/dt|_NN is computed as:
305
+ dE/dt|_NN = p_s/g * [
306
+ (phi_s + integral(Cp*T + Lv*q - Lf*qi + k)dsigma) * d(log p_s)/dt|_NN +
307
+ integral(Cp*dT/dt|_NN + Lv*dq/dt|_NN - Lf*dqi/dt|_NN + dk/dt|_NN)dsigma
308
+ ]
309
+ The module returns the imbalance: (RT - FS) - dE/dt|_NN
310
+ where RT and FS are TOA and surface fluxes obtained from
311
+ observation_operator.
312
+ If use_evaporation_for_latent_heat is True, FS uses latent heat flux
313
+ derived from mean_evaporation_rate (by multiplying by Lv which is inaccurate
314
+ in ice covered regions), otherwise it uses surface_latent_heat_flux
315
+ from net_energy_terms.
316
+ If use_liquid_ice_moist_static_energy is True, qi is included in the
317
+ column energy integral and we need to predict also snowfall to close budget,
318
+ otherwise it is excluded.
319
+ """
320
+
321
+ ylm_map: spherical_harmonics.FixedYlmMapping
322
+ levels: coordinates.SigmaLevels | coordinates.HybridLevels
323
+ sim_units: units.SimUnits
324
+ model_orography: orographies.ModalOrography
325
+ observation_operator: typing.ObservationOperator = nnx.data()
326
+ energy_fluxes_query: dict[str, cx.Coordinate]
327
+ prognostics_arg_key: str | int = 'prognostics'
328
+ use_evaporation_for_latent_heat: bool = False
329
+ use_liquid_ice_moist_static_energy: bool = False
330
+
331
+ def __post_init__(self):
332
+ self.rt_keys = ['top_net_thermal_radiation', 'top_net_solar_radiation']
333
+ self.fs_keys = [
334
+ 'surface_sensible_heat_flux',
335
+ 'surface_net_solar_radiation',
336
+ 'surface_net_thermal_radiation',
337
+ ]
338
+ if self.use_evaporation_for_latent_heat:
339
+ required_keys = self.rt_keys + self.fs_keys + ['mean_evaporation_rate']
340
+ else:
341
+ required_keys = self.rt_keys + self.fs_keys + ['surface_latent_heat_flux']
342
+
343
+ if self.use_liquid_ice_moist_static_energy:
344
+ required_keys.append('snowfall')
345
+ missing_keys = [
346
+ k for k in required_keys if k not in self.energy_fluxes_query
347
+ ]
348
+ if missing_keys:
349
+ raise ValueError(
350
+ f'Missing energy terms in energy_fluxes_query: {missing_keys}'
351
+ )
352
+
353
+ def _compute_ke_and_tendency(
354
+ self,
355
+ tendencies: dict[str, cx.Field],
356
+ prognostics: dict[str, cx.Field],
357
+ ) -> tuple[cx.Field, cx.Field]:
358
+ """Computes nodal kinetic energy and its tendency."""
359
+ velocity_from_div_curl = transforms.VelocityFromDivCurl(self.ylm_map)
360
+ winds = velocity_from_div_curl({
361
+ 'vorticity': prognostics['vorticity'],
362
+ 'divergence': prognostics['divergence'],
363
+ })
364
+ u_nodal = winds['u_component_of_wind']
365
+ v_nodal = winds['v_component_of_wind']
366
+ k_nodal = 0.5 * (u_nodal**2 + v_nodal**2)
367
+ wind_tends = velocity_from_div_curl({
368
+ 'vorticity': tendencies['vorticity'],
369
+ 'divergence': tendencies['divergence'],
370
+ })
371
+ du_dt_nodal = wind_tends['u_component_of_wind']
372
+ dv_dt_nodal = wind_tends['v_component_of_wind']
373
+ dk_dt_nodal = u_nodal * du_dt_nodal + v_nodal * dv_dt_nodal
374
+ return k_nodal, dk_dt_nodal
375
+
376
+ def _compute_vertically_integrated_tendency(
377
+ self,
378
+ tendencies: dict[str, cx.Field],
379
+ prognostics: dict[str, cx.Field],
380
+ ) -> cx.Field:
381
+ """Computes column energy tendency due to parameterization."""
382
+ to_nodal = self.ylm_map.to_nodal
383
+ p_surface_field = cx.cmap(jnp.exp)(
384
+ to_nodal(prognostics['log_surface_pressure'])
385
+ )
386
+ cp = self.sim_units.Cp
387
+ lv = self.sim_units.Lv
388
+ g = self.sim_units.gravity_acceleration
389
+ lf = self.sim_units.Lf
390
+
391
+ t_nodal_field = to_nodal(prognostics['temperature'])
392
+ q_nodal_field = to_nodal(prognostics['specific_humidity'])
393
+
394
+ if self.use_liquid_ice_moist_static_energy:
395
+ qi_nodal_field = to_nodal(prognostics['specific_cloud_ice_water_content'])
396
+ dqi_dt_nodal_field = to_nodal(
397
+ tendencies['specific_cloud_ice_water_content']
398
+ )
399
+ else:
400
+ qi_nodal_field = 0.0
401
+ dqi_dt_nodal_field = 0.0
402
+
403
+ k_nodal_field, dk_dt_nodal_field = self._compute_ke_and_tendency(
404
+ tendencies, prognostics
405
+ )
406
+ temp_tend_nodal_field = to_nodal(tendencies['temperature'])
407
+ hum_tend_nodal_field = to_nodal(tendencies['specific_humidity'])
408
+
409
+ phi_s = self.model_orography.nodal_orography * g
410
+ log_sp_tend = tendencies.get('log_surface_pressure')
411
+ if log_sp_tend is not None:
412
+ log_sp_tend_nodal_field = to_nodal(log_sp_tend)
413
+ else:
414
+ log_sp_tend_nodal_field = p_surface_field * 0
415
+
416
+ integrand1 = (
417
+ cp * t_nodal_field
418
+ + lv * q_nodal_field
419
+ - lf * qi_nodal_field
420
+ + k_nodal_field
421
+ )
422
+ i1 = self.levels.integrate_over_pressure(
423
+ integrand1, p_surface_field, self.sim_units
424
+ )
425
+ integrand2 = (
426
+ cp * temp_tend_nodal_field
427
+ + lv * hum_tend_nodal_field
428
+ - lf * dqi_dt_nodal_field
429
+ + dk_dt_nodal_field
430
+ )
431
+ i2 = self.levels.integrate_over_pressure(
432
+ integrand2, p_surface_field, self.sim_units
433
+ )
434
+
435
+ energy_tendency_data = (1 / g) * (
436
+ (p_surface_field * phi_s + i1) * log_sp_tend_nodal_field + i2
437
+ )
438
+
439
+ return energy_tendency_data
440
+
441
+ def __call__(
442
+ self,
443
+ inputs: dict[str, cx.Field],
444
+ *args,
445
+ **kwargs,
446
+ ) -> dict[str, cx.Field]:
447
+ """Computes temperature tendency adjustment to conserve energy."""
448
+ tendencies = inputs
449
+ if isinstance(self.prognostics_arg_key, int):
450
+ prognostics = args[self.prognostics_arg_key]
451
+ else:
452
+ prognostics = kwargs.get(self.prognostics_arg_key)
453
+ if not isinstance(prognostics, dict):
454
+ raise ValueError(
455
+ f'Prognostics must be a dictionary, got {type(prognostics)=} instead.'
456
+ )
457
+
458
+ e_tendency_nn = self._compute_vertically_integrated_tendency(
459
+ tendencies, prognostics
460
+ )
461
+
462
+ net_energy_terms = self.observation_operator.observe(
463
+ prognostics, query=self.energy_fluxes_query
464
+ )
465
+
466
+ # Assuming observation_operator returns RT and FS fluxes in J/m^2
467
+ # accumulated over an hour time and need to be converted to W/m^2.
468
+ # RT is TOA flux into atm, and FS is surface flux from atm.
469
+ # The user-provided formula is dE/dt = RT - FS.
470
+ sec_in_hour_inv = 1 / self.sim_units.nondimensionalize(
471
+ 3600 * typing.units.seconds
472
+ )
473
+ rt = sum(net_energy_terms[k] for k in self.rt_keys) * sec_in_hour_inv
474
+ fs = sum(net_energy_terms[k] for k in self.fs_keys) * sec_in_hour_inv
475
+
476
+ if self.use_evaporation_for_latent_heat:
477
+ # mean_evaporation_rate is mass rate per second, in SI: (kg/m^2/s).
478
+ # Multiplying by Lv gives nondim equivalent of W/m^2.
479
+ fs += net_energy_terms['mean_evaporation_rate'] * self.sim_units.Lv
480
+ else:
481
+ fs += net_energy_terms['surface_latent_heat_flux'] * sec_in_hour_inv
482
+ if self.use_liquid_ice_moist_static_energy:
483
+ # snowfall is in m of water equivalent (accumulated), so we multiply by
484
+ # density and Lf to get J/m^2 and then divide by time to get W/m^2.
485
+ fs += (
486
+ net_energy_terms['snowfall']
487
+ * self.sim_units.water_density
488
+ * self.sim_units.Lf
489
+ * sec_in_hour_inv
490
+ )
491
+
492
+ # Energy imbalance: difference between required tendency (rt - fs) and
493
+ # tendency from NN (e_tendency_nn).
494
+ imbalance = (rt - fs) - e_tendency_nn
495
+ return {'imbalance': imbalance}
496
+
497
+
498
+ @nnx.dataclass
499
+ class ExtractColumnEnergyBudget(nnx.Module):
500
+ """Computes column energy and adds TOA and surface fluxes.
501
+
502
+ This module calculates column energy based on
503
+ E_col = phi_s*p_s/g + p_s/g * integral(Cp*T + Lv*q - Lf*qi + k)dsigma
504
+ and adds TOA and surface fluxes (RT - FS) obtained from
505
+ observation_operator to compute budget = E_col + RT - FS, representing
506
+ the column energy based on fluxes.
507
+ The result is integrated horizontally to obtain a global energy budget:
508
+ E = horizontal_integral(E_col + RT - FS).
509
+ If use_evaporation_for_latent_heat is True, FS uses latent heat flux
510
+ derived from mean_evaporation_rate (by multiplying by Lv which is inaccurate
511
+ in ice covered regions), otherwise it uses surface_latent_heat_flux
512
+ from net_energy_terms.
513
+ If use_liquid_ice_moist_static_energy is True, qi is included in the
514
+ column energy integral and we need to predict also snowfall to close budget,
515
+ otherwise it is excluded.
516
+ """
517
+
518
+ ylm_map: spherical_harmonics.FixedYlmMapping
519
+ levels: coordinates.SigmaLevels | coordinates.HybridLevels
520
+ sim_units: units.SimUnits
521
+ model_orography: orographies.ModalOrography
522
+ observation_operator: typing.ObservationOperator | None = nnx.data(
523
+ default=None
524
+ )
525
+ energy_fluxes_query: dict[str, cx.Coordinate] | None = None
526
+ dt: float | None = None
527
+ use_evaporation_for_latent_heat: bool = False
528
+ use_liquid_ice_moist_static_energy: bool = False
529
+ prognostics_arg_key: str | int | None = None
530
+
531
+ def __post_init__(self):
532
+ if self.energy_fluxes_query is not None:
533
+ self.rt_keys = ['top_net_thermal_radiation', 'top_net_solar_radiation']
534
+ self.fs_keys = [
535
+ 'surface_sensible_heat_flux',
536
+ 'surface_net_solar_radiation',
537
+ 'surface_net_thermal_radiation',
538
+ ]
539
+ if self.use_evaporation_for_latent_heat:
540
+ required_keys = self.rt_keys + self.fs_keys + ['mean_evaporation_rate']
541
+ else:
542
+ required_keys = (
543
+ self.rt_keys + self.fs_keys + ['surface_latent_heat_flux']
544
+ )
545
+
546
+ if self.use_liquid_ice_moist_static_energy:
547
+ required_keys.append('snowfall')
548
+ missing_keys = [
549
+ k for k in required_keys if k not in self.energy_fluxes_query
550
+ ]
551
+ if missing_keys:
552
+ raise ValueError(
553
+ f'Missing energy terms in energy_fluxes_query: {missing_keys}'
554
+ )
555
+
556
+ def _compute_column_energy(
557
+ self, prognostics: dict[str, cx.Field]
558
+ ) -> cx.Field:
559
+ to_nodal = self.ylm_map.to_nodal
560
+ p_surface_field = cx.cmap(jnp.exp)(
561
+ to_nodal(prognostics['log_surface_pressure'])
562
+ )
563
+ cp = self.sim_units.Cp
564
+ lv = self.sim_units.Lv
565
+ g = self.sim_units.gravity_acceleration
566
+ lf = self.sim_units.Lf
567
+
568
+ t_nodal_field = to_nodal(prognostics['temperature'])
569
+ q_nodal_field = to_nodal(prognostics['specific_humidity'])
570
+
571
+ if self.use_liquid_ice_moist_static_energy:
572
+ qi_nodal_field = to_nodal(prognostics['specific_cloud_ice_water_content'])
573
+ else:
574
+ qi_nodal_field = 0.0
575
+
576
+ velocity_from_div_curl = transforms.VelocityFromDivCurl(self.ylm_map)
577
+ winds = velocity_from_div_curl({
578
+ 'vorticity': prognostics['vorticity'],
579
+ 'divergence': prognostics['divergence'],
580
+ })
581
+ u_nodal = winds['u_component_of_wind']
582
+ v_nodal = winds['v_component_of_wind']
583
+ k_nodal_field = 0.5 * (u_nodal**2 + v_nodal**2)
584
+
585
+ phi_s = self.model_orography.nodal_orography * g
586
+
587
+ integrand = (
588
+ cp * t_nodal_field
589
+ + lv * q_nodal_field
590
+ - lf * qi_nodal_field
591
+ + k_nodal_field
592
+ )
593
+ vert_integrated_energy = self.levels.integrate_over_pressure(
594
+ integrand, p_surface_field, self.sim_units
595
+ )
596
+ column_energy = (1 / g) * (p_surface_field * phi_s + vert_integrated_energy)
597
+ return column_energy
598
+
599
+ def __call__(
600
+ self,
601
+ inputs: dict[str, cx.Field],
602
+ *args,
603
+ **kwargs,
604
+ ) -> dict[str, cx.Field]:
605
+ """Computes total energy budget."""
606
+ if self.prognostics_arg_key is None:
607
+ prognostics = inputs
608
+ elif isinstance(self.prognostics_arg_key, int):
609
+ prognostics = args[self.prognostics_arg_key]
610
+ else:
611
+ prognostics = kwargs[self.prognostics_arg_key]
612
+
613
+ if not isinstance(prognostics, dict):
614
+ raise ValueError(
615
+ f'Prognostics must be a dictionary, got {type(prognostics)=} instead.'
616
+ )
617
+
618
+ column_energy = self._compute_column_energy(prognostics)
619
+ results = {'column_energy': column_energy}
620
+
621
+ if self.observation_operator is not None:
622
+ if self.energy_fluxes_query is None:
623
+ raise ValueError(
624
+ 'energy_fluxes_query must be provided if observation_operator is'
625
+ ' provided.'
626
+ )
627
+ net_energy_terms = self.observation_operator.observe(
628
+ prognostics, query=self.energy_fluxes_query
629
+ )
630
+
631
+ # Assuming observation_operator returns RT and FS fluxes in J/m^2
632
+ # accumulated over an hour time.
633
+ rt = sum(net_energy_terms[k] for k in self.rt_keys)
634
+ fs = sum(net_energy_terms[k] for k in self.fs_keys)
635
+
636
+ if self.use_evaporation_for_latent_heat:
637
+ # mean_evaporation_rate is a rate (kg/m^2/s), so we multiply by Lv and
638
+ # dt to get J/m^2.
639
+ fs += (
640
+ net_energy_terms['mean_evaporation_rate']
641
+ * self.sim_units.Lv
642
+ * self.dt
643
+ )
644
+ else:
645
+ fs += net_energy_terms['surface_latent_heat_flux']
646
+ if self.use_liquid_ice_moist_static_energy:
647
+ # snowfall is in m of water equivalent (accumulated), so we multiply by
648
+ # density and Lf to get J/m^2.
649
+ fs += (
650
+ net_energy_terms['snowfall']
651
+ * self.sim_units.water_density
652
+ * self.sim_units.Lf
653
+ )
654
+ column_budget = column_energy + rt - fs
655
+ results['column_energy_budget'] = column_budget
656
+ return results