agent-shared-core 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,9 @@
1
+ Metadata-Version: 2.4
2
+ Name: agent-shared-core
3
+ Version: 0.1.0
4
+ Summary: Shared Compass core and marketing domain logic for agents
5
+ Requires-Python: >=3.11
6
+ Requires-Dist: google-cloud-firestore
7
+ Requires-Dist: google-cloud-storage
8
+ Requires-Dist: google-auth
9
+ Requires-Dist: trafilatura
@@ -0,0 +1,3 @@
1
+ # agent_shared_core: shared Python library for the Compass Agent Framework.
2
+ # Not an ADK agent (no root_agent). Provides core infra and domain logic.
3
+
@@ -0,0 +1,9 @@
1
+ Metadata-Version: 2.4
2
+ Name: agent-shared-core
3
+ Version: 0.1.0
4
+ Summary: Shared Compass core and marketing domain logic for agents
5
+ Requires-Python: >=3.11
6
+ Requires-Dist: google-cloud-firestore
7
+ Requires-Dist: google-cloud-storage
8
+ Requires-Dist: google-auth
9
+ Requires-Dist: trafilatura
@@ -0,0 +1,22 @@
1
+ __init__.py
2
+ pyproject.toml
3
+ ./__init__.py
4
+ agent_shared_core.egg-info/PKG-INFO
5
+ agent_shared_core.egg-info/SOURCES.txt
6
+ agent_shared_core.egg-info/dependency_links.txt
7
+ agent_shared_core.egg-info/requires.txt
8
+ agent_shared_core.egg-info/top_level.txt
9
+ core/__init__.py
10
+ core/auth.py
11
+ core/checkpoints.py
12
+ core/config.py
13
+ core/content.py
14
+ core/firestore.py
15
+ core/image_gen.py
16
+ core/prompts.py
17
+ core/wallet.py
18
+ mktg/__init__.py
19
+ mktg/brands.py
20
+ mktg/compositions.py
21
+ mktg/platform_specs.py
22
+ mktg/themes.py
@@ -0,0 +1,4 @@
1
+ google-cloud-firestore
2
+ google-cloud-storage
3
+ google-auth
4
+ trafilatura
@@ -0,0 +1 @@
1
+ agent_shared_core
@@ -0,0 +1,2 @@
1
+ # agent_shared_core.core: universal infrastructure (firestore, auth, wallet, config, prompts)
2
+
@@ -0,0 +1,79 @@
1
+ """Role resolution via Google Directory API with TTL cache.
2
+
3
+ Usage:
4
+ from agent_shared_core.core.auth import resolve_role
5
+ role = resolve_role("user@company.com", "mktg") # returns "admin" or "exec"
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ import time
11
+ from typing import Dict, Tuple
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Cache: key = (email, dept), value = (role, timestamp)
16
+ _role_cache: Dict[Tuple[str, str], Tuple[str, float]] = {}
17
+
18
+ # Cache TTL in seconds (5 minutes)
19
+ _CACHE_TTL = 300
20
+
21
+ # Read domain from environment. Set per customer deployment.
22
+ # Default to company.com for the dev sandbox (concretio-compass-sb).
23
+ _GROUP_DOMAIN = os.environ.get("COMPASS_GROUP_DOMAIN", "company.com")
24
+
25
+
26
+ def _check_group_membership(email: str, group_email: str) -> bool:
27
+ """Check if email is a member of the given Google Group.
28
+
29
+ Uses the Google Admin SDK Directory API.
30
+ """
31
+ from googleapiclient.discovery import build
32
+ import google.auth
33
+
34
+ credentials, _ = google.auth.default(
35
+ scopes=["https://www.googleapis.com/auth/admin.directory.group.member.readonly"]
36
+ )
37
+ service = build("admin", "directory_v1", credentials=credentials)
38
+
39
+ try:
40
+ result = service.members().hasMember(
41
+ groupKey=group_email, memberKey=email
42
+ ).execute()
43
+ return result.get("isMember", False)
44
+ except Exception:
45
+ return False
46
+
47
+
48
+ def resolve_role(email: str, dept: str) -> str:
49
+ """Resolve user role based on Google Group membership.
50
+
51
+ Returns "admin" if the user is in {dept}-admins@company.com,
52
+ otherwise returns "exec". On API failure, defaults to "exec" (safe default).
53
+
54
+ Results are cached for 5 minutes per (email, dept) pair.
55
+ """
56
+ cache_key = (email, dept)
57
+
58
+ # Check cache
59
+ if cache_key in _role_cache:
60
+ cached_role, cached_time = _role_cache[cache_key]
61
+ if time.time() - cached_time < _CACHE_TTL:
62
+ return cached_role
63
+
64
+ # Query Directory API
65
+ try:
66
+ admin_group = f"{dept}-admins@{_GROUP_DOMAIN}"
67
+ is_admin = _check_group_membership(email, admin_group)
68
+ role = "admin" if is_admin else "exec"
69
+ except Exception:
70
+ logger.warning(
71
+ "Directory API failure for %s in dept %s, defaulting to exec",
72
+ email, dept,
73
+ )
74
+ role = "exec"
75
+
76
+ # Cache the result
77
+ _role_cache[cache_key] = (role, time.time())
78
+ return role
79
+
@@ -0,0 +1,58 @@
1
+ """Workflow checkpoints via ADK session state (no external deps).
2
+
3
+ Agents use checkpoints to persist intermediate progress within a session.
4
+ Data is stored in tool_context.state["checkpoint"] as a flat dict.
5
+
6
+ Usage:
7
+ from agent_shared_core.core.checkpoints import save_checkpoint, read_checkpoint
8
+
9
+ # Save progress after a step completes
10
+ save_checkpoint(tool_context, {"step": "theme_selected", "theme_id": "abc123"})
11
+
12
+ # Read current checkpoint (returns None if no checkpoint exists)
13
+ cp = read_checkpoint(tool_context)
14
+ if cp:
15
+ print(cp["step"]) # "theme_selected"
16
+
17
+ # Subsequent saves merge into the existing checkpoint
18
+ save_checkpoint(tool_context, {"step": "preview_generated", "image_url": "https://..."})
19
+ # checkpoint now has: step, theme_id, image_url, last_updated
20
+ """
21
+
22
+ import logging
23
+ from datetime import datetime, timezone
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ def save_checkpoint(tool_context, data: dict) -> None:
29
+ """Merge data into the session checkpoint and set last_updated.
30
+
31
+ Reads the existing checkpoint from tool_context.state, merges new
32
+ keys (overwriting on conflict), stamps the current UTC time, and
33
+ writes back.
34
+
35
+ Args:
36
+ tool_context: ADK ToolContext with a .state dict.
37
+ data: key-value pairs to merge into the checkpoint.
38
+ """
39
+ checkpoint = tool_context.state.get("checkpoint", {})
40
+ checkpoint.update(data)
41
+ checkpoint["last_updated"] = datetime.now(timezone.utc).isoformat()
42
+ tool_context.state["checkpoint"] = checkpoint
43
+
44
+ step = data.get("step", "unknown")
45
+ logger.info("Checkpoint saved: step=%s", step)
46
+
47
+
48
+ def read_checkpoint(tool_context) -> dict | None:
49
+ """Read the current checkpoint from session state.
50
+
51
+ Args:
52
+ tool_context: ADK ToolContext with a .state dict.
53
+
54
+ Returns:
55
+ The checkpoint dict, or None if no checkpoint exists.
56
+ """
57
+ return tool_context.state.get("checkpoint") or None
58
+
@@ -0,0 +1,37 @@
1
+ """Runtime configuration from Firestore /config/{agent_name}.
2
+
3
+ Usage:
4
+ from agent_shared_core.core.config import get_config
5
+
6
+ config = get_config("mktg_banner_generator")
7
+ preview_count = config.get("preview_count", 3)
8
+ model_id = config.get("model_id", "imagen-4-fast")
9
+ """
10
+
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def get_config(agent_name: str) -> dict:
17
+ """Read agent configuration from Firestore /config/{agent_name}.
18
+
19
+ Returns the document as a dict. Returns empty dict if the document
20
+ does not exist or on Firestore failure. Callers use the
21
+ .get(key, default) pattern to handle missing keys.
22
+ """
23
+ from agent_shared_core.core.firestore import get_client
24
+
25
+ try:
26
+ doc = get_client().collection("config").document(agent_name).get()
27
+ if doc.exists:
28
+ return doc.to_dict()
29
+ return {}
30
+ except Exception:
31
+ logger.warning(
32
+ "Failed to read config for %s, returning empty dict",
33
+ agent_name,
34
+ exc_info=True,
35
+ )
36
+ return {}
37
+
@@ -0,0 +1,215 @@
1
+ """Content extraction and context analysis for agent workflows.
2
+
3
+ Usage:
4
+ from agent_shared_core.core.content import fetch_content, extract_content_context
5
+
6
+ # Fetch from URL
7
+ result = fetch_content("https://example.com/article")
8
+ # => {"source": "https://...", "content": "...", "content_type": "url", "success": True}
9
+
10
+ # Pass raw text
11
+ result = fetch_content("Some plain text content here.")
12
+ # => {"source": "Some plain...", "content": "Some plain...", "content_type": "text", "success": True}
13
+
14
+ # Pass markdown
15
+ result = fetch_content("# Heading\n\n**Bold** text with [links](url)")
16
+ # => {"source": "# Heading...", "content": "Heading\nBold text with links", "content_type": "markdown", "success": True}
17
+
18
+ # Extract structured context via Gemini Flash
19
+ context = extract_content_context("Article text here...")
20
+ # => {"title": "...", "vertical": "...", "key_themes": [...], ...}
21
+ """
22
+
23
+ import json
24
+ import logging
25
+ import os
26
+ import re
27
+
28
+ import trafilatura
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ _MAX_CONTENT_LENGTH = 15_000
33
+
34
+ _EXTRACTION_PROMPT = """Extract structured information from the following content.
35
+ Return ONLY valid JSON with these keys:
36
+ - title (string): the main title or topic
37
+ - vertical (string): the industry or domain vertical
38
+ - key_themes (list of strings): 3-5 main themes
39
+ - mood (string): overall tone or mood
40
+ - visual_cues (list of strings): 2-4 visual elements suggested by the content
41
+ - color_hints (list of strings): 2-4 colors associated with the content's mood or vertical
42
+ - summary (string): 2-3 sentence summary
43
+
44
+ Content:
45
+ {content}
46
+ """
47
+
48
+
49
+ def fetch_content(source: str) -> dict:
50
+ """Fetch and normalize content from a URL, raw text, or markdown.
51
+
52
+ URLs are fetched via trafilatura. Markdown is stripped of formatting.
53
+ All content is truncated to 15,000 chars.
54
+
55
+ Args:
56
+ source: a URL (http/https), markdown string, or plain text.
57
+
58
+ Returns:
59
+ Dict with keys: source, content, content_type, success, and error (on failure).
60
+ """
61
+ if source.startswith("http://") or source.startswith("https://"):
62
+ return _fetch_url(source)
63
+
64
+ if _looks_like_markdown(source):
65
+ text = _strip_markdown(source)[:_MAX_CONTENT_LENGTH]
66
+ return {"source": source, "content": text, "content_type": "markdown", "success": True}
67
+
68
+ text = source[:_MAX_CONTENT_LENGTH]
69
+ return {"source": source, "content": text, "content_type": "text", "success": True}
70
+
71
+
72
+ def extract_content_context(text: str) -> dict:
73
+ """Extract structured context from text using Gemini 2.0 Flash.
74
+
75
+ Calls the model with a structured extraction prompt and parses the
76
+ JSON response. Logs cost via agent_shared_core.core.wallet.log_cost().
77
+
78
+ Args:
79
+ text: the content to analyze.
80
+
81
+ Returns:
82
+ Dict with title, vertical, key_themes, mood, visual_cues, color_hints, summary.
83
+ On failure, returns {"error": "..."}.
84
+ """
85
+ from google import genai
86
+
87
+ project = os.environ.get("GOOGLE_CLOUD_PROJECT", "concretio-compass-sb")
88
+ location = os.environ.get("GOOGLE_CLOUD_LOCATION", "us-central1")
89
+
90
+ try:
91
+ client = genai.Client(
92
+ vertexai=True,
93
+ project=project,
94
+ location=location,
95
+ )
96
+
97
+ prompt = _EXTRACTION_PROMPT.format(content=text[:_MAX_CONTENT_LENGTH])
98
+ response = client.models.generate_content(
99
+ model="gemini-2.0-flash",
100
+ contents=prompt,
101
+ )
102
+
103
+ raw = response.text.strip()
104
+ # Strip markdown code fences if the model wraps its output
105
+ if raw.startswith("```"):
106
+ raw = re.sub(r"^```(?:json)?\s*", "", raw)
107
+ raw = re.sub(r"\s*```$", "", raw)
108
+
109
+ result = json.loads(raw)
110
+
111
+ # Log cost estimate
112
+ from agent_shared_core.core.wallet import log_cost
113
+ log_cost(
114
+ agent="content_extraction",
115
+ user_email="system",
116
+ model="gemini-2.0-flash",
117
+ cost=0.001,
118
+ phase="extract_context",
119
+ )
120
+
121
+ logger.info("Content context extracted: title=%s", result.get("title", "unknown"))
122
+ return result
123
+
124
+ except json.JSONDecodeError as exc:
125
+ logger.error("Failed to parse Gemini response as JSON: %s", exc)
126
+ return {"error": f"JSON parse error: {exc}"}
127
+ except Exception as exc:
128
+ logger.error("Content context extraction failed: %s", exc, exc_info=True)
129
+ return {"error": str(exc)}
130
+
131
+
132
+ # --- Internal helpers ---
133
+
134
+
135
+ def _fetch_url(url: str) -> dict:
136
+ """Download and extract article text from a URL via trafilatura."""
137
+ try:
138
+ downloaded = trafilatura.fetch_url(url)
139
+ if downloaded is None:
140
+ return {
141
+ "source": url,
142
+ "content": "",
143
+ "content_type": "url",
144
+ "success": False,
145
+ "error": f"Could not fetch URL: {url}",
146
+ }
147
+
148
+ text = trafilatura.extract(downloaded)
149
+ if text is None:
150
+ return {
151
+ "source": url,
152
+ "content": "",
153
+ "content_type": "url",
154
+ "success": False,
155
+ "error": f"Could not extract content from: {url}",
156
+ }
157
+
158
+ truncated = text[:_MAX_CONTENT_LENGTH]
159
+ return {
160
+ "source": url,
161
+ "content": truncated,
162
+ "content_type": "url",
163
+ "success": True,
164
+ }
165
+ except Exception as exc:
166
+ logger.error("URL fetch failed for %s: %s", url, exc)
167
+ return {
168
+ "source": url,
169
+ "content": "",
170
+ "content_type": "url",
171
+ "success": False,
172
+ "error": f"Extraction failed: {exc}",
173
+ }
174
+
175
+
176
+ def _looks_like_markdown(text: str) -> bool:
177
+ """Detect if text contains common markdown patterns."""
178
+ md_patterns = [
179
+ r"^#{1,6}\s", # headings
180
+ r"\*\*.+\*\*", # bold
181
+ r"\[.+\]\(.+\)", # links
182
+ r"^[-*]\s", # unordered lists
183
+ r"^\d+\.\s", # ordered lists
184
+ r"^>\s", # blockquotes
185
+ r"```", # code blocks
186
+ ]
187
+ for pattern in md_patterns:
188
+ if re.search(pattern, text, re.MULTILINE):
189
+ return True
190
+ return False
191
+
192
+
193
+ def _strip_markdown(text: str) -> str:
194
+ """Remove common markdown formatting, returning plain text."""
195
+ # Remove code blocks
196
+ text = re.sub(r"```[\s\S]*?```", "", text)
197
+ # Remove inline code
198
+ text = re.sub(r"`([^`]+)`", r"\1", text)
199
+ # Remove images
200
+ text = re.sub(r"!\[([^\]]*)\]\([^)]+\)", r"\1", text)
201
+ # Remove links, keep text
202
+ text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", text)
203
+ # Remove headings markers
204
+ text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE)
205
+ # Remove bold/italic markers
206
+ text = re.sub(r"\*{1,3}([^*]+)\*{1,3}", r"\1", text)
207
+ text = re.sub(r"_{1,3}([^_]+)_{1,3}", r"\1", text)
208
+ # Remove blockquote markers
209
+ text = re.sub(r"^>\s?", "", text, flags=re.MULTILINE)
210
+ # Remove horizontal rules
211
+ text = re.sub(r"^[-*_]{3,}\s*$", "", text, flags=re.MULTILINE)
212
+ # Collapse multiple blank lines
213
+ text = re.sub(r"\n{3,}", "\n\n", text)
214
+ return text.strip()
215
+
@@ -0,0 +1,31 @@
1
+ """Singleton Firestore client for the Compass Agent Framework.
2
+
3
+ Usage:
4
+ from agent_shared_core.core.firestore import get_client
5
+ db = get_client()
6
+ doc = db.collection("brands").document("abc").get()
7
+ """
8
+
9
+ import logging
10
+ import os
11
+
12
+ from google.cloud import firestore
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ _client = None
17
+
18
+
19
+ def get_client() -> firestore.Client:
20
+ """Return a lazy-initialized singleton Firestore client.
21
+
22
+ Uses GOOGLE_CLOUD_PROJECT from the environment, defaulting to
23
+ concretio-compass-sb (the sandbox project).
24
+ """
25
+ global _client
26
+ if _client is None:
27
+ project = os.environ.get("GOOGLE_CLOUD_PROJECT", "concretio-compass-sb")
28
+ _client = firestore.Client(project=project)
29
+ logger.info("Firestore client initialized for project: %s", project)
30
+ return _client
31
+
@@ -0,0 +1,183 @@
1
+ """Image generation and upscaling via Vertex AI (Imagen 4 + Gemini 3 Pro).
2
+
3
+ Usage:
4
+ from agent_shared_core.core.image_gen import generate_images, upscale_image
5
+
6
+ # Generate preview images
7
+ results = generate_images("a glass panel banner...", resolution="LinkedIn post", count=3)
8
+ # => [{"image_bytes": b"...", "aspect_ratio": "16:9"}, ...]
9
+
10
+ # Upscale from bytes
11
+ upscaled = upscale_image(source_bytes, target_size="2K")
12
+ # => b"..." (upscaled image bytes)
13
+
14
+ GCS upload is optional and separate. For local dev, tools return base64
15
+ and save as ADK artifacts. For production, a Cloud Run wrapper can handle
16
+ GCS upload via this same module.
17
+ """
18
+
19
+ import base64
20
+ import logging
21
+ import os
22
+
23
+ from google import genai
24
+ from google.genai import types
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # --- Model IDs ---
29
+
30
+ GENERATE_MODEL_ID = "imagen-4.0-fast-generate-001"
31
+ UPSCALE_MODEL_ID = "gemini-3.0-pro-image-generate-001"
32
+
33
+ # --- Resolution table ---
34
+
35
+ RESOLUTION_TABLE: dict[str, tuple[int, int]] = {
36
+ "LinkedIn post": (1200, 628),
37
+ "LinkedIn banner": (1584, 396),
38
+ "Twitter/X header": (1500, 500),
39
+ "Twitter/X post": (1200, 675),
40
+ "Facebook cover": (851, 315),
41
+ "Facebook post": (1200, 630),
42
+ "Instagram square": (1080, 1080),
43
+ "Instagram story": (1080, 1920),
44
+ "YouTube thumbnail": (1280, 720),
45
+ "YouTube banner": (2560, 1440),
46
+ "Blog universal": (1200, 628),
47
+ "Blog OG social": (1200, 630),
48
+ "WordPress featured image": (1200, 628),
49
+ "Medium header": (1400, 788),
50
+ "Presentation slide 16:9": (1920, 1080),
51
+ "Presentation slide 4:3": (1440, 1080),
52
+ }
53
+
54
+ DEFAULT_RESOLUTION = (1200, 628)
55
+
56
+ # --- Aspect ratio mapping for Imagen 4 ---
57
+
58
+ SUPPORTED_ASPECT_RATIOS = ["1:1", "9:16", "16:9", "3:4", "4:3"]
59
+
60
+ _ASPECT_RATIO_VALUES = {
61
+ "1:1": 1.0,
62
+ "9:16": 9 / 16,
63
+ "16:9": 16 / 9,
64
+ "3:4": 3 / 4,
65
+ "4:3": 4 / 3,
66
+ }
67
+
68
+ # --- Upscale target dimensions ---
69
+
70
+ UPSCALE_TARGETS = {
71
+ "2K": (2048, 2048),
72
+ "4K": (4096, 4096),
73
+ }
74
+
75
+
76
+ def _get_client() -> genai.Client:
77
+ """Create a google-genai client configured for Vertex AI."""
78
+ project = os.environ.get("GOOGLE_CLOUD_PROJECT", "concretio-compass-sb")
79
+ location = os.environ.get("GOOGLE_CLOUD_LOCATION", "us-central1")
80
+ return genai.Client(vertexai=True, project=project, location=location)
81
+
82
+
83
+ def _closest_aspect_ratio(width: int, height: int) -> str:
84
+ """Map width/height to the nearest supported Imagen 4 aspect ratio."""
85
+ target = width / height
86
+ return min(
87
+ SUPPORTED_ASPECT_RATIOS,
88
+ key=lambda r: abs(_ASPECT_RATIO_VALUES[r] - target),
89
+ )
90
+
91
+
92
+ def generate_images(
93
+ prompt: str,
94
+ resolution: str = "LinkedIn post",
95
+ count: int = 1,
96
+ ) -> list[dict]:
97
+ """Generate images using Imagen 4 Fast via Vertex AI.
98
+
99
+ Args:
100
+ prompt: The image generation prompt.
101
+ resolution: Platform resolution key (e.g. "LinkedIn post").
102
+ count: Number of images to generate (1-4).
103
+
104
+ Returns:
105
+ List of dicts, each with "image_bytes" (raw bytes) and "image_b64"
106
+ (base64-encoded string for ADK artifact display).
107
+
108
+ Raises:
109
+ Exception on API errors (safety rejection, quota, etc.).
110
+ """
111
+ width, height = RESOLUTION_TABLE.get(resolution, DEFAULT_RESOLUTION)
112
+ aspect_ratio = _closest_aspect_ratio(width, height)
113
+
114
+ client = _get_client()
115
+ response = client.models.generate_images(
116
+ model=GENERATE_MODEL_ID,
117
+ prompt=prompt,
118
+ config=types.GenerateImagesConfig(
119
+ number_of_images=count,
120
+ aspect_ratio=aspect_ratio,
121
+ output_mime_type="image/jpeg",
122
+ safety_filter_level="BLOCK_MEDIUM_AND_ABOVE",
123
+ person_generation="ALLOW_ADULT",
124
+ ),
125
+ )
126
+
127
+ results = []
128
+ for generated_image in response.generated_images:
129
+ img_bytes = generated_image.image.image_bytes
130
+ results.append({
131
+ "image_bytes": img_bytes,
132
+ "image_b64": base64.b64encode(img_bytes).decode("utf-8"),
133
+ })
134
+
135
+ logger.info(
136
+ "Generated %d images with %s, resolution=%s, aspect_ratio=%s",
137
+ len(results), GENERATE_MODEL_ID, resolution, aspect_ratio,
138
+ )
139
+ return results
140
+
141
+
142
+ def upscale_image(
143
+ source_bytes: bytes,
144
+ target_size: str = "2K",
145
+ ) -> bytes:
146
+ """Upscale an image using Gemini 3 Pro Image via Vertex AI.
147
+
148
+ Args:
149
+ source_bytes: Raw bytes of the source image.
150
+ target_size: Target resolution, "2K" or "4K".
151
+
152
+ Returns:
153
+ Upscaled image bytes (JPEG).
154
+
155
+ Raises:
156
+ ValueError if the response contains no image.
157
+ Exception on API errors.
158
+ """
159
+ target_width, target_height = UPSCALE_TARGETS.get(target_size, UPSCALE_TARGETS["2K"])
160
+
161
+ client = _get_client()
162
+ source_image = types.Part.from_bytes(data=source_bytes, mime_type="image/jpeg")
163
+
164
+ response = client.models.generate_content(
165
+ model=UPSCALE_MODEL_ID,
166
+ contents=[
167
+ source_image,
168
+ f"Upscale this image to {target_width}x{target_height} resolution. "
169
+ "Preserve all details, colors, and composition exactly. "
170
+ "Enhance clarity and sharpness without altering the content.",
171
+ ],
172
+ config=types.GenerateContentConfig(
173
+ response_modalities=["IMAGE", "TEXT"],
174
+ ),
175
+ )
176
+
177
+ for part in response.candidates[0].content.parts:
178
+ if part.inline_data and part.inline_data.mime_type.startswith("image/"):
179
+ logger.info("Upscaled image to %s using %s", target_size, UPSCALE_MODEL_ID)
180
+ return part.inline_data.data
181
+
182
+ raise ValueError("Upscale response did not contain an image")
183
+