oshell 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
oshell/providers.py ADDED
@@ -0,0 +1,275 @@
1
+ """LLM provider layer.
2
+
3
+ Both providers expose a single ``stream_chat`` method that yields text
4
+ chunks as they arrive, so the CLI can render tokens live.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from typing import Iterator, List, Protocol
11
+
12
+ import httpx
13
+
14
+ from .config import Config
15
+
16
+ Message = dict # {"role": "user" | "assistant" | "system", "content": str}
17
+
18
+
19
+ class Provider(Protocol):
20
+ """Common interface every LLM backend implements."""
21
+
22
+ name: str
23
+
24
+ def stream_chat(self, messages: List[Message]) -> Iterator[str]:
25
+ """Yield response text chunks for the given conversation."""
26
+ ...
27
+
28
+ def list_models(self) -> List[str]:
29
+ """Return available model names (best effort)."""
30
+ ...
31
+
32
+
33
+ class OllamaProvider:
34
+ """Talks to a local Ollama server (https://ollama.com)."""
35
+
36
+ name = "ollama"
37
+
38
+ def __init__(self, cfg: Config) -> None:
39
+ self.host = cfg.ollama_host.rstrip("/")
40
+ self.model = cfg.model
41
+ self.temperature = cfg.temperature
42
+
43
+ def stream_chat(self, messages: List[Message]) -> Iterator[str]:
44
+ payload = {
45
+ "model": self.model,
46
+ "messages": messages,
47
+ "stream": True,
48
+ "options": {"temperature": self.temperature},
49
+ }
50
+ with httpx.stream(
51
+ "POST", f"{self.host}/api/chat", json=payload, timeout=None
52
+ ) as resp:
53
+ resp.raise_for_status()
54
+ for line in resp.iter_lines():
55
+ if not line:
56
+ continue
57
+ data = json.loads(line)
58
+ if data.get("done"):
59
+ break
60
+ chunk = data.get("message", {}).get("content", "")
61
+ if chunk:
62
+ yield chunk
63
+
64
+ def list_models(self) -> List[str]:
65
+ try:
66
+ resp = httpx.get(f"{self.host}/api/tags", timeout=10)
67
+ resp.raise_for_status()
68
+ return [m["name"] for m in resp.json().get("models", [])]
69
+ except httpx.HTTPError:
70
+ return []
71
+
72
+
73
+ class OpenAICompatibleProvider:
74
+ """Base for any service speaking the OpenAI Chat Completions API.
75
+
76
+ OpenAI, Groq, and Gemini (via its OpenAI-compatible endpoint) all share
77
+ this exact wire format, so they differ only in base URL, key, and the
78
+ error message shown when the key is missing.
79
+ """
80
+
81
+ name = "openai"
82
+ _key_hint = "OPENAI_API_KEY"
83
+
84
+ def __init__(self, base_url: str, api_key: str, model: str, temperature: float):
85
+ if not api_key:
86
+ raise ValueError(
87
+ f"{self.name.title()} API key not set. Export {self._key_hint} or run "
88
+ f"`oshell config set {self.name}_api_key <key>`."
89
+ )
90
+ self.base_url = base_url.rstrip("/")
91
+ self.api_key = api_key
92
+ self.model = model
93
+ self.temperature = temperature
94
+
95
+ def stream_chat(self, messages: List[Message]) -> Iterator[str]:
96
+ headers = {"Authorization": f"Bearer {self.api_key}"}
97
+ payload = {
98
+ "model": self.model,
99
+ "messages": messages,
100
+ "stream": True,
101
+ "temperature": self.temperature,
102
+ }
103
+ with httpx.stream(
104
+ "POST",
105
+ f"{self.base_url}/chat/completions",
106
+ json=payload,
107
+ headers=headers,
108
+ timeout=None,
109
+ ) as resp:
110
+ resp.raise_for_status()
111
+ for line in resp.iter_lines():
112
+ if not line or not line.startswith("data: "):
113
+ continue
114
+ data_str = line[len("data: ") :]
115
+ if data_str.strip() == "[DONE]":
116
+ break
117
+ try:
118
+ data = json.loads(data_str)
119
+ except json.JSONDecodeError:
120
+ continue
121
+ choices = data.get("choices") or []
122
+ if not choices:
123
+ continue
124
+ chunk = choices[0].get("delta", {}).get("content")
125
+ if chunk:
126
+ yield chunk
127
+
128
+ def list_models(self) -> List[str]:
129
+ headers = {"Authorization": f"Bearer {self.api_key}"}
130
+ try:
131
+ resp = httpx.get(
132
+ f"{self.base_url}/models", headers=headers, timeout=10
133
+ )
134
+ resp.raise_for_status()
135
+ return sorted(m["id"] for m in resp.json().get("data", []))
136
+ except httpx.HTTPError:
137
+ return []
138
+
139
+
140
+ class OpenAIProvider(OpenAICompatibleProvider):
141
+ """OpenAI Chat Completions API."""
142
+
143
+ name = "openai"
144
+ _key_hint = "OPENAI_API_KEY"
145
+
146
+ def __init__(self, cfg: Config) -> None:
147
+ super().__init__(
148
+ cfg.openai_base_url, cfg.openai_api_key, cfg.model, cfg.temperature
149
+ )
150
+
151
+
152
+ class GroqProvider(OpenAICompatibleProvider):
153
+ """Groq — extremely fast inference, OpenAI-compatible API."""
154
+
155
+ name = "groq"
156
+ _key_hint = "GROQ_API_KEY"
157
+
158
+ def __init__(self, cfg: Config) -> None:
159
+ model = cfg.model if cfg.model else "llama-3.3-70b-versatile"
160
+ super().__init__(
161
+ "https://api.groq.com/openai/v1",
162
+ cfg.groq_api_key,
163
+ model,
164
+ cfg.temperature,
165
+ )
166
+
167
+
168
+ class GeminiProvider(OpenAICompatibleProvider):
169
+ """Google Gemini via its OpenAI-compatible endpoint."""
170
+
171
+ name = "gemini"
172
+ _key_hint = "GEMINI_API_KEY"
173
+
174
+ def __init__(self, cfg: Config) -> None:
175
+ model = cfg.model if cfg.model else "gemini-1.5-flash"
176
+ super().__init__(
177
+ "https://generativelanguage.googleapis.com/v1beta/openai",
178
+ cfg.gemini_api_key,
179
+ model,
180
+ cfg.temperature,
181
+ )
182
+
183
+
184
+ class AnthropicProvider:
185
+ """Anthropic Claude via the native Messages API (SSE streaming)."""
186
+
187
+ name = "anthropic"
188
+
189
+ def __init__(self, cfg: Config) -> None:
190
+ if not cfg.anthropic_api_key:
191
+ raise ValueError(
192
+ "Anthropic API key not set. Export ANTHROPIC_API_KEY or run "
193
+ "`oshell config set anthropic_api_key <key>`."
194
+ )
195
+ self.api_key = cfg.anthropic_api_key
196
+ self.model = cfg.model if cfg.model else "claude-3-5-sonnet-latest"
197
+ self.temperature = cfg.temperature
198
+
199
+ @staticmethod
200
+ def _split(messages: List[Message]):
201
+ """Anthropic takes the system prompt separately from the turns."""
202
+ system = ""
203
+ turns: List[Message] = []
204
+ for m in messages:
205
+ if m["role"] == "system":
206
+ system = m["content"]
207
+ else:
208
+ turns.append({"role": m["role"], "content": m["content"]})
209
+ return system, turns
210
+
211
+ def stream_chat(self, messages: List[Message]) -> Iterator[str]:
212
+ system, turns = self._split(messages)
213
+ headers = {
214
+ "x-api-key": self.api_key,
215
+ "anthropic-version": "2023-06-01",
216
+ "content-type": "application/json",
217
+ }
218
+ payload = {
219
+ "model": self.model,
220
+ "messages": turns,
221
+ "max_tokens": 4096,
222
+ "temperature": self.temperature,
223
+ "stream": True,
224
+ }
225
+ if system:
226
+ payload["system"] = system
227
+ with httpx.stream(
228
+ "POST",
229
+ "https://api.anthropic.com/v1/messages",
230
+ json=payload,
231
+ headers=headers,
232
+ timeout=None,
233
+ ) as resp:
234
+ resp.raise_for_status()
235
+ for line in resp.iter_lines():
236
+ if not line or not line.startswith("data: "):
237
+ continue
238
+ try:
239
+ data = json.loads(line[len("data: ") :])
240
+ except json.JSONDecodeError:
241
+ continue
242
+ if data.get("type") == "content_block_delta":
243
+ chunk = data.get("delta", {}).get("text")
244
+ if chunk:
245
+ yield chunk
246
+
247
+ def list_models(self) -> List[str]:
248
+ headers = {
249
+ "x-api-key": self.api_key,
250
+ "anthropic-version": "2023-06-01",
251
+ }
252
+ try:
253
+ resp = httpx.get(
254
+ "https://api.anthropic.com/v1/models", headers=headers, timeout=10
255
+ )
256
+ resp.raise_for_status()
257
+ return sorted(m["id"] for m in resp.json().get("data", []))
258
+ except httpx.HTTPError:
259
+ return []
260
+
261
+
262
+ def get_provider(cfg: Config) -> Provider:
263
+ """Factory that returns the configured provider instance."""
264
+ providers = {
265
+ "openai": OpenAIProvider,
266
+ "ollama": OllamaProvider,
267
+ "anthropic": AnthropicProvider,
268
+ "groq": GroqProvider,
269
+ "gemini": GeminiProvider,
270
+ }
271
+ cls = providers.get(cfg.provider)
272
+ if cls is None:
273
+ valid = ", ".join(sorted(providers))
274
+ raise ValueError(f"Unknown provider: {cfg.provider!r} (use one of: {valid})")
275
+ return cls(cfg)
oshell/retrieval.py ADDED
@@ -0,0 +1,180 @@
1
+ """Chat-with-your-files: a tiny local RAG index.
2
+
3
+ Embeddings come from Ollama (default ``nomic-embed-text``) or OpenAI, so we
4
+ keep the "no heavy dependencies" promise — vectors are stored as JSON and
5
+ similarity search is plain Python. Good enough for personal knowledge bases.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import math
12
+ import time
13
+ from dataclasses import dataclass
14
+ from pathlib import Path
15
+ from typing import Iterable, List, Tuple
16
+
17
+ import httpx
18
+ from platformdirs import user_data_dir
19
+
20
+ from .config import APP_NAME, Config
21
+
22
+ INDEX_DIR = Path(user_data_dir(APP_NAME)) / "index"
23
+ INDEX_PATH = INDEX_DIR / "store.json"
24
+
25
+ TEXT_SUFFIXES = {
26
+ ".txt", ".md", ".markdown", ".rst", ".py", ".js", ".ts", ".tsx", ".jsx",
27
+ ".json", ".yaml", ".yml", ".toml", ".cfg", ".ini", ".html", ".css",
28
+ ".java", ".go", ".rs", ".rb", ".php", ".c", ".cpp", ".h", ".hpp", ".sh",
29
+ ".sql", ".csv",
30
+ }
31
+ MAX_FILE_BYTES = 500_000
32
+ CHUNK_CHARS = 1200
33
+ CHUNK_OVERLAP = 150
34
+
35
+
36
+ @dataclass
37
+ class Chunk:
38
+ source: str
39
+ text: str
40
+ vector: List[float]
41
+
42
+
43
+ def _chunk_text(text: str) -> Iterable[str]:
44
+ step = CHUNK_CHARS - CHUNK_OVERLAP
45
+ for start in range(0, len(text), step):
46
+ piece = text[start : start + CHUNK_CHARS].strip()
47
+ if piece:
48
+ yield piece
49
+
50
+
51
+ def _cosine(a: List[float], b: List[float]) -> float:
52
+ dot = sum(x * y for x, y in zip(a, b))
53
+ na = math.sqrt(sum(x * x for x in a))
54
+ nb = math.sqrt(sum(y * y for y in b))
55
+ if na == 0 or nb == 0:
56
+ return 0.0
57
+ return dot / (na * nb)
58
+
59
+
60
+ class Embedder:
61
+ """Embeds text via Ollama or OpenAI, matching the configured provider."""
62
+
63
+ def __init__(self, cfg: Config) -> None:
64
+ self.cfg = cfg
65
+ # Use OpenAI embeddings only when explicitly on OpenAI with a key.
66
+ self.use_openai = cfg.provider == "openai" and bool(cfg.openai_api_key)
67
+
68
+ def embed(self, text: str) -> List[float]:
69
+ if self.use_openai:
70
+ return self._embed_openai(text)
71
+ return self._embed_ollama(text)
72
+
73
+ def _embed_ollama(self, text: str) -> List[float]:
74
+ host = self.cfg.ollama_host.rstrip("/")
75
+ # Newer Ollama exposes /api/embed; older ones use /api/embeddings.
76
+ try:
77
+ resp = httpx.post(
78
+ f"{host}/api/embed",
79
+ json={"model": self.cfg.embed_model, "input": text},
80
+ timeout=120,
81
+ )
82
+ if resp.status_code == 404:
83
+ raise httpx.HTTPStatusError("no /api/embed", request=resp.request, response=resp)
84
+ resp.raise_for_status()
85
+ data = resp.json()
86
+ embeddings = data.get("embeddings")
87
+ if embeddings:
88
+ return embeddings[0]
89
+ if data.get("embedding"):
90
+ return data["embedding"]
91
+ except httpx.HTTPStatusError:
92
+ pass
93
+
94
+ resp = httpx.post(
95
+ f"{host}/api/embeddings",
96
+ json={"model": self.cfg.embed_model, "prompt": text},
97
+ timeout=120,
98
+ )
99
+ resp.raise_for_status()
100
+ return resp.json()["embedding"]
101
+
102
+ def _embed_openai(self, text: str) -> List[float]:
103
+ resp = httpx.post(
104
+ f"{self.cfg.openai_base_url.rstrip('/')}/embeddings",
105
+ json={"model": "text-embedding-3-small", "input": text},
106
+ headers={"Authorization": f"Bearer {self.cfg.openai_api_key}"},
107
+ timeout=120,
108
+ )
109
+ resp.raise_for_status()
110
+ return resp.json()["data"][0]["embedding"]
111
+
112
+
113
+ def _iter_files(root: Path) -> Iterable[Path]:
114
+ for path in sorted(root.rglob("*")):
115
+ if not path.is_file():
116
+ continue
117
+ if path.suffix.lower() not in TEXT_SUFFIXES:
118
+ continue
119
+ if any(part in {".git", "node_modules", "__pycache__", ".venv"} for part in path.parts):
120
+ continue
121
+ try:
122
+ if path.stat().st_size > MAX_FILE_BYTES:
123
+ continue
124
+ except OSError:
125
+ continue
126
+ yield path
127
+
128
+
129
+ def build_index(root: Path, cfg: Config, progress=None) -> int:
130
+ """Index every text file under ``root``. Returns the chunk count."""
131
+ embedder = Embedder(cfg)
132
+ chunks: List[dict] = []
133
+ for path in _iter_files(root):
134
+ try:
135
+ text = path.read_text(encoding="utf-8", errors="ignore")
136
+ except OSError:
137
+ continue
138
+ rel = str(path.relative_to(root))
139
+ for piece in _chunk_text(text):
140
+ vector = embedder.embed(piece)
141
+ chunks.append({"source": rel, "text": piece, "vector": vector})
142
+ if progress:
143
+ progress(rel)
144
+
145
+ INDEX_DIR.mkdir(parents=True, exist_ok=True)
146
+ INDEX_PATH.write_text(
147
+ json.dumps(
148
+ {
149
+ "root": str(root.resolve()),
150
+ "created": time.time(),
151
+ "provider": "openai" if embedder.use_openai else "ollama",
152
+ "chunks": chunks,
153
+ }
154
+ ),
155
+ encoding="utf-8",
156
+ )
157
+ return len(chunks)
158
+
159
+
160
+ def index_exists() -> bool:
161
+ return INDEX_PATH.exists()
162
+
163
+
164
+ def search(query: str, cfg: Config, top_k: int = 5) -> List[Tuple[str, str, float]]:
165
+ """Return the top_k (source, text, score) matches for ``query``."""
166
+ if not INDEX_PATH.exists():
167
+ raise FileNotFoundError(
168
+ "No index found. Run `oshell index <folder>` first."
169
+ )
170
+ store = json.loads(INDEX_PATH.read_text(encoding="utf-8"))
171
+ chunks = store.get("chunks", [])
172
+ if not chunks:
173
+ return []
174
+
175
+ qvec = Embedder(cfg).embed(query)
176
+ scored = [
177
+ (c["source"], c["text"], _cosine(qvec, c["vector"])) for c in chunks
178
+ ]
179
+ scored.sort(key=lambda x: x[2], reverse=True)
180
+ return scored[:top_k]
oshell/storyboard.py ADDED
@@ -0,0 +1,163 @@
1
+ """Storyboard mode: brief → planned scenes → image per scene → stitched video.
2
+
3
+ The LLM acts as a director: it breaks a one-line brief into a sequence of
4
+ scenes (each with an image prompt and a caption). We generate an image for
5
+ every scene, then stitch them into an MP4 slideshow with ffmpeg.
6
+
7
+ ffmpeg is only required for the final stitch — the per-scene images are saved
8
+ regardless, so the command still produces useful output if ffmpeg is missing.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import re
15
+ import shutil
16
+ import subprocess
17
+ import time
18
+ from dataclasses import dataclass
19
+ from pathlib import Path
20
+ from typing import Callable, List, Optional
21
+
22
+ from .config import APP_NAME, Config
23
+ from .media import ImageGenerator, _output_dir, _slug
24
+ from .providers import Message, Provider
25
+
26
+ _DIRECTOR_SYSTEM = """\
27
+ You are a creative director planning a short visual story. Given a brief, break \
28
+ it into a sequence of distinct scenes. Respond with ONLY a JSON array; each \
29
+ element is an object: {{"prompt": "<detailed text-to-image prompt with \
30
+ composition, lighting, style and mood>", "caption": "<short on-screen \
31
+ caption>"}}. Produce exactly {n} scenes. No markdown, no code fences, no extra \
32
+ text."""
33
+
34
+
35
+ @dataclass
36
+ class Scene:
37
+ prompt: str
38
+ caption: str
39
+ image: Optional[Path] = None
40
+
41
+
42
+ @dataclass
43
+ class StoryboardResult:
44
+ scenes: List[Scene]
45
+ video: Optional[Path]
46
+ image_dir: Path
47
+
48
+
49
+ def _parse_scenes(raw: str, fallback_brief: str, n: int) -> List[Scene]:
50
+ text = raw.strip()
51
+ if text.startswith("```"):
52
+ text = re.sub(r"^```[a-zA-Z]*\n?", "", text)
53
+ text = re.sub(r"\n?```$", "", text).strip()
54
+ # Pull out the first JSON array if the model added prose.
55
+ start = text.find("[")
56
+ end = text.rfind("]")
57
+ if start != -1 and end != -1:
58
+ text = text[start : end + 1]
59
+ try:
60
+ data = json.loads(text)
61
+ except json.JSONDecodeError:
62
+ data = []
63
+
64
+ scenes: List[Scene] = []
65
+ for item in data:
66
+ if isinstance(item, dict) and item.get("prompt"):
67
+ scenes.append(
68
+ Scene(prompt=str(item["prompt"]), caption=str(item.get("caption", "")))
69
+ )
70
+ # Fallback: if parsing failed, make simple scenes from the brief.
71
+ if not scenes:
72
+ scenes = [
73
+ Scene(prompt=f"{fallback_brief}, scene {i + 1}", caption="")
74
+ for i in range(n)
75
+ ]
76
+ return scenes
77
+
78
+
79
+ class StoryboardAgent:
80
+ """Plans scenes with the LLM and renders them into a video."""
81
+
82
+ def __init__(self, provider: Provider, image_gen: ImageGenerator, cfg: Config):
83
+ self.provider = provider
84
+ self.image_gen = image_gen
85
+ self.cfg = cfg
86
+
87
+ def plan(self, brief: str, n: int) -> List[Scene]:
88
+ system = _DIRECTOR_SYSTEM.format(n=n)
89
+ messages: List[Message] = [
90
+ {"role": "system", "content": system},
91
+ {"role": "user", "content": brief},
92
+ ]
93
+ raw = "".join(self.provider.stream_chat(messages))
94
+ return _parse_scenes(raw, brief, n)
95
+
96
+ def render(
97
+ self,
98
+ brief: str,
99
+ n: int = 4,
100
+ seconds_per_scene: float = 2.5,
101
+ on_scene: Optional[Callable[[int, Scene], None]] = None,
102
+ ) -> StoryboardResult:
103
+ scenes = self.plan(brief, n)
104
+
105
+ out_dir = _output_dir(self.cfg, "storyboards")
106
+ for i, scene in enumerate(scenes):
107
+ result = self.image_gen.generate(scene.prompt, n=1)
108
+ # Move/rename the generated image into the storyboard folder, ordered.
109
+ src = result.paths[0]
110
+ dest = out_dir / f"scene-{i + 1:02d}.png"
111
+ dest.write_bytes(src.read_bytes())
112
+ scene.image = dest
113
+ if on_scene:
114
+ on_scene(i, scene)
115
+
116
+ video = self._stitch(scenes, out_dir, seconds_per_scene)
117
+ return StoryboardResult(scenes=scenes, video=video, image_dir=out_dir)
118
+
119
+ @staticmethod
120
+ def ffmpeg_available() -> bool:
121
+ return shutil.which("ffmpeg") is not None
122
+
123
+ def _stitch(
124
+ self, scenes: List[Scene], out_dir: Path, seconds_per_scene: float
125
+ ) -> Optional[Path]:
126
+ images = [s.image for s in scenes if s.image]
127
+ if not images or not self.ffmpeg_available():
128
+ return None
129
+
130
+ # Build a concat demuxer list with a per-image duration.
131
+ list_path = out_dir / "scenes.txt"
132
+ lines = []
133
+ for img in images:
134
+ lines.append(f"file '{img.as_posix()}'")
135
+ lines.append(f"duration {seconds_per_scene}")
136
+ # The concat demuxer needs the last file repeated (no trailing duration).
137
+ lines.append(f"file '{images[-1].as_posix()}'")
138
+ list_path.write_text("\n".join(lines), encoding="utf-8")
139
+
140
+ video_path = out_dir / "storyboard.mp4"
141
+ cmd = [
142
+ "ffmpeg",
143
+ "-y",
144
+ "-f",
145
+ "concat",
146
+ "-safe",
147
+ "0",
148
+ "-i",
149
+ str(list_path),
150
+ "-vf",
151
+ "scale=1024:-2:force_original_aspect_ratio=decrease,"
152
+ "pad=1024:1024:(ow-iw)/2:(oh-ih)/2,format=yuv420p",
153
+ "-r",
154
+ "30",
155
+ str(video_path),
156
+ ]
157
+ try:
158
+ proc = subprocess.run(
159
+ cmd, capture_output=True, text=True, timeout=300
160
+ )
161
+ except (subprocess.TimeoutExpired, OSError):
162
+ return None
163
+ return video_path if proc.returncode == 0 and video_path.exists() else None