caption-flow 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- caption_flow/cli.py +308 -0
- caption_flow/models.py +134 -1
- caption_flow/monitor.py +1 -1
- caption_flow/orchestrator.py +423 -1715
- caption_flow/processors/__init__.py +11 -0
- caption_flow/processors/base.py +219 -0
- caption_flow/processors/huggingface.py +832 -0
- caption_flow/processors/local_filesystem.py +683 -0
- caption_flow/processors/webdataset.py +782 -0
- caption_flow/storage/__init__.py +1 -0
- caption_flow/storage/exporter.py +550 -0
- caption_flow/{storage.py → storage/manager.py} +489 -401
- caption_flow/utils/checkpoint_tracker.py +2 -2
- caption_flow/utils/chunk_tracker.py +73 -32
- caption_flow/utils/dataset_loader.py +58 -298
- caption_flow/utils/dataset_metadata_cache.py +67 -0
- caption_flow/utils/image_processor.py +1 -4
- caption_flow/utils/shard_processor.py +5 -265
- caption_flow/utils/shard_tracker.py +1 -5
- caption_flow/viewer.py +594 -0
- caption_flow/workers/base.py +3 -3
- caption_flow/workers/caption.py +416 -792
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/METADATA +49 -180
- caption_flow-0.2.4.dist-info/RECORD +38 -0
- caption_flow-0.2.2.dist-info/RECORD +0 -29
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/WHEEL +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/entry_points.txt +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {caption_flow-0.2.2.dist-info → caption_flow-0.2.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,683 @@
|
|
1
|
+
"""Local filesystem datasets processor implementation."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import threading
|
5
|
+
import os
|
6
|
+
from typing import Dict, Any, List, Optional, Iterator, Set, Deque, Tuple
|
7
|
+
from collections import deque, defaultdict
|
8
|
+
from pathlib import Path
|
9
|
+
import json
|
10
|
+
import io
|
11
|
+
import mimetypes
|
12
|
+
from datetime import datetime
|
13
|
+
from PIL import Image
|
14
|
+
import aiofiles
|
15
|
+
from fastapi import FastAPI, HTTPException, Response
|
16
|
+
from fastapi.responses import StreamingResponse
|
17
|
+
import uvicorn
|
18
|
+
import asyncio
|
19
|
+
import requests
|
20
|
+
|
21
|
+
from caption_flow.storage import StorageManager
|
22
|
+
from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
|
23
|
+
from ..utils import ChunkTracker
|
24
|
+
from ..models import JobId
|
25
|
+
|
26
|
+
logger = logging.getLogger(__name__)
|
27
|
+
logger.setLevel(logging.DEBUG)
|
28
|
+
|
29
|
+
# Supported image extensions
|
30
|
+
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".gif", ".tiff", ".tif", ".svg"}
|
31
|
+
|
32
|
+
|
33
|
+
class LocalFilesystemOrchestratorProcessor(OrchestratorProcessor):
|
34
|
+
"""Orchestrator processor for local filesystem datasets."""
|
35
|
+
|
36
|
+
def __init__(self):
|
37
|
+
logger.debug("Initializing LocalFilesystemOrchestratorProcessor")
|
38
|
+
self.dataset_path: Optional[Path] = None
|
39
|
+
self.chunk_tracker: Optional[ChunkTracker] = None
|
40
|
+
self.chunk_size: int = 1000
|
41
|
+
self.recursive: bool = True
|
42
|
+
self.follow_symlinks: bool = False
|
43
|
+
|
44
|
+
# Image file tracking
|
45
|
+
self.all_images: List[Tuple[Path, int]] = [] # (path, size_bytes)
|
46
|
+
self.total_images: int = 0
|
47
|
+
self.current_index: int = 0
|
48
|
+
|
49
|
+
# Work unit management
|
50
|
+
self.work_units: Dict[str, WorkUnit] = {}
|
51
|
+
self.pending_units: Deque[str] = deque()
|
52
|
+
self.assigned_units: Dict[str, Set[str]] = defaultdict(set)
|
53
|
+
self.lock = threading.Lock()
|
54
|
+
|
55
|
+
# Background thread for creating work units
|
56
|
+
self.unit_creation_thread: Optional[threading.Thread] = None
|
57
|
+
self.stop_creation = threading.Event()
|
58
|
+
|
59
|
+
# HTTP server for serving images
|
60
|
+
self.http_app: Optional[FastAPI] = None
|
61
|
+
self.http_server_task: Optional[asyncio.Task] = None
|
62
|
+
self.http_bind_address: str = "0.0.0.0"
|
63
|
+
self.http_port: int = 8766
|
64
|
+
|
65
|
+
def initialize(self, config: ProcessorConfig, storage: StorageManager) -> None:
|
66
|
+
"""Initialize local filesystem processor."""
|
67
|
+
logger.debug("Initializing orchestrator with config: %s", config.config)
|
68
|
+
cfg = config.config
|
69
|
+
|
70
|
+
# Dataset configuration
|
71
|
+
dataset_cfg = cfg.get("dataset", {})
|
72
|
+
self.dataset_path = Path(dataset_cfg.get("dataset_path", "."))
|
73
|
+
|
74
|
+
if not self.dataset_path.exists():
|
75
|
+
raise ValueError(f"Dataset path does not exist: {self.dataset_path}")
|
76
|
+
|
77
|
+
self.recursive = dataset_cfg.get("recursive", True)
|
78
|
+
self.follow_symlinks = dataset_cfg.get("follow_symlinks", False)
|
79
|
+
|
80
|
+
# Chunk settings
|
81
|
+
self.chunk_size = cfg.get("chunk_size", 1000)
|
82
|
+
self.min_buffer = cfg.get("min_chunk_buffer", 10)
|
83
|
+
self.buffer_multiplier = cfg.get("chunk_buffer_multiplier", 3)
|
84
|
+
|
85
|
+
# HTTP server settings
|
86
|
+
self.http_bind_address = dataset_cfg.get("http_bind_address", "0.0.0.0")
|
87
|
+
self.http_public_address = dataset_cfg.get("public_address", "127.0.0.1")
|
88
|
+
self.http_port = dataset_cfg.get("http_port", 8766)
|
89
|
+
|
90
|
+
logger.info(f"Root path: {self.dataset_path}, recursive: {self.recursive}")
|
91
|
+
|
92
|
+
# Initialize chunk tracking
|
93
|
+
checkpoint_dir = Path(cfg.get("checkpoint_dir", "./checkpoints"))
|
94
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
95
|
+
self.chunk_tracker = ChunkTracker(checkpoint_dir / "chunks.json")
|
96
|
+
|
97
|
+
# Discover images
|
98
|
+
self._discover_images()
|
99
|
+
|
100
|
+
# Restore existing state
|
101
|
+
self._restore_state(storage)
|
102
|
+
|
103
|
+
# Start HTTP server for image serving
|
104
|
+
self._start_http_server()
|
105
|
+
|
106
|
+
# Start background unit creation
|
107
|
+
self.unit_creation_thread = threading.Thread(
|
108
|
+
target=self._create_units_background, daemon=True
|
109
|
+
)
|
110
|
+
self.unit_creation_thread.start()
|
111
|
+
logger.debug("Unit creation thread started")
|
112
|
+
|
113
|
+
def _discover_images(self):
|
114
|
+
"""Discover all image files in the filesystem."""
|
115
|
+
logger.info("Discovering images...")
|
116
|
+
|
117
|
+
if self.recursive:
|
118
|
+
# Walk directory tree
|
119
|
+
for root, dirs, files in os.walk(self.dataset_path, followlinks=self.follow_symlinks):
|
120
|
+
root_path = Path(root)
|
121
|
+
|
122
|
+
# Skip hidden directories
|
123
|
+
dirs[:] = [d for d in dirs if not d.startswith(".")]
|
124
|
+
|
125
|
+
for file in files:
|
126
|
+
if any(file.lower().endswith(ext) for ext in IMAGE_EXTENSIONS):
|
127
|
+
file_path = root_path / file
|
128
|
+
try:
|
129
|
+
size = file_path.stat().st_size
|
130
|
+
self.all_images.append((file_path, size))
|
131
|
+
except OSError as e:
|
132
|
+
logger.warning(f"Cannot stat {file_path}: {e}")
|
133
|
+
else:
|
134
|
+
# Just scan root directory
|
135
|
+
for file_path in self.dataset_path.iterdir():
|
136
|
+
if file_path.is_file() and any(
|
137
|
+
file_path.suffix.lower() == ext for ext in IMAGE_EXTENSIONS
|
138
|
+
):
|
139
|
+
try:
|
140
|
+
size = file_path.stat().st_size
|
141
|
+
self.all_images.append((file_path, size))
|
142
|
+
except OSError as e:
|
143
|
+
logger.warning(f"Cannot stat {file_path}: {e}")
|
144
|
+
|
145
|
+
# Sort for consistent ordering
|
146
|
+
self.all_images.sort(key=lambda x: str(x[0]))
|
147
|
+
self.total_images = len(self.all_images)
|
148
|
+
|
149
|
+
logger.info(f"Found {self.total_images} images")
|
150
|
+
|
151
|
+
def _start_http_server(self):
|
152
|
+
"""Start HTTP server for serving images."""
|
153
|
+
self.http_app = FastAPI()
|
154
|
+
|
155
|
+
@self.http_app.get("/image/{image_index:int}")
|
156
|
+
async def get_image(image_index: int):
|
157
|
+
"""Serve an image by index."""
|
158
|
+
if image_index < 0 or image_index >= len(self.all_images):
|
159
|
+
raise HTTPException(status_code=404, detail="Image not found")
|
160
|
+
|
161
|
+
file_path, _ = self.all_images[image_index]
|
162
|
+
|
163
|
+
if not file_path.exists():
|
164
|
+
raise HTTPException(status_code=404, detail="Image file not found")
|
165
|
+
|
166
|
+
# Determine content type
|
167
|
+
content_type = mimetypes.guess_type(str(file_path))[0] or "image/jpeg"
|
168
|
+
|
169
|
+
# Stream file
|
170
|
+
async def stream_file():
|
171
|
+
async with aiofiles.open(file_path, "rb") as f:
|
172
|
+
while chunk := await f.read(1024 * 1024): # 1MB chunks
|
173
|
+
yield chunk
|
174
|
+
|
175
|
+
return StreamingResponse(
|
176
|
+
stream_file(),
|
177
|
+
media_type=content_type,
|
178
|
+
headers={"Content-Disposition": f'inline; filename="{file_path.name}"'},
|
179
|
+
)
|
180
|
+
|
181
|
+
@self.http_app.get("/info")
|
182
|
+
async def get_info():
|
183
|
+
"""Get dataset info."""
|
184
|
+
return {
|
185
|
+
"total_images": self.total_images,
|
186
|
+
"root_path": str(self.dataset_path),
|
187
|
+
"http_url": f"http://{self.http_public_address}:{self.http_port}",
|
188
|
+
}
|
189
|
+
|
190
|
+
# Start server in background
|
191
|
+
async def run_server():
|
192
|
+
config = uvicorn.Config(
|
193
|
+
app=self.http_app,
|
194
|
+
host=self.http_bind_address,
|
195
|
+
port=self.http_port,
|
196
|
+
log_level="warning",
|
197
|
+
)
|
198
|
+
server = uvicorn.Server(config)
|
199
|
+
await server.serve()
|
200
|
+
|
201
|
+
loop = asyncio.new_event_loop()
|
202
|
+
self.http_server_task = loop.create_task(run_server())
|
203
|
+
|
204
|
+
# Run in thread
|
205
|
+
def run_loop():
|
206
|
+
asyncio.set_event_loop(loop)
|
207
|
+
loop.run_forever()
|
208
|
+
|
209
|
+
threading.Thread(target=run_loop, daemon=True).start()
|
210
|
+
logger.info(
|
211
|
+
f"HTTP server started on {self.http_bind_address}:{self.http_port}, advertising hostname {self.http_public_address} to clients"
|
212
|
+
)
|
213
|
+
|
214
|
+
def _restore_state(self, storage: StorageManager) -> None:
|
215
|
+
"""Restore state from chunk tracker."""
|
216
|
+
logger.debug("Restoring state from chunk tracker")
|
217
|
+
if not self.chunk_tracker:
|
218
|
+
return
|
219
|
+
|
220
|
+
all_processed_jobs = storage.get_all_processed_job_ids()
|
221
|
+
|
222
|
+
with self.lock:
|
223
|
+
for chunk_id, chunk_state in self.chunk_tracker.chunks.items():
|
224
|
+
# Calculate actual unprocessed ranges
|
225
|
+
chunk_range = (
|
226
|
+
chunk_state.start_index,
|
227
|
+
chunk_state.start_index + chunk_state.chunk_size - 1,
|
228
|
+
)
|
229
|
+
|
230
|
+
# Get processed indices for this chunk
|
231
|
+
processed_ranges = self.chunk_tracker.get_processed_indices_for_chunk(
|
232
|
+
chunk_id, all_processed_jobs
|
233
|
+
)
|
234
|
+
|
235
|
+
# Calculate unprocessed ranges
|
236
|
+
unprocessed_ranges = self._subtract_ranges([chunk_range], processed_ranges)
|
237
|
+
|
238
|
+
if unprocessed_ranges:
|
239
|
+
# Create work unit for unprocessed items
|
240
|
+
chunk_index = chunk_state.start_index // self.chunk_size
|
241
|
+
|
242
|
+
# Get filenames for this chunk
|
243
|
+
filenames = {}
|
244
|
+
for idx in range(
|
245
|
+
chunk_state.start_index, chunk_state.start_index + chunk_state.chunk_size
|
246
|
+
):
|
247
|
+
if idx < len(self.all_images):
|
248
|
+
filenames[idx] = self.all_images[idx][0].name
|
249
|
+
|
250
|
+
unit = WorkUnit(
|
251
|
+
unit_id=chunk_id,
|
252
|
+
chunk_id=chunk_id,
|
253
|
+
source_id="local",
|
254
|
+
data={
|
255
|
+
"start_index": chunk_state.start_index,
|
256
|
+
"chunk_size": chunk_state.chunk_size,
|
257
|
+
"unprocessed_ranges": unprocessed_ranges,
|
258
|
+
"http_url": f"http://{self.http_public_address}:{self.http_port}",
|
259
|
+
"filenames": filenames,
|
260
|
+
},
|
261
|
+
metadata={
|
262
|
+
"dataset": str(self.dataset_path),
|
263
|
+
"chunk_index": chunk_index,
|
264
|
+
},
|
265
|
+
)
|
266
|
+
|
267
|
+
self.work_units[unit.unit_id] = unit
|
268
|
+
self.pending_units.append(unit.unit_id)
|
269
|
+
|
270
|
+
def _create_units_background(self) -> None:
|
271
|
+
"""Background thread to create work units on demand."""
|
272
|
+
logger.info("Starting work unit creation thread")
|
273
|
+
|
274
|
+
while not self.stop_creation.is_set():
|
275
|
+
# Check if we need more units
|
276
|
+
with self.lock:
|
277
|
+
pending_count = len(self.pending_units)
|
278
|
+
assigned_count = sum(len(units) for units in self.assigned_units.values())
|
279
|
+
worker_count = max(1, len(self.assigned_units))
|
280
|
+
|
281
|
+
target_buffer = max(self.min_buffer, worker_count * self.buffer_multiplier)
|
282
|
+
units_needed = max(0, target_buffer - (pending_count + assigned_count))
|
283
|
+
|
284
|
+
if units_needed == 0:
|
285
|
+
threading.Event().wait(5)
|
286
|
+
continue
|
287
|
+
|
288
|
+
# Create units as needed
|
289
|
+
units_created = 0
|
290
|
+
|
291
|
+
while units_created < units_needed and self.current_index < self.total_images:
|
292
|
+
chunk_size = min(self.chunk_size, self.total_images - self.current_index)
|
293
|
+
chunk_id = self.current_index // self.chunk_size
|
294
|
+
|
295
|
+
with self.lock:
|
296
|
+
job_id_obj = JobId(
|
297
|
+
shard_id="local", chunk_id=str(chunk_id), sample_id=str(self.current_index)
|
298
|
+
)
|
299
|
+
unit_id = job_id_obj.get_chunk_str() # e.g. "local:chunk:0"
|
300
|
+
|
301
|
+
if unit_id in self.work_units:
|
302
|
+
self.current_index += self.chunk_size
|
303
|
+
continue
|
304
|
+
|
305
|
+
# Check if chunk is already completed
|
306
|
+
if self.chunk_tracker:
|
307
|
+
chunk_state = self.chunk_tracker.chunks.get(unit_id)
|
308
|
+
if chunk_state and chunk_state.status == "completed":
|
309
|
+
self.current_index += self.chunk_size
|
310
|
+
continue
|
311
|
+
|
312
|
+
# Get filenames for this chunk
|
313
|
+
filenames = {}
|
314
|
+
for idx in range(self.current_index, self.current_index + chunk_size):
|
315
|
+
if idx < len(self.all_images):
|
316
|
+
filenames[idx] = self.all_images[idx][0].name
|
317
|
+
|
318
|
+
unit = WorkUnit(
|
319
|
+
unit_id=unit_id,
|
320
|
+
chunk_id=unit_id,
|
321
|
+
source_id="local",
|
322
|
+
data={
|
323
|
+
"start_index": self.current_index,
|
324
|
+
"chunk_size": chunk_size,
|
325
|
+
"unprocessed_ranges": [
|
326
|
+
(self.current_index, self.current_index + chunk_size - 1)
|
327
|
+
],
|
328
|
+
"http_url": f"http://{self.http_public_address}:{self.http_port}",
|
329
|
+
"filenames": filenames,
|
330
|
+
},
|
331
|
+
metadata={
|
332
|
+
"dataset": str(self.dataset_path),
|
333
|
+
"chunk_index": chunk_id,
|
334
|
+
},
|
335
|
+
)
|
336
|
+
|
337
|
+
self.work_units[unit_id] = unit
|
338
|
+
self.pending_units.append(unit_id)
|
339
|
+
|
340
|
+
if self.chunk_tracker:
|
341
|
+
self.chunk_tracker.add_chunk(
|
342
|
+
unit_id, "local", str(self.dataset_path), self.current_index, chunk_size
|
343
|
+
)
|
344
|
+
|
345
|
+
units_created += 1
|
346
|
+
|
347
|
+
self.current_index += self.chunk_size
|
348
|
+
|
349
|
+
if units_created > 0:
|
350
|
+
logger.debug(f"Created {units_created} work units")
|
351
|
+
|
352
|
+
def _subtract_ranges(
|
353
|
+
self, total_ranges: List[Tuple[int, int]], processed_ranges: List[Tuple[int, int]]
|
354
|
+
) -> List[Tuple[int, int]]:
|
355
|
+
"""Subtract processed ranges from total ranges."""
|
356
|
+
if not processed_ranges:
|
357
|
+
return total_ranges
|
358
|
+
|
359
|
+
# Create a set of all processed indices
|
360
|
+
processed_indices = set()
|
361
|
+
for start, end in processed_ranges:
|
362
|
+
processed_indices.update(range(start, end + 1))
|
363
|
+
|
364
|
+
# Find unprocessed ranges
|
365
|
+
unprocessed_ranges = []
|
366
|
+
for start, end in total_ranges:
|
367
|
+
current_start = None
|
368
|
+
for i in range(start, end + 1):
|
369
|
+
if i not in processed_indices:
|
370
|
+
if current_start is None:
|
371
|
+
current_start = i
|
372
|
+
else:
|
373
|
+
if current_start is not None:
|
374
|
+
unprocessed_ranges.append((current_start, i - 1))
|
375
|
+
current_start = None
|
376
|
+
|
377
|
+
if current_start is not None:
|
378
|
+
unprocessed_ranges.append((current_start, end))
|
379
|
+
|
380
|
+
return unprocessed_ranges
|
381
|
+
|
382
|
+
def get_work_units(self, count: int, worker_id: str) -> List[WorkUnit]:
|
383
|
+
"""Get available work units for a worker."""
|
384
|
+
logger.debug("get_work_units called: count=%d worker_id=%s", count, worker_id)
|
385
|
+
assigned = []
|
386
|
+
|
387
|
+
with self.lock:
|
388
|
+
while len(assigned) < count and self.pending_units:
|
389
|
+
unit_id = self.pending_units.popleft()
|
390
|
+
unit = self.work_units.get(unit_id)
|
391
|
+
|
392
|
+
if unit:
|
393
|
+
self.assigned_units[worker_id].add(unit_id)
|
394
|
+
assigned.append(unit)
|
395
|
+
logger.debug("Assigning unit %s to worker %s", unit_id, worker_id)
|
396
|
+
|
397
|
+
if self.chunk_tracker:
|
398
|
+
self.chunk_tracker.mark_assigned(unit_id, worker_id)
|
399
|
+
|
400
|
+
logger.debug("Returning %d work units to worker %s", len(assigned), worker_id)
|
401
|
+
return assigned
|
402
|
+
|
403
|
+
def mark_completed(self, unit_id: str, worker_id: str) -> None:
|
404
|
+
"""Mark a work unit as completed."""
|
405
|
+
logger.debug("Marking unit %s as completed by worker %s", unit_id, worker_id)
|
406
|
+
with self.lock:
|
407
|
+
if unit_id in self.work_units:
|
408
|
+
self.assigned_units[worker_id].discard(unit_id)
|
409
|
+
|
410
|
+
if self.chunk_tracker:
|
411
|
+
self.chunk_tracker.mark_completed(unit_id)
|
412
|
+
|
413
|
+
def mark_failed(self, unit_id: str, worker_id: str, error: str) -> None:
|
414
|
+
"""Mark a work unit as failed."""
|
415
|
+
logger.debug("Marking unit %s as failed by worker %s, error: %s", unit_id, worker_id, error)
|
416
|
+
with self.lock:
|
417
|
+
if unit_id in self.work_units:
|
418
|
+
self.assigned_units[worker_id].discard(unit_id)
|
419
|
+
self.pending_units.append(unit_id)
|
420
|
+
|
421
|
+
if self.chunk_tracker:
|
422
|
+
self.chunk_tracker.mark_failed(unit_id)
|
423
|
+
|
424
|
+
def release_assignments(self, worker_id: str) -> None:
|
425
|
+
"""Release all assignments for a disconnected worker."""
|
426
|
+
logger.debug("Releasing assignments for worker %s", worker_id)
|
427
|
+
with self.lock:
|
428
|
+
unit_ids = list(self.assigned_units.get(worker_id, []))
|
429
|
+
|
430
|
+
for unit_id in unit_ids:
|
431
|
+
if unit_id in self.work_units:
|
432
|
+
self.pending_units.append(unit_id)
|
433
|
+
|
434
|
+
if worker_id in self.assigned_units:
|
435
|
+
del self.assigned_units[worker_id]
|
436
|
+
|
437
|
+
if self.chunk_tracker:
|
438
|
+
self.chunk_tracker.release_worker_chunks(worker_id)
|
439
|
+
|
440
|
+
def update_from_storage(self, processed_job_ids: Set[str]) -> None:
|
441
|
+
"""Update work units based on what's been processed."""
|
442
|
+
logger.info(f"Updating work units from {len(processed_job_ids)} processed jobs")
|
443
|
+
|
444
|
+
with self.lock:
|
445
|
+
for unit_id, unit in self.work_units.items():
|
446
|
+
start_index = unit.data["start_index"]
|
447
|
+
chunk_size = unit.data["chunk_size"]
|
448
|
+
chunk_index = unit.metadata["chunk_index"]
|
449
|
+
|
450
|
+
# Find processed indices for this chunk
|
451
|
+
processed_indices = []
|
452
|
+
for job_id in processed_job_ids:
|
453
|
+
job_id_obj = JobId.from_str(job_id)
|
454
|
+
if job_id_obj.shard_id == "local" and int(job_id_obj.chunk_id) == chunk_index:
|
455
|
+
idx = int(job_id_obj.sample_id)
|
456
|
+
if start_index <= idx < start_index + chunk_size:
|
457
|
+
processed_indices.append(idx)
|
458
|
+
|
459
|
+
if processed_indices:
|
460
|
+
# Convert to ranges
|
461
|
+
processed_indices.sort()
|
462
|
+
processed_ranges = []
|
463
|
+
start = processed_indices[0]
|
464
|
+
end = processed_indices[0]
|
465
|
+
|
466
|
+
for idx in processed_indices[1:]:
|
467
|
+
if idx == end + 1:
|
468
|
+
end = idx
|
469
|
+
else:
|
470
|
+
processed_ranges.append((start, end))
|
471
|
+
start = idx
|
472
|
+
end = idx
|
473
|
+
|
474
|
+
processed_ranges.append((start, end))
|
475
|
+
|
476
|
+
# Calculate unprocessed ranges
|
477
|
+
total_range = [(start_index, start_index + chunk_size - 1)]
|
478
|
+
unprocessed_ranges = self._subtract_ranges(total_range, processed_ranges)
|
479
|
+
|
480
|
+
# Update unit
|
481
|
+
unit.data["unprocessed_ranges"] = unprocessed_ranges
|
482
|
+
|
483
|
+
logger.debug(
|
484
|
+
f"Updated unit {unit_id}: {len(processed_indices)} processed, "
|
485
|
+
f"unprocessed ranges: {unprocessed_ranges}"
|
486
|
+
)
|
487
|
+
|
488
|
+
def get_stats(self) -> Dict[str, Any]:
|
489
|
+
"""Get processor statistics."""
|
490
|
+
with self.lock:
|
491
|
+
stats = {
|
492
|
+
"dataset": str(self.dataset_path),
|
493
|
+
"total_units": len(self.work_units),
|
494
|
+
"pending_units": len(self.pending_units),
|
495
|
+
"assigned_units": sum(len(units) for units in self.assigned_units.values()),
|
496
|
+
"total_images": self.total_images,
|
497
|
+
"workers": len(self.assigned_units),
|
498
|
+
}
|
499
|
+
return stats
|
500
|
+
|
501
|
+
def handle_result(self, result: WorkResult) -> Dict[str, Any]:
|
502
|
+
"""Handle result processing."""
|
503
|
+
base_result = super().handle_result(result)
|
504
|
+
|
505
|
+
# Track processed items
|
506
|
+
if self.chunk_tracker:
|
507
|
+
if "item_indices" not in result.metadata:
|
508
|
+
result.metadata["item_indices"] = [result.metadata.get("_item_index")]
|
509
|
+
indices = result.metadata["item_indices"]
|
510
|
+
|
511
|
+
if indices:
|
512
|
+
indices.sort()
|
513
|
+
ranges = []
|
514
|
+
start = indices[0]
|
515
|
+
end = indices[0]
|
516
|
+
|
517
|
+
for i in range(1, len(indices)):
|
518
|
+
if indices[i] == end + 1:
|
519
|
+
end = indices[i]
|
520
|
+
else:
|
521
|
+
ranges.append((start, end))
|
522
|
+
start = indices[i]
|
523
|
+
end = indices[i]
|
524
|
+
|
525
|
+
ranges.append((start, end))
|
526
|
+
|
527
|
+
for start_idx, end_idx in ranges:
|
528
|
+
self.chunk_tracker.mark_items_processed(result.chunk_id, start_idx, end_idx)
|
529
|
+
|
530
|
+
return base_result
|
531
|
+
|
532
|
+
def get_image_paths(self) -> List[Tuple[Path, int]]:
|
533
|
+
"""Get the list of discovered image paths and sizes."""
|
534
|
+
return self.all_images
|
535
|
+
|
536
|
+
|
537
|
+
class LocalFilesystemWorkerProcessor(WorkerProcessor):
|
538
|
+
"""Worker processor for local filesystem datasets."""
|
539
|
+
|
540
|
+
def __init__(self):
|
541
|
+
logger.debug("Initializing LocalFilesystemWorkerProcessor")
|
542
|
+
self.dataset_path: Optional[Path] = None
|
543
|
+
self.image_paths: Optional[List[Tuple[Path, int]]] = None
|
544
|
+
self.dataset_config: Dict[str, Any] = {}
|
545
|
+
|
546
|
+
def initialize(self, config: ProcessorConfig) -> None:
|
547
|
+
"""Initialize processor."""
|
548
|
+
logger.debug("Initializing worker with config: %s", config.config)
|
549
|
+
self.dataset_config = config.config.get("dataset", {})
|
550
|
+
|
551
|
+
# Check if worker has local storage access
|
552
|
+
worker_cfg = config.config.get("worker", {})
|
553
|
+
local_path = worker_cfg.get("local_storage_path")
|
554
|
+
|
555
|
+
self.dataset_path = None
|
556
|
+
if local_path:
|
557
|
+
self.dataset_path = Path(local_path)
|
558
|
+
if self.dataset_path.exists():
|
559
|
+
logger.info(f"Worker has local storage access at: {self.dataset_path}")
|
560
|
+
# Could potentially cache image list here if needed
|
561
|
+
else:
|
562
|
+
logger.warning(f"Local storage path does not exist: {self.dataset_path}")
|
563
|
+
self.dataset_path = None
|
564
|
+
else:
|
565
|
+
logger.info("Worker does not have local storage access, will use HTTP")
|
566
|
+
|
567
|
+
def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
|
568
|
+
"""Process a work unit, yielding items to be captioned."""
|
569
|
+
logger.debug("Processing unit: %s", unit.unit_id)
|
570
|
+
|
571
|
+
start_index = unit.data["start_index"]
|
572
|
+
chunk_size = unit.data["chunk_size"]
|
573
|
+
unprocessed_ranges = unit.data.get(
|
574
|
+
"unprocessed_ranges", [(start_index, start_index + chunk_size - 1)]
|
575
|
+
)
|
576
|
+
http_url = unit.data.get("http_url")
|
577
|
+
filenames = unit.data.get("filenames", {})
|
578
|
+
|
579
|
+
logger.info(f"Processing unit {unit.unit_id} with ranges: {unprocessed_ranges}")
|
580
|
+
|
581
|
+
# Create set of indices to process
|
582
|
+
indices_to_process = set()
|
583
|
+
for start, end in unprocessed_ranges:
|
584
|
+
indices_to_process.update(range(start, end + 1))
|
585
|
+
|
586
|
+
processed_indices = []
|
587
|
+
|
588
|
+
# Get orchestrator info if we need HTTP
|
589
|
+
orchestrator = context.get("orchestrator")
|
590
|
+
|
591
|
+
for idx in sorted(indices_to_process):
|
592
|
+
try:
|
593
|
+
image = None
|
594
|
+
filename = filenames.get(str(idx), f"image_{idx}")
|
595
|
+
|
596
|
+
if self.dataset_path and self.image_paths:
|
597
|
+
# Direct file access
|
598
|
+
if 0 <= idx < len(self.image_paths):
|
599
|
+
file_path, _ = self.image_paths[idx]
|
600
|
+
if file_path.exists():
|
601
|
+
image = Image.open(file_path)
|
602
|
+
filename = file_path.name
|
603
|
+
logger.debug(f"Loaded image from local path: {file_path}")
|
604
|
+
else:
|
605
|
+
logger.warning(f"Local file not found: {file_path}")
|
606
|
+
|
607
|
+
if image is None and http_url:
|
608
|
+
# HTTP fallback
|
609
|
+
image_url = f"{http_url}/image/{idx}"
|
610
|
+
try:
|
611
|
+
response = requests.get(image_url, timeout=30)
|
612
|
+
response.raise_for_status()
|
613
|
+
image = Image.open(io.BytesIO(response.content))
|
614
|
+
logger.debug(f"Loaded image via HTTP: {image_url}")
|
615
|
+
except Exception as e:
|
616
|
+
logger.error(f"Error downloading image from {image_url}: {e}")
|
617
|
+
continue
|
618
|
+
|
619
|
+
if image is None:
|
620
|
+
logger.warning(f"Could not load image at index {idx}")
|
621
|
+
continue
|
622
|
+
|
623
|
+
# Build job ID
|
624
|
+
chunk_index = unit.metadata["chunk_index"]
|
625
|
+
job_id_obj = JobId(shard_id="local", chunk_id=str(chunk_index), sample_id=str(idx))
|
626
|
+
job_id = job_id_obj.get_sample_str()
|
627
|
+
|
628
|
+
# Metadata
|
629
|
+
clean_metadata = {
|
630
|
+
"_item_index": idx,
|
631
|
+
"_chunk_relative_index": idx - start_index,
|
632
|
+
"_job_id": job_id,
|
633
|
+
"_filename": filename,
|
634
|
+
}
|
635
|
+
|
636
|
+
yield {
|
637
|
+
"image": image,
|
638
|
+
"item_key": str(idx),
|
639
|
+
"item_index": idx,
|
640
|
+
"metadata": clean_metadata,
|
641
|
+
"job_id": job_id,
|
642
|
+
}
|
643
|
+
|
644
|
+
processed_indices.append(idx)
|
645
|
+
|
646
|
+
except Exception as e:
|
647
|
+
logger.error(f"Error processing item at index {idx}: {e}")
|
648
|
+
|
649
|
+
# Store processed indices in context
|
650
|
+
context["_processed_indices"] = processed_indices
|
651
|
+
logger.debug("Processed indices for unit %s: %s", unit.unit_id, processed_indices)
|
652
|
+
|
653
|
+
def prepare_result(
|
654
|
+
self, unit: WorkUnit, outputs: List[Dict[str, Any]], processing_time_ms: float
|
655
|
+
) -> WorkResult:
|
656
|
+
"""Prepare result."""
|
657
|
+
logger.debug("Preparing result for unit %s", unit.unit_id)
|
658
|
+
result = super().prepare_result(unit, outputs, processing_time_ms)
|
659
|
+
|
660
|
+
# Add processed indices to metadata
|
661
|
+
if outputs and "_processed_indices" in outputs[0].get("metadata", {}):
|
662
|
+
result.metadata["item_indices"] = outputs[0]["metadata"]["_processed_indices"]
|
663
|
+
|
664
|
+
return result
|
665
|
+
|
666
|
+
def get_dataset_info(self) -> Dict[str, Any]:
|
667
|
+
"""Get dataset information."""
|
668
|
+
return {
|
669
|
+
"dataset_path": self.dataset_config.get("dataset_path", "local"),
|
670
|
+
"dataset_type": "local_filesystem",
|
671
|
+
"has_local_access": self.dataset_path is not None,
|
672
|
+
}
|
673
|
+
|
674
|
+
def set_image_paths_from_orchestrator(self, image_paths: List[Tuple[str, int]]) -> None:
|
675
|
+
"""Set the image paths list from orchestrator (for local access mode)."""
|
676
|
+
if self.dataset_path:
|
677
|
+
# Convert paths relative to our local storage path
|
678
|
+
self.image_paths = []
|
679
|
+
for path_str, size in image_paths:
|
680
|
+
# Orchestrator sends paths relative to its root
|
681
|
+
# We need to resolve them relative to our local_storage_path
|
682
|
+
self.image_paths.append((self.dataset_path / path_str, size))
|
683
|
+
logger.info(f"Set {len(self.image_paths)} image paths for local access")
|