caption-flow 0.3.4__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. caption_flow/__init__.py +3 -3
  2. caption_flow/cli.py +921 -427
  3. caption_flow/models.py +45 -3
  4. caption_flow/monitor.py +2 -3
  5. caption_flow/orchestrator.py +153 -104
  6. caption_flow/processors/__init__.py +3 -3
  7. caption_flow/processors/base.py +8 -7
  8. caption_flow/processors/huggingface.py +463 -68
  9. caption_flow/processors/local_filesystem.py +24 -28
  10. caption_flow/processors/webdataset.py +28 -22
  11. caption_flow/storage/exporter.py +420 -339
  12. caption_flow/storage/manager.py +636 -756
  13. caption_flow/utils/__init__.py +1 -1
  14. caption_flow/utils/auth.py +1 -1
  15. caption_flow/utils/caption_utils.py +1 -1
  16. caption_flow/utils/certificates.py +15 -8
  17. caption_flow/utils/checkpoint_tracker.py +30 -28
  18. caption_flow/utils/chunk_tracker.py +153 -56
  19. caption_flow/utils/image_processor.py +9 -9
  20. caption_flow/utils/json_utils.py +37 -20
  21. caption_flow/utils/prompt_template.py +24 -16
  22. caption_flow/utils/vllm_config.py +5 -4
  23. caption_flow/viewer.py +4 -12
  24. caption_flow/workers/base.py +5 -4
  25. caption_flow/workers/caption.py +303 -92
  26. caption_flow/workers/data.py +6 -8
  27. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/METADATA +9 -4
  28. caption_flow-0.4.1.dist-info/RECORD +33 -0
  29. caption_flow-0.3.4.dist-info/RECORD +0 -33
  30. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/WHEEL +0 -0
  31. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/entry_points.txt +0 -0
  32. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/licenses/LICENSE +0 -0
  33. {caption_flow-0.3.4.dist-info → caption_flow-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,37 +1,41 @@
1
1
  """JSON serialization utilities for handling special types like datetime."""
2
2
 
3
3
  import json
4
- from datetime import datetime, date
5
- from decimal import Decimal
6
- from pathlib import Path
7
- from typing import Any, Dict, List, Union
8
4
  from dataclasses import asdict, is_dataclass
5
+ from datetime import date, datetime
6
+ from decimal import Decimal
9
7
  from enum import Enum
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Union
10
10
 
11
11
 
12
12
  def safe_json_dumps(obj: Any, **kwargs) -> str:
13
- """
14
- Safely serialize objects to JSON, handling special types.
13
+ """Safely serialize objects to JSON, handling special types.
15
14
 
16
15
  Args:
16
+ ----
17
17
  obj: Object to serialize
18
18
  **kwargs: Additional arguments to pass to json.dumps
19
19
 
20
20
  Returns:
21
+ -------
21
22
  JSON string representation
23
+
22
24
  """
23
25
  return json.dumps(obj, default=json_serializer, **kwargs)
24
26
 
25
27
 
26
28
  def safe_dict(obj: Any) -> Dict[str, Any]:
27
- """
28
- Convert an object to a dictionary, handling special types.
29
+ """Convert an object to a dictionary, handling special types.
29
30
 
30
31
  Args:
32
+ ----
31
33
  obj: Object to convert (dataclass, dict, etc.)
32
34
 
33
35
  Returns:
36
+ -------
34
37
  Dictionary with JSON-serializable values
38
+
35
39
  """
36
40
  if is_dataclass(obj):
37
41
  data = asdict(obj)
@@ -46,14 +50,16 @@ def safe_dict(obj: Any) -> Dict[str, Any]:
46
50
 
47
51
 
48
52
  def sanitize_dict(data: Dict[str, Any]) -> Dict[str, Any]:
49
- """
50
- Recursively sanitize a dictionary to ensure all values are JSON-serializable.
53
+ """Recursively sanitize a dictionary to ensure all values are JSON-serializable.
51
54
 
52
55
  Args:
56
+ ----
53
57
  data: Dictionary to sanitize
54
58
 
55
59
  Returns:
60
+ -------
56
61
  Sanitized dictionary
62
+
57
63
  """
58
64
  result = {}
59
65
 
@@ -83,14 +89,16 @@ def sanitize_dict(data: Dict[str, Any]) -> Dict[str, Any]:
83
89
 
84
90
 
85
91
  def sanitize_value(value: Any) -> Any:
86
- """
87
- Sanitize a single value for JSON serialization.
92
+ """Sanitize a single value for JSON serialization.
88
93
 
89
94
  Args:
95
+ ----
90
96
  value: Value to sanitize
91
97
 
92
98
  Returns:
99
+ -------
93
100
  JSON-serializable value
101
+
94
102
  """
95
103
  if value is None:
96
104
  return None
@@ -115,17 +123,20 @@ def sanitize_value(value: Any) -> Any:
115
123
 
116
124
 
117
125
  def json_serializer(obj: Any) -> Any:
118
- """
119
- Default JSON serializer for special types.
126
+ """Default JSON serializer for special types.
120
127
 
121
128
  Args:
129
+ ----
122
130
  obj: Object to serialize
123
131
 
124
132
  Returns:
133
+ -------
125
134
  JSON-serializable representation
126
135
 
127
136
  Raises:
137
+ ------
128
138
  TypeError: If object type is not supported
139
+
129
140
  """
130
141
  if isinstance(obj, (datetime, date)):
131
142
  return obj.isoformat()
@@ -146,14 +157,16 @@ def json_serializer(obj: Any) -> Any:
146
157
 
147
158
 
148
159
  def parse_datetime(dt_string: Union[str, datetime, None]) -> Union[datetime, None]:
149
- """
150
- Parse a datetime string or return existing datetime.
160
+ """Parse a datetime string or return existing datetime.
151
161
 
152
162
  Args:
163
+ ----
153
164
  dt_string: ISO format datetime string, datetime object, or None
154
165
 
155
166
  Returns:
167
+ -------
156
168
  datetime object or None
169
+
157
170
  """
158
171
  if dt_string is None:
159
172
  return None
@@ -171,31 +184,35 @@ def parse_datetime(dt_string: Union[str, datetime, None]) -> Union[datetime, Non
171
184
 
172
185
  # Convenience functions for common use cases
173
186
  def to_json_dict(obj: Any) -> Dict[str, Any]:
174
- """
175
- Convert any object to a JSON-serializable dictionary.
187
+ """Convert any object to a JSON-serializable dictionary.
176
188
 
177
189
  This is a convenience wrapper around safe_dict.
178
190
 
179
191
  Args:
192
+ ----
180
193
  obj: Object to convert
181
194
 
182
195
  Returns:
196
+ -------
183
197
  JSON-serializable dictionary
198
+
184
199
  """
185
200
  return safe_dict(obj)
186
201
 
187
202
 
188
203
  def to_json_string(obj: Any, indent: int = None) -> str:
189
- """
190
- Convert any object to a JSON string.
204
+ """Convert any object to a JSON string.
191
205
 
192
206
  This is a convenience wrapper around safe_json_dumps.
193
207
 
194
208
  Args:
209
+ ----
195
210
  obj: Object to convert
196
211
  indent: Number of spaces for indentation (None for compact)
197
212
 
198
213
  Returns:
214
+ -------
199
215
  JSON string
216
+
200
217
  """
201
218
  return safe_json_dumps(obj, indent=indent)
@@ -1,8 +1,8 @@
1
1
  """Prompt template system for dynamic column substitution."""
2
2
 
3
- import re
4
3
  import logging
5
- from typing import Dict, Any, List, Optional
4
+ import re
5
+ from typing import Any, Dict, List
6
6
 
7
7
  logger = logging.getLogger(__name__)
8
8
 
@@ -14,12 +14,13 @@ class PromptTemplate:
14
14
  COLUMN_PATTERN = re.compile(r"\{(?:column|col):([\w-]+)\}")
15
15
 
16
16
  def __init__(self, template: str):
17
- """
18
- Initialize with a prompt template.
17
+ """Initialize with a prompt template.
19
18
 
20
19
  Args:
20
+ ----
21
21
  template: Prompt template string, e.g.
22
22
  "describe this image. tags: {column:user_tags}"
23
+
23
24
  """
24
25
  self.template = template
25
26
  self.required_columns = self._extract_columns()
@@ -30,14 +31,16 @@ class PromptTemplate:
30
31
  return list(set(matches)) # Remove duplicates
31
32
 
32
33
  def format(self, item_data: Dict[str, Any]) -> str:
33
- """
34
- Format the template with actual column values.
34
+ """Format the template with actual column values.
35
35
 
36
36
  Args:
37
+ ----
37
38
  item_data: Dictionary containing column values from dataset
38
39
 
39
40
  Returns:
41
+ -------
40
42
  Formatted prompt string
43
+
41
44
  """
42
45
  prompt = self.template
43
46
 
@@ -64,11 +67,12 @@ class PromptTemplate:
64
67
  return prompt.strip()
65
68
 
66
69
  def validate_columns(self, available_columns: List[str]) -> List[str]:
67
- """
68
- Validate that required columns are available.
70
+ """Validate that required columns are available.
69
71
 
70
- Returns:
72
+ Returns
73
+ -------
71
74
  List of missing column names
75
+
72
76
  """
73
77
  missing = []
74
78
  for col in self.required_columns:
@@ -81,11 +85,12 @@ class PromptTemplateManager:
81
85
  """Manages multiple prompt templates."""
82
86
 
83
87
  def __init__(self, prompts: List[str]):
84
- """
85
- Initialize with list of prompt strings (which may contain templates).
88
+ """Initialize with list of prompt strings (which may contain templates).
86
89
 
87
90
  Args:
91
+ ----
88
92
  prompts: List of prompt strings
93
+
89
94
  """
90
95
  self.templates = [PromptTemplate(p) for p in prompts]
91
96
  self._all_required_columns = None
@@ -101,14 +106,16 @@ class PromptTemplateManager:
101
106
  return self._all_required_columns
102
107
 
103
108
  def format_all(self, item_data: Dict[str, Any]) -> List[str]:
104
- """
105
- Format all templates with item data.
109
+ """Format all templates with item data.
106
110
 
107
111
  Args:
112
+ ----
108
113
  item_data: Dictionary containing column values
109
114
 
110
115
  Returns:
116
+ -------
111
117
  List of formatted prompts
118
+
112
119
  """
113
120
  formatted = []
114
121
  for template in self.templates:
@@ -123,11 +130,12 @@ class PromptTemplateManager:
123
130
  return formatted
124
131
 
125
132
  def validate_all(self, available_columns: List[str]) -> Dict[str, List[str]]:
126
- """
127
- Validate all templates against available columns.
133
+ """Validate all templates against available columns.
128
134
 
129
- Returns:
135
+ Returns
136
+ -------
130
137
  Dict mapping template string to list of missing columns
138
+
131
139
  """
132
140
  issues = {}
133
141
  for template in self.templates:
@@ -1,8 +1,8 @@
1
1
  """vLLM configuration management utilities."""
2
2
 
3
3
  import logging
4
- from typing import Dict, Any, Optional, Tuple, List
5
4
  from dataclasses import dataclass, field
5
+ from typing import Any, Dict, List, Optional, Tuple
6
6
 
7
7
  logger = logging.getLogger(__name__)
8
8
 
@@ -140,11 +140,12 @@ class VLLMConfigManager:
140
140
  def update_runtime_config(
141
141
  self, vllm_instance, old_config: Dict[str, Any], new_config: Dict[str, Any]
142
142
  ) -> Tuple[bool, Optional[Any]]:
143
- """
144
- Update vLLM configuration at runtime without reload.
143
+ """Update vLLM configuration at runtime without reload.
145
144
 
146
- Returns:
145
+ Returns
146
+ -------
147
147
  Tuple of (success, new_sampling_params)
148
+
148
149
  """
149
150
  change = self.analyze_config_change(old_config, new_config)
150
151
 
caption_flow/viewer.py CHANGED
@@ -1,24 +1,16 @@
1
1
  """TUI viewer for browsing CaptionFlow datasets with image preview using Urwid."""
2
2
 
3
- import asyncio
4
- import io
5
- import json
6
3
  import logging
7
4
  import os
8
- import sys
9
- from datetime import datetime
5
+ import tempfile
10
6
  from pathlib import Path
11
- from typing import Dict, Any, List, Optional, Tuple
12
7
  from urllib.parse import urlparse
13
- import tempfile
14
8
 
15
- import aiohttp
16
9
  import pandas as pd
17
- import pyarrow.parquet as pq
18
10
  import urwid
19
11
 
20
12
  try:
21
- from term_image.image import from_file, from_url, BaseImage
13
+ from term_image.image import BaseImage, from_file, from_url
22
14
  from term_image.widget import UrwidImage, UrwidImageScreen
23
15
 
24
16
  TERM_IMAGE_AVAILABLE = True
@@ -232,7 +224,7 @@ class DatasetViewer:
232
224
  filename = filename[:17] + "..."
233
225
 
234
226
  # Count captions
235
- caption_count = self._count_captions(item)
227
+ self._count_captions(item)
236
228
 
237
229
  text = f"{idx:3d}. {filename}"
238
230
  if idx == self.current_item_idx:
@@ -411,8 +403,8 @@ class DatasetViewer:
411
403
 
412
404
  try:
413
405
  # Download image
414
- import urllib.request
415
406
  import urllib.error
407
+ import urllib.request
416
408
 
417
409
  # Create request with user agent to avoid 403 errors
418
410
  request = urllib.request.Request(
@@ -4,13 +4,12 @@ import asyncio
4
4
  import json
5
5
  import logging
6
6
  import ssl
7
- import time
8
7
  from abc import ABC, abstractmethod
9
- from typing import Dict, Any, Optional
10
8
  from threading import Event
9
+ from typing import Any, Dict, Optional
11
10
 
12
11
  import websockets
13
- from websockets.client import WebSocketClientProtocol
12
+ from websockets.asyncio.client import ClientConnection
14
13
 
15
14
  logger = logging.getLogger(__name__)
16
15
 
@@ -18,6 +17,8 @@ logger = logging.getLogger(__name__)
18
17
  class BaseWorker(ABC):
19
18
  """Base class for all WebSocket-based workers with common connection logic."""
20
19
 
20
+ gpu_id: Optional[int] = None
21
+
21
22
  def __init__(self, config: Dict[str, Any]):
22
23
  self.config = config
23
24
  self.server_url = config["server"]
@@ -29,7 +30,7 @@ class BaseWorker(ABC):
29
30
 
30
31
  # State
31
32
  self.worker_id: Optional[str] = None
32
- self.websocket: Optional[WebSocketClientProtocol] = None
33
+ self.websocket: Optional[ClientConnection] = None
33
34
  self.running = False
34
35
  self.connected = Event()
35
36
  self.main_loop: Optional[asyncio.AbstractEventLoop] = None