napsack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
label/__init__.py ADDED
File without changes
label/__main__.py ADDED
@@ -0,0 +1,210 @@
1
+ from pathlib import Path
2
+ import argparse
3
+ from dotenv import load_dotenv
4
+
5
+ from label.discovery import discover_sessions, discover_screenshots_sessions, create_single_config
6
+ from label.clients import create_client
7
+ from label.processor import Processor
8
+ from label.visualizer import Visualizer
9
+
10
+ load_dotenv()
11
+
12
+
13
+ def parse_args():
14
+ p = argparse.ArgumentParser(description="Process session recordings with VLM")
15
+
16
+ session_group = p.add_mutually_exclusive_group(required=True)
17
+ session_group.add_argument("--session", type=Path)
18
+ session_group.add_argument("--sessions-root", type=Path)
19
+
20
+ p.add_argument("--chunk-duration", type=int, default=60, help="Chunk duration in seconds")
21
+ p.add_argument("--fps", type=int, default=1, help="Frames per second for video processing")
22
+
23
+ p.add_argument("--screenshots-only", action="store_true", help="Process screenshots folder only without aggregations or annotations")
24
+ p.add_argument("--image-extensions", nargs="+", default=[".jpg", ".jpeg", ".png"], help="Image file extensions to consider")
25
+ p.add_argument("--max-time-gap", type=float, default=300.0, help="Maximum time gap (seconds) between images before forcing a video split (default: 120 = 2 minutes)")
26
+ p.add_argument("--prompt-file", default=None, help="Path to prompt file (default: prompts/default.txt or prompts/screenshots_only.txt if screenshots only)")
27
+ p.add_argument("--hash-cache", type=str, default=None, help="Path to hash_cache.json for deduplicating consecutive similar images")
28
+ p.add_argument("--dedupe-threshold", type=int, default=1, help="Hamming distance threshold for deduplication (drop if <= threshold, default: 1)")
29
+ p.add_argument("--annotate", action="store_true", help="Annotate videos with cursor positions and clicks (only for standard processing)")
30
+ p.add_argument("--skip-existing", action="store_true", help="Skip sessions that have already been processed")
31
+ p.add_argument("--visualize", action="store_true", help="Create annotated video visualizations after processing")
32
+ p.add_argument("--encode-only", action="store_true", help="Only encode videos (create chunks), skip labeling. Useful for pre-processing before running the full pipeline.")
33
+
34
+ p.add_argument("--client", choices=["gemini", "vllm", "bigquery"], default="gemini")
35
+ p.add_argument("--model", default="")
36
+ p.add_argument("--encode-workers", type=int, default=8, help="Number of parallel workers for video encoding")
37
+ p.add_argument("--label-workers", type=int, default=4, help="Number of parallel workers for VLM labeling")
38
+
39
+ vllm_group = p.add_argument_group("vLLM Options")
40
+ vllm_group.add_argument("--vllm-url")
41
+
42
+ bq_group = p.add_argument_group("BigQuery Options")
43
+ bq_group.add_argument("--bq-project", help="GCP project ID for AI Platform endpoint")
44
+ bq_group.add_argument("--bq-bucket-name", help="GCS bucket name for uploading videos")
45
+ bq_group.add_argument("--bq-gcs-prefix", default="video_chunks", help="Prefix/folder path in GCS bucket")
46
+ bq_group.add_argument("--bq-object-table-location", default="us", help="Object table location (e.g., 'us' or 'us.screenomics-gemini')")
47
+
48
+ args = p.parse_args()
49
+
50
+ if not args.model:
51
+ if args.client == 'gemini':
52
+ args.model = 'gemini-3-flash-preview'
53
+ elif args.client == 'vllm':
54
+ args.model = 'Qwen/Qwen3-VL-8B-Thinking-FP8'
55
+ elif args.client == 'bigquery':
56
+ args.model = 'dataset.model' # Placeholder - user must provide full model reference
57
+ if not args.prompt_file:
58
+ args.prompt_file = "prompts/screenshots_only.txt" if args.screenshots_only else "prompts/default.txt"
59
+
60
+ return args
61
+
62
+
63
+ def setup_configs(args):
64
+ if args.session:
65
+ configs = [create_single_config(
66
+ args.session,
67
+ args.chunk_duration,
68
+ args.screenshots_only,
69
+ tuple(args.image_extensions),
70
+ )]
71
+ else:
72
+ if args.screenshots_only:
73
+ configs = discover_screenshots_sessions(
74
+ args.sessions_root,
75
+ args.chunk_duration,
76
+ tuple(args.image_extensions),
77
+ )
78
+ else:
79
+ configs = discover_sessions(
80
+ args.sessions_root,
81
+ args.chunk_duration,
82
+ args.skip_existing,
83
+ )
84
+
85
+ if not configs:
86
+ print(f"No sessions found in {args.sessions_root}")
87
+ return []
88
+
89
+ return configs
90
+
91
+
92
+ def process_with_gemini(args, configs):
93
+ client = create_client(
94
+ 'gemini',
95
+ model_name=args.model,
96
+ )
97
+
98
+ processor = Processor(
99
+ client=client,
100
+ encode_workers=args.encode_workers,
101
+ label_workers=args.label_workers,
102
+ screenshots_only=args.screenshots_only,
103
+ prompt_file=args.prompt_file,
104
+ max_time_gap=args.max_time_gap,
105
+ hash_cache_path=args.hash_cache,
106
+ dedupe_threshold=args.dedupe_threshold,
107
+ )
108
+
109
+ return processor.process_sessions(
110
+ configs,
111
+ fps=args.fps,
112
+ annotate=args.annotate and not args.screenshots_only,
113
+ encode_only=args.encode_only,
114
+ )
115
+
116
+
117
+ def process_with_vllm(args, configs):
118
+ client = create_client(
119
+ 'vllm',
120
+ base_url=args.vllm_url if args.vllm_url.endswith('/v1') else f"{args.vllm_url}/v1",
121
+ model_name=args.model
122
+ )
123
+
124
+ processor = Processor(
125
+ client=client,
126
+ encode_workers=args.encode_workers,
127
+ label_workers=args.label_workers,
128
+ screenshots_only=args.screenshots_only,
129
+ prompt_file=args.prompt_file,
130
+ max_time_gap=args.max_time_gap,
131
+ hash_cache_path=args.hash_cache,
132
+ dedupe_threshold=args.dedupe_threshold,
133
+ )
134
+
135
+ return processor.process_sessions(
136
+ configs,
137
+ fps=args.fps,
138
+ annotate=args.annotate and not args.screenshots_only,
139
+ encode_only=args.encode_only,
140
+ )
141
+
142
+
143
+ def process_with_bigquery(args, configs):
144
+ client = create_client(
145
+ 'bigquery',
146
+ model_name=args.model,
147
+ bucket_name=args.bq_bucket_name,
148
+ gcs_prefix=args.bq_gcs_prefix,
149
+ object_table_location=args.bq_object_table_location,
150
+ project_id=args.bq_project,
151
+ )
152
+
153
+ processor = Processor(
154
+ client=client,
155
+ encode_workers=args.encode_workers,
156
+ label_workers=args.label_workers,
157
+ screenshots_only=args.screenshots_only,
158
+ prompt_file=args.prompt_file,
159
+ max_time_gap=args.max_time_gap,
160
+ hash_cache_path=args.hash_cache,
161
+ dedupe_threshold=args.dedupe_threshold,
162
+ )
163
+
164
+ return processor.process_sessions(
165
+ configs,
166
+ fps=args.fps,
167
+ annotate=args.annotate and not args.screenshots_only,
168
+ encode_only=args.encode_only,
169
+ )
170
+
171
+
172
+ def main():
173
+ args = parse_args()
174
+
175
+ configs = setup_configs(args)
176
+ if not configs:
177
+ return
178
+
179
+ print(f"Processing {len(configs)} sessions")
180
+
181
+ if args.client == 'gemini':
182
+ results = process_with_gemini(args, configs)
183
+ elif args.client == 'vllm':
184
+ results = process_with_vllm(args, configs)
185
+ elif args.client == 'bigquery':
186
+ results = process_with_bigquery(args, configs)
187
+ else:
188
+ raise ValueError(f"Unknown client: {args.client}")
189
+
190
+ print(f"✓ Processed {len(results)} sessions")
191
+
192
+ if args.visualize:
193
+ print("\nCreating visualizations...")
194
+ visualizer = Visualizer(args.annotate)
195
+
196
+ for config in configs:
197
+ if not config.matched_captions_jsonl.exists():
198
+ print(f"Skipping Visualizing {config.session_id}: no data.jsonl")
199
+ continue
200
+
201
+ try:
202
+ output = config.session_folder / "annotated.mp4"
203
+ visualizer.visualize(config.session_folder, output, args.fps)
204
+ print(f"✓ {config.session_id}: {output}")
205
+ except Exception as e:
206
+ print(f"✗ {config.session_id}: {e}")
207
+
208
+
209
+ if __name__ == '__main__':
210
+ main()
@@ -0,0 +1,150 @@
1
+ from pathlib import Path
2
+ from typing import List, Dict, Any, Optional
3
+ import json
4
+
5
+
6
+ def match_captions_with_events(
7
+ captions_path: Path,
8
+ aggregations_path: Path,
9
+ output_path: Path,
10
+ fps: int = 1
11
+ ) -> List[Dict[str, Any]]:
12
+ """
13
+ Match captions with aggregated events based on timestamps.
14
+
15
+ Args:
16
+ captions_path: Path to captions.jsonl
17
+ aggregations_path: Path to aggregations.jsonl
18
+ output_path: Path to save matched_captions.jsonl
19
+ fps: Frames per second used in video creation
20
+
21
+ Returns:
22
+ List of matched caption-event objects
23
+ """
24
+ # Load captions
25
+ captions = []
26
+ with open(captions_path, 'r', encoding='utf-8') as f:
27
+ for line in f:
28
+ if line.strip():
29
+ captions.append(json.loads(line))
30
+
31
+ # Load aggregations
32
+ aggregations = []
33
+ with open(aggregations_path, 'r', encoding='utf-8') as f:
34
+ for line in f:
35
+ if line.strip():
36
+ agg = json.loads(line)
37
+ aggregations.append(agg)
38
+
39
+ # Sort aggregations by timestamp
40
+ aggregations.sort(key=lambda x: x.get('timestamp', 0))
41
+
42
+ if not aggregations:
43
+ print("[Matcher] Warning: No aggregations found")
44
+ return []
45
+
46
+ # Get first aggregation timestamp (video start time)
47
+ first_timestamp = aggregations[0].get('timestamp', 0)
48
+
49
+ print(f"[Matcher] Video start time: {first_timestamp}")
50
+ print(f"[Matcher] Total aggregations: {len(aggregations)}")
51
+ print(f"[Matcher] FPS: {fps}")
52
+
53
+ # Match captions with events
54
+ matched_data = []
55
+
56
+ for caption in captions:
57
+ # Convert MM:SS to seconds
58
+ start_seconds = caption['start_seconds']
59
+ end_seconds = caption['end_seconds']
60
+
61
+ # Convert video time to aggregation indices
62
+ # Each aggregation represents 1 frame, so index = seconds * fps
63
+ start_index = int(start_seconds * fps)
64
+ end_index = int(end_seconds * fps)
65
+
66
+ # Clamp to valid range
67
+ start_index = max(0, min(start_index, len(aggregations) - 1))
68
+ end_index = max(start_index, min(end_index, len(aggregations) - 1))
69
+
70
+ print(f"[Matcher] Caption '{caption['caption'][:50]}...' -> indices [{start_index}, {end_index}]")
71
+
72
+ # Get aggregations in this range
73
+ matched_aggs = aggregations[start_index:end_index + 1]
74
+
75
+ if not matched_aggs:
76
+ # No events matched, but still save the caption
77
+ matched_entry = {
78
+ "start_time": first_timestamp + start_seconds,
79
+ "end_time": first_timestamp + end_seconds,
80
+ "start_index": start_index,
81
+ "end_index": end_index,
82
+ "img": None,
83
+ "caption": caption['caption'],
84
+ "raw_events": [],
85
+ "num_aggregations": 0,
86
+ "start_formatted": caption['start'],
87
+ "end_formatted": caption['end'],
88
+ }
89
+ else:
90
+ # Get first and last aggregation for time and image
91
+ first_agg = matched_aggs[0]
92
+ last_agg = matched_aggs[-1]
93
+
94
+ # Concatenate all events from matched aggregations
95
+ all_events = []
96
+ for agg in matched_aggs:
97
+ events = agg.get('events', [])
98
+ all_events.extend(events)
99
+
100
+ matched_entry = {
101
+ "start_time": first_agg.get('timestamp'),
102
+ "end_time": last_agg.get('timestamp'),
103
+ "start_index": start_index,
104
+ "end_index": end_index,
105
+ "img": first_agg.get('screenshot_path'),
106
+ "caption": caption['caption'],
107
+ "raw_events": all_events,
108
+ "num_aggregations": len(matched_aggs),
109
+ "start_formatted": caption['start'],
110
+ "end_formatted": caption['end'],
111
+ }
112
+
113
+ matched_data.append(matched_entry)
114
+
115
+ # Save matched data
116
+ with open(output_path, 'w', encoding='utf-8') as f:
117
+ for entry in matched_data:
118
+ f.write(json.dumps(entry, ensure_ascii=False) + '\n')
119
+
120
+ print(f"[Matcher] Saved {len(matched_data)} matched entries to {output_path}")
121
+
122
+ return matched_data
123
+
124
+
125
+ def create_matched_captions_for_session(session_dir: Path, fps: int = 1) -> Optional[Path]:
126
+ """
127
+ Create matched_captions.jsonl for a session directory.
128
+
129
+ Args:
130
+ session_dir: Path to session directory
131
+ fps: Frames per second used in video creation
132
+
133
+ Returns:
134
+ Path to created matched_captions.jsonl or None if failed
135
+ """
136
+ captions_path = session_dir / "captions.jsonl"
137
+ aggregations_path = session_dir / "aggregations.jsonl"
138
+ output_path = session_dir / "matched_captions.jsonl"
139
+
140
+ if not captions_path.exists():
141
+ print(f"[Matcher] Warning: {captions_path} not found")
142
+ return None
143
+
144
+ if not aggregations_path.exists():
145
+ print(f"[Matcher] Warning: {aggregations_path} not found")
146
+ return None
147
+
148
+ match_captions_with_events(captions_path, aggregations_path, output_path, fps)
149
+
150
+ return output_path
@@ -0,0 +1,28 @@
1
+ from label.clients.client import VLMClient, CAPTION_SCHEMA
2
+ from label.clients.gemini import GeminiClient, GeminiResponse
3
+ from label.clients.vllm import VLLMClient, VLLMResponse
4
+ from label.clients.bigquery import BigQueryClient, BigQueryResponse
5
+
6
+
7
+ def create_client(client_type: str, **kwargs) -> VLMClient:
8
+ if client_type == 'gemini':
9
+ return GeminiClient(**kwargs)
10
+ elif client_type == 'vllm':
11
+ return VLLMClient(**kwargs)
12
+ elif client_type == 'bigquery':
13
+ return BigQueryClient(**kwargs)
14
+ else:
15
+ raise ValueError(f"Unknown client type: {client_type}")
16
+
17
+
18
+ __all__ = [
19
+ "VLMClient",
20
+ "GeminiClient",
21
+ "GeminiResponse",
22
+ "VLLMClient",
23
+ "VLLMResponse",
24
+ "BigQueryClient",
25
+ "BigQueryResponse",
26
+ "CAPTION_SCHEMA",
27
+ "create_client",
28
+ ]
@@ -0,0 +1,171 @@
1
+ from typing import Optional, Any, Dict
2
+ import json
3
+ from pathlib import Path
4
+
5
+ from google.cloud import storage
6
+ from google.cloud import bigquery
7
+
8
+ from label.clients.client import VLMClient, CAPTION_SCHEMA # still imported, but we don't rely on it by default
9
+
10
+ # example call-
11
+ # uv run -m label \
12
+ # --session /home/jupyter/Omar/downloads/test@gmail.com \
13
+ # --screenshots-only \
14
+ # --client bigquery \
15
+ # --model gemini-3-pro-preview \
16
+ # --bq-project hs-nero-phi-reeves-haitech \
17
+ # --bq-bucket-name hs-nero-phi-reeves-haitech-project \
18
+ # --bq-gcs-prefix Shaikh_Omar \
19
+ # --bq-object-table-location us.screenomics-gemini
20
+
21
+ class BigQueryResponse:
22
+ def __init__(self, result_row):
23
+ self.result_row = result_row
24
+ self._json = None
25
+
26
+ @property
27
+ def text(self) -> str:
28
+ return self.result_row
29
+
30
+ @property
31
+ def json(self):
32
+ if self._json is None:
33
+ self._json = json.loads(self.text)
34
+ return self._json
35
+
36
+
37
+ class BigQueryClient(VLMClient):
38
+ def __init__(
39
+ self,
40
+ model_name: str,
41
+ bucket_name: str,
42
+ gcs_prefix: str = "video_chunks",
43
+ object_table_location: str = "us", # e.g., "us.screenomics-gemini"
44
+ temperature: float = 0.0,
45
+ max_output_tokens: int = 65535,
46
+ project_id: Optional[str] = None,
47
+ ):
48
+ """
49
+ Initialize BigQuery client for ML.GENERATE_TEXT with video analysis.
50
+
51
+ Args:
52
+ model_name: Full BigQuery model reference (e.g., "dataset.model" or "project.dataset.model")
53
+ bucket_name: GCS bucket name for uploading videos
54
+ gcs_prefix: Prefix/folder path in GCS bucket
55
+ object_table_location: Object table location (e.g., "us.screenomics-gemini")
56
+ temperature: Model temperature parameter
57
+ max_output_tokens: Maximum number of tokens in the generated response
58
+ project_id: Optional GCP project ID (if not provided, uses default credentials)
59
+ """
60
+ self.model_name = model_name
61
+ self.bucket_name = bucket_name
62
+ self.gcs_prefix = gcs_prefix
63
+ self.object_table_location = object_table_location
64
+ self.temperature = temperature
65
+ self.max_output_tokens = max_output_tokens
66
+ self.project_id = project_id
67
+
68
+ self.storage_client = storage.Client(project=project_id)
69
+ self.bq_client = bigquery.Client(project=project_id)
70
+
71
+ @staticmethod
72
+ def _escape_for_bq_single_quoted_string(s: str) -> str:
73
+ r"""
74
+ Escape a Python string for use in a BigQuery single-quoted string literal.
75
+
76
+ BigQuery-style escaping:
77
+ - Backslash: \ -> \\
78
+ - Newline: actual newline -> \n (two chars)
79
+ - Carriage ret: actual \r -> \r
80
+ - Single quote: ' -> \'
81
+ """
82
+ # Escape backslashes first so we don't re-escape ones we add later
83
+ s = s.replace("\\", "\\\\")
84
+ # Encode newlines and carriage returns as literal escape sequences
85
+ s = s.replace("\r", "\\r")
86
+ s = s.replace("\n", "\\n")
87
+ # Escape single quotes for BigQuery
88
+ s = s.replace("'", "\\'")
89
+ return s
90
+
91
+ def upload_file(self, path: str, session_id: str = None) -> str:
92
+ """
93
+ Upload file to GCS and return the GCS URI.
94
+
95
+ Args:
96
+ path: Local file path
97
+ session_id: Optional session identifier for namespacing uploads
98
+
99
+ Returns:
100
+ GCS URI (gs://bucket/path/to/file)
101
+ """
102
+ file_path = Path(path)
103
+ if session_id:
104
+ destination_blob_name = f"{self.gcs_prefix}/{session_id}/{file_path.name}"
105
+ else:
106
+ destination_blob_name = f"{self.gcs_prefix}/{file_path.name}"
107
+
108
+ bucket = self.storage_client.bucket(self.bucket_name)
109
+ blob = bucket.blob(destination_blob_name)
110
+
111
+ # Upload the file
112
+ blob.upload_from_filename(path)
113
+
114
+ gcs_uri = f"gs://{self.bucket_name}/{destination_blob_name}"
115
+ print(f"Uploaded {path} to {gcs_uri}")
116
+
117
+ return gcs_uri
118
+
119
+ def generate(
120
+ self,
121
+ prompt: str,
122
+ file_descriptor: Optional[Any] = None,
123
+ schema: Optional[Dict] = None,
124
+ ) -> BigQueryResponse:
125
+ if not file_descriptor:
126
+ raise ValueError("file_descriptor (GCS URI) is required")
127
+
128
+ gcs_uri = file_descriptor
129
+
130
+ # Escape everything that will go inside single-quoted SQL string literals
131
+ escaped_prompt = self._escape_for_bq_single_quoted_string(prompt)
132
+ escaped_gcs_uri = self._escape_for_bq_single_quoted_string(gcs_uri)
133
+ escaped_location = self._escape_for_bq_single_quoted_string(
134
+ self.object_table_location
135
+ )
136
+
137
+ response_params = {
138
+ "generation_config": {
139
+ "media_resolution": "MEDIA_RESOLUTION_HIGH",
140
+ "response_mime_type": "application/json"
141
+ }
142
+ }
143
+
144
+ response_params_json = json.dumps(response_params)
145
+ escaped_response_params = self._escape_for_bq_single_quoted_string(response_params_json)
146
+
147
+ query = f"""
148
+ SELECT
149
+ AI.GENERATE(
150
+ (
151
+ '{escaped_prompt}',
152
+ OBJ.FETCH_METADATA(
153
+ OBJ.MAKE_REF('{escaped_gcs_uri}', '{escaped_location}')
154
+ )
155
+ ),
156
+ endpoint => 'https://aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/global/publishers/google/models/{self.model_name}',
157
+ model_params => JSON '{escaped_response_params}'
158
+ ) AS gen;
159
+ """
160
+
161
+ job_config = bigquery.QueryJobConfig()
162
+
163
+ query_job = self.bq_client.query(query, job_config=job_config)
164
+ results = query_job.result()
165
+
166
+ for row in results:
167
+ print(row[0]["result"])
168
+ return BigQueryResponse(row[0]["result"])
169
+
170
+ raise RuntimeError("No results returned from BigQuery")
171
+
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Dict, Optional, List, Union
4
+
5
+ CAPTION_SCHEMA = {
6
+ "type": "array",
7
+ "items": {
8
+ "type": "object",
9
+ "properties": {
10
+ "start": {"type": "string"},
11
+ "end": {"type": "string"},
12
+ "caption": {"type": "string"}
13
+ },
14
+ "required": ["start", "end", "caption"]
15
+ }
16
+ }
17
+
18
+
19
+ class VLMClient(ABC):
20
+ @abstractmethod
21
+ def upload_file(self, path: str, session_id: str = None) -> Any:
22
+ pass
23
+
24
+ @abstractmethod
25
+ def generate(self, prompt: Union[str, List[str]],
26
+ file_descriptor: Optional[Union[Any, List[Any]]] = None,
27
+ schema: Optional[Dict] = None) -> Union[Any, List[Any]]:
28
+ pass
@@ -0,0 +1,84 @@
1
+ from typing import Optional, Any, Dict
2
+ import json
3
+ import os
4
+ import time
5
+ from label.clients.client import VLMClient, CAPTION_SCHEMA
6
+
7
+ from google import genai
8
+ from google.genai import types
9
+
10
+
11
+ class GeminiResponse:
12
+ def __init__(self, response):
13
+ self.response = response
14
+ self._json = None
15
+
16
+ @property
17
+ def text(self) -> str:
18
+ return self.response.text
19
+
20
+ @property
21
+ def json(self):
22
+ if self._json is None:
23
+ self._json = json.loads(self.text)
24
+ return self._json
25
+
26
+
27
+ class GeminiClient(VLMClient):
28
+ def __init__(self, api_key: Optional[str] = None, model_name: str = "gemini-2.5-flash"):
29
+ if genai is None:
30
+ raise RuntimeError("google-genai not installed")
31
+
32
+ api_key = api_key or os.environ.get("GEMINI_API_KEY")
33
+ if not api_key:
34
+ raise RuntimeError("GEMINI_API_KEY not set")
35
+
36
+ self.client = genai.Client(api_key=api_key)
37
+ self.model_name = model_name
38
+
39
+ def upload_file(self, path: str, session_id: str = None) -> Any:
40
+ video_file = self.client.files.upload(file=path)
41
+
42
+ while True:
43
+ video_file = self.client.files.get(name=video_file.name)
44
+ state = getattr(getattr(video_file, "state", None), "name", None)
45
+
46
+ if state == "PROCESSING":
47
+ time.sleep(2)
48
+ elif state == "FAILED":
49
+ raise RuntimeError("Gemini failed processing file")
50
+ elif state == "ACTIVE":
51
+ break
52
+ else:
53
+ break
54
+
55
+ return video_file
56
+
57
+ def generate(self, prompt: str, file_descriptor: Optional[Any] = None,
58
+ schema: Optional[Dict] = None) -> GeminiResponse:
59
+
60
+ inputs = [prompt]
61
+ if file_descriptor:
62
+ inputs.append(file_descriptor)
63
+
64
+ if "gemini-3" in self.model_name:
65
+ config = types.GenerateContentConfig(
66
+ response_mime_type="application/json",
67
+ response_schema=schema or CAPTION_SCHEMA,
68
+ thinking_config=types.ThinkingConfig(thinking_budget=-1),
69
+ media_resolution=types.MediaResolution.MEDIA_RESOLUTION_HIGH,
70
+ )
71
+ else:
72
+ config = types.GenerateContentConfig(
73
+ response_mime_type="application/json",
74
+ temperature=0.0,
75
+ response_schema=schema or CAPTION_SCHEMA
76
+ )
77
+
78
+ res = self.client.models.generate_content(
79
+ model=self.model_name,
80
+ contents=inputs,
81
+ config=config
82
+ )
83
+
84
+ return GeminiResponse(res)