flashlite 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flashlite/__init__.py +169 -0
- flashlite/cache/__init__.py +14 -0
- flashlite/cache/base.py +194 -0
- flashlite/cache/disk.py +285 -0
- flashlite/cache/memory.py +157 -0
- flashlite/client.py +671 -0
- flashlite/config.py +154 -0
- flashlite/conversation/__init__.py +30 -0
- flashlite/conversation/context.py +319 -0
- flashlite/conversation/manager.py +385 -0
- flashlite/conversation/multi_agent.py +378 -0
- flashlite/core/__init__.py +13 -0
- flashlite/core/completion.py +145 -0
- flashlite/core/messages.py +130 -0
- flashlite/middleware/__init__.py +18 -0
- flashlite/middleware/base.py +90 -0
- flashlite/middleware/cache.py +121 -0
- flashlite/middleware/logging.py +159 -0
- flashlite/middleware/rate_limit.py +211 -0
- flashlite/middleware/retry.py +149 -0
- flashlite/observability/__init__.py +34 -0
- flashlite/observability/callbacks.py +155 -0
- flashlite/observability/inspect_compat.py +266 -0
- flashlite/observability/logging.py +293 -0
- flashlite/observability/metrics.py +221 -0
- flashlite/py.typed +0 -0
- flashlite/structured/__init__.py +31 -0
- flashlite/structured/outputs.py +189 -0
- flashlite/structured/schema.py +165 -0
- flashlite/templating/__init__.py +11 -0
- flashlite/templating/engine.py +217 -0
- flashlite/templating/filters.py +143 -0
- flashlite/templating/registry.py +165 -0
- flashlite/tools/__init__.py +74 -0
- flashlite/tools/definitions.py +382 -0
- flashlite/tools/execution.py +353 -0
- flashlite/types.py +233 -0
- flashlite-0.1.0.dist-info/METADATA +173 -0
- flashlite-0.1.0.dist-info/RECORD +41 -0
- flashlite-0.1.0.dist-info/WHEEL +4 -0
- flashlite-0.1.0.dist-info/licenses/LICENSE.md +21 -0
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""Retry middleware with exponential backoff."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
import random
|
|
6
|
+
|
|
7
|
+
from tenacity import (
|
|
8
|
+
AsyncRetrying,
|
|
9
|
+
RetryError,
|
|
10
|
+
retry_if_exception,
|
|
11
|
+
stop_after_attempt,
|
|
12
|
+
wait_exponential_jitter,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from ..types import (
|
|
16
|
+
CompletionError,
|
|
17
|
+
CompletionRequest,
|
|
18
|
+
CompletionResponse,
|
|
19
|
+
RetryConfig,
|
|
20
|
+
)
|
|
21
|
+
from .base import CompletionHandler, Middleware
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _should_retry(exception: BaseException) -> bool:
|
|
27
|
+
"""Determine if an exception should trigger a retry."""
|
|
28
|
+
if isinstance(exception, CompletionError):
|
|
29
|
+
# Retry on specific status codes
|
|
30
|
+
if exception.status_code in (429, 500, 502, 503, 504):
|
|
31
|
+
return True
|
|
32
|
+
# Don't retry on client errors (4xx except 429)
|
|
33
|
+
if exception.status_code and 400 <= exception.status_code < 500:
|
|
34
|
+
return False
|
|
35
|
+
# Retry on connection errors, timeouts, etc.
|
|
36
|
+
if isinstance(exception, (ConnectionError, TimeoutError, asyncio.TimeoutError)):
|
|
37
|
+
return True
|
|
38
|
+
return False
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class RetryMiddleware(Middleware):
|
|
42
|
+
"""
|
|
43
|
+
Middleware that retries failed requests with exponential backoff.
|
|
44
|
+
|
|
45
|
+
Uses tenacity for retry logic with:
|
|
46
|
+
- Exponential backoff with jitter
|
|
47
|
+
- Configurable max attempts and delays
|
|
48
|
+
- Retries on transient errors (429, 5xx, connection errors)
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, config: RetryConfig | None = None):
|
|
52
|
+
"""
|
|
53
|
+
Initialize retry middleware.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
config: Retry configuration. Uses defaults if not provided.
|
|
57
|
+
"""
|
|
58
|
+
self.config = config or RetryConfig()
|
|
59
|
+
|
|
60
|
+
async def __call__(
|
|
61
|
+
self,
|
|
62
|
+
request: CompletionRequest,
|
|
63
|
+
next_handler: CompletionHandler,
|
|
64
|
+
) -> CompletionResponse:
|
|
65
|
+
"""Execute with retry logic."""
|
|
66
|
+
attempt = 0
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
async for attempt_state in AsyncRetrying(
|
|
70
|
+
stop=stop_after_attempt(self.config.max_attempts),
|
|
71
|
+
wait=wait_exponential_jitter(
|
|
72
|
+
initial=self.config.initial_delay,
|
|
73
|
+
max=self.config.max_delay,
|
|
74
|
+
exp_base=self.config.exponential_base,
|
|
75
|
+
jitter=self.config.initial_delay if self.config.jitter else 0,
|
|
76
|
+
),
|
|
77
|
+
retry=retry_if_exception(_should_retry),
|
|
78
|
+
reraise=True,
|
|
79
|
+
):
|
|
80
|
+
with attempt_state:
|
|
81
|
+
attempt = attempt_state.retry_state.attempt_number
|
|
82
|
+
if attempt > 1:
|
|
83
|
+
logger.info(
|
|
84
|
+
f"Retry attempt {attempt}/{self.config.max_attempts} "
|
|
85
|
+
f"for model={request.model}"
|
|
86
|
+
)
|
|
87
|
+
return await next_handler(request)
|
|
88
|
+
|
|
89
|
+
except RetryError as e:
|
|
90
|
+
# Re-raise the last exception
|
|
91
|
+
if e.last_attempt.failed:
|
|
92
|
+
exc = e.last_attempt.exception()
|
|
93
|
+
if exc is not None:
|
|
94
|
+
raise exc from e
|
|
95
|
+
raise CompletionError("Retry attempts exhausted") from e
|
|
96
|
+
|
|
97
|
+
# This should never be reached, but satisfies type checker
|
|
98
|
+
raise CompletionError("Retry logic failed unexpectedly")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class SimpleRetryMiddleware(Middleware):
|
|
102
|
+
"""
|
|
103
|
+
A simpler retry implementation without tenacity dependency.
|
|
104
|
+
|
|
105
|
+
Useful for understanding the retry logic or when tenacity isn't available.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(self, config: RetryConfig | None = None):
|
|
109
|
+
self.config = config or RetryConfig()
|
|
110
|
+
|
|
111
|
+
async def __call__(
|
|
112
|
+
self,
|
|
113
|
+
request: CompletionRequest,
|
|
114
|
+
next_handler: CompletionHandler,
|
|
115
|
+
) -> CompletionResponse:
|
|
116
|
+
last_exception: Exception | None = None
|
|
117
|
+
delay = self.config.initial_delay
|
|
118
|
+
|
|
119
|
+
for attempt in range(1, self.config.max_attempts + 1):
|
|
120
|
+
try:
|
|
121
|
+
return await next_handler(request)
|
|
122
|
+
except Exception as e:
|
|
123
|
+
last_exception = e
|
|
124
|
+
|
|
125
|
+
if not _should_retry(e):
|
|
126
|
+
raise
|
|
127
|
+
|
|
128
|
+
if attempt == self.config.max_attempts:
|
|
129
|
+
raise
|
|
130
|
+
|
|
131
|
+
# Calculate delay with optional jitter
|
|
132
|
+
actual_delay = delay
|
|
133
|
+
if self.config.jitter:
|
|
134
|
+
actual_delay = delay * (0.5 + random.random())
|
|
135
|
+
|
|
136
|
+
logger.info(
|
|
137
|
+
f"Attempt {attempt}/{self.config.max_attempts} failed, "
|
|
138
|
+
f"retrying in {actual_delay:.2f}s: {e}"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
await asyncio.sleep(actual_delay)
|
|
142
|
+
|
|
143
|
+
# Exponential backoff for next attempt
|
|
144
|
+
delay = min(delay * self.config.exponential_base, self.config.max_delay)
|
|
145
|
+
|
|
146
|
+
# Should never reach here
|
|
147
|
+
if last_exception:
|
|
148
|
+
raise last_exception
|
|
149
|
+
raise CompletionError("Retry logic failed unexpectedly")
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Observability module for flashlite."""
|
|
2
|
+
|
|
3
|
+
from .callbacks import (
|
|
4
|
+
CallbackManager,
|
|
5
|
+
OnErrorCallback,
|
|
6
|
+
OnRequestCallback,
|
|
7
|
+
OnResponseCallback,
|
|
8
|
+
create_logging_callbacks,
|
|
9
|
+
)
|
|
10
|
+
from .inspect_compat import FlashliteModelAPI, InspectLogEntry, InspectLogger
|
|
11
|
+
from .logging import RequestContext, RequestLogEntry, ResponseLogEntry, StructuredLogger
|
|
12
|
+
from .metrics import BudgetExceededError, CostMetrics, CostTracker
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
# Logging
|
|
16
|
+
"StructuredLogger",
|
|
17
|
+
"RequestLogEntry",
|
|
18
|
+
"ResponseLogEntry",
|
|
19
|
+
"RequestContext",
|
|
20
|
+
# Metrics
|
|
21
|
+
"CostTracker",
|
|
22
|
+
"CostMetrics",
|
|
23
|
+
"BudgetExceededError",
|
|
24
|
+
# Callbacks
|
|
25
|
+
"CallbackManager",
|
|
26
|
+
"OnRequestCallback",
|
|
27
|
+
"OnResponseCallback",
|
|
28
|
+
"OnErrorCallback",
|
|
29
|
+
"create_logging_callbacks",
|
|
30
|
+
# Inspect
|
|
31
|
+
"InspectLogger",
|
|
32
|
+
"InspectLogEntry",
|
|
33
|
+
"FlashliteModelAPI",
|
|
34
|
+
]
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""Callback system for flashlite observability."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Awaitable, Callable
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from ..types import CompletionRequest, CompletionResponse
|
|
8
|
+
|
|
9
|
+
# Callback type definitions
|
|
10
|
+
OnRequestCallback = Callable[[CompletionRequest, str], Awaitable[None] | None]
|
|
11
|
+
OnResponseCallback = Callable[
|
|
12
|
+
[CompletionResponse, str, float, bool], Awaitable[None] | None
|
|
13
|
+
]
|
|
14
|
+
OnErrorCallback = Callable[[Exception, str, float], Awaitable[None] | None]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class CallbackManager:
|
|
19
|
+
"""
|
|
20
|
+
Manages callbacks for request/response lifecycle events.
|
|
21
|
+
|
|
22
|
+
Example:
|
|
23
|
+
callbacks = CallbackManager()
|
|
24
|
+
|
|
25
|
+
@callbacks.on_request
|
|
26
|
+
async def log_request(request, request_id):
|
|
27
|
+
print(f"Request {request_id}: {request.model}")
|
|
28
|
+
|
|
29
|
+
@callbacks.on_response
|
|
30
|
+
async def log_response(response, request_id, latency_ms, cached):
|
|
31
|
+
print(f"Response {request_id}: {latency_ms}ms")
|
|
32
|
+
|
|
33
|
+
# Or register directly
|
|
34
|
+
callbacks.add_on_request(my_callback)
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
_on_request: list[OnRequestCallback] = field(default_factory=list)
|
|
38
|
+
_on_response: list[OnResponseCallback] = field(default_factory=list)
|
|
39
|
+
_on_error: list[OnErrorCallback] = field(default_factory=list)
|
|
40
|
+
|
|
41
|
+
def add_on_request(self, callback: OnRequestCallback) -> None:
|
|
42
|
+
"""Add a callback to be called before each request."""
|
|
43
|
+
self._on_request.append(callback)
|
|
44
|
+
|
|
45
|
+
def add_on_response(self, callback: OnResponseCallback) -> None:
|
|
46
|
+
"""Add a callback to be called after each successful response."""
|
|
47
|
+
self._on_response.append(callback)
|
|
48
|
+
|
|
49
|
+
def add_on_error(self, callback: OnErrorCallback) -> None:
|
|
50
|
+
"""Add a callback to be called on request errors."""
|
|
51
|
+
self._on_error.append(callback)
|
|
52
|
+
|
|
53
|
+
def on_request(
|
|
54
|
+
self, callback: OnRequestCallback
|
|
55
|
+
) -> OnRequestCallback:
|
|
56
|
+
"""Decorator to register a request callback."""
|
|
57
|
+
self.add_on_request(callback)
|
|
58
|
+
return callback
|
|
59
|
+
|
|
60
|
+
def on_response(
|
|
61
|
+
self, callback: OnResponseCallback
|
|
62
|
+
) -> OnResponseCallback:
|
|
63
|
+
"""Decorator to register a response callback."""
|
|
64
|
+
self.add_on_response(callback)
|
|
65
|
+
return callback
|
|
66
|
+
|
|
67
|
+
def on_error(self, callback: OnErrorCallback) -> OnErrorCallback:
|
|
68
|
+
"""Decorator to register an error callback."""
|
|
69
|
+
self.add_on_error(callback)
|
|
70
|
+
return callback
|
|
71
|
+
|
|
72
|
+
async def emit_request(
|
|
73
|
+
self,
|
|
74
|
+
request: CompletionRequest,
|
|
75
|
+
request_id: str,
|
|
76
|
+
) -> None:
|
|
77
|
+
"""Emit a request event to all registered callbacks."""
|
|
78
|
+
for callback in self._on_request:
|
|
79
|
+
result = callback(request, request_id)
|
|
80
|
+
if isinstance(result, Awaitable):
|
|
81
|
+
await result
|
|
82
|
+
|
|
83
|
+
async def emit_response(
|
|
84
|
+
self,
|
|
85
|
+
response: CompletionResponse,
|
|
86
|
+
request_id: str,
|
|
87
|
+
latency_ms: float,
|
|
88
|
+
cached: bool = False,
|
|
89
|
+
) -> None:
|
|
90
|
+
"""Emit a response event to all registered callbacks."""
|
|
91
|
+
for callback in self._on_response:
|
|
92
|
+
result = callback(response, request_id, latency_ms, cached)
|
|
93
|
+
if isinstance(result, Awaitable):
|
|
94
|
+
await result
|
|
95
|
+
|
|
96
|
+
async def emit_error(
|
|
97
|
+
self,
|
|
98
|
+
error: Exception,
|
|
99
|
+
request_id: str,
|
|
100
|
+
latency_ms: float,
|
|
101
|
+
) -> None:
|
|
102
|
+
"""Emit an error event to all registered callbacks."""
|
|
103
|
+
for callback in self._on_error:
|
|
104
|
+
result = callback(error, request_id, latency_ms)
|
|
105
|
+
if isinstance(result, Awaitable):
|
|
106
|
+
await result
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def create_logging_callbacks(
|
|
110
|
+
logger: Any,
|
|
111
|
+
level: str = "INFO",
|
|
112
|
+
) -> CallbackManager:
|
|
113
|
+
"""
|
|
114
|
+
Create a CallbackManager with standard logging callbacks.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
logger: A logging.Logger instance
|
|
118
|
+
level: Log level to use
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
A configured CallbackManager
|
|
122
|
+
"""
|
|
123
|
+
import logging as stdlib_logging
|
|
124
|
+
|
|
125
|
+
log_level = getattr(stdlib_logging, level.upper())
|
|
126
|
+
callbacks = CallbackManager()
|
|
127
|
+
|
|
128
|
+
@callbacks.on_request
|
|
129
|
+
def log_request(request: CompletionRequest, request_id: str) -> None:
|
|
130
|
+
logger.log(
|
|
131
|
+
log_level,
|
|
132
|
+
f"[{request_id[:8]}] Request: model={request.model}",
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
@callbacks.on_response
|
|
136
|
+
def log_response(
|
|
137
|
+
response: CompletionResponse,
|
|
138
|
+
request_id: str,
|
|
139
|
+
latency_ms: float,
|
|
140
|
+
cached: bool,
|
|
141
|
+
) -> None:
|
|
142
|
+
cache_str = " (cached)" if cached else ""
|
|
143
|
+
tokens = response.usage.total_tokens if response.usage else 0
|
|
144
|
+
logger.log(
|
|
145
|
+
log_level,
|
|
146
|
+
f"[{request_id[:8]}] Response: {latency_ms:.1f}ms, {tokens} tokens{cache_str}",
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
@callbacks.on_error
|
|
150
|
+
def log_error(error: Exception, request_id: str, latency_ms: float) -> None:
|
|
151
|
+
logger.error(
|
|
152
|
+
f"[{request_id[:8]}] Error after {latency_ms:.1f}ms: {error}",
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return callbacks
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
"""Inspect framework compatibility layer for flashlite.
|
|
2
|
+
|
|
3
|
+
This module provides interoperability with the UK AISI's Inspect framework
|
|
4
|
+
(https://inspect.ai-safety-institute.org.uk/).
|
|
5
|
+
|
|
6
|
+
It includes:
|
|
7
|
+
- Log format compatible with Inspect's eval logging
|
|
8
|
+
- ModelAPI protocol implementation for use as an Inspect solver backend
|
|
9
|
+
- Hooks for Inspect's TaskState integration
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import logging
|
|
14
|
+
from dataclasses import dataclass, field
|
|
15
|
+
from datetime import UTC, datetime
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import TYPE_CHECKING, Any
|
|
18
|
+
|
|
19
|
+
from ..types import CompletionRequest, CompletionResponse
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from ..client import Flashlite
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class InspectLogEntry:
|
|
29
|
+
"""A log entry in Inspect-compatible format."""
|
|
30
|
+
|
|
31
|
+
eval_id: str
|
|
32
|
+
sample_id: str | int
|
|
33
|
+
epoch: int
|
|
34
|
+
model: str
|
|
35
|
+
input: list[dict[str, Any]]
|
|
36
|
+
output: str
|
|
37
|
+
tokens: dict[str, int]
|
|
38
|
+
timestamp: str
|
|
39
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
40
|
+
|
|
41
|
+
def to_dict(self) -> dict[str, Any]:
|
|
42
|
+
"""Convert to Inspect log format."""
|
|
43
|
+
return {
|
|
44
|
+
"eval_id": self.eval_id,
|
|
45
|
+
"sample_id": self.sample_id,
|
|
46
|
+
"epoch": self.epoch,
|
|
47
|
+
"model": self.model,
|
|
48
|
+
"input": self.input,
|
|
49
|
+
"output": self.output,
|
|
50
|
+
"tokens": self.tokens,
|
|
51
|
+
"timestamp": self.timestamp,
|
|
52
|
+
"metadata": self.metadata,
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class InspectLogger:
|
|
57
|
+
"""
|
|
58
|
+
A logger that outputs in Inspect-compatible format.
|
|
59
|
+
|
|
60
|
+
This allows Flashlite logs to be analyzed alongside Inspect eval logs,
|
|
61
|
+
enabling unified observability across evaluation runs.
|
|
62
|
+
|
|
63
|
+
Example:
|
|
64
|
+
inspect_logger = InspectLogger(
|
|
65
|
+
log_dir="./logs",
|
|
66
|
+
eval_id="my-eval-001",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Log a completion
|
|
70
|
+
inspect_logger.log(
|
|
71
|
+
request=request,
|
|
72
|
+
response=response,
|
|
73
|
+
sample_id="sample_123",
|
|
74
|
+
epoch=0,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Close when done
|
|
78
|
+
inspect_logger.close()
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
log_dir: str | Path,
|
|
84
|
+
eval_id: str | None = None,
|
|
85
|
+
append: bool = True,
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
Initialize the Inspect logger.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
log_dir: Directory to write log files
|
|
92
|
+
eval_id: Evaluation ID (auto-generated if not provided)
|
|
93
|
+
append: Whether to append to existing log file
|
|
94
|
+
"""
|
|
95
|
+
self._log_dir = Path(log_dir)
|
|
96
|
+
self._log_dir.mkdir(parents=True, exist_ok=True)
|
|
97
|
+
|
|
98
|
+
self._eval_id = eval_id or self._generate_eval_id()
|
|
99
|
+
self._log_file = self._log_dir / f"{self._eval_id}.jsonl"
|
|
100
|
+
self._mode = "a" if append else "w"
|
|
101
|
+
self._file_handle = open(self._log_file, self._mode)
|
|
102
|
+
self._sample_count = 0
|
|
103
|
+
|
|
104
|
+
def _generate_eval_id(self) -> str:
|
|
105
|
+
"""Generate a unique evaluation ID."""
|
|
106
|
+
timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
|
|
107
|
+
return f"flashlite_eval_{timestamp}"
|
|
108
|
+
|
|
109
|
+
def log(
|
|
110
|
+
self,
|
|
111
|
+
request: CompletionRequest,
|
|
112
|
+
response: CompletionResponse,
|
|
113
|
+
sample_id: str | int | None = None,
|
|
114
|
+
epoch: int = 0,
|
|
115
|
+
metadata: dict[str, Any] | None = None,
|
|
116
|
+
) -> None:
|
|
117
|
+
"""
|
|
118
|
+
Log a request/response pair in Inspect format.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
request: The completion request
|
|
122
|
+
response: The completion response
|
|
123
|
+
sample_id: Sample identifier (auto-incremented if not provided)
|
|
124
|
+
epoch: Epoch number for multi-epoch evals
|
|
125
|
+
metadata: Additional metadata to include
|
|
126
|
+
"""
|
|
127
|
+
if sample_id is None:
|
|
128
|
+
sample_id = self._sample_count
|
|
129
|
+
self._sample_count += 1
|
|
130
|
+
|
|
131
|
+
# Convert messages to Inspect format
|
|
132
|
+
input_messages = [
|
|
133
|
+
{"role": msg.get("role", "user"), "content": msg.get("content", "")}
|
|
134
|
+
for msg in request.messages
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
entry = InspectLogEntry(
|
|
138
|
+
eval_id=self._eval_id,
|
|
139
|
+
sample_id=sample_id,
|
|
140
|
+
epoch=epoch,
|
|
141
|
+
model=response.model,
|
|
142
|
+
input=input_messages,
|
|
143
|
+
output=response.content,
|
|
144
|
+
tokens={
|
|
145
|
+
"input": response.usage.input_tokens if response.usage else 0,
|
|
146
|
+
"output": response.usage.output_tokens if response.usage else 0,
|
|
147
|
+
"total": response.usage.total_tokens if response.usage else 0,
|
|
148
|
+
},
|
|
149
|
+
timestamp=datetime.now(UTC).isoformat(),
|
|
150
|
+
metadata=metadata or {},
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
json_str = json.dumps(entry.to_dict())
|
|
154
|
+
self._file_handle.write(json_str + "\n")
|
|
155
|
+
self._file_handle.flush()
|
|
156
|
+
|
|
157
|
+
def close(self) -> None:
|
|
158
|
+
"""Close the log file."""
|
|
159
|
+
if self._file_handle:
|
|
160
|
+
self._file_handle.close()
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def eval_id(self) -> str:
|
|
164
|
+
"""Get the evaluation ID."""
|
|
165
|
+
return self._eval_id
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def log_file(self) -> Path:
|
|
169
|
+
"""Get the log file path."""
|
|
170
|
+
return self._log_file
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class FlashliteModelAPI:
|
|
174
|
+
"""
|
|
175
|
+
An adapter that implements a ModelAPI-like interface for Inspect integration.
|
|
176
|
+
|
|
177
|
+
This allows Flashlite to be used as a model backend in Inspect evaluations.
|
|
178
|
+
|
|
179
|
+
Example:
|
|
180
|
+
from flashlite import Flashlite
|
|
181
|
+
from flashlite.observability import FlashliteModelAPI
|
|
182
|
+
|
|
183
|
+
# Create Flashlite client
|
|
184
|
+
client = Flashlite(rate_limit=RateLimitConfig(requests_per_minute=60))
|
|
185
|
+
|
|
186
|
+
# Wrap for Inspect
|
|
187
|
+
model_api = FlashliteModelAPI(client, model="gpt-4o")
|
|
188
|
+
|
|
189
|
+
# Use in Inspect eval (pseudocode)
|
|
190
|
+
# @task
|
|
191
|
+
# def my_eval():
|
|
192
|
+
# return Task(
|
|
193
|
+
# dataset=my_dataset,
|
|
194
|
+
# solver=my_solver,
|
|
195
|
+
# model=model_api,
|
|
196
|
+
# )
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
def __init__(
|
|
200
|
+
self,
|
|
201
|
+
client: "Flashlite",
|
|
202
|
+
model: str | None = None,
|
|
203
|
+
**default_kwargs: Any,
|
|
204
|
+
):
|
|
205
|
+
"""
|
|
206
|
+
Initialize the Inspect model adapter.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
client: The Flashlite client to use
|
|
210
|
+
model: Default model to use (can be overridden per-request)
|
|
211
|
+
**default_kwargs: Default parameters for completions
|
|
212
|
+
"""
|
|
213
|
+
self._client = client
|
|
214
|
+
self._model = model
|
|
215
|
+
self._default_kwargs = default_kwargs
|
|
216
|
+
|
|
217
|
+
async def generate(
|
|
218
|
+
self,
|
|
219
|
+
messages: list[dict[str, Any]],
|
|
220
|
+
model: str | None = None,
|
|
221
|
+
**kwargs: Any,
|
|
222
|
+
) -> dict[str, Any]:
|
|
223
|
+
"""
|
|
224
|
+
Generate a completion (Inspect-compatible interface).
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
messages: List of messages
|
|
228
|
+
model: Model to use (overrides default)
|
|
229
|
+
**kwargs: Additional parameters
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
Inspect-compatible response dict
|
|
233
|
+
"""
|
|
234
|
+
# Merge kwargs
|
|
235
|
+
call_kwargs = {**self._default_kwargs, **kwargs}
|
|
236
|
+
|
|
237
|
+
# Call Flashlite
|
|
238
|
+
response = await self._client.complete(
|
|
239
|
+
model=model or self._model,
|
|
240
|
+
messages=messages,
|
|
241
|
+
**call_kwargs,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# Return in Inspect-compatible format
|
|
245
|
+
return {
|
|
246
|
+
"choices": [
|
|
247
|
+
{
|
|
248
|
+
"message": {
|
|
249
|
+
"role": "assistant",
|
|
250
|
+
"content": response.content,
|
|
251
|
+
},
|
|
252
|
+
"finish_reason": response.finish_reason,
|
|
253
|
+
}
|
|
254
|
+
],
|
|
255
|
+
"usage": {
|
|
256
|
+
"prompt_tokens": response.usage.input_tokens if response.usage else 0,
|
|
257
|
+
"completion_tokens": response.usage.output_tokens if response.usage else 0,
|
|
258
|
+
"total_tokens": response.usage.total_tokens if response.usage else 0,
|
|
259
|
+
},
|
|
260
|
+
"model": response.model,
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
@property
|
|
264
|
+
def model_name(self) -> str | None:
|
|
265
|
+
"""Get the default model name."""
|
|
266
|
+
return self._model
|