prompture 0.0.29.dev8__py3-none-any.whl → 0.0.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prompture/__init__.py +146 -23
- prompture/_version.py +34 -0
- prompture/aio/__init__.py +74 -0
- prompture/async_conversation.py +607 -0
- prompture/async_core.py +803 -0
- prompture/async_driver.py +169 -0
- prompture/cache.py +469 -0
- prompture/callbacks.py +55 -0
- prompture/cli.py +63 -4
- prompture/conversation.py +631 -0
- prompture/core.py +876 -263
- prompture/cost_mixin.py +51 -0
- prompture/discovery.py +164 -0
- prompture/driver.py +168 -5
- prompture/drivers/__init__.py +173 -69
- prompture/drivers/airllm_driver.py +109 -0
- prompture/drivers/async_airllm_driver.py +26 -0
- prompture/drivers/async_azure_driver.py +117 -0
- prompture/drivers/async_claude_driver.py +107 -0
- prompture/drivers/async_google_driver.py +132 -0
- prompture/drivers/async_grok_driver.py +91 -0
- prompture/drivers/async_groq_driver.py +84 -0
- prompture/drivers/async_hugging_driver.py +61 -0
- prompture/drivers/async_lmstudio_driver.py +79 -0
- prompture/drivers/async_local_http_driver.py +44 -0
- prompture/drivers/async_ollama_driver.py +125 -0
- prompture/drivers/async_openai_driver.py +96 -0
- prompture/drivers/async_openrouter_driver.py +96 -0
- prompture/drivers/async_registry.py +129 -0
- prompture/drivers/azure_driver.py +36 -9
- prompture/drivers/claude_driver.py +251 -34
- prompture/drivers/google_driver.py +107 -38
- prompture/drivers/grok_driver.py +29 -32
- prompture/drivers/groq_driver.py +27 -26
- prompture/drivers/hugging_driver.py +6 -6
- prompture/drivers/lmstudio_driver.py +26 -13
- prompture/drivers/local_http_driver.py +6 -6
- prompture/drivers/ollama_driver.py +157 -23
- prompture/drivers/openai_driver.py +178 -9
- prompture/drivers/openrouter_driver.py +31 -25
- prompture/drivers/registry.py +306 -0
- prompture/field_definitions.py +106 -96
- prompture/logging.py +80 -0
- prompture/model_rates.py +217 -0
- prompture/runner.py +49 -47
- prompture/scaffold/__init__.py +1 -0
- prompture/scaffold/generator.py +84 -0
- prompture/scaffold/templates/Dockerfile.j2 +12 -0
- prompture/scaffold/templates/README.md.j2 +41 -0
- prompture/scaffold/templates/config.py.j2 +21 -0
- prompture/scaffold/templates/env.example.j2 +8 -0
- prompture/scaffold/templates/main.py.j2 +86 -0
- prompture/scaffold/templates/models.py.j2 +40 -0
- prompture/scaffold/templates/requirements.txt.j2 +5 -0
- prompture/server.py +183 -0
- prompture/session.py +117 -0
- prompture/settings.py +18 -1
- prompture/tools.py +219 -267
- prompture/tools_schema.py +254 -0
- prompture/validator.py +3 -3
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/METADATA +117 -21
- prompture-0.0.35.dist-info/RECORD +66 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/WHEEL +1 -1
- prompture-0.0.29.dev8.dist-info/RECORD +0 -27
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/entry_points.txt +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/licenses/LICENSE +0 -0
- {prompture-0.0.29.dev8.dist-info → prompture-0.0.35.dist-info}/top_level.txt +0 -0
prompture/model_rates.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""Live model rates from models.dev API with local caching.
|
|
2
|
+
|
|
3
|
+
Fetches pricing and metadata for LLM models from https://models.dev/api.json,
|
|
4
|
+
caches locally with TTL-based auto-refresh, and provides lookup functions
|
|
5
|
+
used by drivers for cost calculations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import contextlib
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import threading
|
|
12
|
+
from datetime import datetime, timezone
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any, Optional
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
# Maps prompture provider names to models.dev provider names
|
|
19
|
+
PROVIDER_MAP: dict[str, str] = {
|
|
20
|
+
"openai": "openai",
|
|
21
|
+
"claude": "anthropic",
|
|
22
|
+
"google": "google",
|
|
23
|
+
"groq": "groq",
|
|
24
|
+
"grok": "xai",
|
|
25
|
+
"azure": "azure",
|
|
26
|
+
"openrouter": "openrouter",
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
_API_URL = "https://models.dev/api.json"
|
|
30
|
+
_CACHE_DIR = Path.home() / ".prompture" / "cache"
|
|
31
|
+
_CACHE_FILE = _CACHE_DIR / "models_dev.json"
|
|
32
|
+
_META_FILE = _CACHE_DIR / "models_dev_meta.json"
|
|
33
|
+
|
|
34
|
+
_lock = threading.Lock()
|
|
35
|
+
_data: Optional[dict[str, Any]] = None
|
|
36
|
+
_loaded = False
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _get_ttl_days() -> int:
|
|
40
|
+
"""Get TTL from settings if available, otherwise default to 7."""
|
|
41
|
+
try:
|
|
42
|
+
from .settings import settings
|
|
43
|
+
|
|
44
|
+
return getattr(settings, "model_rates_ttl_days", 7)
|
|
45
|
+
except Exception:
|
|
46
|
+
return 7
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _cache_is_valid() -> bool:
|
|
50
|
+
"""Check whether the local cache exists and is within TTL."""
|
|
51
|
+
if not _CACHE_FILE.exists() or not _META_FILE.exists():
|
|
52
|
+
return False
|
|
53
|
+
try:
|
|
54
|
+
meta = json.loads(_META_FILE.read_text(encoding="utf-8"))
|
|
55
|
+
fetched_at = datetime.fromisoformat(meta["fetched_at"])
|
|
56
|
+
ttl_days = meta.get("ttl_days", _get_ttl_days())
|
|
57
|
+
age = datetime.now(timezone.utc) - fetched_at
|
|
58
|
+
return age.total_seconds() < ttl_days * 86400
|
|
59
|
+
except Exception:
|
|
60
|
+
return False
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _write_cache(data: dict[str, Any]) -> None:
|
|
64
|
+
"""Write API data and metadata to local cache."""
|
|
65
|
+
try:
|
|
66
|
+
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
67
|
+
_CACHE_FILE.write_text(json.dumps(data), encoding="utf-8")
|
|
68
|
+
meta = {
|
|
69
|
+
"fetched_at": datetime.now(timezone.utc).isoformat(),
|
|
70
|
+
"ttl_days": _get_ttl_days(),
|
|
71
|
+
}
|
|
72
|
+
_META_FILE.write_text(json.dumps(meta), encoding="utf-8")
|
|
73
|
+
except Exception as exc:
|
|
74
|
+
logger.debug("Failed to write model rates cache: %s", exc)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _read_cache() -> Optional[dict[str, Any]]:
|
|
78
|
+
"""Read cached API data from disk."""
|
|
79
|
+
try:
|
|
80
|
+
return json.loads(_CACHE_FILE.read_text(encoding="utf-8"))
|
|
81
|
+
except Exception:
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _fetch_from_api() -> Optional[dict[str, Any]]:
|
|
86
|
+
"""Fetch fresh data from models.dev API."""
|
|
87
|
+
try:
|
|
88
|
+
import requests
|
|
89
|
+
|
|
90
|
+
resp = requests.get(_API_URL, timeout=15)
|
|
91
|
+
resp.raise_for_status()
|
|
92
|
+
return resp.json()
|
|
93
|
+
except Exception as exc:
|
|
94
|
+
logger.debug("Failed to fetch model rates from %s: %s", _API_URL, exc)
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _ensure_loaded() -> Optional[dict[str, Any]]:
|
|
99
|
+
"""Lazy-load data: use cache if valid, otherwise fetch from API."""
|
|
100
|
+
global _data, _loaded
|
|
101
|
+
if _loaded:
|
|
102
|
+
return _data
|
|
103
|
+
|
|
104
|
+
with _lock:
|
|
105
|
+
# Double-check after acquiring lock
|
|
106
|
+
if _loaded:
|
|
107
|
+
return _data
|
|
108
|
+
|
|
109
|
+
if _cache_is_valid():
|
|
110
|
+
_data = _read_cache()
|
|
111
|
+
if _data is not None:
|
|
112
|
+
_loaded = True
|
|
113
|
+
return _data
|
|
114
|
+
|
|
115
|
+
# Cache missing or expired — fetch fresh
|
|
116
|
+
fresh = _fetch_from_api()
|
|
117
|
+
if fresh is not None:
|
|
118
|
+
_data = fresh
|
|
119
|
+
_write_cache(fresh)
|
|
120
|
+
else:
|
|
121
|
+
# Fetch failed — try stale cache as last resort
|
|
122
|
+
_data = _read_cache()
|
|
123
|
+
|
|
124
|
+
_loaded = True
|
|
125
|
+
return _data
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _lookup_model(provider: str, model_id: str) -> Optional[dict[str, Any]]:
|
|
129
|
+
"""Find a model entry in the cached data.
|
|
130
|
+
|
|
131
|
+
The API structure is ``{provider: {model_id: {...}, ...}, ...}``.
|
|
132
|
+
"""
|
|
133
|
+
data = _ensure_loaded()
|
|
134
|
+
if data is None:
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
api_provider = PROVIDER_MAP.get(provider, provider)
|
|
138
|
+
provider_data = data.get(api_provider)
|
|
139
|
+
if not isinstance(provider_data, dict):
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
return provider_data.get(model_id)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
# ── Public API ──────────────────────────────────────────────────────────────
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def get_model_rates(provider: str, model_id: str) -> Optional[dict[str, float]]:
|
|
149
|
+
"""Return pricing dict for a model, or ``None`` if unavailable.
|
|
150
|
+
|
|
151
|
+
Returned keys mirror models.dev cost fields (per 1M tokens):
|
|
152
|
+
``input``, ``output``, and optionally ``cache_read``, ``cache_write``,
|
|
153
|
+
``reasoning``.
|
|
154
|
+
"""
|
|
155
|
+
entry = _lookup_model(provider, model_id)
|
|
156
|
+
if entry is None:
|
|
157
|
+
return None
|
|
158
|
+
|
|
159
|
+
cost = entry.get("cost")
|
|
160
|
+
if not isinstance(cost, dict):
|
|
161
|
+
return None
|
|
162
|
+
|
|
163
|
+
rates: dict[str, float] = {}
|
|
164
|
+
for key in ("input", "output", "cache_read", "cache_write", "reasoning"):
|
|
165
|
+
val = cost.get(key)
|
|
166
|
+
if val is not None:
|
|
167
|
+
with contextlib.suppress(TypeError, ValueError):
|
|
168
|
+
rates[key] = float(val)
|
|
169
|
+
|
|
170
|
+
# Must have at least input and output to be useful
|
|
171
|
+
if "input" in rates and "output" in rates:
|
|
172
|
+
return rates
|
|
173
|
+
return None
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def get_model_info(provider: str, model_id: str) -> Optional[dict[str, Any]]:
|
|
177
|
+
"""Return full model metadata (cost, limits, capabilities), or ``None``."""
|
|
178
|
+
return _lookup_model(provider, model_id)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def get_all_provider_models(provider: str) -> list[str]:
|
|
182
|
+
"""Return list of model IDs available for a provider."""
|
|
183
|
+
data = _ensure_loaded()
|
|
184
|
+
if data is None:
|
|
185
|
+
return []
|
|
186
|
+
|
|
187
|
+
api_provider = PROVIDER_MAP.get(provider, provider)
|
|
188
|
+
provider_data = data.get(api_provider)
|
|
189
|
+
if not isinstance(provider_data, dict):
|
|
190
|
+
return []
|
|
191
|
+
|
|
192
|
+
return list(provider_data.keys())
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def refresh_rates_cache(force: bool = False) -> bool:
|
|
196
|
+
"""Fetch fresh data from models.dev.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
force: If ``True``, fetch even when the cache is still within TTL.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
``True`` if fresh data was fetched and cached successfully.
|
|
203
|
+
"""
|
|
204
|
+
global _data, _loaded
|
|
205
|
+
|
|
206
|
+
with _lock:
|
|
207
|
+
if not force and _cache_is_valid():
|
|
208
|
+
return False
|
|
209
|
+
|
|
210
|
+
fresh = _fetch_from_api()
|
|
211
|
+
if fresh is not None:
|
|
212
|
+
_data = fresh
|
|
213
|
+
_write_cache(fresh)
|
|
214
|
+
_loaded = True
|
|
215
|
+
return True
|
|
216
|
+
|
|
217
|
+
return False
|
prompture/runner.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
"""Test suite runner for executing JSON validation tests across multiple models."""
|
|
2
|
-
from typing import Dict, Any, List
|
|
3
2
|
|
|
4
|
-
from
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
5
|
from prompture.validator import validate_against_schema
|
|
6
6
|
|
|
7
|
+
from .core import Driver, ask_for_json
|
|
8
|
+
|
|
7
9
|
|
|
8
|
-
def run_suite_from_spec(spec:
|
|
10
|
+
def run_suite_from_spec(spec: dict[str, Any], drivers: dict[str, Driver]) -> dict[str, Any]:
|
|
9
11
|
"""Run a test suite specified by a spec dictionary across multiple models.
|
|
10
|
-
|
|
12
|
+
|
|
11
13
|
Args:
|
|
12
14
|
spec: A dictionary containing the test suite specification with the structure:
|
|
13
15
|
{
|
|
@@ -21,7 +23,7 @@ def run_suite_from_spec(spec: Dict[str, Any], drivers: Dict[str, Driver]) -> Dic
|
|
|
21
23
|
}, ...]
|
|
22
24
|
}
|
|
23
25
|
drivers: A dictionary mapping driver names to driver instances
|
|
24
|
-
|
|
26
|
+
|
|
25
27
|
Returns:
|
|
26
28
|
A dictionary containing test results with the structure:
|
|
27
29
|
{
|
|
@@ -42,67 +44,67 @@ def run_suite_from_spec(spec: Dict[str, Any], drivers: Dict[str, Driver]) -> Dic
|
|
|
42
44
|
}
|
|
43
45
|
"""
|
|
44
46
|
results = []
|
|
45
|
-
|
|
47
|
+
|
|
46
48
|
for test in spec["tests"]:
|
|
47
49
|
for model in spec["models"]:
|
|
48
50
|
driver = drivers.get(model["driver"])
|
|
49
51
|
if not driver:
|
|
50
52
|
continue
|
|
51
|
-
|
|
53
|
+
|
|
52
54
|
# Run test for each input
|
|
53
55
|
for input_data in test["inputs"]:
|
|
54
56
|
# Format prompt template with input data
|
|
55
57
|
try:
|
|
56
58
|
prompt = test["prompt_template"].format(**input_data)
|
|
57
59
|
except KeyError as e:
|
|
58
|
-
results.append(
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
60
|
+
results.append(
|
|
61
|
+
{
|
|
62
|
+
"test_id": test["id"],
|
|
63
|
+
"model_id": model["id"],
|
|
64
|
+
"input": input_data,
|
|
65
|
+
"prompt": test["prompt_template"],
|
|
66
|
+
"error": f"Template formatting error: missing key {e}",
|
|
67
|
+
"validation": {"ok": False, "error": "Prompt formatting failed", "data": None},
|
|
68
|
+
"usage": {"total_tokens": 0, "cost": 0},
|
|
69
|
+
}
|
|
70
|
+
)
|
|
67
71
|
continue
|
|
68
|
-
|
|
72
|
+
|
|
69
73
|
# Get JSON response from model
|
|
70
74
|
try:
|
|
71
75
|
response = ask_for_json(
|
|
72
76
|
driver=driver,
|
|
73
77
|
content_prompt=prompt,
|
|
74
78
|
json_schema=test["schema"],
|
|
75
|
-
options=model.get("options", {})
|
|
79
|
+
options=model.get("options", {}),
|
|
76
80
|
)
|
|
77
|
-
|
|
81
|
+
|
|
78
82
|
# Validate response against schema
|
|
79
|
-
validation = validate_against_schema(
|
|
80
|
-
|
|
81
|
-
|
|
83
|
+
validation = validate_against_schema(response["json_string"], test["schema"])
|
|
84
|
+
|
|
85
|
+
results.append(
|
|
86
|
+
{
|
|
87
|
+
"test_id": test["id"],
|
|
88
|
+
"model_id": model["id"],
|
|
89
|
+
"input": input_data,
|
|
90
|
+
"prompt": prompt,
|
|
91
|
+
"response": response["json_object"],
|
|
92
|
+
"validation": validation,
|
|
93
|
+
"usage": response["usage"],
|
|
94
|
+
}
|
|
82
95
|
)
|
|
83
|
-
|
|
84
|
-
results.append({
|
|
85
|
-
"test_id": test["id"],
|
|
86
|
-
"model_id": model["id"],
|
|
87
|
-
"input": input_data,
|
|
88
|
-
"prompt": prompt,
|
|
89
|
-
"response": response["json_object"],
|
|
90
|
-
"validation": validation,
|
|
91
|
-
"usage": response["usage"]
|
|
92
|
-
})
|
|
93
|
-
|
|
96
|
+
|
|
94
97
|
except Exception as e:
|
|
95
|
-
results.append(
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
}
|
|
98
|
+
results.append(
|
|
99
|
+
{
|
|
100
|
+
"test_id": test["id"],
|
|
101
|
+
"model_id": model["id"],
|
|
102
|
+
"input": input_data,
|
|
103
|
+
"prompt": prompt,
|
|
104
|
+
"error": str(e),
|
|
105
|
+
"validation": {"ok": False, "error": "Model response error", "data": None},
|
|
106
|
+
"usage": {"total_tokens": 0, "cost": 0},
|
|
107
|
+
}
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return {"meta": spec.get("meta", {}), "results": results}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Project scaffolding for Prompture-based FastAPI apps."""
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Project scaffolding generator.
|
|
2
|
+
|
|
3
|
+
Renders Jinja2 templates into a standalone FastAPI project directory
|
|
4
|
+
that users can customize and deploy.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from jinja2 import Environment, FileSystemLoader
|
|
13
|
+
except ImportError:
|
|
14
|
+
Environment = None # type: ignore[assignment,misc]
|
|
15
|
+
FileSystemLoader = None # type: ignore[assignment,misc]
|
|
16
|
+
|
|
17
|
+
_TEMPLATES_DIR = Path(__file__).parent / "templates"
|
|
18
|
+
|
|
19
|
+
# Map from template file -> output path (relative to project root).
|
|
20
|
+
_FILE_MAP = {
|
|
21
|
+
"main.py.j2": "app/main.py",
|
|
22
|
+
"models.py.j2": "app/models.py",
|
|
23
|
+
"config.py.j2": "app/config.py",
|
|
24
|
+
"requirements.txt.j2": "requirements.txt",
|
|
25
|
+
"env.example.j2": ".env.example",
|
|
26
|
+
"README.md.j2": "README.md",
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
_DOCKER_FILES = {
|
|
30
|
+
"Dockerfile.j2": "Dockerfile",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def scaffold_project(
|
|
35
|
+
output_dir: str,
|
|
36
|
+
project_name: str = "my_app",
|
|
37
|
+
model_name: str = "openai/gpt-4o-mini",
|
|
38
|
+
include_docker: bool = True,
|
|
39
|
+
) -> Path:
|
|
40
|
+
"""Render all templates and write the project to *output_dir*.
|
|
41
|
+
|
|
42
|
+
Parameters:
|
|
43
|
+
output_dir: Destination directory (created if needed).
|
|
44
|
+
project_name: Human-friendly project name used in templates.
|
|
45
|
+
model_name: Default model string baked into config.
|
|
46
|
+
include_docker: Whether to include Dockerfile.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
The :class:`Path` to the generated project root.
|
|
50
|
+
"""
|
|
51
|
+
if Environment is None:
|
|
52
|
+
raise ImportError("jinja2 is required for scaffolding: pip install prompture[scaffold]")
|
|
53
|
+
|
|
54
|
+
env = Environment(
|
|
55
|
+
loader=FileSystemLoader(str(_TEMPLATES_DIR)),
|
|
56
|
+
keep_trailing_newline=True,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
context = {
|
|
60
|
+
"project_name": project_name,
|
|
61
|
+
"model_name": model_name,
|
|
62
|
+
"include_docker": include_docker,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
out = Path(output_dir)
|
|
66
|
+
|
|
67
|
+
file_map = dict(_FILE_MAP)
|
|
68
|
+
if include_docker:
|
|
69
|
+
file_map.update(_DOCKER_FILES)
|
|
70
|
+
|
|
71
|
+
for template_name, rel_path in file_map.items():
|
|
72
|
+
template = env.get_template(template_name)
|
|
73
|
+
rendered = template.render(**context)
|
|
74
|
+
|
|
75
|
+
dest = out / rel_path
|
|
76
|
+
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
77
|
+
dest.write_text(rendered, encoding="utf-8")
|
|
78
|
+
|
|
79
|
+
# Create empty __init__.py for the app package
|
|
80
|
+
init_path = out / "app" / "__init__.py"
|
|
81
|
+
if not init_path.exists():
|
|
82
|
+
init_path.write_text("", encoding="utf-8")
|
|
83
|
+
|
|
84
|
+
return out
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# {{ project_name }}
|
|
2
|
+
|
|
3
|
+
A FastAPI server powered by [Prompture](https://github.com/jhd3197/prompture) for structured LLM output.
|
|
4
|
+
|
|
5
|
+
## Quick start
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
# Install dependencies
|
|
9
|
+
pip install -r requirements.txt
|
|
10
|
+
|
|
11
|
+
# Copy and edit environment config
|
|
12
|
+
cp .env.example .env
|
|
13
|
+
|
|
14
|
+
# Run the server
|
|
15
|
+
uvicorn app.main:app --reload
|
|
16
|
+
```
|
|
17
|
+
|
|
18
|
+
## API endpoints
|
|
19
|
+
|
|
20
|
+
| Method | Path | Description |
|
|
21
|
+
|--------|------|-------------|
|
|
22
|
+
| POST | `/v1/chat` | Send a message, get a response |
|
|
23
|
+
| POST | `/v1/extract` | Extract structured JSON with schema |
|
|
24
|
+
| GET | `/v1/conversations/{id}` | Get conversation history |
|
|
25
|
+
| DELETE | `/v1/conversations/{id}` | Delete a conversation |
|
|
26
|
+
|
|
27
|
+
## Example
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
curl -X POST http://localhost:8000/v1/chat \
|
|
31
|
+
-H "Content-Type: application/json" \
|
|
32
|
+
-d '{"message": "Hello!"}'
|
|
33
|
+
```
|
|
34
|
+
{% if include_docker %}
|
|
35
|
+
## Docker
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
docker build -t {{ project_name }} .
|
|
39
|
+
docker run -p 8000:8000 --env-file .env {{ project_name }}
|
|
40
|
+
```
|
|
41
|
+
{% endif %}
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Configuration for {{ project_name }}."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pydantic_settings import BaseSettings
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Settings(BaseSettings):
|
|
9
|
+
model_name: str = "{{ model_name }}"
|
|
10
|
+
system_prompt: str = "You are a helpful assistant."
|
|
11
|
+
cors_origins: list[str] = ["*"]
|
|
12
|
+
|
|
13
|
+
# Provider API keys (loaded from environment / .env)
|
|
14
|
+
openai_api_key: str = ""
|
|
15
|
+
claude_api_key: str = ""
|
|
16
|
+
google_api_key: str = ""
|
|
17
|
+
|
|
18
|
+
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
settings = Settings()
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""{{ project_name }} -- FastAPI server powered by Prompture."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import uuid
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from fastapi import FastAPI, HTTPException
|
|
10
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
11
|
+
|
|
12
|
+
from .config import settings
|
|
13
|
+
from .models import (
|
|
14
|
+
ChatRequest,
|
|
15
|
+
ChatResponse,
|
|
16
|
+
ConversationHistory,
|
|
17
|
+
ExtractRequest,
|
|
18
|
+
ExtractResponse,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from prompture import AsyncConversation
|
|
22
|
+
|
|
23
|
+
app = FastAPI(title="{{ project_name }}", version="0.1.0")
|
|
24
|
+
|
|
25
|
+
app.add_middleware(
|
|
26
|
+
CORSMiddleware,
|
|
27
|
+
allow_origins=settings.cors_origins,
|
|
28
|
+
allow_credentials=True,
|
|
29
|
+
allow_methods=["*"],
|
|
30
|
+
allow_headers=["*"],
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
_conversations: dict[str, AsyncConversation] = {}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _get_or_create_conversation(conv_id: str | None) -> tuple[str, AsyncConversation]:
|
|
37
|
+
if conv_id and conv_id in _conversations:
|
|
38
|
+
return conv_id, _conversations[conv_id]
|
|
39
|
+
new_id = conv_id or uuid.uuid4().hex[:12]
|
|
40
|
+
conv = AsyncConversation(
|
|
41
|
+
model_name=settings.model_name,
|
|
42
|
+
system_prompt=settings.system_prompt,
|
|
43
|
+
)
|
|
44
|
+
_conversations[new_id] = conv
|
|
45
|
+
return new_id, conv
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@app.post("/v1/chat", response_model=ChatResponse)
|
|
49
|
+
async def chat(request: ChatRequest):
|
|
50
|
+
conv_id, conv = _get_or_create_conversation(request.conversation_id)
|
|
51
|
+
text = await conv.ask(request.message, request.options)
|
|
52
|
+
return ChatResponse(message=text, conversation_id=conv_id, usage=conv.usage)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@app.post("/v1/extract", response_model=ExtractResponse)
|
|
56
|
+
async def extract(request: ExtractRequest):
|
|
57
|
+
conv_id, conv = _get_or_create_conversation(request.conversation_id)
|
|
58
|
+
result = await conv.ask_for_json(
|
|
59
|
+
content=request.text,
|
|
60
|
+
json_schema=request.schema_def,
|
|
61
|
+
)
|
|
62
|
+
return ExtractResponse(
|
|
63
|
+
json_object=result["json_object"],
|
|
64
|
+
conversation_id=conv_id,
|
|
65
|
+
usage=conv.usage,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@app.get("/v1/conversations/{conversation_id}", response_model=ConversationHistory)
|
|
70
|
+
async def get_conversation(conversation_id: str):
|
|
71
|
+
if conversation_id not in _conversations:
|
|
72
|
+
raise HTTPException(status_code=404, detail="Conversation not found")
|
|
73
|
+
conv = _conversations[conversation_id]
|
|
74
|
+
return ConversationHistory(
|
|
75
|
+
conversation_id=conversation_id,
|
|
76
|
+
messages=conv.messages,
|
|
77
|
+
usage=conv.usage,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@app.delete("/v1/conversations/{conversation_id}")
|
|
82
|
+
async def delete_conversation(conversation_id: str):
|
|
83
|
+
if conversation_id not in _conversations:
|
|
84
|
+
raise HTTPException(status_code=404, detail="Conversation not found")
|
|
85
|
+
del _conversations[conversation_id]
|
|
86
|
+
return {"status": "deleted", "conversation_id": conversation_id}
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Pydantic request/response models for {{ project_name }}."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ChatRequest(BaseModel):
|
|
11
|
+
message: str
|
|
12
|
+
conversation_id: str | None = None
|
|
13
|
+
stream: bool = False
|
|
14
|
+
options: dict[str, Any] | None = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ChatResponse(BaseModel):
|
|
18
|
+
message: str
|
|
19
|
+
conversation_id: str
|
|
20
|
+
usage: dict[str, Any]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ExtractRequest(BaseModel):
|
|
24
|
+
text: str
|
|
25
|
+
schema_def: dict[str, Any] = Field(..., alias="schema")
|
|
26
|
+
conversation_id: str | None = None
|
|
27
|
+
|
|
28
|
+
model_config = {"populate_by_name": True}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ExtractResponse(BaseModel):
|
|
32
|
+
json_object: dict[str, Any]
|
|
33
|
+
conversation_id: str
|
|
34
|
+
usage: dict[str, Any]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ConversationHistory(BaseModel):
|
|
38
|
+
conversation_id: str
|
|
39
|
+
messages: list[dict[str, Any]]
|
|
40
|
+
usage: dict[str, Any]
|