nuthatch 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nuthatch might be problematic. Click here for more details.

@@ -0,0 +1,199 @@
1
+ from nuthatch.backend import DatabaseBackend, FileBackend, register_backend
2
+ import shutil
3
+ from pathlib import Path
4
+ import terracotta as tc
5
+ import sqlalchemy
6
+ import xarray as xr
7
+ import numpy as np
8
+ from rasterio.io import MemoryFile
9
+ from rasterio.enums import Resampling
10
+
11
+ def base360_to_base180(lons):
12
+ """Converts a list of longitudes from base 360 to base 180.
13
+
14
+ Args:
15
+ lons (list, float): A list of longitudes, or a single longitude
16
+ """
17
+ if not isinstance(lons, np.ndarray) and not isinstance(lons, list):
18
+ lons = [lons]
19
+ val = [x - 360.0 if x >= 180.0 else x for x in lons]
20
+ if len(val) == 1:
21
+ return val[0]
22
+ return np.array(val)
23
+
24
+
25
+ def base180_to_base360(lons):
26
+ """Converts a list of longitudes from base 180 to base 360.
27
+
28
+ Args:
29
+ lons (list, float): A list of longitudes, or a single longitude
30
+ """
31
+ if not isinstance(lons, np.ndarray) and not isinstance(lons, list):
32
+ lons = [lons]
33
+ val = [x + 360.0 if x < 0.0 else x for x in lons]
34
+ if len(val) == 1:
35
+ return val[0]
36
+ return np.array(val)
37
+
38
+
39
+
40
+ def is_wrapped(lons):
41
+ """Check if the longitudes are wrapped.
42
+
43
+ Works for both base180 and base360 longitudes. Requires that
44
+ longitudes are in increasing order, outside of a wrap point.
45
+ """
46
+ wraps = (np.diff(lons) < 0.0).sum()
47
+ if wraps > 1:
48
+ raise ValueError("Only one wrapping discontinuity allowed.")
49
+ elif wraps == 1:
50
+ return True
51
+ return False
52
+
53
+
54
+ def lon_base_change(ds, to_base="base180", lon_dim='lon'):
55
+ """Change the base of the dataset from base 360 to base 180 or vice versa.
56
+
57
+ Args:
58
+ ds (xr.Dataset): Dataset to change.
59
+ to_base (str): The base to change to. One of:
60
+ - base180
61
+ - base360
62
+ lon_dim (str): The longitude column name.
63
+ """
64
+ if to_base == "base180":
65
+ if (ds[lon_dim] < 0.0).any():
66
+ print("Longitude already in base 180 format.")
67
+ return ds
68
+ lons = base360_to_base180(ds[lon_dim].values)
69
+ elif to_base == "base360":
70
+ if (ds[lon_dim] > 180.0).any():
71
+ print("Longitude already in base 360 format.")
72
+ return ds
73
+ lons = base180_to_base360(ds[lon_dim].values)
74
+ else:
75
+ raise ValueError(f"Invalid base {to_base}.")
76
+
77
+ # Check if original data is wrapped
78
+ wrapped = is_wrapped(ds.lon.values)
79
+
80
+ # Then assign new coordinates
81
+ ds = ds.assign_coords({lon_dim: lons})
82
+
83
+ # Sort the lons after conversion, unless the slice
84
+ # you're considering wraps around the meridian
85
+ # in the resultant base.
86
+ if not wrapped:
87
+ ds = ds.sortby('lon')
88
+ return ds
89
+
90
+
91
+ @register_backend
92
+ class TerracottaBackend(DatabaseBackend, FileBackend):
93
+ """
94
+ Terracotta backend for caching geospatial data in a terracotta database.
95
+
96
+ This backend supports xarray datasets.
97
+ """
98
+
99
+ backend_name = 'terracotta'
100
+ config_parameters = DatabaseBackend.config_parameters + FileBackend.config_parameters + ['override_path']
101
+
102
+ def __init__(self, cacheable_config, cache_key, namespace, args, backend_kwargs):
103
+ # This calls both inits right?
104
+ DatabaseBackend.__init__(cacheable_config, cache_key, namespace, args, backend_kwargs)
105
+ FileBackend.__init__(cacheable_config, cache_key, namespace, args, backend_kwargs, 'tif')
106
+
107
+ tc.update_settings(SQL_USER=self.config['write_username'], SQL_PASSWORD=self.config['write_password'])
108
+ self.driver = tc.get_driver(self.write_uri)
109
+
110
+ try:
111
+ self.driver.get_keys()
112
+ except sqlalchemy.exc.DatabaseError:
113
+ # Create a metastore
114
+ print("Creating new terracotta metastore")
115
+ self.driver.create(['key'])
116
+
117
+ if 'override_path' in backend_kwargs:
118
+ base_path = Path(backend_kwargs['override_path'])
119
+
120
+ if namespace:
121
+ self.raw_override_path = base_path.joinpath(namespace, cache_key)
122
+ else:
123
+ self.raw_override_path = base_path.joinpath(cache_key)
124
+
125
+ self.override_path = self.raw_override_path + '.tif'
126
+
127
+ def write(self, ds, upsert=False, primary_keys=None):
128
+
129
+ if not isinstance(ds, xr.Dataset):
130
+ raise NotImplementedError("Terracotta backend only supports xarray datasets")
131
+
132
+ # Check to make sure this is geospatial data
133
+ lats = ['lat', 'y', 'latitude']
134
+ lons = ['lon', 'x', 'longitude']
135
+ if len(ds.dims) != 2:
136
+ if len(ds.dims) != 3 or 'time' not in ds.dims:
137
+ raise RuntimeError("Can only store two dimensional geospatial data to terracotta")
138
+
139
+ foundx = False
140
+ foundy = False
141
+ for y in lats:
142
+ if y in ds.dims:
143
+ ds = ds.rename({y: 'y'})
144
+ foundy = True
145
+ for x in lons:
146
+ if x in ds.dims:
147
+ ds = ds.rename({x: 'x'})
148
+ foundx = True
149
+
150
+ if not foundx or not foundy:
151
+ raise RuntimeError("Can only store two or three dimensional (with time) geospatial data to terracotta")
152
+
153
+ # Adjust coordinates
154
+ if (ds['x'] > 180.0).any():
155
+ lon_base_change(ds, lon_dim='x')
156
+ ds = ds.sortby(['x'])
157
+
158
+ # Adapt the CRS
159
+ ds.rio.write_crs("epsg:4326", inplace=True)
160
+ ds = ds.rio.reproject('EPSG:3857', resampling=Resampling.nearest, nodata=np.nan)
161
+ ds.rio.write_crs("epsg:3857", inplace=True)
162
+
163
+ # Insert the parameters.
164
+ with self.driver.connect():
165
+ if 'time' in ds.dims:
166
+ for t in ds.time:
167
+ # Select just this time and squeeze the dimension
168
+ sub_ds = ds.sel(time=t)
169
+ sub_ds = sub_ds.reset_coords('time', drop=True)
170
+
171
+ # add the time to the cache_key
172
+ sub_cache_key = self.cache_key + '_' + str(t.values)
173
+ sub_path = self.raw_cache_path + '_' + str(t.values) + '.tif'
174
+ sub_override_path = self.raw_override_path + '_' + str(t.values) + '.tif'
175
+
176
+ self.write_individual_raster(self.driver, sub_ds, sub_path, sub_cache_key, sub_override_path)
177
+ else:
178
+ self.write_individual_raster(self.driver, ds, self.path, self.cache_key, self.override_path)
179
+
180
+ pass
181
+
182
+ def write_individual_raster(self, driver, ds, path, cache_key, override_path):
183
+ # Write the raster
184
+ with MemoryFile() as mem_dst:
185
+ ds.rio.to_raster(mem_dst.name, driver="COG")
186
+
187
+ with self.fs.open(path, 'wb') as f_out:
188
+ shutil.copyfileobj(mem_dst, f_out)
189
+
190
+ driver.insert({'key': cache_key.replace('/', '_')}, mem_dst,
191
+ override_path=override_path, skip_metadata=False)
192
+
193
+ print(f"Inserted {cache_key.replace('/', '_')} into the terracotta database.")
194
+
195
+ def read(engine):
196
+ raise NotImplementedError("Cannot read from the terracotta backend.")
197
+
198
+ def delete():
199
+ raise NotImplementedError("Cannot delete from the terracotta backend.")
@@ -0,0 +1,207 @@
1
+ from nuthatch.backend import FileBackend, register_backend
2
+ import xarray as xr
3
+ import numpy as np
4
+
5
+ CHUNK_SIZE_UPPER_LIMIT_MB = 300
6
+ CHUNK_SIZE_LOWER_LIMIT_MB = 30
7
+
8
+
9
+ def get_chunk_size(ds, size_in='MB'):
10
+ """Get the chunk size of a dataset in MB or number of chunks.
11
+
12
+ Args:
13
+ ds (xr.Dataset): The dataset to get the chunk size of.
14
+ size_in (str): The size to return the chunk size in. One of:
15
+ 'KB', 'MB', 'GB', 'TB' for kilo, mega, giga, and terabytes respectively.
16
+ """
17
+ chunk_groups = [(dim, np.median(chunks)) for dim, chunks in ds.chunks.items()]
18
+ div = {'KB': 10**3, 'MB': 10**6, 'GB': 10**9, 'TB': 10**12}[size_in]
19
+ chunk_sizes = [x[1] for x in chunk_groups]
20
+ return np.prod(chunk_sizes) * 4 / div, chunk_groups
21
+
22
+
23
+ def merge_chunk_by_arg(chunking, chunk_by_arg, kwargs):
24
+ """Merge chunking and chunking modifiers into a single chunking dict.
25
+
26
+ Args:
27
+ chunking (dict): The chunking to merge.
28
+ chunk_by_arg (dict): The chunking modifiers to merge.
29
+ kwargs (dict): The kwargs to check for chunking modifiers.
30
+ """
31
+ if chunk_by_arg is None:
32
+ return chunking
33
+
34
+ for k in chunk_by_arg:
35
+ if k not in kwargs:
36
+ raise ValueError(f"Chunking modifier {k} not found in kwargs.")
37
+
38
+ if kwargs[k] in chunk_by_arg[k]:
39
+ # If argument value in chunk_by_arg then merge the chunking
40
+ chunk_dict = chunk_by_arg[k][kwargs[k]]
41
+ chunking.update(chunk_dict)
42
+
43
+ return chunking
44
+
45
+
46
+ #TODO Why was this the way it was in sheerwater???
47
+ def prune_chunking_dimensions(ds, chunking):
48
+ """Prune the chunking dimensions to only those that exist in the dataset.
49
+
50
+ Args:
51
+ ds (xr.Dataset): The dataset to check for chunking dimensions.
52
+ chunking (dict): The chunking dimensions to prune.
53
+ """
54
+ # Drop any dimensions that don't exist in the ds_chunks
55
+ for dim in chunking:
56
+ if dim not in ds.dims:
57
+ del chunking[dim]
58
+
59
+ return chunking
60
+
61
+
62
+ def chunking_compare(ds, chunking):
63
+ """Compare the chunking of a dataset to a specified chunking.
64
+
65
+ Args:
66
+ ds (xr.Dataset): The dataset to check the chunking of.
67
+ chunking (dict): The chunking to compare to.
68
+ """
69
+ # Get the chunks for the dataset
70
+ ds_chunks = {dim: ds.chunks[dim][0] for dim in ds.chunks}
71
+ chunking = prune_chunking_dimensions(ds, chunking)
72
+ return ds_chunks == chunking
73
+
74
+
75
+ def drop_encoded_chunks(ds):
76
+ """Drop the encoded chunks from a dataset."""
77
+ for var in ds.data_vars:
78
+ if 'chunks' in ds[var].encoding:
79
+ del ds[var].encoding['chunks']
80
+ if 'preferred_chunks' in ds[var].encoding:
81
+ del ds[var].encoding['preferred_chunks']
82
+
83
+ for coord in ds.coords:
84
+ if 'chunks' in ds[coord].encoding:
85
+ del ds[coord].encoding['chunks']
86
+ if 'preferred_chunks' in ds[coord].encoding:
87
+ del ds[coord].encoding['preferred_chunks']
88
+
89
+ return ds
90
+
91
+
92
+ @register_backend
93
+ class ZarrBackend(FileBackend):
94
+ """
95
+ Zarr backend for caching data in a zarr store.
96
+
97
+ This backend supports xarray datasets.
98
+
99
+ Possible backend_kwargs:
100
+ chunking(dict): Specifies chunking if that coordinate exists. If coordinate does not exist
101
+ the chunking specified will be dropped.
102
+ chunk_by_arg(dict): Specifies chunking modifiers based on the passed cached arguments,
103
+ e.g. grid resolution. For example:
104
+ chunk_by_arg={
105
+ 'grid': {
106
+ 'global0_25': {"lat": 721, "lon": 1440, 'time': 30}
107
+ 'global1_5': {"lat": 121, "lon": 240, 'time': 1000}
108
+ }
109
+ }
110
+ will modify the chunking dict values for lat, lon, and time, depending
111
+ on the value of the 'grid' argument. If multiple cache arguments specify
112
+ modifiers for the same chunking dimension, the last one specified will prevail.
113
+ auto_rechunk(bool): If True will aggressively rechunk a cache on load.
114
+ """
115
+
116
+ backend_name = 'zarr'
117
+ default_for_type = xr.Dataset
118
+
119
+ def __init__(self, cacheable_config, cache_key, namespace, args, backend_kwargs):
120
+ super().__init__(cacheable_config, cache_key, namespace, args, backend_kwargs, 'zarr')
121
+
122
+ if backend_kwargs and 'chunking' in backend_kwargs and 'chunk_by_arg' in backend_kwargs:
123
+ self.chunking = merge_chunk_by_arg(self.backend_kwargs['chunking'], self.backend_kwargs['chunk_by_arg'], args)
124
+ elif backend_kwargs and 'chunking' in backend_kwargs:
125
+ self.chunking = backend_kwargs['chunking']
126
+ else:
127
+ self.chunking = 'auto'
128
+
129
+ if backend_kwargs and 'auto_rechunk' in backend_kwargs and backend_kwargs['auto_rechunk']:
130
+ self.auto_rechunk = True
131
+ else:
132
+ self.auto_rechunk = False
133
+
134
+ def write(self, data, upsert=False, primary_keys=None):
135
+ if upsert:
136
+ raise NotImplementedError("Zarr backend does not support upsert.")
137
+
138
+ if isinstance(data, xr.Dataset):
139
+ self.chunk_to_zarr(data, self.path)
140
+ else:
141
+ raise NotImplementedError("Zarr backend only supports caching of xarray datasets.")
142
+
143
+ def read(self, engine):
144
+ if engine == 'xarray' or engine == xr.Dataset or engine is None:
145
+ # We must auto open chunks. This tries to use the underlying zarr chunking if possible.
146
+ # Setting chunks=True triggers what I think is an xarray/zarr engine bug where
147
+ # every chunk is only 4B!
148
+ if self.auto_rechunk:
149
+ # If rechunk is passed then check to see if the rechunk array
150
+ # matches chunking. If not then rechunk
151
+ ds_remote = xr.open_dataset(self.path, engine='zarr', chunks={}, decode_timedelta=True)
152
+ if not isinstance(self.chunking, dict):
153
+ raise ValueError("If auto_rechunk is True, a chunking dict must be supplied.")
154
+
155
+ # Compare the dict to the rechunk dict
156
+ if not chunking_compare(ds_remote, self.chunking):
157
+ print("Rechunk was passed and cached chunks do not match rechunk request. "
158
+ "Performing rechunking.")
159
+
160
+ # write to a temp cache map
161
+ # writing to temp cache is necessary because if you overwrite
162
+ # the original cache map it will write it before reading the
163
+ # data leading to corruption.
164
+ self.chunk_to_zarr(ds_remote, self.temp_path)
165
+
166
+ # Remove the old cache and verify files
167
+ if self.fs.exists(self.path):
168
+ self.fs.rm(self.path, recursive=True)
169
+
170
+ self.fs.mv(self.temp_path, self.path, recursive=True)
171
+
172
+ # Reopen the dataset - will use the appropriate global or local cache
173
+ return xr.open_dataset(self.path, engine='zarr',
174
+ chunks={}, decode_timedelta=True)
175
+ else:
176
+ # Requested chunks already match rechunk.
177
+ return xr.open_dataset(self.path, engine='zarr',
178
+ chunks={}, decode_timedelta=True)
179
+ else:
180
+ return xr.open_dataset(self.path, engine='zarr', chunks={}, decode_timedelta=True)
181
+ else:
182
+ raise NotImplementedError(f"Zarr backend does not support reading zarrs to {engine} engine")
183
+
184
+
185
+ def chunk_to_zarr(self, ds, path):
186
+ """Write a dataset to a zarr cache map and check the chunking."""
187
+ ds = drop_encoded_chunks(ds)
188
+
189
+ chunking = self.chunking
190
+ if isinstance(self.chunking, dict):
191
+ # No need to prune if chunking is None or 'auto'
192
+ chunking = prune_chunking_dimensions(ds, self.chunking)
193
+
194
+ ds = ds.chunk(chunks=chunking)
195
+
196
+ try:
197
+ chunk_size, chunk_with_labels = get_chunk_size(ds)
198
+
199
+ if chunk_size > CHUNK_SIZE_UPPER_LIMIT_MB or chunk_size < CHUNK_SIZE_LOWER_LIMIT_MB:
200
+ print(f"WARNING: Chunk size is {chunk_size}MB. Target approx 100MB.")
201
+ print(chunk_with_labels)
202
+ except ValueError:
203
+ print("Failed to get chunks size! Continuing with unknown chunking...")
204
+
205
+ ds.to_zarr(store=path, mode='w')
206
+
207
+