oshell 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oshell/__init__.py +3 -0
- oshell/agent.py +146 -0
- oshell/cli.py +836 -0
- oshell/config.py +91 -0
- oshell/history.py +57 -0
- oshell/media.py +188 -0
- oshell/media_agent.py +55 -0
- oshell/personas.py +90 -0
- oshell/providers.py +275 -0
- oshell/retrieval.py +180 -0
- oshell/storyboard.py +163 -0
- oshell/tools.py +150 -0
- oshell-0.1.1.dist-info/METADATA +286 -0
- oshell-0.1.1.dist-info/RECORD +17 -0
- oshell-0.1.1.dist-info/WHEEL +4 -0
- oshell-0.1.1.dist-info/entry_points.txt +3 -0
- oshell-0.1.1.dist-info/licenses/LICENSE +21 -0
oshell/providers.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""LLM provider layer.
|
|
2
|
+
|
|
3
|
+
Both providers expose a single ``stream_chat`` method that yields text
|
|
4
|
+
chunks as they arrive, so the CLI can render tokens live.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from typing import Iterator, List, Protocol
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
from .config import Config
|
|
15
|
+
|
|
16
|
+
Message = dict # {"role": "user" | "assistant" | "system", "content": str}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Provider(Protocol):
|
|
20
|
+
"""Common interface every LLM backend implements."""
|
|
21
|
+
|
|
22
|
+
name: str
|
|
23
|
+
|
|
24
|
+
def stream_chat(self, messages: List[Message]) -> Iterator[str]:
|
|
25
|
+
"""Yield response text chunks for the given conversation."""
|
|
26
|
+
...
|
|
27
|
+
|
|
28
|
+
def list_models(self) -> List[str]:
|
|
29
|
+
"""Return available model names (best effort)."""
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class OllamaProvider:
|
|
34
|
+
"""Talks to a local Ollama server (https://ollama.com)."""
|
|
35
|
+
|
|
36
|
+
name = "ollama"
|
|
37
|
+
|
|
38
|
+
def __init__(self, cfg: Config) -> None:
|
|
39
|
+
self.host = cfg.ollama_host.rstrip("/")
|
|
40
|
+
self.model = cfg.model
|
|
41
|
+
self.temperature = cfg.temperature
|
|
42
|
+
|
|
43
|
+
def stream_chat(self, messages: List[Message]) -> Iterator[str]:
|
|
44
|
+
payload = {
|
|
45
|
+
"model": self.model,
|
|
46
|
+
"messages": messages,
|
|
47
|
+
"stream": True,
|
|
48
|
+
"options": {"temperature": self.temperature},
|
|
49
|
+
}
|
|
50
|
+
with httpx.stream(
|
|
51
|
+
"POST", f"{self.host}/api/chat", json=payload, timeout=None
|
|
52
|
+
) as resp:
|
|
53
|
+
resp.raise_for_status()
|
|
54
|
+
for line in resp.iter_lines():
|
|
55
|
+
if not line:
|
|
56
|
+
continue
|
|
57
|
+
data = json.loads(line)
|
|
58
|
+
if data.get("done"):
|
|
59
|
+
break
|
|
60
|
+
chunk = data.get("message", {}).get("content", "")
|
|
61
|
+
if chunk:
|
|
62
|
+
yield chunk
|
|
63
|
+
|
|
64
|
+
def list_models(self) -> List[str]:
|
|
65
|
+
try:
|
|
66
|
+
resp = httpx.get(f"{self.host}/api/tags", timeout=10)
|
|
67
|
+
resp.raise_for_status()
|
|
68
|
+
return [m["name"] for m in resp.json().get("models", [])]
|
|
69
|
+
except httpx.HTTPError:
|
|
70
|
+
return []
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class OpenAICompatibleProvider:
|
|
74
|
+
"""Base for any service speaking the OpenAI Chat Completions API.
|
|
75
|
+
|
|
76
|
+
OpenAI, Groq, and Gemini (via its OpenAI-compatible endpoint) all share
|
|
77
|
+
this exact wire format, so they differ only in base URL, key, and the
|
|
78
|
+
error message shown when the key is missing.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
name = "openai"
|
|
82
|
+
_key_hint = "OPENAI_API_KEY"
|
|
83
|
+
|
|
84
|
+
def __init__(self, base_url: str, api_key: str, model: str, temperature: float):
|
|
85
|
+
if not api_key:
|
|
86
|
+
raise ValueError(
|
|
87
|
+
f"{self.name.title()} API key not set. Export {self._key_hint} or run "
|
|
88
|
+
f"`oshell config set {self.name}_api_key <key>`."
|
|
89
|
+
)
|
|
90
|
+
self.base_url = base_url.rstrip("/")
|
|
91
|
+
self.api_key = api_key
|
|
92
|
+
self.model = model
|
|
93
|
+
self.temperature = temperature
|
|
94
|
+
|
|
95
|
+
def stream_chat(self, messages: List[Message]) -> Iterator[str]:
|
|
96
|
+
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
97
|
+
payload = {
|
|
98
|
+
"model": self.model,
|
|
99
|
+
"messages": messages,
|
|
100
|
+
"stream": True,
|
|
101
|
+
"temperature": self.temperature,
|
|
102
|
+
}
|
|
103
|
+
with httpx.stream(
|
|
104
|
+
"POST",
|
|
105
|
+
f"{self.base_url}/chat/completions",
|
|
106
|
+
json=payload,
|
|
107
|
+
headers=headers,
|
|
108
|
+
timeout=None,
|
|
109
|
+
) as resp:
|
|
110
|
+
resp.raise_for_status()
|
|
111
|
+
for line in resp.iter_lines():
|
|
112
|
+
if not line or not line.startswith("data: "):
|
|
113
|
+
continue
|
|
114
|
+
data_str = line[len("data: ") :]
|
|
115
|
+
if data_str.strip() == "[DONE]":
|
|
116
|
+
break
|
|
117
|
+
try:
|
|
118
|
+
data = json.loads(data_str)
|
|
119
|
+
except json.JSONDecodeError:
|
|
120
|
+
continue
|
|
121
|
+
choices = data.get("choices") or []
|
|
122
|
+
if not choices:
|
|
123
|
+
continue
|
|
124
|
+
chunk = choices[0].get("delta", {}).get("content")
|
|
125
|
+
if chunk:
|
|
126
|
+
yield chunk
|
|
127
|
+
|
|
128
|
+
def list_models(self) -> List[str]:
|
|
129
|
+
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
130
|
+
try:
|
|
131
|
+
resp = httpx.get(
|
|
132
|
+
f"{self.base_url}/models", headers=headers, timeout=10
|
|
133
|
+
)
|
|
134
|
+
resp.raise_for_status()
|
|
135
|
+
return sorted(m["id"] for m in resp.json().get("data", []))
|
|
136
|
+
except httpx.HTTPError:
|
|
137
|
+
return []
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class OpenAIProvider(OpenAICompatibleProvider):
|
|
141
|
+
"""OpenAI Chat Completions API."""
|
|
142
|
+
|
|
143
|
+
name = "openai"
|
|
144
|
+
_key_hint = "OPENAI_API_KEY"
|
|
145
|
+
|
|
146
|
+
def __init__(self, cfg: Config) -> None:
|
|
147
|
+
super().__init__(
|
|
148
|
+
cfg.openai_base_url, cfg.openai_api_key, cfg.model, cfg.temperature
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class GroqProvider(OpenAICompatibleProvider):
|
|
153
|
+
"""Groq — extremely fast inference, OpenAI-compatible API."""
|
|
154
|
+
|
|
155
|
+
name = "groq"
|
|
156
|
+
_key_hint = "GROQ_API_KEY"
|
|
157
|
+
|
|
158
|
+
def __init__(self, cfg: Config) -> None:
|
|
159
|
+
model = cfg.model if cfg.model else "llama-3.3-70b-versatile"
|
|
160
|
+
super().__init__(
|
|
161
|
+
"https://api.groq.com/openai/v1",
|
|
162
|
+
cfg.groq_api_key,
|
|
163
|
+
model,
|
|
164
|
+
cfg.temperature,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class GeminiProvider(OpenAICompatibleProvider):
|
|
169
|
+
"""Google Gemini via its OpenAI-compatible endpoint."""
|
|
170
|
+
|
|
171
|
+
name = "gemini"
|
|
172
|
+
_key_hint = "GEMINI_API_KEY"
|
|
173
|
+
|
|
174
|
+
def __init__(self, cfg: Config) -> None:
|
|
175
|
+
model = cfg.model if cfg.model else "gemini-1.5-flash"
|
|
176
|
+
super().__init__(
|
|
177
|
+
"https://generativelanguage.googleapis.com/v1beta/openai",
|
|
178
|
+
cfg.gemini_api_key,
|
|
179
|
+
model,
|
|
180
|
+
cfg.temperature,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class AnthropicProvider:
|
|
185
|
+
"""Anthropic Claude via the native Messages API (SSE streaming)."""
|
|
186
|
+
|
|
187
|
+
name = "anthropic"
|
|
188
|
+
|
|
189
|
+
def __init__(self, cfg: Config) -> None:
|
|
190
|
+
if not cfg.anthropic_api_key:
|
|
191
|
+
raise ValueError(
|
|
192
|
+
"Anthropic API key not set. Export ANTHROPIC_API_KEY or run "
|
|
193
|
+
"`oshell config set anthropic_api_key <key>`."
|
|
194
|
+
)
|
|
195
|
+
self.api_key = cfg.anthropic_api_key
|
|
196
|
+
self.model = cfg.model if cfg.model else "claude-3-5-sonnet-latest"
|
|
197
|
+
self.temperature = cfg.temperature
|
|
198
|
+
|
|
199
|
+
@staticmethod
|
|
200
|
+
def _split(messages: List[Message]):
|
|
201
|
+
"""Anthropic takes the system prompt separately from the turns."""
|
|
202
|
+
system = ""
|
|
203
|
+
turns: List[Message] = []
|
|
204
|
+
for m in messages:
|
|
205
|
+
if m["role"] == "system":
|
|
206
|
+
system = m["content"]
|
|
207
|
+
else:
|
|
208
|
+
turns.append({"role": m["role"], "content": m["content"]})
|
|
209
|
+
return system, turns
|
|
210
|
+
|
|
211
|
+
def stream_chat(self, messages: List[Message]) -> Iterator[str]:
|
|
212
|
+
system, turns = self._split(messages)
|
|
213
|
+
headers = {
|
|
214
|
+
"x-api-key": self.api_key,
|
|
215
|
+
"anthropic-version": "2023-06-01",
|
|
216
|
+
"content-type": "application/json",
|
|
217
|
+
}
|
|
218
|
+
payload = {
|
|
219
|
+
"model": self.model,
|
|
220
|
+
"messages": turns,
|
|
221
|
+
"max_tokens": 4096,
|
|
222
|
+
"temperature": self.temperature,
|
|
223
|
+
"stream": True,
|
|
224
|
+
}
|
|
225
|
+
if system:
|
|
226
|
+
payload["system"] = system
|
|
227
|
+
with httpx.stream(
|
|
228
|
+
"POST",
|
|
229
|
+
"https://api.anthropic.com/v1/messages",
|
|
230
|
+
json=payload,
|
|
231
|
+
headers=headers,
|
|
232
|
+
timeout=None,
|
|
233
|
+
) as resp:
|
|
234
|
+
resp.raise_for_status()
|
|
235
|
+
for line in resp.iter_lines():
|
|
236
|
+
if not line or not line.startswith("data: "):
|
|
237
|
+
continue
|
|
238
|
+
try:
|
|
239
|
+
data = json.loads(line[len("data: ") :])
|
|
240
|
+
except json.JSONDecodeError:
|
|
241
|
+
continue
|
|
242
|
+
if data.get("type") == "content_block_delta":
|
|
243
|
+
chunk = data.get("delta", {}).get("text")
|
|
244
|
+
if chunk:
|
|
245
|
+
yield chunk
|
|
246
|
+
|
|
247
|
+
def list_models(self) -> List[str]:
|
|
248
|
+
headers = {
|
|
249
|
+
"x-api-key": self.api_key,
|
|
250
|
+
"anthropic-version": "2023-06-01",
|
|
251
|
+
}
|
|
252
|
+
try:
|
|
253
|
+
resp = httpx.get(
|
|
254
|
+
"https://api.anthropic.com/v1/models", headers=headers, timeout=10
|
|
255
|
+
)
|
|
256
|
+
resp.raise_for_status()
|
|
257
|
+
return sorted(m["id"] for m in resp.json().get("data", []))
|
|
258
|
+
except httpx.HTTPError:
|
|
259
|
+
return []
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def get_provider(cfg: Config) -> Provider:
|
|
263
|
+
"""Factory that returns the configured provider instance."""
|
|
264
|
+
providers = {
|
|
265
|
+
"openai": OpenAIProvider,
|
|
266
|
+
"ollama": OllamaProvider,
|
|
267
|
+
"anthropic": AnthropicProvider,
|
|
268
|
+
"groq": GroqProvider,
|
|
269
|
+
"gemini": GeminiProvider,
|
|
270
|
+
}
|
|
271
|
+
cls = providers.get(cfg.provider)
|
|
272
|
+
if cls is None:
|
|
273
|
+
valid = ", ".join(sorted(providers))
|
|
274
|
+
raise ValueError(f"Unknown provider: {cfg.provider!r} (use one of: {valid})")
|
|
275
|
+
return cls(cfg)
|
oshell/retrieval.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""Chat-with-your-files: a tiny local RAG index.
|
|
2
|
+
|
|
3
|
+
Embeddings come from Ollama (default ``nomic-embed-text``) or OpenAI, so we
|
|
4
|
+
keep the "no heavy dependencies" promise — vectors are stored as JSON and
|
|
5
|
+
similarity search is plain Python. Good enough for personal knowledge bases.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import math
|
|
12
|
+
import time
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Iterable, List, Tuple
|
|
16
|
+
|
|
17
|
+
import httpx
|
|
18
|
+
from platformdirs import user_data_dir
|
|
19
|
+
|
|
20
|
+
from .config import APP_NAME, Config
|
|
21
|
+
|
|
22
|
+
INDEX_DIR = Path(user_data_dir(APP_NAME)) / "index"
|
|
23
|
+
INDEX_PATH = INDEX_DIR / "store.json"
|
|
24
|
+
|
|
25
|
+
TEXT_SUFFIXES = {
|
|
26
|
+
".txt", ".md", ".markdown", ".rst", ".py", ".js", ".ts", ".tsx", ".jsx",
|
|
27
|
+
".json", ".yaml", ".yml", ".toml", ".cfg", ".ini", ".html", ".css",
|
|
28
|
+
".java", ".go", ".rs", ".rb", ".php", ".c", ".cpp", ".h", ".hpp", ".sh",
|
|
29
|
+
".sql", ".csv",
|
|
30
|
+
}
|
|
31
|
+
MAX_FILE_BYTES = 500_000
|
|
32
|
+
CHUNK_CHARS = 1200
|
|
33
|
+
CHUNK_OVERLAP = 150
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class Chunk:
|
|
38
|
+
source: str
|
|
39
|
+
text: str
|
|
40
|
+
vector: List[float]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _chunk_text(text: str) -> Iterable[str]:
|
|
44
|
+
step = CHUNK_CHARS - CHUNK_OVERLAP
|
|
45
|
+
for start in range(0, len(text), step):
|
|
46
|
+
piece = text[start : start + CHUNK_CHARS].strip()
|
|
47
|
+
if piece:
|
|
48
|
+
yield piece
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _cosine(a: List[float], b: List[float]) -> float:
|
|
52
|
+
dot = sum(x * y for x, y in zip(a, b))
|
|
53
|
+
na = math.sqrt(sum(x * x for x in a))
|
|
54
|
+
nb = math.sqrt(sum(y * y for y in b))
|
|
55
|
+
if na == 0 or nb == 0:
|
|
56
|
+
return 0.0
|
|
57
|
+
return dot / (na * nb)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class Embedder:
|
|
61
|
+
"""Embeds text via Ollama or OpenAI, matching the configured provider."""
|
|
62
|
+
|
|
63
|
+
def __init__(self, cfg: Config) -> None:
|
|
64
|
+
self.cfg = cfg
|
|
65
|
+
# Use OpenAI embeddings only when explicitly on OpenAI with a key.
|
|
66
|
+
self.use_openai = cfg.provider == "openai" and bool(cfg.openai_api_key)
|
|
67
|
+
|
|
68
|
+
def embed(self, text: str) -> List[float]:
|
|
69
|
+
if self.use_openai:
|
|
70
|
+
return self._embed_openai(text)
|
|
71
|
+
return self._embed_ollama(text)
|
|
72
|
+
|
|
73
|
+
def _embed_ollama(self, text: str) -> List[float]:
|
|
74
|
+
host = self.cfg.ollama_host.rstrip("/")
|
|
75
|
+
# Newer Ollama exposes /api/embed; older ones use /api/embeddings.
|
|
76
|
+
try:
|
|
77
|
+
resp = httpx.post(
|
|
78
|
+
f"{host}/api/embed",
|
|
79
|
+
json={"model": self.cfg.embed_model, "input": text},
|
|
80
|
+
timeout=120,
|
|
81
|
+
)
|
|
82
|
+
if resp.status_code == 404:
|
|
83
|
+
raise httpx.HTTPStatusError("no /api/embed", request=resp.request, response=resp)
|
|
84
|
+
resp.raise_for_status()
|
|
85
|
+
data = resp.json()
|
|
86
|
+
embeddings = data.get("embeddings")
|
|
87
|
+
if embeddings:
|
|
88
|
+
return embeddings[0]
|
|
89
|
+
if data.get("embedding"):
|
|
90
|
+
return data["embedding"]
|
|
91
|
+
except httpx.HTTPStatusError:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
resp = httpx.post(
|
|
95
|
+
f"{host}/api/embeddings",
|
|
96
|
+
json={"model": self.cfg.embed_model, "prompt": text},
|
|
97
|
+
timeout=120,
|
|
98
|
+
)
|
|
99
|
+
resp.raise_for_status()
|
|
100
|
+
return resp.json()["embedding"]
|
|
101
|
+
|
|
102
|
+
def _embed_openai(self, text: str) -> List[float]:
|
|
103
|
+
resp = httpx.post(
|
|
104
|
+
f"{self.cfg.openai_base_url.rstrip('/')}/embeddings",
|
|
105
|
+
json={"model": "text-embedding-3-small", "input": text},
|
|
106
|
+
headers={"Authorization": f"Bearer {self.cfg.openai_api_key}"},
|
|
107
|
+
timeout=120,
|
|
108
|
+
)
|
|
109
|
+
resp.raise_for_status()
|
|
110
|
+
return resp.json()["data"][0]["embedding"]
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _iter_files(root: Path) -> Iterable[Path]:
|
|
114
|
+
for path in sorted(root.rglob("*")):
|
|
115
|
+
if not path.is_file():
|
|
116
|
+
continue
|
|
117
|
+
if path.suffix.lower() not in TEXT_SUFFIXES:
|
|
118
|
+
continue
|
|
119
|
+
if any(part in {".git", "node_modules", "__pycache__", ".venv"} for part in path.parts):
|
|
120
|
+
continue
|
|
121
|
+
try:
|
|
122
|
+
if path.stat().st_size > MAX_FILE_BYTES:
|
|
123
|
+
continue
|
|
124
|
+
except OSError:
|
|
125
|
+
continue
|
|
126
|
+
yield path
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def build_index(root: Path, cfg: Config, progress=None) -> int:
|
|
130
|
+
"""Index every text file under ``root``. Returns the chunk count."""
|
|
131
|
+
embedder = Embedder(cfg)
|
|
132
|
+
chunks: List[dict] = []
|
|
133
|
+
for path in _iter_files(root):
|
|
134
|
+
try:
|
|
135
|
+
text = path.read_text(encoding="utf-8", errors="ignore")
|
|
136
|
+
except OSError:
|
|
137
|
+
continue
|
|
138
|
+
rel = str(path.relative_to(root))
|
|
139
|
+
for piece in _chunk_text(text):
|
|
140
|
+
vector = embedder.embed(piece)
|
|
141
|
+
chunks.append({"source": rel, "text": piece, "vector": vector})
|
|
142
|
+
if progress:
|
|
143
|
+
progress(rel)
|
|
144
|
+
|
|
145
|
+
INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
|
146
|
+
INDEX_PATH.write_text(
|
|
147
|
+
json.dumps(
|
|
148
|
+
{
|
|
149
|
+
"root": str(root.resolve()),
|
|
150
|
+
"created": time.time(),
|
|
151
|
+
"provider": "openai" if embedder.use_openai else "ollama",
|
|
152
|
+
"chunks": chunks,
|
|
153
|
+
}
|
|
154
|
+
),
|
|
155
|
+
encoding="utf-8",
|
|
156
|
+
)
|
|
157
|
+
return len(chunks)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def index_exists() -> bool:
|
|
161
|
+
return INDEX_PATH.exists()
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def search(query: str, cfg: Config, top_k: int = 5) -> List[Tuple[str, str, float]]:
|
|
165
|
+
"""Return the top_k (source, text, score) matches for ``query``."""
|
|
166
|
+
if not INDEX_PATH.exists():
|
|
167
|
+
raise FileNotFoundError(
|
|
168
|
+
"No index found. Run `oshell index <folder>` first."
|
|
169
|
+
)
|
|
170
|
+
store = json.loads(INDEX_PATH.read_text(encoding="utf-8"))
|
|
171
|
+
chunks = store.get("chunks", [])
|
|
172
|
+
if not chunks:
|
|
173
|
+
return []
|
|
174
|
+
|
|
175
|
+
qvec = Embedder(cfg).embed(query)
|
|
176
|
+
scored = [
|
|
177
|
+
(c["source"], c["text"], _cosine(qvec, c["vector"])) for c in chunks
|
|
178
|
+
]
|
|
179
|
+
scored.sort(key=lambda x: x[2], reverse=True)
|
|
180
|
+
return scored[:top_k]
|
oshell/storyboard.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""Storyboard mode: brief → planned scenes → image per scene → stitched video.
|
|
2
|
+
|
|
3
|
+
The LLM acts as a director: it breaks a one-line brief into a sequence of
|
|
4
|
+
scenes (each with an image prompt and a caption). We generate an image for
|
|
5
|
+
every scene, then stitch them into an MP4 slideshow with ffmpeg.
|
|
6
|
+
|
|
7
|
+
ffmpeg is only required for the final stitch — the per-scene images are saved
|
|
8
|
+
regardless, so the command still produces useful output if ffmpeg is missing.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
import re
|
|
15
|
+
import shutil
|
|
16
|
+
import subprocess
|
|
17
|
+
import time
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Callable, List, Optional
|
|
21
|
+
|
|
22
|
+
from .config import APP_NAME, Config
|
|
23
|
+
from .media import ImageGenerator, _output_dir, _slug
|
|
24
|
+
from .providers import Message, Provider
|
|
25
|
+
|
|
26
|
+
_DIRECTOR_SYSTEM = """\
|
|
27
|
+
You are a creative director planning a short visual story. Given a brief, break \
|
|
28
|
+
it into a sequence of distinct scenes. Respond with ONLY a JSON array; each \
|
|
29
|
+
element is an object: {{"prompt": "<detailed text-to-image prompt with \
|
|
30
|
+
composition, lighting, style and mood>", "caption": "<short on-screen \
|
|
31
|
+
caption>"}}. Produce exactly {n} scenes. No markdown, no code fences, no extra \
|
|
32
|
+
text."""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class Scene:
|
|
37
|
+
prompt: str
|
|
38
|
+
caption: str
|
|
39
|
+
image: Optional[Path] = None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class StoryboardResult:
|
|
44
|
+
scenes: List[Scene]
|
|
45
|
+
video: Optional[Path]
|
|
46
|
+
image_dir: Path
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _parse_scenes(raw: str, fallback_brief: str, n: int) -> List[Scene]:
|
|
50
|
+
text = raw.strip()
|
|
51
|
+
if text.startswith("```"):
|
|
52
|
+
text = re.sub(r"^```[a-zA-Z]*\n?", "", text)
|
|
53
|
+
text = re.sub(r"\n?```$", "", text).strip()
|
|
54
|
+
# Pull out the first JSON array if the model added prose.
|
|
55
|
+
start = text.find("[")
|
|
56
|
+
end = text.rfind("]")
|
|
57
|
+
if start != -1 and end != -1:
|
|
58
|
+
text = text[start : end + 1]
|
|
59
|
+
try:
|
|
60
|
+
data = json.loads(text)
|
|
61
|
+
except json.JSONDecodeError:
|
|
62
|
+
data = []
|
|
63
|
+
|
|
64
|
+
scenes: List[Scene] = []
|
|
65
|
+
for item in data:
|
|
66
|
+
if isinstance(item, dict) and item.get("prompt"):
|
|
67
|
+
scenes.append(
|
|
68
|
+
Scene(prompt=str(item["prompt"]), caption=str(item.get("caption", "")))
|
|
69
|
+
)
|
|
70
|
+
# Fallback: if parsing failed, make simple scenes from the brief.
|
|
71
|
+
if not scenes:
|
|
72
|
+
scenes = [
|
|
73
|
+
Scene(prompt=f"{fallback_brief}, scene {i + 1}", caption="")
|
|
74
|
+
for i in range(n)
|
|
75
|
+
]
|
|
76
|
+
return scenes
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class StoryboardAgent:
|
|
80
|
+
"""Plans scenes with the LLM and renders them into a video."""
|
|
81
|
+
|
|
82
|
+
def __init__(self, provider: Provider, image_gen: ImageGenerator, cfg: Config):
|
|
83
|
+
self.provider = provider
|
|
84
|
+
self.image_gen = image_gen
|
|
85
|
+
self.cfg = cfg
|
|
86
|
+
|
|
87
|
+
def plan(self, brief: str, n: int) -> List[Scene]:
|
|
88
|
+
system = _DIRECTOR_SYSTEM.format(n=n)
|
|
89
|
+
messages: List[Message] = [
|
|
90
|
+
{"role": "system", "content": system},
|
|
91
|
+
{"role": "user", "content": brief},
|
|
92
|
+
]
|
|
93
|
+
raw = "".join(self.provider.stream_chat(messages))
|
|
94
|
+
return _parse_scenes(raw, brief, n)
|
|
95
|
+
|
|
96
|
+
def render(
|
|
97
|
+
self,
|
|
98
|
+
brief: str,
|
|
99
|
+
n: int = 4,
|
|
100
|
+
seconds_per_scene: float = 2.5,
|
|
101
|
+
on_scene: Optional[Callable[[int, Scene], None]] = None,
|
|
102
|
+
) -> StoryboardResult:
|
|
103
|
+
scenes = self.plan(brief, n)
|
|
104
|
+
|
|
105
|
+
out_dir = _output_dir(self.cfg, "storyboards")
|
|
106
|
+
for i, scene in enumerate(scenes):
|
|
107
|
+
result = self.image_gen.generate(scene.prompt, n=1)
|
|
108
|
+
# Move/rename the generated image into the storyboard folder, ordered.
|
|
109
|
+
src = result.paths[0]
|
|
110
|
+
dest = out_dir / f"scene-{i + 1:02d}.png"
|
|
111
|
+
dest.write_bytes(src.read_bytes())
|
|
112
|
+
scene.image = dest
|
|
113
|
+
if on_scene:
|
|
114
|
+
on_scene(i, scene)
|
|
115
|
+
|
|
116
|
+
video = self._stitch(scenes, out_dir, seconds_per_scene)
|
|
117
|
+
return StoryboardResult(scenes=scenes, video=video, image_dir=out_dir)
|
|
118
|
+
|
|
119
|
+
@staticmethod
|
|
120
|
+
def ffmpeg_available() -> bool:
|
|
121
|
+
return shutil.which("ffmpeg") is not None
|
|
122
|
+
|
|
123
|
+
def _stitch(
|
|
124
|
+
self, scenes: List[Scene], out_dir: Path, seconds_per_scene: float
|
|
125
|
+
) -> Optional[Path]:
|
|
126
|
+
images = [s.image for s in scenes if s.image]
|
|
127
|
+
if not images or not self.ffmpeg_available():
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
# Build a concat demuxer list with a per-image duration.
|
|
131
|
+
list_path = out_dir / "scenes.txt"
|
|
132
|
+
lines = []
|
|
133
|
+
for img in images:
|
|
134
|
+
lines.append(f"file '{img.as_posix()}'")
|
|
135
|
+
lines.append(f"duration {seconds_per_scene}")
|
|
136
|
+
# The concat demuxer needs the last file repeated (no trailing duration).
|
|
137
|
+
lines.append(f"file '{images[-1].as_posix()}'")
|
|
138
|
+
list_path.write_text("\n".join(lines), encoding="utf-8")
|
|
139
|
+
|
|
140
|
+
video_path = out_dir / "storyboard.mp4"
|
|
141
|
+
cmd = [
|
|
142
|
+
"ffmpeg",
|
|
143
|
+
"-y",
|
|
144
|
+
"-f",
|
|
145
|
+
"concat",
|
|
146
|
+
"-safe",
|
|
147
|
+
"0",
|
|
148
|
+
"-i",
|
|
149
|
+
str(list_path),
|
|
150
|
+
"-vf",
|
|
151
|
+
"scale=1024:-2:force_original_aspect_ratio=decrease,"
|
|
152
|
+
"pad=1024:1024:(ow-iw)/2:(oh-ih)/2,format=yuv420p",
|
|
153
|
+
"-r",
|
|
154
|
+
"30",
|
|
155
|
+
str(video_path),
|
|
156
|
+
]
|
|
157
|
+
try:
|
|
158
|
+
proc = subprocess.run(
|
|
159
|
+
cmd, capture_output=True, text=True, timeout=300
|
|
160
|
+
)
|
|
161
|
+
except (subprocess.TimeoutExpired, OSError):
|
|
162
|
+
return None
|
|
163
|
+
return video_path if proc.returncode == 0 and video_path.exists() else None
|