caption-flow 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
caption_flow/viewer.py ADDED
@@ -0,0 +1,594 @@
1
+ """TUI viewer for browsing CaptionFlow datasets with image preview using Urwid."""
2
+
3
+ import asyncio
4
+ import io
5
+ import json
6
+ import logging
7
+ import os
8
+ import sys
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+ from typing import Dict, Any, List, Optional, Tuple
12
+ from urllib.parse import urlparse
13
+ import tempfile
14
+
15
+ import aiohttp
16
+ import pandas as pd
17
+ import pyarrow.parquet as pq
18
+ import urwid
19
+
20
+ try:
21
+ from term_image.image import from_file, from_url, BaseImage
22
+ from term_image.widget import UrwidImage, UrwidImageScreen
23
+
24
+ TERM_IMAGE_AVAILABLE = True
25
+ except ImportError:
26
+ TERM_IMAGE_AVAILABLE = False
27
+ logging.warning("term-image not available. Install with: pip install term-image")
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class SelectableListItem(urwid.WidgetWrap):
33
+ """A selectable list item that can be highlighted."""
34
+
35
+ def __init__(self, content, on_select=None):
36
+ self.content = content
37
+ self.on_select = on_select
38
+ self._w = urwid.AttrMap(urwid.Text(content), "normal", focus_map="selected")
39
+ super().__init__(self._w)
40
+
41
+ def selectable(self):
42
+ return True
43
+
44
+ def keypress(self, size, key):
45
+ if key in ("enter", " ") and self.on_select:
46
+ self.on_select()
47
+ return None
48
+ return key
49
+
50
+
51
+ class DatasetViewer:
52
+ """Interactive dataset viewer with image preview using Urwid."""
53
+
54
+ palette = [
55
+ ("normal", "white", "black"),
56
+ ("selected", "black", "light gray"),
57
+ ("header", "white,bold", "dark blue"),
58
+ ("footer", "white", "dark gray"),
59
+ ("title", "yellow,bold", "black"),
60
+ ("error", "light red", "black"),
61
+ ("dim", "dark gray", "black"),
62
+ ("caption_title", "light cyan,bold", "black"),
63
+ ("caption_text", "white", "black"),
64
+ ]
65
+
66
+ def __init__(self, data_dir: Path):
67
+ self.data_dir = Path(data_dir)
68
+ self.captions_path = self.data_dir / "captions.parquet"
69
+
70
+ if not self.captions_path.exists():
71
+ raise FileNotFoundError(f"No captions file found at {self.captions_path}")
72
+
73
+ # Data
74
+ self.df = None
75
+ self.shards = []
76
+ self.current_shard_idx = 0
77
+ self.current_item_idx = 0
78
+ self.current_shard_items = []
79
+
80
+ # UI components
81
+ self.loop = None
82
+ self.screen = None
83
+ self.disable_images = False
84
+
85
+ # Widgets
86
+ self.shards_list = None
87
+ self.items_list = None
88
+ self.caption_box = None
89
+ self.image_box = None
90
+ self.image_widget = None
91
+ self.shards_box = None # Store LineBox reference
92
+ self.items_box = None # Store LineBox reference
93
+
94
+ # Image handling
95
+ self.current_image_url = None
96
+ self.session = None
97
+ self.temp_files = []
98
+
99
+ def load_data(self):
100
+ """Load dataset synchronously."""
101
+ logger.info("Loading dataset...")
102
+
103
+ # Read parquet file
104
+ self.df = pd.read_parquet(self.captions_path)
105
+
106
+ # Get unique shards
107
+ if "shard" in self.df.columns:
108
+ self.shards = sorted(self.df["shard"].unique())
109
+ else:
110
+ # If no shard column, create a dummy one
111
+ self.shards = ["all"]
112
+ self.df["shard"] = "all"
113
+
114
+ # Load first shard
115
+ self._load_shard(0)
116
+
117
+ logger.info(f"Loaded {len(self.df)} items across {len(self.shards)} shards")
118
+
119
+ def _load_shard(self, shard_idx: int):
120
+ """Load items from a specific shard."""
121
+ if 0 <= shard_idx < len(self.shards):
122
+ self.current_shard_idx = shard_idx
123
+ shard_name = self.shards[shard_idx]
124
+
125
+ # Get items in this shard
126
+ shard_df = self.df[self.df["shard"] == shard_name]
127
+
128
+ # Sort by item_index if available, otherwise by job_id
129
+ if "item_index" in shard_df.columns:
130
+ shard_df = shard_df.sort_values("item_index")
131
+ elif "job_id" in shard_df.columns:
132
+ shard_df = shard_df.sort_values("job_id")
133
+
134
+ self.current_shard_items = shard_df.to_dict("records")
135
+ self.current_item_idx = 0
136
+
137
+ # Reset image
138
+ self.current_image_url = None
139
+
140
+ def create_ui(self):
141
+ """Create the Urwid UI."""
142
+ # Header
143
+ header = urwid.AttrMap(urwid.Text("CaptionFlow Dataset Viewer", align="center"), "header")
144
+
145
+ # Shards list
146
+ self._create_shards_list()
147
+ self.shards_box = urwid.LineBox(self.shards_list, title="Shards", title_attr="title")
148
+
149
+ # Items list
150
+ self._create_items_list()
151
+ self.items_box = urwid.LineBox(
152
+ self.items_list, title=f"Items ({len(self.current_shard_items)})", title_attr="title"
153
+ )
154
+
155
+ # Caption display
156
+ self.caption_text = urwid.Text("")
157
+ self.caption_box = urwid.LineBox(
158
+ urwid.Filler(self.caption_text, valign="top"), title="Captions", title_attr="title"
159
+ )
160
+
161
+ # Image display
162
+ if TERM_IMAGE_AVAILABLE and not self.disable_images:
163
+ self.image_placeholder = urwid.Text("No image loaded", align="center")
164
+ self.image_filler = urwid.Filler(self.image_placeholder)
165
+ else:
166
+ msg = "Image preview disabled" if self.disable_images else "term-image not installed"
167
+ self.image_filler = urwid.Filler(urwid.Text(msg, align="center"))
168
+
169
+ self.image_box = urwid.LineBox(self.image_filler, title="Image Preview", title_attr="title")
170
+
171
+ # Preview area (captions + image) - ADJUSTED WEIGHTS
172
+ preview = urwid.Pile(
173
+ [
174
+ ("weight", 1, self.caption_box), # Reduced from 2 to 1
175
+ ("weight", 1, self.image_box), # Increased from 1 to 1 (equal space)
176
+ ]
177
+ )
178
+
179
+ # Main body columns - ADJUSTED WEIGHTS AND FIXED WIDTH
180
+ body = urwid.Columns(
181
+ [
182
+ (12, self.shards_box), # Reduced from 15 to 12
183
+ (30, self.items_box), # Reduced from 35 to 30
184
+ ("weight", 1, preview), # More space for preview
185
+ ]
186
+ )
187
+
188
+ # Footer
189
+ footer = urwid.AttrMap(
190
+ urwid.Text(
191
+ "↑/↓/j/k: Navigate | ←/→/h/l: Shards | Space/b: Page | q: Quit", align="center"
192
+ ),
193
+ "footer",
194
+ )
195
+
196
+ # Main layout
197
+ self.main_widget = urwid.Frame(body=body, header=header, footer=footer)
198
+
199
+ def _create_shards_list(self):
200
+ """Create the shards list widget."""
201
+ shard_items = []
202
+ for idx, shard in enumerate(self.shards):
203
+ count = len(self.df[self.df["shard"] == shard])
204
+ # Truncate shard name if too long for narrower column
205
+ shard_display = shard if len(shard) <= 8 else shard[:5] + "..."
206
+ text = f"{shard_display} ({count})"
207
+ if idx == self.current_shard_idx:
208
+ text = f"▶ {text}"
209
+ item = SelectableListItem(text, on_select=lambda i=idx: self._select_shard(i))
210
+ shard_items.append(item)
211
+
212
+ self.shards_walker = urwid.SimpleFocusListWalker(shard_items)
213
+ self.shards_list = urwid.ListBox(self.shards_walker)
214
+
215
+ # Set focus to current shard
216
+ if self.shards_walker:
217
+ self.shards_walker.set_focus(self.current_shard_idx)
218
+
219
+ def _create_items_list(self):
220
+ """Create the items list widget."""
221
+ item_widgets = []
222
+ for idx, item in enumerate(self.current_shard_items):
223
+ # Extract filename
224
+ filename = item.get("filename", "")
225
+ if not filename and "url" in item:
226
+ filename = os.path.basename(urlparse(item["url"]).path)
227
+ if not filename:
228
+ filename = f"item_{idx}"
229
+
230
+ # Truncate more aggressively for narrower fixed width
231
+ if len(filename) > 20:
232
+ filename = filename[:17] + "..."
233
+
234
+ # Count captions
235
+ caption_count = self._count_captions(item)
236
+
237
+ text = f"{idx:3d}. {filename}"
238
+ if idx == self.current_item_idx:
239
+ text = f"▶ {text}"
240
+
241
+ item_widget = SelectableListItem(text, on_select=lambda i=idx: self._select_item(i))
242
+ item_widgets.append(item_widget)
243
+
244
+ self.items_walker = urwid.SimpleFocusListWalker(item_widgets)
245
+ self.items_list = urwid.ListBox(self.items_walker)
246
+
247
+ # Set focus to current item
248
+ if self.items_walker and self.current_item_idx < len(self.items_walker):
249
+ self.items_walker.set_focus(self.current_item_idx)
250
+
251
+ def _count_captions(self, item):
252
+ """Count the number of captions in an item."""
253
+ caption_count = 0
254
+ for field in ["captions", "descriptions", "alt_text", "long_caption", "short_caption"]:
255
+ if field in item and item[field] is not None:
256
+ try:
257
+ value = item[field]
258
+ # Handle numpy arrays
259
+ if hasattr(value, "__array__"):
260
+ value = value.tolist()
261
+
262
+ if isinstance(value, list):
263
+ # Filter out None/empty values
264
+ non_empty = [v for v in value if v is not None and str(v).strip()]
265
+ caption_count += len(non_empty)
266
+ elif value and str(value).strip():
267
+ caption_count += 1
268
+ except:
269
+ pass
270
+ return caption_count
271
+
272
+ def _select_shard(self, idx):
273
+ """Select a shard."""
274
+ if idx != self.current_shard_idx:
275
+ self._load_shard(idx)
276
+ self._update_ui()
277
+
278
+ def _select_item(self, idx):
279
+ """Select an item."""
280
+ if idx != self.current_item_idx:
281
+ self.current_item_idx = idx
282
+ self._update_preview()
283
+
284
+ def _update_ui(self):
285
+ """Update the entire UI."""
286
+ # Update shards list
287
+ self._create_shards_list()
288
+ self.shards_box.original_widget = self.shards_list
289
+
290
+ # Update items list
291
+ self._create_items_list()
292
+ self.items_box.original_widget = self.items_list
293
+ self.items_box.set_title(f"Items ({len(self.current_shard_items)})")
294
+
295
+ # Update preview
296
+ self._update_preview()
297
+
298
+ def _update_preview(self):
299
+ """Update the preview area (captions and image)."""
300
+ if not self.current_shard_items:
301
+ return
302
+
303
+ item = self.current_shard_items[self.current_item_idx]
304
+
305
+ # Update captions
306
+ caption_text = self._format_captions(item)
307
+ self.caption_text.set_text(caption_text)
308
+
309
+ # Update image
310
+ if TERM_IMAGE_AVAILABLE and not self.disable_images:
311
+ self._update_image(item)
312
+
313
+ def _extract_caption_values(self, value):
314
+ """Extract caption values from various formats."""
315
+ results = []
316
+
317
+ # Handle numpy arrays
318
+ if hasattr(value, "__array__"):
319
+ value = value.tolist()
320
+
321
+ # Handle string representation of lists/arrays
322
+ if isinstance(value, str):
323
+ # Try to parse string representations like "['caption1', 'caption2']"
324
+ if value.startswith("[") and value.endswith("]"):
325
+ try:
326
+ import ast
327
+
328
+ parsed = ast.literal_eval(value)
329
+ if isinstance(parsed, list):
330
+ results.extend([str(v) for v in parsed if v])
331
+ else:
332
+ results.append(str(parsed))
333
+ except:
334
+ # If parsing fails, treat as regular string
335
+ results.append(value)
336
+ else:
337
+ results.append(value)
338
+ elif isinstance(value, list):
339
+ # Process each item in the list
340
+ for item in value:
341
+ if item is not None and str(item).strip():
342
+ # Recursively extract in case of nested lists
343
+ sub_values = self._extract_caption_values(item)
344
+ results.extend(sub_values)
345
+ elif value is not None and str(value).strip():
346
+ results.append(str(value))
347
+
348
+ return results
349
+
350
+ def _format_captions(self, item):
351
+ """Format captions for display."""
352
+ parts = []
353
+
354
+ # Standard caption fields to check
355
+ caption_fields = [
356
+ ("captions", "Captions"),
357
+ ("descriptions", "Descriptions"),
358
+ ("alt_text", "Alt Text"),
359
+ ("long_caption", "Long Caption"),
360
+ ("short_caption", "Short Caption"),
361
+ ]
362
+
363
+ for field_name, display_name in caption_fields:
364
+ if field_name in item and item[field_name] is not None:
365
+ # Extract all caption values
366
+ captions = self._extract_caption_values(item[field_name])
367
+
368
+ if captions:
369
+ parts.append(f"\n{display_name}:")
370
+ for i, caption in enumerate(captions, 1):
371
+ # Wrap text for readability - adjust for wider preview area
372
+ wrapped = self._wrap_text(caption, 100)
373
+ parts.append(f" {i}. {wrapped}")
374
+
375
+ # Add metadata
376
+ metadata_parts = []
377
+ if "job_id" in item:
378
+ metadata_parts.append(f"Job: {item['job_id']}")
379
+ if "contributor_id" in item:
380
+ metadata_parts.append(f"By: {item['contributor_id']}")
381
+
382
+ if metadata_parts:
383
+ parts.insert(0, " | ".join(metadata_parts))
384
+
385
+ return "\n".join(parts) if parts else "No captions available"
386
+
387
+ def _wrap_text(self, text, width):
388
+ """Simple text wrapping."""
389
+ import textwrap
390
+
391
+ lines = textwrap.wrap(text, width=width)
392
+ return "\n ".join(lines)
393
+
394
+ def _update_image(self, item):
395
+ """Update the image display."""
396
+ url = item.get("url", "")
397
+
398
+ # Skip if same URL
399
+ if url == self.current_image_url:
400
+ return
401
+
402
+ self.current_image_url = url
403
+
404
+ if not url:
405
+ self.image_filler.body = urwid.Text("No URL available", align="center")
406
+ return
407
+
408
+ # Show loading message
409
+ self.image_filler.body = urwid.Text("Loading image...", align="center")
410
+ self.loop.draw_screen()
411
+
412
+ try:
413
+ # Download image
414
+ import urllib.request
415
+ import urllib.error
416
+
417
+ # Create request with user agent to avoid 403 errors
418
+ request = urllib.request.Request(
419
+ url,
420
+ headers={
421
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
422
+ },
423
+ )
424
+
425
+ with urllib.request.urlopen(request, timeout=10) as response:
426
+ image_data = response.read()
427
+
428
+ # Save to temp file
429
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
430
+ tmp.write(image_data)
431
+ tmp_path = tmp.name
432
+ self.temp_files.append(tmp_path)
433
+
434
+ # Create term_image from file
435
+ image = from_file(tmp_path)
436
+
437
+ # Dynamic sizing based on terminal size and actual space
438
+ if self.loop and self.screen:
439
+ cols, rows = self.screen.get_cols_rows()
440
+
441
+ # Calculate available space more accurately
442
+ # Left columns take 42 chars (12 + 30) + borders/padding
443
+ available_width = max(30, cols - 46)
444
+
445
+ # Image box has weight 1 out of 2 total in the preview pile (equal space)
446
+ # But account for borders, header, footer (about 6 rows total)
447
+ preview_height = rows - 6
448
+ image_height = preview_height // 2 # 1/2 of preview area now
449
+ available_height = max(20, image_height - 2) # Account for box borders
450
+
451
+ # Try to preserve aspect ratio while maximizing use of space
452
+ # Prioritize HEIGHT to avoid vertical cropping
453
+ try:
454
+ # First try setting only height, let width auto-adjust
455
+ image.set_size(height=available_height)
456
+ except:
457
+ # If that fails, try setting both but prioritize height
458
+ try:
459
+ image.set_size(width=available_width, height=available_height)
460
+ except:
461
+ # Last resort - use fixed size
462
+ image.set_size(60, 30)
463
+ else:
464
+ # Fallback size
465
+ image.set_size(60, 30)
466
+
467
+ # Create UrwidImage widget without upscaling to maintain proper bounds
468
+ self.image_widget = UrwidImage(image, upscale=False)
469
+ # Center the image in the available space
470
+ self.image_filler.body = urwid.Padding(self.image_widget, align="center")
471
+
472
+ except urllib.error.HTTPError as e:
473
+ self.image_filler.body = urwid.Text(f"HTTP Error {e.code}: {e.reason}", align="center")
474
+ except urllib.error.URLError as e:
475
+ self.image_filler.body = urwid.Text(f"URL Error: {str(e)}", align="center")
476
+ except Exception as e:
477
+ self.image_filler.body = urwid.Text(f"Error: {str(e)}", align="center")
478
+
479
+ def handle_input(self, key):
480
+ """Handle keyboard input."""
481
+ if key in ("q", "Q"):
482
+ self.cleanup()
483
+ raise urwid.ExitMainLoop()
484
+ elif key in ("down", "j"):
485
+ self._navigate_items(1)
486
+ elif key in ("up", "k"):
487
+ self._navigate_items(-1)
488
+ elif key in ("left", "h"):
489
+ self._navigate_shards(-1)
490
+ elif key in ("right", "l"):
491
+ self._navigate_shards(1)
492
+ elif key == " ": # Space - page down
493
+ self._navigate_items(10)
494
+ elif key == "b": # b - page up
495
+ self._navigate_items(-10)
496
+ elif key == "page down":
497
+ self._navigate_items(10)
498
+ elif key == "page up":
499
+ self._navigate_items(-10)
500
+
501
+ def _navigate_items(self, delta):
502
+ """Navigate through items."""
503
+ if not self.current_shard_items:
504
+ return
505
+
506
+ new_idx = self.current_item_idx + delta
507
+ new_idx = max(0, min(new_idx, len(self.current_shard_items) - 1))
508
+
509
+ if new_idx != self.current_item_idx:
510
+ self.current_item_idx = new_idx
511
+
512
+ # Update items list focus
513
+ if self.items_walker and new_idx < len(self.items_walker):
514
+ self.items_walker.set_focus(new_idx)
515
+
516
+ # Update the list display
517
+ self._create_items_list()
518
+ self.items_box.original_widget = self.items_list
519
+
520
+ self._update_preview()
521
+
522
+ def _navigate_shards(self, delta):
523
+ """Navigate through shards."""
524
+ new_idx = self.current_shard_idx + delta
525
+ new_idx = max(0, min(new_idx, len(self.shards) - 1))
526
+
527
+ if new_idx != self.current_shard_idx:
528
+ self._select_shard(new_idx)
529
+
530
+ def cleanup(self):
531
+ """Clean up resources."""
532
+ # Clean up temp files
533
+ for tmp_file in self.temp_files:
534
+ try:
535
+ os.unlink(tmp_file)
536
+ except:
537
+ pass
538
+
539
+ def run(self):
540
+ """Run the viewer."""
541
+ # Load data first
542
+ self.load_data()
543
+
544
+ # Create UI
545
+ self.create_ui()
546
+
547
+ # Create event loop and screen
548
+ if TERM_IMAGE_AVAILABLE:
549
+ self.screen = UrwidImageScreen()
550
+ else:
551
+ self.screen = urwid.raw_display.Screen()
552
+
553
+ # Main loop
554
+ self.loop = urwid.MainLoop(
555
+ self.main_widget,
556
+ palette=self.palette,
557
+ screen=self.screen,
558
+ unhandled_input=self.handle_input,
559
+ )
560
+
561
+ try:
562
+ self.loop.run()
563
+ finally:
564
+ self.cleanup()
565
+
566
+
567
+ def main():
568
+ """Main entry point."""
569
+ import argparse
570
+
571
+ parser = argparse.ArgumentParser(description="View CaptionFlow dataset")
572
+ parser.add_argument("data_dir", type=Path, help="Dataset directory")
573
+ parser.add_argument("--no-images", action="store_true", help="Disable image preview")
574
+
575
+ args = parser.parse_args()
576
+
577
+ # Setup logging
578
+ logging.basicConfig(level=logging.INFO)
579
+
580
+ # Create and run viewer
581
+ try:
582
+ viewer = DatasetViewer(args.data_dir)
583
+ if args.no_images:
584
+ viewer.disable_images = True
585
+ viewer.run()
586
+ except KeyboardInterrupt:
587
+ pass
588
+ except Exception as e:
589
+ logger.error(f"Error: {e}")
590
+ raise
591
+
592
+
593
+ if __name__ == "__main__":
594
+ main()