napsack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- label/__init__.py +0 -0
- label/__main__.py +210 -0
- label/caption_matching.py +150 -0
- label/clients/__init__.py +28 -0
- label/clients/bigquery.py +171 -0
- label/clients/client.py +28 -0
- label/clients/gemini.py +84 -0
- label/clients/vllm.py +137 -0
- label/discovery.py +116 -0
- label/models.py +533 -0
- label/processor.py +642 -0
- label/prompts/default.txt +98 -0
- label/prompts/screenshots_only.txt +87 -0
- label/video.py +323 -0
- label/visualizer.py +280 -0
- napsack/__init__.py +2 -0
- napsack-0.1.0.dist-info/METADATA +105 -0
- napsack-0.1.0.dist-info/RECORD +43 -0
- napsack-0.1.0.dist-info/WHEEL +5 -0
- napsack-0.1.0.dist-info/entry_points.txt +3 -0
- napsack-0.1.0.dist-info/top_level.txt +3 -0
- record/__init__.py +0 -0
- record/__main__.py +413 -0
- record/constants.py +112 -0
- record/handlers/__init__.py +7 -0
- record/handlers/accessibility.py +227 -0
- record/handlers/input_event.py +269 -0
- record/handlers/screenshot.py +87 -0
- record/models/__init__.py +16 -0
- record/models/aggregation.py +51 -0
- record/models/event.py +35 -0
- record/models/event_queue.py +503 -0
- record/models/image.py +23 -0
- record/models/image_queue.py +118 -0
- record/monitor/__init__.py +9 -0
- record/monitor/reader.py +101 -0
- record/monitor/summary.py +402 -0
- record/monitor/viewer.py +393 -0
- record/sanitize.py +224 -0
- record/workers/__init__.py +10 -0
- record/workers/aggregation.py +157 -0
- record/workers/save.py +104 -0
- record/workers/screenshot.py +136 -0
label/__init__.py
ADDED
|
File without changes
|
label/__main__.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
import argparse
|
|
3
|
+
from dotenv import load_dotenv
|
|
4
|
+
|
|
5
|
+
from label.discovery import discover_sessions, discover_screenshots_sessions, create_single_config
|
|
6
|
+
from label.clients import create_client
|
|
7
|
+
from label.processor import Processor
|
|
8
|
+
from label.visualizer import Visualizer
|
|
9
|
+
|
|
10
|
+
load_dotenv()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def parse_args():
|
|
14
|
+
p = argparse.ArgumentParser(description="Process session recordings with VLM")
|
|
15
|
+
|
|
16
|
+
session_group = p.add_mutually_exclusive_group(required=True)
|
|
17
|
+
session_group.add_argument("--session", type=Path)
|
|
18
|
+
session_group.add_argument("--sessions-root", type=Path)
|
|
19
|
+
|
|
20
|
+
p.add_argument("--chunk-duration", type=int, default=60, help="Chunk duration in seconds")
|
|
21
|
+
p.add_argument("--fps", type=int, default=1, help="Frames per second for video processing")
|
|
22
|
+
|
|
23
|
+
p.add_argument("--screenshots-only", action="store_true", help="Process screenshots folder only without aggregations or annotations")
|
|
24
|
+
p.add_argument("--image-extensions", nargs="+", default=[".jpg", ".jpeg", ".png"], help="Image file extensions to consider")
|
|
25
|
+
p.add_argument("--max-time-gap", type=float, default=300.0, help="Maximum time gap (seconds) between images before forcing a video split (default: 120 = 2 minutes)")
|
|
26
|
+
p.add_argument("--prompt-file", default=None, help="Path to prompt file (default: prompts/default.txt or prompts/screenshots_only.txt if screenshots only)")
|
|
27
|
+
p.add_argument("--hash-cache", type=str, default=None, help="Path to hash_cache.json for deduplicating consecutive similar images")
|
|
28
|
+
p.add_argument("--dedupe-threshold", type=int, default=1, help="Hamming distance threshold for deduplication (drop if <= threshold, default: 1)")
|
|
29
|
+
p.add_argument("--annotate", action="store_true", help="Annotate videos with cursor positions and clicks (only for standard processing)")
|
|
30
|
+
p.add_argument("--skip-existing", action="store_true", help="Skip sessions that have already been processed")
|
|
31
|
+
p.add_argument("--visualize", action="store_true", help="Create annotated video visualizations after processing")
|
|
32
|
+
p.add_argument("--encode-only", action="store_true", help="Only encode videos (create chunks), skip labeling. Useful for pre-processing before running the full pipeline.")
|
|
33
|
+
|
|
34
|
+
p.add_argument("--client", choices=["gemini", "vllm", "bigquery"], default="gemini")
|
|
35
|
+
p.add_argument("--model", default="")
|
|
36
|
+
p.add_argument("--encode-workers", type=int, default=8, help="Number of parallel workers for video encoding")
|
|
37
|
+
p.add_argument("--label-workers", type=int, default=4, help="Number of parallel workers for VLM labeling")
|
|
38
|
+
|
|
39
|
+
vllm_group = p.add_argument_group("vLLM Options")
|
|
40
|
+
vllm_group.add_argument("--vllm-url")
|
|
41
|
+
|
|
42
|
+
bq_group = p.add_argument_group("BigQuery Options")
|
|
43
|
+
bq_group.add_argument("--bq-project", help="GCP project ID for AI Platform endpoint")
|
|
44
|
+
bq_group.add_argument("--bq-bucket-name", help="GCS bucket name for uploading videos")
|
|
45
|
+
bq_group.add_argument("--bq-gcs-prefix", default="video_chunks", help="Prefix/folder path in GCS bucket")
|
|
46
|
+
bq_group.add_argument("--bq-object-table-location", default="us", help="Object table location (e.g., 'us' or 'us.screenomics-gemini')")
|
|
47
|
+
|
|
48
|
+
args = p.parse_args()
|
|
49
|
+
|
|
50
|
+
if not args.model:
|
|
51
|
+
if args.client == 'gemini':
|
|
52
|
+
args.model = 'gemini-3-flash-preview'
|
|
53
|
+
elif args.client == 'vllm':
|
|
54
|
+
args.model = 'Qwen/Qwen3-VL-8B-Thinking-FP8'
|
|
55
|
+
elif args.client == 'bigquery':
|
|
56
|
+
args.model = 'dataset.model' # Placeholder - user must provide full model reference
|
|
57
|
+
if not args.prompt_file:
|
|
58
|
+
args.prompt_file = "prompts/screenshots_only.txt" if args.screenshots_only else "prompts/default.txt"
|
|
59
|
+
|
|
60
|
+
return args
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def setup_configs(args):
|
|
64
|
+
if args.session:
|
|
65
|
+
configs = [create_single_config(
|
|
66
|
+
args.session,
|
|
67
|
+
args.chunk_duration,
|
|
68
|
+
args.screenshots_only,
|
|
69
|
+
tuple(args.image_extensions),
|
|
70
|
+
)]
|
|
71
|
+
else:
|
|
72
|
+
if args.screenshots_only:
|
|
73
|
+
configs = discover_screenshots_sessions(
|
|
74
|
+
args.sessions_root,
|
|
75
|
+
args.chunk_duration,
|
|
76
|
+
tuple(args.image_extensions),
|
|
77
|
+
)
|
|
78
|
+
else:
|
|
79
|
+
configs = discover_sessions(
|
|
80
|
+
args.sessions_root,
|
|
81
|
+
args.chunk_duration,
|
|
82
|
+
args.skip_existing,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
if not configs:
|
|
86
|
+
print(f"No sessions found in {args.sessions_root}")
|
|
87
|
+
return []
|
|
88
|
+
|
|
89
|
+
return configs
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def process_with_gemini(args, configs):
|
|
93
|
+
client = create_client(
|
|
94
|
+
'gemini',
|
|
95
|
+
model_name=args.model,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
processor = Processor(
|
|
99
|
+
client=client,
|
|
100
|
+
encode_workers=args.encode_workers,
|
|
101
|
+
label_workers=args.label_workers,
|
|
102
|
+
screenshots_only=args.screenshots_only,
|
|
103
|
+
prompt_file=args.prompt_file,
|
|
104
|
+
max_time_gap=args.max_time_gap,
|
|
105
|
+
hash_cache_path=args.hash_cache,
|
|
106
|
+
dedupe_threshold=args.dedupe_threshold,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return processor.process_sessions(
|
|
110
|
+
configs,
|
|
111
|
+
fps=args.fps,
|
|
112
|
+
annotate=args.annotate and not args.screenshots_only,
|
|
113
|
+
encode_only=args.encode_only,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def process_with_vllm(args, configs):
|
|
118
|
+
client = create_client(
|
|
119
|
+
'vllm',
|
|
120
|
+
base_url=args.vllm_url if args.vllm_url.endswith('/v1') else f"{args.vllm_url}/v1",
|
|
121
|
+
model_name=args.model
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
processor = Processor(
|
|
125
|
+
client=client,
|
|
126
|
+
encode_workers=args.encode_workers,
|
|
127
|
+
label_workers=args.label_workers,
|
|
128
|
+
screenshots_only=args.screenshots_only,
|
|
129
|
+
prompt_file=args.prompt_file,
|
|
130
|
+
max_time_gap=args.max_time_gap,
|
|
131
|
+
hash_cache_path=args.hash_cache,
|
|
132
|
+
dedupe_threshold=args.dedupe_threshold,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return processor.process_sessions(
|
|
136
|
+
configs,
|
|
137
|
+
fps=args.fps,
|
|
138
|
+
annotate=args.annotate and not args.screenshots_only,
|
|
139
|
+
encode_only=args.encode_only,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def process_with_bigquery(args, configs):
|
|
144
|
+
client = create_client(
|
|
145
|
+
'bigquery',
|
|
146
|
+
model_name=args.model,
|
|
147
|
+
bucket_name=args.bq_bucket_name,
|
|
148
|
+
gcs_prefix=args.bq_gcs_prefix,
|
|
149
|
+
object_table_location=args.bq_object_table_location,
|
|
150
|
+
project_id=args.bq_project,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
processor = Processor(
|
|
154
|
+
client=client,
|
|
155
|
+
encode_workers=args.encode_workers,
|
|
156
|
+
label_workers=args.label_workers,
|
|
157
|
+
screenshots_only=args.screenshots_only,
|
|
158
|
+
prompt_file=args.prompt_file,
|
|
159
|
+
max_time_gap=args.max_time_gap,
|
|
160
|
+
hash_cache_path=args.hash_cache,
|
|
161
|
+
dedupe_threshold=args.dedupe_threshold,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return processor.process_sessions(
|
|
165
|
+
configs,
|
|
166
|
+
fps=args.fps,
|
|
167
|
+
annotate=args.annotate and not args.screenshots_only,
|
|
168
|
+
encode_only=args.encode_only,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def main():
|
|
173
|
+
args = parse_args()
|
|
174
|
+
|
|
175
|
+
configs = setup_configs(args)
|
|
176
|
+
if not configs:
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
print(f"Processing {len(configs)} sessions")
|
|
180
|
+
|
|
181
|
+
if args.client == 'gemini':
|
|
182
|
+
results = process_with_gemini(args, configs)
|
|
183
|
+
elif args.client == 'vllm':
|
|
184
|
+
results = process_with_vllm(args, configs)
|
|
185
|
+
elif args.client == 'bigquery':
|
|
186
|
+
results = process_with_bigquery(args, configs)
|
|
187
|
+
else:
|
|
188
|
+
raise ValueError(f"Unknown client: {args.client}")
|
|
189
|
+
|
|
190
|
+
print(f"✓ Processed {len(results)} sessions")
|
|
191
|
+
|
|
192
|
+
if args.visualize:
|
|
193
|
+
print("\nCreating visualizations...")
|
|
194
|
+
visualizer = Visualizer(args.annotate)
|
|
195
|
+
|
|
196
|
+
for config in configs:
|
|
197
|
+
if not config.matched_captions_jsonl.exists():
|
|
198
|
+
print(f"Skipping Visualizing {config.session_id}: no data.jsonl")
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
output = config.session_folder / "annotated.mp4"
|
|
203
|
+
visualizer.visualize(config.session_folder, output, args.fps)
|
|
204
|
+
print(f"✓ {config.session_id}: {output}")
|
|
205
|
+
except Exception as e:
|
|
206
|
+
print(f"✗ {config.session_id}: {e}")
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
if __name__ == '__main__':
|
|
210
|
+
main()
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import List, Dict, Any, Optional
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def match_captions_with_events(
|
|
7
|
+
captions_path: Path,
|
|
8
|
+
aggregations_path: Path,
|
|
9
|
+
output_path: Path,
|
|
10
|
+
fps: int = 1
|
|
11
|
+
) -> List[Dict[str, Any]]:
|
|
12
|
+
"""
|
|
13
|
+
Match captions with aggregated events based on timestamps.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
captions_path: Path to captions.jsonl
|
|
17
|
+
aggregations_path: Path to aggregations.jsonl
|
|
18
|
+
output_path: Path to save matched_captions.jsonl
|
|
19
|
+
fps: Frames per second used in video creation
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
List of matched caption-event objects
|
|
23
|
+
"""
|
|
24
|
+
# Load captions
|
|
25
|
+
captions = []
|
|
26
|
+
with open(captions_path, 'r', encoding='utf-8') as f:
|
|
27
|
+
for line in f:
|
|
28
|
+
if line.strip():
|
|
29
|
+
captions.append(json.loads(line))
|
|
30
|
+
|
|
31
|
+
# Load aggregations
|
|
32
|
+
aggregations = []
|
|
33
|
+
with open(aggregations_path, 'r', encoding='utf-8') as f:
|
|
34
|
+
for line in f:
|
|
35
|
+
if line.strip():
|
|
36
|
+
agg = json.loads(line)
|
|
37
|
+
aggregations.append(agg)
|
|
38
|
+
|
|
39
|
+
# Sort aggregations by timestamp
|
|
40
|
+
aggregations.sort(key=lambda x: x.get('timestamp', 0))
|
|
41
|
+
|
|
42
|
+
if not aggregations:
|
|
43
|
+
print("[Matcher] Warning: No aggregations found")
|
|
44
|
+
return []
|
|
45
|
+
|
|
46
|
+
# Get first aggregation timestamp (video start time)
|
|
47
|
+
first_timestamp = aggregations[0].get('timestamp', 0)
|
|
48
|
+
|
|
49
|
+
print(f"[Matcher] Video start time: {first_timestamp}")
|
|
50
|
+
print(f"[Matcher] Total aggregations: {len(aggregations)}")
|
|
51
|
+
print(f"[Matcher] FPS: {fps}")
|
|
52
|
+
|
|
53
|
+
# Match captions with events
|
|
54
|
+
matched_data = []
|
|
55
|
+
|
|
56
|
+
for caption in captions:
|
|
57
|
+
# Convert MM:SS to seconds
|
|
58
|
+
start_seconds = caption['start_seconds']
|
|
59
|
+
end_seconds = caption['end_seconds']
|
|
60
|
+
|
|
61
|
+
# Convert video time to aggregation indices
|
|
62
|
+
# Each aggregation represents 1 frame, so index = seconds * fps
|
|
63
|
+
start_index = int(start_seconds * fps)
|
|
64
|
+
end_index = int(end_seconds * fps)
|
|
65
|
+
|
|
66
|
+
# Clamp to valid range
|
|
67
|
+
start_index = max(0, min(start_index, len(aggregations) - 1))
|
|
68
|
+
end_index = max(start_index, min(end_index, len(aggregations) - 1))
|
|
69
|
+
|
|
70
|
+
print(f"[Matcher] Caption '{caption['caption'][:50]}...' -> indices [{start_index}, {end_index}]")
|
|
71
|
+
|
|
72
|
+
# Get aggregations in this range
|
|
73
|
+
matched_aggs = aggregations[start_index:end_index + 1]
|
|
74
|
+
|
|
75
|
+
if not matched_aggs:
|
|
76
|
+
# No events matched, but still save the caption
|
|
77
|
+
matched_entry = {
|
|
78
|
+
"start_time": first_timestamp + start_seconds,
|
|
79
|
+
"end_time": first_timestamp + end_seconds,
|
|
80
|
+
"start_index": start_index,
|
|
81
|
+
"end_index": end_index,
|
|
82
|
+
"img": None,
|
|
83
|
+
"caption": caption['caption'],
|
|
84
|
+
"raw_events": [],
|
|
85
|
+
"num_aggregations": 0,
|
|
86
|
+
"start_formatted": caption['start'],
|
|
87
|
+
"end_formatted": caption['end'],
|
|
88
|
+
}
|
|
89
|
+
else:
|
|
90
|
+
# Get first and last aggregation for time and image
|
|
91
|
+
first_agg = matched_aggs[0]
|
|
92
|
+
last_agg = matched_aggs[-1]
|
|
93
|
+
|
|
94
|
+
# Concatenate all events from matched aggregations
|
|
95
|
+
all_events = []
|
|
96
|
+
for agg in matched_aggs:
|
|
97
|
+
events = agg.get('events', [])
|
|
98
|
+
all_events.extend(events)
|
|
99
|
+
|
|
100
|
+
matched_entry = {
|
|
101
|
+
"start_time": first_agg.get('timestamp'),
|
|
102
|
+
"end_time": last_agg.get('timestamp'),
|
|
103
|
+
"start_index": start_index,
|
|
104
|
+
"end_index": end_index,
|
|
105
|
+
"img": first_agg.get('screenshot_path'),
|
|
106
|
+
"caption": caption['caption'],
|
|
107
|
+
"raw_events": all_events,
|
|
108
|
+
"num_aggregations": len(matched_aggs),
|
|
109
|
+
"start_formatted": caption['start'],
|
|
110
|
+
"end_formatted": caption['end'],
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
matched_data.append(matched_entry)
|
|
114
|
+
|
|
115
|
+
# Save matched data
|
|
116
|
+
with open(output_path, 'w', encoding='utf-8') as f:
|
|
117
|
+
for entry in matched_data:
|
|
118
|
+
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
|
119
|
+
|
|
120
|
+
print(f"[Matcher] Saved {len(matched_data)} matched entries to {output_path}")
|
|
121
|
+
|
|
122
|
+
return matched_data
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def create_matched_captions_for_session(session_dir: Path, fps: int = 1) -> Optional[Path]:
|
|
126
|
+
"""
|
|
127
|
+
Create matched_captions.jsonl for a session directory.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
session_dir: Path to session directory
|
|
131
|
+
fps: Frames per second used in video creation
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Path to created matched_captions.jsonl or None if failed
|
|
135
|
+
"""
|
|
136
|
+
captions_path = session_dir / "captions.jsonl"
|
|
137
|
+
aggregations_path = session_dir / "aggregations.jsonl"
|
|
138
|
+
output_path = session_dir / "matched_captions.jsonl"
|
|
139
|
+
|
|
140
|
+
if not captions_path.exists():
|
|
141
|
+
print(f"[Matcher] Warning: {captions_path} not found")
|
|
142
|
+
return None
|
|
143
|
+
|
|
144
|
+
if not aggregations_path.exists():
|
|
145
|
+
print(f"[Matcher] Warning: {aggregations_path} not found")
|
|
146
|
+
return None
|
|
147
|
+
|
|
148
|
+
match_captions_with_events(captions_path, aggregations_path, output_path, fps)
|
|
149
|
+
|
|
150
|
+
return output_path
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from label.clients.client import VLMClient, CAPTION_SCHEMA
|
|
2
|
+
from label.clients.gemini import GeminiClient, GeminiResponse
|
|
3
|
+
from label.clients.vllm import VLLMClient, VLLMResponse
|
|
4
|
+
from label.clients.bigquery import BigQueryClient, BigQueryResponse
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def create_client(client_type: str, **kwargs) -> VLMClient:
|
|
8
|
+
if client_type == 'gemini':
|
|
9
|
+
return GeminiClient(**kwargs)
|
|
10
|
+
elif client_type == 'vllm':
|
|
11
|
+
return VLLMClient(**kwargs)
|
|
12
|
+
elif client_type == 'bigquery':
|
|
13
|
+
return BigQueryClient(**kwargs)
|
|
14
|
+
else:
|
|
15
|
+
raise ValueError(f"Unknown client type: {client_type}")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"VLMClient",
|
|
20
|
+
"GeminiClient",
|
|
21
|
+
"GeminiResponse",
|
|
22
|
+
"VLLMClient",
|
|
23
|
+
"VLLMResponse",
|
|
24
|
+
"BigQueryClient",
|
|
25
|
+
"BigQueryResponse",
|
|
26
|
+
"CAPTION_SCHEMA",
|
|
27
|
+
"create_client",
|
|
28
|
+
]
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
from typing import Optional, Any, Dict
|
|
2
|
+
import json
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from google.cloud import storage
|
|
6
|
+
from google.cloud import bigquery
|
|
7
|
+
|
|
8
|
+
from label.clients.client import VLMClient, CAPTION_SCHEMA # still imported, but we don't rely on it by default
|
|
9
|
+
|
|
10
|
+
# example call-
|
|
11
|
+
# uv run -m label \
|
|
12
|
+
# --session /home/jupyter/Omar/downloads/test@gmail.com \
|
|
13
|
+
# --screenshots-only \
|
|
14
|
+
# --client bigquery \
|
|
15
|
+
# --model gemini-3-pro-preview \
|
|
16
|
+
# --bq-project hs-nero-phi-reeves-haitech \
|
|
17
|
+
# --bq-bucket-name hs-nero-phi-reeves-haitech-project \
|
|
18
|
+
# --bq-gcs-prefix Shaikh_Omar \
|
|
19
|
+
# --bq-object-table-location us.screenomics-gemini
|
|
20
|
+
|
|
21
|
+
class BigQueryResponse:
|
|
22
|
+
def __init__(self, result_row):
|
|
23
|
+
self.result_row = result_row
|
|
24
|
+
self._json = None
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def text(self) -> str:
|
|
28
|
+
return self.result_row
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def json(self):
|
|
32
|
+
if self._json is None:
|
|
33
|
+
self._json = json.loads(self.text)
|
|
34
|
+
return self._json
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class BigQueryClient(VLMClient):
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
model_name: str,
|
|
41
|
+
bucket_name: str,
|
|
42
|
+
gcs_prefix: str = "video_chunks",
|
|
43
|
+
object_table_location: str = "us", # e.g., "us.screenomics-gemini"
|
|
44
|
+
temperature: float = 0.0,
|
|
45
|
+
max_output_tokens: int = 65535,
|
|
46
|
+
project_id: Optional[str] = None,
|
|
47
|
+
):
|
|
48
|
+
"""
|
|
49
|
+
Initialize BigQuery client for ML.GENERATE_TEXT with video analysis.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
model_name: Full BigQuery model reference (e.g., "dataset.model" or "project.dataset.model")
|
|
53
|
+
bucket_name: GCS bucket name for uploading videos
|
|
54
|
+
gcs_prefix: Prefix/folder path in GCS bucket
|
|
55
|
+
object_table_location: Object table location (e.g., "us.screenomics-gemini")
|
|
56
|
+
temperature: Model temperature parameter
|
|
57
|
+
max_output_tokens: Maximum number of tokens in the generated response
|
|
58
|
+
project_id: Optional GCP project ID (if not provided, uses default credentials)
|
|
59
|
+
"""
|
|
60
|
+
self.model_name = model_name
|
|
61
|
+
self.bucket_name = bucket_name
|
|
62
|
+
self.gcs_prefix = gcs_prefix
|
|
63
|
+
self.object_table_location = object_table_location
|
|
64
|
+
self.temperature = temperature
|
|
65
|
+
self.max_output_tokens = max_output_tokens
|
|
66
|
+
self.project_id = project_id
|
|
67
|
+
|
|
68
|
+
self.storage_client = storage.Client(project=project_id)
|
|
69
|
+
self.bq_client = bigquery.Client(project=project_id)
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def _escape_for_bq_single_quoted_string(s: str) -> str:
|
|
73
|
+
r"""
|
|
74
|
+
Escape a Python string for use in a BigQuery single-quoted string literal.
|
|
75
|
+
|
|
76
|
+
BigQuery-style escaping:
|
|
77
|
+
- Backslash: \ -> \\
|
|
78
|
+
- Newline: actual newline -> \n (two chars)
|
|
79
|
+
- Carriage ret: actual \r -> \r
|
|
80
|
+
- Single quote: ' -> \'
|
|
81
|
+
"""
|
|
82
|
+
# Escape backslashes first so we don't re-escape ones we add later
|
|
83
|
+
s = s.replace("\\", "\\\\")
|
|
84
|
+
# Encode newlines and carriage returns as literal escape sequences
|
|
85
|
+
s = s.replace("\r", "\\r")
|
|
86
|
+
s = s.replace("\n", "\\n")
|
|
87
|
+
# Escape single quotes for BigQuery
|
|
88
|
+
s = s.replace("'", "\\'")
|
|
89
|
+
return s
|
|
90
|
+
|
|
91
|
+
def upload_file(self, path: str, session_id: str = None) -> str:
|
|
92
|
+
"""
|
|
93
|
+
Upload file to GCS and return the GCS URI.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
path: Local file path
|
|
97
|
+
session_id: Optional session identifier for namespacing uploads
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
GCS URI (gs://bucket/path/to/file)
|
|
101
|
+
"""
|
|
102
|
+
file_path = Path(path)
|
|
103
|
+
if session_id:
|
|
104
|
+
destination_blob_name = f"{self.gcs_prefix}/{session_id}/{file_path.name}"
|
|
105
|
+
else:
|
|
106
|
+
destination_blob_name = f"{self.gcs_prefix}/{file_path.name}"
|
|
107
|
+
|
|
108
|
+
bucket = self.storage_client.bucket(self.bucket_name)
|
|
109
|
+
blob = bucket.blob(destination_blob_name)
|
|
110
|
+
|
|
111
|
+
# Upload the file
|
|
112
|
+
blob.upload_from_filename(path)
|
|
113
|
+
|
|
114
|
+
gcs_uri = f"gs://{self.bucket_name}/{destination_blob_name}"
|
|
115
|
+
print(f"Uploaded {path} to {gcs_uri}")
|
|
116
|
+
|
|
117
|
+
return gcs_uri
|
|
118
|
+
|
|
119
|
+
def generate(
|
|
120
|
+
self,
|
|
121
|
+
prompt: str,
|
|
122
|
+
file_descriptor: Optional[Any] = None,
|
|
123
|
+
schema: Optional[Dict] = None,
|
|
124
|
+
) -> BigQueryResponse:
|
|
125
|
+
if not file_descriptor:
|
|
126
|
+
raise ValueError("file_descriptor (GCS URI) is required")
|
|
127
|
+
|
|
128
|
+
gcs_uri = file_descriptor
|
|
129
|
+
|
|
130
|
+
# Escape everything that will go inside single-quoted SQL string literals
|
|
131
|
+
escaped_prompt = self._escape_for_bq_single_quoted_string(prompt)
|
|
132
|
+
escaped_gcs_uri = self._escape_for_bq_single_quoted_string(gcs_uri)
|
|
133
|
+
escaped_location = self._escape_for_bq_single_quoted_string(
|
|
134
|
+
self.object_table_location
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
response_params = {
|
|
138
|
+
"generation_config": {
|
|
139
|
+
"media_resolution": "MEDIA_RESOLUTION_HIGH",
|
|
140
|
+
"response_mime_type": "application/json"
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
response_params_json = json.dumps(response_params)
|
|
145
|
+
escaped_response_params = self._escape_for_bq_single_quoted_string(response_params_json)
|
|
146
|
+
|
|
147
|
+
query = f"""
|
|
148
|
+
SELECT
|
|
149
|
+
AI.GENERATE(
|
|
150
|
+
(
|
|
151
|
+
'{escaped_prompt}',
|
|
152
|
+
OBJ.FETCH_METADATA(
|
|
153
|
+
OBJ.MAKE_REF('{escaped_gcs_uri}', '{escaped_location}')
|
|
154
|
+
)
|
|
155
|
+
),
|
|
156
|
+
endpoint => 'https://aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/global/publishers/google/models/{self.model_name}',
|
|
157
|
+
model_params => JSON '{escaped_response_params}'
|
|
158
|
+
) AS gen;
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
job_config = bigquery.QueryJobConfig()
|
|
162
|
+
|
|
163
|
+
query_job = self.bq_client.query(query, job_config=job_config)
|
|
164
|
+
results = query_job.result()
|
|
165
|
+
|
|
166
|
+
for row in results:
|
|
167
|
+
print(row[0]["result"])
|
|
168
|
+
return BigQueryResponse(row[0]["result"])
|
|
169
|
+
|
|
170
|
+
raise RuntimeError("No results returned from BigQuery")
|
|
171
|
+
|
label/clients/client.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any, Dict, Optional, List, Union
|
|
4
|
+
|
|
5
|
+
CAPTION_SCHEMA = {
|
|
6
|
+
"type": "array",
|
|
7
|
+
"items": {
|
|
8
|
+
"type": "object",
|
|
9
|
+
"properties": {
|
|
10
|
+
"start": {"type": "string"},
|
|
11
|
+
"end": {"type": "string"},
|
|
12
|
+
"caption": {"type": "string"}
|
|
13
|
+
},
|
|
14
|
+
"required": ["start", "end", "caption"]
|
|
15
|
+
}
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class VLMClient(ABC):
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def upload_file(self, path: str, session_id: str = None) -> Any:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def generate(self, prompt: Union[str, List[str]],
|
|
26
|
+
file_descriptor: Optional[Union[Any, List[Any]]] = None,
|
|
27
|
+
schema: Optional[Dict] = None) -> Union[Any, List[Any]]:
|
|
28
|
+
pass
|
label/clients/gemini.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from typing import Optional, Any, Dict
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import time
|
|
5
|
+
from label.clients.client import VLMClient, CAPTION_SCHEMA
|
|
6
|
+
|
|
7
|
+
from google import genai
|
|
8
|
+
from google.genai import types
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GeminiResponse:
|
|
12
|
+
def __init__(self, response):
|
|
13
|
+
self.response = response
|
|
14
|
+
self._json = None
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def text(self) -> str:
|
|
18
|
+
return self.response.text
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def json(self):
|
|
22
|
+
if self._json is None:
|
|
23
|
+
self._json = json.loads(self.text)
|
|
24
|
+
return self._json
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GeminiClient(VLMClient):
|
|
28
|
+
def __init__(self, api_key: Optional[str] = None, model_name: str = "gemini-2.5-flash"):
|
|
29
|
+
if genai is None:
|
|
30
|
+
raise RuntimeError("google-genai not installed")
|
|
31
|
+
|
|
32
|
+
api_key = api_key or os.environ.get("GEMINI_API_KEY")
|
|
33
|
+
if not api_key:
|
|
34
|
+
raise RuntimeError("GEMINI_API_KEY not set")
|
|
35
|
+
|
|
36
|
+
self.client = genai.Client(api_key=api_key)
|
|
37
|
+
self.model_name = model_name
|
|
38
|
+
|
|
39
|
+
def upload_file(self, path: str, session_id: str = None) -> Any:
|
|
40
|
+
video_file = self.client.files.upload(file=path)
|
|
41
|
+
|
|
42
|
+
while True:
|
|
43
|
+
video_file = self.client.files.get(name=video_file.name)
|
|
44
|
+
state = getattr(getattr(video_file, "state", None), "name", None)
|
|
45
|
+
|
|
46
|
+
if state == "PROCESSING":
|
|
47
|
+
time.sleep(2)
|
|
48
|
+
elif state == "FAILED":
|
|
49
|
+
raise RuntimeError("Gemini failed processing file")
|
|
50
|
+
elif state == "ACTIVE":
|
|
51
|
+
break
|
|
52
|
+
else:
|
|
53
|
+
break
|
|
54
|
+
|
|
55
|
+
return video_file
|
|
56
|
+
|
|
57
|
+
def generate(self, prompt: str, file_descriptor: Optional[Any] = None,
|
|
58
|
+
schema: Optional[Dict] = None) -> GeminiResponse:
|
|
59
|
+
|
|
60
|
+
inputs = [prompt]
|
|
61
|
+
if file_descriptor:
|
|
62
|
+
inputs.append(file_descriptor)
|
|
63
|
+
|
|
64
|
+
if "gemini-3" in self.model_name:
|
|
65
|
+
config = types.GenerateContentConfig(
|
|
66
|
+
response_mime_type="application/json",
|
|
67
|
+
response_schema=schema or CAPTION_SCHEMA,
|
|
68
|
+
thinking_config=types.ThinkingConfig(thinking_budget=-1),
|
|
69
|
+
media_resolution=types.MediaResolution.MEDIA_RESOLUTION_HIGH,
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
config = types.GenerateContentConfig(
|
|
73
|
+
response_mime_type="application/json",
|
|
74
|
+
temperature=0.0,
|
|
75
|
+
response_schema=schema or CAPTION_SCHEMA
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
res = self.client.models.generate_content(
|
|
79
|
+
model=self.model_name,
|
|
80
|
+
contents=inputs,
|
|
81
|
+
config=config
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return GeminiResponse(res)
|