earthcatalog 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. earthcatalog/__init__.py +164 -0
  2. earthcatalog/async_http_client.py +1006 -0
  3. earthcatalog/config.py +97 -0
  4. earthcatalog/engines/__init__.py +308 -0
  5. earthcatalog/engines/rustac_engine.py +142 -0
  6. earthcatalog/engines/stac_geoparquet_engine.py +126 -0
  7. earthcatalog/exceptions.py +471 -0
  8. earthcatalog/grid_systems.py +1114 -0
  9. earthcatalog/ingestion_pipeline.py +2281 -0
  10. earthcatalog/input_readers.py +603 -0
  11. earthcatalog/job_tracking.py +485 -0
  12. earthcatalog/pipeline.py +606 -0
  13. earthcatalog/schema_generator.py +911 -0
  14. earthcatalog/spatial_resolver.py +1207 -0
  15. earthcatalog/stac_hooks.py +754 -0
  16. earthcatalog/statistics.py +677 -0
  17. earthcatalog/storage_backends.py +548 -0
  18. earthcatalog/tests/__init__.py +1 -0
  19. earthcatalog/tests/conftest.py +76 -0
  20. earthcatalog/tests/test_all_grids.py +793 -0
  21. earthcatalog/tests/test_async_http.py +700 -0
  22. earthcatalog/tests/test_cli_and_storage.py +230 -0
  23. earthcatalog/tests/test_config.py +245 -0
  24. earthcatalog/tests/test_dask_integration.py +580 -0
  25. earthcatalog/tests/test_e2e_synthetic.py +1624 -0
  26. earthcatalog/tests/test_engines.py +272 -0
  27. earthcatalog/tests/test_exceptions.py +346 -0
  28. earthcatalog/tests/test_file_structure.py +245 -0
  29. earthcatalog/tests/test_input_readers.py +666 -0
  30. earthcatalog/tests/test_integration.py +200 -0
  31. earthcatalog/tests/test_integration_async.py +283 -0
  32. earthcatalog/tests/test_job_tracking.py +603 -0
  33. earthcatalog/tests/test_multi_file_input.py +336 -0
  34. earthcatalog/tests/test_passthrough_hook.py +196 -0
  35. earthcatalog/tests/test_pipeline.py +684 -0
  36. earthcatalog/tests/test_pipeline_components.py +665 -0
  37. earthcatalog/tests/test_schema_generator.py +506 -0
  38. earthcatalog/tests/test_spatial_resolver.py +413 -0
  39. earthcatalog/tests/test_stac_hooks.py +776 -0
  40. earthcatalog/tests/test_statistics.py +477 -0
  41. earthcatalog/tests/test_storage_backends.py +236 -0
  42. earthcatalog/tests/test_validation.py +435 -0
  43. earthcatalog/tests/test_workers.py +653 -0
  44. earthcatalog/validation.py +921 -0
  45. earthcatalog/workers.py +682 -0
  46. earthcatalog-0.2.0.dist-info/METADATA +333 -0
  47. earthcatalog-0.2.0.dist-info/RECORD +50 -0
  48. earthcatalog-0.2.0.dist-info/WHEEL +5 -0
  49. earthcatalog-0.2.0.dist-info/entry_points.txt +3 -0
  50. earthcatalog-0.2.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,684 @@
1
+ """Tests for core pipeline functionality."""
2
+
3
+ import shutil
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Any
7
+ from unittest.mock import Mock, patch
8
+
9
+ import pandas as pd
10
+ import pytest
11
+ import requests
12
+
13
+ # Import from earthcatalog package
14
+ from earthcatalog import ingestion_pipeline
15
+
16
+
17
+ # Module-level functions for multiprocessing tests (must be pickleable)
18
+ def mock_process_fn(chunk, worker_id):
19
+ """Mock processing function for multiprocessing tests."""
20
+ return f"processed_{len(chunk)}_items_by_worker_{worker_id}"
21
+
22
+
23
+ def mock_consolidate_fn(partition_key, shard_paths):
24
+ """Mock consolidation function for multiprocessing tests."""
25
+ return f"consolidated_{partition_key}_{len(shard_paths)}_shards"
26
+
27
+
28
+ class TestProcessingConfig:
29
+ """Test ProcessingConfig dataclass."""
30
+
31
+ def test_default_config(self):
32
+ """Test default configuration values."""
33
+ config = ingestion_pipeline.ProcessingConfig(
34
+ input_file="test.parquet", output_catalog="./output", scratch_location="./scratch"
35
+ )
36
+
37
+ assert config.input_file == "test.parquet"
38
+ assert config.output_catalog == "./output"
39
+ assert config.scratch_location == "./scratch"
40
+ assert config.grid_system == "h3"
41
+ assert config.grid_resolution == 2
42
+ assert config.temporal_bin == "month"
43
+ assert config.sort_key == "datetime"
44
+ assert config.sort_ascending is True
45
+ assert config.items_per_shard == 10000
46
+ assert config.max_workers == 8
47
+
48
+ def test_custom_config(self):
49
+ """Test custom configuration values."""
50
+ config = ingestion_pipeline.ProcessingConfig(
51
+ input_file="custom.parquet",
52
+ output_catalog="./custom_output",
53
+ scratch_location="./custom_scratch",
54
+ grid_system="s2",
55
+ grid_resolution=10,
56
+ temporal_bin="day",
57
+ sort_key="properties.datetime",
58
+ sort_ascending=False,
59
+ items_per_shard=5000,
60
+ max_workers=4,
61
+ )
62
+
63
+ assert config.grid_system == "s2"
64
+ assert config.grid_resolution == 10
65
+ assert config.temporal_bin == "day"
66
+ assert config.sort_key == "properties.datetime"
67
+ assert config.sort_ascending is False
68
+ assert config.items_per_shard == 5000
69
+ assert config.max_workers == 4
70
+
71
+ def test_to_dict_includes_all_fields(self):
72
+ """to_dict should include all configuration fields."""
73
+ config = ingestion_pipeline.ProcessingConfig(
74
+ input_file="input.parquet",
75
+ output_catalog="./catalog",
76
+ scratch_location="./scratch",
77
+ )
78
+
79
+ data = config.to_dict()
80
+
81
+ # Check required fields
82
+ assert data["input_file"] == "input.parquet"
83
+ assert data["output_catalog"] == "./catalog"
84
+ assert data["scratch_location"] == "./scratch"
85
+
86
+ # Check some default fields
87
+ assert data["grid_system"] == "h3"
88
+ assert data["grid_resolution"] == 2
89
+ assert data["batch_size"] == 1000
90
+ assert data["enable_concurrent_http"] is True
91
+
92
+ def test_from_dict_restores_config(self):
93
+ """from_dict should restore config from dictionary."""
94
+ data = {
95
+ "input_file": "test.parquet",
96
+ "output_catalog": "s3://bucket/catalog",
97
+ "scratch_location": "s3://bucket/scratch",
98
+ "grid_resolution": 5,
99
+ "batch_size": 2000,
100
+ "concurrent_requests": 100,
101
+ }
102
+
103
+ config = ingestion_pipeline.ProcessingConfig.from_dict(data)
104
+
105
+ assert config.input_file == "test.parquet"
106
+ assert config.output_catalog == "s3://bucket/catalog"
107
+ assert config.grid_resolution == 5
108
+ assert config.batch_size == 2000
109
+ assert config.concurrent_requests == 100
110
+ # Defaults should still apply for missing fields
111
+ assert config.grid_system == "h3"
112
+
113
+ def test_to_dict_from_dict_roundtrip(self):
114
+ """to_dict and from_dict should roundtrip correctly."""
115
+ original = ingestion_pipeline.ProcessingConfig(
116
+ input_file="input.parquet",
117
+ output_catalog="./catalog",
118
+ scratch_location="./scratch",
119
+ grid_resolution=4,
120
+ batch_size=500,
121
+ concurrent_requests=75,
122
+ )
123
+
124
+ data = original.to_dict()
125
+ restored = ingestion_pipeline.ProcessingConfig.from_dict(data)
126
+
127
+ assert restored.input_file == original.input_file
128
+ assert restored.grid_resolution == original.grid_resolution
129
+ assert restored.batch_size == original.batch_size
130
+ assert restored.concurrent_requests == original.concurrent_requests
131
+
132
+ def test_config_hash_consistent(self):
133
+ """config_hash should return same hash for same settings."""
134
+ config1 = ingestion_pipeline.ProcessingConfig(
135
+ input_file="a.parquet",
136
+ output_catalog="./catalog1",
137
+ scratch_location="./scratch1",
138
+ )
139
+ config2 = ingestion_pipeline.ProcessingConfig(
140
+ input_file="b.parquet", # Different input
141
+ output_catalog="./catalog2", # Different output
142
+ scratch_location="./scratch2", # Different scratch
143
+ )
144
+
145
+ # Same processing settings should give same hash
146
+ assert config1.config_hash() == config2.config_hash()
147
+
148
+ def test_config_hash_different_for_different_settings(self):
149
+ """config_hash should differ when processing settings differ."""
150
+ config1 = ingestion_pipeline.ProcessingConfig(
151
+ input_file="input.parquet",
152
+ output_catalog="./catalog",
153
+ scratch_location="./scratch",
154
+ grid_resolution=4,
155
+ )
156
+ config2 = ingestion_pipeline.ProcessingConfig(
157
+ input_file="input.parquet",
158
+ output_catalog="./catalog",
159
+ scratch_location="./scratch",
160
+ grid_resolution=5, # Different resolution
161
+ )
162
+
163
+ assert config1.config_hash() != config2.config_hash()
164
+
165
+
166
+ class TestLocalProcessor:
167
+ """Test LocalProcessor class."""
168
+
169
+ def test_processor_initialization(self):
170
+ """Test processor initialization."""
171
+ processor = ingestion_pipeline.LocalProcessor(n_workers=4)
172
+ assert processor.executor._max_workers == 4
173
+ processor.close()
174
+
175
+ def test_process_urls(self):
176
+ """Test URL processing functionality."""
177
+ processor = ingestion_pipeline.LocalProcessor(n_workers=2)
178
+
179
+ url_chunks = [["url1", "url2"], ["url3", "url4"]]
180
+ results = processor.process_urls(url_chunks, mock_process_fn)
181
+
182
+ assert len(results) == 2
183
+ assert "processed_2_items_by_worker_0" in results
184
+ assert "processed_2_items_by_worker_1" in results
185
+
186
+ processor.close()
187
+
188
+ def test_consolidate_shards(self):
189
+ """Test shard consolidation functionality."""
190
+ processor = ingestion_pipeline.LocalProcessor(n_workers=2)
191
+
192
+ partition_items = [("partition1", ["shard1", "shard2"]), ("partition2", ["shard3"])]
193
+ results = processor.consolidate_shards(partition_items, mock_consolidate_fn)
194
+
195
+ assert len(results) == 2
196
+ assert "consolidated_partition1_2_shards" in results
197
+ assert "consolidated_partition2_1_shards" in results
198
+
199
+ processor.close()
200
+
201
+
202
+ class TestSTACIngestionPipeline:
203
+ """Test STACIngestionPipeline class."""
204
+
205
+ def setup_method(self):
206
+ """Set up test fixtures."""
207
+ self.temp_dir = tempfile.mkdtemp()
208
+ self.config = ingestion_pipeline.ProcessingConfig(
209
+ input_file=str(Path(self.temp_dir) / "input.parquet"),
210
+ output_catalog=str(Path(self.temp_dir) / "catalog"),
211
+ scratch_location=str(Path(self.temp_dir) / "scratch"),
212
+ items_per_shard=10,
213
+ max_workers=2,
214
+ )
215
+
216
+ # Create sample input parquet
217
+ urls = ["https://example.com/item1.json", "https://example.com/item2.json"]
218
+ df = pd.DataFrame({"url": urls})
219
+ df.to_parquet(self.config.input_file, index=False)
220
+
221
+ self.processor = ingestion_pipeline.LocalProcessor(n_workers=2)
222
+ self.pipeline = ingestion_pipeline.STACIngestionPipeline(self.config, self.processor)
223
+
224
+ def teardown_method(self):
225
+ """Clean up test fixtures."""
226
+ self.processor.close()
227
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
228
+
229
+ def test_pipeline_initialization(self):
230
+ """Test pipeline initialization."""
231
+ assert self.pipeline.config == self.config
232
+ assert self.pipeline.processor == self.processor
233
+ assert self.pipeline.grid is not None
234
+ assert self.pipeline.storage is not None
235
+ assert self.pipeline.scratch_storage is not None
236
+
237
+ def test_read_input_urls(self):
238
+ """Test reading URLs from input parquet."""
239
+ urls = self.pipeline._read_input_urls()
240
+ assert len(urls) == 2
241
+ assert "https://example.com/item1.json" in urls
242
+ assert "https://example.com/item2.json" in urls
243
+
244
+ def test_read_input_urls_missing_column(self):
245
+ """Test error handling for missing URL column."""
246
+ # Create parquet without URL column
247
+ df = pd.DataFrame({"other_column": [1, 2]})
248
+ bad_path = str(Path(self.temp_dir) / "bad.parquet")
249
+ df.to_parquet(bad_path, index=False)
250
+
251
+ self.config.input_file = bad_path
252
+ pipeline = ingestion_pipeline.STACIngestionPipeline(self.config, self.processor)
253
+
254
+ with pytest.raises(ValueError, match="must contain 'url' column"):
255
+ pipeline._read_input_urls()
256
+
257
+ def test_chunk_urls(self):
258
+ """Test URL chunking functionality."""
259
+ urls = [f"url_{i}" for i in range(10)]
260
+ chunks = self.pipeline._chunk_urls(urls, 3)
261
+
262
+ assert len(chunks) == 3
263
+ # With current implementation: chunk_size = 10 // 3 = 3
264
+ # chunks = [0:3], [3:6], [6:9], [9:10] -> last gets combined with previous
265
+ assert len(chunks[0]) == 3
266
+ assert len(chunks[1]) == 3
267
+ assert len(chunks[2]) == 4 # Gets the remainder
268
+ assert sum(len(chunk) for chunk in chunks) == 10
269
+
270
+ def test_extract_temporal_hive_parts(self):
271
+ """Test Hive-style temporal partition extraction."""
272
+ # Test different temporal bins
273
+ item_month = {"properties": {"datetime": "2024-01-15T10:30:00Z"}}
274
+ item_year = {"properties": {"datetime": "2024-01-15T10:30:00Z"}}
275
+ item_day = {"properties": {"datetime": "2024-01-15T10:30:00Z"}}
276
+ item_no_datetime: dict[str, Any] = {"properties": {}}
277
+
278
+ # Test month binning (Hive-style)
279
+ self.config.temporal_bin = "month"
280
+ result = self.pipeline._extract_temporal_hive_parts(item_month)
281
+ assert result == "year=2024/month=01"
282
+
283
+ # Test year binning (Hive-style)
284
+ self.config.temporal_bin = "year"
285
+ result = self.pipeline._extract_temporal_hive_parts(item_year)
286
+ assert result == "year=2024"
287
+
288
+ # Test day binning (Hive-style)
289
+ self.config.temporal_bin = "day"
290
+ result = self.pipeline._extract_temporal_hive_parts(item_day)
291
+ assert result == "year=2024/month=01/day=15"
292
+
293
+ # Test missing datetime
294
+ result = self.pipeline._extract_temporal_hive_parts(item_no_datetime)
295
+ assert result == "unknown"
296
+
297
+ def test_compute_partition_key(self):
298
+ """Test partition key computation with Hive-style temporal parts."""
299
+ # Mock grid system
300
+ mock_grid = Mock()
301
+ mock_grid.tiles_for_geometry_with_spanning_detection.return_value = (["h3_cell_123"], False)
302
+ mock_grid.get_global_partition_threshold.return_value = 10 # Return int instead of Mock
303
+ self.pipeline.grid = mock_grid
304
+
305
+ item = {
306
+ "geometry": {"type": "Point", "coordinates": [-105.0, 40.0]},
307
+ "properties": {"datetime": "2024-01-15T10:30:00Z"},
308
+ }
309
+
310
+ self.config.temporal_bin = "month"
311
+ partition_key = self.pipeline._compute_partition_key(item)
312
+
313
+ assert partition_key == "unknown_mission/partition=h3/level=2/h3_cell_123/year=2024/month=01"
314
+ mock_grid.tiles_for_geometry_with_spanning_detection.assert_called_once()
315
+
316
+ def test_get_final_partition_path(self):
317
+ """Test final partition path generation with Hive-style."""
318
+ partition_key = "mission/partition=h3/level=2/h3_cell_123/year=2024/month=01"
319
+ expected_path = f"{self.config.output_catalog}/{partition_key}/items.parquet"
320
+
321
+ result = self.pipeline._get_final_partition_path(partition_key)
322
+ assert result == expected_path
323
+
324
+ @patch("requests.get")
325
+ def test_download_stac_item_http(self, mock_get):
326
+ """Test downloading STAC item from HTTP URL."""
327
+ # Mock successful HTTP response
328
+ mock_response = Mock()
329
+ mock_response.json.return_value = {"id": "test_item", "type": "Feature"}
330
+ mock_response.raise_for_status.return_value = None
331
+ mock_get.return_value = mock_response
332
+
333
+ url = "https://example.com/item.json"
334
+ result = self.pipeline._download_stac_item(url)
335
+
336
+ assert result == {"id": "test_item", "type": "Feature"}
337
+ mock_get.assert_called_once_with(url, timeout=30)
338
+
339
+ @patch("fsspec.filesystem")
340
+ def test_download_stac_item_s3(self, mock_fs):
341
+ """Test downloading STAC item from S3."""
342
+ # Mock S3 filesystem
343
+ mock_s3_fs = Mock()
344
+ mock_file = Mock()
345
+ mock_file.__enter__ = Mock(return_value=mock_file)
346
+ mock_file.__exit__ = Mock(return_value=None)
347
+ mock_s3_fs.open.return_value = mock_file
348
+ mock_fs.return_value = mock_s3_fs
349
+
350
+ # Mock json.load
351
+ with patch("json.load", return_value={"id": "s3_item"}):
352
+ url = "s3://bucket/item.json"
353
+ result = self.pipeline._download_stac_item(url)
354
+
355
+ assert result == {"id": "s3_item"}
356
+ mock_fs.assert_called_once_with("s3")
357
+
358
+ def test_download_stac_item_error(self):
359
+ """Test error handling in STAC item download."""
360
+ url = "https://invalid-url.com/item.json"
361
+
362
+ with patch("requests.get", side_effect=requests.exceptions.RequestException("Network error")):
363
+ result = self.pipeline._download_stac_item(url)
364
+ assert result is None
365
+
366
+
367
+ class TestDaskDistributedProcessor:
368
+ """Test DaskDistributedProcessor class."""
369
+
370
+ def test_dask_processor_import_error(self):
371
+ """Test DaskDistributedProcessor raises ImportError when dask.distributed is not available."""
372
+ # Mock the dask.distributed module to raise ImportError
373
+ with patch.dict("sys.modules", {"dask.distributed": None}):
374
+ with patch("builtins.__import__") as mock_import:
375
+
376
+ def import_side_effect(name, *args, **kwargs):
377
+ if name == "dask.distributed":
378
+ raise ImportError("No module named 'dask.distributed'")
379
+ return __import__(name, *args, **kwargs)
380
+
381
+ mock_import.side_effect = import_side_effect
382
+
383
+ with pytest.raises(ImportError, match="Dask distributed required"):
384
+ ingestion_pipeline.DaskDistributedProcessor()
385
+
386
+
387
+ class TestPartitionAwareSharding:
388
+ """Test partition-aware shard organization."""
389
+
390
+ def setup_method(self):
391
+ """Set up test fixtures."""
392
+ self.temp_dir = tempfile.mkdtemp()
393
+ self.config = ingestion_pipeline.ProcessingConfig(
394
+ input_file=str(Path(self.temp_dir) / "input.parquet"),
395
+ output_catalog=str(Path(self.temp_dir) / "catalog"),
396
+ scratch_location=str(Path(self.temp_dir) / "scratch"),
397
+ items_per_shard=2, # Small shards for testing
398
+ max_workers=2,
399
+ )
400
+
401
+ # Create sample input parquet
402
+ urls = [
403
+ "https://example.com/item1.json",
404
+ "https://example.com/item2.json",
405
+ "https://example.com/item3.json",
406
+ ]
407
+ df = pd.DataFrame({"url": urls})
408
+ df.to_parquet(self.config.input_file, index=False)
409
+
410
+ self.processor = ingestion_pipeline.LocalProcessor(n_workers=2)
411
+ self.pipeline = ingestion_pipeline.STACIngestionPipeline(self.config, self.processor)
412
+
413
+ def teardown_method(self):
414
+ """Clean up test fixtures."""
415
+ self.processor.close()
416
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
417
+
418
+ def test_partition_aware_shard_path(self):
419
+ """Test that shard paths include partition information."""
420
+ # Mock sample STAC items
421
+ items = [
422
+ {
423
+ "id": "item1",
424
+ "geometry": {"type": "Point", "coordinates": [-105.0, 40.0]},
425
+ "properties": {"datetime": "2024-01-15T10:30:00Z"},
426
+ },
427
+ {
428
+ "id": "item2",
429
+ "geometry": {"type": "Point", "coordinates": [-105.1, 40.1]},
430
+ "properties": {"datetime": "2024-01-16T11:30:00Z"},
431
+ },
432
+ ]
433
+
434
+ # Mock grid system to return predictable tile
435
+ mock_grid = Mock()
436
+ mock_grid.tiles_for_geometry_with_spanning_detection.return_value = (["h3_cell_123"], False)
437
+ self.pipeline.grid = mock_grid
438
+
439
+ # Test shard path generation with Hive-style temporal parts
440
+ partition_key = self.pipeline._compute_partition_key(items[0])
441
+
442
+ # The new implementation should include Hive-style temporal parts in partition key
443
+ assert partition_key == "unknown_mission/partition=h3/level=2/h3_cell_123/year=2024/month=01"
444
+
445
+ def test_consolidation_config(self):
446
+ """Test consolidation configuration options."""
447
+ # Test that config options are available
448
+ config = ingestion_pipeline.ProcessingConfig(
449
+ input_file="test.parquet",
450
+ output_catalog="./output",
451
+ scratch_location="./scratch",
452
+ max_memory_per_partition_mb=512,
453
+ enable_streaming_merge=True,
454
+ )
455
+
456
+ assert config.max_memory_per_partition_mb == 512
457
+ assert config.enable_streaming_merge is True
458
+
459
+ def test_shard_grouping_optimization(self):
460
+ """Test that partition-aware sharding simplifies grouping with Hive-style paths."""
461
+ # Test the new grouping logic structure with Hive-style temporal parts
462
+ expected_groups = {
463
+ "h3_cell_123/year=2024/month=01": [
464
+ f"{self.config.scratch_location}/shards/h3_cell_123/year=2024/month=01/worker-1.parquet",
465
+ f"{self.config.scratch_location}/shards/h3_cell_123/year=2024/month=01/worker-2.parquet",
466
+ ],
467
+ "h3_cell_456/year=2024/month=01": [
468
+ f"{self.config.scratch_location}/shards/h3_cell_456/year=2024/month=01/worker-1.parquet"
469
+ ],
470
+ }
471
+
472
+ # We'll implement the new grouping logic that uses path structure
473
+ # For now, test the expected structure
474
+ assert len(expected_groups) == 2
475
+ assert len(expected_groups["h3_cell_123/year=2024/month=01"]) == 2
476
+ assert len(expected_groups["h3_cell_456/year=2024/month=01"]) == 1
477
+
478
+
479
+ class TestS3StorageEnhancements:
480
+ """Test enhanced S3 storage backend operations."""
481
+
482
+ def test_atomic_s3_operations_config(self):
483
+ """Test S3 atomic operations configuration."""
484
+ config = ingestion_pipeline.ProcessingConfig(
485
+ input_file="test.parquet",
486
+ output_catalog="s3://bucket/catalog",
487
+ scratch_location="s3://bucket/scratch",
488
+ s3_multipart_threshold_mb=100,
489
+ temp_dir_location="/tmp",
490
+ )
491
+
492
+ assert config.s3_multipart_threshold_mb == 100
493
+ assert config.temp_dir_location == "/tmp"
494
+
495
+ def test_streaming_merge_memory_limits(self):
496
+ """Test memory-bounded streaming merge."""
497
+ # Test configuration for memory limits
498
+ config = ingestion_pipeline.ProcessingConfig(
499
+ input_file="test.parquet",
500
+ output_catalog="./output",
501
+ scratch_location="./scratch",
502
+ max_memory_per_partition_mb=256,
503
+ )
504
+
505
+ # Should use streaming when partition size exceeds limit
506
+ assert config.max_memory_per_partition_mb == 256
507
+
508
+
509
+ class TestIntegratedPipelineWorkflow:
510
+ """Test integrated pipeline workflow with improvements."""
511
+
512
+ def setup_method(self):
513
+ """Set up test fixtures."""
514
+ self.temp_dir = tempfile.mkdtemp()
515
+
516
+ def teardown_method(self):
517
+ """Clean up test fixtures."""
518
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
519
+
520
+ def test_end_to_end_partition_aware_workflow(self):
521
+ """Test end-to-end workflow with partition-aware sharding."""
522
+ # This will be a comprehensive test once we implement the changes
523
+ config = ingestion_pipeline.ProcessingConfig(
524
+ input_file=str(Path(self.temp_dir) / "input.parquet"),
525
+ output_catalog=str(Path(self.temp_dir) / "catalog"),
526
+ scratch_location=str(Path(self.temp_dir) / "scratch"),
527
+ max_workers=2,
528
+ items_per_shard=2,
529
+ )
530
+
531
+ # Create minimal test input
532
+ urls = ["https://example.com/item1.json"]
533
+ df = pd.DataFrame({"url": urls})
534
+ df.to_parquet(config.input_file, index=False)
535
+
536
+ processor = ingestion_pipeline.LocalProcessor(n_workers=2)
537
+ pipeline = ingestion_pipeline.STACIngestionPipeline(config, processor)
538
+
539
+ # Test that pipeline initializes with config
540
+ assert pipeline.config.max_memory_per_partition_mb > 0
541
+
542
+ processor.close()
543
+
544
+
545
+ class TestBatchModeConfig:
546
+ """Test batch mode configuration and processing mode selection."""
547
+
548
+ def setup_method(self):
549
+ """Set up test fixtures."""
550
+ self.temp_dir = tempfile.mkdtemp()
551
+
552
+ def teardown_method(self):
553
+ """Clean up test fixtures."""
554
+ shutil.rmtree(self.temp_dir, ignore_errors=True)
555
+
556
+ def test_batch_mode_config_defaults(self):
557
+ """Test default batch mode configuration values."""
558
+ config = ingestion_pipeline.ProcessingConfig(
559
+ input_file="test.parquet",
560
+ output_catalog="./output",
561
+ scratch_location="./scratch",
562
+ )
563
+
564
+ assert config.batch_threshold == 10000
565
+ assert config.distributed is None # Auto mode by default
566
+ assert config.large_batch_confirm_threshold == 20000
567
+
568
+ def test_batch_mode_config_custom(self):
569
+ """Test custom batch mode configuration values."""
570
+ config = ingestion_pipeline.ProcessingConfig(
571
+ input_file="test.parquet",
572
+ output_catalog="./output",
573
+ scratch_location="./scratch",
574
+ batch_threshold=5000,
575
+ distributed=True,
576
+ large_batch_confirm_threshold=15000,
577
+ )
578
+
579
+ assert config.batch_threshold == 5000
580
+ assert config.distributed is True
581
+ assert config.large_batch_confirm_threshold == 15000
582
+
583
+ def test_batch_mode_in_to_dict(self):
584
+ """Test that batch mode options are included in to_dict."""
585
+ config = ingestion_pipeline.ProcessingConfig(
586
+ input_file="test.parquet",
587
+ output_catalog="./output",
588
+ scratch_location="./scratch",
589
+ batch_threshold=7500,
590
+ distributed=False,
591
+ )
592
+
593
+ data = config.to_dict()
594
+ assert data["batch_threshold"] == 7500
595
+ assert data["distributed"] is False
596
+ assert "large_batch_confirm_threshold" in data
597
+
598
+ def test_batch_mode_from_dict(self):
599
+ """Test restoring batch mode options from dict."""
600
+ data = {
601
+ "input_file": "test.parquet",
602
+ "output_catalog": "./output",
603
+ "scratch_location": "./scratch",
604
+ "batch_threshold": 8000,
605
+ "distributed": True,
606
+ "large_batch_confirm_threshold": 25000,
607
+ }
608
+
609
+ config = ingestion_pipeline.ProcessingConfig.from_dict(data)
610
+ assert config.batch_threshold == 8000
611
+ assert config.distributed is True
612
+ assert config.large_batch_confirm_threshold == 25000
613
+
614
+ def test_should_use_distributed_force_true(self):
615
+ """Test _should_use_distributed with forced distributed mode."""
616
+ config = ingestion_pipeline.ProcessingConfig(
617
+ input_file=str(Path(self.temp_dir) / "input.parquet"),
618
+ output_catalog=str(Path(self.temp_dir) / "catalog"),
619
+ scratch_location=str(Path(self.temp_dir) / "scratch"),
620
+ distributed=True,
621
+ )
622
+
623
+ # Create minimal input file
624
+ df = pd.DataFrame({"url": ["url1"]})
625
+ df.to_parquet(config.input_file, index=False)
626
+
627
+ processor = ingestion_pipeline.LocalProcessor(n_workers=1)
628
+ pipeline = ingestion_pipeline.STACIngestionPipeline(config, processor)
629
+
630
+ # Should use distributed even with 1 URL (well below threshold)
631
+ assert pipeline._should_use_distributed(1) is True
632
+ assert pipeline._should_use_distributed(100) is True
633
+ processor.close()
634
+
635
+ def test_should_use_distributed_force_false(self):
636
+ """Test _should_use_distributed with forced local mode."""
637
+ config = ingestion_pipeline.ProcessingConfig(
638
+ input_file=str(Path(self.temp_dir) / "input.parquet"),
639
+ output_catalog=str(Path(self.temp_dir) / "catalog"),
640
+ scratch_location=str(Path(self.temp_dir) / "scratch"),
641
+ distributed=False,
642
+ )
643
+
644
+ # Create minimal input file
645
+ df = pd.DataFrame({"url": ["url1"]})
646
+ df.to_parquet(config.input_file, index=False)
647
+
648
+ processor = ingestion_pipeline.LocalProcessor(n_workers=1)
649
+ pipeline = ingestion_pipeline.STACIngestionPipeline(config, processor)
650
+
651
+ # Should use local even with many URLs (above threshold)
652
+ assert pipeline._should_use_distributed(1) is False
653
+ assert pipeline._should_use_distributed(100000) is False
654
+ processor.close()
655
+
656
+ def test_should_use_distributed_auto_mode(self):
657
+ """Test _should_use_distributed with auto mode."""
658
+ config = ingestion_pipeline.ProcessingConfig(
659
+ input_file=str(Path(self.temp_dir) / "input.parquet"),
660
+ output_catalog=str(Path(self.temp_dir) / "catalog"),
661
+ scratch_location=str(Path(self.temp_dir) / "scratch"),
662
+ batch_threshold=1000,
663
+ distributed=None, # Auto mode
664
+ )
665
+
666
+ # Create minimal input file
667
+ df = pd.DataFrame({"url": ["url1"]})
668
+ df.to_parquet(config.input_file, index=False)
669
+
670
+ processor = ingestion_pipeline.LocalProcessor(n_workers=1)
671
+ pipeline = ingestion_pipeline.STACIngestionPipeline(config, processor)
672
+
673
+ # Below threshold -> local
674
+ assert pipeline._should_use_distributed(500) is False
675
+ assert pipeline._should_use_distributed(999) is False
676
+
677
+ # At or above threshold -> distributed
678
+ assert pipeline._should_use_distributed(1000) is True
679
+ assert pipeline._should_use_distributed(5000) is True
680
+ processor.close()
681
+
682
+
683
+ if __name__ == "__main__":
684
+ pytest.main([__file__])