anemoi-datasets 0.4.5__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. anemoi/datasets/_version.py +2 -2
  2. anemoi/datasets/commands/create.py +3 -2
  3. anemoi/datasets/commands/inspect.py +1 -1
  4. anemoi/datasets/commands/publish.py +30 -0
  5. anemoi/datasets/create/__init__.py +72 -35
  6. anemoi/datasets/create/check.py +6 -0
  7. anemoi/datasets/create/config.py +4 -3
  8. anemoi/datasets/create/functions/filters/pressure_level_relative_humidity_to_specific_humidity.py +57 -0
  9. anemoi/datasets/create/functions/filters/pressure_level_specific_humidity_to_relative_humidity.py +57 -0
  10. anemoi/datasets/create/functions/filters/rename.py +2 -3
  11. anemoi/datasets/create/functions/filters/single_level_dewpoint_to_relative_humidity.py +54 -0
  12. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_dewpoint.py +59 -0
  13. anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_specific_humidity.py +115 -0
  14. anemoi/datasets/create/functions/filters/single_level_specific_humidity_to_relative_humidity.py +390 -0
  15. anemoi/datasets/create/functions/filters/speeddir_to_uv.py +77 -0
  16. anemoi/datasets/create/functions/filters/uv_to_speeddir.py +55 -0
  17. anemoi/datasets/create/functions/sources/__init__.py +7 -1
  18. anemoi/datasets/create/functions/sources/accumulations.py +2 -0
  19. anemoi/datasets/create/functions/sources/grib.py +87 -2
  20. anemoi/datasets/create/functions/sources/hindcasts.py +14 -73
  21. anemoi/datasets/create/functions/sources/mars.py +9 -3
  22. anemoi/datasets/create/functions/sources/xarray/__init__.py +6 -1
  23. anemoi/datasets/create/functions/sources/xarray/coordinates.py +6 -1
  24. anemoi/datasets/create/functions/sources/xarray/field.py +20 -5
  25. anemoi/datasets/create/functions/sources/xarray/fieldlist.py +16 -16
  26. anemoi/datasets/create/functions/sources/xarray/flavour.py +126 -12
  27. anemoi/datasets/create/functions/sources/xarray/grid.py +106 -17
  28. anemoi/datasets/create/functions/sources/xarray/metadata.py +6 -12
  29. anemoi/datasets/create/functions/sources/xarray/time.py +1 -5
  30. anemoi/datasets/create/functions/sources/xarray/variable.py +10 -10
  31. anemoi/datasets/create/input/__init__.py +69 -0
  32. anemoi/datasets/create/input/action.py +123 -0
  33. anemoi/datasets/create/input/concat.py +92 -0
  34. anemoi/datasets/create/input/context.py +59 -0
  35. anemoi/datasets/create/input/data_sources.py +71 -0
  36. anemoi/datasets/create/input/empty.py +42 -0
  37. anemoi/datasets/create/input/filter.py +76 -0
  38. anemoi/datasets/create/input/function.py +122 -0
  39. anemoi/datasets/create/input/join.py +57 -0
  40. anemoi/datasets/create/input/misc.py +85 -0
  41. anemoi/datasets/create/input/pipe.py +33 -0
  42. anemoi/datasets/create/input/repeated_dates.py +217 -0
  43. anemoi/datasets/create/input/result.py +413 -0
  44. anemoi/datasets/create/input/step.py +99 -0
  45. anemoi/datasets/create/{template.py → input/template.py} +0 -42
  46. anemoi/datasets/create/persistent.py +1 -1
  47. anemoi/datasets/create/statistics/__init__.py +1 -1
  48. anemoi/datasets/create/utils.py +3 -0
  49. anemoi/datasets/create/zarr.py +4 -2
  50. anemoi/datasets/data/dataset.py +11 -1
  51. anemoi/datasets/data/debug.py +5 -1
  52. anemoi/datasets/data/masked.py +2 -2
  53. anemoi/datasets/data/rescale.py +147 -0
  54. anemoi/datasets/data/stores.py +20 -7
  55. anemoi/datasets/dates/__init__.py +113 -30
  56. anemoi/datasets/dates/groups.py +92 -19
  57. anemoi/datasets/fields.py +66 -0
  58. anemoi/datasets/utils/fields.py +47 -0
  59. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.5.dist-info}/METADATA +10 -19
  60. anemoi_datasets-0.5.5.dist-info/RECORD +121 -0
  61. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.5.dist-info}/WHEEL +1 -1
  62. anemoi/datasets/create/input.py +0 -1065
  63. anemoi_datasets-0.4.5.dist-info/RECORD +0 -96
  64. /anemoi/datasets/create/{trace.py → input/trace.py} +0 -0
  65. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.5.dist-info}/LICENSE +0 -0
  66. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.5.dist-info}/entry_points.txt +0 -0
  67. {anemoi_datasets-0.4.5.dist-info → anemoi_datasets-0.5.5.dist-info}/top_level.txt +0 -0
@@ -11,9 +11,87 @@
11
11
  import glob
12
12
 
13
13
  from earthkit.data import from_source
14
+ from earthkit.data.indexing.fieldlist import FieldArray
14
15
  from earthkit.data.utils.patterns import Pattern
15
16
 
16
17
 
18
+ def _load(context, name, record):
19
+ ds = None
20
+
21
+ param = record["param"]
22
+
23
+ if "path" in record:
24
+ context.info(f"Using {name} from {record['path']} (param={param})")
25
+ ds = from_source("file", record["path"])
26
+
27
+ if "url" in record:
28
+ context.info(f"Using {name} from {record['url']} (param={param})")
29
+ ds = from_source("url", record["url"])
30
+
31
+ ds = ds.sel(param=param)
32
+
33
+ assert len(ds) == 1, f"{name} {param}, expected one field, got {len(ds)}"
34
+ ds = ds[0]
35
+
36
+ return ds.to_numpy(flatten=True), ds.metadata("uuidOfHGrid")
37
+
38
+
39
+ class Geography:
40
+ """This class retrieve the latitudes and longitudes of unstructured grids,
41
+ and checks if the fields are compatible with the grid.
42
+ """
43
+
44
+ def __init__(self, context, latitudes, longitudes):
45
+
46
+ latitudes, uuidOfHGrid_lat = _load(context, "latitudes", latitudes)
47
+ longitudes, uuidOfHGrid_lon = _load(context, "longitudes", longitudes)
48
+
49
+ assert (
50
+ uuidOfHGrid_lat == uuidOfHGrid_lon
51
+ ), f"uuidOfHGrid mismatch: lat={uuidOfHGrid_lat} != lon={uuidOfHGrid_lon}"
52
+
53
+ context.info(f"Latitudes: {len(latitudes)}, Longitudes: {len(longitudes)}")
54
+ assert len(latitudes) == len(longitudes)
55
+
56
+ self.uuidOfHGrid = uuidOfHGrid_lat
57
+ self.latitudes = latitudes
58
+ self.longitudes = longitudes
59
+ self.first = True
60
+
61
+ def check(self, field):
62
+ if self.first:
63
+ # We only check the first field, for performance reasons
64
+ assert (
65
+ field.metadata("uuidOfHGrid") == self.uuidOfHGrid
66
+ ), f"uuidOfHGrid mismatch: {field.metadata('uuidOfHGrid')} != {self.uuidOfHGrid}"
67
+ self.first = False
68
+
69
+
70
+ class AddGrid:
71
+ """An earth-kit.data.Field wrapper that adds grid information."""
72
+
73
+ def __init__(self, field, geography):
74
+ self._field = field
75
+
76
+ geography.check(field)
77
+
78
+ self._latitudes = geography.latitudes
79
+ self._longitudes = geography.longitudes
80
+
81
+ def __getattr__(self, name):
82
+ return getattr(self._field, name)
83
+
84
+ def __repr__(self) -> str:
85
+ return repr(self._field)
86
+
87
+ def grid_points(self):
88
+ return self._latitudes, self._longitudes
89
+
90
+ @property
91
+ def resolution(self):
92
+ return "unknown"
93
+
94
+
17
95
  def check(ds, paths, **kwargs):
18
96
  count = 1
19
97
  for k, v in kwargs.items():
@@ -34,9 +112,13 @@ def _expand(paths):
34
112
  yield path
35
113
 
36
114
 
37
- def execute(context, dates, path, *args, **kwargs):
115
+ def execute(context, dates, path, latitudes=None, longitudes=None, *args, **kwargs):
38
116
  given_paths = path if isinstance(path, list) else [path]
39
117
 
118
+ geography = None
119
+ if latitudes is not None and longitudes is not None:
120
+ geography = Geography(context, latitudes, longitudes)
121
+
40
122
  ds = from_source("empty")
41
123
  dates = [d.isoformat() for d in dates]
42
124
 
@@ -53,7 +135,10 @@ def execute(context, dates, path, *args, **kwargs):
53
135
  s = s.sel(valid_datetime=dates, **kwargs)
54
136
  ds = ds + s
55
137
 
56
- if kwargs:
138
+ if kwargs and not context.partial_ok:
57
139
  check(ds, given_paths, valid_datetime=dates, **kwargs)
58
140
 
141
+ if geography is not None:
142
+ ds = FieldArray([AddGrid(_, geography) for _ in ds])
143
+
59
144
  return ds
@@ -6,7 +6,6 @@
6
6
  # granted to it by virtue of its status as an intergovernmental organisation
7
7
  # nor does it submit to any jurisdiction.
8
8
  #
9
- import datetime
10
9
  import logging
11
10
 
12
11
  from earthkit.data.core.fieldlist import MultiFieldList
@@ -14,7 +13,6 @@ from earthkit.data.core.fieldlist import MultiFieldList
14
13
  from anemoi.datasets.create.functions.sources.mars import mars
15
14
 
16
15
  LOGGER = logging.getLogger(__name__)
17
- DEBUG = True
18
16
 
19
17
 
20
18
  def _to_list(x):
@@ -23,91 +21,34 @@ def _to_list(x):
23
21
  return [x]
24
22
 
25
23
 
26
- class HindcastCompute:
27
- def __init__(self, base_times, available_steps, request):
28
- self.base_times = base_times
29
- self.available_steps = available_steps
30
- self.request = request
31
-
32
- def compute_hindcast(self, date):
33
- result = []
34
- for step in sorted(self.available_steps): # Use the shortest step
35
- start_date = date - datetime.timedelta(hours=step)
36
- hours = start_date.hour
37
- if hours in self.base_times:
38
- r = self.request.copy()
39
- r["date"] = start_date
40
- r["time"] = f"{start_date.hour:02d}00"
41
- r["step"] = step
42
- result.append(r)
43
-
44
- if not result:
45
- raise ValueError(
46
- f"Cannot find data for {self.request} for {date} (base_times={self.base_times}, "
47
- f"available_steps={self.available_steps})"
48
- )
49
-
50
- if len(result) > 1:
51
- raise ValueError(
52
- f"Multiple requests for {self.request} for {date} (base_times={self.base_times}, "
53
- f"available_steps={self.available_steps})"
54
- )
55
-
56
- return result[0]
57
-
58
-
59
- def use_reference_year(reference_year, request):
60
- request = request.copy()
61
- hdate = request.pop("date")
62
-
63
- if hdate.year >= reference_year:
64
- return None, False
24
+ def hindcasts(context, dates, **request):
65
25
 
66
- try:
67
- date = datetime.datetime(reference_year, hdate.month, hdate.day)
68
- except ValueError:
69
- if hdate.month == 2 and hdate.day == 29:
70
- return None, False
71
- raise
26
+ from anemoi.datasets.dates import HindcastsDates
72
27
 
73
- request.update(date=date.strftime("%Y-%m-%d"), hdate=hdate.strftime("%Y-%m-%d"))
74
- return request, True
28
+ provider = context.dates_provider
29
+ assert isinstance(provider, HindcastsDates)
75
30
 
31
+ context.trace("H️", f"hindcasts {len(dates)=}")
76
32
 
77
- def hindcasts(context, dates, **request):
78
33
  request["param"] = _to_list(request["param"])
79
- request["step"] = _to_list(request["step"])
34
+ request["step"] = _to_list(request.get("step", 0))
80
35
  request["step"] = [int(_) for _ in request["step"]]
81
36
 
82
- if request.get("stream") == "enfh" and "base_times" not in request:
83
- request["base_times"] = [0]
84
-
85
- available_steps = request.pop("step")
86
- available_steps = _to_list(available_steps)
87
-
88
- base_times = request.pop("base_times")
89
-
90
- reference_year = request.pop("reference_year")
37
+ context.trace("H️", f"hindcast {request}")
91
38
 
92
- context.trace("H️", f"hindcast {request} {base_times} {available_steps} {reference_year}")
93
-
94
- c = HindcastCompute(base_times, available_steps, request)
95
39
  requests = []
96
40
  for d in dates:
97
- req = c.compute_hindcast(d)
98
- req, ok = use_reference_year(reference_year, req)
99
- if ok:
100
- requests.append(req)
101
-
102
- # print("HINDCASTS requests", reference_year, base_times, available_steps)
103
- # print("HINDCASTS dates", compress_dates(dates))
41
+ r = request.copy()
42
+ hindcast = provider.mapping[d]
43
+ r["hdate"] = hindcast.hdate.strftime("%Y-%m-%d")
44
+ r["date"] = hindcast.refdate.strftime("%Y-%m-%d")
45
+ r["time"] = hindcast.refdate.strftime("%H")
46
+ r["step"] = hindcast.step
47
+ requests.append(r)
104
48
 
105
49
  if len(requests) == 0:
106
- # print("HINDCASTS no requests")
107
50
  return MultiFieldList([])
108
51
 
109
- # print("HINDCASTS requests", requests)
110
-
111
52
  return mars(
112
53
  context,
113
54
  dates,
@@ -203,16 +203,22 @@ def mars(context, dates, *requests, request_already_using_valid_datetime=False,
203
203
  request_already_using_valid_datetime=request_already_using_valid_datetime,
204
204
  date_key=date_key,
205
205
  )
206
+
207
+ requests = list(requests)
208
+
206
209
  ds = from_source("empty")
210
+ context.trace("✅", f"{[str(d) for d in dates]}")
211
+ context.trace("✅", f"Will run {len(requests)} requests")
212
+ for r in requests:
213
+ r = {k: v for k, v in r.items() if v != ("-",)}
214
+ context.trace("✅", f"mars {r}")
215
+
207
216
  for r in requests:
208
217
  r = {k: v for k, v in r.items() if v != ("-",)}
209
218
 
210
219
  if context.use_grib_paramid and "param" in r:
211
220
  r = use_grib_paramid(r)
212
221
 
213
- if DEBUG:
214
- context.trace("✅", f"from_source(mars, {r}")
215
-
216
222
  for k, v in r.items():
217
223
  if k not in MARS_KEYS:
218
224
  raise ValueError(
@@ -12,6 +12,7 @@ import logging
12
12
  from earthkit.data.core.fieldlist import MultiFieldList
13
13
 
14
14
  from anemoi.datasets.data.stores import name_to_zarr_store
15
+ from anemoi.datasets.utils.fields import NewMetadataField as NewMetadataField
15
16
 
16
17
  from .. import iterate_patterns
17
18
  from .fieldlist import XarrayFieldList
@@ -49,7 +50,11 @@ def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs)
49
50
  data = xr.open_dataset(dataset, **options)
50
51
 
51
52
  fs = XarrayFieldList.from_xarray(data, flavour)
52
- result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])
53
+
54
+ if len(dates) == 0:
55
+ return fs.sel(**kwargs)
56
+ else:
57
+ result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])
53
58
 
54
59
  if len(result) == 0:
55
60
  LOG.warning(f"No data found for {dataset} and dates {dates} and {kwargs}")
@@ -56,6 +56,8 @@ class Coordinate:
56
56
  is_step = False
57
57
  is_date = False
58
58
  is_member = False
59
+ is_x = False
60
+ is_y = False
59
61
 
60
62
  def __init__(self, variable):
61
63
  self.variable = variable
@@ -66,10 +68,11 @@ class Coordinate:
66
68
  return 1 if self.scalar else len(self.variable)
67
69
 
68
70
  def __repr__(self):
69
- return "%s[name=%s,values=%s]" % (
71
+ return "%s[name=%s,values=%s,shape=%s]" % (
70
72
  self.__class__.__name__,
71
73
  self.variable.name,
72
74
  self.variable.values if self.scalar else len(self),
75
+ self.variable.shape,
73
76
  )
74
77
 
75
78
  def reduced(self, i):
@@ -225,11 +228,13 @@ class LatitudeCoordinate(Coordinate):
225
228
 
226
229
  class XCoordinate(Coordinate):
227
230
  is_grid = True
231
+ is_x = True
228
232
  mars_names = ("x",)
229
233
 
230
234
 
231
235
  class YCoordinate(Coordinate):
232
236
  is_grid = True
237
+ is_y = True
233
238
  mars_names = ("y",)
234
239
 
235
240
 
@@ -7,6 +7,7 @@
7
7
  # nor does it submit to any jurisdiction.
8
8
  #
9
9
 
10
+ import datetime
10
11
  import logging
11
12
 
12
13
  from earthkit.data.core.fieldlist import Field
@@ -71,13 +72,18 @@ class XArrayField(Field):
71
72
  def shape(self):
72
73
  return self._shape
73
74
 
74
- def to_numpy(self, flatten=False, dtype=None):
75
- values = self.selection.values
75
+ def to_numpy(self, flatten=False, dtype=None, index=None):
76
+ if index is not None:
77
+ values = self.selection[index]
78
+ else:
79
+ values = self.selection
76
80
 
77
81
  assert dtype is None
82
+
78
83
  if flatten:
79
- return values.flatten()
80
- return values.reshape(self.shape)
84
+ return values.values.flatten()
85
+
86
+ return values # .reshape(self.shape)
81
87
 
82
88
  def _make_metadata(self):
83
89
  return XArrayMetadata(self)
@@ -103,7 +109,16 @@ class XArrayField(Field):
103
109
 
104
110
  @property
105
111
  def forecast_reference_time(self):
106
- return self.owner.forecast_reference_time
112
+ date, time = self.metadata("date", "time")
113
+ assert len(time) == 4, time
114
+ assert len(date) == 8, date
115
+ yyyymmdd = int(date)
116
+ time = int(time) // 100
117
+ return datetime.datetime(yyyymmdd // 10000, yyyymmdd // 100 % 100, yyyymmdd % 100, time)
107
118
 
108
119
  def __repr__(self):
109
120
  return repr(self._metadata)
121
+
122
+ def _values(self):
123
+ # we don't use .values as this will download the data
124
+ return self.selection
@@ -70,10 +70,10 @@ class XarrayFieldList(FieldList):
70
70
  skip.update(attr_val.split(" "))
71
71
 
72
72
  for name in ds.data_vars:
73
- v = ds[name]
74
- _skip_attr(v, "coordinates")
75
- _skip_attr(v, "bounds")
76
- _skip_attr(v, "grid_mapping")
73
+ variable = ds[name]
74
+ _skip_attr(variable, "coordinates")
75
+ _skip_attr(variable, "bounds")
76
+ _skip_attr(variable, "grid_mapping")
77
77
 
78
78
  # Select only geographical variables
79
79
  for name in ds.data_vars:
@@ -81,14 +81,14 @@ class XarrayFieldList(FieldList):
81
81
  if name in skip:
82
82
  continue
83
83
 
84
- v = ds[name]
84
+ variable = ds[name]
85
85
  coordinates = []
86
86
 
87
- for coord in v.coords:
87
+ for coord in variable.coords:
88
88
 
89
89
  c = guess.guess(ds[coord], coord)
90
90
  assert c, f"Could not guess coordinate for {coord}"
91
- if coord not in v.dims:
91
+ if coord not in variable.dims:
92
92
  c.is_dim = False
93
93
  coordinates.append(c)
94
94
 
@@ -98,17 +98,17 @@ class XarrayFieldList(FieldList):
98
98
  if grid_coords < 2:
99
99
  continue
100
100
 
101
- variables.append(
102
- Variable(
103
- ds=ds,
104
- var=v,
105
- coordinates=coordinates,
106
- grid=guess.grid(coordinates),
107
- time=Time.from_coordinates(coordinates),
108
- metadata={},
109
- )
101
+ v = Variable(
102
+ ds=ds,
103
+ variable=variable,
104
+ coordinates=coordinates,
105
+ grid=guess.grid(coordinates, variable),
106
+ time=Time.from_coordinates(coordinates),
107
+ metadata={},
110
108
  )
111
109
 
110
+ variables.append(v)
111
+
112
112
  return cls(ds, variables)
113
113
 
114
114
  def sel(self, **kwargs):
@@ -8,6 +8,8 @@
8
8
  #
9
9
 
10
10
 
11
+ import logging
12
+
11
13
  from .coordinates import DateCoordinate
12
14
  from .coordinates import EnsembleCoordinate
13
15
  from .coordinates import LatitudeCoordinate
@@ -18,8 +20,13 @@ from .coordinates import StepCoordinate
18
20
  from .coordinates import TimeCoordinate
19
21
  from .coordinates import XCoordinate
20
22
  from .coordinates import YCoordinate
23
+ from .coordinates import is_scalar
21
24
  from .grid import MeshedGrid
25
+ from .grid import MeshProjectionGrid
22
26
  from .grid import UnstructuredGrid
27
+ from .grid import UnstructuredProjectionGrid
28
+
29
+ LOG = logging.getLogger(__name__)
23
30
 
24
31
 
25
32
  class CoordinateGuesser:
@@ -155,31 +162,138 @@ class CoordinateGuesser:
155
162
  f" {long_name=}, {standard_name=}, units\n\n{c}\n\n{type(c.values)} {c.shape}"
156
163
  )
157
164
 
158
- def grid(self, coordinates):
165
+ def grid(self, coordinates, variable):
159
166
  lat = [c for c in coordinates if c.is_lat]
160
167
  lon = [c for c in coordinates if c.is_lon]
161
168
 
162
- if len(lat) != 1:
163
- raise NotImplementedError(f"Expected 1 latitude coordinate, got {len(lat)}")
169
+ if len(lat) == 1 and len(lon) == 1:
170
+ return self._lat_lon_provided(lat, lon, variable)
171
+
172
+ x = [c for c in coordinates if c.is_x]
173
+ y = [c for c in coordinates if c.is_y]
174
+
175
+ if len(x) == 1 and len(y) == 1:
176
+ return self._x_y_provided(x, y, variable)
177
+
178
+ raise NotImplementedError(f"Cannot establish grid {coordinates}")
179
+
180
+ def _check_dims(self, variable, x_or_lon, y_or_lat):
181
+
182
+ x_dims = set(x_or_lon.variable.dims)
183
+ y_dims = set(y_or_lat.variable.dims)
184
+ variable_dims = set(variable.dims)
164
185
 
165
- if len(lon) != 1:
166
- raise NotImplementedError(f"Expected 1 longitude coordinate, got {len(lon)}")
186
+ if not (x_dims <= variable_dims) or not (y_dims <= variable_dims):
187
+ raise ValueError(
188
+ f"Dimensions do not match {variable.name}{variable.dims} !="
189
+ f" {x_or_lon.name}{x_or_lon.variable.dims} and {y_or_lat.name}{y_or_lat.variable.dims}"
190
+ )
191
+
192
+ variable_dims = tuple(v for v in variable.dims if v in (x_dims | y_dims))
193
+ if x_dims == y_dims:
194
+ # It's unstructured
195
+ return variable_dims, True
167
196
 
197
+ if len(x_dims) == 1 and len(y_dims) == 1:
198
+ # It's a mesh
199
+ return variable_dims, False
200
+
201
+ raise ValueError(
202
+ f"Cannot establish grid for {variable.name}{variable.dims},"
203
+ f" {x_or_lon.name}{x_or_lon.variable.dims},"
204
+ f" {y_or_lat.name}{y_or_lat.variable.dims}"
205
+ )
206
+
207
+ def _lat_lon_provided(self, lat, lon, variable):
168
208
  lat = lat[0]
169
209
  lon = lon[0]
170
210
 
171
- if (lat.name, lon.name) in self._cache:
172
- return self._cache[(lat.name, lon.name)]
211
+ dim_vars, unstructured = self._check_dims(variable, lon, lat)
212
+
213
+ if (lat.name, lon.name, dim_vars) in self._cache:
214
+ return self._cache[(lat.name, lon.name, dim_vars)]
173
215
 
174
- assert len(lat.variable.shape) == len(lon.variable.shape), (lat.variable.shape, lon.variable.shape)
175
- if len(lat.variable.shape) == 1:
176
- grid = MeshedGrid(lat, lon)
216
+ if unstructured:
217
+ grid = UnstructuredGrid(lat, lon, dim_vars)
177
218
  else:
178
- grid = UnstructuredGrid(lat, lon)
219
+ grid = MeshedGrid(lat, lon, dim_vars)
179
220
 
180
- self._cache[(lat.name, lon.name)] = grid
221
+ self._cache[(lat.name, lon.name, dim_vars)] = grid
181
222
  return grid
182
223
 
224
+ def _x_y_provided(self, x, y, variable):
225
+ x = x[0]
226
+ y = y[0]
227
+
228
+ _, unstructured = self._check_dims(variable, x, y)
229
+
230
+ if x.variable.dims != y.variable.dims:
231
+ raise ValueError(f"Dimensions do not match {x.name}{x.variable.dims} != {y.name}{y.variable.dims}")
232
+
233
+ if (x.name, y.name) in self._cache:
234
+ return self._cache[(x.name, y.name)]
235
+
236
+ if (x.name, y.name) in self._cache:
237
+ return self._cache[(x.name, y.name)]
238
+
239
+ assert len(x.variable.shape) == len(y.variable.shape), (x.variable.shape, y.variable.shape)
240
+
241
+ grid_mapping = variable.attrs.get("grid_mapping", None)
242
+
243
+ if grid_mapping is None:
244
+ LOG.warning(f"No 'grid_mapping' attribute provided for '{variable.name}'")
245
+ LOG.warning("Trying to guess...")
246
+
247
+ PROBE = {
248
+ "prime_meridian_name",
249
+ "reference_ellipsoid_name",
250
+ "crs_wkt",
251
+ "horizontal_datum_name",
252
+ "semi_major_axis",
253
+ "spatial_ref",
254
+ "inverse_flattening",
255
+ "semi_minor_axis",
256
+ "geographic_crs_name",
257
+ "GeoTransform",
258
+ "grid_mapping_name",
259
+ "longitude_of_prime_meridian",
260
+ }
261
+ candidate = None
262
+ for v in self.ds.variables:
263
+ var = self.ds[v]
264
+ if not is_scalar(var):
265
+ continue
266
+
267
+ if PROBE.intersection(var.attrs.keys()):
268
+ if candidate:
269
+ raise ValueError(f"Multiple candidates for 'grid_mapping': {candidate} and {v}")
270
+ candidate = v
271
+
272
+ if candidate:
273
+ LOG.warning(f"Using '{candidate}' as 'grid_mapping'")
274
+ grid_mapping = candidate
275
+ else:
276
+ LOG.warning("Could not fine a candidate for 'grid_mapping'")
277
+
278
+ if grid_mapping is None:
279
+ if "crs" in self.ds[variable].attrs:
280
+ grid_mapping = self.ds[variable].attrs["crs"]
281
+ LOG.warning(f"Using CRS {grid_mapping} from variable '{variable.name}' attributes")
282
+
283
+ if grid_mapping is None:
284
+ if "crs" in self.ds.attrs:
285
+ grid_mapping = self.ds.attrs["crs"]
286
+ LOG.warning(f"Using CRS {grid_mapping} from global attributes")
287
+
288
+ if grid_mapping is not None:
289
+ if unstructured:
290
+ return UnstructuredProjectionGrid(x, y, grid_mapping)
291
+ else:
292
+ return MeshProjectionGrid(x, y, grid_mapping)
293
+
294
+ LOG.error("Could not fine a candidate for 'grid_mapping'")
295
+ raise NotImplementedError(f"Unstructured grid {x.name} {y.name}")
296
+
183
297
 
184
298
  class DefaultCoordinateGuesser(CoordinateGuesser):
185
299
  def __init__(self, ds):