recursive-diff 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,858 @@
1
+ import math
2
+ from copy import deepcopy
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import pytest
7
+ import xarray
8
+ from packaging.version import Version
9
+
10
+ from recursive_diff import cast, recursive_diff
11
+ from recursive_diff.tests import requires_dask
12
+
13
+ PANDAS_GE_200 = Version(pd.__version__).release >= (2, 0)
14
+
15
+
16
+ class Rectangle:
17
+ """Sample class to test custom comparisons"""
18
+
19
+ def __init__(self, w, h):
20
+ self.w = w
21
+ self.h = h
22
+
23
+ def __eq__(self, other):
24
+ return self.w == other.w and self.h == other.h
25
+
26
+ def __repr__(self):
27
+ return f"Rectangle({self.w}, {self.h})"
28
+
29
+
30
+ class Drawing:
31
+ """Another class that is not Rectangle but just happens to be cast to the
32
+ same dict
33
+ """
34
+
35
+ def __init__(self, w, h):
36
+ self.w = w
37
+ self.h = h
38
+
39
+ def __eq__(self, other):
40
+ return self.w == other.w and self.h == other.h
41
+
42
+
43
+ @cast.register(Rectangle)
44
+ @cast.register(Drawing)
45
+ def _(obj, brief_dims):
46
+ return {"w": obj.w, "h": obj.h}
47
+
48
+
49
+ class Circle:
50
+ """A class which that supports == but is not registered"""
51
+
52
+ def __init__(self, radius):
53
+ self.radius = radius
54
+
55
+ def __eq__(self, other):
56
+ return self.radius == other.radius
57
+
58
+ def __repr__(self):
59
+ return f"Circle({self.radius})"
60
+
61
+
62
+ class Square:
63
+ """Another unregistered class"""
64
+
65
+ def __init__(self, side):
66
+ self.side = side
67
+
68
+ def __eq__(self, other):
69
+ return self.side == other.side
70
+
71
+ def __repr__(self):
72
+ return f"Square({self.side})"
73
+
74
+
75
+ def check(lhs, rhs, *expect, rel_tol=1e-09, abs_tol=0.0, brief_dims=()):
76
+ expect = sorted(e for e in expect if e)
77
+ actual = sorted(
78
+ recursive_diff(
79
+ lhs, rhs, rel_tol=rel_tol, abs_tol=abs_tol, brief_dims=brief_dims
80
+ )
81
+ )
82
+ assert actual == expect
83
+
84
+
85
+ @pytest.mark.parametrize(
86
+ "x",
87
+ [
88
+ 123,
89
+ "blah",
90
+ "a\nb",
91
+ math.nan,
92
+ np.nan,
93
+ True,
94
+ False,
95
+ [1, 2],
96
+ (1, 2),
97
+ np.int8(1),
98
+ np.uint8(1),
99
+ np.int64(1),
100
+ np.uint64(1),
101
+ np.float32(1),
102
+ np.float64(1),
103
+ {1: 2, 3: 4},
104
+ {1, 2},
105
+ frozenset([1, 2]),
106
+ np.arange(10),
107
+ np.arange(10, dtype=np.float64),
108
+ pd.Series([1, 2]),
109
+ pd.Series([1, 2], index=[3, 4]),
110
+ pd.RangeIndex(10),
111
+ pd.RangeIndex(1, 10, 3),
112
+ pd.Index([1, 2, 3]),
113
+ pd.MultiIndex.from_tuples(
114
+ [("bar", "one"), ("bar", "two"), ("baz", "one")], names=["l1", "l2"]
115
+ ),
116
+ pd.DataFrame([[1, 2], [3, 4]]),
117
+ pd.DataFrame([[1, 2], [3, 4]], index=["i1", "i2"], columns=["c1", "c2"]),
118
+ xarray.DataArray([1, 2]),
119
+ xarray.DataArray([1, 2], dims=["x"], coords={"x": [3, 4]}),
120
+ Rectangle(1, 2),
121
+ Circle(1),
122
+ ],
123
+ )
124
+ def test_identical(x):
125
+ assert not list(recursive_diff(x, deepcopy(x)))
126
+
127
+
128
+ def test_simple():
129
+ check(1, 0, "1 != 0 (abs: -1.0e+00, rel: -1.0e+00)")
130
+ check("asd", "lol", "asd != lol")
131
+ check(b"asd", b"lol", "b'asd' != b'lol'")
132
+ check(True, False, "True != False")
133
+
134
+
135
+ def test_object_type_differs():
136
+ check(1, "1", "1 != 1", "object type differs: int != str")
137
+ check(True, 1, "object type differs: bool != int")
138
+ check(False, 0, "object type differs: bool != int")
139
+ check([1, 2], (1, 2), "object type differs: list != tuple")
140
+ check({1, 2}, frozenset([1, 2]), "object type differs: set != frozenset")
141
+
142
+
143
+ def test_collections():
144
+ check([1, 2], [1, 2, 3], "RHS has 1 more elements than LHS: [3]")
145
+ check({1, 2}, {1, 2, (3, 4)}, "(3, 4) is in RHS only")
146
+ check(
147
+ {"x": 10, "y": 20},
148
+ {"x": 10, "y": 30},
149
+ "[y]: 20 != 30 (abs: 1.0e+01, rel: 5.0e-01)",
150
+ )
151
+ check({2: 20}, {1: 10}, "Pair 1:10 is in RHS only", "Pair 2:20 is in LHS only")
152
+
153
+
154
+ def test_limit_str_length():
155
+ """Long and multi-line strings are truncated"""
156
+ check("a" * 100, "a" * 100)
157
+ check("a" * 100, "a" * 101, "{} ... != {} ...".format("a" * 76, "a" * 76))
158
+ check("a\nb", "a\nb")
159
+ check("a\nb", "a\nc", "a ... != a ...")
160
+
161
+
162
+ @pytest.mark.parametrize("nan", [np.nan, math.nan])
163
+ def test_nan(nan):
164
+ check(nan, nan)
165
+ check(nan, math.nan)
166
+ check(nan, np.nan)
167
+ check(0.0, nan, "0.0 != nan (abs: nan, rel: nan)")
168
+ check(nan, 0.0, "nan != 0.0 (abs: nan, rel: nan)")
169
+
170
+
171
+ def test_float():
172
+ """Float comparison with tolerance"""
173
+ # Test that floats are not accidentally rounded when printing
174
+ check(
175
+ 123456.7890123456,
176
+ 123456.789,
177
+ "123456.7890123456 != 123456.789 (abs: -1.2e-05, rel: -1.0e-10)",
178
+ rel_tol=0,
179
+ abs_tol=0,
180
+ )
181
+
182
+ check(123, 123.0000000000001) # difference is below rel_tol=1e-9
183
+
184
+ check(
185
+ 123456.7890123456,
186
+ 123456.789,
187
+ "123456.7890123456 != 123456.789 (abs: -1.2e-05, rel: -1.0e-10)",
188
+ rel_tol=1e-11,
189
+ abs_tol=0,
190
+ )
191
+ check(
192
+ 123456.7890123456,
193
+ 123456.789,
194
+ "123456.7890123456 != 123456.789 (abs: -1.2e-05, rel: -1.0e-10)",
195
+ rel_tol=0,
196
+ abs_tol=1e-5,
197
+ )
198
+
199
+ check(123456.7890123456, 123456.789, rel_tol=0, abs_tol=1e-4)
200
+ check(123456.7890123456, 123456.789, rel_tol=1e-7, abs_tol=0)
201
+
202
+ # Abs tol is RHS - LHS; rel tol is RHS / LHS - 1
203
+ check(80.0, 175.0, "80.0 != 175.0 (abs: 9.5e+01, rel: 1.2e+00)")
204
+
205
+ # Division by zero in relative delta
206
+ check(1.0, 0.0, "1.0 != 0.0 (abs: -1.0e+00, rel: -1.0e+00)")
207
+ check(0.0, 1.0, "0.0 != 1.0 (abs: 1.0e+00, rel: nan)")
208
+
209
+ # tolerance settings are retained when descending into containers
210
+ check(
211
+ [{"x": (1.0, 2.0)}],
212
+ [{"x": (1.1, 2.01)}],
213
+ "[0][x][0]: 1.0 != 1.1 (abs: 1.0e-01, rel: 1.0e-01)",
214
+ rel_tol=0.05,
215
+ abs_tol=0,
216
+ )
217
+
218
+ # tolerance > 1 in a comparison among int's
219
+ # note how int's are not cast to float when both lhs and rhs are int
220
+ check(1, 2, abs_tol=2)
221
+ check(2, 5, "2 != 5 (abs: 3.0e+00, rel: 1.5e+00)", abs_tol=2)
222
+
223
+
224
+ def test_int_vs_float():
225
+ """ints are silently cast to float and do not cause an
226
+ 'object type differs' error.
227
+ """
228
+ check(123, 123.0)
229
+ check(123, 123.0000000000001) # difference is below rel_tol=1e-9
230
+ check(1, 1.01, "1.0 != 1.01 (abs: 1.0e-02, rel: 1.0e-02)", abs_tol=0.001)
231
+ check(1, 1.01, abs_tol=0.1)
232
+
233
+
234
+ def test_numpy_types():
235
+ """scalar numpy data types (not to be confused with numpy arrays)
236
+ are silently cast to pure numpy types and do not cause an
237
+ 'object type differs' error. They're compared with tolerance.
238
+ """
239
+ check(123, np.int32(123))
240
+ check(np.int64(123), np.int32(123))
241
+ check(123, np.float64(123))
242
+ check(np.float32(123), np.float64(123))
243
+ check(
244
+ np.float64(1),
245
+ np.float64(1.01),
246
+ "1.0 != 1.01 (abs: 1.0e-02, rel: 1.0e-02)",
247
+ abs_tol=0.001,
248
+ )
249
+ check(np.float32(1), np.float32(1.01), abs_tol=0.1)
250
+ check(np.float64(1), np.float64(1.01), abs_tol=0.1)
251
+
252
+
253
+ def test_numpy():
254
+ # test tolerance and comparison of float vs. int
255
+ check(
256
+ np.array([1.0, 2.0, 3.01, 4.0001, 5.0]),
257
+ np.array([1, 4, 3, 4], dtype=np.int64),
258
+ "[data][1]: 2.0 != 4.0 (abs: 2.0e+00, rel: 1.0e+00)",
259
+ "[data][2]: 3.01 != 3.0 (abs: -1.0e-02, rel: -3.3e-03)",
260
+ "[dim_0]: LHS has 1 more elements than RHS",
261
+ "object type differs: ndarray<float64> != ndarray<int64>",
262
+ abs_tol=0.001,
263
+ )
264
+
265
+ # Tolerance > 1 in a comparison among int's
266
+ # Make sure that tolerance is not applied to RangeIndex comparison
267
+ check(
268
+ np.array([1, 2]),
269
+ np.array([2, 20, 3, 4]),
270
+ "[data][1]: 2 != 20 (abs: 1.8e+01, rel: 9.0e+00)",
271
+ "[dim_0]: RHS has 2 more elements than LHS",
272
+ abs_tol=10,
273
+ )
274
+
275
+ # array of numbers vs. dates; mismatched size
276
+ check(
277
+ np.array([1, 2], dtype=np.int64),
278
+ pd.to_datetime(["2000-01-01", "2000-01-02", "2000-01-03"]).values,
279
+ "[data][0]: 1 != 2000-01-01 00:00:00",
280
+ "[data][1]: 2 != 2000-01-02 00:00:00",
281
+ "[dim_0]: RHS has 1 more elements than LHS",
282
+ "object type differs: ndarray<int64> != ndarray<datetime64>",
283
+ )
284
+
285
+ # array of numbers vs. strings; mismatched size
286
+ check(
287
+ np.array([1, 2, 3], dtype=np.int64),
288
+ np.array(["foo", "bar"]),
289
+ "[data][0]: 1 != foo",
290
+ "[data][1]: 2 != bar",
291
+ "[dim_0]: LHS has 1 more elements than RHS",
292
+ "object type differs: ndarray<int64> != ndarray<<U...>",
293
+ )
294
+
295
+ # Mismatched dimensions
296
+ check(
297
+ np.array([1, 2, 3, 4]),
298
+ np.array([[1, 2], [3, 4]]),
299
+ "[dim_0]: LHS has 2 more elements than RHS",
300
+ "Dimension dim_1 is in RHS only",
301
+ )
302
+
303
+ # numpy vs. list
304
+ check(
305
+ np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64),
306
+ [[1, 4, 3], [4, 5, 6]],
307
+ "object type differs: ndarray<int64> != list",
308
+ "[data][0, 1]: 2 != 4 (abs: 2.0e+00, rel: 1.0e+00)",
309
+ )
310
+
311
+ # numpy vs. other object
312
+ check(
313
+ np.array([0, 0], dtype=np.int64),
314
+ 0,
315
+ "Dimension dim_0 is in LHS only",
316
+ "object type differs: ndarray<int64> != int",
317
+ )
318
+
319
+
320
+ def test_numpy_strings():
321
+ """Strings in numpy can be unicode (<U...), binary ascii (<S...)
322
+ or Python variable-length (object).
323
+ Test that these three types are not considered equivalent.
324
+ """
325
+ a = np.array(["foo"], dtype=object)
326
+ b = np.array(["foo"], dtype="U")
327
+ c = np.array(["foo"], dtype="S")
328
+ check(a, b, "object type differs: ndarray<object> != ndarray<<U...>")
329
+ check(
330
+ a,
331
+ c,
332
+ "object type differs: ndarray<object> != ndarray<|S...>",
333
+ "[data][0]: foo != b'foo'",
334
+ )
335
+ check(
336
+ b,
337
+ c,
338
+ "object type differs: ndarray<<U...> != ndarray<|S...>",
339
+ "[data][0]: foo != b'foo'",
340
+ )
341
+
342
+
343
+ @pytest.mark.parametrize("x,y", [("foo", "barbaz"), (b"foo", b"babaz")])
344
+ def test_numpy_string_slice(x, y):
345
+ """When slicing an array of strings, the output sub-dtype won't change.
346
+ Test that string that differs only by dtype-length are considered
347
+ equivalent.
348
+ """
349
+ a = np.array([x, y]) # dtype=<U6/<S6
350
+ b = a[:1] # dtype=<U6/<S6
351
+ c = np.array([x]) # dtype=<U3/<S3
352
+ assert a.dtype == b.dtype
353
+ assert a.dtype != c.dtype
354
+ check(b, c)
355
+
356
+
357
+ @pytest.mark.filterwarnings(
358
+ "ignore:Converting non-nanosecond precision datetime:UserWarning"
359
+ )
360
+ def test_numpy_dates():
361
+ a = pd.to_datetime(["2000-01-01", "2000-01-02", "2000-01-03", "NaT"]).values.astype(
362
+ "<M8[D]"
363
+ )
364
+ b = pd.to_datetime(
365
+ [
366
+ "2000-01-01", # identical
367
+ "2000-01-04", # differs, both LHS and RHS are non-NaT
368
+ "NaT", # non-NaT vs. NaT
369
+ "NaT", # NaT == NaT
370
+ # differences in sub-type must be ignored
371
+ ]
372
+ ).values.astype("<M8[ns]")
373
+ check(
374
+ a,
375
+ b,
376
+ "[data][1]: 2000-01-02 00:00:00 != 2000-01-04 00:00:00",
377
+ "[data][2]: 2000-01-03 00:00:00 != NaT",
378
+ )
379
+
380
+
381
+ def test_numpy_scalar():
382
+ check(
383
+ np.array(1, dtype=np.int64),
384
+ np.array(2.5),
385
+ "[data]: 1.0 != 2.5 (abs: 1.5e+00, rel: 1.5e+00)",
386
+ "object type differs: ndarray<int64> != ndarray<float64>",
387
+ )
388
+ check(
389
+ np.array(1, dtype=np.int64),
390
+ 2,
391
+ "[data]: 1 != 2 (abs: 1.0e+00, rel: 1.0e+00)",
392
+ "object type differs: ndarray<int64> != int",
393
+ )
394
+ check(np.array("foo"), np.array("bar"), "[data]: foo != bar")
395
+ # Note: datetime64 are not 0-dimensional arrays
396
+ check(
397
+ np.datetime64("2000-01-01"),
398
+ np.datetime64("2000-01-02"),
399
+ "2000-01-01 != 2000-01-02",
400
+ )
401
+ check(np.datetime64("2000-01-01"), np.datetime64("NaT"), "2000-01-01 != NaT")
402
+
403
+
404
+ def test_pandas_series():
405
+ # pd.Series
406
+ # Note that we're also testing that order is ignored
407
+ check(
408
+ pd.Series([1, 2, 3], index=["foo", "bar", "baz"], name="hello"),
409
+ pd.Series([1, 3, 4], index=["foo", "baz", "bar"], name="world"),
410
+ "[data][index=bar]: 2 != 4 (abs: 2.0e+00, rel: 1.0e+00)",
411
+ "[name]: hello != world",
412
+ )
413
+
414
+
415
+ def test_pandas_dataframe():
416
+ df1 = pd.DataFrame(
417
+ [[1, 2, 3], [4, 5, 6]], index=["x1", "x2"], columns=["y1", "y2", "y3"]
418
+ )
419
+ df2 = pd.DataFrame(
420
+ [[1, 3, 2], [4, 7, 5]], index=["x1", "x2"], columns=["y1", "y3", "y4"]
421
+ )
422
+
423
+ check(
424
+ df1,
425
+ df2,
426
+ "[data][column=y3, index=x2]: 6 != 7 (abs: 1.0e+00, rel: 1.7e-01)",
427
+ "[columns]: y2 is in LHS only",
428
+ "[columns]: y4 is in RHS only",
429
+ )
430
+
431
+
432
+ def test_pandas_index():
433
+ # Regular index
434
+ # Test that order is ignored
435
+ # Use huge abs_tol and rel_tol to test that tolerance is ignored
436
+ check(
437
+ pd.Index([1, 2, 3, 4]),
438
+ pd.Index([1, 3.000001, 2]),
439
+ "3.0 is in LHS only",
440
+ "3.000001 is in RHS only",
441
+ "4.0 is in LHS only",
442
+ "" if PANDAS_GE_200 else "object type differs: Int64Index != Float64Index",
443
+ rel_tol=10,
444
+ abs_tol=10,
445
+ )
446
+
447
+ check(pd.Index(["x", "y", "z"]), pd.Index(["y", "x"]), "z is in LHS only")
448
+
449
+
450
+ def test_pandas_rangeindex():
451
+ # RangeIndex(stop)
452
+ check(pd.RangeIndex(10), pd.RangeIndex(10))
453
+ check(pd.RangeIndex(8), pd.RangeIndex(10), "RHS has 2 more elements than LHS")
454
+ check(pd.RangeIndex(10), pd.RangeIndex(8), "LHS has 2 more elements than RHS")
455
+
456
+ # RangeIndex(start, stop, step, name)
457
+ check(pd.RangeIndex(1, 2, 3, name="x"), pd.RangeIndex(1, 2, 3, name="x"))
458
+ check(
459
+ pd.RangeIndex(0, 4, 1),
460
+ pd.RangeIndex(1, 4, 1),
461
+ "RangeIndex(start=0, stop=4, step=1) != RangeIndex(start=1, stop=4, step=1)",
462
+ )
463
+ check(
464
+ pd.RangeIndex(0, 4, 2),
465
+ pd.RangeIndex(0, 5, 2),
466
+ "RangeIndex(start=0, stop=4, step=2) != RangeIndex(start=0, stop=5, step=2)",
467
+ )
468
+ check(
469
+ pd.RangeIndex(0, 4, 2),
470
+ pd.RangeIndex(0, 4, 3),
471
+ "RangeIndex(start=0, stop=4, step=2) != RangeIndex(start=0, stop=4, step=3)",
472
+ )
473
+ check(
474
+ pd.RangeIndex(4, name="foo"),
475
+ pd.RangeIndex(4, name="bar"),
476
+ "RangeIndex(start=0, stop=4, step=1, name='foo') != "
477
+ "RangeIndex(start=0, stop=4, step=1, name='bar')",
478
+ )
479
+
480
+ # RangeIndex vs regular index
481
+ int_index = "Index" if PANDAS_GE_200 else "Int64Index"
482
+ check(
483
+ pd.RangeIndex(4),
484
+ pd.Index([0, 1, 2]),
485
+ "3 is in LHS only",
486
+ f"object type differs: RangeIndex != {int_index}",
487
+ )
488
+
489
+
490
+ def test_pandas_multiindex():
491
+ lhs = pd.MultiIndex.from_tuples(
492
+ [("bar", "one"), ("bar", "two"), ("baz", "one")], names=["l1", "l2"]
493
+ )
494
+ rhs = pd.MultiIndex.from_tuples(
495
+ [("baz", "one"), ("bar", "three"), ("bar", "one"), ("baz", "four")],
496
+ names=["l1", "l3"],
497
+ )
498
+ check(
499
+ lhs,
500
+ rhs,
501
+ "[data]: ('bar', 'three') is in RHS only",
502
+ "[data]: ('bar', 'two') is in LHS only",
503
+ "[data]: ('baz', 'four') is in RHS only",
504
+ "[names][1]: l2 != l3",
505
+ )
506
+
507
+ # MultiIndex vs. regular index
508
+ int_index = "Index" if PANDAS_GE_200 else "Int64Index"
509
+ check(
510
+ lhs,
511
+ pd.Index([0, 1, 2]),
512
+ "Cannot compare objects: MultiIndex([('bar', 'one'), "
513
+ f"..., {int_index}([0, 1, 2], dtype='int64')",
514
+ f"object type differs: MultiIndex != {int_index}",
515
+ )
516
+
517
+
518
+ def test_xarray():
519
+ # xarray.Dataset
520
+ ds1 = xarray.Dataset(
521
+ data_vars={"d1": ("x", [1, 2, 3]), "d2": (("y", "x"), [[4, 5, 6], [7, 8, 9]])},
522
+ coords={
523
+ "x": ("x", ["x1", "x2", "x3"]),
524
+ "y": ("y", ["y1", "y2"]),
525
+ "nonindex": ("x", ["ni1", "ni2", "ni3"]),
526
+ },
527
+ attrs={"some": "attr", "some2": 1},
528
+ )
529
+
530
+ ds2 = ds1.copy(deep=True)
531
+ del ds2["d1"]
532
+ ds2["d2"][0, 0] = 10
533
+ ds2["nonindex"][1] = "ni4"
534
+ ds2.attrs["some2"] = 2
535
+ ds2.attrs["other"] = "someval"
536
+
537
+ # Older versions of xarray don't have the 'Size 24B' bit
538
+ d1_str = str(ds1["d1"].stack({"__stacked__": ["x"]})).splitlines()[0].strip()
539
+
540
+ check(
541
+ ds1,
542
+ ds2,
543
+ "[attrs]: Pair other:someval is in RHS only",
544
+ "[attrs][some2]: 1 != 2 (abs: 1.0e+00, rel: 1.0e+00)",
545
+ "[coords][nonindex][x=x2]: ni2 != ni4",
546
+ f"[data_vars]: Pair d1:{d1_str} ... is in LHS only",
547
+ "[data_vars][d2][x=x1, y=y1]: 4 != 10 (abs: 6.0e+00, rel: 1.5e+00)",
548
+ )
549
+
550
+ check(
551
+ ds1,
552
+ ds2,
553
+ "[attrs]: Pair other:someval is in RHS only",
554
+ "[coords][nonindex][x=x2]: ni2 != ni4",
555
+ f"[data_vars]: Pair d1:{d1_str} ... is in LHS only",
556
+ abs_tol=7,
557
+ )
558
+
559
+ # xarray.DataArray
560
+ # Note: this sample has a non-index coordinate
561
+ # In Linux, int maps to int64 while in Windows it maps to int32
562
+ da1 = ds1["d2"].astype(np.int64)
563
+ da1.name = "foo"
564
+ da1.attrs["attr1"] = 1.0
565
+ da1.attrs["attr2"] = 1.0
566
+
567
+ # Test dimension order does not matter
568
+ check(da1, da1.T)
569
+
570
+ da2 = da1.copy(deep=True).astype(float)
571
+ da2[0, 0] *= 1.0 + 1e-7
572
+ da2[0, 1] *= 1.0 + 1e-10
573
+ da2["nonindex"][1] = "ni4"
574
+ da2.name = "bar"
575
+ da2.attrs["attr1"] = 1.0 + 1e-7
576
+ da2.attrs["attr2"] = 1.0 + 1e-10
577
+ da2.attrs["attr3"] = "new"
578
+
579
+ check(
580
+ da1,
581
+ da2,
582
+ "[attrs]: Pair attr3:new is in RHS only",
583
+ "[attrs][attr1]: 1.0 != 1.0000001 (abs: 1.0e-07, rel: 1.0e-07)",
584
+ "[coords][nonindex][x=x2]: ni2 != ni4",
585
+ "[data][x=x1, y=y1]: 4.0 != 4.0000004 (abs: 4.0e-07, rel: 1.0e-07)",
586
+ "[name]: foo != bar",
587
+ "object type differs: DataArray<int64> != DataArray<float64>",
588
+ )
589
+
590
+
591
+ def test_xarray_scalar():
592
+ da1 = xarray.DataArray(1.0)
593
+ da2 = xarray.DataArray(1.0 + 1e-7)
594
+ check(da1, da2, "[data]: 1.0 != 1.0000001 (abs: 1.0e-07, rel: 1.0e-07)")
595
+ da2 = xarray.DataArray(1.0 + 1e-10)
596
+ check(da1, da2)
597
+
598
+
599
+ def test_xarray_no_coords():
600
+ check(
601
+ xarray.DataArray([0, 1]),
602
+ xarray.DataArray([0, 2]),
603
+ "[data][1]: 1 != 2 (abs: 1.0e+00, rel: 1.0e+00)",
604
+ )
605
+
606
+
607
+ def test_xarray_mismatched_dims():
608
+ # 0-dimensional vs. 1+-dimensional
609
+ check(
610
+ xarray.DataArray(1.0),
611
+ xarray.DataArray([0.0, 0.1]),
612
+ "[index]: Dimension dim_0 is in RHS only",
613
+ )
614
+
615
+ # both arrays are 1+-dimensional
616
+ check(
617
+ xarray.DataArray([0, 1], dims=["x"]),
618
+ xarray.DataArray([[0, 1], [2, 3]], dims=["x", "y"]),
619
+ "[index]: Dimension y is in RHS only",
620
+ )
621
+
622
+
623
+ def test_xarray_size0():
624
+ check(
625
+ xarray.DataArray([]),
626
+ xarray.DataArray([1.0]),
627
+ "[index][dim_0]: RHS has 1 more elements than LHS",
628
+ )
629
+
630
+
631
+ def test_xarray_stacked():
632
+ # Pre-stacked dims, mixed with non-stacked ones
633
+ da1 = xarray.DataArray(
634
+ [[[0, 1], [2, 3]], [[4, 5], [6, 7]]],
635
+ dims=["x", "y", "z"],
636
+ coords={"x": ["x1", "x2"]},
637
+ )
638
+
639
+ # Stacked and unstacked dims are compared point by point,
640
+ # while still pointing out the difference in stacking
641
+ da2 = da1.copy(deep=True)
642
+ da2[0, 0, 0] = 10
643
+ da2 = da2.stack(s=["x", "y"])
644
+ check(
645
+ da1,
646
+ da2,
647
+ "[data][x=x1, y=0, z=0]: 0 != 10 (abs: 1.0e+01, rel: nan)",
648
+ "[index]: Dimension s is in RHS only",
649
+ "[index]: Dimension x is in LHS only",
650
+ "[index]: Dimension y is in LHS only",
651
+ )
652
+
653
+
654
+ def test_brief_dims_1d():
655
+ # all dims are brief
656
+ da1 = xarray.DataArray([1, 2, 3], dims=["x"])
657
+ da2 = xarray.DataArray([1, 3, 4], dims=["x"])
658
+ check(
659
+ da1,
660
+ da2,
661
+ "[data][x=1]: 2 != 3 (abs: 1.0e+00, rel: 5.0e-01)",
662
+ "[data][x=2]: 3 != 4 (abs: 1.0e+00, rel: 3.3e-01)",
663
+ )
664
+ check(da1, da2, "[data]: 2 differences", brief_dims=["x"])
665
+ check(da1, da2, "[data]: 2 differences", brief_dims="all")
666
+
667
+ check(da1, da1)
668
+ check(da1, da1, brief_dims=["x"])
669
+ check(da1, da1, brief_dims="all")
670
+
671
+
672
+ def test_brief_dims_nd():
673
+ # some dims are brief
674
+ da1 = xarray.DataArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dims=["r", "c"])
675
+ da2 = xarray.DataArray([[1, 5, 4], [4, 5, 6], [7, 8, 0]], dims=["r", "c"])
676
+ check(
677
+ da1,
678
+ da2,
679
+ "[data][c=1, r=0]: 2 != 5 (abs: 3.0e+00, rel: 1.5e+00)",
680
+ "[data][c=2, r=0]: 3 != 4 (abs: 1.0e+00, rel: 3.3e-01)",
681
+ "[data][c=2, r=2]: 9 != 0 (abs: -9.0e+00, rel: -1.0e+00)",
682
+ )
683
+ check(
684
+ da1,
685
+ da2,
686
+ "[data][c=1]: 1 differences",
687
+ "[data][c=2]: 2 differences",
688
+ brief_dims=["r"],
689
+ )
690
+ check(da1, da2, "[data]: 3 differences", brief_dims="all")
691
+
692
+ check(da1, da1)
693
+ check(da1, da1, brief_dims=["r"])
694
+ check(da1, da1, brief_dims="all")
695
+
696
+
697
+ def test_brief_dims_nested():
698
+ """xarray object not at the first level, and not all variables have all
699
+ brief_dims
700
+ """
701
+ lhs = {
702
+ "foo": xarray.Dataset(
703
+ data_vars={
704
+ "x": (("r", "c"), [[1, 2, 3], [4, 5, 6]]),
705
+ "y": ("c", [1, 2, 3]),
706
+ }
707
+ )
708
+ }
709
+ rhs = {
710
+ "foo": xarray.Dataset(
711
+ data_vars={
712
+ "x": (("r", "c"), [[1, 2, 4], [4, 5, 6]]),
713
+ "y": ("c", [1, 2, 4]),
714
+ }
715
+ )
716
+ }
717
+ check(
718
+ lhs,
719
+ rhs,
720
+ "[foo][data_vars][x][c=2, r=0]: 3 != 4 (abs: 1.0e+00, rel: 3.3e-01)",
721
+ "[foo][data_vars][y][c=2]: 3 != 4 (abs: 1.0e+00, rel: 3.3e-01)",
722
+ )
723
+ check(
724
+ lhs,
725
+ rhs,
726
+ "[foo][data_vars][x][c=2]: 1 differences",
727
+ "[foo][data_vars][y][c=2]: 3 != 4 (abs: 1.0e+00, rel: 3.3e-01)",
728
+ brief_dims=["r"],
729
+ )
730
+ check(
731
+ lhs,
732
+ rhs,
733
+ "[foo][data_vars][x]: 1 differences",
734
+ "[foo][data_vars][y]: 1 differences",
735
+ brief_dims="all",
736
+ )
737
+
738
+
739
+ def test_nested1():
740
+ # Subclasses of the supported types must only produce a type error
741
+ class MyDict(dict):
742
+ pass
743
+
744
+ class MyList(list):
745
+ pass
746
+
747
+ class MyTuple(tuple):
748
+ pass
749
+
750
+ # Two complex arrays which are identical
751
+ lhs = {
752
+ "foo": [1, 2, (5.2, "asd")],
753
+ "bar": None,
754
+ "baz": np.array([1, 2, 3]),
755
+ None: [np.array([1, 2, 3])],
756
+ }
757
+ rhs = MyDict(
758
+ {
759
+ "foo": MyList([1, 2, MyTuple((5.20000000001, "asd"))]),
760
+ "bar": None,
761
+ "baz": np.array([1, 2, 3]),
762
+ None: [np.array([1, 2, 3])],
763
+ }
764
+ )
765
+ check(
766
+ lhs,
767
+ rhs,
768
+ "[foo]: object type differs: list != MyList",
769
+ "[foo][2]: object type differs: tuple != MyTuple",
770
+ "object type differs: dict != MyDict",
771
+ )
772
+
773
+
774
+ def test_nested2():
775
+ lhs = {
776
+ "foo": [1, 2, ("asd", 5.2), 4],
777
+ "bar": np.array([1, 2, 3, 4], dtype=np.int64),
778
+ "baz": np.array([1, 2, 3], dtype=np.int64),
779
+ "key_only_lhs": None,
780
+ }
781
+ rhs = {
782
+ # type changed from tuple to list
783
+ # a string content has changed
784
+ # LHS outermost list is longer
785
+ # RHS innermost list is longer
786
+ "foo": [1, 2, ["lol", 5.2, 3]],
787
+ # numpy dtype has changed
788
+ # LHS is longer
789
+ "bar": np.array([1, 2, 3], dtype=np.float64),
790
+ # numpy vs. list
791
+ "baz": [1, 2, 3],
792
+ # Test string truncation
793
+ "key_only_rhs": "a" * 200,
794
+ }
795
+
796
+ check(
797
+ lhs,
798
+ rhs,
799
+ "[bar]: object type differs: ndarray<int64> != ndarray<float64>",
800
+ "[bar][dim_0]: LHS has 1 more elements than RHS",
801
+ "[baz]: object type differs: ndarray<int64> != list",
802
+ "[foo]: LHS has 1 more elements than RHS: [4]",
803
+ "[foo][2]: RHS has 1 more elements than LHS: [3]",
804
+ "[foo][2]: object type differs: tuple != list",
805
+ "[foo][2][0]: asd != lol",
806
+ "Pair key_only_lhs:None is in LHS only",
807
+ "Pair key_only_rhs:%s ... is in RHS only" % ("a" * 76),
808
+ )
809
+
810
+
811
+ def test_custom_classes():
812
+ check(
813
+ Rectangle(1, 2),
814
+ Rectangle(1.1, 2.7),
815
+ "[h]: 2.0 != 2.7 (abs: 7.0e-01, rel: 3.5e-01)",
816
+ abs_tol=0.5,
817
+ )
818
+
819
+ check(
820
+ Rectangle(1, 2),
821
+ Drawing(3, 2),
822
+ "[w]: 1 != 3 (abs: 2.0e+00, rel: 2.0e+00)",
823
+ "object type differs: Rectangle != Drawing",
824
+ )
825
+
826
+ # Unregistered classes can still be compared but without
827
+ # tolerance or recursion
828
+ check(Circle(4), Circle(4.1), "Circle(4) != Circle(4.1)", abs_tol=0.5)
829
+
830
+ check(
831
+ Rectangle(4, 4),
832
+ Square(4),
833
+ "Cannot compare objects: Rectangle(4, 4), Square(4)",
834
+ "object type differs: Rectangle != Square",
835
+ )
836
+
837
+ check(
838
+ Circle(4),
839
+ Square(4),
840
+ "Cannot compare objects: Circle(4), Square(4)",
841
+ "object type differs: Circle != Square",
842
+ )
843
+
844
+
845
+ @requires_dask
846
+ @pytest.mark.parametrize(
847
+ "chunk_lhs,chunk_rhs",
848
+ [(None, None), (None, -1), (None, 2), ({"x": (1, 2)}, {"x": (2, 1)})],
849
+ )
850
+ def test_dask(chunk_lhs, chunk_rhs):
851
+ lhs = xarray.DataArray(["a", "b", "c"], dims=["x"])
852
+ rhs = xarray.DataArray(["a", "b", "d"], dims=["x"])
853
+ if chunk_lhs:
854
+ lhs = lhs.chunk(chunk_lhs)
855
+ if chunk_rhs:
856
+ rhs = rhs.chunk(chunk_rhs)
857
+
858
+ check(lhs, rhs, "[data][x=2]: c != d")