anemoi-datasets 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. anemoi/datasets/__main__.py +7 -51
  2. anemoi/datasets/_version.py +2 -2
  3. anemoi/datasets/commands/__init__.py +5 -59
  4. anemoi/datasets/commands/copy.py +141 -83
  5. anemoi/datasets/commands/create.py +14 -3
  6. anemoi/datasets/commands/inspect/__init__.py +1 -5
  7. anemoi/datasets/compute/{perturbations.py → recentre.py} +24 -23
  8. anemoi/datasets/create/__init__.py +3 -0
  9. anemoi/datasets/create/config.py +7 -1
  10. anemoi/datasets/create/functions/sources/accumulations.py +7 -3
  11. anemoi/datasets/create/functions/sources/hindcasts.py +437 -0
  12. anemoi/datasets/create/functions/sources/mars.py +13 -7
  13. anemoi/datasets/create/functions/sources/{perturbations.py → recentre.py} +5 -5
  14. anemoi/datasets/create/input.py +0 -5
  15. anemoi/datasets/create/loaders.py +36 -0
  16. anemoi/datasets/create/persistent.py +1 -3
  17. anemoi/datasets/create/statistics/__init__.py +7 -17
  18. anemoi/datasets/create/statistics/summary.py +1 -4
  19. anemoi/datasets/create/writer.py +4 -3
  20. anemoi/datasets/data/indexing.py +1 -3
  21. anemoi/datasets/data/stores.py +2 -6
  22. anemoi/datasets/data/unchecked.py +1 -6
  23. anemoi/datasets/grids.py +2 -2
  24. {anemoi_datasets-0.2.0.dist-info → anemoi_datasets-0.3.0.dist-info}/METADATA +30 -21
  25. {anemoi_datasets-0.2.0.dist-info → anemoi_datasets-0.3.0.dist-info}/RECORD +29 -28
  26. {anemoi_datasets-0.2.0.dist-info → anemoi_datasets-0.3.0.dist-info}/LICENSE +0 -0
  27. {anemoi_datasets-0.2.0.dist-info → anemoi_datasets-0.3.0.dist-info}/WHEEL +0 -0
  28. {anemoi_datasets-0.2.0.dist-info → anemoi_datasets-0.3.0.dist-info}/entry_points.txt +0 -0
  29. {anemoi_datasets-0.2.0.dist-info → anemoi_datasets-0.3.0.dist-info}/top_level.txt +0 -0
@@ -8,64 +8,20 @@
8
8
  # nor does it submit to any jurisdiction.
9
9
  #
10
10
 
11
-
12
- import argparse
13
- import logging
14
- import sys
15
- import traceback
11
+ from anemoi.utils.cli import cli_main
12
+ from anemoi.utils.cli import make_parser
16
13
 
17
14
  from . import __version__
18
15
  from .commands import COMMANDS
19
16
 
20
- LOG = logging.getLogger(__name__)
21
-
22
-
23
- def main():
24
- parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
25
-
26
- parser.add_argument(
27
- "--version",
28
- "-V",
29
- action="store_true",
30
- help="show the version and exit",
31
- )
32
- parser.add_argument(
33
- "--debug",
34
- "-d",
35
- action="store_true",
36
- help="Debug mode",
37
- )
38
17
 
39
- subparsers = parser.add_subparsers(help="commands:", dest="command")
40
- for name, command in COMMANDS.items():
41
- command_parser = subparsers.add_parser(name, help=command.__doc__)
42
- command.add_arguments(command_parser)
18
+ # For read-the-docs
19
+ def create_parser():
20
+ return make_parser(__doc__, COMMANDS)
43
21
 
44
- args = parser.parse_args()
45
22
 
46
- if args.version:
47
- print(__version__)
48
- return
49
-
50
- if args.command is None:
51
- parser.print_help()
52
- return
53
-
54
- cmd = COMMANDS[args.command]
55
-
56
- logging.basicConfig(
57
- format="%(asctime)s %(levelname)s %(message)s",
58
- datefmt="%Y-%m-%d %H:%M:%S",
59
- level=logging.DEBUG if args.debug else logging.INFO,
60
- )
61
-
62
- try:
63
- cmd.run(args)
64
- except ValueError as e:
65
- traceback.print_exc()
66
- LOG.error("\n💣 %s", str(e).lstrip())
67
- LOG.error("💣 Exiting")
68
- sys.exit(1)
23
+ def main():
24
+ cli_main(__version__, __doc__, COMMANDS)
69
25
 
70
26
 
71
27
  if __name__ == "__main__":
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.0'
16
- __version_tuple__ = version_tuple = (0, 2, 0)
15
+ __version__ = version = '0.3.0'
16
+ __version_tuple__ = version_tuple = (0, 3, 0)
@@ -8,69 +8,15 @@
8
8
  # nor does it submit to any jurisdiction.
9
9
  #
10
10
 
11
- import argparse
12
- import importlib
13
- import logging
14
11
  import os
15
- import sys
16
12
 
17
- LOG = logging.getLogger(__name__)
13
+ from anemoi.utils.cli import Command
14
+ from anemoi.utils.cli import Failed
15
+ from anemoi.utils.cli import register_commands
18
16
 
17
+ __all__ = ["Command"]
19
18
 
20
- def register(here, package, select, fail=None):
21
- result = {}
22
- not_available = {}
23
-
24
- for p in os.listdir(here):
25
- full = os.path.join(here, p)
26
- if p.startswith("_"):
27
- continue
28
- if not (p.endswith(".py") or (os.path.isdir(full) and os.path.exists(os.path.join(full, "__init__.py")))):
29
- continue
30
-
31
- name, _ = os.path.splitext(p)
32
-
33
- try:
34
- imported = importlib.import_module(
35
- f".{name}",
36
- package=package,
37
- )
38
- except ImportError as e:
39
- not_available[name] = e
40
- continue
41
-
42
- obj = select(imported)
43
- if obj is not None:
44
- result[name] = obj
45
-
46
- for name, e in not_available.items():
47
- if fail is None:
48
- pass
49
- if callable(fail):
50
- result[name] = fail(name, e)
51
-
52
- return result
53
-
54
-
55
- class Command:
56
- def run(self, args):
57
- raise NotImplementedError(f"Command not implemented: {args.command}")
58
-
59
-
60
- class Failed(Command):
61
- def __init__(self, name, error):
62
- self.name = name
63
- self.error = error
64
-
65
- def add_arguments(self, command_parser):
66
- command_parser.add_argument("x", nargs=argparse.REMAINDER)
67
-
68
- def run(self, args):
69
- print(f"Command '{self.name}' not available: {self.error}")
70
- sys.exit(1)
71
-
72
-
73
- COMMANDS = register(
19
+ COMMANDS = register_commands(
74
20
  os.path.dirname(__file__),
75
21
  __name__,
76
22
  lambda x: x.command(),
@@ -41,24 +41,19 @@ zinfo https://object-store.os-api.cci1.ecmwf.int/
41
41
  """
42
42
 
43
43
 
44
- class CopyMixin:
45
- internal = True
46
- timestamp = True
47
-
48
- def add_arguments(self, command_parser):
49
- command_parser.add_argument("--transfers", type=int, default=8)
50
- command_parser.add_argument("--block-size", type=int, default=100)
51
- command_parser.add_argument("--overwrite", action="store_true")
52
- command_parser.add_argument("--progress", action="store_true")
53
- command_parser.add_argument("--nested", action="store_true", help="Use ZARR's nested directpry backend.")
54
- command_parser.add_argument(
55
- "--rechunk",
56
- nargs="+",
57
- help="Rechunk given array.",
58
- metavar="array=i,j,k,l",
59
- )
60
- command_parser.add_argument("source")
61
- command_parser.add_argument("target")
44
+ class Copier:
45
+ def __init__(self, source, target, transfers, block_size, overwrite, resume, progress, nested, rechunk, **kwargs):
46
+ self.source = source
47
+ self.target = target
48
+ self.transfers = transfers
49
+ self.block_size = block_size
50
+ self.overwrite = overwrite
51
+ self.resume = resume
52
+ self.progress = progress
53
+ self.nested = nested
54
+ self.rechunk = rechunk
55
+
56
+ self.rechunking = rechunk.split(",") if rechunk else []
62
57
 
63
58
  def _store(self, path, nested=False):
64
59
  if nested:
@@ -67,30 +62,56 @@ class CopyMixin:
67
62
  return zarr.storage.NestedDirectoryStore(path)
68
63
  return path
69
64
 
70
- def copy_chunk(self, n, m, source, target, block_size, _copy, progress):
65
+ def copy_chunk(self, n, m, source, target, _copy, progress):
71
66
  if _copy[n:m].all():
72
67
  LOG.info(f"Skipping {n} to {m}")
73
68
  return None
74
69
 
75
- for i in tqdm.tqdm(
76
- range(n, m),
77
- desc=f"Copying {n} to {m}",
78
- leave=False,
79
- disable=not isatty and not progress,
80
- ):
81
- target[i] = source[i]
70
+ if self.block_size % self.data_chunks[0] == 0:
71
+ target[slice(n, m)] = source[slice(n, m)]
72
+ else:
73
+ LOG.warning(
74
+ f"Block size ({self.block_size}) is not a multiple of target chunk size ({self.data_chunks[0]}). Slow copy expected."
75
+ )
76
+ if self.transfers > 1:
77
+ # race condition, different threads might copy the same data to the same chunk
78
+ raise NotImplementedError(
79
+ "Block size is not a multiple of target chunk size. Parallel copy not supported."
80
+ )
81
+ for i in tqdm.tqdm(
82
+ range(n, m),
83
+ desc=f"Copying {n} to {m}",
84
+ leave=False,
85
+ disable=not isatty and not progress,
86
+ ):
87
+ target[i] = source[i]
88
+
82
89
  return slice(n, m)
83
90
 
84
- def copy_data(self, source, target, transfers, block_size, _copy, progress, rechunking):
91
+ def parse_rechunking(self, rechunking, source_data):
92
+ shape = source_data.shape
93
+ chunks = list(source_data.chunks)
94
+ for i, c in enumerate(rechunking):
95
+ if not c:
96
+ continue
97
+ elif c == "full":
98
+ chunks[i] = shape[i]
99
+ c = int(c)
100
+ c = min(c, shape[i])
101
+ chunks[i] = c
102
+ chunks = tuple(chunks)
103
+
104
+ if chunks != source_data.chunks:
105
+ LOG.info(f"Rechunking data from {source_data.chunks} to {chunks}")
106
+ # if self.transfers > 1:
107
+ # raise NotImplementedError("Rechunking with multiple transfers is not implemented")
108
+ return chunks
109
+
110
+ def copy_data(self, source, target, _copy, progress):
85
111
  LOG.info("Copying data")
86
112
  source_data = source["data"]
87
113
 
88
- chunks = list(source_data.chunks)
89
- if "data" in rechunking:
90
- assert len(chunks) == len(rechunking["data"]), (chunks, rechunking["data"])
91
- for i, c in enumerate(rechunking["data"]):
92
- if c != -1:
93
- chunks[i] = c
114
+ self.data_chunks = self.parse_rechunking(self.rechunking, source_data)
94
115
 
95
116
  target_data = (
96
117
  target["data"]
@@ -98,12 +119,12 @@ class CopyMixin:
98
119
  else target.create_dataset(
99
120
  "data",
100
121
  shape=source_data.shape,
101
- chunks=chunks,
122
+ chunks=self.data_chunks,
102
123
  dtype=source_data.dtype,
103
124
  )
104
125
  )
105
126
 
106
- executor = ThreadPoolExecutor(max_workers=transfers)
127
+ executor = ThreadPoolExecutor(max_workers=self.transfers)
107
128
  tasks = []
108
129
  n = 0
109
130
  while n < target_data.shape[0]:
@@ -111,15 +132,14 @@ class CopyMixin:
111
132
  executor.submit(
112
133
  self.copy_chunk,
113
134
  n,
114
- min(n + block_size, target_data.shape[0]),
135
+ min(n + self.block_size, target_data.shape[0]),
115
136
  source_data,
116
137
  target_data,
117
- block_size,
118
138
  _copy,
119
139
  progress,
120
140
  )
121
141
  )
122
- n += block_size
142
+ n += self.block_size
123
143
 
124
144
  for future in tqdm.tqdm(as_completed(tasks), total=len(tasks), smoothing=0):
125
145
  copied = future.result()
@@ -131,7 +151,7 @@ class CopyMixin:
131
151
 
132
152
  LOG.info("Copied data")
133
153
 
134
- def copy_array(self, name, source, target, transfers, block_size, _copy, progress, rechunking):
154
+ def copy_array(self, name, source, target, _copy, progress):
135
155
  for k, v in source.attrs.items():
136
156
  target.attrs[k] = v
137
157
 
@@ -139,14 +159,14 @@ class CopyMixin:
139
159
  return
140
160
 
141
161
  if name == "data":
142
- self.copy_data(source, target, transfers, block_size, _copy, progress, rechunking)
162
+ self.copy_data(source, target, _copy, progress)
143
163
  return
144
164
 
145
165
  LOG.info(f"Copying {name}")
146
166
  target[name] = source[name]
147
167
  LOG.info(f"Copied {name}")
148
168
 
149
- def copy_group(self, source, target, transfers, block_size, _copy, progress, rechunking):
169
+ def copy_group(self, source, target, _copy, progress):
150
170
  import zarr
151
171
 
152
172
  for k, v in source.attrs.items():
@@ -158,25 +178,19 @@ class CopyMixin:
158
178
  self.copy_group(
159
179
  source[name],
160
180
  group,
161
- transfers,
162
- block_size,
163
181
  _copy,
164
182
  progress,
165
- rechunking,
166
183
  )
167
184
  else:
168
185
  self.copy_array(
169
186
  name,
170
187
  source,
171
188
  target,
172
- transfers,
173
- block_size,
174
189
  _copy,
175
190
  progress,
176
- rechunking,
177
191
  )
178
192
 
179
- def copy(self, source, target, transfers, block_size, progress, rechunking):
193
+ def copy(self, source, target, progress):
180
194
  import zarr
181
195
 
182
196
  if "_copy" not in target:
@@ -187,32 +201,26 @@ class CopyMixin:
187
201
  _copy = target["_copy"]
188
202
  _copy_np = _copy[:]
189
203
 
190
- self.copy_group(source, target, transfers, block_size, _copy_np, progress, rechunking)
204
+ self.copy_group(source, target, _copy_np, progress)
191
205
  del target["_copy"]
192
206
 
193
- def run(self, args):
207
+ def run(self):
194
208
  import zarr
195
209
 
196
210
  # base, ext = os.path.splitext(os.path.basename(args.source))
197
211
  # assert ext == ".zarr", ext
198
212
  # assert "." not in base, base
199
- LOG.info(f"Copying {args.source} to {args.target}")
200
-
201
- rechunking = {}
202
- if args.rechunk:
203
- for r in args.rechunk:
204
- k, v = r.split("=")
205
- if k != "data":
206
- raise ValueError(f"Only rechunking data is supported: {k}")
207
- values = v.split(",")
208
- values = [-1 if x == "" else x for x in values]
209
- values = tuple(int(x) for x in values)
210
- rechunking[k] = values
211
- for k, v in rechunking.items():
212
- LOG.info(f"Rechunking {k} to {v}")
213
-
214
- try:
215
- target = zarr.open(self._store(args.target, args.nested), mode="r")
213
+ LOG.info(f"Copying {self.source} to {self.target}")
214
+
215
+ def target_exists():
216
+ try:
217
+ zarr.open(self._store(self.target), mode="r")
218
+ return True
219
+ except ValueError:
220
+ return False
221
+
222
+ def target_finished():
223
+ target = zarr.open(self._store(self.target), mode="r")
216
224
  if "_copy" in target:
217
225
  done = sum(1 if x else 0 for x in target["_copy"])
218
226
  todo = len(target["_copy"])
@@ -222,26 +230,76 @@ class CopyMixin:
222
230
  todo,
223
231
  int(done / todo * 100 + 0.5),
224
232
  )
233
+ return False
225
234
  elif "sums" in target and "data" in target: # sums is copied last
226
- LOG.error("Target already exists")
227
- return
228
- except ValueError as e:
229
- LOG.info(f"Target does not exist: {e}")
230
- pass
231
-
232
- source = zarr.open(self._store(args.source), mode="r")
233
- if args.overwrite:
234
- target = zarr.open(self._store(args.target, args.nested), mode="w")
235
- else:
236
- try:
237
- target = zarr.open(self._store(args.target, args.nested), mode="w+")
238
- except ValueError:
239
- target = zarr.open(self._store(args.target, args.nested), mode="w")
240
- self.copy(source, target, args.transfers, args.block_size, args.progress, rechunking)
235
+ return True
236
+ return False
237
+
238
+ def open_target():
239
+
240
+ if not target_exists():
241
+ return zarr.open(self._store(self.target, self.nested), mode="w")
242
+
243
+ if self.overwrite:
244
+ LOG.error("Target already exists, overwriting.")
245
+ return zarr.open(self._store(self.target, self.nested), mode="w")
246
+
247
+ if self.resume:
248
+ if target_finished():
249
+ LOG.error("Target already exists and is finished.")
250
+ sys.exit(0)
251
+
252
+ LOG.error("Target already exists, resuming copy.")
253
+ return zarr.open(self._store(self.target, self.nested), mode="w+")
254
+
255
+ LOG.error("Target already exists, use either --overwrite or --resume.")
256
+ sys.exit(1)
257
+
258
+ target = open_target()
259
+
260
+ assert target is not None, target
261
+
262
+ source = zarr.open(self._store(self.source), mode="r")
263
+ self.copy(source, target, self.progress)
264
+
265
+
266
+ class CopyMixin:
267
+ internal = True
268
+ timestamp = True
269
+
270
+ def add_arguments(self, command_parser):
271
+ group = command_parser.add_mutually_exclusive_group()
272
+ group.add_argument(
273
+ "--overwrite",
274
+ action="store_true",
275
+ help="Overwrite existing dataset. This will delete the target dataset if it already exists. Cannot be used with --resume.",
276
+ )
277
+ group.add_argument(
278
+ "--resume", action="store_true", help="Resume copying an existing dataset. Cannot be used with --overwrite."
279
+ )
280
+ command_parser.add_argument("--transfers", type=int, default=8, help="Number of parallel transfers.")
281
+ command_parser.add_argument(
282
+ "--progress", action="store_true", help="Force show progress bar, even if not in an interactive shell."
283
+ )
284
+ command_parser.add_argument("--nested", action="store_true", help="Use ZARR's nested directpry backend.")
285
+ command_parser.add_argument(
286
+ "--rechunk", help="Rechunk the target data array. Rechunk size should be a diviser of the block size."
287
+ )
288
+ command_parser.add_argument(
289
+ "--block-size",
290
+ type=int,
291
+ default=100,
292
+ help="For optimisation purposes, data is transfered by blocks. Default is 100.",
293
+ )
294
+ command_parser.add_argument("source", help="Source location.")
295
+ command_parser.add_argument("target", help="Target location.")
296
+
297
+ def run(self, args):
298
+ Copier(**vars(args)).run()
241
299
 
242
300
 
243
301
  class Copy(CopyMixin, Command):
244
- pass
302
+ """Copy a dataset from one location to another."""
245
303
 
246
304
 
247
305
  command = Copy
@@ -4,13 +4,24 @@ from . import Command
4
4
 
5
5
 
6
6
  class Create(Command):
7
+ """Create a dataset."""
8
+
7
9
  internal = True
8
10
  timestamp = True
9
11
 
10
12
  def add_arguments(self, command_parser):
11
- command_parser.add_argument("--overwrite", action="store_true", help="Overwrite existing files")
12
- command_parser.add_argument("config", help="Configuration file")
13
- command_parser.add_argument("path", help="Path to store the created data")
13
+ command_parser.add_argument(
14
+ "--overwrite",
15
+ action="store_true",
16
+ help="Overwrite existing files. This will delete the target dataset if it already exists.",
17
+ )
18
+ command_parser.add_argument(
19
+ "--test",
20
+ action="store_true",
21
+ help="Build a small dataset, using only the first dates. And, when possible, using low resolution and less ensemble members.",
22
+ )
23
+ command_parser.add_argument("config", help="Configuration yaml file defining the recipe to create the dataset.")
24
+ command_parser.add_argument("path", help="Path to store the created data.")
14
25
 
15
26
  def run(self, args):
16
27
  kwargs = vars(args)
@@ -11,16 +11,12 @@ import os
11
11
  from .. import Command
12
12
  from .zarr import InspectZarr
13
13
 
14
- # from .checkpoint import InspectCheckpoint
15
-
16
14
 
17
15
  class Inspect(Command, InspectZarr):
18
- # class Inspect(Command, InspectCheckpoint, InspectZarr):
19
- """Inspect a checkpoint or zarr file."""
16
+ """Inspect a zarr dataset."""
20
17
 
21
18
  def add_arguments(self, command_parser):
22
19
  # g = command_parser.add_mutually_exclusive_group()
23
- # g.add_argument("--inspect", action="store_true", help="Inspect weights")
24
20
  command_parser.add_argument("path", metavar="PATH", nargs="+")
25
21
  command_parser.add_argument("--detailed", action="store_true")
26
22
  # command_parser.add_argument("--probe", action="store_true")
@@ -32,7 +32,7 @@ CLIP_VARIABLES = (
32
32
  SKIP = ("class", "stream", "type", "number", "expver", "_leg_number", "anoffset")
33
33
 
34
34
 
35
- def check_compatible(f1, f2, center_field_as_mars, ensemble_field_as_mars):
35
+ def check_compatible(f1, f2, centre_field_as_mars, ensemble_field_as_mars):
36
36
  assert f1.mars_grid == f2.mars_grid, (f1.mars_grid, f2.mars_grid)
37
37
  assert f1.mars_area == f2.mars_area, (f1.mars_area, f2.mars_area)
38
38
  assert f1.shape == f2.shape, (f1.shape, f2.shape)
@@ -43,21 +43,22 @@ def check_compatible(f1, f2, center_field_as_mars, ensemble_field_as_mars):
43
43
  f2.metadata("valid_datetime"),
44
44
  )
45
45
 
46
- for k in set(center_field_as_mars.keys()) | set(ensemble_field_as_mars.keys()):
46
+ for k in set(centre_field_as_mars.keys()) | set(ensemble_field_as_mars.keys()):
47
47
  if k in SKIP:
48
48
  continue
49
- assert center_field_as_mars[k] == ensemble_field_as_mars[k], (
49
+ assert centre_field_as_mars[k] == ensemble_field_as_mars[k], (
50
50
  k,
51
- center_field_as_mars[k],
51
+ centre_field_as_mars[k],
52
52
  ensemble_field_as_mars[k],
53
53
  )
54
54
 
55
55
 
56
- def perturbations(
56
+ def recentre(
57
57
  *,
58
58
  members,
59
- center,
59
+ centre,
60
60
  clip_variables=CLIP_VARIABLES,
61
+ alpha=1.0,
61
62
  output=None,
62
63
  ):
63
64
 
@@ -70,16 +71,16 @@ def perturbations(
70
71
 
71
72
  LOG.info("Ordering fields")
72
73
  members = members.order_by(*keys)
73
- center = center.order_by(*keys)
74
+ centre = centre.order_by(*keys)
74
75
  LOG.info("Done")
75
76
 
76
- if len(center) * n_numbers != len(members):
77
- LOG.error("%s %s %s", len(center), n_numbers, len(members))
77
+ if len(centre) * n_numbers != len(members):
78
+ LOG.error("%s %s %s", len(centre), n_numbers, len(members))
78
79
  for f in members:
79
80
  LOG.error("Member: %r", f)
80
- for f in center:
81
- LOG.error("Center: %r", f)
82
- raise ValueError(f"Inconsistent number of fields: {len(center)} * {n_numbers} != {len(members)}")
81
+ for f in centre:
82
+ LOG.error("centre: %r", f)
83
+ raise ValueError(f"Inconsistent number of fields: {len(centre)} * {n_numbers} != {len(members)}")
83
84
 
84
85
  if output is None:
85
86
  # prepare output tmp file so we can read it back
@@ -93,32 +94,32 @@ def perturbations(
93
94
 
94
95
  seen = set()
95
96
 
96
- for i, center_field in enumerate(center):
97
- param = center_field.metadata("param")
98
- center_field_as_mars = center_field.as_mars()
97
+ for i, centre_field in enumerate(centre):
98
+ param = centre_field.metadata("param")
99
+ centre_field_as_mars = centre_field.as_mars()
99
100
 
100
- # load the center field
101
- center_np = center_field.to_numpy()
101
+ # load the centre field
102
+ centre_np = centre_field.to_numpy()
102
103
 
103
104
  # load the ensemble fields and compute the mean
104
- members_np = np.zeros((n_numbers, *center_np.shape))
105
+ members_np = np.zeros((n_numbers, *centre_np.shape))
105
106
 
106
107
  for j in range(n_numbers):
107
108
  ensemble_field = members[i * n_numbers + j]
108
109
  ensemble_field_as_mars = ensemble_field.as_mars()
109
- check_compatible(center_field, ensemble_field, center_field_as_mars, ensemble_field_as_mars)
110
+ check_compatible(centre_field, ensemble_field, centre_field_as_mars, ensemble_field_as_mars)
110
111
  members_np[j] = ensemble_field.to_numpy()
111
112
 
112
113
  ensemble_field_as_mars = tuple(sorted(ensemble_field_as_mars.items()))
113
114
  assert ensemble_field_as_mars not in seen, ensemble_field_as_mars
114
115
  seen.add(ensemble_field_as_mars)
115
116
 
116
- # cmin=np.amin(center_np)
117
+ # cmin=np.amin(centre_np)
117
118
  # emin=np.amin(members_np)
118
119
 
119
120
  # if cmin < 0 and emin >= 0:
120
121
  # LOG.warning(f"Negative values in {param} cmin={cmin} emin={emin}")
121
- # LOG.warning(f"Center: {center_field_as_mars}")
122
+ # LOG.warning(f"centre: {centre_field_as_mars}")
122
123
 
123
124
  mean_np = members_np.mean(axis=0)
124
125
 
@@ -126,11 +127,11 @@ def perturbations(
126
127
  template = members[i * n_numbers + j]
127
128
  e = members_np[j]
128
129
  m = mean_np
129
- c = center_np
130
+ c = centre_np
130
131
 
131
132
  assert e.shape == c.shape == m.shape, (e.shape, c.shape, m.shape)
132
133
 
133
- x = c - m + e
134
+ x = c + (e - m) * alpha
134
135
 
135
136
  if param in clip_variables:
136
137
  # LOG.warning(f"Clipping {param} to be positive")
@@ -19,6 +19,7 @@ class Creator:
19
19
  print=print,
20
20
  statistics_tmp=None,
21
21
  overwrite=False,
22
+ test=None,
22
23
  **kwargs,
23
24
  ):
24
25
  self.path = path # Output path
@@ -27,6 +28,7 @@ class Creator:
27
28
  self.print = print
28
29
  self.statistics_tmp = statistics_tmp
29
30
  self.overwrite = overwrite
31
+ self.test = test
30
32
 
31
33
  def init(self, check_name=False):
32
34
  # check path
@@ -43,6 +45,7 @@ class Creator:
43
45
  config=self.config,
44
46
  statistics_tmp=self.statistics_tmp,
45
47
  print=self.print,
48
+ test=self.test,
46
49
  )
47
50
  obj.initialise(check_name=check_name)
48
51