caption-flow 0.4.0__tar.gz → 0.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. {caption_flow-0.4.0/src/caption_flow.egg-info → caption_flow-0.4.1}/PKG-INFO +5 -1
  2. {caption_flow-0.4.0 → caption_flow-0.4.1}/README.md +5 -1
  3. {caption_flow-0.4.0 → caption_flow-0.4.1}/pyproject.toml +1 -1
  4. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/__init__.py +1 -1
  5. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/cli.py +2 -27
  6. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/processors/huggingface.py +25 -2
  7. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/workers/caption.py +39 -3
  8. {caption_flow-0.4.0 → caption_flow-0.4.1/src/caption_flow.egg-info}/PKG-INFO +5 -1
  9. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_cli.py +26 -0
  10. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_processors.py +109 -0
  11. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_worker_caption.py +189 -0
  12. {caption_flow-0.4.0 → caption_flow-0.4.1}/LICENSE +0 -0
  13. {caption_flow-0.4.0 → caption_flow-0.4.1}/setup.cfg +0 -0
  14. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/models.py +0 -0
  15. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/monitor.py +0 -0
  16. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/orchestrator.py +0 -0
  17. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/processors/__init__.py +0 -0
  18. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/processors/base.py +0 -0
  19. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/processors/local_filesystem.py +0 -0
  20. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/processors/webdataset.py +0 -0
  21. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/storage/__init__.py +0 -0
  22. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/storage/exporter.py +0 -0
  23. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/storage/manager.py +0 -0
  24. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/utils/__init__.py +0 -0
  25. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/utils/auth.py +0 -0
  26. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/utils/caption_utils.py +0 -0
  27. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/utils/certificates.py +0 -0
  28. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/utils/checkpoint_tracker.py +0 -0
  29. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/utils/chunk_tracker.py +0 -0
  30. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/utils/image_processor.py +0 -0
  31. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/utils/json_utils.py +0 -0
  32. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/utils/prompt_template.py +0 -0
  33. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/utils/vllm_config.py +0 -0
  34. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/viewer.py +0 -0
  35. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/workers/base.py +0 -0
  36. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow/workers/data.py +0 -0
  37. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow.egg-info/SOURCES.txt +0 -0
  38. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow.egg-info/dependency_links.txt +0 -0
  39. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow.egg-info/entry_points.txt +0 -0
  40. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow.egg-info/requires.txt +0 -0
  41. {caption_flow-0.4.0 → caption_flow-0.4.1}/src/caption_flow.egg-info/top_level.txt +0 -0
  42. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_caption_utils.py +0 -0
  43. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_certificates.py +0 -0
  44. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_config_reload.py +0 -0
  45. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_duplicate_job_assignments.py +0 -0
  46. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_exporter.py +0 -0
  47. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_fix_verification.py +0 -0
  48. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_huggingface_ranges.py +0 -0
  49. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_json_utils.py +0 -0
  50. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_main.py +0 -0
  51. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_monitor.py +0 -0
  52. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_range_level_distribution.py +0 -0
  53. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_storage_components.py +0 -0
  54. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_viewer.py +0 -0
  55. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_vllm_config.py +0 -0
  56. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_webdataset_ranges.py +0 -0
  57. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_worker_reconnection_complete.py +0 -0
  58. {caption_flow-0.4.0 → caption_flow-0.4.1}/tests/test_worker_reconnection_sequence.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: caption-flow
3
- Version: 0.4.0
3
+ Version: 0.4.1
4
4
  Summary: Self-contained distributed community captioning system
5
5
  Author-email: bghira <bghira@users.github.com>
6
6
  License: MIT
@@ -48,6 +48,10 @@ Dynamic: license-file
48
48
 
49
49
  # CaptionFlow
50
50
 
51
+ <!-- [![Tests](https://github.com/bghira/CaptionFlow/workflows/tests/badge.svg)](https://github.com/bghira/CaptionFlow/actions/workflows/tests.yml) -->
52
+ [![codecov](https://codecov.io/github/bghira/CaptionFlow/graph/badge.svg?token=PRAQPNGYAS)](https://codecov.io/github/bghira/CaptionFlow)
53
+ [![PyPI version](https://badge.fury.io/py/caption-flow.svg)](https://badge.fury.io/py/caption-flow)
54
+
51
55
  scalable, fault-tolerant **vLLM-powered image captioning**.
52
56
 
53
57
  a fast websocket-based orchestrator paired with lightweight gpu workers achieves exceptional performance for batched requests through vLLM.
@@ -1,5 +1,9 @@
1
1
  # CaptionFlow
2
2
 
3
+ <!-- [![Tests](https://github.com/bghira/CaptionFlow/workflows/tests/badge.svg)](https://github.com/bghira/CaptionFlow/actions/workflows/tests.yml) -->
4
+ [![codecov](https://codecov.io/github/bghira/CaptionFlow/graph/badge.svg?token=PRAQPNGYAS)](https://codecov.io/github/bghira/CaptionFlow)
5
+ [![PyPI version](https://badge.fury.io/py/caption-flow.svg)](https://badge.fury.io/py/caption-flow)
6
+
3
7
  scalable, fault-tolerant **vLLM-powered image captioning**.
4
8
 
5
9
  a fast websocket-based orchestrator paired with lightweight gpu workers achieves exceptional performance for batched requests through vLLM.
@@ -190,4 +194,4 @@ Your contributions will be tracked and attributed in the final dataset!
190
194
 
191
195
  ## License
192
196
 
193
- AGPLv3
197
+ AGPLv3
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "caption-flow"
3
- version = "0.4.0"
3
+ version = "0.4.1"
4
4
  description = "Self-contained distributed community captioning system"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11,<3.13"
@@ -1,6 +1,6 @@
1
1
  """CaptionFlow - Distributed community captioning system."""
2
2
 
3
- __version__ = "0.4.0"
3
+ __version__ = "0.4.1"
4
4
 
5
5
  from .monitor import Monitor
6
6
  from .orchestrator import Orchestrator
@@ -1276,33 +1276,6 @@ async def _export_single_format(
1276
1276
  console.print(f" • {shard_name}: {count:,} items")
1277
1277
 
1278
1278
 
1279
- @main.command()
1280
- @click.option("--data-dir", default="./caption_data", help="Storage directory")
1281
- @click.option(
1282
- "--format",
1283
- type=click.Choice(
1284
- ["jsonl", "json", "csv", "txt", "parquet", "lance", "huggingface_hub", "all"],
1285
- case_sensitive=False,
1286
- ),
1287
- default="jsonl",
1288
- help="Export format (default: jsonl)",
1289
- )
1290
- @click.option("--output", "-o", help="Output path (file for jsonl/csv, directory for json/txt)")
1291
- @click.option("--limit", type=int, help="Limit number of rows to export")
1292
- @click.option("--columns", help="Comma-separated list of columns to export (default: all)")
1293
- @click.option("--export-column", default="captions", help="Column to export for txt format")
1294
- @click.option("--filename-column", default="filename", help="Column containing filenames")
1295
- @click.option("--shard", help="Specific shard to export (e.g., data-0001)")
1296
- @click.option("--shards", help="Comma-separated list of shards to export")
1297
- @click.option("--include-empty", is_flag=True, help="Include rows with empty export column")
1298
- @click.option("--stats-only", is_flag=True, help="Show statistics without exporting")
1299
- @click.option("--optimize", is_flag=True, help="Optimize storage before export")
1300
- @click.option("--verbose", is_flag=True, help="Show detailed export progress")
1301
- @click.option("--hf-dataset", help="Dataset name on HF Hub (e.g., username/dataset-name)")
1302
- @click.option("--license", default="apache-2.0", help="License for the dataset")
1303
- @click.option("--private", is_flag=True, help="Make HF dataset private")
1304
- @click.option("--nsfw", is_flag=True, help="Add not-for-all-audiences tag")
1305
- @click.option("--tags", help="Comma-separated tags for HF dataset")
1306
1279
  def _validate_export_setup(data_dir):
1307
1280
  """Validate export setup and create storage manager."""
1308
1281
  from .storage import StorageManager
@@ -1333,6 +1306,7 @@ async def _run_export_process(
1333
1306
  tags,
1334
1307
  stats_only,
1335
1308
  optimize,
1309
+ include_empty,
1336
1310
  ):
1337
1311
  """Execute the main export process."""
1338
1312
  from .storage.exporter import LanceStorageExporter
@@ -1448,6 +1422,7 @@ def export(
1448
1422
  tags,
1449
1423
  stats_only,
1450
1424
  optimize,
1425
+ include_empty,
1451
1426
  )
1452
1427
  )
1453
1428
  except ExportError as e:
@@ -1195,7 +1195,18 @@ class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
1195
1195
 
1196
1196
  # Still extract URL if available for metadata
1197
1197
  if self.url_column and self.url_column in item:
1198
- image_url = item[self.url_column]
1198
+ url_value = item[self.url_column]
1199
+ if (
1200
+ url_value
1201
+ and str(url_value).strip()
1202
+ and str(url_value).strip().lower() != "none"
1203
+ ):
1204
+ image_url = str(url_value).strip()
1205
+ else:
1206
+ logger.debug(
1207
+ f"Invalid or None URL for item {global_idx}: {url_value}"
1208
+ )
1209
+ image_url = None
1199
1210
 
1200
1211
  # Create dummy image with metadata context
1201
1212
  image = self._create_dummy_image(
@@ -1209,7 +1220,19 @@ class HuggingFaceDatasetWorkerProcessor(WorkerProcessor):
1209
1220
  # Normal processing - load real images
1210
1221
  if self.url_column:
1211
1222
  if self.url_column in item:
1212
- image_url = item[self.url_column]
1223
+ url_value = item[self.url_column]
1224
+ if (
1225
+ url_value
1226
+ and str(url_value).strip()
1227
+ and str(url_value).strip().lower() != "none"
1228
+ ):
1229
+ image_url = str(url_value).strip()
1230
+ else:
1231
+ logger.debug(
1232
+ f"Skipping invalid or None URL for item {global_idx}: {url_value}"
1233
+ )
1234
+ continue # Skip this item entirely
1235
+
1213
1236
  try:
1214
1237
  max_retries = 3
1215
1238
  backoff_factor = 2
@@ -137,6 +137,19 @@ class MultiStageVLLMManager:
137
137
 
138
138
  def get_model_for_stage(self, stage_name: str, model_name: str) -> Tuple[Any, Any, Any, Any]:
139
139
  """Get model components for a stage."""
140
+ if model_name not in self.models:
141
+ raise KeyError(
142
+ f"Model '{model_name}' not found in loaded models. Available models: {list(self.models.keys())}"
143
+ )
144
+ if model_name not in self.processors:
145
+ raise KeyError(f"Processor for model '{model_name}' not found")
146
+ if model_name not in self.tokenizers:
147
+ raise KeyError(f"Tokenizer for model '{model_name}' not found")
148
+ if stage_name not in self.sampling_params:
149
+ raise KeyError(
150
+ f"Sampling params for stage '{stage_name}' not found. Available stages: {list(self.sampling_params.keys())}"
151
+ )
152
+
140
153
  return (
141
154
  self.models[model_name],
142
155
  self.processors[model_name],
@@ -489,7 +502,19 @@ class CaptionWorker(BaseWorker):
489
502
  return True
490
503
  except Exception as e:
491
504
  logger.error(f"Failed to reload vLLM: {e}")
505
+ # Restore previous state
492
506
  self.vllm_config = old_config
507
+ self.stages = self._parse_stages_config(old_config)
508
+ self.stage_order = self._topological_sort_stages(self.stages)
509
+ # Attempt to restore previous models
510
+ try:
511
+ self._setup_vllm()
512
+ except Exception as restore_error:
513
+ logger.error(f"Failed to restore previous vLLM state: {restore_error}")
514
+ # Clean up broken state
515
+ if self.model_manager:
516
+ self.model_manager.cleanup()
517
+ self.model_manager = None
493
518
  return False
494
519
  else:
495
520
  # Clean up models if switching to mock mode
@@ -886,10 +911,21 @@ class CaptionWorker(BaseWorker):
886
911
  stage = next(s for s in self.stages if s.name == stage_name)
887
912
  logger.debug(f"Processing batch through stage: {stage_name}")
888
913
 
914
+ # Check if model manager is properly initialized
915
+ if not self.model_manager:
916
+ logger.error("Model manager not initialized")
917
+ self.items_failed += len(batch)
918
+ return []
919
+
889
920
  # Get model components
890
- llm, processor, tokenizer, sampling_params = self.model_manager.get_model_for_stage(
891
- stage_name, stage.model
892
- )
921
+ try:
922
+ llm, processor, tokenizer, sampling_params = self.model_manager.get_model_for_stage(
923
+ stage_name, stage.model
924
+ )
925
+ except KeyError as e:
926
+ logger.error(f"Model not found during batch processing: {e}")
927
+ self.items_failed += len(batch)
928
+ return []
893
929
 
894
930
  # Validate batch before processing
895
931
  processable_batch, too_long_items = self._validate_and_split_batch(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: caption-flow
3
- Version: 0.4.0
3
+ Version: 0.4.1
4
4
  Summary: Self-contained distributed community captioning system
5
5
  Author-email: bghira <bghira@users.github.com>
6
6
  License: MIT
@@ -48,6 +48,10 @@ Dynamic: license-file
48
48
 
49
49
  # CaptionFlow
50
50
 
51
+ <!-- [![Tests](https://github.com/bghira/CaptionFlow/workflows/tests/badge.svg)](https://github.com/bghira/CaptionFlow/actions/workflows/tests.yml) -->
52
+ [![codecov](https://codecov.io/github/bghira/CaptionFlow/graph/badge.svg?token=PRAQPNGYAS)](https://codecov.io/github/bghira/CaptionFlow)
53
+ [![PyPI version](https://badge.fury.io/py/caption-flow.svg)](https://badge.fury.io/py/caption-flow)
54
+
51
55
  scalable, fault-tolerant **vLLM-powered image captioning**.
52
56
 
53
57
  a fast websocket-based orchestrator paired with lightweight gpu workers achieves exceptional performance for batched requests through vLLM.
@@ -395,6 +395,32 @@ class TestExportCommand:
395
395
  assert result.exit_code == 0
396
396
  assert "export" in result.output.lower()
397
397
 
398
+ def test_export_command_no_duplicate_registration(self, runner):
399
+ """Regression test: Ensure export command is only registered once.
400
+
401
+ This test prevents the bug where @main.command() was incorrectly
402
+ applied to _validate_export_setup() causing duplicate command
403
+ registration and argument parsing errors.
404
+ """
405
+ # Get all registered commands
406
+ commands = list(main.commands.keys())
407
+
408
+ # Count occurrences of 'export'
409
+ export_count = commands.count("export")
410
+
411
+ # Should be exactly one export command
412
+ assert export_count == 1, f"Expected 1 export command, found {export_count}: {commands}"
413
+
414
+ # Test that export command can handle basic arguments without parsing errors
415
+ result = runner.invoke(main, ["export", "--help"])
416
+ assert result.exit_code == 0
417
+ assert "Got unexpected extra arguments" not in result.output
418
+
419
+ # Test with a data directory argument (the one that was causing issues)
420
+ result = runner.invoke(main, ["export", "--data-dir", "caption_data", "--stats-only"])
421
+ # Should not get parsing errors (though it may fail for other reasons like missing files)
422
+ assert "Got unexpected extra arguments" not in result.output
423
+
398
424
  @patch("caption_flow.storage.StorageManager")
399
425
  @patch("caption_flow.cli.asyncio.run")
400
426
  def test_export_stats_only(self, mock_asyncio_run, mock_storage_class, runner, tmp_path):
@@ -1034,5 +1034,114 @@ class TestProcessorIntegration(ProcessorTestBase):
1034
1034
  assert any(uid in available_ids for uid in unit_ids[1:])
1035
1035
 
1036
1036
 
1037
+ class TestHuggingFaceURLValidation:
1038
+ """Test URL validation in HuggingFace processor."""
1039
+
1040
+ @pytest.fixture
1041
+ def temp_dir(self):
1042
+ """Create a temporary directory for testing."""
1043
+ temp_dir = tempfile.mkdtemp()
1044
+ yield Path(temp_dir)
1045
+ shutil.rmtree(temp_dir)
1046
+
1047
+ @pytest.fixture
1048
+ def mock_parquet_data_with_invalid_urls(self, temp_dir):
1049
+ """Create mock parquet data with invalid URLs."""
1050
+ # Create test data with various invalid URL scenarios
1051
+ data = {
1052
+ "id": [1, 2, 3, 4, 5],
1053
+ "url": [
1054
+ "https://example.com/valid.jpg", # Valid URL
1055
+ None, # None URL
1056
+ "", # Empty string
1057
+ "None", # String "None"
1058
+ " ", # Whitespace only
1059
+ ],
1060
+ "caption": [
1061
+ "A valid image",
1062
+ "Image with None URL",
1063
+ "Image with empty URL",
1064
+ "Image with None string URL",
1065
+ "Image with whitespace URL",
1066
+ ],
1067
+ }
1068
+
1069
+ # Create parquet file
1070
+ table = pa.table(data)
1071
+ parquet_file = temp_dir / "test_data.parquet"
1072
+ pq.write_table(table, parquet_file)
1073
+
1074
+ return str(parquet_file)
1075
+
1076
+ def test_url_validation_skips_invalid_urls(self):
1077
+ """Test URL validation logic that skips invalid URLs."""
1078
+ # Test the actual validation logic used in the processor
1079
+ invalid_urls = [None, "", " ", "None", "NONE", "none"]
1080
+ valid_urls = ["https://example.com/image.jpg", "http://test.com/pic.png"]
1081
+
1082
+ # Track which URLs would be processed (not skipped)
1083
+ processed_urls = []
1084
+
1085
+ for url_value in invalid_urls + valid_urls:
1086
+ # This matches the exact logic from the processor
1087
+ if url_value and str(url_value).strip() and str(url_value).strip().lower() != "none":
1088
+ processed_urls.append(str(url_value).strip())
1089
+
1090
+ # Should only process valid URLs
1091
+ assert len(processed_urls) == 2
1092
+ assert "https://example.com/image.jpg" in processed_urls
1093
+ assert "http://test.com/pic.png" in processed_urls
1094
+
1095
+ def test_url_validation_in_mock_mode(self):
1096
+ """Test URL validation logic preserves valid URLs for metadata."""
1097
+ test_urls = {
1098
+ "valid": "https://example.com/valid.jpg",
1099
+ "none": None,
1100
+ "empty": "",
1101
+ "none_string": "None",
1102
+ "whitespace": " ",
1103
+ }
1104
+
1105
+ # Simulate extraction and validation for metadata
1106
+ extracted_urls = {}
1107
+ for key, url_value in test_urls.items():
1108
+ if url_value and str(url_value).strip() and str(url_value).strip().lower() != "none":
1109
+ extracted_urls[key] = str(url_value).strip()
1110
+ else:
1111
+ extracted_urls[key] = None
1112
+
1113
+ # Only valid URL should be extracted
1114
+ assert extracted_urls["valid"] == "https://example.com/valid.jpg"
1115
+ assert extracted_urls["none"] is None
1116
+ assert extracted_urls["empty"] is None
1117
+ assert extracted_urls["none_string"] is None
1118
+ assert extracted_urls["whitespace"] is None
1119
+
1120
+ def test_url_validation_edge_cases(self):
1121
+ """Test edge cases for URL validation."""
1122
+ processor = HuggingFaceDatasetWorkerProcessor()
1123
+
1124
+ # Test different invalid URL values
1125
+ test_cases = [
1126
+ (None, False),
1127
+ ("", False),
1128
+ (" ", False),
1129
+ ("None", False),
1130
+ ("NONE", False),
1131
+ ("none", False),
1132
+ ("https://valid.com/image.jpg", True),
1133
+ ("http://valid.com/image.jpg", True),
1134
+ (" https://valid.com/image.jpg ", True), # Should be stripped
1135
+ ]
1136
+
1137
+ for url_value, should_be_valid in test_cases:
1138
+ # Simulate the validation logic from the processor (matches the actual code)
1139
+ is_valid = bool(
1140
+ url_value and str(url_value).strip() and str(url_value).strip().lower() != "none"
1141
+ )
1142
+
1143
+ assert is_valid == should_be_valid, f"URL validation failed for: {url_value!r}"
1144
+
1145
+
1037
1146
  if __name__ == "__main__":
1038
1147
  pytest.main([__file__, "-v", "-s"])
@@ -879,5 +879,194 @@ class TestCaptionWorkerProcessors:
879
879
  assert worker.processor is not None
880
880
 
881
881
 
882
+ class TestCaptionWorkerConfigReload:
883
+ """Test CaptionWorker config reload functionality."""
884
+
885
+ @pytest.fixture
886
+ def worker_config(self):
887
+ return {
888
+ "name": "test_worker",
889
+ "token": "test_token",
890
+ "server": "ws://localhost:8765",
891
+ "server_url": "ws://localhost:8765",
892
+ "gpu_id": 0,
893
+ }
894
+
895
+ @pytest.fixture
896
+ def initial_vllm_config(self):
897
+ return {
898
+ "model": "test-model-v1",
899
+ "batch_size": 4,
900
+ "max_model_len": 16384,
901
+ "stages": [
902
+ {
903
+ "name": "caption",
904
+ "model": "test-model-v1",
905
+ "prompts": ["describe this image"],
906
+ "output_field": "captions",
907
+ "requires": [],
908
+ }
909
+ ],
910
+ }
911
+
912
+ def test_config_reload_failure_restores_state(self, worker_config, initial_vllm_config):
913
+ """Test that config reload failure properly restores previous state."""
914
+ worker = CaptionWorker(worker_config)
915
+
916
+ # Set up initial state
917
+ worker.vllm_config = initial_vllm_config
918
+ worker.stages = worker._parse_stages_config(initial_vllm_config)
919
+ worker.stage_order = worker._topological_sort_stages(worker.stages)
920
+ worker.mock_mode = False
921
+
922
+ # Mock model manager with working models
923
+ mock_model_manager = Mock()
924
+ mock_model_manager.models = {"test-model-v1": "loaded_model"}
925
+ mock_model_manager.processors = {"test-model-v1": "loaded_processor"}
926
+ mock_model_manager.tokenizers = {"test-model-v1": "loaded_tokenizer"}
927
+ mock_model_manager.sampling_params = {"caption": "loaded_sampling"}
928
+ worker.model_manager = mock_model_manager
929
+
930
+ # New config that will cause setup failure
931
+ new_config = {
932
+ "model": "test-model-v2",
933
+ "batch_size": 8,
934
+ "stages": [
935
+ {
936
+ "name": "caption",
937
+ "model": "test-model-v2",
938
+ "prompts": ["analyze this image"],
939
+ "output_field": "captions",
940
+ "requires": [],
941
+ }
942
+ ],
943
+ }
944
+
945
+ # Mock _setup_vllm to fail on first call (new config) but succeed on second call (restore)
946
+ setup_call_count = 0
947
+
948
+ def mock_setup_vllm():
949
+ nonlocal setup_call_count
950
+ setup_call_count += 1
951
+ if setup_call_count == 1:
952
+ raise Exception("Failed to load new model")
953
+ # Second call succeeds (restoration)
954
+ return
955
+
956
+ with patch.object(worker, "_setup_vllm", side_effect=mock_setup_vllm):
957
+ # Attempt config update
958
+ result = worker._handle_vllm_config_update(new_config)
959
+
960
+ # Should return False due to failure
961
+ assert result is False
962
+
963
+ # Should have restored original config
964
+ assert worker.vllm_config == initial_vllm_config
965
+
966
+ # Should have restored original stages
967
+ assert len(worker.stages) == 1
968
+ assert worker.stages[0].model == "test-model-v1"
969
+ assert worker.stages[0].prompts == ["describe this image"]
970
+
971
+ # Should have called _setup_vllm twice (once for new config, once for restore)
972
+ assert setup_call_count == 2
973
+
974
+ # Model manager cleanup should have been called
975
+ mock_model_manager.cleanup.assert_called()
976
+
977
+ def test_model_manager_get_model_for_stage_keyerror_handling(self):
978
+ """Test that get_model_for_stage provides helpful error messages."""
979
+ from caption_flow.workers.caption import MultiStageVLLMManager
980
+
981
+ manager = MultiStageVLLMManager()
982
+
983
+ # Test missing model
984
+ with pytest.raises(KeyError) as exc_info:
985
+ manager.get_model_for_stage("caption", "missing-model")
986
+ assert "Model 'missing-model' not found" in str(exc_info.value)
987
+ assert "Available models: []" in str(exc_info.value)
988
+
989
+ # Add a model but missing stage
990
+ manager.models["test-model"] = Mock()
991
+ manager.processors["test-model"] = Mock()
992
+ manager.tokenizers["test-model"] = Mock()
993
+
994
+ with pytest.raises(KeyError) as exc_info:
995
+ manager.get_model_for_stage("missing-stage", "test-model")
996
+ assert "Sampling params for stage 'missing-stage' not found" in str(exc_info.value)
997
+ assert "Available stages: []" in str(exc_info.value)
998
+
999
+ def test_process_batch_handles_missing_model_manager(self, worker_config):
1000
+ """Test that batch processing handles missing model manager gracefully."""
1001
+ worker = CaptionWorker(worker_config)
1002
+
1003
+ # Create a mock processing item
1004
+ mock_image = Image.new("RGB", (100, 100), "red")
1005
+ item = ProcessingItem(
1006
+ unit_id="test-unit",
1007
+ job_id="test-job",
1008
+ chunk_id="test-chunk",
1009
+ item_key="test-item",
1010
+ item_index=0,
1011
+ image=mock_image,
1012
+ image_data=b"fake_data",
1013
+ metadata={},
1014
+ )
1015
+
1016
+ # Set up worker state without model manager
1017
+ worker.vllm_config = {"max_model_len": 16384}
1018
+ mock_stage = Mock()
1019
+ mock_stage.name = "test-stage"
1020
+ worker.stages = [mock_stage]
1021
+ worker.stage_order = ["test-stage"]
1022
+ worker.model_manager = None # Simulate missing model manager
1023
+
1024
+ # Process batch should handle missing model manager
1025
+ result = worker._process_batch_multi_stage([item])
1026
+
1027
+ # Should return empty results and increment failed items
1028
+ assert result == []
1029
+ assert worker.items_failed == 1
1030
+
1031
+ def test_process_batch_handles_model_keyerror(self, worker_config):
1032
+ """Test that batch processing handles KeyError from get_model_for_stage."""
1033
+ worker = CaptionWorker(worker_config)
1034
+
1035
+ # Create a mock processing item
1036
+ mock_image = Image.new("RGB", (100, 100), "red")
1037
+ item = ProcessingItem(
1038
+ unit_id="test-unit",
1039
+ job_id="test-job",
1040
+ chunk_id="test-chunk",
1041
+ item_key="test-item",
1042
+ item_index=0,
1043
+ image=mock_image,
1044
+ image_data=b"fake_data",
1045
+ metadata={},
1046
+ )
1047
+
1048
+ # Set up worker state
1049
+ worker.vllm_config = {"max_model_len": 16384}
1050
+
1051
+ # Create mock stage
1052
+ mock_stage = Mock()
1053
+ mock_stage.name = "test-stage"
1054
+ mock_stage.model = "missing-model"
1055
+ worker.stages = [mock_stage]
1056
+ worker.stage_order = ["test-stage"]
1057
+
1058
+ # Mock model manager that raises KeyError
1059
+ mock_model_manager = Mock()
1060
+ mock_model_manager.get_model_for_stage.side_effect = KeyError("Model not found")
1061
+ worker.model_manager = mock_model_manager
1062
+
1063
+ # Process batch should handle KeyError gracefully
1064
+ result = worker._process_batch_multi_stage([item])
1065
+
1066
+ # Should return empty results and increment failed items
1067
+ assert result == []
1068
+ assert worker.items_failed == 1
1069
+
1070
+
882
1071
  if __name__ == "__main__":
883
1072
  pytest.main([__file__, "-v", "-s"])
File without changes
File without changes