anemoi-datasets 0.5.0__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/inspect.py +1 -1
  3. anemoi/datasets/commands/publish.py +30 -0
  4. anemoi/datasets/create/__init__.py +42 -3
  5. anemoi/datasets/create/check.py +6 -0
  6. anemoi/datasets/create/functions/filters/rename.py +2 -3
  7. anemoi/datasets/create/functions/sources/__init__.py +7 -1
  8. anemoi/datasets/create/functions/sources/accumulations.py +2 -0
  9. anemoi/datasets/create/functions/sources/grib.py +1 -1
  10. anemoi/datasets/create/functions/sources/xarray/__init__.py +7 -2
  11. anemoi/datasets/create/functions/sources/xarray/coordinates.py +12 -1
  12. anemoi/datasets/create/functions/sources/xarray/field.py +13 -4
  13. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +16 -16
  14. anemoi/datasets/create/functions/sources/xarray/flavour.py +130 -13
  15. anemoi/datasets/create/functions/sources/xarray/grid.py +106 -17
  16. anemoi/datasets/create/functions/sources/xarray/metadata.py +3 -11
  17. anemoi/datasets/create/functions/sources/xarray/time.py +1 -5
  18. anemoi/datasets/create/functions/sources/xarray/variable.py +10 -10
  19. anemoi/datasets/create/input/__init__.py +69 -0
  20. anemoi/datasets/create/input/action.py +123 -0
  21. anemoi/datasets/create/input/concat.py +92 -0
  22. anemoi/datasets/create/input/context.py +59 -0
  23. anemoi/datasets/create/input/data_sources.py +71 -0
  24. anemoi/datasets/create/input/empty.py +42 -0
  25. anemoi/datasets/create/input/filter.py +76 -0
  26. anemoi/datasets/create/input/function.py +122 -0
  27. anemoi/datasets/create/input/join.py +57 -0
  28. anemoi/datasets/create/input/misc.py +85 -0
  29. anemoi/datasets/create/input/pipe.py +33 -0
  30. anemoi/datasets/create/input/repeated_dates.py +217 -0
  31. anemoi/datasets/create/input/result.py +413 -0
  32. anemoi/datasets/create/input/step.py +99 -0
  33. anemoi/datasets/create/{template.py → input/template.py} +0 -42
  34. anemoi/datasets/create/statistics/__init__.py +1 -1
  35. anemoi/datasets/create/zarr.py +4 -2
  36. anemoi/datasets/dates/__init__.py +1 -0
  37. anemoi/datasets/dates/groups.py +12 -4
  38. anemoi/datasets/fields.py +66 -0
  39. anemoi/datasets/utils/fields.py +47 -0
  40. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/METADATA +1 -1
  41. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/RECORD +46 -30
  42. anemoi/datasets/create/input.py +0 -1087
  43. /anemoi/datasets/create/{trace.py → input/trace.py} +0 -0
  44. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/LICENSE +0 -0
  45. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/WHEEL +0 -0
  46. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/entry_points.txt +0 -0
  47. {anemoi_datasets-0.5.0.dist-info → anemoi_datasets-0.5.6.dist-info}/top_level.txt +0 -0
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.5.0'
16
- __version_tuple__ = version_tuple = (0, 5, 0)
15
+ __version__ = version = '0.5.6'
16
+ __version_tuple__ = version_tuple = (0, 5, 6)
@@ -311,7 +311,7 @@ class Version:
311
311
  print(f"🕰️ Dataset initialized {when(start)}.")
312
312
  if built and latest:
313
313
  speed = (latest - start) / built
314
- eta = datetime.datetime.utcnow() + speed * (total - built)
314
+ eta = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + speed * (total - built)
315
315
  print(f"🏁 ETA {when(eta)}.")
316
316
  else:
317
317
  if latest:
@@ -0,0 +1,30 @@
1
+ import logging
2
+
3
+ from . import Command
4
+
5
+ LOG = logging.getLogger(__name__)
6
+
7
+
8
+ class Publish(Command):
9
+ """Publish a dataset."""
10
+
11
+ # This is a command that is used to publish a dataset.
12
+ # it is a class, inheriting from Command.
13
+
14
+ internal = True
15
+ timestamp = True
16
+
17
+ def add_arguments(self, parser):
18
+ parser.add_argument("path", help="Path of the dataset to publish.")
19
+
20
+ def run(self, args):
21
+ try:
22
+ from anemoi.registry import publish_dataset
23
+ except ImportError:
24
+ LOG.error("anemoi-registry is not installed. Please install it to use this command.")
25
+ return
26
+
27
+ publish_dataset(args.path)
28
+
29
+
30
+ command = Publish
@@ -14,6 +14,7 @@ import os
14
14
  import time
15
15
  import uuid
16
16
  import warnings
17
+ from copy import deepcopy
17
18
  from functools import cached_property
18
19
 
19
20
  import numpy as np
@@ -24,9 +25,11 @@ from anemoi.utils.dates import frequency_to_string
24
25
  from anemoi.utils.dates import frequency_to_timedelta
25
26
  from anemoi.utils.humanize import compress_dates
26
27
  from anemoi.utils.humanize import seconds_to_human
28
+ from earthkit.data.core.order import build_remapping
27
29
 
28
30
  from anemoi.datasets import MissingDateError
29
31
  from anemoi.datasets import open_dataset
32
+ from anemoi.datasets.create.input.trace import enable_trace
30
33
  from anemoi.datasets.create.persistent import build_storage
31
34
  from anemoi.datasets.data.misc import as_first_date
32
35
  from anemoi.datasets.data.misc import as_last_date
@@ -308,7 +311,6 @@ class HasElementForDataMixin:
308
311
 
309
312
 
310
313
  def build_input_(main_config, output_config):
311
- from earthkit.data.core.order import build_remapping
312
314
 
313
315
  builder = build_input(
314
316
  main_config.input,
@@ -323,6 +325,43 @@ def build_input_(main_config, output_config):
323
325
  return builder
324
326
 
325
327
 
328
+ def tidy_recipe(config: object):
329
+ """Remove potentially private information in the config"""
330
+ config = deepcopy(config)
331
+ if isinstance(config, (tuple, list)):
332
+ return [tidy_recipe(_) for _ in config]
333
+ if isinstance(config, (dict, DotDict)):
334
+ for k, v in config.items():
335
+ if k.startswith("_"):
336
+ config[k] = "*** REMOVED FOR SECURITY ***"
337
+ else:
338
+ config[k] = tidy_recipe(v)
339
+ if isinstance(config, str):
340
+ if config.startswith("_"):
341
+ return "*** REMOVED FOR SECURITY ***"
342
+ if config.startswith("s3://"):
343
+ return "*** REMOVED FOR SECURITY ***"
344
+ if config.startswith("gs://"):
345
+ return "*** REMOVED FOR SECURITY ***"
346
+ if config.startswith("http"):
347
+ return "*** REMOVED FOR SECURITY ***"
348
+ if config.startswith("ftp"):
349
+ return "*** REMOVED FOR SECURITY ***"
350
+ if config.startswith("file"):
351
+ return "*** REMOVED FOR SECURITY ***"
352
+ if config.startswith("ssh"):
353
+ return "*** REMOVED FOR SECURITY ***"
354
+ if config.startswith("scp"):
355
+ return "*** REMOVED FOR SECURITY ***"
356
+ if config.startswith("rsync"):
357
+ return "*** REMOVED FOR SECURITY ***"
358
+ if config.startswith("/"):
359
+ return "*** REMOVED FOR SECURITY ***"
360
+ if "@" in config:
361
+ return "*** REMOVED FOR SECURITY ***"
362
+ return config
363
+
364
+
326
365
  class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin):
327
366
  dataset_class = NewDataset
328
367
  def __init__(self, path, config, check_name=False, overwrite=False, use_threads=False, statistics_temp_dir=None, progress=None, test=False, cache=None, **kwargs): # fmt: skip
@@ -409,6 +448,7 @@ class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixi
409
448
  metadata.update(self.main_config.get("add_metadata", {}))
410
449
 
411
450
  metadata["_create_yaml_config"] = self.main_config.get_serialisable_dict()
451
+ metadata["recipe"] = tidy_recipe(self.main_config.get_serialisable_dict())
412
452
 
413
453
  metadata["description"] = self.main_config.description
414
454
  metadata["licence"] = self.main_config["licence"]
@@ -524,7 +564,7 @@ class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixi
524
564
  # assert isinstance(group[0], datetime.datetime), type(group[0])
525
565
  LOG.debug(f"Building data for group {igroup}/{self.n_groups}")
526
566
 
527
- result = self.input.select(dates=group)
567
+ result = self.input.select(group_of_dates=group)
528
568
  assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group)
529
569
 
530
570
  # There are several groups.
@@ -992,7 +1032,6 @@ def chain(tasks):
992
1032
 
993
1033
  def creator_factory(name, trace=None, **kwargs):
994
1034
  if trace:
995
- from anemoi.datasets.create.trace import enable_trace
996
1035
 
997
1036
  enable_trace(trace)
998
1037
 
@@ -140,9 +140,15 @@ class StatisticsValueError(ValueError):
140
140
 
141
141
  def check_data_values(arr, *, name: str, log=[], allow_nans=False):
142
142
 
143
+ shape = arr.shape
144
+
143
145
  if (isinstance(allow_nans, (set, list, tuple, dict)) and name in allow_nans) or allow_nans:
144
146
  arr = arr[~np.isnan(arr)]
145
147
 
148
+ if arr.size == 0:
149
+ warnings.warn(f"Empty array for {name} ({shape})")
150
+ return
151
+
146
152
  assert arr.size > 0, (name, *log)
147
153
 
148
154
  min, max = arr.min(), arr.max()
@@ -32,7 +32,7 @@ class RenamedFieldMapping:
32
32
 
33
33
  value = self.field.metadata(key, **kwargs)
34
34
  if key == self.what:
35
- return self.renaming.get(value, value)
35
+ return self.renaming.get(self.what, {}).get(value, value)
36
36
 
37
37
  return value
38
38
 
@@ -68,8 +68,7 @@ class RenamedFieldFormat:
68
68
 
69
69
 
70
70
  def execute(context, input, what="param", **kwargs):
71
- # print('🍍🍍🍍🍍🍍🍍🍍🍍🍍🍍🍍🍍🍍 ==========', kwargs)
72
- if what in kwargs:
71
+ if what in kwargs and isinstance(kwargs[what], str):
73
72
  return FieldArray([RenamedFieldFormat(fs, kwargs[what]) for fs in input])
74
73
 
75
74
  return FieldArray([RenamedFieldMapping(fs, what, kwargs) for fs in input])
@@ -16,6 +16,10 @@ LOG = logging.getLogger(__name__)
16
16
 
17
17
 
18
18
  def _expand(paths):
19
+
20
+ if not isinstance(paths, list):
21
+ paths = [paths]
22
+
19
23
  for path in paths:
20
24
  if path.startswith("file://"):
21
25
  path = path[7:]
@@ -40,8 +44,10 @@ def iterate_patterns(path, dates, **kwargs):
40
44
  given_paths = path if isinstance(path, list) else [path]
41
45
 
42
46
  dates = [d.isoformat() for d in dates]
47
+ if len(dates) > 0:
48
+ kwargs["date"] = dates
43
49
 
44
50
  for path in given_paths:
45
- paths = Pattern(path, ignore_missing_keys=True).substitute(date=dates, **kwargs)
51
+ paths = Pattern(path, ignore_missing_keys=True).substitute(**kwargs)
46
52
  for path in _expand(paths):
47
53
  yield path, dates
@@ -375,6 +375,8 @@ def accumulations(context, dates, **request):
375
375
  ("od", "elda"): dict(base_times=(6, 18)),
376
376
  ("ea", "oper"): dict(data_accumulation_period=1, base_times=(6, 18)),
377
377
  ("ea", "enda"): dict(data_accumulation_period=3, base_times=(6, 18)),
378
+ ("rr", "oper"): dict(data_accumulation_period=3, base_times=(0, 3, 6, 9, 12, 15, 18, 21)),
379
+ ("l5", "oper"): dict(data_accumulation_period=1, base_times=(0,)),
378
380
  }
379
381
 
380
382
  kwargs = KWARGS.get((class_, stream), {})
@@ -135,7 +135,7 @@ def execute(context, dates, path, latitudes=None, longitudes=None, *args, **kwar
135
135
  s = s.sel(valid_datetime=dates, **kwargs)
136
136
  ds = ds + s
137
137
 
138
- if kwargs:
138
+ if kwargs and not context.partial_ok:
139
139
  check(ds, given_paths, valid_datetime=dates, **kwargs)
140
140
 
141
141
  if geography is not None:
@@ -12,6 +12,7 @@ import logging
12
12
  from earthkit.data.core.fieldlist import MultiFieldList
13
13
 
14
14
  from anemoi.datasets.data.stores import name_to_zarr_store
15
+ from anemoi.datasets.utils.fields import NewMetadataField as NewMetadataField
15
16
 
16
17
  from .. import iterate_patterns
17
18
  from .fieldlist import XarrayFieldList
@@ -41,7 +42,7 @@ def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs)
41
42
  We have seen this bug triggered when we run many clients in parallel, for example, when we create a new dataset using `xarray-zarr`.
42
43
  """
43
44
 
44
- context.trace(emoji, dataset, options)
45
+ context.trace(emoji, dataset, options, kwargs)
45
46
 
46
47
  if isinstance(dataset, str) and ".zarr" in dataset:
47
48
  data = xr.open_zarr(name_to_zarr_store(dataset), **options)
@@ -49,7 +50,11 @@ def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs)
49
50
  data = xr.open_dataset(dataset, **options)
50
51
 
51
52
  fs = XarrayFieldList.from_xarray(data, flavour)
52
- result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])
53
+
54
+ if len(dates) == 0:
55
+ return fs.sel(**kwargs)
56
+ else:
57
+ result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])
53
58
 
54
59
  if len(result) == 0:
55
60
  LOG.warning(f"No data found for {dataset} and dates {dates} and {kwargs}")
@@ -56,6 +56,8 @@ class Coordinate:
56
56
  is_step = False
57
57
  is_date = False
58
58
  is_member = False
59
+ is_x = False
60
+ is_y = False
59
61
 
60
62
  def __init__(self, variable):
61
63
  self.variable = variable
@@ -66,10 +68,11 @@ class Coordinate:
66
68
  return 1 if self.scalar else len(self.variable)
67
69
 
68
70
  def __repr__(self):
69
- return "%s[name=%s,values=%s]" % (
71
+ return "%s[name=%s,values=%s,shape=%s]" % (
70
72
  self.__class__.__name__,
71
73
  self.variable.name,
72
74
  self.variable.values if self.scalar else len(self),
75
+ self.variable.shape,
73
76
  )
74
77
 
75
78
  def reduced(self, i):
@@ -225,11 +228,13 @@ class LatitudeCoordinate(Coordinate):
225
228
 
226
229
  class XCoordinate(Coordinate):
227
230
  is_grid = True
231
+ is_x = True
228
232
  mars_names = ("x",)
229
233
 
230
234
 
231
235
  class YCoordinate(Coordinate):
232
236
  is_grid = True
237
+ is_y = True
233
238
  mars_names = ("y",)
234
239
 
235
240
 
@@ -239,3 +244,9 @@ class ScalarCoordinate(Coordinate):
239
244
  @property
240
245
  def mars_names(self):
241
246
  return (self.variable.name,)
247
+
248
+
249
+ class UnsupportedCoordinate(Coordinate):
250
+ @property
251
+ def mars_names(self):
252
+ return (self.variable.name,)
@@ -72,13 +72,18 @@ class XArrayField(Field):
72
72
  def shape(self):
73
73
  return self._shape
74
74
 
75
- def to_numpy(self, flatten=False, dtype=None):
76
- values = self.selection.values
75
+ def to_numpy(self, flatten=False, dtype=None, index=None):
76
+ if index is not None:
77
+ values = self.selection[index]
78
+ else:
79
+ values = self.selection
77
80
 
78
81
  assert dtype is None
82
+
79
83
  if flatten:
80
- return values.flatten()
81
- return values.reshape(self.shape)
84
+ return values.values.flatten()
85
+
86
+ return values # .reshape(self.shape)
82
87
 
83
88
  def _make_metadata(self):
84
89
  return XArrayMetadata(self)
@@ -113,3 +118,7 @@ class XArrayField(Field):
113
118
 
114
119
  def __repr__(self):
115
120
  return repr(self._metadata)
121
+
122
+ def _values(self):
123
+ # we don't use .values as this will download the data
124
+ return self.selection
@@ -70,10 +70,10 @@ class XarrayFieldList(FieldList):
70
70
  skip.update(attr_val.split(" "))
71
71
 
72
72
  for name in ds.data_vars:
73
- v = ds[name]
74
- _skip_attr(v, "coordinates")
75
- _skip_attr(v, "bounds")
76
- _skip_attr(v, "grid_mapping")
73
+ variable = ds[name]
74
+ _skip_attr(variable, "coordinates")
75
+ _skip_attr(variable, "bounds")
76
+ _skip_attr(variable, "grid_mapping")
77
77
 
78
78
  # Select only geographical variables
79
79
  for name in ds.data_vars:
@@ -81,14 +81,14 @@ class XarrayFieldList(FieldList):
81
81
  if name in skip:
82
82
  continue
83
83
 
84
- v = ds[name]
84
+ variable = ds[name]
85
85
  coordinates = []
86
86
 
87
- for coord in v.coords:
87
+ for coord in variable.coords:
88
88
 
89
89
  c = guess.guess(ds[coord], coord)
90
90
  assert c, f"Could not guess coordinate for {coord}"
91
- if coord not in v.dims:
91
+ if coord not in variable.dims:
92
92
  c.is_dim = False
93
93
  coordinates.append(c)
94
94
 
@@ -98,17 +98,17 @@ class XarrayFieldList(FieldList):
98
98
  if grid_coords < 2:
99
99
  continue
100
100
 
101
- variables.append(
102
- Variable(
103
- ds=ds,
104
- var=v,
105
- coordinates=coordinates,
106
- grid=guess.grid(coordinates),
107
- time=Time.from_coordinates(coordinates),
108
- metadata={},
109
- )
101
+ v = Variable(
102
+ ds=ds,
103
+ variable=variable,
104
+ coordinates=coordinates,
105
+ grid=guess.grid(coordinates, variable),
106
+ time=Time.from_coordinates(coordinates),
107
+ metadata={},
110
108
  )
111
109
 
110
+ variables.append(v)
111
+
112
112
  return cls(ds, variables)
113
113
 
114
114
  def sel(self, **kwargs):
@@ -8,6 +8,8 @@
8
8
  #
9
9
 
10
10
 
11
+ import logging
12
+
11
13
  from .coordinates import DateCoordinate
12
14
  from .coordinates import EnsembleCoordinate
13
15
  from .coordinates import LatitudeCoordinate
@@ -16,10 +18,16 @@ from .coordinates import LongitudeCoordinate
16
18
  from .coordinates import ScalarCoordinate
17
19
  from .coordinates import StepCoordinate
18
20
  from .coordinates import TimeCoordinate
21
+ from .coordinates import UnsupportedCoordinate
19
22
  from .coordinates import XCoordinate
20
23
  from .coordinates import YCoordinate
24
+ from .coordinates import is_scalar
21
25
  from .grid import MeshedGrid
26
+ from .grid import MeshProjectionGrid
22
27
  from .grid import UnstructuredGrid
28
+ from .grid import UnstructuredProjectionGrid
29
+
30
+ LOG = logging.getLogger(__name__)
23
31
 
24
32
 
25
33
  class CoordinateGuesser:
@@ -150,36 +158,145 @@ class CoordinateGuesser:
150
158
  if c.shape in ((1,), tuple()):
151
159
  return ScalarCoordinate(c)
152
160
 
153
- raise NotImplementedError(
161
+ LOG.warning(
154
162
  f"Coordinate {coord} not supported\n{axis=}, {name=},"
155
163
  f" {long_name=}, {standard_name=}, units\n\n{c}\n\n{type(c.values)} {c.shape}"
156
164
  )
157
165
 
158
- def grid(self, coordinates):
166
+ return UnsupportedCoordinate(c)
167
+
168
+ def grid(self, coordinates, variable):
159
169
  lat = [c for c in coordinates if c.is_lat]
160
170
  lon = [c for c in coordinates if c.is_lon]
161
171
 
162
- if len(lat) != 1:
163
- raise NotImplementedError(f"Expected 1 latitude coordinate, got {len(lat)}")
172
+ if len(lat) == 1 and len(lon) == 1:
173
+ return self._lat_lon_provided(lat, lon, variable)
174
+
175
+ x = [c for c in coordinates if c.is_x]
176
+ y = [c for c in coordinates if c.is_y]
164
177
 
165
- if len(lon) != 1:
166
- raise NotImplementedError(f"Expected 1 longitude coordinate, got {len(lon)}")
178
+ if len(x) == 1 and len(y) == 1:
179
+ return self._x_y_provided(x, y, variable)
167
180
 
181
+ raise NotImplementedError(f"Cannot establish grid {coordinates}")
182
+
183
+ def _check_dims(self, variable, x_or_lon, y_or_lat):
184
+
185
+ x_dims = set(x_or_lon.variable.dims)
186
+ y_dims = set(y_or_lat.variable.dims)
187
+ variable_dims = set(variable.dims)
188
+
189
+ if not (x_dims <= variable_dims) or not (y_dims <= variable_dims):
190
+ raise ValueError(
191
+ f"Dimensions do not match {variable.name}{variable.dims} !="
192
+ f" {x_or_lon.name}{x_or_lon.variable.dims} and {y_or_lat.name}{y_or_lat.variable.dims}"
193
+ )
194
+
195
+ variable_dims = tuple(v for v in variable.dims if v in (x_dims | y_dims))
196
+ if x_dims == y_dims:
197
+ # It's unstructured
198
+ return variable_dims, True
199
+
200
+ if len(x_dims) == 1 and len(y_dims) == 1:
201
+ # It's a mesh
202
+ return variable_dims, False
203
+
204
+ raise ValueError(
205
+ f"Cannot establish grid for {variable.name}{variable.dims},"
206
+ f" {x_or_lon.name}{x_or_lon.variable.dims},"
207
+ f" {y_or_lat.name}{y_or_lat.variable.dims}"
208
+ )
209
+
210
+ def _lat_lon_provided(self, lat, lon, variable):
168
211
  lat = lat[0]
169
212
  lon = lon[0]
170
213
 
171
- if (lat.name, lon.name) in self._cache:
172
- return self._cache[(lat.name, lon.name)]
214
+ dim_vars, unstructured = self._check_dims(variable, lon, lat)
215
+
216
+ if (lat.name, lon.name, dim_vars) in self._cache:
217
+ return self._cache[(lat.name, lon.name, dim_vars)]
173
218
 
174
- assert len(lat.variable.shape) == len(lon.variable.shape), (lat.variable.shape, lon.variable.shape)
175
- if len(lat.variable.shape) == 1:
176
- grid = MeshedGrid(lat, lon)
219
+ if unstructured:
220
+ grid = UnstructuredGrid(lat, lon, dim_vars)
177
221
  else:
178
- grid = UnstructuredGrid(lat, lon)
222
+ grid = MeshedGrid(lat, lon, dim_vars)
179
223
 
180
- self._cache[(lat.name, lon.name)] = grid
224
+ self._cache[(lat.name, lon.name, dim_vars)] = grid
181
225
  return grid
182
226
 
227
+ def _x_y_provided(self, x, y, variable):
228
+ x = x[0]
229
+ y = y[0]
230
+
231
+ _, unstructured = self._check_dims(variable, x, y)
232
+
233
+ if x.variable.dims != y.variable.dims:
234
+ raise ValueError(f"Dimensions do not match {x.name}{x.variable.dims} != {y.name}{y.variable.dims}")
235
+
236
+ if (x.name, y.name) in self._cache:
237
+ return self._cache[(x.name, y.name)]
238
+
239
+ if (x.name, y.name) in self._cache:
240
+ return self._cache[(x.name, y.name)]
241
+
242
+ assert len(x.variable.shape) == len(y.variable.shape), (x.variable.shape, y.variable.shape)
243
+
244
+ grid_mapping = variable.attrs.get("grid_mapping", None)
245
+
246
+ if grid_mapping is None:
247
+ LOG.warning(f"No 'grid_mapping' attribute provided for '{variable.name}'")
248
+ LOG.warning("Trying to guess...")
249
+
250
+ PROBE = {
251
+ "prime_meridian_name",
252
+ "reference_ellipsoid_name",
253
+ "crs_wkt",
254
+ "horizontal_datum_name",
255
+ "semi_major_axis",
256
+ "spatial_ref",
257
+ "inverse_flattening",
258
+ "semi_minor_axis",
259
+ "geographic_crs_name",
260
+ "GeoTransform",
261
+ "grid_mapping_name",
262
+ "longitude_of_prime_meridian",
263
+ }
264
+ candidate = None
265
+ for v in self.ds.variables:
266
+ var = self.ds[v]
267
+ if not is_scalar(var):
268
+ continue
269
+
270
+ if PROBE.intersection(var.attrs.keys()):
271
+ if candidate:
272
+ raise ValueError(f"Multiple candidates for 'grid_mapping': {candidate} and {v}")
273
+ candidate = v
274
+
275
+ if candidate:
276
+ LOG.warning(f"Using '{candidate}' as 'grid_mapping'")
277
+ grid_mapping = candidate
278
+ else:
279
+ LOG.warning("Could not fine a candidate for 'grid_mapping'")
280
+
281
+ if grid_mapping is None:
282
+ if "crs" in self.ds[variable].attrs:
283
+ grid_mapping = self.ds[variable].attrs["crs"]
284
+ LOG.warning(f"Using CRS {grid_mapping} from variable '{variable.name}' attributes")
285
+
286
+ if grid_mapping is None:
287
+ if "crs" in self.ds.attrs:
288
+ grid_mapping = self.ds.attrs["crs"]
289
+ LOG.warning(f"Using CRS {grid_mapping} from global attributes")
290
+
291
+ if grid_mapping is not None:
292
+ if unstructured:
293
+ return UnstructuredProjectionGrid(x, y, grid_mapping)
294
+ else:
295
+ return MeshProjectionGrid(x, y, grid_mapping)
296
+
297
+ LOG.error("Could not fine a candidate for 'grid_mapping'")
298
+ raise NotImplementedError(f"Unstructured grid {x.name} {y.name}")
299
+
183
300
 
184
301
  class DefaultCoordinateGuesser(CoordinateGuesser):
185
302
  def __init__(self, ds):