anemoi-datasets 0.5.18__py3-none-any.whl → 0.5.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/compare-lam.py +401 -0
  3. anemoi/datasets/commands/grib-index.py +114 -0
  4. anemoi/datasets/create/__init__.py +8 -1
  5. anemoi/datasets/create/config.py +5 -0
  6. anemoi/datasets/create/filters/pressure_level_relative_humidity_to_specific_humidity.py +3 -1
  7. anemoi/datasets/create/filters/pressure_level_specific_humidity_to_relative_humidity.py +3 -1
  8. anemoi/datasets/create/filters/transform.py +1 -3
  9. anemoi/datasets/create/filters/wz_to_w.py +3 -2
  10. anemoi/datasets/create/input/action.py +4 -1
  11. anemoi/datasets/create/input/function.py +5 -5
  12. anemoi/datasets/create/input/result.py +1 -1
  13. anemoi/datasets/create/sources/accumulations.py +13 -0
  14. anemoi/datasets/create/sources/accumulations2.py +652 -0
  15. anemoi/datasets/create/sources/anemoi_dataset.py +73 -0
  16. anemoi/datasets/create/sources/grib.py +7 -0
  17. anemoi/datasets/create/sources/grib_index.py +614 -0
  18. anemoi/datasets/create/sources/legacy.py +7 -2
  19. anemoi/datasets/create/sources/xarray_support/__init__.py +1 -1
  20. anemoi/datasets/create/sources/xarray_support/fieldlist.py +2 -2
  21. anemoi/datasets/create/sources/xarray_support/flavour.py +6 -0
  22. anemoi/datasets/create/sources/xarray_support/grid.py +15 -4
  23. anemoi/datasets/data/__init__.py +16 -0
  24. anemoi/datasets/data/complement.py +4 -1
  25. anemoi/datasets/data/dataset.py +14 -0
  26. anemoi/datasets/data/interpolate.py +76 -0
  27. anemoi/datasets/data/masked.py +77 -0
  28. anemoi/datasets/data/misc.py +159 -0
  29. anemoi/datasets/grids.py +8 -2
  30. {anemoi_datasets-0.5.18.dist-info → anemoi_datasets-0.5.20.dist-info}/METADATA +10 -4
  31. {anemoi_datasets-0.5.18.dist-info → anemoi_datasets-0.5.20.dist-info}/RECORD +35 -30
  32. {anemoi_datasets-0.5.18.dist-info → anemoi_datasets-0.5.20.dist-info}/WHEEL +1 -1
  33. {anemoi_datasets-0.5.18.dist-info → anemoi_datasets-0.5.20.dist-info}/entry_points.txt +0 -0
  34. {anemoi_datasets-0.5.18.dist-info → anemoi_datasets-0.5.20.dist-info}/licenses/LICENSE +0 -0
  35. {anemoi_datasets-0.5.18.dist-info → anemoi_datasets-0.5.20.dist-info}/top_level.txt +0 -0
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.5.18'
21
- __version_tuple__ = version_tuple = (0, 5, 18)
20
+ __version__ = version = '0.5.20'
21
+ __version_tuple__ = version_tuple = (0, 5, 20)
@@ -0,0 +1,401 @@
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+ import logging
11
+ import math
12
+ import os
13
+
14
+ from anemoi.datasets import open_dataset
15
+
16
+ from . import Command
17
+
18
+ RADIUS_EARTH_KM = 6371.0 # Earth's radius in kilometers
19
+
20
+ LOG = logging.getLogger(__name__)
21
+
22
+
23
+ class HTML_Writer:
24
+ def __init__(self):
25
+ self.html_content = """
26
+ <!DOCTYPE html>
27
+ <html>
28
+ <head>
29
+ <style>
30
+ table {
31
+ border-collapse: collapse;
32
+ width: 100%;
33
+ }
34
+ th, td {
35
+ border: 1px solid black;
36
+ padding: 8px;
37
+ text-align: center;
38
+ }
39
+ th {
40
+ background-color: #e2e2e2;
41
+ }
42
+ </style>
43
+ </head>
44
+ <body>
45
+ <table>
46
+ <thead>
47
+ <tr>
48
+ <th>Variable</th>
49
+ <th>Global Mean</th>
50
+ <th>LAM Mean</th>
51
+ <th>Mean Diff (%)</th>
52
+ <th>Global Std</th>
53
+ <th>LAM Std</th>
54
+ <th>Std Diff (%)</th>
55
+ <th>Global Max</th>
56
+ <th>LAM Max</th>
57
+ <th>Max Diff (%)</th>
58
+ <th>Global Min</th>
59
+ <th>LAM Min</th>
60
+ <th>Min Diff (%)</th>
61
+ </tr>
62
+ </thead>
63
+ <tbody>
64
+ """
65
+
66
+ def update_table(
67
+ self,
68
+ v1,
69
+ global_mean,
70
+ lam_mean,
71
+ mean_diff,
72
+ global_std,
73
+ lam_std,
74
+ std_diff,
75
+ global_max,
76
+ lam_max,
77
+ max_diff,
78
+ global_min,
79
+ lam_min,
80
+ min_diff,
81
+ ):
82
+
83
+ # Determine inline style for HTML
84
+ mean_bg_color = "background-color: #d4edda;" if abs(mean_diff) < 20 else "background-color: #f8d7da;"
85
+ std_bg_color = "background-color: #d4edda;" if abs(std_diff) < 20 else "background-color: #f8d7da;"
86
+ max_bg_color = "background-color: #d4edda;" if abs(max_diff) < 20 else "background-color: #f8d7da;"
87
+ min_bg_color = "background-color: #d4edda;" if abs(min_diff) < 20 else "background-color: #f8d7da;"
88
+
89
+ # Add a row to the HTML table with inline styles
90
+ self.html_content += f"""
91
+ <tr>
92
+ <td style="background-color: #f2f2f2;">{v1}</td>
93
+ <td>{global_mean}</td>
94
+ <td>{lam_mean}</td>
95
+ <td style="{mean_bg_color}">{mean_diff}%</td>
96
+ <td>{global_std}</td>
97
+ <td>{lam_std}</td>
98
+ <td style="{std_bg_color}">{std_diff}%</td>
99
+ <td>{global_max}</td>
100
+ <td>{lam_max}</td>
101
+ <td style="{max_bg_color}">{max_diff}%</td>
102
+ <td>{global_min}</td>
103
+ <td>{lam_min}</td>
104
+ <td style="{min_bg_color}">{min_diff}%</td>
105
+ </tr>
106
+ """
107
+
108
+ def save_table(self, save_path="stats_table.html"):
109
+ # Close the HTML tags
110
+ self.html_content += """
111
+ </tbody>
112
+ </table>
113
+ </body>
114
+ </html>
115
+ """
116
+
117
+ # Save the HTML content to a file
118
+ with open(save_path, "w") as f:
119
+ f.write(self.html_content)
120
+
121
+ LOG.info(f"\nHTML table saved to: {save_path}")
122
+
123
+
124
+ def plot_coordinates_on_map(lats, lons):
125
+ import cartopy.crs as ccrs
126
+ import cartopy.feature as cfeature
127
+ import matplotlib.pyplot as plt
128
+
129
+ """
130
+ Plots the given latitude and longitude coordinates on a map using Cartopy and Matplotlib.
131
+
132
+ Parameters:
133
+ - lats: List of latitudes
134
+ - lons: List of longitudes
135
+ """
136
+
137
+ if len(lats) != len(lons):
138
+ raise ValueError("The length of latitude and longitude lists must be the same.")
139
+
140
+ # Create a figure and axis using the PlateCarree projection
141
+
142
+ # Define source (PlateCarree) and target (LambertConformal) projections
143
+ target_proj = ccrs.LambertConformal(central_latitude=0, central_longitude=10, standard_parallels=[63.3, 63.3])
144
+
145
+ # Create a figure and axis
146
+ fig, ax = plt.subplots(figsize=(14, 12), subplot_kw={"projection": target_proj})
147
+
148
+ # Set the extent of the map based on the transformed coordinates
149
+ margin = 10
150
+ ax.set_extent(
151
+ [min(lons) - margin, max(lons) + margin, min(lats) - margin, max(lats) + margin], crs=ccrs.PlateCarree()
152
+ )
153
+ # ax.set_extent([-25, 45, 30, 75], crs=ccrs.PlateCarree())
154
+
155
+ # Add map features
156
+ ax.add_feature(cfeature.LAND)
157
+ ax.add_feature(cfeature.OCEAN)
158
+ ax.add_feature(cfeature.COASTLINE.with_scale("50m"), zorder=1, alpha=0.8)
159
+ ax.add_feature(cfeature.BORDERS.with_scale("50m"), linestyle=":", zorder=1)
160
+
161
+ # Plot transformed coordinates
162
+ ax.scatter(lons, lats, color="blue", s=1, edgecolor="b", transform=ccrs.PlateCarree(), alpha=0.3)
163
+ ax.set_title("Latitude and Longitude")
164
+ ax.title.set_size(20)
165
+
166
+ # Show the plot
167
+ return fig
168
+
169
+
170
+ def haversine(lat1, lon1, lat2, lon2):
171
+ """Calculate the great-circle distance between two points on the Earth's surface using the Haversine formula."""
172
+ dlat = math.radians(lat2 - lat1)
173
+ dlon = math.radians(lon2 - lon1)
174
+ a = math.sin(dlat / 2) ** 2 + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon / 2) ** 2
175
+ c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
176
+ return RADIUS_EARTH_KM * c
177
+
178
+
179
+ def rectangle_area_km2(lat1, lon1, lat2, lon2):
180
+ """Calculate the area of a rectangle given the coordinates of the top left and bottom right corners in
181
+ latitude and longitude.
182
+
183
+ Parameters:
184
+ lat1, lon1 - Latitude and longitude of the top-left corner.
185
+ lat2, lon2 - Latitude and longitude of the bottom-right corner.
186
+
187
+ Returns:
188
+ Area in square kilometers (km^2).
189
+ """
190
+ # Calculate the height (difference in latitude)
191
+ height_km = haversine(lat1, lon1, lat2, lon1)
192
+
193
+ # Calculate the width (difference in longitude)
194
+ width_km = haversine(lat1, lon1, lat1, lon2)
195
+
196
+ # Area of the rectangle
197
+ area_km2 = height_km * width_km
198
+ return area_km2
199
+
200
+
201
+ def check_order(vars_1, vars_2):
202
+ for v1, v2 in zip(vars_1, vars_2):
203
+ if v1 != v2 and ((v1 in vars_2) and (v2 in vars_1)):
204
+ return False
205
+
206
+ return True
207
+
208
+
209
+ def compute_wighted_diff(s1, s2, round_ndigits):
210
+ return round((s2 - s1) * 100 / s2, ndigits=round_ndigits)
211
+
212
+
213
+ class CompareLAM(Command):
214
+ """Compare statistic of two datasets. \
215
+ This command compares the statistics of each variable in two datasets ONLY in the overlapping area between the two. \
216
+ """
217
+
218
+ def add_arguments(self, command_parser):
219
+ command_parser.add_argument("dataset1", help="Path of the global dataset or the largest dataset.")
220
+ command_parser.add_argument("dataset2", help="Path of the LAM dataset or the smallest dataset.")
221
+ command_parser.add_argument(
222
+ "-D", "--number-of-dates", type=int, default=10, help="Number of datapoints (in time) to compare over."
223
+ )
224
+ command_parser.add_argument("-O", "--outpath", default="./", help="Path to output folder.")
225
+ command_parser.add_argument("-R", "--number-of-digits", type=int, default=4, help="Number of digits to keep.")
226
+ command_parser.add_argument(
227
+ "--selected-vars",
228
+ nargs="+",
229
+ default=["10u", "10v", "2d", "2t"],
230
+ help="List of selected variables to use in the script.",
231
+ )
232
+ command_parser.add_argument(
233
+ "--save-plots", action="store_true", help="Toggle to save a picture of the data grid."
234
+ )
235
+
236
+ def run(self, args):
237
+ import matplotlib.pyplot as plt
238
+ import numpy as np
239
+ from prettytable import PrettyTable
240
+ from termcolor import colored # For coloring text in the terminal
241
+
242
+ # Unpack args
243
+ date_idx = args.number_of_dates
244
+ round_ndigits = args.number_of_digits
245
+ selected_vars = args.selected_vars
246
+ global_name = args.dataset1
247
+ lam_name = args.dataset2
248
+ date_idx = 10 # "all" or specific index to stop at
249
+ name = f"{global_name}-{lam_name}_{date_idx}"
250
+ save_path = os.path.join(args.outpath, f"comparison_table_{name}.html")
251
+
252
+ # Open LAM dataset
253
+ lam_dataset = open_dataset(lam_name, select=selected_vars)
254
+ lam_vars = list(lam_dataset.variables)
255
+ lam_num_grid_points = lam_dataset[0, 0].shape[1]
256
+ lam_area = rectangle_area_km2(
257
+ max(lam_dataset.latitudes),
258
+ max(lam_dataset.longitudes),
259
+ min(lam_dataset.latitudes),
260
+ min(lam_dataset.longitudes),
261
+ )
262
+ l_coords = (
263
+ max(lam_dataset.latitudes),
264
+ min(lam_dataset.longitudes),
265
+ min(lam_dataset.latitudes),
266
+ max(lam_dataset.longitudes),
267
+ )
268
+
269
+ if args.save_plots:
270
+ _ = plot_coordinates_on_map(lam_dataset.latitudes, lam_dataset.longitudes)
271
+ plt.savefig(os.path.join(args.outpath, "lam_dataset.png"))
272
+
273
+ LOG.info(f"Dataset {lam_name}, has {lam_num_grid_points} grid points. \n")
274
+ LOG.info("LAM (north, west, south, east): ", l_coords)
275
+ LOG.info(f"Point every: {math.sqrt(lam_area / lam_num_grid_points)} km")
276
+
277
+ # Open global dataset and cut it
278
+ lam_start = lam_dataset.dates[0]
279
+ lam_end = lam_dataset.dates[-1]
280
+ global_dataset = open_dataset(global_name, start=lam_start, end=lam_end, area=l_coords, select=selected_vars)
281
+ global_vars = list(global_dataset.variables)
282
+ global_num_grid_points = global_dataset[0, 0].shape[1]
283
+ global_area = rectangle_area_km2(
284
+ max(global_dataset.latitudes),
285
+ max(global_dataset.longitudes),
286
+ min(global_dataset.latitudes),
287
+ min(global_dataset.longitudes),
288
+ )
289
+ g_coords = (
290
+ max(global_dataset.latitudes),
291
+ min(global_dataset.longitudes),
292
+ min(global_dataset.latitudes),
293
+ max(global_dataset.longitudes),
294
+ )
295
+
296
+ if args.save_plots:
297
+ _ = plot_coordinates_on_map(global_dataset.latitudes, global_dataset.longitudes)
298
+ plt.savefig(os.path.join(args.outpath, "global_dataset.png"))
299
+
300
+ LOG.info(f"Dataset {global_name}, has {global_num_grid_points} grid points. \n")
301
+ LOG.info("Global-lam cut (north, west, south, east): ", g_coords)
302
+ LOG.info(f"Point every: {math.sqrt(global_area / global_num_grid_points)} km")
303
+
304
+ # Check variable ordering
305
+ same_order = check_order(global_vars, lam_vars)
306
+ LOG.info(f"Lam dataset has the same order of variables as the global dataset: {same_order}")
307
+
308
+ LOG.info("\nComparing statistics..")
309
+ table = PrettyTable()
310
+ table.field_names = [
311
+ "Variable",
312
+ "Global Mean",
313
+ "LAM Mean",
314
+ "Mean Diff (%)",
315
+ "Global Std",
316
+ "LAM Std",
317
+ "Std Diff (%)",
318
+ "Global Max",
319
+ "LAM Max",
320
+ "Max Diff (%)",
321
+ "Global Min",
322
+ "LAM Min",
323
+ "Min Diff (%)",
324
+ ]
325
+
326
+ # Create a styled HTML table
327
+ html_writer = HTML_Writer()
328
+
329
+ for v1, v2 in zip(global_vars, lam_vars):
330
+ assert v1 == v2
331
+ idx = global_vars.index(v1)
332
+
333
+ if date_idx == "all":
334
+ lam_mean = lam_dataset.statistics["mean"][idx]
335
+ lam_std = lam_dataset.statistics["stdev"][idx]
336
+ lam_max = lam_dataset.statistics["max"][idx]
337
+ lam_min = lam_dataset.statistics["min"][idx]
338
+ global_mean = global_dataset.statistics["mean"][idx]
339
+ global_std = global_dataset.statistics["stdev"][idx]
340
+ global_max = global_dataset.statistics["max"][idx]
341
+ global_min = global_dataset.statistics["min"][idx]
342
+
343
+ else:
344
+ lam_mean = np.nanmean(lam_dataset[:date_idx], axis=(0, 3))[idx][0]
345
+ lam_std = np.nanstd(lam_dataset[:date_idx], axis=(0, 3))[idx][0]
346
+ lam_max = np.nanmax(lam_dataset[:date_idx], axis=(0, 3))[idx][0]
347
+ lam_min = np.nanmin(lam_dataset[:date_idx], axis=(0, 3))[idx][0]
348
+ global_mean = np.nanmean(global_dataset[:date_idx], axis=(0, 3))[idx][0]
349
+ global_std = np.nanstd(global_dataset[:date_idx], axis=(0, 3))[idx][0]
350
+ global_max = np.nanmax(global_dataset[:date_idx], axis=(0, 3))[idx][0]
351
+ global_min = np.nanmin(global_dataset[:date_idx], axis=(0, 3))[idx][0]
352
+
353
+ mean_diff = compute_wighted_diff(lam_mean, global_mean, round_ndigits)
354
+ std_diff = compute_wighted_diff(lam_std, global_std, round_ndigits)
355
+ max_diff = compute_wighted_diff(lam_max, global_max, round_ndigits)
356
+ min_diff = compute_wighted_diff(lam_min, global_min, round_ndigits)
357
+
358
+ mean_color = "red" if abs(mean_diff) >= 20 else "green"
359
+ std_color = "red" if abs(std_diff) >= 20 else "green"
360
+ max_color = "red" if abs(max_diff) >= 20 else "green"
361
+ min_color = "red" if abs(min_diff) >= 20 else "green"
362
+
363
+ table.add_row(
364
+ [
365
+ v1,
366
+ round(global_mean, ndigits=round_ndigits),
367
+ round(lam_mean, ndigits=round_ndigits),
368
+ colored(f"{mean_diff}%", mean_color),
369
+ round(global_std, ndigits=round_ndigits),
370
+ round(lam_std, ndigits=round_ndigits),
371
+ colored(f"{std_diff}%", std_color),
372
+ round(global_max, ndigits=round_ndigits),
373
+ round(lam_max, ndigits=round_ndigits),
374
+ colored(f"{max_diff}%", max_color),
375
+ round(global_min, ndigits=round_ndigits),
376
+ round(lam_min, ndigits=round_ndigits),
377
+ colored(f"{min_diff}%", min_color),
378
+ ]
379
+ )
380
+
381
+ html_writer.update_table(
382
+ v1,
383
+ global_mean,
384
+ lam_mean,
385
+ mean_diff,
386
+ global_std,
387
+ lam_std,
388
+ std_diff,
389
+ global_max,
390
+ lam_max,
391
+ max_diff,
392
+ global_min,
393
+ lam_min,
394
+ min_diff,
395
+ )
396
+
397
+ html_writer.save_table(save_path)
398
+ print(table)
399
+
400
+
401
+ command = CompareLAM
@@ -0,0 +1,114 @@
1
+ # (C) Copyright 2024 Anemoi contributors.
2
+ #
3
+ # This software is licensed under the terms of the Apache Licence Version 2.0
4
+ # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5
+ #
6
+ # In applying this licence, ECMWF does not waive the privileges and immunities
7
+ # granted to it by virtue of its status as an intergovernmental organisation
8
+ # nor does it submit to any jurisdiction.
9
+
10
+ import fnmatch
11
+ import os
12
+ from typing import Any
13
+
14
+ import tqdm
15
+
16
+ from . import Command
17
+
18
+
19
+ class GribIndexCmd(Command):
20
+ internal = True
21
+ timestamp = True
22
+
23
+ def add_arguments(self, command_parser: Any) -> None:
24
+ """Add arguments to the command parser.
25
+
26
+ Parameters
27
+ ----------
28
+ command_parser : Any
29
+ The command parser to which arguments are added.
30
+ """
31
+
32
+ from anemoi.datasets.create.sources.grib_index import KEYS
33
+
34
+ command_parser.add_argument(
35
+ "--index",
36
+ help="Path to the index file to create or update",
37
+ required=True,
38
+ )
39
+
40
+ command_parser.add_argument(
41
+ "--overwrite",
42
+ action="store_true",
43
+ help="Overwrite the index file if it exists (default is to update)",
44
+ )
45
+
46
+ command_parser.add_argument(
47
+ "--match",
48
+ help="Give a glob pattern to match files (default: *.grib)",
49
+ default="*.grib",
50
+ )
51
+
52
+ command_parser.add_argument(
53
+ "--keys",
54
+ help="GRIB keys to add to the index, separated by commas. If the list starts with a +, the keys are added to default list.",
55
+ default=",".join(KEYS),
56
+ )
57
+
58
+ command_parser.add_argument(
59
+ "--flavour",
60
+ help="GRIB flavour file (yaml or json)",
61
+ )
62
+
63
+ command_parser.add_argument("paths", nargs="+", help="Paths to scan")
64
+
65
+ def run(self, args: Any) -> None:
66
+ """Execute the scan command.
67
+
68
+ Parameters
69
+ ----------
70
+ args : Any
71
+ The arguments passed to the command.
72
+ """
73
+
74
+ def match(path: str) -> bool:
75
+ """Check if a path matches the given pattern.
76
+
77
+ Parameters
78
+ ----------
79
+ path : str
80
+ The path to check.
81
+
82
+ Returns
83
+ -------
84
+ bool
85
+ True if the path matches, False otherwise.
86
+ """
87
+ return fnmatch.fnmatch(path, args.match)
88
+
89
+ from anemoi.datasets.create.sources.grib_index import GribIndex
90
+
91
+ index = GribIndex(
92
+ args.index,
93
+ keys=args.keys,
94
+ update=True,
95
+ overwrite=args.overwrite,
96
+ flavour=args.flavour,
97
+ )
98
+
99
+ paths = []
100
+ for path in args.paths:
101
+ if os.path.isfile(path):
102
+ paths.append(path)
103
+ else:
104
+ for root, _, files in os.walk(path):
105
+ for file in files:
106
+ full = os.path.join(root, file)
107
+ paths.append(full)
108
+
109
+ for path in tqdm.tqdm(paths, leave=False):
110
+ if match(path):
111
+ index.add_grib_file(path)
112
+
113
+
114
+ command = GribIndexCmd
@@ -294,7 +294,14 @@ class Dataset:
294
294
  import zarr
295
295
 
296
296
  z = zarr.open(self.path, mode="r")
297
- return loader_config(z.attrs.get("_create_yaml_config"))
297
+ config = loader_config(z.attrs.get("_create_yaml_config"))
298
+
299
+ if "env" in config:
300
+ for k, v in config["env"].items():
301
+ LOG.info(f"Setting env variable {k}={v}")
302
+ os.environ[k] = str(v)
303
+
304
+ return config
298
305
 
299
306
 
300
307
  class WritableDataset(Dataset):
@@ -420,6 +420,11 @@ def loader_config(config: dict, is_test: bool = False) -> LoadersConfig:
420
420
  print(b)
421
421
  raise ValueError("Serialisation failed")
422
422
 
423
+ if "env" in copy:
424
+ for k, v in copy["env"].items():
425
+ LOG.info(f"Setting env variable {k}={v}")
426
+ os.environ[k] = str(v)
427
+
423
428
  return copy
424
429
 
425
430
 
@@ -71,7 +71,9 @@ def execute(context: Any, input: ekd.FieldList, t: str, rh: str, q: str = "q") -
71
71
 
72
72
  t_pl = values[t].to_numpy(flatten=True)
73
73
  rh_pl = values[rh].to_numpy(flatten=True)
74
- pressure = keys[4][1] * 100 # TODO: REMOVE HARDCODED INDICES
74
+ pressure = next(
75
+ float(v) * 100 for k, v in keys if k in ["level", "levelist"]
76
+ ) # Looks first for "level" then "levelist" value
75
77
  # print(f"Handling fields for pressure level {pressure}...")
76
78
 
77
79
  # actual conversion from rh --> q_v
@@ -72,7 +72,9 @@ def execute(context: Any, input: ekd.FieldList, t: str, q: str, rh: str = "r") -
72
72
 
73
73
  t_pl = values[t].to_numpy(flatten=True)
74
74
  q_pl = values[q].to_numpy(flatten=True)
75
- pressure = keys[4][1] * 100 # TODO: REMOVE HARDCODED INDICES
75
+ pressure = next(
76
+ float(v) * 100 for k, v in keys if k in ["level", "levelist"]
77
+ ) # Looks first for "level" then "levelist" value
76
78
  # print(f"Handling fields for pressure level {pressure}...")
77
79
 
78
80
  # actual conversion from rh --> q_v
@@ -33,15 +33,13 @@ class TransformFilter(Filter):
33
33
  from anemoi.transform.filters import create_filter
34
34
 
35
35
  self.name = name
36
- self.transform_filter = create_filter(self, config)
36
+ self.transform_filter = create_filter(context, config)
37
37
 
38
38
  def execute(self, input: ekd.FieldList) -> ekd.FieldList:
39
39
  """Execute the transformation filter.
40
40
 
41
41
  Parameters
42
42
  ----------
43
- context : Any
44
- The context in which the execution occurs.
45
43
  input : ekd.FieldList
46
44
  The input data to be transformed.
47
45
 
@@ -66,8 +66,9 @@ def execute(context: Any, input: ekd.FieldList, wz: str, t: str, w: str = "w") -
66
66
 
67
67
  wz_pl = values[wz].to_numpy(flatten=True)
68
68
  t_pl = values[t].to_numpy(flatten=True)
69
- pressure = keys[4][1] * 100 # TODO: REMOVE HARDCODED INDICES
70
-
69
+ pressure = next(
70
+ float(v) * 100 for k, v in keys if k in ["level", "levelist"]
71
+ ) # Looks first for "level" then "levelist" value
71
72
  w_pl = wz_to_w(wz_pl, t_pl, pressure)
72
73
  result.append(new_field_from_numpy(values[wz], w_pl, param=w))
73
74
 
@@ -17,6 +17,7 @@ from earthkit.data.core.order import build_remapping
17
17
 
18
18
  from ...dates.groups import GroupOfDates
19
19
  from .context import Context
20
+ from .template import substitute
20
21
 
21
22
  LOG = logging.getLogger(__name__)
22
23
 
@@ -238,17 +239,19 @@ def action_factory(config: Dict[str, Any], context: ActionContext, action_path:
238
239
 
239
240
  cls = {
240
241
  "data_sources": DataSourcesAction,
242
+ "data-sources": DataSourcesAction,
241
243
  "concat": ConcatAction,
242
244
  "join": JoinAction,
243
245
  "pipe": PipeAction,
244
246
  "function": FunctionAction,
245
247
  "repeated_dates": RepeatedDatesAction,
248
+ "repeated-dates": RepeatedDatesAction,
246
249
  }.get(key)
247
250
 
248
251
  if cls is None:
249
252
  from ..sources import create_source
250
253
 
251
- source = create_source(None, config)
254
+ source = create_source(None, substitute(context, config))
252
255
  return FunctionAction(context, action_path + [key], key, source)
253
256
 
254
257
  return cls(context, action_path + [key], *args, **kwargs)
@@ -20,7 +20,6 @@ from .misc import _tidy
20
20
  from .misc import assert_fieldlist
21
21
  from .result import Result
22
22
  from .template import notify_result
23
- from .template import resolve
24
23
  from .template import substitute
25
24
  from .trace import trace
26
25
  from .trace import trace_datasource
@@ -79,6 +78,9 @@ class FunctionContext:
79
78
  """Returns whether partial results are acceptable."""
80
79
  return self.owner.group_of_dates.partial_ok
81
80
 
81
+ def get_result(self, *args, **kwargs) -> Any:
82
+ return self.owner.context.get_result(*args, **kwargs)
83
+
82
84
 
83
85
  class FunctionAction(Action):
84
86
  """Represents an action that executes a function.
@@ -203,14 +205,12 @@ class FunctionResult(Result):
203
205
  @trace_datasource
204
206
  def datasource(self) -> FieldList:
205
207
  """Returns the datasource for the function result."""
206
- args, kwargs = resolve(self.context, (self.args, self.kwargs))
208
+ # args, kwargs = resolve(self.context, (self.args, self.kwargs))
207
209
  self.action.source.context = FunctionContext(self)
208
210
 
209
211
  return _tidy(
210
212
  self.action.source.execute(
211
- self.group_of_dates, # Will provide a list of datetime objects
212
- *args,
213
- **kwargs,
213
+ list(self.group_of_dates), # Will provide a list of datetime objects
214
214
  )
215
215
  )
216
216
 
@@ -215,7 +215,7 @@ def _fields_metatata(variables: Tuple[str, ...], cube: Any) -> Dict[str, Any]:
215
215
  result[k] = dict(mars=v) if v else {}
216
216
  result[k].update(other[k])
217
217
  result[k].update(KNOWN.get(k, {}))
218
- assert result[k], k
218
+ # assert result[k], k
219
219
 
220
220
  assert i + 1 == len(variables), (i + 1, len(variables))
221
221
  return result