caption-flow 0.3.1__tar.gz → 0.3.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {caption_flow-0.3.1/src/caption_flow.egg-info → caption_flow-0.3.3}/PKG-INFO +1 -1
- {caption_flow-0.3.1 → caption_flow-0.3.3}/pyproject.toml +1 -1
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/__init__.py +1 -1
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/orchestrator.py +2 -1
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/processors/base.py +3 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/processors/huggingface.py +1 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/processors/local_filesystem.py +2 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/processors/webdataset.py +173 -56
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/utils/chunk_tracker.py +1 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/workers/caption.py +29 -11
- {caption_flow-0.3.1 → caption_flow-0.3.3/src/caption_flow.egg-info}/PKG-INFO +1 -1
- {caption_flow-0.3.1 → caption_flow-0.3.3}/LICENSE +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/README.md +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/setup.cfg +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/cli.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/models.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/monitor.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/processors/__init__.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/storage/__init__.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/storage/exporter.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/storage/manager.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/utils/__init__.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/utils/auth.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/utils/caption_utils.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/utils/certificates.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/utils/checkpoint_tracker.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/utils/image_processor.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/utils/json_utils.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/utils/prompt_template.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/utils/vllm_config.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/viewer.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/workers/base.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow/workers/data.py +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow.egg-info/SOURCES.txt +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow.egg-info/dependency_links.txt +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow.egg-info/entry_points.txt +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow.egg-info/requires.txt +0 -0
- {caption_flow-0.3.1 → caption_flow-0.3.3}/src/caption_flow.egg-info/top_level.txt +0 -0
@@ -124,13 +124,14 @@ class Orchestrator:
|
|
124
124
|
|
125
125
|
# Initialize storage
|
126
126
|
await self.storage.initialize()
|
127
|
-
await self.update_unprocessed_ranges()
|
128
127
|
|
129
128
|
# Start background tasks
|
130
129
|
asyncio.create_task(self._heartbeat_loop())
|
131
130
|
asyncio.create_task(self._checkpoint_loop())
|
132
131
|
asyncio.create_task(self._stats_update_loop())
|
133
132
|
|
133
|
+
await self.update_unprocessed_ranges()
|
134
|
+
|
134
135
|
# Start WebSocket server
|
135
136
|
websocket_logger = logging.getLogger("websockets")
|
136
137
|
websocket_logger.setLevel(logging.WARNING)
|
@@ -14,6 +14,7 @@ class WorkUnit:
|
|
14
14
|
unit_id: str # usually, but not always, the chunk id
|
15
15
|
chunk_id: str # always the chunk id
|
16
16
|
source_id: str # the shard name
|
17
|
+
unit_size: int # how many elements are in the workunit
|
17
18
|
data: Dict[str, Any]
|
18
19
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
19
20
|
priority: int = 0
|
@@ -44,6 +45,7 @@ class WorkAssignment:
|
|
44
45
|
"unit_id": u.unit_id,
|
45
46
|
"source_id": u.source_id,
|
46
47
|
"chunk_id": u.chunk_id,
|
48
|
+
"unit_size": u.unit_size,
|
47
49
|
"data": u.data,
|
48
50
|
"metadata": u.metadata,
|
49
51
|
"priority": u.priority,
|
@@ -62,6 +64,7 @@ class WorkAssignment:
|
|
62
64
|
unit_id=u["unit_id"],
|
63
65
|
chunk_id=u["chunk_id"],
|
64
66
|
source_id=u["source_id"],
|
67
|
+
unit_size=u["unit_size"],
|
65
68
|
data=u["data"],
|
66
69
|
metadata=u.get("metadata", {}),
|
67
70
|
priority=u.get("priority", 0),
|
@@ -251,6 +251,7 @@ class LocalFilesystemOrchestratorProcessor(OrchestratorProcessor):
|
|
251
251
|
unit_id=chunk_id,
|
252
252
|
chunk_id=chunk_id,
|
253
253
|
source_id="local",
|
254
|
+
unit_size=chunk_state.chunk_size,
|
254
255
|
data={
|
255
256
|
"start_index": chunk_state.start_index,
|
256
257
|
"chunk_size": chunk_state.chunk_size,
|
@@ -319,6 +320,7 @@ class LocalFilesystemOrchestratorProcessor(OrchestratorProcessor):
|
|
319
320
|
unit_id=unit_id,
|
320
321
|
chunk_id=unit_id,
|
321
322
|
source_id="local",
|
323
|
+
unit_size=chunk_size,
|
322
324
|
data={
|
323
325
|
"start_index": self.current_index,
|
324
326
|
"chunk_size": chunk_size,
|
@@ -12,6 +12,7 @@ from datetime import datetime
|
|
12
12
|
from PIL import Image
|
13
13
|
import io
|
14
14
|
|
15
|
+
from caption_flow.models import JobId
|
15
16
|
from caption_flow.storage import StorageManager
|
16
17
|
from .base import OrchestratorProcessor, WorkerProcessor, ProcessorConfig, WorkUnit, WorkResult
|
17
18
|
from ..utils import ChunkTracker
|
@@ -21,6 +22,7 @@ import cv2
|
|
21
22
|
import numpy as np
|
22
23
|
|
23
24
|
logger = logging.getLogger(__name__)
|
25
|
+
logger.setLevel(logging.INFO)
|
24
26
|
|
25
27
|
|
26
28
|
class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
@@ -108,52 +110,86 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
108
110
|
return self.shard_info_cache[shard_idx]
|
109
111
|
|
110
112
|
def _restore_state(self, storage: StorageManager) -> None:
|
111
|
-
"""Restore state from chunk tracker."""
|
112
|
-
logger.
|
113
|
+
"""Restore state from chunk tracker and synchronize with storage."""
|
114
|
+
logger.info("Restoring state from chunk tracker and synchronizing with storage")
|
113
115
|
if not self.chunk_tracker:
|
114
116
|
return
|
115
117
|
|
118
|
+
# First, update chunk tracker from storage
|
119
|
+
processed_job_ids = storage.get_all_processed_job_ids()
|
120
|
+
if processed_job_ids:
|
121
|
+
logger.info(
|
122
|
+
f"Synchronizing chunk tracker with {len(processed_job_ids)} processed items from storage"
|
123
|
+
)
|
124
|
+
self.update_from_storage(processed_job_ids)
|
125
|
+
|
126
|
+
# Then restore work units from chunk tracker
|
116
127
|
shards_summary = self.chunk_tracker.get_shards_summary()
|
128
|
+
logger.info(f"Restoring work units from chunk tracker: {len(shards_summary)} shards")
|
117
129
|
|
118
130
|
with self.lock:
|
131
|
+
restored_count = 0
|
119
132
|
for shard_name, shard_info in shards_summary.items():
|
120
133
|
chunks = shard_info.get("chunks", [])
|
121
134
|
for chunk_state in chunks:
|
122
135
|
# Only add incomplete chunks
|
123
|
-
if chunk_state.status
|
124
|
-
logger.debug(f"
|
136
|
+
if chunk_state.status == "completed":
|
137
|
+
logger.debug(f"Skipping completed chunk {chunk_state.chunk_id}")
|
138
|
+
continue
|
125
139
|
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
140
|
+
# Get unprocessed ranges
|
141
|
+
unprocessed_ranges = chunk_state.get_unprocessed_ranges()
|
142
|
+
if not unprocessed_ranges:
|
143
|
+
logger.debug(
|
144
|
+
f"Chunk {chunk_state.chunk_id} has no unprocessed ranges, marking as completed"
|
145
|
+
)
|
146
|
+
self.chunk_tracker.mark_completed(chunk_state.chunk_id)
|
147
|
+
continue
|
130
148
|
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
abs_start = chunk_state.start_index + start
|
135
|
-
abs_end = chunk_state.start_index + end
|
136
|
-
absolute_ranges.append((abs_start, abs_end))
|
149
|
+
logger.info(
|
150
|
+
f"Restoring chunk {chunk_state.chunk_id} with unprocessed ranges: {unprocessed_ranges}"
|
151
|
+
)
|
137
152
|
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
153
|
+
# Convert relative ranges to absolute file indices
|
154
|
+
absolute_ranges = []
|
155
|
+
for start, end in unprocessed_ranges:
|
156
|
+
abs_start = chunk_state.start_index + start
|
157
|
+
abs_end = chunk_state.start_index + end
|
158
|
+
absolute_ranges.append((abs_start, abs_end))
|
159
|
+
|
160
|
+
# Get shard index if available
|
161
|
+
shard_idx = None
|
162
|
+
if self.dataset:
|
163
|
+
for idx in range(self.dataset.num_shards):
|
164
|
+
shard_info = self._get_shard_info_cached(idx)
|
165
|
+
if shard_info and shard_info["name"] == shard_name:
|
166
|
+
shard_idx = idx
|
167
|
+
break
|
168
|
+
|
169
|
+
unit = WorkUnit(
|
170
|
+
unit_id=chunk_state.chunk_id,
|
171
|
+
chunk_id=chunk_state.chunk_id,
|
172
|
+
source_id=shard_name,
|
173
|
+
unit_size=chunk_state.chunk_size,
|
174
|
+
data={
|
175
|
+
"shard_url": chunk_state.shard_url,
|
176
|
+
"shard_name": shard_name,
|
177
|
+
"shard_idx": shard_idx,
|
178
|
+
"start_index": chunk_state.start_index,
|
179
|
+
"chunk_size": chunk_state.chunk_size,
|
180
|
+
"unprocessed_ranges": absolute_ranges,
|
181
|
+
},
|
182
|
+
metadata={
|
183
|
+
"shard_name": shard_name,
|
184
|
+
"chunk_index": chunk_state.start_index // self.chunk_size,
|
185
|
+
},
|
186
|
+
)
|
154
187
|
|
155
|
-
|
156
|
-
|
188
|
+
self.work_units[unit.unit_id] = unit
|
189
|
+
self.pending_units.append(unit.unit_id)
|
190
|
+
restored_count += 1
|
191
|
+
|
192
|
+
logger.info(f"Restored {restored_count} incomplete work units")
|
157
193
|
|
158
194
|
def _create_units_background(self) -> None:
|
159
195
|
"""Background thread to create work units on demand."""
|
@@ -201,7 +237,13 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
201
237
|
|
202
238
|
# Create chunk for current position
|
203
239
|
chunk_size = min(self.chunk_size, shard_files - current_file_idx)
|
204
|
-
|
240
|
+
self.current_chunk_index = current_file_idx // self.chunk_size
|
241
|
+
job_id_obj = JobId(
|
242
|
+
shard_id=shard_name,
|
243
|
+
chunk_id=self.current_chunk_index,
|
244
|
+
sample_id=current_file_idx,
|
245
|
+
)
|
246
|
+
chunk_id = job_id_obj.get_chunk_str()
|
205
247
|
|
206
248
|
with self.lock:
|
207
249
|
# Skip if already exists
|
@@ -224,6 +266,7 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
224
266
|
unit_id=chunk_id,
|
225
267
|
chunk_id=chunk_id,
|
226
268
|
source_id=shard_name,
|
269
|
+
unit_size=chunk_size,
|
227
270
|
data={
|
228
271
|
"shard_url": shard_url,
|
229
272
|
"shard_name": shard_name,
|
@@ -268,6 +311,25 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
268
311
|
unit = self.work_units.get(unit_id)
|
269
312
|
|
270
313
|
if unit:
|
314
|
+
# Update unprocessed ranges from chunk tracker before assigning
|
315
|
+
if self.chunk_tracker and unit_id in self.chunk_tracker.chunks:
|
316
|
+
chunk_state = self.chunk_tracker.chunks[unit_id]
|
317
|
+
relative_unprocessed = chunk_state.get_unprocessed_ranges()
|
318
|
+
|
319
|
+
# Convert relative to absolute indices
|
320
|
+
absolute_ranges = []
|
321
|
+
for start, end in relative_unprocessed:
|
322
|
+
abs_start = chunk_state.start_index + start
|
323
|
+
abs_end = chunk_state.start_index + end
|
324
|
+
absolute_ranges.append((abs_start, abs_end))
|
325
|
+
|
326
|
+
# Update the work unit's unprocessed ranges
|
327
|
+
unit.data["unprocessed_ranges"] = absolute_ranges
|
328
|
+
|
329
|
+
logger.debug(
|
330
|
+
f"Updated unit {unit_id} with unprocessed ranges: {absolute_ranges}"
|
331
|
+
)
|
332
|
+
|
271
333
|
self.assigned_units[worker_id].add(unit_id)
|
272
334
|
assigned.append(unit)
|
273
335
|
|
@@ -373,26 +435,72 @@ class WebDatasetOrchestratorProcessor(OrchestratorProcessor):
|
|
373
435
|
# Group by chunk
|
374
436
|
processed_by_chunk = defaultdict(set)
|
375
437
|
|
376
|
-
for
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
continue
|
438
|
+
for job_id_str in processed_job_ids:
|
439
|
+
try:
|
440
|
+
# Use JobId to parse the job ID string
|
441
|
+
job_id = JobId.from_str(job_id_str)
|
442
|
+
chunk_id = job_id.get_chunk_str()
|
443
|
+
sample_idx = int(job_id.sample_id)
|
444
|
+
processed_by_chunk[chunk_id].add(sample_idx)
|
445
|
+
except ValueError as e:
|
446
|
+
logger.warning(f"Invalid job ID format: {job_id_str} - {e}")
|
447
|
+
continue
|
387
448
|
|
388
449
|
# Update chunk tracker with processed items
|
389
450
|
if self.chunk_tracker:
|
390
451
|
for chunk_id, indices in processed_by_chunk.items():
|
391
452
|
if indices:
|
453
|
+
# Get or create chunk state
|
454
|
+
chunk_state = self.chunk_tracker.chunks.get(chunk_id)
|
455
|
+
if not chunk_state:
|
456
|
+
# Parse chunk_id using JobId to get shard info
|
457
|
+
try:
|
458
|
+
# chunk_id format: "shard_id:chunk:chunk_idx"
|
459
|
+
parts = chunk_id.split(":")
|
460
|
+
if len(parts) >= 3:
|
461
|
+
shard_name = parts[0]
|
462
|
+
chunk_idx = int(parts[2])
|
463
|
+
# Infer start index from chunk index and size
|
464
|
+
start_index = chunk_idx * self.chunk_size
|
465
|
+
# Create chunk state
|
466
|
+
self.chunk_tracker.add_chunk(
|
467
|
+
chunk_id,
|
468
|
+
shard_name,
|
469
|
+
f"{shard_name}.tar",
|
470
|
+
start_index,
|
471
|
+
self.chunk_size,
|
472
|
+
)
|
473
|
+
logger.info(f"Created missing chunk state for {chunk_id}")
|
474
|
+
except (ValueError, IndexError) as e:
|
475
|
+
logger.error(f"Failed to create chunk state for {chunk_id}: {e}")
|
476
|
+
continue
|
477
|
+
|
392
478
|
# Sort indices and convert to ranges
|
393
479
|
sorted_indices = sorted(indices)
|
394
|
-
|
395
|
-
|
480
|
+
if not sorted_indices:
|
481
|
+
continue
|
482
|
+
|
483
|
+
# Condense into contiguous ranges
|
484
|
+
ranges = []
|
485
|
+
start_range = sorted_indices[0]
|
486
|
+
end_range = sorted_indices[0]
|
487
|
+
|
488
|
+
for i in range(1, len(sorted_indices)):
|
489
|
+
if sorted_indices[i] == end_range + 1:
|
490
|
+
end_range = sorted_indices[i]
|
491
|
+
else:
|
492
|
+
ranges.append((start_range, end_range))
|
493
|
+
start_range = sorted_indices[i]
|
494
|
+
end_range = sorted_indices[i]
|
495
|
+
ranges.append((start_range, end_range))
|
496
|
+
|
497
|
+
# Mark each contiguous range as processed
|
498
|
+
logger.info(f"Marking ranges {ranges} as processed in chunk {chunk_id}")
|
499
|
+
for start_idx, end_idx in ranges:
|
500
|
+
self.chunk_tracker.mark_items_processed(chunk_id, start_idx, end_idx)
|
501
|
+
|
502
|
+
# Save checkpoint after updating
|
503
|
+
self.chunk_tracker.save()
|
396
504
|
|
397
505
|
def get_stats(self) -> Dict[str, Any]:
|
398
506
|
"""Get processor statistics."""
|
@@ -488,7 +596,7 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
488
596
|
|
489
597
|
def process_unit(self, unit: WorkUnit, context: Dict[str, Any]) -> Iterator[Dict[str, Any]]:
|
490
598
|
"""Process a work unit by iterating specified ranges."""
|
491
|
-
logger.debug(f"Processing unit: {unit
|
599
|
+
logger.debug(f"Processing unit: {unit}")
|
492
600
|
|
493
601
|
shard_name = unit.data["shard_name"]
|
494
602
|
shard_idx = unit.data.get("shard_idx")
|
@@ -502,7 +610,11 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
502
610
|
# Generate mock results for unprocessed ranges
|
503
611
|
for start_idx, end_idx in unprocessed_ranges:
|
504
612
|
for idx in range(start_idx, end_idx + 1):
|
505
|
-
|
613
|
+
# Use JobId to create consistent job ID
|
614
|
+
job_id = JobId.from_values(
|
615
|
+
shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(idx)
|
616
|
+
)
|
617
|
+
job_id_str = job_id.get_sample_str()
|
506
618
|
|
507
619
|
yield {
|
508
620
|
"image": self._create_mock_image(idx),
|
@@ -512,10 +624,11 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
512
624
|
"metadata": {
|
513
625
|
"_item_index": idx,
|
514
626
|
"_chunk_relative_index": idx - unit.data["start_index"],
|
515
|
-
"_job_id":
|
627
|
+
"_job_id": job_id_str,
|
516
628
|
"_mock": True,
|
629
|
+
"_processed_indices": processed_indices,
|
517
630
|
},
|
518
|
-
"job_id":
|
631
|
+
"job_id": job_id_str,
|
519
632
|
}
|
520
633
|
|
521
634
|
processed_indices.append(idx)
|
@@ -560,8 +673,11 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
560
673
|
f"Error decoding image {entry.path} with cv2: {img_e}"
|
561
674
|
)
|
562
675
|
|
563
|
-
# Generate job ID
|
564
|
-
job_id =
|
676
|
+
# Generate job ID using JobId class
|
677
|
+
job_id = JobId.from_values(
|
678
|
+
shard_id=shard_name, chunk_id=str(chunk_index), sample_id=str(idx)
|
679
|
+
)
|
680
|
+
job_id_str = job_id.get_sample_str()
|
565
681
|
|
566
682
|
yield {
|
567
683
|
"image": image,
|
@@ -571,11 +687,12 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
571
687
|
"metadata": {
|
572
688
|
"_item_index": idx,
|
573
689
|
"_chunk_relative_index": idx - unit.data["start_index"],
|
574
|
-
"_job_id":
|
690
|
+
"_job_id": job_id_str,
|
575
691
|
"_filename": entry.path,
|
576
692
|
"_file_size": entry.size,
|
693
|
+
"_processed_indices": processed_indices,
|
577
694
|
},
|
578
|
-
"job_id":
|
695
|
+
"job_id": job_id_str,
|
579
696
|
}
|
580
697
|
|
581
698
|
processed_indices.append(idx)
|
@@ -605,8 +722,8 @@ class WebDatasetWorkerProcessor(WorkerProcessor):
|
|
605
722
|
result = super().prepare_result(unit, outputs, processing_time_ms)
|
606
723
|
|
607
724
|
# Add processed indices for chunk tracker
|
608
|
-
if
|
609
|
-
result.metadata["item_indices"] =
|
725
|
+
if hasattr(self, "_last_context") and "_processed_indices" in self._last_context:
|
726
|
+
result.metadata["item_indices"] = self._last_context["_processed_indices"]
|
610
727
|
|
611
728
|
return result
|
612
729
|
|
@@ -565,7 +565,8 @@ class CaptionWorker(BaseWorker):
|
|
565
565
|
batch = []
|
566
566
|
batch_size = self.vllm_config.get("batch_size", 8)
|
567
567
|
context = {}
|
568
|
-
|
568
|
+
self.items_processed = 0
|
569
|
+
self.items_failed = 0
|
569
570
|
# Collect items for batching
|
570
571
|
for item_data in self.processor.process_unit(unit, context):
|
571
572
|
if self.should_stop_processing.is_set() or not self.connected.is_set():
|
@@ -604,16 +605,33 @@ class CaptionWorker(BaseWorker):
|
|
604
605
|
self._process_batch(batch)
|
605
606
|
|
606
607
|
# Notify orchestrator that unit is complete
|
607
|
-
if
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
608
|
+
# Check if the number of processed items matches the expected count for the unit.
|
609
|
+
# The context dictionary holds the count of items yielded by the processor.
|
610
|
+
total_items_in_unit = unit.unit_size
|
611
|
+
|
612
|
+
if (
|
613
|
+
not self.should_stop_processing.is_set()
|
614
|
+
and self.connected.is_set()
|
615
|
+
and self.items_failed == 0
|
616
|
+
and self.items_processed >= total_items_in_unit
|
617
|
+
):
|
618
|
+
if self.websocket:
|
619
|
+
try:
|
620
|
+
asyncio.run_coroutine_threadsafe(
|
621
|
+
self.websocket.send(
|
622
|
+
json.dumps({"type": "work_complete", "unit_id": unit.unit_id})
|
623
|
+
),
|
624
|
+
self.main_loop,
|
625
|
+
).result(timeout=5)
|
626
|
+
logger.info(
|
627
|
+
f"Unit {unit.unit_id} fully processed ({self.items_processed}/{total_items_in_unit}) and marked complete."
|
628
|
+
)
|
629
|
+
except Exception as e:
|
630
|
+
logger.warning(f"Could not notify work complete for unit {unit.unit_id}: {e}")
|
631
|
+
else:
|
632
|
+
logger.warning(
|
633
|
+
f"Processing of unit {unit.unit_id} was incomplete ({self.items_processed}/{total_items_in_unit}). Not marking as complete."
|
634
|
+
)
|
617
635
|
|
618
636
|
def _process_batch(self, batch: List[ProcessingItem]):
|
619
637
|
"""Process a batch of items through all stages."""
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|