caption-flow 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,365 @@
1
+ """Chunk tracking for persistent state across restarts."""
2
+
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import Set, Dict, List, Optional, Any
7
+ from datetime import datetime
8
+ from dataclasses import dataclass, asdict
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ @dataclass
14
+ class ChunkState:
15
+ """State of a chunk."""
16
+
17
+ chunk_id: str
18
+ shard_name: str
19
+ start_index: int
20
+ chunk_size: int
21
+ status: str # pending, assigned, completed, failed
22
+ completed_at: Optional[datetime] = None
23
+ assigned_to: Optional[str] = None
24
+ assigned_at: Optional[datetime] = None
25
+
26
+ def to_dict(self):
27
+ """Convert to dictionary for JSON serialization."""
28
+ d = asdict(self)
29
+ # Convert datetime objects to ISO format strings
30
+ if d["completed_at"]:
31
+ d["completed_at"] = d["completed_at"].isoformat()
32
+ if d["assigned_at"]:
33
+ d["assigned_at"] = d["assigned_at"].isoformat()
34
+ return d
35
+
36
+ @classmethod
37
+ def from_dict(cls, d: Dict):
38
+ """Create from dictionary."""
39
+ # Convert ISO format strings back to datetime objects
40
+ if d.get("completed_at"):
41
+ d["completed_at"] = datetime.fromisoformat(d["completed_at"])
42
+ if d.get("assigned_at"):
43
+ d["assigned_at"] = datetime.fromisoformat(d["assigned_at"])
44
+ return cls(**d)
45
+
46
+
47
+ class ChunkTracker:
48
+ """Tracks chunk processing state persistently."""
49
+
50
+ def __init__(self, checkpoint_file: Path):
51
+ self.checkpoint_file = checkpoint_file
52
+ self.chunks: Dict[str, ChunkState] = {}
53
+ self.completed_chunks: Set[str] = set()
54
+ self._load_checkpoint()
55
+
56
+ def _load_checkpoint(self):
57
+ """Load checkpoint from disk."""
58
+ if self.checkpoint_file.exists():
59
+ try:
60
+ with open(self.checkpoint_file, "r") as f:
61
+ data = json.load(f)
62
+
63
+ # Load chunk states
64
+ for chunk_id, chunk_data in data.get("chunks", {}).items():
65
+ chunk_state = ChunkState.from_dict(chunk_data)
66
+ self.chunks[chunk_id] = chunk_state
67
+ if chunk_state.status == "completed":
68
+ self.completed_chunks.add(chunk_id)
69
+
70
+ logger.info(
71
+ f"Loaded {len(self.chunks)} chunks from checkpoint, "
72
+ f"{len(self.completed_chunks)} completed"
73
+ )
74
+
75
+ except Exception as e:
76
+ logger.error(f"Error loading chunk checkpoint: {e}")
77
+ self.chunks = {}
78
+ self.completed_chunks = set()
79
+
80
+ def save_checkpoint(self):
81
+ """Save checkpoint to disk."""
82
+ try:
83
+ # Ensure parent directory exists
84
+ self.checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
85
+
86
+ # Convert chunks to serializable format
87
+ data = {
88
+ "chunks": {chunk_id: chunk.to_dict() for chunk_id, chunk in self.chunks.items()},
89
+ "updated_at": datetime.utcnow().isoformat(),
90
+ }
91
+
92
+ # Write atomically with absolute paths
93
+ tmp_file = self.checkpoint_file.parent / f"{self.checkpoint_file.name}.tmp"
94
+
95
+ # Write to temp file
96
+ with open(tmp_file, "w") as f:
97
+ json.dump(data, f, indent=2)
98
+
99
+ # Ensure temp file was created
100
+ if not tmp_file.exists():
101
+ raise IOError(f"Failed to create temporary file: {tmp_file}")
102
+
103
+ # Move atomically (use rename for same filesystem)
104
+ import shutil
105
+
106
+ shutil.move(str(tmp_file), str(self.checkpoint_file))
107
+
108
+ logger.debug(
109
+ f"Saved chunk checkpoint with {len(self.chunks)} chunks to {self.checkpoint_file}"
110
+ )
111
+
112
+ except Exception as e:
113
+ logger.error(f"Error saving chunk checkpoint: {e}", exc_info=True)
114
+ # Try direct write as fallback
115
+ try:
116
+ with open(self.checkpoint_file, "w") as f:
117
+ json.dump(data, f, indent=2)
118
+ logger.info("Saved checkpoint using fallback direct write")
119
+ except Exception as fallback_error:
120
+ logger.error(f"Fallback save also failed: {fallback_error}")
121
+
122
+ def add_chunk(self, chunk_id: str, shard_name: str, start_index: int, chunk_size: int) -> bool:
123
+ """Add a new chunk. Returns False if chunk already exists and is completed."""
124
+ if chunk_id in self.completed_chunks:
125
+ logger.debug(f"Chunk {chunk_id} already completed, skipping")
126
+ return False
127
+
128
+ if chunk_id not in self.chunks:
129
+ self.chunks[chunk_id] = ChunkState(
130
+ chunk_id=chunk_id,
131
+ shard_name=shard_name,
132
+ start_index=start_index,
133
+ chunk_size=chunk_size,
134
+ status="pending",
135
+ )
136
+ self.save_checkpoint()
137
+
138
+ return True
139
+
140
+ def mark_assigned(self, chunk_id: str, worker_id: str):
141
+ """Mark chunk as assigned."""
142
+ if chunk_id in self.chunks:
143
+ chunk = self.chunks[chunk_id]
144
+ chunk.status = "assigned"
145
+ chunk.assigned_to = worker_id
146
+ chunk.assigned_at = datetime.utcnow()
147
+ self.save_checkpoint()
148
+
149
+ def mark_completed(self, chunk_id: str):
150
+ """Mark chunk as completed."""
151
+ if chunk_id in self.chunks:
152
+ chunk = self.chunks[chunk_id]
153
+ chunk.status = "completed"
154
+ chunk.completed_at = datetime.utcnow()
155
+ self.completed_chunks.add(chunk_id)
156
+ self.save_checkpoint()
157
+ logger.info(f"Chunk {chunk_id} marked as completed")
158
+
159
+ def mark_failed(self, chunk_id: str):
160
+ """Mark chunk as failed."""
161
+ if chunk_id in self.chunks:
162
+ chunk = self.chunks[chunk_id]
163
+ chunk.status = "pending" # Reset to pending for retry
164
+ chunk.assigned_to = None
165
+ chunk.assigned_at = None
166
+ self.save_checkpoint()
167
+
168
+ def mark_pending(self, chunk_id: str):
169
+ """Mark chunk as pending (for manual reset)."""
170
+ if chunk_id in self.chunks:
171
+ chunk = self.chunks[chunk_id]
172
+ chunk.status = "pending"
173
+ chunk.assigned_to = None
174
+ chunk.assigned_at = None
175
+ self.save_checkpoint()
176
+
177
+ def release_worker_chunks(self, worker_id: str):
178
+ """Release all chunks assigned to a worker."""
179
+ released_chunks = []
180
+ for chunk_id, chunk in self.chunks.items():
181
+ if chunk.assigned_to == worker_id and chunk.status == "assigned":
182
+ chunk.status = "pending"
183
+ chunk.assigned_to = None
184
+ chunk.assigned_at = None
185
+ released_chunks.append(chunk_id)
186
+ self.save_checkpoint()
187
+ return released_chunks
188
+
189
+ def get_pending_chunks(self, shard_name: Optional[str] = None) -> List[str]:
190
+ """Get list of pending chunk IDs, optionally filtered by shard."""
191
+ pending = []
192
+ for chunk_id, chunk in self.chunks.items():
193
+ if chunk.status == "pending":
194
+ if shard_name is None or chunk.shard_name == shard_name:
195
+ pending.append(chunk_id)
196
+ return pending
197
+
198
+ def is_shard_complete(self, shard_name: str) -> bool:
199
+ """Check if all chunks for a shard are complete."""
200
+ shard_chunks = [chunk for chunk in self.chunks.values() if chunk.shard_name == shard_name]
201
+
202
+ if not shard_chunks:
203
+ return False
204
+
205
+ return all(chunk.status == "completed" for chunk in shard_chunks)
206
+
207
+ def get_stats(self) -> Dict[str, int]:
208
+ """Get chunk statistics."""
209
+ stats = {
210
+ "total": len(self.chunks),
211
+ "pending": sum(1 for c in self.chunks.values() if c.status == "pending"),
212
+ "assigned": sum(1 for c in self.chunks.values() if c.status == "assigned"),
213
+ "completed": len(self.completed_chunks),
214
+ "failed": sum(1 for c in self.chunks.values() if c.status == "failed"),
215
+ }
216
+ return stats
217
+
218
+ def get_shards_summary(self) -> Dict[str, Dict[str, Any]]:
219
+ """Get summary of all shards and their chunk status."""
220
+ shards = {}
221
+
222
+ for chunk_id, chunk_state in self.chunks.items():
223
+ shard_name = chunk_state.shard_name
224
+ if shard_name not in shards:
225
+ shards[shard_name] = {
226
+ "total_chunks": 0,
227
+ "completed_chunks": 0,
228
+ "pending_chunks": 0,
229
+ "assigned_chunks": 0,
230
+ "failed_chunks": 0,
231
+ "is_complete": True,
232
+ "chunks": [],
233
+ }
234
+
235
+ shards[shard_name]["chunks"].append(chunk_state)
236
+ shards[shard_name]["total_chunks"] += 1
237
+
238
+ if chunk_state.status == "completed":
239
+ shards[shard_name]["completed_chunks"] += 1
240
+ elif chunk_state.status == "pending":
241
+ shards[shard_name]["pending_chunks"] += 1
242
+ shards[shard_name]["is_complete"] = False
243
+ elif chunk_state.status == "assigned":
244
+ shards[shard_name]["assigned_chunks"] += 1
245
+ shards[shard_name]["is_complete"] = False
246
+ elif chunk_state.status == "failed":
247
+ shards[shard_name]["failed_chunks"] += 1
248
+ shards[shard_name]["is_complete"] = False
249
+
250
+ return shards
251
+
252
+ def get_incomplete_shards(self) -> Set[str]:
253
+ """Get set of shard names that have incomplete chunks."""
254
+ incomplete = set()
255
+ for chunk_id, chunk_state in self.chunks.items():
256
+ if chunk_state.status != "completed":
257
+ incomplete.add(chunk_state.shard_name)
258
+ return incomplete
259
+
260
+ def get_shards_summary(self) -> Dict[str, Dict[str, Any]]:
261
+ """Get summary of all shards and their chunk status."""
262
+ shards = {}
263
+
264
+ for chunk_id, chunk_state in self.chunks.items():
265
+ shard_name = chunk_state.shard_name
266
+ if shard_name not in shards:
267
+ shards[shard_name] = {
268
+ "total_chunks": 0,
269
+ "completed_chunks": 0,
270
+ "pending_chunks": 0,
271
+ "assigned_chunks": 0,
272
+ "failed_chunks": 0,
273
+ "is_complete": True,
274
+ "chunks": [],
275
+ }
276
+
277
+ shards[shard_name]["chunks"].append(chunk_state)
278
+ shards[shard_name]["total_chunks"] += 1
279
+
280
+ if chunk_state.status == "completed":
281
+ shards[shard_name]["completed_chunks"] += 1
282
+ elif chunk_state.status == "pending":
283
+ shards[shard_name]["pending_chunks"] += 1
284
+ shards[shard_name]["is_complete"] = False
285
+ elif chunk_state.status == "assigned":
286
+ shards[shard_name]["assigned_chunks"] += 1
287
+ shards[shard_name]["is_complete"] = False
288
+ elif chunk_state.status == "failed":
289
+ shards[shard_name]["failed_chunks"] += 1
290
+ shards[shard_name]["is_complete"] = False
291
+
292
+ return shards
293
+
294
+ async def sync_with_storage(self, storage_manager):
295
+ """Sync chunk state with storage to detect already-processed chunks."""
296
+ logger.info("Syncing chunk state with storage...")
297
+
298
+ # Get all existing captions from storage
299
+ if storage_manager.captions_path.exists():
300
+ import pyarrow.parquet as pq
301
+
302
+ # Read just the job_id column
303
+ table = pq.read_table(storage_manager.captions_path, columns=["job_id", "chunk_id"])
304
+ existing_job_ids = set(table["job_id"].to_pylist())
305
+
306
+ # Also get chunk_ids if available
307
+ if "chunk_id" in table.column_names:
308
+ existing_chunk_ids = set(
309
+ cid for cid in table["chunk_id"].to_pylist() if cid is not None
310
+ )
311
+
312
+ # Mark existing chunks as completed
313
+ for chunk_id in existing_chunk_ids:
314
+ if chunk_id in self.chunks:
315
+ self.mark_completed(chunk_id)
316
+ else:
317
+ # Create chunk entry for already-processed chunks
318
+ # Extract shard name from chunk_id (format: shard_chunk_index)
319
+ parts = chunk_id.rsplit("_chunk_", 1)
320
+ if len(parts) == 2:
321
+ shard_name = parts[0]
322
+ try:
323
+ start_idx = int(parts[1])
324
+ # We don't know exact chunk size, but mark it as completed
325
+ self.chunks[chunk_id] = ChunkState(
326
+ chunk_id=chunk_id,
327
+ shard_name=shard_name,
328
+ start_index=start_idx,
329
+ chunk_size=1000, # Default chunk size
330
+ status="completed",
331
+ completed_at=datetime.utcnow(),
332
+ )
333
+ self.completed_chunks.add(chunk_id)
334
+ except ValueError:
335
+ logger.warning(f"Could not parse chunk_id: {chunk_id}")
336
+
337
+ logger.info(f"Found {len(existing_chunk_ids)} completed chunks in storage")
338
+
339
+ # Also check by job_id pattern if chunk_id column doesn't exist
340
+ else:
341
+ for job_id in existing_job_ids:
342
+ # Extract chunk_id from job_id (format: chunk_id_item_key)
343
+ if "_chunk_" in job_id:
344
+ parts = job_id.split("_")
345
+ # Find the chunk part
346
+ for i, part in enumerate(parts):
347
+ if part == "chunk" and i + 1 < len(parts):
348
+ try:
349
+ # Reconstruct chunk_id
350
+ chunk_idx = int(parts[i + 1])
351
+ shard_parts = parts[:i]
352
+ chunk_id = f"{'_'.join(shard_parts)}_chunk_{chunk_idx}"
353
+
354
+ if chunk_id not in self.completed_chunks:
355
+ self.completed_chunks.add(chunk_id)
356
+ logger.debug(
357
+ f"Marked chunk {chunk_id} as completed from job_id"
358
+ )
359
+ break
360
+ except ValueError:
361
+ continue
362
+
363
+ logger.info(f"Inferred {len(self.completed_chunks)} completed chunks from job_ids")
364
+
365
+ self.save_checkpoint()
@@ -0,0 +1,186 @@
1
+ """Dataset loading utilities for WebDataset and HuggingFace."""
2
+
3
+ import asyncio
4
+ import shlex
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import List, Dict, Any, Generator, Optional, Tuple
8
+ import json
9
+
10
+ import webdataset as wds
11
+ from huggingface_hub import HfFileSystem, get_token, hf_hub_url
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class DatasetLoader:
17
+ """Handles loading datasets from various sources."""
18
+
19
+ def __init__(self, dataset_path: str, dataset_type: str = "huggingface"):
20
+ """
21
+ Initialize dataset loader.
22
+
23
+ Args:
24
+ dataset_path: Path to dataset (HF repo, local dir, etc.)
25
+ dataset_type: Type of dataset ("huggingface", "webdataset", "local")
26
+ """
27
+ self.dataset_path = dataset_path
28
+ self.dataset_type = dataset_type
29
+ self.token = get_token()
30
+
31
+ if not self.token and dataset_type == "huggingface":
32
+ logger.warning("No HuggingFace token found; run `huggingface-cli login`")
33
+
34
+ def get_shard_list(self) -> List[str]:
35
+ """Get list of all shards in the dataset."""
36
+ if self.dataset_type == "huggingface":
37
+ return self._get_hf_shards()
38
+ elif self.dataset_type == "local":
39
+ return self._get_local_shards()
40
+ else:
41
+ raise ValueError(f"Unknown dataset type: {self.dataset_type}")
42
+
43
+ def _get_hf_shards(self) -> List[str]:
44
+ """Get shard URLs from HuggingFace dataset."""
45
+ logger.info(f"Getting shard list from HuggingFace: {self.dataset_path}")
46
+
47
+ fs = HfFileSystem()
48
+ files = [fs.resolve_path(p) for p in fs.glob(f"hf://datasets/{self.dataset_path}/**/*.tar")]
49
+
50
+ urls = [hf_hub_url(f.repo_id, f.path_in_repo, repo_type="dataset") for f in files]
51
+
52
+ logger.info(f"Found {len(urls)} shards")
53
+ return sorted(urls)
54
+
55
+ def _get_local_shards(self) -> List[str]:
56
+ """Get shard files from local directory."""
57
+ path = Path(self.dataset_path)
58
+ if not path.exists():
59
+ raise ValueError(f"Local dataset path does not exist: {path}")
60
+
61
+ shards = list(path.glob("*.tar"))
62
+ logger.info(f"Found {len(shards)} local shards")
63
+ return [str(s) for s in sorted(shards)]
64
+
65
+ def load_shard(self, shard_url: str, processed_keys: Optional[set] = None) -> wds.DataPipeline:
66
+ """
67
+ Load a single shard as a WebDataset pipeline.
68
+
69
+ Args:
70
+ shard_url: URL or path to the shard
71
+ processed_keys: Set of already processed keys to skip
72
+ """
73
+ if processed_keys is None:
74
+ processed_keys = set()
75
+
76
+ if self.dataset_type == "huggingface":
77
+ # Use curl with auth token for HuggingFace
78
+ url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
79
+ ds = wds.DataPipeline(
80
+ wds.SimpleShardList(url_cmd),
81
+ wds.tarfile_to_samples(),
82
+ wds.to_tuple("__key__", "__url__", "jpg;png;jpeg;webp;jxl"),
83
+ wds.select(lambda x: x[0] not in processed_keys),
84
+ )
85
+ else:
86
+ # Local file access
87
+ ds = wds.DataPipeline(
88
+ wds.SimpleShardList(shard_url),
89
+ wds.tarfile_to_samples(),
90
+ wds.to_tuple("__key__", "__url__", "jpg;png;jpeg;webp;jxl"),
91
+ wds.select(lambda x: x[0] not in processed_keys),
92
+ )
93
+
94
+ return ds
95
+
96
+ def iterate_shard(
97
+ self, shard_url: str, processed_keys: Optional[set] = None
98
+ ) -> Generator[Tuple[str, str, bytes], None, None]:
99
+ """
100
+ Iterate over items in a shard.
101
+
102
+ Yields:
103
+ Tuple of (key, url, image_bytes)
104
+ """
105
+ ds = self.load_shard(shard_url, processed_keys)
106
+
107
+ for key, url, image_data in ds:
108
+ yield key, url, image_data
109
+
110
+ def count_shard_items(self, shard_url: str, processed_keys: Optional[set] = None) -> int:
111
+ """Count items in a shard (can be slow for large shards)."""
112
+ count = 0
113
+ try:
114
+ for _ in self.iterate_shard(shard_url, processed_keys):
115
+ count += 1
116
+ except Exception as e:
117
+ logger.error(f"Error counting shard {shard_url}: {e}")
118
+ return count
119
+
120
+
121
+ class ShardTracker:
122
+ """Tracks shard processing progress."""
123
+
124
+ def __init__(self, checkpoint_path: Path):
125
+ """Initialize shard tracker with checkpoint file."""
126
+ self.checkpoint_path = checkpoint_path
127
+ self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
128
+
129
+ self.completed_shards: set = set()
130
+ self.partial_shards: Dict[str, Dict[str, Any]] = {}
131
+ self.load()
132
+
133
+ def load(self):
134
+ """Load checkpoint from disk."""
135
+ if self.checkpoint_path.exists():
136
+ try:
137
+ data = json.loads(self.checkpoint_path.read_text())
138
+ self.completed_shards = set(data.get("completed_shards", []))
139
+ self.partial_shards = data.get("partial_shards", {})
140
+ logger.info(
141
+ f"Loaded checkpoint: {len(self.completed_shards)} completed, "
142
+ f"{len(self.partial_shards)} partial shards"
143
+ )
144
+ except Exception as e:
145
+ logger.error(f"Failed to load checkpoint: {e}")
146
+
147
+ def save(self):
148
+ """Save checkpoint to disk."""
149
+ data = {
150
+ "completed_shards": list(self.completed_shards),
151
+ "partial_shards": self.partial_shards,
152
+ }
153
+
154
+ tmp = self.checkpoint_path.with_suffix(".tmp")
155
+ tmp.write_text(json.dumps(data, indent=2))
156
+ tmp.replace(self.checkpoint_path)
157
+
158
+ def mark_complete(self, shard_name: str):
159
+ """Mark a shard as complete."""
160
+ self.completed_shards.add(shard_name)
161
+ if shard_name in self.partial_shards:
162
+ del self.partial_shards[shard_name]
163
+ self.save()
164
+
165
+ def update_partial(self, shard_name: str, processed_keys: List[str]):
166
+ """Update partial progress for a shard."""
167
+ self.partial_shards[shard_name] = {"keys": processed_keys, "count": len(processed_keys)}
168
+ self.save()
169
+
170
+ def get_processed_keys(self, shard_name: str) -> set:
171
+ """Get set of processed keys for a shard."""
172
+ if shard_name in self.completed_shards:
173
+ return set() # All done
174
+
175
+ if shard_name in self.partial_shards:
176
+ return set(self.partial_shards[shard_name].get("keys", []))
177
+
178
+ return set()
179
+
180
+ def is_complete(self, shard_name: str) -> bool:
181
+ """Check if a shard is complete."""
182
+ return shard_name in self.completed_shards
183
+
184
+ def get_remaining_shards(self, all_shards: List[str]) -> List[str]:
185
+ """Get list of shards that still need processing."""
186
+ return [s for s in all_shards if Path(s).stem not in self.completed_shards]
@@ -0,0 +1,51 @@
1
+ """Image preprocessing utilities."""
2
+
3
+ import asyncio
4
+ from concurrent.futures import ProcessPoolExecutor
5
+ from pathlib import Path
6
+ from typing import List, Any
7
+
8
+ import numpy as np
9
+ from PIL import Image
10
+
11
+
12
+ class ImageProcessor:
13
+ """Handles image loading and preprocessing."""
14
+
15
+ def __init__(self, num_workers: int = 4):
16
+ self.executor = ProcessPoolExecutor(max_workers=num_workers)
17
+
18
+ async def process_batch(self, image_paths: List[Path]) -> List[np.ndarray]:
19
+ """Process a batch of images in parallel."""
20
+ loop = asyncio.get_event_loop()
21
+
22
+ tasks = []
23
+ for path in image_paths:
24
+ task = loop.run_in_executor(self.executor, self._process_image, path)
25
+ tasks.append(task)
26
+
27
+ return await asyncio.gather(*tasks)
28
+
29
+ @staticmethod
30
+ def _process_image(path: Path) -> np.ndarray:
31
+ """Process a single image."""
32
+ img = Image.open(path)
33
+
34
+ # Resize to standard size
35
+ img = img.resize((224, 224), Image.Resampling.LANCZOS)
36
+
37
+ # Convert to RGB if needed
38
+ if img.mode != "RGB":
39
+ img = img.convert("RGB")
40
+
41
+ # Convert to numpy array
42
+ arr = np.array(img, dtype=np.float32)
43
+
44
+ # Normalize
45
+ arr = arr / 255.0
46
+
47
+ return arr
48
+
49
+ def shutdown(self):
50
+ """Shutdown the executor."""
51
+ self.executor.shutdown(wait=True)
@@ -0,0 +1,41 @@
1
+ """Job queue management."""
2
+
3
+ import asyncio
4
+ from typing import Optional
5
+ from collections import deque
6
+
7
+ from ..models import Job
8
+
9
+
10
+ class JobQueue:
11
+ """Priority job queue with backpressure."""
12
+
13
+ def __init__(self):
14
+ self.queue = deque()
15
+ self.processing = set()
16
+ self.lock = asyncio.Lock()
17
+
18
+ async def add(self, job: Job):
19
+ """Add job to queue."""
20
+ async with self.lock:
21
+ self.queue.append(job)
22
+
23
+ async def get_next(self) -> Optional[Job]:
24
+ """Get next available job."""
25
+ async with self.lock:
26
+ if self.queue:
27
+ job = self.queue.popleft()
28
+ self.processing.add(job.job_id)
29
+ return job
30
+ return None
31
+
32
+ async def complete(self, job_id: str):
33
+ """Mark job as complete."""
34
+ async with self.lock:
35
+ self.processing.discard(job_id)
36
+
37
+ async def requeue(self, job: Job):
38
+ """Requeue a job (for failures)."""
39
+ async with self.lock:
40
+ self.processing.discard(job.job_id)
41
+ self.queue.appendleft(job) # Priority requeue