earthcatalog 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. earthcatalog/__init__.py +164 -0
  2. earthcatalog/async_http_client.py +1006 -0
  3. earthcatalog/config.py +97 -0
  4. earthcatalog/engines/__init__.py +308 -0
  5. earthcatalog/engines/rustac_engine.py +142 -0
  6. earthcatalog/engines/stac_geoparquet_engine.py +126 -0
  7. earthcatalog/exceptions.py +471 -0
  8. earthcatalog/grid_systems.py +1114 -0
  9. earthcatalog/ingestion_pipeline.py +2281 -0
  10. earthcatalog/input_readers.py +603 -0
  11. earthcatalog/job_tracking.py +485 -0
  12. earthcatalog/pipeline.py +606 -0
  13. earthcatalog/schema_generator.py +911 -0
  14. earthcatalog/spatial_resolver.py +1207 -0
  15. earthcatalog/stac_hooks.py +754 -0
  16. earthcatalog/statistics.py +677 -0
  17. earthcatalog/storage_backends.py +548 -0
  18. earthcatalog/tests/__init__.py +1 -0
  19. earthcatalog/tests/conftest.py +76 -0
  20. earthcatalog/tests/test_all_grids.py +793 -0
  21. earthcatalog/tests/test_async_http.py +700 -0
  22. earthcatalog/tests/test_cli_and_storage.py +230 -0
  23. earthcatalog/tests/test_config.py +245 -0
  24. earthcatalog/tests/test_dask_integration.py +580 -0
  25. earthcatalog/tests/test_e2e_synthetic.py +1624 -0
  26. earthcatalog/tests/test_engines.py +272 -0
  27. earthcatalog/tests/test_exceptions.py +346 -0
  28. earthcatalog/tests/test_file_structure.py +245 -0
  29. earthcatalog/tests/test_input_readers.py +666 -0
  30. earthcatalog/tests/test_integration.py +200 -0
  31. earthcatalog/tests/test_integration_async.py +283 -0
  32. earthcatalog/tests/test_job_tracking.py +603 -0
  33. earthcatalog/tests/test_multi_file_input.py +336 -0
  34. earthcatalog/tests/test_passthrough_hook.py +196 -0
  35. earthcatalog/tests/test_pipeline.py +684 -0
  36. earthcatalog/tests/test_pipeline_components.py +665 -0
  37. earthcatalog/tests/test_schema_generator.py +506 -0
  38. earthcatalog/tests/test_spatial_resolver.py +413 -0
  39. earthcatalog/tests/test_stac_hooks.py +776 -0
  40. earthcatalog/tests/test_statistics.py +477 -0
  41. earthcatalog/tests/test_storage_backends.py +236 -0
  42. earthcatalog/tests/test_validation.py +435 -0
  43. earthcatalog/tests/test_workers.py +653 -0
  44. earthcatalog/validation.py +921 -0
  45. earthcatalog/workers.py +682 -0
  46. earthcatalog-0.2.0.dist-info/METADATA +333 -0
  47. earthcatalog-0.2.0.dist-info/RECORD +50 -0
  48. earthcatalog-0.2.0.dist-info/WHEEL +5 -0
  49. earthcatalog-0.2.0.dist-info/entry_points.txt +3 -0
  50. earthcatalog-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,665 @@
1
+ # test_pipeline_components.py
2
+ """Tests for pipeline components: configurations, data classes, and utilities.
3
+
4
+ This module tests the new pipeline components including BatchConfig, ConsolidationConfig,
5
+ ShardInfo, PartitionResult, BatchResult, and utility functions.
6
+ """
7
+
8
+ from unittest.mock import MagicMock
9
+
10
+ import geopandas as gpd
11
+ import pytest
12
+ from shapely.geometry import Point
13
+
14
+ from earthcatalog.pipeline import (
15
+ BatchConfig,
16
+ BatchResult,
17
+ ConsolidationConfig,
18
+ PartitionResult,
19
+ ShardInfo,
20
+ chunk_urls,
21
+ group_shards_by_partition,
22
+ merge_geodataframes,
23
+ )
24
+
25
+
26
+ class TestBatchConfig:
27
+ """Tests for BatchConfig dataclass."""
28
+
29
+ def test_default_values(self):
30
+ """Test BatchConfig has sensible defaults."""
31
+ config = BatchConfig()
32
+ assert config.batch_size == 1000
33
+ assert config.items_per_shard == 10000
34
+ assert config.enable_concurrent_http is True
35
+ assert config.concurrent_requests == 50
36
+ assert config.connection_pool_size == 100
37
+ assert config.request_timeout == 30
38
+ assert config.retry_attempts == 3
39
+ assert config.retry_delay == 1.0
40
+
41
+ def test_custom_values(self):
42
+ """Test BatchConfig with custom values."""
43
+ config = BatchConfig(
44
+ batch_size=500,
45
+ items_per_shard=5000,
46
+ enable_concurrent_http=False,
47
+ concurrent_requests=25,
48
+ )
49
+ assert config.batch_size == 500
50
+ assert config.items_per_shard == 5000
51
+ assert config.enable_concurrent_http is False
52
+ assert config.concurrent_requests == 25
53
+
54
+ @pytest.mark.parametrize(
55
+ "kwargs,error_pattern",
56
+ [
57
+ ({"batch_size": 0}, "batch_size must be positive"),
58
+ ({"batch_size": -1}, "batch_size must be positive"),
59
+ ({"items_per_shard": 0}, "items_per_shard must be positive"),
60
+ ({"concurrent_requests": 0}, "concurrent_requests must be positive"),
61
+ ({"connection_pool_size": -5}, "connection_pool_size must be positive"),
62
+ ({"request_timeout": 0}, "request_timeout must be positive"),
63
+ ({"retry_attempts": -1}, "retry_attempts must be non-negative"),
64
+ ({"retry_delay": -0.5}, "retry_delay must be non-negative"),
65
+ ],
66
+ ids=[
67
+ "batch_size_zero",
68
+ "batch_size_negative",
69
+ "items_per_shard_zero",
70
+ "concurrent_requests_zero",
71
+ "connection_pool_size_negative",
72
+ "request_timeout_zero",
73
+ "retry_attempts_negative",
74
+ "retry_delay_negative",
75
+ ],
76
+ )
77
+ def test_validation_rejects_invalid_values(self, kwargs, error_pattern):
78
+ """Test validation rejects invalid configuration values."""
79
+ with pytest.raises(ValueError, match=error_pattern):
80
+ BatchConfig(**kwargs)
81
+
82
+ def test_repr(self):
83
+ """Test string representation."""
84
+ config = BatchConfig(batch_size=100)
85
+ repr_str = repr(config)
86
+ assert "BatchConfig" in repr_str
87
+ assert "batch_size=100" in repr_str
88
+
89
+ def test_bool_valid(self):
90
+ """Test __bool__ returns True for valid config."""
91
+ config = BatchConfig()
92
+ assert bool(config) is True
93
+
94
+ def test_bool_would_be_invalid(self):
95
+ """Test that a config that would be invalid fails at construction."""
96
+ # Since validation happens in __post_init__, we can't create an invalid config
97
+ # This test documents that behavior
98
+ with pytest.raises(ValueError):
99
+ BatchConfig(batch_size=-1)
100
+
101
+
102
+ class TestConsolidationConfig:
103
+ """Tests for ConsolidationConfig dataclass."""
104
+
105
+ def test_default_values(self):
106
+ """Test ConsolidationConfig has sensible defaults."""
107
+ config = ConsolidationConfig()
108
+ assert config.strategy == "efficient"
109
+ assert config.max_memory_per_partition_mb == 1024
110
+ assert config.enable_streaming_merge is True
111
+ assert config.s3_multipart_threshold_mb == 100
112
+ assert config.sort_key == "datetime"
113
+ assert config.sort_ascending is True
114
+ assert config.deduplicate_key == "id"
115
+ assert config.keep_duplicates == "last"
116
+
117
+ def test_custom_values(self):
118
+ """Test ConsolidationConfig with custom values."""
119
+ config = ConsolidationConfig(
120
+ strategy="legacy",
121
+ max_memory_per_partition_mb=2048,
122
+ enable_streaming_merge=False,
123
+ )
124
+ assert config.strategy == "legacy"
125
+ assert config.max_memory_per_partition_mb == 2048
126
+ assert config.enable_streaming_merge is False
127
+
128
+ @pytest.mark.parametrize(
129
+ "kwargs,error_pattern",
130
+ [
131
+ ({"strategy": "invalid"}, "strategy must be"),
132
+ ({"max_memory_per_partition_mb": 0}, "max_memory_per_partition_mb must be positive"),
133
+ ({"s3_multipart_threshold_mb": -10}, "s3_multipart_threshold_mb must be positive"),
134
+ ({"keep_duplicates": "middle"}, "keep_duplicates must be"),
135
+ ],
136
+ ids=["invalid_strategy", "memory_zero", "s3_threshold_negative", "invalid_keep_duplicates"],
137
+ )
138
+ def test_validation_rejects_invalid_values(self, kwargs, error_pattern):
139
+ """Test validation rejects invalid configuration values."""
140
+ with pytest.raises(ValueError, match=error_pattern):
141
+ ConsolidationConfig(**kwargs)
142
+
143
+ def test_repr(self):
144
+ """Test string representation."""
145
+ config = ConsolidationConfig(strategy="legacy")
146
+ repr_str = repr(config)
147
+ assert "ConsolidationConfig" in repr_str
148
+ assert "strategy='legacy'" in repr_str
149
+
150
+ def test_bool_valid(self):
151
+ """Test __bool__ returns True for valid config."""
152
+ config = ConsolidationConfig()
153
+ assert bool(config) is True
154
+
155
+
156
+ class TestShardInfo:
157
+ """Tests for ShardInfo dataclass."""
158
+
159
+ def test_basic_creation(self):
160
+ """Test basic ShardInfo creation."""
161
+ shard = ShardInfo(
162
+ shard_path="/scratch/shards/shard.parquet",
163
+ partition_key="dataset/partition=h3/level=2/82",
164
+ item_count=100,
165
+ worker_id="worker-0",
166
+ )
167
+ assert shard.shard_path == "/scratch/shards/shard.parquet"
168
+ assert shard.partition_key == "dataset/partition=h3/level=2/82"
169
+ assert shard.item_count == 100
170
+ assert shard.worker_id == "worker-0"
171
+ assert shard.shard_id == 0 # default
172
+
173
+ def test_with_shard_id(self):
174
+ """Test ShardInfo with custom shard_id."""
175
+ shard = ShardInfo(
176
+ shard_path="/path/shard.parquet",
177
+ partition_key="p1",
178
+ item_count=50,
179
+ worker_id="w1",
180
+ shard_id=5,
181
+ )
182
+ assert shard.shard_id == 5
183
+
184
+ def test_repr(self):
185
+ """Test string representation."""
186
+ shard = ShardInfo(
187
+ shard_path="/path/shard.parquet",
188
+ partition_key="p1",
189
+ item_count=50,
190
+ worker_id="w1",
191
+ )
192
+ repr_str = repr(shard)
193
+ assert "ShardInfo" in repr_str
194
+ assert "items=50" in repr_str
195
+
196
+ def test_bool_with_items(self):
197
+ """Test __bool__ returns True when shard has items."""
198
+ shard = ShardInfo(
199
+ shard_path="/path/shard.parquet",
200
+ partition_key="p1",
201
+ item_count=10,
202
+ worker_id="w1",
203
+ )
204
+ assert bool(shard) is True
205
+
206
+ def test_bool_without_items(self):
207
+ """Test __bool__ returns False when shard has no items."""
208
+ shard = ShardInfo(
209
+ shard_path="/path/shard.parquet",
210
+ partition_key="p1",
211
+ item_count=0,
212
+ worker_id="w1",
213
+ )
214
+ assert bool(shard) is False
215
+
216
+ def test_to_dict(self):
217
+ """Test conversion to dictionary."""
218
+ shard = ShardInfo(
219
+ shard_path="/path/shard.parquet",
220
+ partition_key="p1",
221
+ item_count=50,
222
+ worker_id="w1",
223
+ shard_id=3,
224
+ )
225
+ d = shard.to_dict()
226
+ assert d["shard_path"] == "/path/shard.parquet"
227
+ assert d["partition_key"] == "p1"
228
+ assert d["item_count"] == 50
229
+ assert d["worker_id"] == "w1"
230
+ assert d["shard_id"] == 3
231
+
232
+ def test_from_dict(self):
233
+ """Test creation from dictionary."""
234
+ data = {
235
+ "shard_path": "/path/shard.parquet",
236
+ "partition_key": "p1",
237
+ "item_count": 50,
238
+ "worker_id": "w1",
239
+ "shard_id": 3,
240
+ }
241
+ shard = ShardInfo.from_dict(data)
242
+ assert shard.shard_path == "/path/shard.parquet"
243
+ assert shard.partition_key == "p1"
244
+ assert shard.item_count == 50
245
+ assert shard.worker_id == "w1"
246
+ assert shard.shard_id == 3
247
+
248
+ def test_from_dict_minimal(self):
249
+ """Test creation from minimal dictionary."""
250
+ data = {
251
+ "shard_path": "/path/shard.parquet",
252
+ "item_count": 50,
253
+ "worker_id": "w1",
254
+ }
255
+ shard = ShardInfo.from_dict(data)
256
+ assert shard.shard_path == "/path/shard.parquet"
257
+ assert shard.partition_key == "" # default
258
+ assert shard.shard_id == 0 # default
259
+
260
+
261
+ class TestPartitionResult:
262
+ """Tests for PartitionResult dataclass."""
263
+
264
+ def test_basic_creation(self):
265
+ """Test basic PartitionResult creation."""
266
+ result = PartitionResult(
267
+ partition_key="dataset/partition=h3/level=2/82",
268
+ item_count=1000,
269
+ existing_count=500,
270
+ new_count=550,
271
+ duplicates_removed=50,
272
+ )
273
+ assert result.partition_key == "dataset/partition=h3/level=2/82"
274
+ assert result.item_count == 1000
275
+ assert result.existing_count == 500
276
+ assert result.new_count == 550
277
+ assert result.duplicates_removed == 50
278
+ assert result.success is True
279
+ assert result.error == ""
280
+
281
+ def test_repr(self):
282
+ """Test string representation."""
283
+ result = PartitionResult(partition_key="p1", item_count=100)
284
+ repr_str = repr(result)
285
+ assert "PartitionResult" in repr_str
286
+ assert "items=100" in repr_str
287
+ assert "OK" in repr_str
288
+
289
+ def test_repr_failed(self):
290
+ """Test string representation for failed result."""
291
+ result = PartitionResult(partition_key="p1", success=False, error="Disk full")
292
+ repr_str = repr(result)
293
+ assert "FAILED" in repr_str
294
+ assert "Disk full" in repr_str
295
+
296
+ @pytest.mark.parametrize(
297
+ "item_count,success,expected_bool",
298
+ [
299
+ (100, True, True), # success with items
300
+ (0, True, False), # success without items
301
+ (100, False, False), # failed with items
302
+ ],
303
+ ids=["success_with_items", "success_no_items", "failed"],
304
+ )
305
+ def test_bool_behavior(self, item_count, success, expected_bool):
306
+ """Test __bool__ returns correct value based on success and item_count."""
307
+ result = PartitionResult(partition_key="p1", item_count=item_count, success=success)
308
+ assert bool(result) is expected_bool
309
+
310
+ def test_to_dict(self):
311
+ """Test conversion to dictionary."""
312
+ result = PartitionResult(
313
+ partition_key="p1",
314
+ item_count=100,
315
+ existing_count=50,
316
+ new_count=60,
317
+ duplicates_removed=10,
318
+ final_path="/path/to/file.parquet",
319
+ )
320
+ d = result.to_dict()
321
+ assert d["partition"] == "p1"
322
+ assert d["item_count"] == 100
323
+ assert d["existing_count"] == 50
324
+ assert d["new_count"] == 60
325
+ assert d["duplicates_removed"] == 10
326
+ assert d["final_path"] == "/path/to/file.parquet"
327
+
328
+ def test_from_dict(self):
329
+ """Test creation from dictionary."""
330
+ data = {
331
+ "partition": "p1",
332
+ "item_count": 100,
333
+ "existing_count": 50,
334
+ "new_count": 60,
335
+ "duplicates_removed": 10,
336
+ }
337
+ result = PartitionResult.from_dict(data)
338
+ assert result.partition_key == "p1"
339
+ assert result.item_count == 100
340
+
341
+ def test_empty_factory(self):
342
+ """Test empty factory method."""
343
+ result = PartitionResult.empty("p1")
344
+ assert result.partition_key == "p1"
345
+ assert result.item_count == 0
346
+ assert result.success is True
347
+
348
+ def test_failed_factory(self):
349
+ """Test failed factory method."""
350
+ result = PartitionResult.failed("p1", "Network error")
351
+ assert result.partition_key == "p1"
352
+ assert result.success is False
353
+ assert result.error == "Network error"
354
+
355
+
356
+ class TestBatchResult:
357
+ """Tests for BatchResult dataclass."""
358
+
359
+ def test_basic_creation(self):
360
+ """Test basic BatchResult creation."""
361
+ result = BatchResult(
362
+ worker_id="worker-0",
363
+ urls_processed=100,
364
+ urls_succeeded=95,
365
+ urls_failed=5,
366
+ )
367
+ assert result.worker_id == "worker-0"
368
+ assert result.urls_processed == 100
369
+ assert result.urls_succeeded == 95
370
+ assert result.urls_failed == 5
371
+ assert result.shards == []
372
+
373
+ def test_with_shards(self):
374
+ """Test BatchResult with shards."""
375
+ shards = [
376
+ ShardInfo(shard_path="/a.parquet", partition_key="p1", item_count=50, worker_id="w1"),
377
+ ShardInfo(shard_path="/b.parquet", partition_key="p2", item_count=45, worker_id="w1"),
378
+ ]
379
+ result = BatchResult(
380
+ worker_id="w1",
381
+ shards=shards,
382
+ urls_processed=100,
383
+ urls_succeeded=95,
384
+ urls_failed=5,
385
+ )
386
+ assert len(result.shards) == 2
387
+ assert result.total_items == 95
388
+
389
+ def test_success_rate(self):
390
+ """Test success_rate calculation."""
391
+ result = BatchResult(
392
+ worker_id="w1",
393
+ urls_processed=100,
394
+ urls_succeeded=80,
395
+ urls_failed=20,
396
+ )
397
+ assert result.success_rate == 0.8
398
+
399
+ def test_success_rate_zero_processed(self):
400
+ """Test success_rate when no URLs processed."""
401
+ result = BatchResult(worker_id="w1", urls_processed=0)
402
+ assert result.success_rate == 0.0
403
+
404
+ def test_total_items(self):
405
+ """Test total_items property."""
406
+ shards = [
407
+ ShardInfo(shard_path="/a.parquet", partition_key="p1", item_count=100, worker_id="w1"),
408
+ ShardInfo(shard_path="/b.parquet", partition_key="p2", item_count=200, worker_id="w1"),
409
+ ShardInfo(shard_path="/c.parquet", partition_key="p3", item_count=300, worker_id="w1"),
410
+ ]
411
+ result = BatchResult(worker_id="w1", shards=shards)
412
+ assert result.total_items == 600
413
+
414
+ def test_repr(self):
415
+ """Test string representation."""
416
+ result = BatchResult(
417
+ worker_id="w1",
418
+ urls_processed=100,
419
+ urls_succeeded=95,
420
+ urls_failed=5,
421
+ )
422
+ repr_str = repr(result)
423
+ assert "BatchResult" in repr_str
424
+ assert "worker='w1'" in repr_str
425
+
426
+ def test_bool_with_items(self):
427
+ """Test __bool__ returns True when result has items."""
428
+ shards = [ShardInfo(shard_path="/a.parquet", partition_key="p1", item_count=10, worker_id="w1")]
429
+ result = BatchResult(worker_id="w1", shards=shards)
430
+ assert bool(result) is True
431
+
432
+ def test_bool_without_items(self):
433
+ """Test __bool__ returns False when result has no items."""
434
+ result = BatchResult(worker_id="w1")
435
+ assert bool(result) is False
436
+
437
+ def test_to_dict(self):
438
+ """Test conversion to dictionary."""
439
+ shards = [ShardInfo(shard_path="/a.parquet", partition_key="p1", item_count=50, worker_id="w1")]
440
+ result = BatchResult(
441
+ worker_id="w1",
442
+ shards=shards,
443
+ urls_processed=100,
444
+ urls_succeeded=95,
445
+ urls_failed=5,
446
+ )
447
+ d = result.to_dict()
448
+ assert d["worker_id"] == "w1"
449
+ assert d["urls_processed"] == 100
450
+ assert d["total_items"] == 50
451
+ assert d["success_rate"] == 0.95
452
+ assert len(d["shards"]) == 1
453
+
454
+
455
+ class TestMergeGeoDataFrames:
456
+ """Tests for merge_geodataframes function."""
457
+
458
+ def test_merge_empty_list(self):
459
+ """Test merging empty list returns empty GeoDataFrame."""
460
+ result = merge_geodataframes([])
461
+ assert isinstance(result, gpd.GeoDataFrame)
462
+ assert len(result) == 0
463
+
464
+ def test_merge_single_dataframe(self):
465
+ """Test merging single dataframe returns copy."""
466
+ gdf = gpd.GeoDataFrame({"id": ["a", "b"], "value": [1, 2], "geometry": [Point(0, 0), Point(1, 1)]})
467
+ result = merge_geodataframes([gdf])
468
+ assert len(result) == 2
469
+ # Verify it's a copy, not the same object
470
+ assert result is not gdf
471
+
472
+ def test_merge_multiple_dataframes(self):
473
+ """Test merging multiple dataframes."""
474
+ gdf1 = gpd.GeoDataFrame({"id": ["a", "b"], "value": [1, 2], "geometry": [Point(0, 0), Point(1, 1)]})
475
+ gdf2 = gpd.GeoDataFrame({"id": ["c", "d"], "value": [3, 4], "geometry": [Point(2, 2), Point(3, 3)]})
476
+ result = merge_geodataframes([gdf1, gdf2])
477
+ assert len(result) == 4
478
+
479
+ def test_merge_with_deduplication(self):
480
+ """Test merging with duplicate removal."""
481
+ gdf1 = gpd.GeoDataFrame({"id": ["a", "b"], "value": [1, 2], "geometry": [Point(0, 0), Point(1, 1)]})
482
+ gdf2 = gpd.GeoDataFrame({"id": ["b", "c"], "value": [20, 3], "geometry": [Point(1, 1), Point(2, 2)]})
483
+ result = merge_geodataframes([gdf1, gdf2], deduplicate_key="id", keep="last")
484
+ assert len(result) == 3
485
+ # Check that "b" has the value from the second dataframe
486
+ b_row = result[result["id"] == "b"]
487
+ assert b_row["value"].iloc[0] == 20
488
+
489
+ def test_merge_keep_first(self):
490
+ """Test merging keeping first duplicate."""
491
+ gdf1 = gpd.GeoDataFrame({"id": ["a", "b"], "value": [1, 2], "geometry": [Point(0, 0), Point(1, 1)]})
492
+ gdf2 = gpd.GeoDataFrame({"id": ["b", "c"], "value": [20, 3], "geometry": [Point(1, 1), Point(2, 2)]})
493
+ result = merge_geodataframes([gdf1, gdf2], deduplicate_key="id", keep="first")
494
+ assert len(result) == 3
495
+ # Check that "b" has the value from the first dataframe
496
+ b_row = result[result["id"] == "b"]
497
+ assert b_row["value"].iloc[0] == 2
498
+
499
+ def test_merge_with_sort(self):
500
+ """Test merging with sorting."""
501
+ gdf1 = gpd.GeoDataFrame({"id": ["a", "c"], "value": [3, 1], "geometry": [Point(0, 0), Point(2, 2)]})
502
+ gdf2 = gpd.GeoDataFrame({"id": ["b"], "value": [2], "geometry": [Point(1, 1)]})
503
+ result = merge_geodataframes([gdf1, gdf2], sort_key="value", sort_ascending=True)
504
+ assert list(result["value"]) == [1, 2, 3]
505
+
506
+ def test_merge_sort_descending(self):
507
+ """Test merging with descending sort."""
508
+ gdf = gpd.GeoDataFrame(
509
+ {"id": ["a", "b", "c"], "value": [1, 3, 2], "geometry": [Point(0, 0), Point(1, 1), Point(2, 2)]}
510
+ )
511
+ result = merge_geodataframes([gdf], sort_key="value", sort_ascending=False)
512
+ assert list(result["value"]) == [3, 2, 1]
513
+
514
+
515
+ class TestGroupShardsByPartition:
516
+ """Tests for group_shards_by_partition function."""
517
+
518
+ def test_group_shard_info_objects(self):
519
+ """Test grouping ShardInfo objects."""
520
+ shards = [
521
+ ShardInfo(shard_path="/a.parquet", partition_key="p1", item_count=10, worker_id="w1"),
522
+ ShardInfo(shard_path="/b.parquet", partition_key="p1", item_count=20, worker_id="w2"),
523
+ ShardInfo(shard_path="/c.parquet", partition_key="p2", item_count=15, worker_id="w1"),
524
+ ]
525
+ groups = group_shards_by_partition(shards)
526
+ assert len(groups) == 2
527
+ assert len(groups["p1"]) == 2
528
+ assert len(groups["p2"]) == 1
529
+ assert "/a.parquet" in groups["p1"]
530
+ assert "/b.parquet" in groups["p1"]
531
+ assert "/c.parquet" in groups["p2"]
532
+
533
+ def test_group_dictionaries(self):
534
+ """Test grouping dictionary-based shard info."""
535
+ shards = [
536
+ {"shard_path": "/a.parquet", "partition_key": "p1", "item_count": 10},
537
+ {"shard_path": "/b.parquet", "partition_key": "p2", "item_count": 20},
538
+ ]
539
+ groups = group_shards_by_partition(shards)
540
+ assert len(groups) == 2
541
+ assert groups["p1"] == ["/a.parquet"]
542
+ assert groups["p2"] == ["/b.parquet"]
543
+
544
+ def test_group_empty_list(self):
545
+ """Test grouping empty list."""
546
+ groups = group_shards_by_partition([])
547
+ assert groups == {}
548
+
549
+ def test_group_skips_empty_partition_key(self):
550
+ """Test that shards with empty partition keys are skipped."""
551
+ shards = [
552
+ ShardInfo(shard_path="/a.parquet", partition_key="", item_count=10, worker_id="w1"),
553
+ ShardInfo(shard_path="/b.parquet", partition_key="p1", item_count=20, worker_id="w1"),
554
+ ]
555
+ groups = group_shards_by_partition(shards)
556
+ assert len(groups) == 1
557
+ assert "p1" in groups
558
+
559
+
560
+ class TestChunkUrls:
561
+ """Tests for chunk_urls function."""
562
+
563
+ @pytest.mark.parametrize(
564
+ "urls,num_chunks,expected_total_urls,expected_chunks_range",
565
+ [
566
+ (["url1", "url2", "url3", "url4"], 2, 4, (2, 2)), # even split
567
+ (["url1", "url2", "url3", "url4", "url5"], 2, 5, (2, 3)), # uneven split
568
+ (["url1", "url2"], 10, 2, (1, 2)), # more chunks than urls
569
+ (["url1"], 3, 1, (1, 1)), # single url
570
+ ([], 5, 0, (0, 0)), # empty list
571
+ ],
572
+ ids=["even_split", "uneven_split", "more_chunks_than_urls", "single_url", "empty_list"],
573
+ )
574
+ def test_chunk_behavior(self, urls, num_chunks, expected_total_urls, expected_chunks_range):
575
+ """Test chunking behavior for various inputs."""
576
+ chunks = chunk_urls(urls, num_chunks)
577
+ total = sum(len(c) for c in chunks)
578
+ assert total == expected_total_urls
579
+ min_chunks, max_chunks = expected_chunks_range
580
+ assert min_chunks <= len(chunks) <= max_chunks or (expected_total_urls == 0 and len(chunks) == 0)
581
+
582
+ @pytest.mark.parametrize(
583
+ "invalid_num_chunks",
584
+ [0, -1],
585
+ ids=["zero", "negative"],
586
+ )
587
+ def test_chunk_invalid_num_chunks(self, invalid_num_chunks):
588
+ """Test that invalid num_chunks raises error."""
589
+ with pytest.raises(ValueError, match="num_chunks must be positive"):
590
+ chunk_urls(["url1"], invalid_num_chunks)
591
+
592
+
593
+ class TestIntegration:
594
+ """Integration tests for pipeline components."""
595
+
596
+ def test_shard_info_roundtrip(self):
597
+ """Test ShardInfo can roundtrip through dict conversion."""
598
+ original = ShardInfo(
599
+ shard_path="/path/to/shard.parquet",
600
+ partition_key="dataset/partition=h3/level=2/82",
601
+ item_count=100,
602
+ worker_id="worker-0-abc123",
603
+ shard_id=5,
604
+ )
605
+ d = original.to_dict()
606
+ restored = ShardInfo.from_dict(d)
607
+ assert restored.shard_path == original.shard_path
608
+ assert restored.partition_key == original.partition_key
609
+ assert restored.item_count == original.item_count
610
+ assert restored.worker_id == original.worker_id
611
+ assert restored.shard_id == original.shard_id
612
+
613
+ def test_partition_result_roundtrip(self):
614
+ """Test PartitionResult can roundtrip through dict conversion."""
615
+ original = PartitionResult(
616
+ partition_key="p1",
617
+ item_count=1000,
618
+ existing_count=500,
619
+ new_count=550,
620
+ duplicates_removed=50,
621
+ final_path="/path/to/partition.parquet",
622
+ )
623
+ d = original.to_dict()
624
+ restored = PartitionResult.from_dict(d)
625
+ assert restored.partition_key == original.partition_key
626
+ assert restored.item_count == original.item_count
627
+ assert restored.existing_count == original.existing_count
628
+ assert restored.new_count == original.new_count
629
+ assert restored.duplicates_removed == original.duplicates_removed
630
+
631
+ def test_batch_result_with_stats(self):
632
+ """Test BatchResult can hold stats object."""
633
+ # Create a mock stats object
634
+ mock_stats = MagicMock()
635
+ mock_stats.total_items = 100
636
+
637
+ result = BatchResult(
638
+ worker_id="w1",
639
+ urls_processed=100,
640
+ urls_succeeded=95,
641
+ urls_failed=5,
642
+ stats=mock_stats,
643
+ )
644
+ assert result.stats is not None
645
+ assert result.stats.total_items == 100
646
+
647
+ def test_configs_work_together(self):
648
+ """Test that BatchConfig and ConsolidationConfig can be used together."""
649
+ batch_config = BatchConfig(
650
+ batch_size=500,
651
+ items_per_shard=5000,
652
+ concurrent_requests=25,
653
+ )
654
+ consolidation_config = ConsolidationConfig(
655
+ strategy="efficient",
656
+ max_memory_per_partition_mb=512,
657
+ enable_streaming_merge=True,
658
+ )
659
+
660
+ # Both should be valid
661
+ assert bool(batch_config) is True
662
+ assert bool(consolidation_config) is True
663
+
664
+ # Can use values together
665
+ assert batch_config.items_per_shard < consolidation_config.max_memory_per_partition_mb * 1000