caption-flow 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: caption-flow
3
- Version: 0.2.3
3
+ Version: 0.3.1
4
4
  Summary: Self-contained distributed community captioning system
5
5
  Author-email: bghira <bghira@users.github.com>
6
6
  License: MIT
@@ -33,6 +33,9 @@ Requires-Dist: arrow<2.0.0,>=1.3.0
33
33
  Requires-Dist: datasets<5.0.0,>=4.0.0
34
34
  Requires-Dist: boto3<2.0.0,>=1.40.11
35
35
  Requires-Dist: torchdata<0.12.0,>=0.11.0
36
+ Requires-Dist: textual<6.0.0,>=5.3.0
37
+ Requires-Dist: urwid<4.0.0,>=3.0.2
38
+ Requires-Dist: webshart<0.5.0,>=0.4.0
36
39
  Provides-Extra: dev
37
40
  Requires-Dist: pytest>=7.4.0; extra == "dev"
38
41
  Requires-Dist: pytest-asyncio>=0.21.0; extra == "dev"
@@ -44,12 +47,13 @@ Dynamic: license-file
44
47
 
45
48
  # CaptionFlow
46
49
 
47
- scalable, fault-tolerant **vLLM-powered image captioning**. this "first round" focuses on a fast websocket orchestrator plus lightweight gpu workers that batch requests through vLLM.
50
+ scalable, fault-tolerant **vLLM-powered image captioning**.
51
+
52
+ a fast websocket-based orchestrator paired with lightweight gpu workers achieves exceptional performance for batched requests through vLLM.
48
53
 
49
54
  * **orchestrator**: hands out work in chunked shards, collects captions, checkpoints progress, and keeps simple stats.
50
55
  * **workers (vLLM)**: connect to the orchestrator, stream in image samples, batch them, and generate 1..N captions per image using prompts supplied by the orchestrator.
51
56
  * **config-driven**: all components read YAML config; flags can override.
52
- * **tui monitor (optional)**: a monitor client is wired into the CLI; ship a `monitor` module to enable it.
53
57
 
54
58
  > no conda. just `venv` + `pip`.
55
59
 
@@ -103,6 +107,25 @@ caption-flow worker --config my-worker.yaml --server ws://your.hostname.address:
103
107
  caption-flow monitor --config my-monitor.yaml
104
108
  ```
105
109
 
110
+ 5. export the data
111
+
112
+ ```bash
113
+ % caption-flow export --help
114
+ Usage: caption-flow export [OPTIONS]
115
+
116
+ Export caption data to various formats.
117
+
118
+ Options:
119
+ --format [jsonl|json|csv|txt|huggingface_hub|all] Export format (default: jsonl)
120
+ ```
121
+
122
+ * **jsonl**: create JSON line file in the specified `--output` path
123
+ * **csv**: exports CSV-compatible data columns to the `--output` path containing incomplete metadata
124
+ * **json**: creates a `.json` file for each sample inside the `--output` subdirectory containing **complete** metadata; useful for webdatasets
125
+ * **txt**: creates `.txt` file for each sample inside the `--output` subdirectory containing ONLY captions
126
+ * **huggingface_hub**: creates a dataset on Hugging Face Hub, possibly `--private` and `--nsfw` where necessary
127
+ * **all**: creates all export formats in a specified `--output` directory
128
+
106
129
  ---
107
130
 
108
131
  ## how it’s wired
@@ -111,20 +134,11 @@ caption-flow monitor --config my-monitor.yaml
111
134
 
112
135
  * **websocket server** (default `0.0.0.0:8765`) with three client roles: workers, data-feeders, and admin.
113
136
  * **dataset control**: the orchestrator centrally defines the dataset (`huggingface` or `local`) and version/name. it chunk-slices shards and assigns work.
137
+ * **data serving to remote workers**: local files can be captioned by remote workers that don't have access to the same files, automatically.
114
138
  * **vLLM config broadcast**: model, tp size, dtype, max seq len, memory targets, batching, sampling params, and **inference prompts** are all pushed to workers; workers can apply many changes without a model reload.
115
139
  * **storage + checkpoints**: captions buffer to disk with periodic checkpoints. chunk state is tracked so restarts don’t double-work.
116
140
  * **auth**: token lists for `worker`, `monitor`, and `admin` roles.
117
141
 
118
- start flags you’ll likely use:
119
-
120
- ```text
121
- --config PATH # yaml config for the orchestrator
122
- --port INT, --host STR # bind controls
123
- --data-dir PATH # overrides storage.data_dir
124
- --cert PATH, --key PATH # enable TLS (or use --no-ssl for ws:// in dev)
125
- --vllm # use the vLLM-style orchestrator (webdataset/hf)
126
- ```
127
-
128
142
  ### vLLM worker
129
143
 
130
144
  * **one process per gpu**. select the device with `--gpu-id` (or `worker.gpu_id` in YAML).
@@ -132,27 +146,15 @@ start flags you’ll likely use:
132
146
  * **resilient**: detects disconnects, abandons the current chunk cleanly, clears queues, reconnects, and resumes.
133
147
  * **batched generate()**: images are resized down for consistent batching; each image can get multiple captions (one per prompt).
134
148
 
135
- start flags you’ll likely use:
136
-
137
- ```text
138
- --config PATH # yaml for the worker
139
- --server URL # ws(s)://host:port
140
- --token STR # must match an allowed worker token on the orchestrator
141
- --name STR # display name
142
- --batch-size INT # override vLLM batch size
143
- --vllm # use the vLLM worker implementation
144
- --gpu-id INT # which gpu to use
145
- --precision STR, --model STR # optional overrides for dtype/model
146
- --no-verify-ssl # accept self-signed certs in dev
147
- ```
148
-
149
- ### (optional) monitor
149
+ ---
150
150
 
151
- * a CLI entry exists for a TUI monitor; wire in a `monitor` module to enable it. config lives in `monitor.yaml` or inside `orchestrator.yaml` under `monitor:`.
151
+ ## dataset formats
152
152
 
153
- ---
153
+ * huggingface hub or local based URL list datasets that are compatible with the datasets library
154
+ * webdatasets shards containing full image data; also can be hosted on the hub
155
+ * local folder filled with images; orchestrator will serve the data to workers
154
156
 
155
- ## configuration
157
+ ## configuration path
156
158
 
157
159
  ### config discovery order
158
160
 
@@ -166,98 +168,6 @@ for any component, the CLI looks for config in this order (first match wins):
166
168
  6. any `$XDG_CONFIG_DIRS` entries under `caption-flow/`
167
169
  7. `./examples/<component>.yaml` (fallback)
168
170
 
169
- ### orchestrator.yaml (highlights)
170
-
171
- ```yaml
172
- orchestrator:
173
- host: 0.0.0.0
174
- port: 8765
175
- # ssl:
176
- # cert: /path/fullchain.pem
177
- # key: /path/privkey.pem
178
-
179
- dataset:
180
- type: huggingface
181
- path: <hf-dataset-or-local-path>
182
- name: <logical-name>
183
- version: "1.0"
184
-
185
- vllm:
186
- model: Qwen/Qwen2.5-VL-3B-Instruct
187
- tensor_parallel_size: 1
188
- max_model_len: 16384
189
- dtype: float16
190
- gpu_memory_utilization: 0.92
191
- enforce_eager: true
192
- disable_mm_preprocessor_cache: true
193
- limit_mm_per_prompt: { image: 1 }
194
-
195
- batch_size: 8
196
-
197
- sampling:
198
- temperature: 0.7
199
- top_p: 0.95
200
- max_tokens: 256
201
- repetition_penalty: 1.05
202
- skip_special_tokens: true
203
- stop: ["<|end|>", "<|endoftext|>", "<|im_end|>"]
204
-
205
- inference_prompts:
206
- - "describe this image in detail"
207
- - "provide a comprehensive description of the visual content"
208
- - "what are the key elements in this image?"
209
-
210
- storage:
211
- data_dir: ./caption_data
212
- checkpoint_dir: ./checkpoints
213
- caption_buffer_size: 100
214
- checkpoint_interval: 1000
215
-
216
- # chunking/queueing
217
- chunk_size: 1000
218
- chunks_per_request: 2
219
- chunk_buffer_multiplier: 3
220
- min_chunk_buffer: 10
221
-
222
- auth:
223
- worker_tokens:
224
- - { token: "example-worker-token", name: "Example Worker" }
225
- monitor_tokens:
226
- - { token: "letmein", name: "Default monitor" }
227
- admin_tokens:
228
- - { token: "admin-secret-2024", name: "Admin" }
229
- ```
230
-
231
- ### worker.yaml (highlights)
232
-
233
- ```yaml
234
- worker:
235
- server: ws://localhost:8765 # use wss:// in prod
236
- token: example-worker-token
237
- name: local-gpu
238
- gpu_id: 0
239
- vllm: true
240
-
241
- # local queues
242
- readahead_size: 256
243
- inference_queue_size: 128
244
- ```
245
-
246
- ### monitor.yaml (optional)
247
-
248
- ```yaml
249
- monitor:
250
- server: ws://localhost:8765
251
- token: letmein
252
- refresh_rate: 1.0
253
- show_contributors: true
254
- show_quality_metrics: true
255
- max_activity_items: 20
256
- show_chunk_progress: true
257
- show_worker_queues: true
258
- show_throughput_graph: true
259
- ```
260
-
261
171
  ---
262
172
 
263
173
  ## tls / certificates
@@ -300,66 +210,24 @@ PRs welcome. keep it simple and fast.
300
210
  ```
301
211
  ┌─────────────┐ WebSocket ┌─────────────┐
302
212
  │ Worker │◄──────────────────►│ │
303
- └─────────────┘ │ │ ┌──────────────┐
304
- Orchestrator│────►│Arrow/Parquet │
305
- ┌─────────────┐ │ │ Storage │
306
- Worker │◄──────────────────►│ │ └──────────────┘
307
- └─────────────┘ └─────────────┘
213
+ │ │ │ │ ┌──────────────┐
214
+ │◄───────────────────│ │────►│Arrow/Parquet │
215
+ └─────────────┘ HTTP (img data) Orchestrator│ │ Storage │
216
+ │ │ └──────────────┘
217
+ ┌─────────────┐ │ │
218
+ │ Worker │◄──────────────────►│ │
219
+ │ │ │ │
220
+ │ │◄───────────────────│ │
221
+ └─────────────┘ HTTP (img data) └─────────────┘
308
222
 
309
223
  ┌─────────────┐ │
310
224
  │ Monitor │◄──────────────────────────┘
311
225
  └─────────────┘
312
226
  ```
313
227
 
314
- ## Storage Schema
315
-
316
- ### captions.parquet
317
-
318
- - `job_id`: Unique job identifier
319
- * `dataset`: Dataset name
320
- * `shard`: Shard identifier
321
- * `item_key`: Item within shard
322
- * `caption`: Generated caption text
323
- * `contributor_id`: Worker who generated it
324
- * `timestamp`: Generation time
325
- * `quality_score`: Optional quality metric
326
-
327
- ### jobs.parquet
328
-
329
- - `job_id`: Unique identifier
330
- * `dataset`: Dataset name
331
- * `shard`: Shard identifier
332
- * `status`: pending/processing/completed/failed
333
- * `assigned_to`: Worker ID
334
- * `timestamp`: Status change time
335
-
336
- ### contributors.parquet
337
-
338
- - `contributor_id`: Unique identifier
339
- * `name`: Display name
340
- * `total_captions`: Lifetime count
341
- * `trust_level`: Quality tier (0-5)
342
-
343
- ## Development
344
-
345
- ```bash
346
- # Install with dev dependencies
347
- pip install -e ".[dev]"
348
-
349
- # Run tests
350
- pytest
351
-
352
- # Format code
353
- black src/
354
- ruff --fix src/
355
-
356
- # Type checking
357
- mypy src/
358
- ```
359
-
360
- ## Community Contribution
228
+ ## Community Clusters
361
229
 
362
- To contribute compute:
230
+ To contribute compute to a cluster:
363
231
 
364
232
  1. Install caption-flow: `pip install caption-flow`
365
233
  2. Get a worker token from the project maintainer
@@ -369,4 +237,4 @@ Your contributions will be tracked and attributed in the final dataset!
369
237
 
370
238
  ## License
371
239
 
372
- MIT
240
+ AGPLv3
@@ -0,0 +1,33 @@
1
+ caption_flow/__init__.py,sha256=pL77m-1slbrkzValJF7YfHpcp3yol6iTvSyjHpjJFOA,303
2
+ caption_flow/cli.py,sha256=t_cYCxJE7f5UtB3br2Es51JjO5KPsWM1JTdDXAxM_Lw,41371
3
+ caption_flow/models.py,sha256=2n6iphTEL62xK2FFcJM6axMsaE8KwsUv5Ak_cCF-TdQ,5652
4
+ caption_flow/monitor.py,sha256=bAt9EJqfPgT_KdbknGdCxwBRH002pRDgyUmYIj6Dyso,7885
5
+ caption_flow/orchestrator.py,sha256=34gZvaW14YZ7a7LagYOO3VKKwlbuS4aw0yoP1L8gwf0,36192
6
+ caption_flow/viewer.py,sha256=HxO98eHR1xtivG0dEdYC2U9T_RgeRfJqqTK-37u9bNM,20471
7
+ caption_flow/processors/__init__.py,sha256=hvq-OuAJWQe6hFglKe7QmkS8473k20FmxZDSxfXpCrg,423
8
+ caption_flow/processors/base.py,sha256=JlTqCHo5HRXrXMVzgle_6pNwh4HGHsF7jLF6PeSnWr0,6783
9
+ caption_flow/processors/huggingface.py,sha256=Q1PNQRXZT4NzEzGKtF1A1e8K_5-hgeM4G4lz_CZYuN4,41203
10
+ caption_flow/processors/local_filesystem.py,sha256=EYmsImbkqsIU7UZL2FijL0hotKLtPOtkzfwernQDSxA,27860
11
+ caption_flow/processors/webdataset.py,sha256=1JS3TmQe-fComBKzLPUMhUHx1T0Wf7m9nFkusM7tTXI,26152
12
+ caption_flow/storage/__init__.py,sha256=IVnzcSCPpPuyp-QLlgJirRZ9Sb3tR0F4sfuF5u2cNMk,36
13
+ caption_flow/storage/exporter.py,sha256=mFJqMDQ61cP-qcXe118_-oL1TUqULdQZ8LdjSTym44I,19697
14
+ caption_flow/storage/manager.py,sha256=KPExcKPuFVQSsBnfCBdne5PO4PwN4NTfd-EJQk13OY0,47459
15
+ caption_flow/utils/__init__.py,sha256=bDcO5uR455TKCQ2hX-_XcdTnRXDBaT8Yn4jWqWzfFsE,120
16
+ caption_flow/utils/auth.py,sha256=UrxX2n8OEEcfMD1Ey27TxGfrJFmUCpC59x-SCrQJoVE,2253
17
+ caption_flow/utils/caption_utils.py,sha256=esUMAdcCkNjRroZ0Bhxv0_yKlLtMf0XeDCTt-5k6bik,5309
18
+ caption_flow/utils/certificates.py,sha256=eu4blQZEkL9NRaY1ynQWg1asvDorRYhGRZea7STonJE,4635
19
+ caption_flow/utils/checkpoint_tracker.py,sha256=-nN5gLvXyMdKOCT2SNNL2Km6UYm2Hii9wuXeezWhwx4,3339
20
+ caption_flow/utils/chunk_tracker.py,sha256=lyso_V-ckYUVrDmlCCsaZKF9E_sR4ipef5W6BiVAS5M,19944
21
+ caption_flow/utils/image_processor.py,sha256=wmOExkVfM7OeuLfX3AwMefsH-TxL8TNcn22gp0NmJKY,1541
22
+ caption_flow/utils/json_utils.py,sha256=IiZYn8uCM-3pYmyIbX2fmaOIyutArn67SqAyp0ggNpU,5396
23
+ caption_flow/utils/prompt_template.py,sha256=AKp0diSZqNBMwZkpiTNjw8-bbQwHStr7QZTOJ7o1dC4,4345
24
+ caption_flow/utils/vllm_config.py,sha256=TC7Rmjk0zRKbBXbWUXrFL4Z58hzax_-4L0pXZn09hdM,6019
25
+ caption_flow/workers/base.py,sha256=2AGWERC5hbmO-0V_A1MUbgRVvRNN3blqGPyDokvvzmM,7575
26
+ caption_flow/workers/caption.py,sha256=4nETqDmHgb2dVgT7_zxzr3bcrTtWSxr3FSdB811boEw,38436
27
+ caption_flow/workers/data.py,sha256=0Tg8NE0wdONeMlivYQ4nvbcfWdLuU51O7vR8_YSnJgo,14813
28
+ caption_flow-0.3.1.dist-info/licenses/LICENSE,sha256=hIahDEOTzuHCU5J2nd07LWwkLW7Hko4UFO__ffsvB-8,34523
29
+ caption_flow-0.3.1.dist-info/METADATA,sha256=Bc8LSEqMhK1rmzhyu9-P-amdGpjML_AVWorr93jrYGo,9708
30
+ caption_flow-0.3.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
+ caption_flow-0.3.1.dist-info/entry_points.txt,sha256=KnVlyrGKZj6p2zNyuEnCx4Y6jvJ4V-mcfN0lddPKTlQ,55
32
+ caption_flow-0.3.1.dist-info/top_level.txt,sha256=_bXpKRutqded0FQ80dCChIz26ETV7tL4d4e2E_Y1FXs,13
33
+ caption_flow-0.3.1.dist-info/RECORD,,
@@ -1,222 +0,0 @@
1
- """Dataset loading utilities for WebDataset and HuggingFace."""
2
-
3
- import asyncio
4
- import shlex
5
- import logging
6
- from pathlib import Path
7
- from typing import List, Dict, Any, Generator, Optional, Tuple
8
- import json
9
-
10
- import webdataset as wds
11
- from huggingface_hub import HfFileSystem, get_token, hf_hub_url
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
- class DatasetLoader:
17
- """Handles loading datasets from various sources."""
18
-
19
- def __init__(
20
- self,
21
- dataset_path: str,
22
- dataset_type: str = "huggingface",
23
- split: str = "train",
24
- image_column: str = "image",
25
- cache_dir: Optional[Path] = None,
26
- ):
27
- """
28
- Initialize dataset loader.
29
-
30
- Args:
31
- dataset_path: Path to dataset (HF repo, local dir, etc.)
32
- dataset_type: Type of dataset ("huggingface", "webdataset", "local")
33
- split: Split to use for HuggingFace datasets (default: "train")
34
- image_column: Column name containing image data or URLs (default: "image")
35
- """
36
- self.dataset_path = dataset_path
37
- self.dataset_type = dataset_type
38
- self.split = split
39
- self.image_column = image_column
40
- self.token = get_token()
41
- self.dataset_format = None # Will be detected: "webdataset" or "huggingface_datasets"
42
-
43
- if not self.token and dataset_type == "huggingface":
44
- logger.warning("No HuggingFace token found; run `huggingface-cli login`")
45
-
46
- # Detect the actual format if it's a HuggingFace dataset
47
- if dataset_type == "huggingface":
48
- self.dataset_format = self._detect_dataset_format()
49
- logger.info(f"Detected dataset format: {self.dataset_format}")
50
-
51
- def _detect_dataset_format(self) -> str:
52
- """Detect whether it's WebDataset or HuggingFace datasets format."""
53
- fs = HfFileSystem(token=self.token)
54
-
55
- # Check for .tar files (WebDataset)
56
- tar_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.tar"))
57
- if tar_files:
58
- return "webdataset"
59
-
60
- # Check for .parquet files (Huggingface Arrow DB)
61
- parquet_files = list(fs.glob(f"hf://datasets/{self.dataset_path}/**/*.parquet"))
62
- if parquet_files:
63
- return "huggingface_datasets"
64
-
65
- raise AssertionError(f"Could not detect dataset format for {self.dataset_path}")
66
-
67
- def get_shard_list(self) -> List[str]:
68
- """Get list of all shards in the dataset."""
69
- if self.dataset_type == "huggingface":
70
- if self.dataset_format == "webdataset":
71
- return self._get_hf_webdataset_shards()
72
- else:
73
- logger.error(f"Unknown dataset format: {self.dataset_format}")
74
- return []
75
- elif self.dataset_type == "local":
76
- return self._get_local_shards()
77
- else:
78
- raise ValueError(f"Unknown dataset type: {self.dataset_type}")
79
-
80
- def _get_hf_webdataset_shards(self) -> List[str]:
81
- """Get shard URLs from HuggingFace WebDataset."""
82
- logger.info(f"Getting WebDataset shard list from HuggingFace: {self.dataset_path}")
83
-
84
- fs = HfFileSystem(token=self.token)
85
- files = [fs.resolve_path(p) for p in fs.glob(f"hf://datasets/{self.dataset_path}/**/*.tar")]
86
-
87
- urls = [hf_hub_url(f.repo_id, f.path_in_repo, repo_type="dataset") for f in files]
88
-
89
- logger.info(f"Found {len(urls)} WebDataset shards")
90
- return sorted(urls)
91
-
92
- def _get_local_shards(self) -> List[str]:
93
- """Get shard files from local directory."""
94
- path = Path(self.dataset_path)
95
- if not path.exists():
96
- raise ValueError(f"Local dataset path does not exist: {path}")
97
-
98
- shards = list(path.glob("*.tar"))
99
- logger.info(f"Found {len(shards)} local shards")
100
- return [str(s) for s in sorted(shards)]
101
-
102
- def load_shard(self, shard_url: str, processed_keys: Optional[set] = None) -> wds.DataPipeline:
103
- """
104
- Load a single shard as a WebDataset pipeline.
105
-
106
- Args:
107
- shard_url: URL or path to the shard
108
- processed_keys: Set of already processed keys to skip
109
- """
110
- if processed_keys is None:
111
- processed_keys = set()
112
-
113
- if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
114
- # Use curl with auth token for HuggingFace
115
- url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
116
- ds = wds.DataPipeline(
117
- wds.SimpleShardList(url_cmd),
118
- wds.tarfile_to_samples(),
119
- wds.to_tuple("__key__", "__url__", "jpg;png;jpeg;webp;jxl"),
120
- wds.select(lambda x: x[0] not in processed_keys),
121
- )
122
- else:
123
- # Local file access
124
- ds = wds.DataPipeline(
125
- wds.SimpleShardList(shard_url),
126
- wds.tarfile_to_samples(),
127
- wds.to_tuple("__key__", "__url__", "jpg;png;jpeg;webp;jxl"),
128
- wds.select(lambda x: x[0] not in processed_keys),
129
- )
130
-
131
- return ds
132
-
133
- def iterate_shard(
134
- self,
135
- shard_url: str,
136
- processed_keys: Optional[set] = None,
137
- unprocessed_ranges: Optional[List[Tuple[int, int]]] = None,
138
- ) -> Generator[Dict[str, Any], None, None]:
139
- """
140
- Iterate over items in a shard, returning full sample dictionaries.
141
-
142
- Args:
143
- shard_url: URL or identifier of the shard
144
- processed_keys: Set of already processed keys to skip
145
- unprocessed_ranges: Specific ranges to process (for range-based processing)
146
-
147
- Yields:
148
- Dictionary containing the full WebDataset sample
149
- """
150
- if processed_keys is None:
151
- processed_keys = set()
152
-
153
- if self.dataset_type == "huggingface" and self.dataset_format == "webdataset":
154
- # Use curl with auth token for HuggingFace
155
- url_cmd = f"pipe:curl -s -L -H 'Authorization:Bearer {shlex.quote(self.token)}' {shlex.quote(shard_url)} || true"
156
- ds = wds.DataPipeline(
157
- wds.SimpleShardList(url_cmd),
158
- wds.tarfile_to_samples(),
159
- wds.select(lambda x: x.get("__key__", "") not in processed_keys),
160
- )
161
- else:
162
- # Local file access
163
- ds = wds.DataPipeline(
164
- wds.SimpleShardList(shard_url),
165
- wds.tarfile_to_samples(),
166
- wds.select(lambda x: x.get("__key__", "") not in processed_keys),
167
- )
168
-
169
- # Return full samples as dictionaries
170
- for sample in ds:
171
- # Ensure it's a dict and has required fields
172
- if isinstance(sample, dict) and "__key__" in sample:
173
- yield sample
174
-
175
- def count_shard_items(self, shard_url: str, processed_keys: Optional[set] = None) -> int:
176
- """Count items in a shard (can be slow for large shards)."""
177
- count = 0
178
- try:
179
- for _ in self.iterate_shard(shard_url, processed_keys):
180
- count += 1
181
- except Exception as e:
182
- logger.error(f"Error counting shard {shard_url}: {e}")
183
- return count
184
-
185
- def get_dataset_info(self) -> Dict[str, Any]:
186
- """Get information about the dataset."""
187
- info = {
188
- "dataset_path": self.dataset_path,
189
- "dataset_type": self.dataset_type,
190
- "dataset_format": self.dataset_format,
191
- }
192
-
193
- if self.dataset_format == "huggingface_datasets":
194
- # Include cached metadata if available
195
- if hasattr(self, "_hf_metadata"):
196
- info.update(self._hf_metadata)
197
- else:
198
-
199
- try:
200
- # Try to get more info about the dataset
201
- dataset_info = load_dataset(
202
- self.dataset_path, split=self.split, streaming=True, token=self.token
203
- )
204
- # Get features info
205
- if hasattr(dataset_info, "features"):
206
- info["features"] = str(dataset_info.features)
207
-
208
- # Try to get total size (might not work for all datasets)
209
- try:
210
- # This might be expensive for large datasets
211
- total_examples = len(
212
- load_dataset(self.dataset_path, split=self.split, token=self.token)
213
- )
214
- info["total_examples"] = total_examples
215
- self._hf_total_items = total_examples
216
- except:
217
- info["total_examples"] = "unknown"
218
-
219
- except Exception as e:
220
- logger.error(f"Error getting dataset info: {e}")
221
-
222
- return info
@@ -1,67 +0,0 @@
1
- """Dataset metadata caching for efficient HuggingFace dataset handling."""
2
-
3
- import json
4
- import logging
5
- from pathlib import Path
6
- from typing import Dict, Any, Optional, List
7
- from datetime import datetime
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- class DatasetMetadataCache:
13
- """Caches dataset metadata to avoid repeated full iterations."""
14
-
15
- def __init__(self, cache_dir: Path):
16
- self.cache_dir = Path(cache_dir)
17
- self.cache_dir.mkdir(parents=True, exist_ok=True)
18
- self.cache_file = self.cache_dir / "dataset_metadata.json"
19
- self.metadata: Dict[str, Any] = {}
20
- self._load_cache()
21
-
22
- def _load_cache(self):
23
- """Load cached metadata from disk."""
24
- if self.cache_file.exists():
25
- try:
26
- with open(self.cache_file, "r") as f:
27
- self.metadata = json.load(f)
28
- logger.info(f"Loaded dataset metadata cache with {len(self.metadata)} datasets")
29
- except Exception as e:
30
- logger.error(f"Failed to load metadata cache: {e}")
31
- self.metadata = {}
32
-
33
- def _save_cache(self):
34
- """Save metadata cache to disk."""
35
- try:
36
- with open(self.cache_file, "w") as f:
37
- json.dump(self.metadata, f, indent=2)
38
- logger.debug("Saved dataset metadata cache")
39
- except Exception as e:
40
- logger.error(f"Failed to save metadata cache: {e}")
41
-
42
- def get_dataset_key(self, dataset_path: str, split: str) -> str:
43
- """Generate a unique key for a dataset+split combination."""
44
- return f"{dataset_path}:{split}"
45
-
46
- def get_metadata(self, dataset_path: str, split: str) -> Optional[Dict[str, Any]]:
47
- """Get cached metadata for a dataset."""
48
- key = self.get_dataset_key(dataset_path, split)
49
- return self.metadata.get(key)
50
-
51
- def set_metadata(self, dataset_path: str, split: str, metadata: Dict[str, Any]):
52
- """Cache metadata for a dataset."""
53
- key = self.get_dataset_key(dataset_path, split)
54
- metadata["cached_at"] = datetime.utcnow().isoformat()
55
- metadata["dataset_path"] = dataset_path
56
- metadata["split"] = split
57
- self.metadata[key] = metadata
58
- self._save_cache()
59
- logger.info(f"Cached metadata for {key}: {metadata.get('total_items', 0)} items")
60
-
61
- def invalidate(self, dataset_path: str, split: str):
62
- """Remove cached metadata for a dataset."""
63
- key = self.get_dataset_key(dataset_path, split)
64
- if key in self.metadata:
65
- del self.metadata[key]
66
- self._save_cache()
67
- logger.info(f"Invalidated metadata cache for {key}")
@@ -1,41 +0,0 @@
1
- """Job queue management."""
2
-
3
- import asyncio
4
- from typing import Optional
5
- from collections import deque
6
-
7
- from ..models import Job
8
-
9
-
10
- class JobQueue:
11
- """Priority job queue with backpressure."""
12
-
13
- def __init__(self):
14
- self.queue = deque()
15
- self.processing = set()
16
- self.lock = asyncio.Lock()
17
-
18
- async def add(self, job: Job):
19
- """Add job to queue."""
20
- async with self.lock:
21
- self.queue.append(job)
22
-
23
- async def get_next(self) -> Optional[Job]:
24
- """Get next available job."""
25
- async with self.lock:
26
- if self.queue:
27
- job = self.queue.popleft()
28
- self.processing.add(job.job_id)
29
- return job
30
- return None
31
-
32
- async def complete(self, job_id: str):
33
- """Mark job as complete."""
34
- async with self.lock:
35
- self.processing.discard(job_id)
36
-
37
- async def requeue(self, job: Job):
38
- """Requeue a job (for failures)."""
39
- async with self.lock:
40
- self.processing.discard(job.job_id)
41
- self.queue.appendleft(job) # Priority requeue