earthcatalog 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- earthcatalog/__init__.py +164 -0
- earthcatalog/async_http_client.py +1006 -0
- earthcatalog/config.py +97 -0
- earthcatalog/engines/__init__.py +308 -0
- earthcatalog/engines/rustac_engine.py +142 -0
- earthcatalog/engines/stac_geoparquet_engine.py +126 -0
- earthcatalog/exceptions.py +471 -0
- earthcatalog/grid_systems.py +1114 -0
- earthcatalog/ingestion_pipeline.py +2281 -0
- earthcatalog/input_readers.py +603 -0
- earthcatalog/job_tracking.py +485 -0
- earthcatalog/pipeline.py +606 -0
- earthcatalog/schema_generator.py +911 -0
- earthcatalog/spatial_resolver.py +1207 -0
- earthcatalog/stac_hooks.py +754 -0
- earthcatalog/statistics.py +677 -0
- earthcatalog/storage_backends.py +548 -0
- earthcatalog/tests/__init__.py +1 -0
- earthcatalog/tests/conftest.py +76 -0
- earthcatalog/tests/test_all_grids.py +793 -0
- earthcatalog/tests/test_async_http.py +700 -0
- earthcatalog/tests/test_cli_and_storage.py +230 -0
- earthcatalog/tests/test_config.py +245 -0
- earthcatalog/tests/test_dask_integration.py +580 -0
- earthcatalog/tests/test_e2e_synthetic.py +1624 -0
- earthcatalog/tests/test_engines.py +272 -0
- earthcatalog/tests/test_exceptions.py +346 -0
- earthcatalog/tests/test_file_structure.py +245 -0
- earthcatalog/tests/test_input_readers.py +666 -0
- earthcatalog/tests/test_integration.py +200 -0
- earthcatalog/tests/test_integration_async.py +283 -0
- earthcatalog/tests/test_job_tracking.py +603 -0
- earthcatalog/tests/test_multi_file_input.py +336 -0
- earthcatalog/tests/test_passthrough_hook.py +196 -0
- earthcatalog/tests/test_pipeline.py +684 -0
- earthcatalog/tests/test_pipeline_components.py +665 -0
- earthcatalog/tests/test_schema_generator.py +506 -0
- earthcatalog/tests/test_spatial_resolver.py +413 -0
- earthcatalog/tests/test_stac_hooks.py +776 -0
- earthcatalog/tests/test_statistics.py +477 -0
- earthcatalog/tests/test_storage_backends.py +236 -0
- earthcatalog/tests/test_validation.py +435 -0
- earthcatalog/tests/test_workers.py +653 -0
- earthcatalog/validation.py +921 -0
- earthcatalog/workers.py +682 -0
- earthcatalog-0.2.0.dist-info/METADATA +333 -0
- earthcatalog-0.2.0.dist-info/RECORD +50 -0
- earthcatalog-0.2.0.dist-info/WHEEL +5 -0
- earthcatalog-0.2.0.dist-info/entry_points.txt +3 -0
- earthcatalog-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,653 @@
|
|
|
1
|
+
# test_workers.py
|
|
2
|
+
"""Tests for workers module - serializable worker functions.
|
|
3
|
+
|
|
4
|
+
This module tests:
|
|
5
|
+
- process_url_batch: URL batch processing
|
|
6
|
+
- consolidate_partition: Partition consolidation
|
|
7
|
+
- Result types: DownloadResult, ConsolidationResult
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import tempfile
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from unittest.mock import MagicMock, patch
|
|
13
|
+
|
|
14
|
+
import pytest
|
|
15
|
+
|
|
16
|
+
from earthcatalog.ingestion_pipeline import ProcessingConfig
|
|
17
|
+
from earthcatalog.statistics import IngestionStatistics
|
|
18
|
+
from earthcatalog.workers import (
|
|
19
|
+
ConsolidationResult,
|
|
20
|
+
DownloadResult,
|
|
21
|
+
_download_stac_item,
|
|
22
|
+
_get_partition_key,
|
|
23
|
+
consolidate_partition,
|
|
24
|
+
process_url_batch,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TestDownloadResult:
|
|
29
|
+
"""Test DownloadResult dataclass."""
|
|
30
|
+
|
|
31
|
+
def test_to_dict(self):
|
|
32
|
+
"""to_dict should serialize result."""
|
|
33
|
+
stats = IngestionStatistics()
|
|
34
|
+
stats.record_url_processed(success=True)
|
|
35
|
+
result = DownloadResult(
|
|
36
|
+
shards=[{"shard_path": "/path/to/shard.parquet"}],
|
|
37
|
+
stats=stats,
|
|
38
|
+
failed_urls=["http://failed.com"],
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
data = result.to_dict()
|
|
42
|
+
|
|
43
|
+
assert data["shards"] == [{"shard_path": "/path/to/shard.parquet"}]
|
|
44
|
+
assert data["failed_urls"] == ["http://failed.com"]
|
|
45
|
+
# Stats contains url counts
|
|
46
|
+
assert "urls_processed" in data["stats"]
|
|
47
|
+
|
|
48
|
+
def test_from_dict(self):
|
|
49
|
+
"""from_dict should deserialize result."""
|
|
50
|
+
data = {
|
|
51
|
+
"shards": [{"shard_path": "/path/to/shard.parquet"}],
|
|
52
|
+
"stats": {"items_processed": 10, "items_failed": 2},
|
|
53
|
+
"failed_urls": ["http://failed.com"],
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
result = DownloadResult.from_dict(data)
|
|
57
|
+
|
|
58
|
+
assert len(result.shards) == 1
|
|
59
|
+
assert len(result.failed_urls) == 1
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class TestConsolidationResult:
|
|
63
|
+
"""Test ConsolidationResult dataclass."""
|
|
64
|
+
|
|
65
|
+
def test_to_dict(self):
|
|
66
|
+
"""to_dict should serialize result."""
|
|
67
|
+
result = ConsolidationResult(
|
|
68
|
+
partition_key="h3_82/2024/01",
|
|
69
|
+
item_count=100,
|
|
70
|
+
existing_count=50,
|
|
71
|
+
new_count=50,
|
|
72
|
+
duplicates_removed=5,
|
|
73
|
+
final_path="/catalog/h3_82/2024/01/data.parquet",
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
data = result.to_dict()
|
|
77
|
+
|
|
78
|
+
assert data["partition_key"] == "h3_82/2024/01"
|
|
79
|
+
assert data["item_count"] == 100
|
|
80
|
+
assert data["existing_count"] == 50
|
|
81
|
+
assert data["new_count"] == 50
|
|
82
|
+
assert data["duplicates_removed"] == 5
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class TestDownloadStacItem:
|
|
86
|
+
"""Test _download_stac_item helper function."""
|
|
87
|
+
|
|
88
|
+
def test_successful_download(self):
|
|
89
|
+
"""Should return parsed JSON on successful download."""
|
|
90
|
+
mock_hook = MagicMock()
|
|
91
|
+
mock_hook.fetch.return_value = {"id": "test-item", "type": "Feature"}
|
|
92
|
+
|
|
93
|
+
with patch("earthcatalog.workers._get_stac_hook", return_value=mock_hook):
|
|
94
|
+
result = _download_stac_item("http://example.com/item.json")
|
|
95
|
+
|
|
96
|
+
assert result is not None
|
|
97
|
+
assert result["id"] == "test-item"
|
|
98
|
+
mock_hook.fetch.assert_called_once()
|
|
99
|
+
|
|
100
|
+
def test_failed_download_returns_none(self):
|
|
101
|
+
"""Should return None after all retries fail."""
|
|
102
|
+
mock_hook = MagicMock()
|
|
103
|
+
mock_hook.fetch.return_value = None
|
|
104
|
+
|
|
105
|
+
with patch("earthcatalog.workers._get_stac_hook", return_value=mock_hook):
|
|
106
|
+
result = _download_stac_item("http://example.com/item.json", retry_attempts=1)
|
|
107
|
+
|
|
108
|
+
assert result is None
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class TestGetPartitionKey:
|
|
112
|
+
"""Test _get_partition_key helper function."""
|
|
113
|
+
|
|
114
|
+
def test_partition_key_with_datetime(self):
|
|
115
|
+
"""Should generate correct partition key with datetime."""
|
|
116
|
+
from earthcatalog.grid_systems import get_grid_system
|
|
117
|
+
|
|
118
|
+
item = {
|
|
119
|
+
"id": "test-item",
|
|
120
|
+
"geometry": {"type": "Point", "coordinates": [-122.0, 37.0]},
|
|
121
|
+
"properties": {"datetime": "2024-06-15T12:00:00Z"},
|
|
122
|
+
}
|
|
123
|
+
grid = get_grid_system("h3", resolution=2)
|
|
124
|
+
|
|
125
|
+
key = _get_partition_key(
|
|
126
|
+
item,
|
|
127
|
+
grid_resolver=grid,
|
|
128
|
+
temporal_bin="month",
|
|
129
|
+
enable_global=False,
|
|
130
|
+
global_threshold=50,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
assert "2024" in key
|
|
134
|
+
assert "06" in key
|
|
135
|
+
|
|
136
|
+
def test_partition_key_with_year_bin(self):
|
|
137
|
+
"""Should generate year-level partition key."""
|
|
138
|
+
from earthcatalog.grid_systems import get_grid_system
|
|
139
|
+
|
|
140
|
+
item = {
|
|
141
|
+
"id": "test-item",
|
|
142
|
+
"geometry": {"type": "Point", "coordinates": [-122.0, 37.0]},
|
|
143
|
+
"properties": {"datetime": "2024-06-15T12:00:00Z"},
|
|
144
|
+
}
|
|
145
|
+
grid = get_grid_system("h3", resolution=2)
|
|
146
|
+
|
|
147
|
+
key = _get_partition_key(
|
|
148
|
+
item,
|
|
149
|
+
grid_resolver=grid,
|
|
150
|
+
temporal_bin="year",
|
|
151
|
+
enable_global=False,
|
|
152
|
+
global_threshold=50,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
assert "2024" in key
|
|
156
|
+
# Month should not be in year-level binning
|
|
157
|
+
assert key.count("/") == 1 # grid_cell/year
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class TestProcessUrlBatch:
|
|
161
|
+
"""Test process_url_batch function."""
|
|
162
|
+
|
|
163
|
+
@pytest.fixture
|
|
164
|
+
def temp_dir(self):
|
|
165
|
+
"""Create a temporary directory."""
|
|
166
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
167
|
+
yield tmpdir
|
|
168
|
+
|
|
169
|
+
@pytest.fixture
|
|
170
|
+
def config_dict(self, temp_dir):
|
|
171
|
+
"""Create a config dict for testing."""
|
|
172
|
+
config = ProcessingConfig(
|
|
173
|
+
input_file="test.parquet",
|
|
174
|
+
output_catalog=f"{temp_dir}/catalog",
|
|
175
|
+
scratch_location=f"{temp_dir}/scratch",
|
|
176
|
+
enable_concurrent_http=False, # Disable async for testing
|
|
177
|
+
batch_size=10,
|
|
178
|
+
)
|
|
179
|
+
return config.to_dict()
|
|
180
|
+
|
|
181
|
+
def test_process_empty_batch(self, config_dict):
|
|
182
|
+
"""Should handle empty URL list."""
|
|
183
|
+
result = process_url_batch(
|
|
184
|
+
urls=[],
|
|
185
|
+
worker_id=0,
|
|
186
|
+
config_dict=config_dict,
|
|
187
|
+
job_id="test-job",
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
assert result["shards"] == []
|
|
191
|
+
assert result["failed_urls"] == []
|
|
192
|
+
|
|
193
|
+
def test_process_batch_with_failures(self, config_dict):
|
|
194
|
+
"""Should record failed URLs."""
|
|
195
|
+
mock_hook = MagicMock()
|
|
196
|
+
mock_hook.fetch.return_value = None # Simulate failed fetch
|
|
197
|
+
mock_hook.fetch_batch.return_value = [None, None] # Batch also fails
|
|
198
|
+
|
|
199
|
+
with patch("earthcatalog.workers._get_stac_hook", return_value=mock_hook):
|
|
200
|
+
result = process_url_batch(
|
|
201
|
+
urls=["http://bad1.com", "http://bad2.com"],
|
|
202
|
+
worker_id=0,
|
|
203
|
+
config_dict=config_dict,
|
|
204
|
+
job_id="test-job",
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
assert len(result["failed_urls"]) == 2
|
|
208
|
+
|
|
209
|
+
@pytest.mark.skip(reason="Requires complex mocking of STAC engine serialization")
|
|
210
|
+
def test_process_batch_successful(self, config_dict):
|
|
211
|
+
"""Should process URLs and write shards."""
|
|
212
|
+
mock_item = {
|
|
213
|
+
"id": "test-item",
|
|
214
|
+
"type": "Feature",
|
|
215
|
+
"geometry": {"type": "Point", "coordinates": [-122.0, 37.0]},
|
|
216
|
+
"bbox": [-122.0, 37.0, -122.0, 37.0],
|
|
217
|
+
"properties": {"datetime": "2024-06-15T12:00:00Z"},
|
|
218
|
+
"stac_version": "1.0.0",
|
|
219
|
+
"links": [],
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
mock_response = MagicMock()
|
|
223
|
+
mock_response.json.return_value = mock_item
|
|
224
|
+
mock_response.raise_for_status = MagicMock()
|
|
225
|
+
|
|
226
|
+
with patch("earthcatalog.workers.requests.get", return_value=mock_response):
|
|
227
|
+
result = process_url_batch(
|
|
228
|
+
urls=["http://example.com/item1.json"],
|
|
229
|
+
worker_id=0,
|
|
230
|
+
config_dict=config_dict,
|
|
231
|
+
job_id="test-job",
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
assert len(result["shards"]) >= 1
|
|
235
|
+
assert result["failed_urls"] == []
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class TestConsolidatePartition:
|
|
239
|
+
"""Test consolidate_partition function."""
|
|
240
|
+
|
|
241
|
+
@pytest.fixture
|
|
242
|
+
def temp_dir(self):
|
|
243
|
+
"""Create a temporary directory."""
|
|
244
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
245
|
+
yield tmpdir
|
|
246
|
+
|
|
247
|
+
@pytest.fixture
|
|
248
|
+
def config_dict(self, temp_dir):
|
|
249
|
+
"""Create a config dict for testing."""
|
|
250
|
+
config = ProcessingConfig(
|
|
251
|
+
input_file="test.parquet",
|
|
252
|
+
output_catalog=f"{temp_dir}/catalog",
|
|
253
|
+
scratch_location=f"{temp_dir}/scratch",
|
|
254
|
+
)
|
|
255
|
+
return config.to_dict()
|
|
256
|
+
|
|
257
|
+
def test_consolidate_empty_shards(self, config_dict):
|
|
258
|
+
"""Should handle empty shard list."""
|
|
259
|
+
result = consolidate_partition(
|
|
260
|
+
partition_key="h3_82/2024/01",
|
|
261
|
+
shard_paths=[],
|
|
262
|
+
config_dict=config_dict,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
assert result["partition_key"] == "h3_82/2024/01"
|
|
266
|
+
assert result["item_count"] == 0
|
|
267
|
+
|
|
268
|
+
def test_consolidate_with_shards(self, config_dict, temp_dir):
|
|
269
|
+
"""Should consolidate shards into final partition."""
|
|
270
|
+
import geopandas as gpd
|
|
271
|
+
from shapely.geometry import Point
|
|
272
|
+
|
|
273
|
+
# Create a test shard
|
|
274
|
+
scratch_dir = Path(temp_dir) / "scratch" / "shards"
|
|
275
|
+
scratch_dir.mkdir(parents=True)
|
|
276
|
+
|
|
277
|
+
shard_path = str(scratch_dir / "test_shard.parquet")
|
|
278
|
+
gdf = gpd.GeoDataFrame(
|
|
279
|
+
{
|
|
280
|
+
"id": ["item1", "item2"],
|
|
281
|
+
"datetime": ["2024-06-15T12:00:00Z", "2024-06-16T12:00:00Z"],
|
|
282
|
+
},
|
|
283
|
+
geometry=gpd.GeoSeries([Point(-122, 37), Point(-122, 37)]),
|
|
284
|
+
)
|
|
285
|
+
gdf.to_parquet(shard_path)
|
|
286
|
+
|
|
287
|
+
result = consolidate_partition(
|
|
288
|
+
partition_key="h3_82/2024/06",
|
|
289
|
+
shard_paths=[shard_path],
|
|
290
|
+
config_dict=config_dict,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
assert result["partition_key"] == "h3_82/2024/06"
|
|
294
|
+
assert result["item_count"] == 2
|
|
295
|
+
assert Path(result["final_path"]).exists()
|
|
296
|
+
|
|
297
|
+
def test_consolidate_deduplicates_by_datetime(self, config_dict, temp_dir):
|
|
298
|
+
"""Should keep newer item when deduplicating by ID."""
|
|
299
|
+
import geopandas as gpd
|
|
300
|
+
from shapely.geometry import Point
|
|
301
|
+
|
|
302
|
+
scratch_dir = Path(temp_dir) / "scratch" / "shards"
|
|
303
|
+
scratch_dir.mkdir(parents=True)
|
|
304
|
+
|
|
305
|
+
# Create two shards with same ID but different datetimes
|
|
306
|
+
shard1_path = str(scratch_dir / "shard1.parquet")
|
|
307
|
+
gdf1 = gpd.GeoDataFrame(
|
|
308
|
+
{
|
|
309
|
+
"id": ["item1"],
|
|
310
|
+
"datetime": ["2024-06-15T12:00:00Z"], # Older
|
|
311
|
+
},
|
|
312
|
+
geometry=gpd.GeoSeries([Point(-122, 37)]),
|
|
313
|
+
)
|
|
314
|
+
gdf1.to_parquet(shard1_path)
|
|
315
|
+
|
|
316
|
+
shard2_path = str(scratch_dir / "shard2.parquet")
|
|
317
|
+
gdf2 = gpd.GeoDataFrame(
|
|
318
|
+
{
|
|
319
|
+
"id": ["item1"],
|
|
320
|
+
"datetime": ["2024-06-16T12:00:00Z"], # Newer
|
|
321
|
+
},
|
|
322
|
+
geometry=gpd.GeoSeries([Point(-122, 37)]),
|
|
323
|
+
)
|
|
324
|
+
gdf2.to_parquet(shard2_path)
|
|
325
|
+
|
|
326
|
+
result = consolidate_partition(
|
|
327
|
+
partition_key="h3_82/2024/06",
|
|
328
|
+
shard_paths=[shard1_path, shard2_path],
|
|
329
|
+
config_dict=config_dict,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
assert result["item_count"] == 1 # Deduplicated to 1
|
|
333
|
+
assert result["duplicates_removed"] == 1
|
|
334
|
+
|
|
335
|
+
# Verify the newer item was kept
|
|
336
|
+
final_gdf = gpd.read_parquet(result["final_path"])
|
|
337
|
+
assert len(final_gdf) == 1
|
|
338
|
+
assert "2024-06-16" in final_gdf.iloc[0]["datetime"]
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
class TestPickleability:
|
|
342
|
+
"""Test that worker functions can be pickled for Dask."""
|
|
343
|
+
|
|
344
|
+
def test_process_url_batch_is_pickleable(self):
|
|
345
|
+
"""process_url_batch should be pickleable."""
|
|
346
|
+
import pickle
|
|
347
|
+
|
|
348
|
+
# Should not raise
|
|
349
|
+
pickled = pickle.dumps(process_url_batch)
|
|
350
|
+
restored = pickle.loads(pickled)
|
|
351
|
+
assert callable(restored)
|
|
352
|
+
|
|
353
|
+
def test_consolidate_partition_is_pickleable(self):
|
|
354
|
+
"""consolidate_partition should be pickleable."""
|
|
355
|
+
import pickle
|
|
356
|
+
|
|
357
|
+
pickled = pickle.dumps(consolidate_partition)
|
|
358
|
+
restored = pickle.loads(pickled)
|
|
359
|
+
assert callable(restored)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
class TestDaskIntegration:
|
|
363
|
+
"""Tests for Dask distributed integration with serializable workers.
|
|
364
|
+
|
|
365
|
+
These tests verify that the worker functions can be properly used
|
|
366
|
+
with Dask's distributed computing model.
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
@pytest.fixture
|
|
370
|
+
def temp_dir(self):
|
|
371
|
+
"""Create a temporary directory."""
|
|
372
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
373
|
+
yield tmpdir
|
|
374
|
+
|
|
375
|
+
@pytest.fixture
|
|
376
|
+
def config_dict(self, temp_dir):
|
|
377
|
+
"""Create a config dict for testing."""
|
|
378
|
+
config = ProcessingConfig(
|
|
379
|
+
input_file="test.parquet",
|
|
380
|
+
output_catalog=f"{temp_dir}/catalog",
|
|
381
|
+
scratch_location=f"{temp_dir}/scratch",
|
|
382
|
+
enable_concurrent_http=False,
|
|
383
|
+
batch_size=10,
|
|
384
|
+
)
|
|
385
|
+
return config.to_dict()
|
|
386
|
+
|
|
387
|
+
def test_config_dict_roundtrip(self, config_dict):
|
|
388
|
+
"""Config dict should serialize and deserialize correctly."""
|
|
389
|
+
restored = ProcessingConfig.from_dict(config_dict)
|
|
390
|
+
|
|
391
|
+
assert restored.output_catalog == config_dict["output_catalog"]
|
|
392
|
+
assert restored.scratch_location == config_dict["scratch_location"]
|
|
393
|
+
assert restored.enable_concurrent_http == config_dict["enable_concurrent_http"]
|
|
394
|
+
|
|
395
|
+
def test_process_url_batch_accepts_config_dict(self, config_dict):
|
|
396
|
+
"""process_url_batch should accept config_dict parameter."""
|
|
397
|
+
# Should not raise
|
|
398
|
+
result = process_url_batch(
|
|
399
|
+
urls=[],
|
|
400
|
+
worker_id=0,
|
|
401
|
+
config_dict=config_dict,
|
|
402
|
+
job_id="test-job-123",
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
assert "shards" in result
|
|
406
|
+
assert "stats" in result
|
|
407
|
+
assert "failed_urls" in result
|
|
408
|
+
|
|
409
|
+
def test_consolidate_partition_accepts_config_dict(self, config_dict):
|
|
410
|
+
"""consolidate_partition should accept config_dict parameter."""
|
|
411
|
+
result = consolidate_partition(
|
|
412
|
+
partition_key="test/2024/01",
|
|
413
|
+
shard_paths=[],
|
|
414
|
+
config_dict=config_dict,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
assert "partition_key" in result
|
|
418
|
+
assert result["partition_key"] == "test/2024/01"
|
|
419
|
+
assert result["item_count"] == 0
|
|
420
|
+
|
|
421
|
+
def test_worker_functions_are_module_level(self):
|
|
422
|
+
"""Worker functions should be importable at module level for Dask."""
|
|
423
|
+
from earthcatalog import workers
|
|
424
|
+
|
|
425
|
+
# These should be module-level functions, not methods
|
|
426
|
+
assert hasattr(workers, "process_url_batch")
|
|
427
|
+
assert hasattr(workers, "consolidate_partition")
|
|
428
|
+
assert callable(workers.process_url_batch)
|
|
429
|
+
assert callable(workers.consolidate_partition)
|
|
430
|
+
|
|
431
|
+
def test_worker_functions_have_correct_signature(self):
|
|
432
|
+
"""Worker functions should have correct signature for Dask usage."""
|
|
433
|
+
import inspect
|
|
434
|
+
|
|
435
|
+
# process_url_batch signature
|
|
436
|
+
sig = inspect.signature(process_url_batch)
|
|
437
|
+
params = list(sig.parameters.keys())
|
|
438
|
+
assert "urls" in params
|
|
439
|
+
assert "worker_id" in params
|
|
440
|
+
assert "config_dict" in params
|
|
441
|
+
assert "job_id" in params
|
|
442
|
+
|
|
443
|
+
# consolidate_partition signature
|
|
444
|
+
sig = inspect.signature(consolidate_partition)
|
|
445
|
+
params = list(sig.parameters.keys())
|
|
446
|
+
assert "partition_key" in params
|
|
447
|
+
assert "shard_paths" in params
|
|
448
|
+
assert "config_dict" in params
|
|
449
|
+
|
|
450
|
+
def test_result_serialization_for_dask(self, config_dict):
|
|
451
|
+
"""Results should be serializable for Dask transmission."""
|
|
452
|
+
import pickle
|
|
453
|
+
|
|
454
|
+
result = process_url_batch(
|
|
455
|
+
urls=[],
|
|
456
|
+
worker_id=0,
|
|
457
|
+
config_dict=config_dict,
|
|
458
|
+
job_id="test",
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Should be pickleable (required for Dask)
|
|
462
|
+
pickled = pickle.dumps(result)
|
|
463
|
+
restored = pickle.loads(pickled)
|
|
464
|
+
|
|
465
|
+
assert restored["shards"] == result["shards"]
|
|
466
|
+
assert restored["failed_urls"] == result["failed_urls"]
|
|
467
|
+
|
|
468
|
+
def test_consolidation_result_serialization(self, config_dict):
|
|
469
|
+
"""Consolidation results should be serializable."""
|
|
470
|
+
import pickle
|
|
471
|
+
|
|
472
|
+
result = consolidate_partition(
|
|
473
|
+
partition_key="h3_82/2024/01",
|
|
474
|
+
shard_paths=[],
|
|
475
|
+
config_dict=config_dict,
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
pickled = pickle.dumps(result)
|
|
479
|
+
restored = pickle.loads(pickled)
|
|
480
|
+
|
|
481
|
+
assert restored["partition_key"] == result["partition_key"]
|
|
482
|
+
assert restored["item_count"] == result["item_count"]
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
class TestDaskDistributedProcessorIntegration:
|
|
486
|
+
"""Tests for DaskDistributedProcessor using workers.py functions.
|
|
487
|
+
|
|
488
|
+
These tests mock Dask to verify the correct functions are called.
|
|
489
|
+
"""
|
|
490
|
+
|
|
491
|
+
@pytest.fixture
|
|
492
|
+
def mock_dask_client(self):
|
|
493
|
+
"""Create a mock Dask client."""
|
|
494
|
+
mock_client = MagicMock()
|
|
495
|
+
mock_client.dashboard_link = "http://localhost:8787"
|
|
496
|
+
mock_client.submit.return_value = MagicMock()
|
|
497
|
+
mock_client.gather.return_value = [{"shards": [], "stats": {}, "failed_urls": []}]
|
|
498
|
+
return mock_client
|
|
499
|
+
|
|
500
|
+
def test_dask_processor_uses_workers_process_url_batch(self, mock_dask_client):
|
|
501
|
+
"""DaskDistributedProcessor should use workers.process_url_batch when config_dict provided."""
|
|
502
|
+
# Create mock dd module
|
|
503
|
+
mock_dd = MagicMock()
|
|
504
|
+
mock_dd.Client.return_value = mock_dask_client
|
|
505
|
+
|
|
506
|
+
with patch.dict("sys.modules", {"dask.distributed": mock_dd}):
|
|
507
|
+
# Need to reimport to pick up mocked module
|
|
508
|
+
from importlib import reload
|
|
509
|
+
|
|
510
|
+
from earthcatalog import ingestion_pipeline
|
|
511
|
+
|
|
512
|
+
reload(ingestion_pipeline)
|
|
513
|
+
|
|
514
|
+
processor = ingestion_pipeline.DaskDistributedProcessor()
|
|
515
|
+
|
|
516
|
+
config_dict = {
|
|
517
|
+
"input_file": "test.parquet",
|
|
518
|
+
"output_catalog": "./out",
|
|
519
|
+
"scratch_location": "./scratch",
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
# Call process_urls with config_dict
|
|
523
|
+
processor.process_urls(
|
|
524
|
+
url_chunks=[["url1", "url2"]],
|
|
525
|
+
process_fn=lambda x, y: None, # Should be ignored
|
|
526
|
+
config_dict=config_dict,
|
|
527
|
+
job_id="test-job",
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
# Verify the submit call used process_url_batch
|
|
531
|
+
assert mock_dask_client.submit.called
|
|
532
|
+
call_args = mock_dask_client.submit.call_args
|
|
533
|
+
# First positional arg should be the function
|
|
534
|
+
called_func = call_args[0][0] if call_args[0] else call_args[1].get("func")
|
|
535
|
+
# The function name should be process_url_batch
|
|
536
|
+
assert called_func.__name__ == "process_url_batch"
|
|
537
|
+
|
|
538
|
+
def test_dask_processor_uses_workers_consolidate_partition(self, mock_dask_client):
|
|
539
|
+
"""DaskDistributedProcessor should use workers.consolidate_partition when config_dict provided."""
|
|
540
|
+
mock_dask_client.gather.return_value = [{"partition_key": "test", "item_count": 0}]
|
|
541
|
+
mock_dd = MagicMock()
|
|
542
|
+
mock_dd.Client.return_value = mock_dask_client
|
|
543
|
+
|
|
544
|
+
with patch.dict("sys.modules", {"dask.distributed": mock_dd}):
|
|
545
|
+
from importlib import reload
|
|
546
|
+
|
|
547
|
+
from earthcatalog import ingestion_pipeline
|
|
548
|
+
|
|
549
|
+
reload(ingestion_pipeline)
|
|
550
|
+
|
|
551
|
+
processor = ingestion_pipeline.DaskDistributedProcessor()
|
|
552
|
+
|
|
553
|
+
config_dict = {
|
|
554
|
+
"input_file": "test.parquet",
|
|
555
|
+
"output_catalog": "./out",
|
|
556
|
+
"scratch_location": "./scratch",
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
processor.consolidate_shards(
|
|
560
|
+
partition_items=[("test/2024/01", ["/path/shard1.parquet"])],
|
|
561
|
+
consolidate_fn=lambda x, y: None, # Should be ignored
|
|
562
|
+
config_dict=config_dict,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
assert mock_dask_client.submit.called
|
|
566
|
+
call_args = mock_dask_client.submit.call_args
|
|
567
|
+
called_func = call_args[0][0] if call_args[0] else call_args[1].get("func")
|
|
568
|
+
assert called_func.__name__ == "consolidate_partition"
|
|
569
|
+
|
|
570
|
+
def test_dask_processor_fallback_without_config_dict(self, mock_dask_client):
|
|
571
|
+
"""DaskDistributedProcessor should fall back to old behavior without config_dict."""
|
|
572
|
+
mock_dask_client.gather.return_value = ["result1"]
|
|
573
|
+
mock_dd = MagicMock()
|
|
574
|
+
mock_dd.Client.return_value = mock_dask_client
|
|
575
|
+
|
|
576
|
+
with patch.dict("sys.modules", {"dask.distributed": mock_dd}):
|
|
577
|
+
from importlib import reload
|
|
578
|
+
|
|
579
|
+
from earthcatalog import ingestion_pipeline
|
|
580
|
+
|
|
581
|
+
reload(ingestion_pipeline)
|
|
582
|
+
|
|
583
|
+
processor = ingestion_pipeline.DaskDistributedProcessor()
|
|
584
|
+
|
|
585
|
+
custom_fn = MagicMock(return_value="custom_result")
|
|
586
|
+
|
|
587
|
+
# Call without config_dict
|
|
588
|
+
processor.process_urls(
|
|
589
|
+
url_chunks=[["url1"]],
|
|
590
|
+
process_fn=custom_fn,
|
|
591
|
+
# No config_dict - should use custom_fn
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
assert mock_dask_client.submit.called
|
|
595
|
+
call_args = mock_dask_client.submit.call_args
|
|
596
|
+
called_func = call_args[0][0]
|
|
597
|
+
# Should use the provided custom function, not process_url_batch
|
|
598
|
+
assert called_func == custom_fn
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
class TestStacHookIntegrationWithWorkers:
|
|
602
|
+
"""Test STAC hook integration with workers.py."""
|
|
603
|
+
|
|
604
|
+
@pytest.fixture
|
|
605
|
+
def temp_dir(self):
|
|
606
|
+
"""Create a temporary directory."""
|
|
607
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
608
|
+
yield tmpdir
|
|
609
|
+
|
|
610
|
+
def test_hook_config_in_config_dict(self, temp_dir):
|
|
611
|
+
"""stac_hook should be included in config_dict."""
|
|
612
|
+
config = ProcessingConfig(
|
|
613
|
+
input_file="test.parquet",
|
|
614
|
+
output_catalog=f"{temp_dir}/catalog",
|
|
615
|
+
scratch_location=f"{temp_dir}/scratch",
|
|
616
|
+
stac_hook="module:my_module:my_func",
|
|
617
|
+
)
|
|
618
|
+
|
|
619
|
+
config_dict = config.to_dict()
|
|
620
|
+
assert "stac_hook" in config_dict
|
|
621
|
+
assert config_dict["stac_hook"] == "module:my_module:my_func"
|
|
622
|
+
|
|
623
|
+
def test_hook_config_restored_from_dict(self, temp_dir):
|
|
624
|
+
"""stac_hook should be restored from config_dict."""
|
|
625
|
+
config_dict = {
|
|
626
|
+
"input_file": "test.parquet",
|
|
627
|
+
"output_catalog": f"{temp_dir}/catalog",
|
|
628
|
+
"scratch_location": f"{temp_dir}/scratch",
|
|
629
|
+
"stac_hook": "script:/path/to/script.py",
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
config = ProcessingConfig.from_dict(config_dict)
|
|
633
|
+
assert config.stac_hook == "script:/path/to/script.py"
|
|
634
|
+
|
|
635
|
+
def test_default_hook_in_workers(self, temp_dir):
|
|
636
|
+
"""Workers should use default hook when not specified."""
|
|
637
|
+
config = ProcessingConfig(
|
|
638
|
+
input_file="test.parquet",
|
|
639
|
+
output_catalog=f"{temp_dir}/catalog",
|
|
640
|
+
scratch_location=f"{temp_dir}/scratch",
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
config_dict = config.to_dict()
|
|
644
|
+
|
|
645
|
+
# process_url_batch should work with default hook
|
|
646
|
+
result = process_url_batch(
|
|
647
|
+
urls=[],
|
|
648
|
+
worker_id=0,
|
|
649
|
+
config_dict=config_dict,
|
|
650
|
+
job_id="test",
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
assert result["failed_urls"] == []
|