xdiffly 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xdiff/__init__.py +25 -0
- xdiff/comparators/__init__.py +19 -0
- xdiff/comparators/base.py +20 -0
- xdiff/comparators/netcdf.py +286 -0
- xdiff/compare/__init__.py +38 -0
- xdiff/compare/ncdiff.py +34 -0
- xdiff/conf/__init__.py +26 -0
- xdiff/conf/global_settings.py +26 -0
- xdiff/core/__init__.py +15 -0
- xdiff/core/dask_runtime.py +114 -0
- xdiff/core/main.py +84 -0
- xdiff/core/service.py +242 -0
- xdiff/discovery/__init__.py +5 -0
- xdiff/discovery/filesystem.py +20 -0
- xdiff/exceptions/__init__.py +13 -0
- xdiff/exceptions/all_nan.py +2 -0
- xdiff/exceptions/last_timestep.py +2 -0
- xdiff/exceptions/no_match.py +5 -0
- xdiff/exceptions/unsupported_artifact.py +5 -0
- xdiff/management/__init__.py +7 -0
- xdiff/management/cli.py +248 -0
- xdiff/matching/__init__.py +5 -0
- xdiff/matching/default.py +42 -0
- xdiff/model/__init__.py +20 -0
- xdiff/model/artifact.py +58 -0
- xdiff/model/compare_result.py +32 -0
- xdiff/model/comparison.py +88 -0
- xdiff/model/match.py +15 -0
- xdiff/model/report.py +42 -0
- xdiff/model/request.py +88 -0
- xdiff/printlib/__init__.py +2 -0
- xdiff/printlib/formatter.py +219 -0
- xdiff/printlib/progress.py +335 -0
- xdiff/utils/__init__.py +0 -0
- xdiff/utils/log.py +64 -0
- xdiff/utils/module_loading.py +30 -0
- xdiff/utils/regex.py +65 -0
- xdiffly-0.2.6.dist-info/METADATA +210 -0
- xdiffly-0.2.6.dist-info/RECORD +41 -0
- xdiffly-0.2.6.dist-info/WHEEL +4 -0
- xdiffly-0.2.6.dist-info/entry_points.txt +3 -0
xdiff/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
from xdiff import conf as settings
|
|
6
|
+
from xdiff.utils.log import configure_logging
|
|
7
|
+
|
|
8
|
+
MIN_SUPPORTED_PYTHON = (3, 10)
|
|
9
|
+
MAX_SUPPORTED_PYTHON = (3, 14)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def validate_runtime() -> None:
|
|
13
|
+
version = sys.version_info[:2]
|
|
14
|
+
if MIN_SUPPORTED_PYTHON <= version < MAX_SUPPORTED_PYTHON:
|
|
15
|
+
return
|
|
16
|
+
|
|
17
|
+
raise RuntimeError(
|
|
18
|
+
"xdiff supports Python 3.10 through 3.13. "
|
|
19
|
+
f"The current interpreter is Python {version[0]}.{version[1]}."
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def setup():
|
|
24
|
+
validate_runtime()
|
|
25
|
+
configure_logging(settings.LOGGING_CONFIG, settings.LOGGING)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Comparator implementations for each artifact type."""
|
|
2
|
+
|
|
3
|
+
from importlib import import_module
|
|
4
|
+
|
|
5
|
+
from xdiff.comparators.base import ArtifactComparator
|
|
6
|
+
|
|
7
|
+
__all__ = ["ArtifactComparator", "NetcdfComparator"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def __getattr__(name: str):
|
|
11
|
+
if name == "NetcdfComparator":
|
|
12
|
+
comparator = getattr(import_module("xdiff.comparators.netcdf"), name)
|
|
13
|
+
globals()[name] = comparator
|
|
14
|
+
return comparator
|
|
15
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def __dir__() -> list[str]:
|
|
19
|
+
return sorted(__all__)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Base interfaces for artifact comparators."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
|
|
7
|
+
from xdiff.model.artifact import ArtifactKind
|
|
8
|
+
from xdiff.model.comparison import Comparison
|
|
9
|
+
from xdiff.model.match import ArtifactMatch
|
|
10
|
+
from xdiff.model.request import CompareRequest
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ArtifactComparator(ABC):
|
|
14
|
+
"""Compares a matched pair of artifacts."""
|
|
15
|
+
|
|
16
|
+
artifact_kind: ArtifactKind
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def compare(self, match: ArtifactMatch, request: CompareRequest) -> Comparison:
|
|
20
|
+
"""Return the comparison outcome for a single matched pair."""
|
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
"""Comparator for netCDF artifacts and related numeric helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import warnings
|
|
7
|
+
|
|
8
|
+
from functools import lru_cache
|
|
9
|
+
from importlib import import_module
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Iterable
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
import xdiff.conf as settings
|
|
15
|
+
|
|
16
|
+
from xdiff.comparators.base import ArtifactComparator
|
|
17
|
+
from xdiff.exceptions import AllNaN, LastTimestepTimeCheckException
|
|
18
|
+
from xdiff.model import CompareResult
|
|
19
|
+
from xdiff.model.artifact import ArtifactKind
|
|
20
|
+
from xdiff.model.comparison import Comparison
|
|
21
|
+
from xdiff.model.match import ArtifactMatch
|
|
22
|
+
from xdiff.model.request import CompareRequest
|
|
23
|
+
|
|
24
|
+
warnings.filterwarnings("ignore", message="All-NaN slice encountered")
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger("xdiff")
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
import xarray as xr
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@lru_cache(maxsize=1)
|
|
33
|
+
def load_xarray():
|
|
34
|
+
"""Import xarray only when a netCDF comparison is actually requested."""
|
|
35
|
+
try:
|
|
36
|
+
return import_module("xarray")
|
|
37
|
+
except ImportError as exc:
|
|
38
|
+
raise RuntimeError(
|
|
39
|
+
"xarray is required to compare netCDF files. Install the package dependencies before using this feature."
|
|
40
|
+
) from exc
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class NetcdfComparator(ArtifactComparator):
|
|
44
|
+
"""Compare two netCDF files using xarray-backed numeric checks."""
|
|
45
|
+
|
|
46
|
+
artifact_kind = ArtifactKind.NETCDF
|
|
47
|
+
|
|
48
|
+
def compare(self, match: ArtifactMatch, request: CompareRequest) -> Comparison:
|
|
49
|
+
comparison = Comparison(
|
|
50
|
+
reference_artifact=match.reference,
|
|
51
|
+
comparison_artifact=match.comparison,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
if match.comparison is None:
|
|
55
|
+
raise ValueError("A comparison artifact is required for netCDF comparison")
|
|
56
|
+
|
|
57
|
+
comparison.extend(
|
|
58
|
+
compare_files(
|
|
59
|
+
match.reference.path,
|
|
60
|
+
match.comparison.path,
|
|
61
|
+
request.variables,
|
|
62
|
+
last_time_step=request.last_time_step,
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
return comparison
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def compare_files(
|
|
69
|
+
file1,
|
|
70
|
+
file2,
|
|
71
|
+
variables: tuple[str, ...] | list[str] | object | None,
|
|
72
|
+
*,
|
|
73
|
+
last_time_step: bool,
|
|
74
|
+
) -> list[CompareResult]:
|
|
75
|
+
xr = load_xarray()
|
|
76
|
+
with xr.open_dataset(file1) as dataset1, xr.open_dataset(file2) as dataset2:
|
|
77
|
+
variables_to_compare = get_dataset_variables(dataset1, variables)
|
|
78
|
+
return compare_datasets(
|
|
79
|
+
dataset1,
|
|
80
|
+
dataset2,
|
|
81
|
+
variables_to_compare,
|
|
82
|
+
last_time_step=last_time_step,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def compare_datasets(
|
|
87
|
+
reference: xr.Dataset,
|
|
88
|
+
comparison: xr.Dataset,
|
|
89
|
+
variables: list[str],
|
|
90
|
+
*,
|
|
91
|
+
last_time_step: bool,
|
|
92
|
+
) -> list[CompareResult]:
|
|
93
|
+
results: list[CompareResult] = []
|
|
94
|
+
|
|
95
|
+
for variable in variables:
|
|
96
|
+
logger.info("Comparing %s", variable)
|
|
97
|
+
try:
|
|
98
|
+
reference_field = reference[variable]
|
|
99
|
+
comparison_field = comparison[variable]
|
|
100
|
+
results.append(
|
|
101
|
+
compare_variables(
|
|
102
|
+
reference_field,
|
|
103
|
+
comparison_field,
|
|
104
|
+
variable,
|
|
105
|
+
last_time_step=last_time_step,
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
except Exception as exc:
|
|
109
|
+
results.append(CompareResult(variable=variable, description=str(exc)))
|
|
110
|
+
|
|
111
|
+
return results
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def compare_variables(
|
|
115
|
+
ref_da: xr.DataArray,
|
|
116
|
+
cmp_da: xr.DataArray,
|
|
117
|
+
variable: str,
|
|
118
|
+
*,
|
|
119
|
+
last_time_step: bool,
|
|
120
|
+
) -> CompareResult:
|
|
121
|
+
if last_time_step:
|
|
122
|
+
if is_time_coordinate_variable(variable, ref_da, cmp_da):
|
|
123
|
+
raise LastTimestepTimeCheckException("Can't compare time if last time step is enabled")
|
|
124
|
+
ref_da = select_last_time_step(ref_da)
|
|
125
|
+
cmp_da = select_last_time_step(cmp_da)
|
|
126
|
+
|
|
127
|
+
validate_matching_metadata(ref_da, cmp_da)
|
|
128
|
+
|
|
129
|
+
reference_values = ref_da.values
|
|
130
|
+
comparison_values = cmp_da.values
|
|
131
|
+
reference_masked = ref_da.to_masked_array()
|
|
132
|
+
comparison_masked = cmp_da.to_masked_array()
|
|
133
|
+
|
|
134
|
+
difference_field = reference_values - comparison_values
|
|
135
|
+
|
|
136
|
+
if np.isnan(difference_field).all():
|
|
137
|
+
raise AllNaN("All nan values found")
|
|
138
|
+
|
|
139
|
+
mask_is_equal = np.array_equal(reference_masked.mask, comparison_masked.mask)
|
|
140
|
+
|
|
141
|
+
return CompareResult(
|
|
142
|
+
relative_error=compute_relative_error(difference_field, comparison_values),
|
|
143
|
+
min_diff=np.nanmin(difference_field),
|
|
144
|
+
max_diff=np.nanmax(difference_field),
|
|
145
|
+
mask_equal=mask_is_equal,
|
|
146
|
+
variable=variable,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def select_last_time_step(field: xr.DataArray) -> xr.DataArray:
|
|
153
|
+
time_dimension = find_time_dims_name(field.dims)
|
|
154
|
+
if time_dimension is None:
|
|
155
|
+
return field
|
|
156
|
+
|
|
157
|
+
if field.sizes[time_dimension] > 1:
|
|
158
|
+
return field.isel({time_dimension: slice(-1, None)})
|
|
159
|
+
return field
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def find_time_dims_name(dims: Iterable) -> Any | None:
|
|
163
|
+
time_dimensions = [dimension for dimension in dims if "time" in dimension]
|
|
164
|
+
if len(time_dimensions) == 0:
|
|
165
|
+
return None
|
|
166
|
+
if len(time_dimensions) > 1:
|
|
167
|
+
raise ValueError(f"Found more than 1 time dimension: {', '.join(time_dimensions)}")
|
|
168
|
+
return time_dimensions.pop()
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def compute_relative_error(diff: np.ndarray, field2: np.ndarray):
|
|
172
|
+
if np.all(diff == 0.0):
|
|
173
|
+
return 0.0
|
|
174
|
+
|
|
175
|
+
if is_time_dtype(field2.dtype):
|
|
176
|
+
field2_values = field2.view("int64")
|
|
177
|
+
else:
|
|
178
|
+
field2_values = field2
|
|
179
|
+
|
|
180
|
+
abs_diff = np.abs(diff)
|
|
181
|
+
abs_field2 = np.abs(field2_values)
|
|
182
|
+
|
|
183
|
+
try:
|
|
184
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
185
|
+
rel_err_array = abs_diff / abs_field2
|
|
186
|
+
if np.isinf(rel_err_array).any():
|
|
187
|
+
rel_err_array[np.isinf(rel_err_array)] = np.nan
|
|
188
|
+
rel_err = np.nanmax(rel_err_array)
|
|
189
|
+
except Exception as exc:
|
|
190
|
+
logger.debug("An error occurred when computing relative error: %s", exc)
|
|
191
|
+
rel_err = np.nan
|
|
192
|
+
|
|
193
|
+
if is_time_dtype(field2.dtype):
|
|
194
|
+
return rel_err / np.timedelta64(1, "s")
|
|
195
|
+
return rel_err
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def is_time_dtype(dtype) -> bool:
|
|
199
|
+
normalized_dtype = np.dtype(dtype)
|
|
200
|
+
return np.issubdtype(normalized_dtype, np.datetime64) or np.issubdtype(normalized_dtype, np.timedelta64)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def is_time_coordinate_variable(variable: str, ref_da: xr.DataArray, cmp_da: xr.DataArray) -> bool:
|
|
204
|
+
time_dimension = find_time_dims_name(ref_da.dims)
|
|
205
|
+
comparison_time_dimension = find_time_dims_name(cmp_da.dims)
|
|
206
|
+
|
|
207
|
+
if time_dimension != comparison_time_dimension or time_dimension is None:
|
|
208
|
+
return False
|
|
209
|
+
|
|
210
|
+
return (
|
|
211
|
+
variable == time_dimension
|
|
212
|
+
and ref_da.dims == (time_dimension,)
|
|
213
|
+
and cmp_da.dims == (time_dimension,)
|
|
214
|
+
and is_time_dtype(ref_da.dtype)
|
|
215
|
+
and is_time_dtype(cmp_da.dtype)
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def validate_matching_metadata(ref_da: xr.DataArray, cmp_da: xr.DataArray) -> None:
|
|
220
|
+
if ref_da.dims != cmp_da.dims:
|
|
221
|
+
raise ValueError(f"Dimension mismatch: '{ref_da.dims}' - '{cmp_da.dims}'")
|
|
222
|
+
|
|
223
|
+
reference_sizes = tuple(ref_da.sizes[dimension] for dimension in ref_da.dims)
|
|
224
|
+
comparison_sizes = tuple(cmp_da.sizes[dimension] for dimension in cmp_da.dims)
|
|
225
|
+
if reference_sizes != comparison_sizes:
|
|
226
|
+
raise ValueError(f"Dimension size mismatch: '{reference_sizes}' - '{comparison_sizes}'")
|
|
227
|
+
|
|
228
|
+
if np.dtype(ref_da.dtype) != np.dtype(cmp_da.dtype):
|
|
229
|
+
raise ValueError(f"Data type mismatch: '{ref_da.dtype}' - '{cmp_da.dtype}'")
|
|
230
|
+
|
|
231
|
+
reference_coordinates = set(ref_da.coords)
|
|
232
|
+
comparison_coordinates = set(cmp_da.coords)
|
|
233
|
+
if reference_coordinates != comparison_coordinates:
|
|
234
|
+
logger.debug(
|
|
235
|
+
"Coordinate mismatch: '%s' - '%s'",
|
|
236
|
+
", ".join(sorted(reference_coordinates)) or "-",
|
|
237
|
+
", ".join(sorted(comparison_coordinates)) or "-",
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
for coordinate_name in sorted(reference_coordinates & comparison_coordinates):
|
|
241
|
+
reference_coordinate = ref_da.coords[coordinate_name]
|
|
242
|
+
comparison_coordinate = cmp_da.coords[coordinate_name]
|
|
243
|
+
|
|
244
|
+
if reference_coordinate.dims != comparison_coordinate.dims:
|
|
245
|
+
logger.debug(
|
|
246
|
+
"Coordinate dimension mismatch for '%s': '%s' - '%s'",
|
|
247
|
+
coordinate_name,
|
|
248
|
+
reference_coordinate.dims,
|
|
249
|
+
comparison_coordinate.dims,
|
|
250
|
+
)
|
|
251
|
+
continue
|
|
252
|
+
|
|
253
|
+
if np.dtype(reference_coordinate.dtype) != np.dtype(comparison_coordinate.dtype):
|
|
254
|
+
logger.debug(
|
|
255
|
+
"Coordinate type mismatch for '%s': '%s' - '%s'",
|
|
256
|
+
coordinate_name,
|
|
257
|
+
reference_coordinate.dtype,
|
|
258
|
+
comparison_coordinate.dtype,
|
|
259
|
+
)
|
|
260
|
+
continue
|
|
261
|
+
|
|
262
|
+
if not reference_coordinate.equals(comparison_coordinate):
|
|
263
|
+
logger.debug("Coordinate values mismatch for '%s'", coordinate_name)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def get_dataset_variables(dataset: xr.Dataset, variables: tuple[str, ...] | list[str] | object | None) -> list[str]:
|
|
267
|
+
"""Extract comparable variables and dimensions from a dataset."""
|
|
268
|
+
selected_variables: list[str] = []
|
|
269
|
+
|
|
270
|
+
if variables in (None, settings.DEFAULT_VARIABLES_TO_CHECK):
|
|
271
|
+
variables_to_check = list(dataset.data_vars) + list(dataset.dims)
|
|
272
|
+
else:
|
|
273
|
+
variables_to_check = list(variables)
|
|
274
|
+
|
|
275
|
+
for variable in variables_to_check:
|
|
276
|
+
if variable not in dataset:
|
|
277
|
+
continue
|
|
278
|
+
|
|
279
|
+
dtype_kind = dataset[variable].dtype.kind
|
|
280
|
+
if dtype_kind in ("U", "S", "O", "a"):
|
|
281
|
+
logger.debug("Skipping variable %s due to datatype %s", variable, dtype_kind)
|
|
282
|
+
continue
|
|
283
|
+
|
|
284
|
+
selected_variables.append(variable)
|
|
285
|
+
|
|
286
|
+
return selected_variables
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Public comparison exports."""
|
|
2
|
+
|
|
3
|
+
from importlib import import_module
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"compare",
|
|
7
|
+
"compare_datasets",
|
|
8
|
+
"compare_files",
|
|
9
|
+
"compare_variables",
|
|
10
|
+
"compute_relative_error",
|
|
11
|
+
"find_time_dims_name",
|
|
12
|
+
"get_dataset_variables",
|
|
13
|
+
"select_last_time_step",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
_LAZY_EXPORTS = {
|
|
17
|
+
"compare": ("xdiff.compare.ncdiff", "compare"),
|
|
18
|
+
"compare_datasets": ("xdiff.comparators.netcdf", "compare_datasets"),
|
|
19
|
+
"compare_files": ("xdiff.comparators.netcdf", "compare_files"),
|
|
20
|
+
"compare_variables": ("xdiff.comparators.netcdf", "compare_variables"),
|
|
21
|
+
"compute_relative_error": ("xdiff.comparators.netcdf", "compute_relative_error"),
|
|
22
|
+
"find_time_dims_name": ("xdiff.comparators.netcdf", "find_time_dims_name"),
|
|
23
|
+
"get_dataset_variables": ("xdiff.comparators.netcdf", "get_dataset_variables"),
|
|
24
|
+
"select_last_time_step": ("xdiff.comparators.netcdf", "select_last_time_step"),
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def __getattr__(name: str):
|
|
29
|
+
if name in _LAZY_EXPORTS:
|
|
30
|
+
module_name, attribute_name = _LAZY_EXPORTS[name]
|
|
31
|
+
value = getattr(import_module(module_name), attribute_name)
|
|
32
|
+
globals()[name] = value
|
|
33
|
+
return value
|
|
34
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def __dir__() -> list[str]:
|
|
38
|
+
return sorted(__all__)
|
xdiff/compare/ncdiff.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""NetCDF comparison orchestration for matched file pairs."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Iterable
|
|
5
|
+
|
|
6
|
+
from xdiff.exceptions import NoMatchFound
|
|
7
|
+
from xdiff.model.comparison import Comparison
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def compare_files(*args, **kwargs):
|
|
11
|
+
"""Proxy to the netCDF comparator while keeping the import lazy and patchable."""
|
|
12
|
+
from xdiff.comparators.netcdf import compare_files as compare_files_impl
|
|
13
|
+
|
|
14
|
+
return compare_files_impl(*args, **kwargs)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def compare(
|
|
18
|
+
compare_match: dict[Path, list[Path]],
|
|
19
|
+
variables: Iterable[str] | tuple[str, ...] | list[str] | object | None,
|
|
20
|
+
last_time_step: bool,
|
|
21
|
+
):
|
|
22
|
+
for reference, to_compares in compare_match.items():
|
|
23
|
+
if len(to_compares) == 0:
|
|
24
|
+
yield Comparison.from_paths(reference, None, NoMatchFound(f"No match found for {reference}"))
|
|
25
|
+
continue
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
for to_compare in to_compares:
|
|
29
|
+
comparison = Comparison.from_paths(reference, to_compare)
|
|
30
|
+
comparison.extend(compare_files(reference, to_compare, variables, last_time_step=last_time_step))
|
|
31
|
+
yield comparison
|
|
32
|
+
|
|
33
|
+
except Exception as e:
|
|
34
|
+
yield Comparison.from_paths(reference, None, e)
|
xdiff/conf/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Public configuration exports."""
|
|
2
|
+
|
|
3
|
+
from xdiff.conf import global_settings as settings
|
|
4
|
+
|
|
5
|
+
DEFAULT_COMMON_PATTERN = settings.DEFAULT_COMMON_PATTERN
|
|
6
|
+
DEFAULT_MAXDEPTH = settings.DEFAULT_MAXDEPTH
|
|
7
|
+
DEFAULT_NAME_TO_COMPARE = settings.DEFAULT_NAME_TO_COMPARE
|
|
8
|
+
DEFAULT_VARIABLES_TO_CHECK = settings.DEFAULT_VARIABLES_TO_CHECK
|
|
9
|
+
DTYPE_NOT_CHECKED = settings.DTYPE_NOT_CHECKED
|
|
10
|
+
TIME_DTYPE = settings.TIME_DTYPE
|
|
11
|
+
LOGGING = settings.LOGGING
|
|
12
|
+
LOGGING_CONFIG = settings.LOGGING_CONFIG
|
|
13
|
+
DEBUG = settings.DEBUG
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"settings",
|
|
17
|
+
"DEFAULT_COMMON_PATTERN",
|
|
18
|
+
"DEFAULT_MAXDEPTH",
|
|
19
|
+
"DEFAULT_NAME_TO_COMPARE",
|
|
20
|
+
"DEFAULT_VARIABLES_TO_CHECK",
|
|
21
|
+
"DTYPE_NOT_CHECKED",
|
|
22
|
+
"TIME_DTYPE",
|
|
23
|
+
"LOGGING",
|
|
24
|
+
"LOGGING_CONFIG",
|
|
25
|
+
"DEBUG",
|
|
26
|
+
]
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Default X-Diff settings.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
########################
|
|
6
|
+
# SETTINGS
|
|
7
|
+
########################
|
|
8
|
+
|
|
9
|
+
DEFAULT_MAXDEPTH = 1 # negative value remove the limit on
|
|
10
|
+
DEFAULT_NAME_TO_COMPARE = "*.nc"
|
|
11
|
+
DEFAULT_VARIABLES_TO_CHECK = object()
|
|
12
|
+
DEFAULT_COMMON_PATTERN = None
|
|
13
|
+
DTYPE_NOT_CHECKED = ["S8", "S1", "O"] # S8|S1:char, O: string
|
|
14
|
+
TIME_DTYPE = ["datetime64[ns]", "<M8[ns]"]
|
|
15
|
+
|
|
16
|
+
########################
|
|
17
|
+
# LOG
|
|
18
|
+
########################
|
|
19
|
+
|
|
20
|
+
# The callable to use to configure logging
|
|
21
|
+
LOGGING_CONFIG = "logging.config.dictConfig"
|
|
22
|
+
|
|
23
|
+
# Custom logging configuration.
|
|
24
|
+
LOGGING = {}
|
|
25
|
+
|
|
26
|
+
DEBUG = False
|
xdiff/core/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Public core entrypoints for the package."""
|
|
2
|
+
|
|
3
|
+
import xdiff
|
|
4
|
+
|
|
5
|
+
from xdiff.core import main as main_module
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def execute(**kwargs):
|
|
9
|
+
xdiff.setup()
|
|
10
|
+
get_version = kwargs.pop("get_version", False)
|
|
11
|
+
|
|
12
|
+
if get_version:
|
|
13
|
+
exit(0)
|
|
14
|
+
|
|
15
|
+
return main_module.execute(**kwargs)
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""Helpers for optional Dask execution.
|
|
2
|
+
|
|
3
|
+
The application still defaults to serial execution. This module is only used
|
|
4
|
+
when the caller explicitly selects a Dask-backed execution mode, which keeps
|
|
5
|
+
the regular CLI startup path lightweight.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
|
|
13
|
+
from contextlib import contextmanager
|
|
14
|
+
from importlib import import_module
|
|
15
|
+
from typing import TYPE_CHECKING, Iterator
|
|
16
|
+
|
|
17
|
+
from xdiff.model.request import CompareRequest
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from distributed import Client
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger("xdiff")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def log_local_worker_advisories(request: CompareRequest) -> None:
|
|
26
|
+
"""Report obviously poor local-worker counts without blocking the run."""
|
|
27
|
+
if request.dask_workers is None:
|
|
28
|
+
return
|
|
29
|
+
|
|
30
|
+
visible_cpus = os.cpu_count()
|
|
31
|
+
if visible_cpus is not None and request.dask_workers > visible_cpus:
|
|
32
|
+
logger.warning(
|
|
33
|
+
"Requested %s Dask worker(s) but only %s CPU(s) are visible on this node.",
|
|
34
|
+
request.dask_workers,
|
|
35
|
+
visible_cpus,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def log_file_mode_advisories(request: CompareRequest, comparable_file_pairs: int) -> None:
|
|
40
|
+
"""Report file-mode specific worker advisories without blocking the run."""
|
|
41
|
+
log_local_worker_advisories(request)
|
|
42
|
+
if request.dask_workers is None:
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
if comparable_file_pairs < request.dask_workers:
|
|
46
|
+
logger.warning(
|
|
47
|
+
"Requested %s Dask worker(s) for %s comparable file pair(s); some workers will remain idle.",
|
|
48
|
+
request.dask_workers,
|
|
49
|
+
comparable_file_pairs,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def log_local_cluster_address(cluster, client) -> None:
|
|
54
|
+
"""Emit the local scheduler address through the default visible log channel."""
|
|
55
|
+
scheduler_address = getattr(cluster, "scheduler_address", None)
|
|
56
|
+
if scheduler_address is None:
|
|
57
|
+
scheduler = getattr(client, "scheduler", None)
|
|
58
|
+
scheduler_address = getattr(scheduler, "address", None)
|
|
59
|
+
|
|
60
|
+
if scheduler_address is not None:
|
|
61
|
+
logger.warning("Local Dask scheduler available at %s", scheduler_address)
|
|
62
|
+
|
|
63
|
+
dashboard_link = getattr(cluster, "dashboard_link", None)
|
|
64
|
+
if dashboard_link is not None:
|
|
65
|
+
logger.warning("Local Dask dashboard available at %s", dashboard_link)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@contextmanager
|
|
69
|
+
def client_from_request(request: CompareRequest) -> Iterator["Client"]:
|
|
70
|
+
"""Attach to an external scheduler or create a local cluster for this request."""
|
|
71
|
+
distributed = _load_distributed()
|
|
72
|
+
cluster = None
|
|
73
|
+
|
|
74
|
+
if request.uses_external_dask_scheduler:
|
|
75
|
+
client = _connect_to_scheduler(request, distributed.Client)
|
|
76
|
+
else:
|
|
77
|
+
# One worker process per requested slot is the safest default for Dask-backed netCDF I/O.
|
|
78
|
+
cluster = distributed.LocalCluster(
|
|
79
|
+
n_workers=request.dask_workers,
|
|
80
|
+
threads_per_worker=1,
|
|
81
|
+
processes=True,
|
|
82
|
+
dashboard_address=":8787",
|
|
83
|
+
)
|
|
84
|
+
client = distributed.Client(cluster)
|
|
85
|
+
log_local_cluster_address(cluster, client)
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
yield client
|
|
89
|
+
finally:
|
|
90
|
+
client.close()
|
|
91
|
+
if cluster is not None:
|
|
92
|
+
cluster.close()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def iterate_results_as_completed(futures):
|
|
96
|
+
"""Yield futures and results in completion order."""
|
|
97
|
+
distributed = _load_distributed()
|
|
98
|
+
yield from distributed.as_completed(futures, with_results=True)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _connect_to_scheduler(request: CompareRequest, client_type):
|
|
102
|
+
if request.dask_scheduler_file is not None:
|
|
103
|
+
return client_type(scheduler_file=str(request.dask_scheduler_file))
|
|
104
|
+
return client_type(request.dask_scheduler)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _load_distributed():
|
|
108
|
+
try:
|
|
109
|
+
return import_module("distributed")
|
|
110
|
+
except ImportError as exc:
|
|
111
|
+
raise RuntimeError(
|
|
112
|
+
"Dask support requires the 'distributed' package. "
|
|
113
|
+
"Install the project dependencies before using this feature."
|
|
114
|
+
) from exc
|
xdiff/core/main.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Compatibility layer for the application service entrypoint."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING, Iterable
|
|
5
|
+
|
|
6
|
+
import xdiff.conf as settings
|
|
7
|
+
|
|
8
|
+
from xdiff.core.service import ComparisonService
|
|
9
|
+
from xdiff.discovery import FileSystemArtifactDiscovery
|
|
10
|
+
from xdiff.model import CompareMode, CompareRequest, ComparisonReport, ExecutionMode
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from xdiff.printlib.progress import ProgressReporter
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def execute(
|
|
17
|
+
reference_path: Path,
|
|
18
|
+
comparison_path: Path,
|
|
19
|
+
filter_name: str = settings.DEFAULT_NAME_TO_COMPARE,
|
|
20
|
+
common_pattern: str | None = settings.DEFAULT_COMMON_PATTERN,
|
|
21
|
+
variables: Iterable[str] | object = settings.DEFAULT_VARIABLES_TO_CHECK,
|
|
22
|
+
last_time_step: bool = False,
|
|
23
|
+
input_mode: CompareMode = CompareMode.DIRECTORIES,
|
|
24
|
+
execution_mode: ExecutionMode | str = ExecutionMode.SERIAL,
|
|
25
|
+
dask_scheduler: str | None = None,
|
|
26
|
+
dask_scheduler_file: Path | None = None,
|
|
27
|
+
dask_workers: int | None = None,
|
|
28
|
+
progress_reporter: "ProgressReporter | None" = None,
|
|
29
|
+
) -> ComparisonReport:
|
|
30
|
+
request = build_request(
|
|
31
|
+
reference_path=reference_path,
|
|
32
|
+
comparison_path=comparison_path,
|
|
33
|
+
input_mode=input_mode,
|
|
34
|
+
filter_name=filter_name,
|
|
35
|
+
common_pattern=common_pattern,
|
|
36
|
+
variables=variables,
|
|
37
|
+
last_time_step=last_time_step,
|
|
38
|
+
execution_mode=execution_mode,
|
|
39
|
+
dask_scheduler=dask_scheduler,
|
|
40
|
+
dask_scheduler_file=dask_scheduler_file,
|
|
41
|
+
dask_workers=dask_workers,
|
|
42
|
+
)
|
|
43
|
+
return ComparisonService.default().run(request, progress_reporter=progress_reporter)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def build_request(
|
|
47
|
+
reference_path: Path,
|
|
48
|
+
comparison_path: Path,
|
|
49
|
+
input_mode: CompareMode = CompareMode.DIRECTORIES,
|
|
50
|
+
filter_name: str = settings.DEFAULT_NAME_TO_COMPARE,
|
|
51
|
+
common_pattern: str | None = settings.DEFAULT_COMMON_PATTERN,
|
|
52
|
+
variables: Iterable[str] | object = settings.DEFAULT_VARIABLES_TO_CHECK,
|
|
53
|
+
last_time_step: bool = False,
|
|
54
|
+
execution_mode: ExecutionMode | str = ExecutionMode.SERIAL,
|
|
55
|
+
dask_scheduler: str | None = None,
|
|
56
|
+
dask_scheduler_file: Path | None = None,
|
|
57
|
+
dask_workers: int | None = None,
|
|
58
|
+
) -> CompareRequest:
|
|
59
|
+
"""Normalize legacy execute arguments into a service request."""
|
|
60
|
+
return CompareRequest(
|
|
61
|
+
input_mode=input_mode,
|
|
62
|
+
reference_path=reference_path,
|
|
63
|
+
comparison_path=comparison_path,
|
|
64
|
+
filter_name=filter_name,
|
|
65
|
+
common_pattern=common_pattern,
|
|
66
|
+
variables=normalize_variables(variables),
|
|
67
|
+
last_time_step=last_time_step,
|
|
68
|
+
execution_mode=execution_mode,
|
|
69
|
+
dask_scheduler=dask_scheduler,
|
|
70
|
+
dask_scheduler_file=dask_scheduler_file,
|
|
71
|
+
dask_workers=dask_workers,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def normalize_variables(variables: Iterable[str] | object) -> tuple[str, ...] | None:
|
|
76
|
+
if variables in (None, settings.DEFAULT_VARIABLES_TO_CHECK):
|
|
77
|
+
return None
|
|
78
|
+
return tuple(variables)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def load_files(directory: Path, filter_name: str) -> list[Path]:
|
|
83
|
+
"""Compatibility helper preserved for callers and tests."""
|
|
84
|
+
return FileSystemArtifactDiscovery().list_paths(directory, filter_name)
|