agent-shared-core 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_shared_core-0.1.0/PKG-INFO +9 -0
- agent_shared_core-0.1.0/__init__.py +3 -0
- agent_shared_core-0.1.0/agent_shared_core.egg-info/PKG-INFO +9 -0
- agent_shared_core-0.1.0/agent_shared_core.egg-info/SOURCES.txt +22 -0
- agent_shared_core-0.1.0/agent_shared_core.egg-info/dependency_links.txt +1 -0
- agent_shared_core-0.1.0/agent_shared_core.egg-info/requires.txt +4 -0
- agent_shared_core-0.1.0/agent_shared_core.egg-info/top_level.txt +1 -0
- agent_shared_core-0.1.0/core/__init__.py +2 -0
- agent_shared_core-0.1.0/core/auth.py +79 -0
- agent_shared_core-0.1.0/core/checkpoints.py +58 -0
- agent_shared_core-0.1.0/core/config.py +37 -0
- agent_shared_core-0.1.0/core/content.py +215 -0
- agent_shared_core-0.1.0/core/firestore.py +31 -0
- agent_shared_core-0.1.0/core/image_gen.py +183 -0
- agent_shared_core-0.1.0/core/prompts.py +64 -0
- agent_shared_core-0.1.0/core/wallet.py +127 -0
- agent_shared_core-0.1.0/mktg/__init__.py +2 -0
- agent_shared_core-0.1.0/mktg/brands.py +119 -0
- agent_shared_core-0.1.0/mktg/compositions.py +66 -0
- agent_shared_core-0.1.0/mktg/platform_specs.py +158 -0
- agent_shared_core-0.1.0/mktg/themes.py +155 -0
- agent_shared_core-0.1.0/pyproject.toml +24 -0
- agent_shared_core-0.1.0/setup.cfg +4 -0
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: agent-shared-core
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Shared Compass core and marketing domain logic for agents
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Requires-Dist: google-cloud-firestore
|
|
7
|
+
Requires-Dist: google-cloud-storage
|
|
8
|
+
Requires-Dist: google-auth
|
|
9
|
+
Requires-Dist: trafilatura
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: agent-shared-core
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Shared Compass core and marketing domain logic for agents
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Requires-Dist: google-cloud-firestore
|
|
7
|
+
Requires-Dist: google-cloud-storage
|
|
8
|
+
Requires-Dist: google-auth
|
|
9
|
+
Requires-Dist: trafilatura
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
__init__.py
|
|
2
|
+
pyproject.toml
|
|
3
|
+
./__init__.py
|
|
4
|
+
agent_shared_core.egg-info/PKG-INFO
|
|
5
|
+
agent_shared_core.egg-info/SOURCES.txt
|
|
6
|
+
agent_shared_core.egg-info/dependency_links.txt
|
|
7
|
+
agent_shared_core.egg-info/requires.txt
|
|
8
|
+
agent_shared_core.egg-info/top_level.txt
|
|
9
|
+
core/__init__.py
|
|
10
|
+
core/auth.py
|
|
11
|
+
core/checkpoints.py
|
|
12
|
+
core/config.py
|
|
13
|
+
core/content.py
|
|
14
|
+
core/firestore.py
|
|
15
|
+
core/image_gen.py
|
|
16
|
+
core/prompts.py
|
|
17
|
+
core/wallet.py
|
|
18
|
+
mktg/__init__.py
|
|
19
|
+
mktg/brands.py
|
|
20
|
+
mktg/compositions.py
|
|
21
|
+
mktg/platform_specs.py
|
|
22
|
+
mktg/themes.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
agent_shared_core
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Role resolution via Google Directory API with TTL cache.
|
|
2
|
+
|
|
3
|
+
Usage:
|
|
4
|
+
from agent_shared_core.core.auth import resolve_role
|
|
5
|
+
role = resolve_role("user@company.com", "mktg") # returns "admin" or "exec"
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import time
|
|
11
|
+
from typing import Dict, Tuple
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
# Cache: key = (email, dept), value = (role, timestamp)
|
|
16
|
+
_role_cache: Dict[Tuple[str, str], Tuple[str, float]] = {}
|
|
17
|
+
|
|
18
|
+
# Cache TTL in seconds (5 minutes)
|
|
19
|
+
_CACHE_TTL = 300
|
|
20
|
+
|
|
21
|
+
# Read domain from environment. Set per customer deployment.
|
|
22
|
+
# Default to company.com for the dev sandbox (concretio-compass-sb).
|
|
23
|
+
_GROUP_DOMAIN = os.environ.get("COMPASS_GROUP_DOMAIN", "company.com")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _check_group_membership(email: str, group_email: str) -> bool:
|
|
27
|
+
"""Check if email is a member of the given Google Group.
|
|
28
|
+
|
|
29
|
+
Uses the Google Admin SDK Directory API.
|
|
30
|
+
"""
|
|
31
|
+
from googleapiclient.discovery import build
|
|
32
|
+
import google.auth
|
|
33
|
+
|
|
34
|
+
credentials, _ = google.auth.default(
|
|
35
|
+
scopes=["https://www.googleapis.com/auth/admin.directory.group.member.readonly"]
|
|
36
|
+
)
|
|
37
|
+
service = build("admin", "directory_v1", credentials=credentials)
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
result = service.members().hasMember(
|
|
41
|
+
groupKey=group_email, memberKey=email
|
|
42
|
+
).execute()
|
|
43
|
+
return result.get("isMember", False)
|
|
44
|
+
except Exception:
|
|
45
|
+
return False
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def resolve_role(email: str, dept: str) -> str:
|
|
49
|
+
"""Resolve user role based on Google Group membership.
|
|
50
|
+
|
|
51
|
+
Returns "admin" if the user is in {dept}-admins@company.com,
|
|
52
|
+
otherwise returns "exec". On API failure, defaults to "exec" (safe default).
|
|
53
|
+
|
|
54
|
+
Results are cached for 5 minutes per (email, dept) pair.
|
|
55
|
+
"""
|
|
56
|
+
cache_key = (email, dept)
|
|
57
|
+
|
|
58
|
+
# Check cache
|
|
59
|
+
if cache_key in _role_cache:
|
|
60
|
+
cached_role, cached_time = _role_cache[cache_key]
|
|
61
|
+
if time.time() - cached_time < _CACHE_TTL:
|
|
62
|
+
return cached_role
|
|
63
|
+
|
|
64
|
+
# Query Directory API
|
|
65
|
+
try:
|
|
66
|
+
admin_group = f"{dept}-admins@{_GROUP_DOMAIN}"
|
|
67
|
+
is_admin = _check_group_membership(email, admin_group)
|
|
68
|
+
role = "admin" if is_admin else "exec"
|
|
69
|
+
except Exception:
|
|
70
|
+
logger.warning(
|
|
71
|
+
"Directory API failure for %s in dept %s, defaulting to exec",
|
|
72
|
+
email, dept,
|
|
73
|
+
)
|
|
74
|
+
role = "exec"
|
|
75
|
+
|
|
76
|
+
# Cache the result
|
|
77
|
+
_role_cache[cache_key] = (role, time.time())
|
|
78
|
+
return role
|
|
79
|
+
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Workflow checkpoints via ADK session state (no external deps).
|
|
2
|
+
|
|
3
|
+
Agents use checkpoints to persist intermediate progress within a session.
|
|
4
|
+
Data is stored in tool_context.state["checkpoint"] as a flat dict.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from agent_shared_core.core.checkpoints import save_checkpoint, read_checkpoint
|
|
8
|
+
|
|
9
|
+
# Save progress after a step completes
|
|
10
|
+
save_checkpoint(tool_context, {"step": "theme_selected", "theme_id": "abc123"})
|
|
11
|
+
|
|
12
|
+
# Read current checkpoint (returns None if no checkpoint exists)
|
|
13
|
+
cp = read_checkpoint(tool_context)
|
|
14
|
+
if cp:
|
|
15
|
+
print(cp["step"]) # "theme_selected"
|
|
16
|
+
|
|
17
|
+
# Subsequent saves merge into the existing checkpoint
|
|
18
|
+
save_checkpoint(tool_context, {"step": "preview_generated", "image_url": "https://..."})
|
|
19
|
+
# checkpoint now has: step, theme_id, image_url, last_updated
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import logging
|
|
23
|
+
from datetime import datetime, timezone
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def save_checkpoint(tool_context, data: dict) -> None:
|
|
29
|
+
"""Merge data into the session checkpoint and set last_updated.
|
|
30
|
+
|
|
31
|
+
Reads the existing checkpoint from tool_context.state, merges new
|
|
32
|
+
keys (overwriting on conflict), stamps the current UTC time, and
|
|
33
|
+
writes back.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
tool_context: ADK ToolContext with a .state dict.
|
|
37
|
+
data: key-value pairs to merge into the checkpoint.
|
|
38
|
+
"""
|
|
39
|
+
checkpoint = tool_context.state.get("checkpoint", {})
|
|
40
|
+
checkpoint.update(data)
|
|
41
|
+
checkpoint["last_updated"] = datetime.now(timezone.utc).isoformat()
|
|
42
|
+
tool_context.state["checkpoint"] = checkpoint
|
|
43
|
+
|
|
44
|
+
step = data.get("step", "unknown")
|
|
45
|
+
logger.info("Checkpoint saved: step=%s", step)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def read_checkpoint(tool_context) -> dict | None:
|
|
49
|
+
"""Read the current checkpoint from session state.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
tool_context: ADK ToolContext with a .state dict.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
The checkpoint dict, or None if no checkpoint exists.
|
|
56
|
+
"""
|
|
57
|
+
return tool_context.state.get("checkpoint") or None
|
|
58
|
+
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Runtime configuration from Firestore /config/{agent_name}.
|
|
2
|
+
|
|
3
|
+
Usage:
|
|
4
|
+
from agent_shared_core.core.config import get_config
|
|
5
|
+
|
|
6
|
+
config = get_config("mktg_banner_generator")
|
|
7
|
+
preview_count = config.get("preview_count", 3)
|
|
8
|
+
model_id = config.get("model_id", "imagen-4-fast")
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_config(agent_name: str) -> dict:
|
|
17
|
+
"""Read agent configuration from Firestore /config/{agent_name}.
|
|
18
|
+
|
|
19
|
+
Returns the document as a dict. Returns empty dict if the document
|
|
20
|
+
does not exist or on Firestore failure. Callers use the
|
|
21
|
+
.get(key, default) pattern to handle missing keys.
|
|
22
|
+
"""
|
|
23
|
+
from agent_shared_core.core.firestore import get_client
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
doc = get_client().collection("config").document(agent_name).get()
|
|
27
|
+
if doc.exists:
|
|
28
|
+
return doc.to_dict()
|
|
29
|
+
return {}
|
|
30
|
+
except Exception:
|
|
31
|
+
logger.warning(
|
|
32
|
+
"Failed to read config for %s, returning empty dict",
|
|
33
|
+
agent_name,
|
|
34
|
+
exc_info=True,
|
|
35
|
+
)
|
|
36
|
+
return {}
|
|
37
|
+
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""Content extraction and context analysis for agent workflows.
|
|
2
|
+
|
|
3
|
+
Usage:
|
|
4
|
+
from agent_shared_core.core.content import fetch_content, extract_content_context
|
|
5
|
+
|
|
6
|
+
# Fetch from URL
|
|
7
|
+
result = fetch_content("https://example.com/article")
|
|
8
|
+
# => {"source": "https://...", "content": "...", "content_type": "url", "success": True}
|
|
9
|
+
|
|
10
|
+
# Pass raw text
|
|
11
|
+
result = fetch_content("Some plain text content here.")
|
|
12
|
+
# => {"source": "Some plain...", "content": "Some plain...", "content_type": "text", "success": True}
|
|
13
|
+
|
|
14
|
+
# Pass markdown
|
|
15
|
+
result = fetch_content("# Heading\n\n**Bold** text with [links](url)")
|
|
16
|
+
# => {"source": "# Heading...", "content": "Heading\nBold text with links", "content_type": "markdown", "success": True}
|
|
17
|
+
|
|
18
|
+
# Extract structured context via Gemini Flash
|
|
19
|
+
context = extract_content_context("Article text here...")
|
|
20
|
+
# => {"title": "...", "vertical": "...", "key_themes": [...], ...}
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import json
|
|
24
|
+
import logging
|
|
25
|
+
import os
|
|
26
|
+
import re
|
|
27
|
+
|
|
28
|
+
import trafilatura
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
_MAX_CONTENT_LENGTH = 15_000
|
|
33
|
+
|
|
34
|
+
_EXTRACTION_PROMPT = """Extract structured information from the following content.
|
|
35
|
+
Return ONLY valid JSON with these keys:
|
|
36
|
+
- title (string): the main title or topic
|
|
37
|
+
- vertical (string): the industry or domain vertical
|
|
38
|
+
- key_themes (list of strings): 3-5 main themes
|
|
39
|
+
- mood (string): overall tone or mood
|
|
40
|
+
- visual_cues (list of strings): 2-4 visual elements suggested by the content
|
|
41
|
+
- color_hints (list of strings): 2-4 colors associated with the content's mood or vertical
|
|
42
|
+
- summary (string): 2-3 sentence summary
|
|
43
|
+
|
|
44
|
+
Content:
|
|
45
|
+
{content}
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def fetch_content(source: str) -> dict:
|
|
50
|
+
"""Fetch and normalize content from a URL, raw text, or markdown.
|
|
51
|
+
|
|
52
|
+
URLs are fetched via trafilatura. Markdown is stripped of formatting.
|
|
53
|
+
All content is truncated to 15,000 chars.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
source: a URL (http/https), markdown string, or plain text.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Dict with keys: source, content, content_type, success, and error (on failure).
|
|
60
|
+
"""
|
|
61
|
+
if source.startswith("http://") or source.startswith("https://"):
|
|
62
|
+
return _fetch_url(source)
|
|
63
|
+
|
|
64
|
+
if _looks_like_markdown(source):
|
|
65
|
+
text = _strip_markdown(source)[:_MAX_CONTENT_LENGTH]
|
|
66
|
+
return {"source": source, "content": text, "content_type": "markdown", "success": True}
|
|
67
|
+
|
|
68
|
+
text = source[:_MAX_CONTENT_LENGTH]
|
|
69
|
+
return {"source": source, "content": text, "content_type": "text", "success": True}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def extract_content_context(text: str) -> dict:
|
|
73
|
+
"""Extract structured context from text using Gemini 2.0 Flash.
|
|
74
|
+
|
|
75
|
+
Calls the model with a structured extraction prompt and parses the
|
|
76
|
+
JSON response. Logs cost via agent_shared_core.core.wallet.log_cost().
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
text: the content to analyze.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Dict with title, vertical, key_themes, mood, visual_cues, color_hints, summary.
|
|
83
|
+
On failure, returns {"error": "..."}.
|
|
84
|
+
"""
|
|
85
|
+
from google import genai
|
|
86
|
+
|
|
87
|
+
project = os.environ.get("GOOGLE_CLOUD_PROJECT", "concretio-compass-sb")
|
|
88
|
+
location = os.environ.get("GOOGLE_CLOUD_LOCATION", "us-central1")
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
client = genai.Client(
|
|
92
|
+
vertexai=True,
|
|
93
|
+
project=project,
|
|
94
|
+
location=location,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
prompt = _EXTRACTION_PROMPT.format(content=text[:_MAX_CONTENT_LENGTH])
|
|
98
|
+
response = client.models.generate_content(
|
|
99
|
+
model="gemini-2.0-flash",
|
|
100
|
+
contents=prompt,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
raw = response.text.strip()
|
|
104
|
+
# Strip markdown code fences if the model wraps its output
|
|
105
|
+
if raw.startswith("```"):
|
|
106
|
+
raw = re.sub(r"^```(?:json)?\s*", "", raw)
|
|
107
|
+
raw = re.sub(r"\s*```$", "", raw)
|
|
108
|
+
|
|
109
|
+
result = json.loads(raw)
|
|
110
|
+
|
|
111
|
+
# Log cost estimate
|
|
112
|
+
from agent_shared_core.core.wallet import log_cost
|
|
113
|
+
log_cost(
|
|
114
|
+
agent="content_extraction",
|
|
115
|
+
user_email="system",
|
|
116
|
+
model="gemini-2.0-flash",
|
|
117
|
+
cost=0.001,
|
|
118
|
+
phase="extract_context",
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
logger.info("Content context extracted: title=%s", result.get("title", "unknown"))
|
|
122
|
+
return result
|
|
123
|
+
|
|
124
|
+
except json.JSONDecodeError as exc:
|
|
125
|
+
logger.error("Failed to parse Gemini response as JSON: %s", exc)
|
|
126
|
+
return {"error": f"JSON parse error: {exc}"}
|
|
127
|
+
except Exception as exc:
|
|
128
|
+
logger.error("Content context extraction failed: %s", exc, exc_info=True)
|
|
129
|
+
return {"error": str(exc)}
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
# --- Internal helpers ---
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _fetch_url(url: str) -> dict:
|
|
136
|
+
"""Download and extract article text from a URL via trafilatura."""
|
|
137
|
+
try:
|
|
138
|
+
downloaded = trafilatura.fetch_url(url)
|
|
139
|
+
if downloaded is None:
|
|
140
|
+
return {
|
|
141
|
+
"source": url,
|
|
142
|
+
"content": "",
|
|
143
|
+
"content_type": "url",
|
|
144
|
+
"success": False,
|
|
145
|
+
"error": f"Could not fetch URL: {url}",
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
text = trafilatura.extract(downloaded)
|
|
149
|
+
if text is None:
|
|
150
|
+
return {
|
|
151
|
+
"source": url,
|
|
152
|
+
"content": "",
|
|
153
|
+
"content_type": "url",
|
|
154
|
+
"success": False,
|
|
155
|
+
"error": f"Could not extract content from: {url}",
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
truncated = text[:_MAX_CONTENT_LENGTH]
|
|
159
|
+
return {
|
|
160
|
+
"source": url,
|
|
161
|
+
"content": truncated,
|
|
162
|
+
"content_type": "url",
|
|
163
|
+
"success": True,
|
|
164
|
+
}
|
|
165
|
+
except Exception as exc:
|
|
166
|
+
logger.error("URL fetch failed for %s: %s", url, exc)
|
|
167
|
+
return {
|
|
168
|
+
"source": url,
|
|
169
|
+
"content": "",
|
|
170
|
+
"content_type": "url",
|
|
171
|
+
"success": False,
|
|
172
|
+
"error": f"Extraction failed: {exc}",
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _looks_like_markdown(text: str) -> bool:
|
|
177
|
+
"""Detect if text contains common markdown patterns."""
|
|
178
|
+
md_patterns = [
|
|
179
|
+
r"^#{1,6}\s", # headings
|
|
180
|
+
r"\*\*.+\*\*", # bold
|
|
181
|
+
r"\[.+\]\(.+\)", # links
|
|
182
|
+
r"^[-*]\s", # unordered lists
|
|
183
|
+
r"^\d+\.\s", # ordered lists
|
|
184
|
+
r"^>\s", # blockquotes
|
|
185
|
+
r"```", # code blocks
|
|
186
|
+
]
|
|
187
|
+
for pattern in md_patterns:
|
|
188
|
+
if re.search(pattern, text, re.MULTILINE):
|
|
189
|
+
return True
|
|
190
|
+
return False
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _strip_markdown(text: str) -> str:
|
|
194
|
+
"""Remove common markdown formatting, returning plain text."""
|
|
195
|
+
# Remove code blocks
|
|
196
|
+
text = re.sub(r"```[\s\S]*?```", "", text)
|
|
197
|
+
# Remove inline code
|
|
198
|
+
text = re.sub(r"`([^`]+)`", r"\1", text)
|
|
199
|
+
# Remove images
|
|
200
|
+
text = re.sub(r"!\[([^\]]*)\]\([^)]+\)", r"\1", text)
|
|
201
|
+
# Remove links, keep text
|
|
202
|
+
text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", text)
|
|
203
|
+
# Remove headings markers
|
|
204
|
+
text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE)
|
|
205
|
+
# Remove bold/italic markers
|
|
206
|
+
text = re.sub(r"\*{1,3}([^*]+)\*{1,3}", r"\1", text)
|
|
207
|
+
text = re.sub(r"_{1,3}([^_]+)_{1,3}", r"\1", text)
|
|
208
|
+
# Remove blockquote markers
|
|
209
|
+
text = re.sub(r"^>\s?", "", text, flags=re.MULTILINE)
|
|
210
|
+
# Remove horizontal rules
|
|
211
|
+
text = re.sub(r"^[-*_]{3,}\s*$", "", text, flags=re.MULTILINE)
|
|
212
|
+
# Collapse multiple blank lines
|
|
213
|
+
text = re.sub(r"\n{3,}", "\n\n", text)
|
|
214
|
+
return text.strip()
|
|
215
|
+
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Singleton Firestore client for the Compass Agent Framework.
|
|
2
|
+
|
|
3
|
+
Usage:
|
|
4
|
+
from agent_shared_core.core.firestore import get_client
|
|
5
|
+
db = get_client()
|
|
6
|
+
doc = db.collection("brands").document("abc").get()
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
from google.cloud import firestore
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
_client = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_client() -> firestore.Client:
|
|
20
|
+
"""Return a lazy-initialized singleton Firestore client.
|
|
21
|
+
|
|
22
|
+
Uses GOOGLE_CLOUD_PROJECT from the environment, defaulting to
|
|
23
|
+
concretio-compass-sb (the sandbox project).
|
|
24
|
+
"""
|
|
25
|
+
global _client
|
|
26
|
+
if _client is None:
|
|
27
|
+
project = os.environ.get("GOOGLE_CLOUD_PROJECT", "concretio-compass-sb")
|
|
28
|
+
_client = firestore.Client(project=project)
|
|
29
|
+
logger.info("Firestore client initialized for project: %s", project)
|
|
30
|
+
return _client
|
|
31
|
+
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
"""Image generation and upscaling via Vertex AI (Imagen 4 + Gemini 3 Pro).
|
|
2
|
+
|
|
3
|
+
Usage:
|
|
4
|
+
from agent_shared_core.core.image_gen import generate_images, upscale_image
|
|
5
|
+
|
|
6
|
+
# Generate preview images
|
|
7
|
+
results = generate_images("a glass panel banner...", resolution="LinkedIn post", count=3)
|
|
8
|
+
# => [{"image_bytes": b"...", "aspect_ratio": "16:9"}, ...]
|
|
9
|
+
|
|
10
|
+
# Upscale from bytes
|
|
11
|
+
upscaled = upscale_image(source_bytes, target_size="2K")
|
|
12
|
+
# => b"..." (upscaled image bytes)
|
|
13
|
+
|
|
14
|
+
GCS upload is optional and separate. For local dev, tools return base64
|
|
15
|
+
and save as ADK artifacts. For production, a Cloud Run wrapper can handle
|
|
16
|
+
GCS upload via this same module.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import base64
|
|
20
|
+
import logging
|
|
21
|
+
import os
|
|
22
|
+
|
|
23
|
+
from google import genai
|
|
24
|
+
from google.genai import types
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
# --- Model IDs ---
|
|
29
|
+
|
|
30
|
+
GENERATE_MODEL_ID = "imagen-4.0-fast-generate-001"
|
|
31
|
+
UPSCALE_MODEL_ID = "gemini-3.0-pro-image-generate-001"
|
|
32
|
+
|
|
33
|
+
# --- Resolution table ---
|
|
34
|
+
|
|
35
|
+
RESOLUTION_TABLE: dict[str, tuple[int, int]] = {
|
|
36
|
+
"LinkedIn post": (1200, 628),
|
|
37
|
+
"LinkedIn banner": (1584, 396),
|
|
38
|
+
"Twitter/X header": (1500, 500),
|
|
39
|
+
"Twitter/X post": (1200, 675),
|
|
40
|
+
"Facebook cover": (851, 315),
|
|
41
|
+
"Facebook post": (1200, 630),
|
|
42
|
+
"Instagram square": (1080, 1080),
|
|
43
|
+
"Instagram story": (1080, 1920),
|
|
44
|
+
"YouTube thumbnail": (1280, 720),
|
|
45
|
+
"YouTube banner": (2560, 1440),
|
|
46
|
+
"Blog universal": (1200, 628),
|
|
47
|
+
"Blog OG social": (1200, 630),
|
|
48
|
+
"WordPress featured image": (1200, 628),
|
|
49
|
+
"Medium header": (1400, 788),
|
|
50
|
+
"Presentation slide 16:9": (1920, 1080),
|
|
51
|
+
"Presentation slide 4:3": (1440, 1080),
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
DEFAULT_RESOLUTION = (1200, 628)
|
|
55
|
+
|
|
56
|
+
# --- Aspect ratio mapping for Imagen 4 ---
|
|
57
|
+
|
|
58
|
+
SUPPORTED_ASPECT_RATIOS = ["1:1", "9:16", "16:9", "3:4", "4:3"]
|
|
59
|
+
|
|
60
|
+
_ASPECT_RATIO_VALUES = {
|
|
61
|
+
"1:1": 1.0,
|
|
62
|
+
"9:16": 9 / 16,
|
|
63
|
+
"16:9": 16 / 9,
|
|
64
|
+
"3:4": 3 / 4,
|
|
65
|
+
"4:3": 4 / 3,
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
# --- Upscale target dimensions ---
|
|
69
|
+
|
|
70
|
+
UPSCALE_TARGETS = {
|
|
71
|
+
"2K": (2048, 2048),
|
|
72
|
+
"4K": (4096, 4096),
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _get_client() -> genai.Client:
|
|
77
|
+
"""Create a google-genai client configured for Vertex AI."""
|
|
78
|
+
project = os.environ.get("GOOGLE_CLOUD_PROJECT", "concretio-compass-sb")
|
|
79
|
+
location = os.environ.get("GOOGLE_CLOUD_LOCATION", "us-central1")
|
|
80
|
+
return genai.Client(vertexai=True, project=project, location=location)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _closest_aspect_ratio(width: int, height: int) -> str:
|
|
84
|
+
"""Map width/height to the nearest supported Imagen 4 aspect ratio."""
|
|
85
|
+
target = width / height
|
|
86
|
+
return min(
|
|
87
|
+
SUPPORTED_ASPECT_RATIOS,
|
|
88
|
+
key=lambda r: abs(_ASPECT_RATIO_VALUES[r] - target),
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def generate_images(
|
|
93
|
+
prompt: str,
|
|
94
|
+
resolution: str = "LinkedIn post",
|
|
95
|
+
count: int = 1,
|
|
96
|
+
) -> list[dict]:
|
|
97
|
+
"""Generate images using Imagen 4 Fast via Vertex AI.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
prompt: The image generation prompt.
|
|
101
|
+
resolution: Platform resolution key (e.g. "LinkedIn post").
|
|
102
|
+
count: Number of images to generate (1-4).
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
List of dicts, each with "image_bytes" (raw bytes) and "image_b64"
|
|
106
|
+
(base64-encoded string for ADK artifact display).
|
|
107
|
+
|
|
108
|
+
Raises:
|
|
109
|
+
Exception on API errors (safety rejection, quota, etc.).
|
|
110
|
+
"""
|
|
111
|
+
width, height = RESOLUTION_TABLE.get(resolution, DEFAULT_RESOLUTION)
|
|
112
|
+
aspect_ratio = _closest_aspect_ratio(width, height)
|
|
113
|
+
|
|
114
|
+
client = _get_client()
|
|
115
|
+
response = client.models.generate_images(
|
|
116
|
+
model=GENERATE_MODEL_ID,
|
|
117
|
+
prompt=prompt,
|
|
118
|
+
config=types.GenerateImagesConfig(
|
|
119
|
+
number_of_images=count,
|
|
120
|
+
aspect_ratio=aspect_ratio,
|
|
121
|
+
output_mime_type="image/jpeg",
|
|
122
|
+
safety_filter_level="BLOCK_MEDIUM_AND_ABOVE",
|
|
123
|
+
person_generation="ALLOW_ADULT",
|
|
124
|
+
),
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
results = []
|
|
128
|
+
for generated_image in response.generated_images:
|
|
129
|
+
img_bytes = generated_image.image.image_bytes
|
|
130
|
+
results.append({
|
|
131
|
+
"image_bytes": img_bytes,
|
|
132
|
+
"image_b64": base64.b64encode(img_bytes).decode("utf-8"),
|
|
133
|
+
})
|
|
134
|
+
|
|
135
|
+
logger.info(
|
|
136
|
+
"Generated %d images with %s, resolution=%s, aspect_ratio=%s",
|
|
137
|
+
len(results), GENERATE_MODEL_ID, resolution, aspect_ratio,
|
|
138
|
+
)
|
|
139
|
+
return results
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def upscale_image(
|
|
143
|
+
source_bytes: bytes,
|
|
144
|
+
target_size: str = "2K",
|
|
145
|
+
) -> bytes:
|
|
146
|
+
"""Upscale an image using Gemini 3 Pro Image via Vertex AI.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
source_bytes: Raw bytes of the source image.
|
|
150
|
+
target_size: Target resolution, "2K" or "4K".
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Upscaled image bytes (JPEG).
|
|
154
|
+
|
|
155
|
+
Raises:
|
|
156
|
+
ValueError if the response contains no image.
|
|
157
|
+
Exception on API errors.
|
|
158
|
+
"""
|
|
159
|
+
target_width, target_height = UPSCALE_TARGETS.get(target_size, UPSCALE_TARGETS["2K"])
|
|
160
|
+
|
|
161
|
+
client = _get_client()
|
|
162
|
+
source_image = types.Part.from_bytes(data=source_bytes, mime_type="image/jpeg")
|
|
163
|
+
|
|
164
|
+
response = client.models.generate_content(
|
|
165
|
+
model=UPSCALE_MODEL_ID,
|
|
166
|
+
contents=[
|
|
167
|
+
source_image,
|
|
168
|
+
f"Upscale this image to {target_width}x{target_height} resolution. "
|
|
169
|
+
"Preserve all details, colors, and composition exactly. "
|
|
170
|
+
"Enhance clarity and sharpness without altering the content.",
|
|
171
|
+
],
|
|
172
|
+
config=types.GenerateContentConfig(
|
|
173
|
+
response_modalities=["IMAGE", "TEXT"],
|
|
174
|
+
),
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
for part in response.candidates[0].content.parts:
|
|
178
|
+
if part.inline_data and part.inline_data.mime_type.startswith("image/"):
|
|
179
|
+
logger.info("Upscaled image to %s using %s", target_size, UPSCALE_MODEL_ID)
|
|
180
|
+
return part.inline_data.data
|
|
181
|
+
|
|
182
|
+
raise ValueError("Upscale response did not contain an image")
|
|
183
|
+
|