oshell 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- oshell/__init__.py +3 -0
- oshell/agent.py +146 -0
- oshell/cli.py +836 -0
- oshell/config.py +91 -0
- oshell/history.py +57 -0
- oshell/media.py +188 -0
- oshell/media_agent.py +55 -0
- oshell/personas.py +90 -0
- oshell/providers.py +275 -0
- oshell/retrieval.py +180 -0
- oshell/storyboard.py +163 -0
- oshell/tools.py +150 -0
- oshell-0.1.1.dist-info/METADATA +286 -0
- oshell-0.1.1.dist-info/RECORD +17 -0
- oshell-0.1.1.dist-info/WHEEL +4 -0
- oshell-0.1.1.dist-info/entry_points.txt +3 -0
- oshell-0.1.1.dist-info/licenses/LICENSE +21 -0
oshell/config.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Configuration loading and persistence for oshell.
|
|
2
|
+
|
|
3
|
+
Settings are resolved in this order (highest priority first):
|
|
4
|
+
1. Command-line flags
|
|
5
|
+
2. Environment variables
|
|
6
|
+
3. Config file (~/.config/oshell/config.json on Linux/macOS)
|
|
7
|
+
4. Built-in defaults
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import os
|
|
14
|
+
from dataclasses import asdict, dataclass
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
from platformdirs import user_config_dir
|
|
18
|
+
|
|
19
|
+
APP_NAME = "oshell"
|
|
20
|
+
|
|
21
|
+
CONFIG_DIR = Path(user_config_dir(APP_NAME))
|
|
22
|
+
CONFIG_PATH = CONFIG_DIR / "config.json"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class Config:
|
|
27
|
+
"""Resolved runtime configuration."""
|
|
28
|
+
|
|
29
|
+
provider: str = "ollama" # ollama | openai | anthropic | groq | gemini
|
|
30
|
+
model: str = "llama3.2"
|
|
31
|
+
system_prompt: str = "You are a helpful, concise assistant."
|
|
32
|
+
persona: str = "" # named preset; overrides system_prompt when set
|
|
33
|
+
temperature: float = 0.7
|
|
34
|
+
|
|
35
|
+
# Provider endpoints / credentials
|
|
36
|
+
ollama_host: str = "http://localhost:11434"
|
|
37
|
+
openai_base_url: str = "https://api.openai.com/v1"
|
|
38
|
+
openai_api_key: str = ""
|
|
39
|
+
anthropic_api_key: str = ""
|
|
40
|
+
groq_api_key: str = ""
|
|
41
|
+
gemini_api_key: str = ""
|
|
42
|
+
|
|
43
|
+
# Retrieval (chat-with-your-files)
|
|
44
|
+
embed_model: str = "nomic-embed-text" # ollama embedding model
|
|
45
|
+
|
|
46
|
+
# Media generation
|
|
47
|
+
image_model: str = "gpt-image-1"
|
|
48
|
+
image_size: str = "1024x1024"
|
|
49
|
+
video_model: str = "minimax/video-01"
|
|
50
|
+
replicate_api_token: str = ""
|
|
51
|
+
media_output_dir: str = "" # empty -> ~/oshell/media
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def load(cls) -> "Config":
|
|
55
|
+
"""Load config from file + environment, applying defaults."""
|
|
56
|
+
data: dict = {}
|
|
57
|
+
if CONFIG_PATH.exists():
|
|
58
|
+
try:
|
|
59
|
+
data = json.loads(CONFIG_PATH.read_text(encoding="utf-8"))
|
|
60
|
+
except (json.JSONDecodeError, OSError):
|
|
61
|
+
data = {}
|
|
62
|
+
|
|
63
|
+
cfg = cls(**{k: v for k, v in data.items() if k in cls.__annotations__})
|
|
64
|
+
|
|
65
|
+
# Environment overrides
|
|
66
|
+
cfg.openai_api_key = os.environ.get("OPENAI_API_KEY", cfg.openai_api_key)
|
|
67
|
+
if os.environ.get("OPENAI_BASE_URL"):
|
|
68
|
+
cfg.openai_base_url = os.environ["OPENAI_BASE_URL"]
|
|
69
|
+
if os.environ.get("OLLAMA_HOST"):
|
|
70
|
+
cfg.ollama_host = os.environ["OLLAMA_HOST"]
|
|
71
|
+
if os.environ.get("OSHELL_PROVIDER"):
|
|
72
|
+
cfg.provider = os.environ["OSHELL_PROVIDER"]
|
|
73
|
+
if os.environ.get("OSHELL_MODEL"):
|
|
74
|
+
cfg.model = os.environ["OSHELL_MODEL"]
|
|
75
|
+
if os.environ.get("REPLICATE_API_TOKEN"):
|
|
76
|
+
cfg.replicate_api_token = os.environ["REPLICATE_API_TOKEN"]
|
|
77
|
+
cfg.anthropic_api_key = os.environ.get(
|
|
78
|
+
"ANTHROPIC_API_KEY", cfg.anthropic_api_key
|
|
79
|
+
)
|
|
80
|
+
cfg.groq_api_key = os.environ.get("GROQ_API_KEY", cfg.groq_api_key)
|
|
81
|
+
cfg.gemini_api_key = os.environ.get("GEMINI_API_KEY", cfg.gemini_api_key)
|
|
82
|
+
|
|
83
|
+
return cfg
|
|
84
|
+
|
|
85
|
+
def save(self) -> Path:
|
|
86
|
+
"""Persist current config to disk."""
|
|
87
|
+
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
|
88
|
+
CONFIG_PATH.write_text(
|
|
89
|
+
json.dumps(asdict(self), indent=2), encoding="utf-8"
|
|
90
|
+
)
|
|
91
|
+
return CONFIG_PATH
|
oshell/history.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Conversation history persistence (JSONL session files)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import time
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import List
|
|
9
|
+
|
|
10
|
+
from platformdirs import user_data_dir
|
|
11
|
+
|
|
12
|
+
from .config import APP_NAME
|
|
13
|
+
from .providers import Message
|
|
14
|
+
|
|
15
|
+
HISTORY_DIR = Path(user_data_dir(APP_NAME)) / "sessions"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _ensure_dir() -> None:
|
|
19
|
+
HISTORY_DIR.mkdir(parents=True, exist_ok=True)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def new_session_path() -> Path:
|
|
23
|
+
"""Return a fresh, timestamped session file path."""
|
|
24
|
+
_ensure_dir()
|
|
25
|
+
stamp = time.strftime("%Y%m%d-%H%M%S")
|
|
26
|
+
return HISTORY_DIR / f"chat-{stamp}.jsonl"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def append(path: Path, message: Message) -> None:
|
|
30
|
+
"""Append a single message to a session file."""
|
|
31
|
+
_ensure_dir()
|
|
32
|
+
with path.open("a", encoding="utf-8") as fh:
|
|
33
|
+
fh.write(json.dumps(message) + "\n")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def load(path: Path) -> List[Message]:
|
|
37
|
+
"""Load all messages from a session file."""
|
|
38
|
+
if not path.exists():
|
|
39
|
+
return []
|
|
40
|
+
messages: List[Message] = []
|
|
41
|
+
with path.open(encoding="utf-8") as fh:
|
|
42
|
+
for line in fh:
|
|
43
|
+
line = line.strip()
|
|
44
|
+
if line:
|
|
45
|
+
messages.append(json.loads(line))
|
|
46
|
+
return messages
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def list_sessions() -> List[Path]:
|
|
50
|
+
"""Return saved sessions, newest first."""
|
|
51
|
+
if not HISTORY_DIR.exists():
|
|
52
|
+
return []
|
|
53
|
+
return sorted(
|
|
54
|
+
HISTORY_DIR.glob("chat-*.jsonl"),
|
|
55
|
+
key=lambda p: p.stat().st_mtime,
|
|
56
|
+
reverse=True,
|
|
57
|
+
)
|
oshell/media.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Media generation backends: images (OpenAI) and video (Replicate).
|
|
2
|
+
|
|
3
|
+
Both return the path(s) to the saved file(s). Outputs are written to a
|
|
4
|
+
timestamped folder so repeated runs never clobber each other.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import base64
|
|
10
|
+
import time
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import List
|
|
14
|
+
|
|
15
|
+
import httpx
|
|
16
|
+
from platformdirs import user_pictures_dir
|
|
17
|
+
|
|
18
|
+
from .config import APP_NAME, Config
|
|
19
|
+
|
|
20
|
+
# How long to wait for slow video jobs (seconds).
|
|
21
|
+
VIDEO_POLL_TIMEOUT = 600
|
|
22
|
+
VIDEO_POLL_INTERVAL = 3
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class MediaResult:
|
|
27
|
+
"""Where the generated media landed, plus the prompt actually used."""
|
|
28
|
+
|
|
29
|
+
paths: List[Path]
|
|
30
|
+
prompt: str
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _output_dir(cfg: Config, kind: str) -> Path:
|
|
34
|
+
base = (
|
|
35
|
+
Path(cfg.media_output_dir)
|
|
36
|
+
if cfg.media_output_dir
|
|
37
|
+
else Path(user_pictures_dir()) / APP_NAME
|
|
38
|
+
)
|
|
39
|
+
stamp = time.strftime("%Y%m%d-%H%M%S")
|
|
40
|
+
target = base / kind / stamp
|
|
41
|
+
target.mkdir(parents=True, exist_ok=True)
|
|
42
|
+
return target
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _slug(text: str, limit: int = 40) -> str:
|
|
46
|
+
keep = [c if c.isalnum() else "-" for c in text.lower()]
|
|
47
|
+
slug = "".join(keep).strip("-")
|
|
48
|
+
while "--" in slug:
|
|
49
|
+
slug = slug.replace("--", "-")
|
|
50
|
+
return slug[:limit] or "media"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ImageGenerator:
|
|
54
|
+
"""Generates images with the OpenAI Images API (gpt-image-1 / DALL·E 3)."""
|
|
55
|
+
|
|
56
|
+
def __init__(self, cfg: Config) -> None:
|
|
57
|
+
if not cfg.openai_api_key:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
"OpenAI API key not set. Export OPENAI_API_KEY or run "
|
|
60
|
+
"`oshell config set openai_api_key <key>`."
|
|
61
|
+
)
|
|
62
|
+
self.base_url = cfg.openai_base_url.rstrip("/")
|
|
63
|
+
self.api_key = cfg.openai_api_key
|
|
64
|
+
self.model = cfg.image_model
|
|
65
|
+
self.size = cfg.image_size
|
|
66
|
+
self.cfg = cfg
|
|
67
|
+
|
|
68
|
+
def generate(self, prompt: str, n: int = 1) -> MediaResult:
|
|
69
|
+
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
70
|
+
payload = {
|
|
71
|
+
"model": self.model,
|
|
72
|
+
"prompt": prompt,
|
|
73
|
+
"n": n,
|
|
74
|
+
"size": self.size,
|
|
75
|
+
}
|
|
76
|
+
resp = httpx.post(
|
|
77
|
+
f"{self.base_url}/images/generations",
|
|
78
|
+
json=payload,
|
|
79
|
+
headers=headers,
|
|
80
|
+
timeout=180,
|
|
81
|
+
)
|
|
82
|
+
resp.raise_for_status()
|
|
83
|
+
data = resp.json().get("data", [])
|
|
84
|
+
|
|
85
|
+
out_dir = _output_dir(self.cfg, "images")
|
|
86
|
+
slug = _slug(prompt)
|
|
87
|
+
paths: List[Path] = []
|
|
88
|
+
for i, item in enumerate(data):
|
|
89
|
+
suffix = "" if len(data) == 1 else f"-{i + 1}"
|
|
90
|
+
dest = out_dir / f"{slug}{suffix}.png"
|
|
91
|
+
if item.get("b64_json"):
|
|
92
|
+
dest.write_bytes(base64.b64decode(item["b64_json"]))
|
|
93
|
+
elif item.get("url"):
|
|
94
|
+
img = httpx.get(item["url"], timeout=120)
|
|
95
|
+
img.raise_for_status()
|
|
96
|
+
dest.write_bytes(img.content)
|
|
97
|
+
else:
|
|
98
|
+
continue
|
|
99
|
+
paths.append(dest)
|
|
100
|
+
|
|
101
|
+
if not paths:
|
|
102
|
+
raise RuntimeError("Image API returned no usable data.")
|
|
103
|
+
return MediaResult(paths=paths, prompt=prompt)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class VideoGenerator:
|
|
107
|
+
"""Generates video via Replicate's prediction API.
|
|
108
|
+
|
|
109
|
+
Works with any text-to-video model on Replicate (default
|
|
110
|
+
``minimax/video-01``); just change ``video_model`` in config.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(self, cfg: Config) -> None:
|
|
114
|
+
if not cfg.replicate_api_token:
|
|
115
|
+
raise ValueError(
|
|
116
|
+
"Replicate token not set. Export REPLICATE_API_TOKEN or run "
|
|
117
|
+
"`oshell config set replicate_api_token <token>`."
|
|
118
|
+
)
|
|
119
|
+
self.token = cfg.replicate_api_token
|
|
120
|
+
self.model = cfg.video_model
|
|
121
|
+
self.cfg = cfg
|
|
122
|
+
|
|
123
|
+
def _headers(self) -> dict:
|
|
124
|
+
return {
|
|
125
|
+
"Authorization": f"Bearer {self.token}",
|
|
126
|
+
"Content-Type": "application/json",
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
def generate(self, prompt: str) -> MediaResult:
|
|
130
|
+
# Use the model-scoped predictions endpoint so we don't need a version hash.
|
|
131
|
+
url = f"https://api.replicate.com/v1/models/{self.model}/predictions"
|
|
132
|
+
resp = httpx.post(
|
|
133
|
+
url,
|
|
134
|
+
json={"input": {"prompt": prompt}},
|
|
135
|
+
headers={**self._headers(), "Prefer": "wait"},
|
|
136
|
+
timeout=120,
|
|
137
|
+
)
|
|
138
|
+
resp.raise_for_status()
|
|
139
|
+
prediction = resp.json()
|
|
140
|
+
|
|
141
|
+
prediction = self._await_completion(prediction)
|
|
142
|
+
output = prediction.get("output")
|
|
143
|
+
urls = self._collect_urls(output)
|
|
144
|
+
if not urls:
|
|
145
|
+
raise RuntimeError(
|
|
146
|
+
f"Video job finished with status '{prediction.get('status')}' "
|
|
147
|
+
"but produced no output URL."
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
out_dir = _output_dir(self.cfg, "videos")
|
|
151
|
+
slug = _slug(prompt)
|
|
152
|
+
paths: List[Path] = []
|
|
153
|
+
for i, video_url in enumerate(urls):
|
|
154
|
+
suffix = "" if len(urls) == 1 else f"-{i + 1}"
|
|
155
|
+
dest = out_dir / f"{slug}{suffix}.mp4"
|
|
156
|
+
data = httpx.get(video_url, timeout=300)
|
|
157
|
+
data.raise_for_status()
|
|
158
|
+
dest.write_bytes(data.content)
|
|
159
|
+
paths.append(dest)
|
|
160
|
+
return MediaResult(paths=paths, prompt=prompt)
|
|
161
|
+
|
|
162
|
+
def _await_completion(self, prediction: dict) -> dict:
|
|
163
|
+
terminal = {"succeeded", "failed", "canceled"}
|
|
164
|
+
deadline = time.time() + VIDEO_POLL_TIMEOUT
|
|
165
|
+
while prediction.get("status") not in terminal:
|
|
166
|
+
if time.time() > deadline:
|
|
167
|
+
raise TimeoutError("Video generation timed out.")
|
|
168
|
+
poll_url = prediction.get("urls", {}).get("get")
|
|
169
|
+
if not poll_url:
|
|
170
|
+
break
|
|
171
|
+
time.sleep(VIDEO_POLL_INTERVAL)
|
|
172
|
+
r = httpx.get(poll_url, headers=self._headers(), timeout=60)
|
|
173
|
+
r.raise_for_status()
|
|
174
|
+
prediction = r.json()
|
|
175
|
+
|
|
176
|
+
if prediction.get("status") == "failed":
|
|
177
|
+
raise RuntimeError(f"Video generation failed: {prediction.get('error')}")
|
|
178
|
+
return prediction
|
|
179
|
+
|
|
180
|
+
@staticmethod
|
|
181
|
+
def _collect_urls(output) -> List[str]:
|
|
182
|
+
if not output:
|
|
183
|
+
return []
|
|
184
|
+
if isinstance(output, str):
|
|
185
|
+
return [output]
|
|
186
|
+
if isinstance(output, list):
|
|
187
|
+
return [u for u in output if isinstance(u, str)]
|
|
188
|
+
return []
|
oshell/media_agent.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""The media agent: turns a short brief into rich media.
|
|
2
|
+
|
|
3
|
+
It uses the chat LLM to *enhance* a user's brief into a detailed, model-ready
|
|
4
|
+
prompt (composition, lighting, style, camera, mood…), then hands that prompt
|
|
5
|
+
to the image or video generator. This is the "sophisticated" part — you write
|
|
6
|
+
"a fox in a forest" and get a cinematic, well-specified prompt.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from .media import ImageGenerator, MediaResult, VideoGenerator
|
|
12
|
+
from .providers import Message, Provider
|
|
13
|
+
|
|
14
|
+
_IMAGE_ENHANCE = """\
|
|
15
|
+
You are a prompt engineer for text-to-image models. Rewrite the user's brief \
|
|
16
|
+
into ONE vivid, detailed image prompt. Include subject, composition, lighting, \
|
|
17
|
+
color palette, art style, and mood. Keep it under 80 words. Output ONLY the \
|
|
18
|
+
prompt text — no quotes, no preamble, no explanation."""
|
|
19
|
+
|
|
20
|
+
_VIDEO_ENHANCE = """\
|
|
21
|
+
You are a prompt engineer for text-to-video models. Rewrite the user's brief \
|
|
22
|
+
into ONE detailed video prompt describing the scene, subject motion, camera \
|
|
23
|
+
movement, lighting, and mood. Keep it under 80 words. Output ONLY the prompt \
|
|
24
|
+
text — no quotes, no preamble, no explanation."""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MediaAgent:
|
|
28
|
+
"""Enhances a brief with the LLM, then generates media."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, provider: Provider) -> None:
|
|
31
|
+
self.provider = provider
|
|
32
|
+
|
|
33
|
+
def enhance(self, brief: str, kind: str) -> str:
|
|
34
|
+
"""Expand ``brief`` into a detailed prompt for the given media kind."""
|
|
35
|
+
system = _IMAGE_ENHANCE if kind == "image" else _VIDEO_ENHANCE
|
|
36
|
+
messages: list[Message] = [
|
|
37
|
+
{"role": "system", "content": system},
|
|
38
|
+
{"role": "user", "content": brief},
|
|
39
|
+
]
|
|
40
|
+
enhanced = "".join(self.provider.stream_chat(messages)).strip()
|
|
41
|
+
# Strip stray surrounding quotes some models add.
|
|
42
|
+
enhanced = enhanced.strip('"').strip()
|
|
43
|
+
return enhanced or brief
|
|
44
|
+
|
|
45
|
+
def make_image(
|
|
46
|
+
self, brief: str, generator: ImageGenerator, n: int = 1, enhance: bool = True
|
|
47
|
+
) -> MediaResult:
|
|
48
|
+
prompt = self.enhance(brief, "image") if enhance else brief
|
|
49
|
+
return generator.generate(prompt, n=n)
|
|
50
|
+
|
|
51
|
+
def make_video(
|
|
52
|
+
self, brief: str, generator: VideoGenerator, enhance: bool = True
|
|
53
|
+
) -> MediaResult:
|
|
54
|
+
prompt = self.enhance(brief, "video") if enhance else brief
|
|
55
|
+
return generator.generate(prompt)
|
oshell/personas.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Named personas (system-prompt presets).
|
|
2
|
+
|
|
3
|
+
Built-in personas live here; users can add their own with
|
|
4
|
+
``oshell persona add <name> "<system prompt>"`` which writes to a small JSON
|
|
5
|
+
file next to the config.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Dict
|
|
13
|
+
|
|
14
|
+
from .config import CONFIG_DIR
|
|
15
|
+
|
|
16
|
+
PERSONA_PATH = CONFIG_DIR / "personas.json"
|
|
17
|
+
|
|
18
|
+
BUILTIN: Dict[str, str] = {
|
|
19
|
+
"default": "You are a helpful, concise assistant.",
|
|
20
|
+
"reviewer": (
|
|
21
|
+
"You are a meticulous senior code reviewer. Point out bugs, security "
|
|
22
|
+
"issues, edge cases, and style problems. Be direct and specific, cite "
|
|
23
|
+
"line-level concerns, and suggest concrete fixes."
|
|
24
|
+
),
|
|
25
|
+
"teacher": (
|
|
26
|
+
"You are a patient programming teacher. Explain concepts step by step "
|
|
27
|
+
"with simple analogies and short examples. Check understanding and "
|
|
28
|
+
"avoid jargon unless you define it."
|
|
29
|
+
),
|
|
30
|
+
"shell": (
|
|
31
|
+
"You are a command-line expert. Prefer giving the exact command(s) to "
|
|
32
|
+
"run, with a one-line explanation. Assume a competent user."
|
|
33
|
+
),
|
|
34
|
+
"rubber-duck": (
|
|
35
|
+
"You are a rubber-duck debugging partner. Ask probing questions that "
|
|
36
|
+
"help the user reason through their problem rather than giving the "
|
|
37
|
+
"answer outright."
|
|
38
|
+
),
|
|
39
|
+
"concise": "Answer in as few words as possible. No preamble. No filler.",
|
|
40
|
+
"pirate": "You are a witty pirate. Answer correctly, but talk like a pirate.",
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _load_custom() -> Dict[str, str]:
|
|
45
|
+
if not PERSONA_PATH.exists():
|
|
46
|
+
return {}
|
|
47
|
+
try:
|
|
48
|
+
return json.loads(PERSONA_PATH.read_text(encoding="utf-8"))
|
|
49
|
+
except (json.JSONDecodeError, OSError):
|
|
50
|
+
return {}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def all_personas() -> Dict[str, str]:
|
|
54
|
+
"""Built-ins merged with user-defined personas (custom wins)."""
|
|
55
|
+
merged = dict(BUILTIN)
|
|
56
|
+
merged.update(_load_custom())
|
|
57
|
+
return merged
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get(name: str) -> str | None:
|
|
61
|
+
"""Return the system prompt for ``name``, or None if unknown."""
|
|
62
|
+
return all_personas().get(name)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def add(name: str, prompt: str) -> Path:
|
|
66
|
+
"""Create or update a custom persona."""
|
|
67
|
+
custom = _load_custom()
|
|
68
|
+
custom[name] = prompt
|
|
69
|
+
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
|
70
|
+
PERSONA_PATH.write_text(json.dumps(custom, indent=2), encoding="utf-8")
|
|
71
|
+
return PERSONA_PATH
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def remove(name: str) -> bool:
|
|
75
|
+
"""Delete a custom persona. Returns True if it existed."""
|
|
76
|
+
custom = _load_custom()
|
|
77
|
+
if name not in custom:
|
|
78
|
+
return False
|
|
79
|
+
del custom[name]
|
|
80
|
+
PERSONA_PATH.write_text(json.dumps(custom, indent=2), encoding="utf-8")
|
|
81
|
+
return True
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def resolve_system_prompt(cfg) -> str:
|
|
85
|
+
"""Return the active system prompt, honoring a selected persona."""
|
|
86
|
+
if cfg.persona:
|
|
87
|
+
prompt = get(cfg.persona)
|
|
88
|
+
if prompt:
|
|
89
|
+
return prompt
|
|
90
|
+
return cfg.system_prompt
|