freelm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- freelm/__init__.py +63 -0
- freelm/_backoff.py +13 -0
- freelm/_breaker.py +41 -0
- freelm/_cache.py +62 -0
- freelm/_engine.py +87 -0
- freelm/_keys.py +94 -0
- freelm/_ratelimit.py +52 -0
- freelm/_types.py +119 -0
- freelm/client.py +272 -0
- freelm/compat/__init__.py +3 -0
- freelm/compat/openai.py +76 -0
- freelm/config.py +53 -0
- freelm/discovery.py +149 -0
- freelm/errors.py +124 -0
- freelm/providers/__init__.py +6 -0
- freelm/providers/base.py +133 -0
- freelm/providers/google.py +34 -0
- freelm/providers/nim.py +30 -0
- freelm/providers/openrouter.py +49 -0
- freelm/py.typed +0 -0
- freelm/registry.py +53 -0
- freelm/strategy.py +56 -0
- freelm/types_compat.py +83 -0
- freelm-0.1.0.dist-info/METADATA +182 -0
- freelm-0.1.0.dist-info/RECORD +27 -0
- freelm-0.1.0.dist-info/WHEEL +4 -0
- freelm-0.1.0.dist-info/licenses/LICENSE +21 -0
freelm/__init__.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""freelm — one always-up LLM client over free-tier providers.
|
|
2
|
+
|
|
3
|
+
Quick start::
|
|
4
|
+
|
|
5
|
+
import freelm
|
|
6
|
+
llm = freelm.FreeLLM.from_env()
|
|
7
|
+
print(llm.text("Explain black holes in one sentence."))
|
|
8
|
+
|
|
9
|
+
Explicit config::
|
|
10
|
+
|
|
11
|
+
from freelm import FreeLLM, OpenRouter, GoogleAIStudio, NIM
|
|
12
|
+
llm = FreeLLM(
|
|
13
|
+
providers=[OpenRouter("sk-or-..."), GoogleAIStudio("AIza..."), NIM("nvapi-...")],
|
|
14
|
+
strategy="quota_aware",
|
|
15
|
+
)
|
|
16
|
+
"""
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from ._types import ChatRequest, ChatResponse, Choice, Message, Usage
|
|
20
|
+
from .client import AsyncFreeLLM, FreeLLM
|
|
21
|
+
from .config import providers_from_env
|
|
22
|
+
from .discovery import list_free_models
|
|
23
|
+
from .errors import (
|
|
24
|
+
AuthError,
|
|
25
|
+
ConfigError,
|
|
26
|
+
FreeLLMError,
|
|
27
|
+
ModelNotFound,
|
|
28
|
+
NoProvidersAvailable,
|
|
29
|
+
ProviderError,
|
|
30
|
+
RateLimited,
|
|
31
|
+
Transient,
|
|
32
|
+
)
|
|
33
|
+
from .providers import Gemini, GoogleAIStudio, NIM, OpenRouter, Provider
|
|
34
|
+
from .registry import ModelSpec
|
|
35
|
+
|
|
36
|
+
__version__ = "0.1.0"
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"FreeLLM",
|
|
40
|
+
"AsyncFreeLLM",
|
|
41
|
+
"Provider",
|
|
42
|
+
"OpenRouter",
|
|
43
|
+
"GoogleAIStudio",
|
|
44
|
+
"Gemini",
|
|
45
|
+
"NIM",
|
|
46
|
+
"ModelSpec",
|
|
47
|
+
"Message",
|
|
48
|
+
"ChatRequest",
|
|
49
|
+
"ChatResponse",
|
|
50
|
+
"Choice",
|
|
51
|
+
"Usage",
|
|
52
|
+
"providers_from_env",
|
|
53
|
+
"list_free_models",
|
|
54
|
+
"FreeLLMError",
|
|
55
|
+
"ConfigError",
|
|
56
|
+
"ProviderError",
|
|
57
|
+
"AuthError",
|
|
58
|
+
"RateLimited",
|
|
59
|
+
"Transient",
|
|
60
|
+
"ModelNotFound",
|
|
61
|
+
"NoProvidersAvailable",
|
|
62
|
+
"__version__",
|
|
63
|
+
]
|
freelm/_backoff.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Exponential backoff with full jitter."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import random
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def compute_delay(attempt: int, base: float = 0.5, factor: float = 2.0, cap: float = 30.0, jitter: bool = True) -> float:
|
|
8
|
+
"""Delay (seconds) for a given retry attempt (1-based)."""
|
|
9
|
+
attempt = max(1, attempt)
|
|
10
|
+
raw = min(cap, base * (factor ** (attempt - 1)))
|
|
11
|
+
if jitter:
|
|
12
|
+
return random.uniform(0.0, raw)
|
|
13
|
+
return raw
|
freelm/_breaker.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Per-key circuit breaker. Time is injected (monotonic seconds) for testability."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
CLOSED = "closed"
|
|
7
|
+
OPEN = "open"
|
|
8
|
+
HALF_OPEN = "half_open"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class CircuitBreaker:
|
|
13
|
+
fail_threshold: int = 4
|
|
14
|
+
cooldown: float = 30.0
|
|
15
|
+
state: str = CLOSED
|
|
16
|
+
failures: int = 0
|
|
17
|
+
opened_at: float = 0.0
|
|
18
|
+
|
|
19
|
+
def allow(self, now: float) -> bool:
|
|
20
|
+
"""May a request go through right now?"""
|
|
21
|
+
if self.state == OPEN:
|
|
22
|
+
if now - self.opened_at >= self.cooldown:
|
|
23
|
+
self.state = HALF_OPEN
|
|
24
|
+
return True
|
|
25
|
+
return False
|
|
26
|
+
return True
|
|
27
|
+
|
|
28
|
+
def on_success(self) -> None:
|
|
29
|
+
self.failures = 0
|
|
30
|
+
self.state = CLOSED
|
|
31
|
+
|
|
32
|
+
def on_failure(self, now: float) -> None:
|
|
33
|
+
self.failures += 1
|
|
34
|
+
if self.state == HALF_OPEN or self.failures >= self.fail_threshold:
|
|
35
|
+
self.state = OPEN
|
|
36
|
+
self.opened_at = now
|
|
37
|
+
|
|
38
|
+
def time_until_half_open(self, now: float) -> float:
|
|
39
|
+
if self.state != OPEN:
|
|
40
|
+
return 0.0
|
|
41
|
+
return max(0.0, self.cooldown - (now - self.opened_at))
|
freelm/_cache.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Tiny TTL disk cache for discovered model lists.
|
|
2
|
+
|
|
3
|
+
Mirrors the openrouter-free skill: JSON file, TTL (default 1h, env override),
|
|
4
|
+
0o600 perms. Path: $FREELM_CACHE_DIR or ~/.cache/freelm/models-<provider>.json
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import time
|
|
11
|
+
from typing import Any, List, Optional
|
|
12
|
+
|
|
13
|
+
DEFAULT_TTL = 3600.0
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def cache_dir() -> str:
|
|
17
|
+
return os.environ.get("FREELM_CACHE_DIR") or os.path.join(os.path.expanduser("~"), ".cache", "freelm")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def default_ttl() -> float:
|
|
21
|
+
raw = os.environ.get("FREELM_CACHE_TTL")
|
|
22
|
+
if not raw:
|
|
23
|
+
return DEFAULT_TTL
|
|
24
|
+
try:
|
|
25
|
+
return float(raw)
|
|
26
|
+
except ValueError:
|
|
27
|
+
return DEFAULT_TTL
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _path(name: str) -> str:
|
|
31
|
+
safe = name.replace("/", "_")
|
|
32
|
+
return os.path.join(cache_dir(), f"models-{safe}.json")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def load(name: str) -> Optional[List[Any]]:
|
|
36
|
+
try:
|
|
37
|
+
with open(_path(name), "r", encoding="utf-8") as f:
|
|
38
|
+
entry = json.load(f)
|
|
39
|
+
if time.time() > entry.get("expires_at", 0):
|
|
40
|
+
return None
|
|
41
|
+
return entry.get("data")
|
|
42
|
+
except (OSError, ValueError):
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def save(name: str, data: List[Any], ttl: Optional[float] = None) -> None:
|
|
47
|
+
p = _path(name)
|
|
48
|
+
try:
|
|
49
|
+
os.makedirs(os.path.dirname(p), exist_ok=True)
|
|
50
|
+
entry = {"data": data, "expires_at": time.time() + (ttl if ttl is not None else default_ttl())}
|
|
51
|
+
with open(p, "w", encoding="utf-8") as f:
|
|
52
|
+
json.dump(entry, f)
|
|
53
|
+
os.chmod(p, 0o600)
|
|
54
|
+
except OSError:
|
|
55
|
+
pass # cache is best-effort; never fatal
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def clear(name: str) -> None:
|
|
59
|
+
try:
|
|
60
|
+
os.remove(_path(name))
|
|
61
|
+
except OSError:
|
|
62
|
+
pass
|
freelm/_engine.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Pure (no-I/O) orchestration helpers shared by the sync and async clients.
|
|
2
|
+
|
|
3
|
+
The actual HTTP call differs between sync/async, but candidate selection and
|
|
4
|
+
post-attempt state updates are identical, so they live here.
|
|
5
|
+
"""
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
9
|
+
|
|
10
|
+
from ._backoff import compute_delay
|
|
11
|
+
from .errors import AuthError, ModelNotFound, ProviderError, RateLimited, Transient
|
|
12
|
+
from .strategy import Candidate, order_candidates
|
|
13
|
+
|
|
14
|
+
TriedKey = Tuple[str, str, str]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def select_candidate(
|
|
18
|
+
providers: List[Any],
|
|
19
|
+
strategy: str,
|
|
20
|
+
rr: Dict[str, int],
|
|
21
|
+
alias: str,
|
|
22
|
+
tried: Set[TriedKey],
|
|
23
|
+
now: float,
|
|
24
|
+
) -> Optional[Candidate]:
|
|
25
|
+
"""First ready candidate not already tried this call, in strategy order."""
|
|
26
|
+
for c in order_candidates(providers, alias, now, strategy, rr):
|
|
27
|
+
if (c.provider.name, c.key.key, c.model) in tried:
|
|
28
|
+
continue
|
|
29
|
+
if c.key.ready(now):
|
|
30
|
+
return c
|
|
31
|
+
return None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def soonest_wait(providers: List[Any], now: float) -> Optional[float]:
|
|
35
|
+
"""Smallest wait until *some* non-disabled key becomes ready, or None."""
|
|
36
|
+
waits: List[float] = []
|
|
37
|
+
for p in providers:
|
|
38
|
+
for k in p.keys:
|
|
39
|
+
w = k.wait_time(now)
|
|
40
|
+
if w is not None:
|
|
41
|
+
waits.append(w)
|
|
42
|
+
return min(waits) if waits else None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def apply_success(cand: Candidate, latency_ms: float) -> None:
|
|
46
|
+
k = cand.key
|
|
47
|
+
k.breaker.on_success()
|
|
48
|
+
k.last_error = None
|
|
49
|
+
k.ewma_latency = latency_ms if k.ewma_latency == 0 else 0.7 * k.ewma_latency + 0.3 * latency_ms
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def apply_error(cand: Candidate, exc: ProviderError, now: float) -> None:
|
|
53
|
+
"""Update key state after a failed attempt. Returns nothing; raising is the
|
|
54
|
+
caller's decision (see ``should_raise``)."""
|
|
55
|
+
k = cand.key
|
|
56
|
+
if isinstance(exc, AuthError):
|
|
57
|
+
k.disabled = True
|
|
58
|
+
k.last_error = f"auth:{exc.status}"
|
|
59
|
+
elif isinstance(exc, RateLimited):
|
|
60
|
+
if getattr(exc, "scope", "key") == "model":
|
|
61
|
+
# only this model is throttled upstream — keep the key hot, the
|
|
62
|
+
# 'tried' set steers us to the next model on the same key.
|
|
63
|
+
k.last_error = "rate_limited:model"
|
|
64
|
+
else:
|
|
65
|
+
k.cooldown_until = now + (exc.retry_after or 60.0)
|
|
66
|
+
k.last_error = "rate_limited"
|
|
67
|
+
elif isinstance(exc, ModelNotFound):
|
|
68
|
+
k.last_error = "model_missing" # don't penalise the key for a bad model id
|
|
69
|
+
elif isinstance(exc, Transient):
|
|
70
|
+
k.breaker.on_failure(now)
|
|
71
|
+
delay = exc.retry_after if exc.retry_after is not None else compute_delay(k.breaker.failures)
|
|
72
|
+
k.cooldown_until = now + min(30.0, delay)
|
|
73
|
+
k.last_error = f"transient:{exc.status}"
|
|
74
|
+
else: # non-retryable ProviderError
|
|
75
|
+
k.breaker.on_failure(now)
|
|
76
|
+
k.last_error = f"error:{exc.status}"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def should_raise(exc: ProviderError) -> bool:
|
|
80
|
+
"""A non-retryable, non-model error (e.g. malformed 400/422) is a caller bug:
|
|
81
|
+
bail immediately instead of burning every key on the same broken request.
|
|
82
|
+
|
|
83
|
+
Auth errors are *not* fatal to the whole call — the key is disabled and we
|
|
84
|
+
fail over to other keys/providers."""
|
|
85
|
+
if isinstance(exc, AuthError):
|
|
86
|
+
return False
|
|
87
|
+
return not exc.retryable and not exc.model_missing
|
freelm/_keys.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""Per-key runtime state: breaker + rpm bucket + daily quota + cooldowns."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
from ._breaker import CircuitBreaker
|
|
8
|
+
from ._ratelimit import TokenBucket
|
|
9
|
+
|
|
10
|
+
DAY = 86400.0
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class KeyState:
|
|
15
|
+
key: str
|
|
16
|
+
tier: str = "free"
|
|
17
|
+
breaker: CircuitBreaker = field(default_factory=CircuitBreaker)
|
|
18
|
+
bucket: Optional[TokenBucket] = None
|
|
19
|
+
rpd: Optional[int] = None # requests-per-day cap (None = unknown/unlimited)
|
|
20
|
+
rpd_used: int = 0
|
|
21
|
+
rpd_reset: float = 0.0 # monotonic ts at which the daily counter rolls
|
|
22
|
+
cooldown_until: float = 0.0
|
|
23
|
+
disabled: bool = False # hard-off after auth failure
|
|
24
|
+
ewma_latency: float = 0.0
|
|
25
|
+
last_error: Optional[str] = None
|
|
26
|
+
|
|
27
|
+
# -- daily window ----------------------------------------------------
|
|
28
|
+
def _roll_daily(self, now: float) -> None:
|
|
29
|
+
if self.rpd is None:
|
|
30
|
+
return
|
|
31
|
+
if self.rpd_reset == 0.0:
|
|
32
|
+
self.rpd_reset = now + DAY
|
|
33
|
+
elif now >= self.rpd_reset:
|
|
34
|
+
self.rpd_used = 0
|
|
35
|
+
self.rpd_reset = now + DAY
|
|
36
|
+
|
|
37
|
+
# -- gating ----------------------------------------------------------
|
|
38
|
+
def ready(self, now: float) -> bool:
|
|
39
|
+
if self.disabled:
|
|
40
|
+
return False
|
|
41
|
+
if now < self.cooldown_until:
|
|
42
|
+
return False
|
|
43
|
+
if not self.breaker.allow(now):
|
|
44
|
+
return False
|
|
45
|
+
self._roll_daily(now)
|
|
46
|
+
if self.rpd is not None and self.rpd_used >= self.rpd:
|
|
47
|
+
return False
|
|
48
|
+
if self.bucket is not None and self.bucket.peek(now) < 1:
|
|
49
|
+
return False
|
|
50
|
+
return True
|
|
51
|
+
|
|
52
|
+
def reserve(self, now: float) -> bool:
|
|
53
|
+
"""Consume one rpm token + one daily slot just before firing a request."""
|
|
54
|
+
self._roll_daily(now)
|
|
55
|
+
if self.bucket is not None and not self.bucket.consume(1, now):
|
|
56
|
+
return False
|
|
57
|
+
self.rpd_used += 1
|
|
58
|
+
return True
|
|
59
|
+
|
|
60
|
+
def remaining(self, now: float) -> float:
|
|
61
|
+
"""A rough 'how much headroom' score for quota-aware routing."""
|
|
62
|
+
self._roll_daily(now)
|
|
63
|
+
daily = float("inf") if self.rpd is None else float(max(0, self.rpd - self.rpd_used))
|
|
64
|
+
burst = self.bucket.peek(now) if self.bucket is not None else float("inf")
|
|
65
|
+
return min(daily, burst)
|
|
66
|
+
|
|
67
|
+
def wait_time(self, now: float) -> Optional[float]:
|
|
68
|
+
"""Seconds until this key could be ready again, or None if permanently off."""
|
|
69
|
+
if self.disabled:
|
|
70
|
+
return None
|
|
71
|
+
waits = []
|
|
72
|
+
if now < self.cooldown_until:
|
|
73
|
+
waits.append(self.cooldown_until - now)
|
|
74
|
+
waits.append(self.breaker.time_until_half_open(now))
|
|
75
|
+
self._roll_daily(now)
|
|
76
|
+
if self.rpd is not None and self.rpd_used >= self.rpd:
|
|
77
|
+
waits.append(max(0.0, self.rpd_reset - now))
|
|
78
|
+
if self.bucket is not None and self.bucket.peek(now) < 1:
|
|
79
|
+
waits.append(self.bucket.time_until(1, now))
|
|
80
|
+
return max(waits) if waits else 0.0
|
|
81
|
+
|
|
82
|
+
def masked(self) -> str:
|
|
83
|
+
k = self.key
|
|
84
|
+
return (k[:6] + "..." + k[-4:]) if len(k) > 12 else "***"
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def new_key_state(key: str, *, tier: str, rpm: Optional[float], rpd: Optional[int]) -> KeyState:
|
|
88
|
+
return KeyState(
|
|
89
|
+
key=key,
|
|
90
|
+
tier=tier,
|
|
91
|
+
breaker=CircuitBreaker(),
|
|
92
|
+
bucket=TokenBucket(rate_per_min=rpm) if rpm else None,
|
|
93
|
+
rpd=rpd,
|
|
94
|
+
)
|
freelm/_ratelimit.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Token bucket for requests-per-minute pacing. Time injected for testability."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import threading
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class TokenBucket:
|
|
11
|
+
rate_per_min: float
|
|
12
|
+
capacity: Optional[float] = None
|
|
13
|
+
tokens: Optional[float] = None
|
|
14
|
+
updated: float = 0.0
|
|
15
|
+
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False, compare=False)
|
|
16
|
+
|
|
17
|
+
def __post_init__(self) -> None:
|
|
18
|
+
if self.capacity is None:
|
|
19
|
+
self.capacity = max(1.0, float(self.rate_per_min))
|
|
20
|
+
if self.tokens is None:
|
|
21
|
+
self.tokens = self.capacity
|
|
22
|
+
|
|
23
|
+
def _refill(self, now: float) -> None:
|
|
24
|
+
if self.updated == 0.0:
|
|
25
|
+
self.updated = now
|
|
26
|
+
return
|
|
27
|
+
dt = now - self.updated
|
|
28
|
+
if dt <= 0:
|
|
29
|
+
return
|
|
30
|
+
self.tokens = min(self.capacity, self.tokens + dt * (self.rate_per_min / 60.0)) # type: ignore[operator]
|
|
31
|
+
self.updated = now
|
|
32
|
+
|
|
33
|
+
def peek(self, now: float) -> float:
|
|
34
|
+
with self._lock:
|
|
35
|
+
self._refill(now)
|
|
36
|
+
return self.tokens # type: ignore[return-value]
|
|
37
|
+
|
|
38
|
+
def consume(self, n: float, now: float) -> bool:
|
|
39
|
+
with self._lock:
|
|
40
|
+
self._refill(now)
|
|
41
|
+
if self.tokens >= n: # type: ignore[operator]
|
|
42
|
+
self.tokens -= n # type: ignore[operator]
|
|
43
|
+
return True
|
|
44
|
+
return False
|
|
45
|
+
|
|
46
|
+
def time_until(self, n: float, now: float) -> float:
|
|
47
|
+
with self._lock:
|
|
48
|
+
self._refill(now)
|
|
49
|
+
if self.tokens >= n: # type: ignore[operator]
|
|
50
|
+
return 0.0
|
|
51
|
+
deficit = n - self.tokens # type: ignore[operator]
|
|
52
|
+
return deficit / (self.rate_per_min / 60.0)
|
freelm/_types.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""Provider-agnostic data types (OpenAI-shaped, but pure dataclasses)."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any, Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class Message:
|
|
10
|
+
role: str
|
|
11
|
+
content: Optional[str] = None
|
|
12
|
+
name: Optional[str] = None
|
|
13
|
+
tool_calls: Optional[List[Dict[str, Any]]] = None
|
|
14
|
+
tool_call_id: Optional[str] = None
|
|
15
|
+
|
|
16
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
17
|
+
d: Dict[str, Any] = {"role": self.role}
|
|
18
|
+
if self.content is not None:
|
|
19
|
+
d["content"] = self.content
|
|
20
|
+
if self.name:
|
|
21
|
+
d["name"] = self.name
|
|
22
|
+
if self.tool_calls:
|
|
23
|
+
d["tool_calls"] = self.tool_calls
|
|
24
|
+
if self.tool_call_id:
|
|
25
|
+
d["tool_call_id"] = self.tool_call_id
|
|
26
|
+
return d
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def from_any(cls, m: Union["Message", Dict[str, Any], str]) -> "Message":
|
|
30
|
+
if isinstance(m, Message):
|
|
31
|
+
return m
|
|
32
|
+
if isinstance(m, str):
|
|
33
|
+
return cls(role="user", content=m)
|
|
34
|
+
if isinstance(m, dict):
|
|
35
|
+
return cls(
|
|
36
|
+
role=m.get("role", "user"),
|
|
37
|
+
content=m.get("content"),
|
|
38
|
+
name=m.get("name"),
|
|
39
|
+
tool_calls=m.get("tool_calls"),
|
|
40
|
+
tool_call_id=m.get("tool_call_id"),
|
|
41
|
+
)
|
|
42
|
+
raise TypeError(f"unsupported message type: {type(m)!r}")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class Usage:
|
|
47
|
+
prompt_tokens: int = 0
|
|
48
|
+
completion_tokens: int = 0
|
|
49
|
+
total_tokens: int = 0
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def from_dict(cls, d: Optional[Dict[str, Any]]) -> "Usage":
|
|
53
|
+
d = d or {}
|
|
54
|
+
return cls(
|
|
55
|
+
prompt_tokens=int(d.get("prompt_tokens") or 0),
|
|
56
|
+
completion_tokens=int(d.get("completion_tokens") or 0),
|
|
57
|
+
total_tokens=int(d.get("total_tokens") or 0),
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class Choice:
|
|
63
|
+
index: int
|
|
64
|
+
message: Message
|
|
65
|
+
finish_reason: Optional[str] = None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class ChatResponse:
|
|
70
|
+
id: Optional[str]
|
|
71
|
+
model: Optional[str]
|
|
72
|
+
provider: Optional[str]
|
|
73
|
+
choices: List[Choice]
|
|
74
|
+
usage: Usage
|
|
75
|
+
latency_ms: float = 0.0
|
|
76
|
+
raw: Optional[Dict[str, Any]] = None
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def text(self) -> str:
|
|
80
|
+
if not self.choices:
|
|
81
|
+
return ""
|
|
82
|
+
return self.choices[0].message.content or ""
|
|
83
|
+
|
|
84
|
+
def __str__(self) -> str: # so print(resp) gives the text
|
|
85
|
+
return self.text
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
_SAMPLING_FIELDS = ("temperature", "max_tokens", "top_p", "stop", "seed", "frequency_penalty", "presence_penalty")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass
|
|
92
|
+
class ChatRequest:
|
|
93
|
+
messages: List[Dict[str, Any]]
|
|
94
|
+
model: str = "auto" # virtual alias resolved per provider
|
|
95
|
+
temperature: Optional[float] = None
|
|
96
|
+
max_tokens: Optional[int] = None
|
|
97
|
+
top_p: Optional[float] = None
|
|
98
|
+
stop: Optional[Union[str, List[str]]] = None
|
|
99
|
+
seed: Optional[int] = None
|
|
100
|
+
frequency_penalty: Optional[float] = None
|
|
101
|
+
presence_penalty: Optional[float] = None
|
|
102
|
+
extra: Dict[str, Any] = field(default_factory=dict)
|
|
103
|
+
|
|
104
|
+
def payload(self, concrete_model: str) -> Dict[str, Any]:
|
|
105
|
+
body: Dict[str, Any] = {"model": concrete_model, "messages": self.messages}
|
|
106
|
+
for k in _SAMPLING_FIELDS:
|
|
107
|
+
v = getattr(self, k)
|
|
108
|
+
if v is not None:
|
|
109
|
+
body[k] = v
|
|
110
|
+
body.update(self.extra)
|
|
111
|
+
return body
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def build_request(messages: Any, model: str, kw: Dict[str, Any]) -> ChatRequest:
|
|
115
|
+
if not isinstance(messages, (list, tuple)):
|
|
116
|
+
messages = [messages]
|
|
117
|
+
msgs = [Message.from_any(m).to_dict() for m in messages]
|
|
118
|
+
fields = {k: kw.pop(k) for k in list(kw) if k in _SAMPLING_FIELDS}
|
|
119
|
+
return ChatRequest(messages=msgs, model=model, extra=dict(kw), **fields)
|