khazad 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- khazad/__init__.py +109 -0
- khazad/_models.py +81 -0
- khazad/_transport.py +290 -0
- khazad/adapters/__init__.py +0 -0
- khazad/adapters/embedders/__init__.py +0 -0
- khazad/adapters/embedders/huggingface.py +41 -0
- khazad/adapters/embedders/openai.py +50 -0
- khazad/adapters/parsers/__init__.py +0 -0
- khazad/adapters/parsers/anthropic.py +122 -0
- khazad/adapters/parsers/gemini.py +50 -0
- khazad/adapters/parsers/openai.py +125 -0
- khazad/adapters/parsers/openai_responses.py +168 -0
- khazad/adapters/redis/__init__.py +0 -0
- khazad/adapters/redis/store.py +142 -0
- khazad/khazad.py +290 -0
- khazad/ports/__init__.py +7 -0
- khazad/ports/embedder.py +18 -0
- khazad/ports/parser.py +91 -0
- khazad/ports/store.py +50 -0
- khazad/py.typed +0 -0
- khazad-0.1.2.dist-info/METADATA +443 -0
- khazad-0.1.2.dist-info/RECORD +24 -0
- khazad-0.1.2.dist-info/WHEEL +4 -0
- khazad-0.1.2.dist-info/licenses/LICENSE +21 -0
khazad/__init__.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""Khazad — transparent semantic cache for LLM API calls.
|
|
2
|
+
|
|
3
|
+
"You shall not pass" — Khazad stands between your application and
|
|
4
|
+
expensive LLM API calls, turning semantically equivalent requests
|
|
5
|
+
away at the bridge.
|
|
6
|
+
|
|
7
|
+
Usage::
|
|
8
|
+
|
|
9
|
+
# Functional singleton API
|
|
10
|
+
import khazad
|
|
11
|
+
khazad.init(redis_url="redis://localhost:6379", threshold=0.92)
|
|
12
|
+
khazad.stop()
|
|
13
|
+
|
|
14
|
+
# Or manage the instance explicitly
|
|
15
|
+
from khazad import Khazad
|
|
16
|
+
cache = Khazad(redis_url="redis://localhost:6379", threshold=0.92)
|
|
17
|
+
cache.stop()
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import logging
|
|
23
|
+
from typing import Literal
|
|
24
|
+
|
|
25
|
+
from khazad._models import CacheHit, CacheScope, ParsedRequest, Stats
|
|
26
|
+
from khazad.khazad import Khazad
|
|
27
|
+
|
|
28
|
+
__version__ = "0.1.2"
|
|
29
|
+
__all__ = [
|
|
30
|
+
"CacheHit",
|
|
31
|
+
"CacheScope",
|
|
32
|
+
"Khazad",
|
|
33
|
+
"ParsedRequest",
|
|
34
|
+
"Stats",
|
|
35
|
+
"flush",
|
|
36
|
+
"get_stats",
|
|
37
|
+
"init",
|
|
38
|
+
"is_active",
|
|
39
|
+
"stop",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
# ---------------------------------------------------------------------------
|
|
43
|
+
# Module-level singleton — functional interface over a single Khazad instance
|
|
44
|
+
# ---------------------------------------------------------------------------
|
|
45
|
+
|
|
46
|
+
_instance: Khazad | None = None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def init(
|
|
50
|
+
redis_url: str = "redis://localhost:6379",
|
|
51
|
+
threshold: float = 0.90,
|
|
52
|
+
ttl: int | None = None,
|
|
53
|
+
namespace: str = "khazad",
|
|
54
|
+
embedder: str = "huggingface",
|
|
55
|
+
embedding_model: str = "redis/langcache-embed-v2",
|
|
56
|
+
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR"] = "INFO",
|
|
57
|
+
hosts: list[str] | None = None,
|
|
58
|
+
cache_scope: CacheScope | Literal["model", "host"] = CacheScope.MODEL,
|
|
59
|
+
) -> None:
|
|
60
|
+
"""Activate the global Khazad singleton."""
|
|
61
|
+
global _instance
|
|
62
|
+
|
|
63
|
+
if _instance is not None and _instance.is_active():
|
|
64
|
+
logging.getLogger("khazad").warning("[Khazad] Already initialized, call stop() first")
|
|
65
|
+
return
|
|
66
|
+
|
|
67
|
+
_instance = Khazad(
|
|
68
|
+
redis_url=redis_url,
|
|
69
|
+
threshold=threshold,
|
|
70
|
+
ttl=ttl,
|
|
71
|
+
namespace=namespace,
|
|
72
|
+
embedder=embedder,
|
|
73
|
+
embedding_model=embedding_model,
|
|
74
|
+
log_level=log_level,
|
|
75
|
+
hosts=hosts,
|
|
76
|
+
cache_scope=cache_scope,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def stop() -> None:
|
|
81
|
+
"""Deactivate the global Khazad singleton."""
|
|
82
|
+
global _instance
|
|
83
|
+
|
|
84
|
+
if _instance is None or not _instance.is_active():
|
|
85
|
+
logging.getLogger("khazad").warning("[Khazad] Not currently active")
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
_instance.stop()
|
|
89
|
+
_instance = None
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_stats() -> dict:
|
|
93
|
+
"""Return current cache performance metrics as a dictionary."""
|
|
94
|
+
if _instance is None:
|
|
95
|
+
return Stats().to_dict()
|
|
96
|
+
return _instance.get_stats().to_dict()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def flush() -> None:
|
|
100
|
+
"""Clear all cached entries from Redis."""
|
|
101
|
+
if _instance is None:
|
|
102
|
+
logging.getLogger("khazad").warning("[Khazad] Not initialized, nothing to flush")
|
|
103
|
+
return
|
|
104
|
+
_instance.flush()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def is_active() -> bool:
|
|
108
|
+
"""Return True if Khazad is currently intercepting HTTP traffic."""
|
|
109
|
+
return _instance is not None and _instance.is_active()
|
khazad/_models.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Domain models for parsed requests, cache hits and observability stats."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from enum import Enum
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CacheScope(str, Enum):
|
|
10
|
+
"""How cache entries are partitioned within a provider host.
|
|
11
|
+
|
|
12
|
+
The provider host is always part of the scope, so a cached response is
|
|
13
|
+
never replayed across providers. This enum only controls whether the
|
|
14
|
+
**model** is also part of the key.
|
|
15
|
+
|
|
16
|
+
- :attr:`MODEL` (default) — each ``(host, model)`` pair gets its own vector
|
|
17
|
+
set, so a ``gpt-4o`` answer is never served to a ``gpt-4o-mini`` call.
|
|
18
|
+
- :attr:`HOST` — every model or deployment on the same provider shares one
|
|
19
|
+
vector set. Safe only for format-compatible pools (e.g. several Azure
|
|
20
|
+
OpenAI deployments, or treating ``gpt-4o`` and ``gpt-4o-mini`` as
|
|
21
|
+
interchangeable).
|
|
22
|
+
|
|
23
|
+
Members are also plain strings, so ``"model"`` / ``"host"`` are accepted
|
|
24
|
+
anywhere a :class:`CacheScope` is expected.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
MODEL = "model"
|
|
28
|
+
HOST = "host"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen=True, slots=True)
|
|
32
|
+
class ParsedRequest:
|
|
33
|
+
"""Semantic content extracted from a provider request body."""
|
|
34
|
+
|
|
35
|
+
prompt: str
|
|
36
|
+
model: str | None = None
|
|
37
|
+
stream: bool = False
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(frozen=True, slots=True)
|
|
41
|
+
class CacheHit:
|
|
42
|
+
"""Result of a successful cache lookup."""
|
|
43
|
+
|
|
44
|
+
key: str
|
|
45
|
+
similarity: float
|
|
46
|
+
response_data: bytes
|
|
47
|
+
latency_ms: float
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass(slots=True)
|
|
51
|
+
class Stats:
|
|
52
|
+
"""Observable cache performance metrics."""
|
|
53
|
+
|
|
54
|
+
total_requests: int = 0
|
|
55
|
+
cache_hits: int = 0
|
|
56
|
+
cache_misses: int = 0
|
|
57
|
+
total_hit_similarity: float = 0.0
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def hit_rate(self) -> float:
|
|
61
|
+
"""Return the cache hit ratio as a value between 0 and 1."""
|
|
62
|
+
if self.total_requests == 0:
|
|
63
|
+
return 0.0
|
|
64
|
+
return self.cache_hits / self.total_requests
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def avg_hit_similarity(self) -> float:
|
|
68
|
+
"""Return the average cosine similarity of cache hits."""
|
|
69
|
+
if self.cache_hits == 0:
|
|
70
|
+
return 0.0
|
|
71
|
+
return self.total_hit_similarity / self.cache_hits
|
|
72
|
+
|
|
73
|
+
def to_dict(self) -> dict:
|
|
74
|
+
"""Serialize stats to a plain dictionary."""
|
|
75
|
+
return {
|
|
76
|
+
"total_requests": self.total_requests,
|
|
77
|
+
"cache_hits": self.cache_hits,
|
|
78
|
+
"cache_misses": self.cache_misses,
|
|
79
|
+
"hit_rate": round(self.hit_rate, 4),
|
|
80
|
+
"avg_hit_similarity": round(self.avg_hit_similarity, 4),
|
|
81
|
+
}
|
khazad/_transport.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
"""HTTP transport interceptor — patches httpx to route LLM traffic through the cache."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from collections.abc import Callable, Iterator
|
|
13
|
+
|
|
14
|
+
from khazad.khazad import Khazad, PreparedRequest
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger("khazad")
|
|
17
|
+
|
|
18
|
+
# ---------------------------------------------------------------------------
|
|
19
|
+
# Patch install / uninstall
|
|
20
|
+
# ---------------------------------------------------------------------------
|
|
21
|
+
|
|
22
|
+
_original_async_init = None
|
|
23
|
+
_original_sync_init = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def install(cache: Khazad) -> None:
|
|
27
|
+
"""Monkey-patch httpx.Client and httpx.AsyncClient to use Khazad transports.
|
|
28
|
+
|
|
29
|
+
Safe to call repeatedly: only the *first* install captures the pristine
|
|
30
|
+
``__init__`` references. Subsequent calls swap in a new cache without
|
|
31
|
+
losing the original — so ``uninstall()`` always restores real httpx.
|
|
32
|
+
"""
|
|
33
|
+
global _original_async_init, _original_sync_init
|
|
34
|
+
|
|
35
|
+
if _original_async_init is None:
|
|
36
|
+
_original_async_init = httpx.AsyncClient.__init__
|
|
37
|
+
if _original_sync_init is None:
|
|
38
|
+
_original_sync_init = httpx.Client.__init__
|
|
39
|
+
|
|
40
|
+
original_async = _original_async_init
|
|
41
|
+
original_sync = _original_sync_init
|
|
42
|
+
|
|
43
|
+
def patched_async_init(self: httpx.AsyncClient, *args, **kwargs) -> None:
|
|
44
|
+
original_async(self, *args, **kwargs)
|
|
45
|
+
self._transport = CachedAsyncTransport(cache, self._transport)
|
|
46
|
+
|
|
47
|
+
def patched_sync_init(self: httpx.Client, *args, **kwargs) -> None:
|
|
48
|
+
original_sync(self, *args, **kwargs)
|
|
49
|
+
self._transport = CachedSyncTransport(cache, self._transport)
|
|
50
|
+
|
|
51
|
+
httpx.AsyncClient.__init__ = patched_async_init # type: ignore[method-assign]
|
|
52
|
+
httpx.Client.__init__ = patched_sync_init # type: ignore[method-assign]
|
|
53
|
+
|
|
54
|
+
logger.info("[Khazad] HTTP transport patches installed")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def uninstall() -> None:
|
|
58
|
+
"""Restore the original httpx transports."""
|
|
59
|
+
global _original_async_init, _original_sync_init
|
|
60
|
+
|
|
61
|
+
if _original_async_init is not None:
|
|
62
|
+
httpx.AsyncClient.__init__ = _original_async_init # type: ignore[method-assign]
|
|
63
|
+
_original_async_init = None
|
|
64
|
+
|
|
65
|
+
if _original_sync_init is not None:
|
|
66
|
+
httpx.Client.__init__ = _original_sync_init # type: ignore[method-assign]
|
|
67
|
+
_original_sync_init = None
|
|
68
|
+
|
|
69
|
+
logger.info("[Khazad] HTTP transport patches removed")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# ---------------------------------------------------------------------------
|
|
73
|
+
# Transport wrappers
|
|
74
|
+
# ---------------------------------------------------------------------------
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class CachedSyncTransport(httpx.BaseTransport):
|
|
78
|
+
"""Sync httpx transport that intercepts LLM requests for caching."""
|
|
79
|
+
|
|
80
|
+
def __init__(self, cache: Khazad, wrapped: httpx.BaseTransport) -> None:
|
|
81
|
+
self._cache = cache
|
|
82
|
+
self._wrapped = wrapped
|
|
83
|
+
|
|
84
|
+
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
|
85
|
+
cache = self._cache
|
|
86
|
+
prepared = cache.prepare(request) if cache.is_active() else None
|
|
87
|
+
if prepared is None:
|
|
88
|
+
return self._wrapped.handle_request(request)
|
|
89
|
+
|
|
90
|
+
hit = cache.lookup(prepared)
|
|
91
|
+
if hit is not None:
|
|
92
|
+
return _replay(prepared, hit)
|
|
93
|
+
|
|
94
|
+
response = self._wrapped.handle_request(request)
|
|
95
|
+
if response.status_code != 200:
|
|
96
|
+
return response
|
|
97
|
+
|
|
98
|
+
if _is_sse(response):
|
|
99
|
+
if not _can_capture(response):
|
|
100
|
+
return response
|
|
101
|
+
stream = _SyncTeeStream(
|
|
102
|
+
response.stream, lambda raw: _store_stream(cache, prepared, raw)
|
|
103
|
+
)
|
|
104
|
+
return _swap_stream(response, stream)
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
response.read()
|
|
108
|
+
cache.store(prepared, response.content)
|
|
109
|
+
except Exception:
|
|
110
|
+
logger.warning("[Khazad] Failed to store response in cache", exc_info=True)
|
|
111
|
+
return response
|
|
112
|
+
|
|
113
|
+
def close(self) -> None:
|
|
114
|
+
self._wrapped.close()
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class CachedAsyncTransport(httpx.AsyncBaseTransport):
|
|
118
|
+
"""Async httpx transport that intercepts LLM requests for caching."""
|
|
119
|
+
|
|
120
|
+
def __init__(self, cache: Khazad, wrapped: httpx.AsyncBaseTransport) -> None:
|
|
121
|
+
self._cache = cache
|
|
122
|
+
self._wrapped = wrapped
|
|
123
|
+
|
|
124
|
+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
|
125
|
+
cache = self._cache
|
|
126
|
+
prepared = cache.prepare(request) if cache.is_active() else None
|
|
127
|
+
if prepared is None:
|
|
128
|
+
return await self._wrapped.handle_async_request(request)
|
|
129
|
+
|
|
130
|
+
# Embedding and Redis search are blocking — keep the event loop free.
|
|
131
|
+
loop = asyncio.get_running_loop()
|
|
132
|
+
hit = await loop.run_in_executor(None, cache.lookup, prepared)
|
|
133
|
+
if hit is not None:
|
|
134
|
+
return _replay(prepared, hit)
|
|
135
|
+
|
|
136
|
+
response = await self._wrapped.handle_async_request(request)
|
|
137
|
+
if response.status_code != 200:
|
|
138
|
+
return response
|
|
139
|
+
|
|
140
|
+
if _is_sse(response):
|
|
141
|
+
if not _can_capture(response):
|
|
142
|
+
return response
|
|
143
|
+
stream = _AsyncTeeStream(
|
|
144
|
+
response.stream, lambda raw: _store_stream(cache, prepared, raw)
|
|
145
|
+
)
|
|
146
|
+
return _swap_stream(response, stream)
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
await response.aread()
|
|
150
|
+
await loop.run_in_executor(None, cache.store, prepared, response.content)
|
|
151
|
+
except Exception:
|
|
152
|
+
logger.warning("[Khazad] Failed to store response in cache", exc_info=True)
|
|
153
|
+
return response
|
|
154
|
+
|
|
155
|
+
async def aclose(self) -> None:
|
|
156
|
+
await self._wrapped.aclose()
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# ---------------------------------------------------------------------------
|
|
160
|
+
# Cache hit / miss plumbing
|
|
161
|
+
# ---------------------------------------------------------------------------
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _replay(prepared: PreparedRequest, hit) -> httpx.Response:
|
|
165
|
+
"""Build the response for a cache hit — JSON body or simulated SSE stream."""
|
|
166
|
+
if prepared.stream:
|
|
167
|
+
return httpx.Response(
|
|
168
|
+
status_code=200,
|
|
169
|
+
headers={"content-type": "text/event-stream"},
|
|
170
|
+
stream=_ReplayStream(prepared.parser.stream_chunks(hit.response_data)),
|
|
171
|
+
)
|
|
172
|
+
return prepared.parser.build_response(hit.response_data)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _store_stream(cache: Khazad, prepared: PreparedRequest, raw: bytes) -> None:
|
|
176
|
+
"""Reconstruct a canonical JSON response from raw SSE bytes and cache it."""
|
|
177
|
+
body = prepared.parser.response_from_stream(raw)
|
|
178
|
+
if body:
|
|
179
|
+
cache.store(prepared, body)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _is_sse(response: httpx.Response) -> bool:
|
|
183
|
+
return "text/event-stream" in response.headers.get("content-type", "")
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _can_capture(response: httpx.Response) -> bool:
|
|
187
|
+
"""Raw stream bytes are only parseable when the body is not compressed."""
|
|
188
|
+
return response.headers.get("content-encoding", "identity").lower() in ("", "identity")
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _swap_stream(
|
|
192
|
+
response: httpx.Response, stream: httpx.SyncByteStream | httpx.AsyncByteStream
|
|
193
|
+
) -> httpx.Response:
|
|
194
|
+
"""Clone a response, replacing its stream with a tee'd one."""
|
|
195
|
+
return httpx.Response(
|
|
196
|
+
status_code=response.status_code,
|
|
197
|
+
headers=response.headers,
|
|
198
|
+
stream=stream,
|
|
199
|
+
extensions=response.extensions,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# ---------------------------------------------------------------------------
|
|
204
|
+
# Stream helpers
|
|
205
|
+
# ---------------------------------------------------------------------------
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class _ReplayStream(httpx.SyncByteStream, httpx.AsyncByteStream):
|
|
209
|
+
"""Serve synthesized SSE chunks to either a sync or an async client."""
|
|
210
|
+
|
|
211
|
+
def __init__(self, chunks: Iterator[bytes]) -> None:
|
|
212
|
+
self._chunks = chunks
|
|
213
|
+
|
|
214
|
+
def __iter__(self):
|
|
215
|
+
yield from self._chunks
|
|
216
|
+
|
|
217
|
+
async def __aiter__(self):
|
|
218
|
+
for chunk in self._chunks:
|
|
219
|
+
yield chunk
|
|
220
|
+
await asyncio.sleep(0) # let other tasks run between frames
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class _SyncTeeStream(httpx.SyncByteStream):
|
|
224
|
+
"""Pass chunks through untouched; cache the body when the stream ends.
|
|
225
|
+
|
|
226
|
+
SDKs (e.g. the OpenAI client) break their read loop on the terminal SSE
|
|
227
|
+
sentinel and call ``close()`` without driving the byte stream to EOF, so
|
|
228
|
+
the body is captured on *either* natural exhaustion or ``close()``. The
|
|
229
|
+
parser's :meth:`response_from_stream` decides whether the captured bytes
|
|
230
|
+
form a complete response — an aborted, partial stream reconstructs to
|
|
231
|
+
``None`` and is never cached.
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
def __init__(self, inner: httpx.SyncByteStream, on_complete: Callable[[bytes], None]) -> None:
|
|
235
|
+
self._inner = inner
|
|
236
|
+
self._on_complete = on_complete
|
|
237
|
+
self._parts: list[bytes] = []
|
|
238
|
+
self._finished = False
|
|
239
|
+
|
|
240
|
+
def __iter__(self):
|
|
241
|
+
for chunk in self._inner:
|
|
242
|
+
self._parts.append(chunk)
|
|
243
|
+
yield chunk
|
|
244
|
+
self._finish()
|
|
245
|
+
|
|
246
|
+
def _finish(self) -> None:
|
|
247
|
+
if self._finished:
|
|
248
|
+
return
|
|
249
|
+
self._finished = True
|
|
250
|
+
try:
|
|
251
|
+
self._on_complete(b"".join(self._parts))
|
|
252
|
+
except Exception:
|
|
253
|
+
logger.warning("[Khazad] Failed to store streamed response", exc_info=True)
|
|
254
|
+
|
|
255
|
+
def close(self) -> None:
|
|
256
|
+
self._finish()
|
|
257
|
+
self._inner.close()
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class _AsyncTeeStream(httpx.AsyncByteStream):
|
|
261
|
+
"""Async variant of :class:`_SyncTeeStream`; stores off the event loop."""
|
|
262
|
+
|
|
263
|
+
def __init__(self, inner: httpx.AsyncByteStream, on_complete: Callable[[bytes], None]) -> None:
|
|
264
|
+
self._inner = inner
|
|
265
|
+
self._on_complete = on_complete
|
|
266
|
+
self._parts: list[bytes] = []
|
|
267
|
+
self._finished = False
|
|
268
|
+
|
|
269
|
+
async def __aiter__(self):
|
|
270
|
+
async for chunk in self._inner:
|
|
271
|
+
self._parts.append(chunk)
|
|
272
|
+
yield chunk
|
|
273
|
+
self._finish()
|
|
274
|
+
|
|
275
|
+
def _finish(self) -> None:
|
|
276
|
+
if self._finished:
|
|
277
|
+
return
|
|
278
|
+
self._finished = True
|
|
279
|
+
body = b"".join(self._parts)
|
|
280
|
+
asyncio.get_running_loop().run_in_executor(None, self._safe_complete, body)
|
|
281
|
+
|
|
282
|
+
def _safe_complete(self, body: bytes) -> None:
|
|
283
|
+
try:
|
|
284
|
+
self._on_complete(body)
|
|
285
|
+
except Exception:
|
|
286
|
+
logger.warning("[Khazad] Failed to store streamed response", exc_info=True)
|
|
287
|
+
|
|
288
|
+
async def aclose(self) -> None:
|
|
289
|
+
self._finish()
|
|
290
|
+
await self._inner.aclose()
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""HuggingFace embedding adapter using sentence-transformers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import threading
|
|
7
|
+
from functools import cached_property
|
|
8
|
+
|
|
9
|
+
from sentence_transformers import SentenceTransformer
|
|
10
|
+
|
|
11
|
+
from khazad.ports.embedder import Embedder
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger("khazad")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HuggingFaceEmbedder(Embedder):
|
|
17
|
+
"""Embedder backed by a local sentence-transformers model."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, model_name: str = "redis/langcache-embed-v2") -> None:
|
|
20
|
+
self._model_name = model_name
|
|
21
|
+
self._model: SentenceTransformer | None = None
|
|
22
|
+
self._load_lock = threading.Lock()
|
|
23
|
+
|
|
24
|
+
def _get_model(self) -> SentenceTransformer:
|
|
25
|
+
"""Lazily load the model, guarding against concurrent double-loads."""
|
|
26
|
+
if self._model is None:
|
|
27
|
+
with self._load_lock:
|
|
28
|
+
if self._model is None:
|
|
29
|
+
logger.info("[Khazad] Loading embedding model: %s", self._model_name)
|
|
30
|
+
self._model = SentenceTransformer(self._model_name)
|
|
31
|
+
return self._model
|
|
32
|
+
|
|
33
|
+
def embed(self, text: str) -> list[float]:
|
|
34
|
+
"""Generate a normalized embedding vector for the given text."""
|
|
35
|
+
vector = self._get_model().encode(text, normalize_embeddings=True)
|
|
36
|
+
return vector.tolist()
|
|
37
|
+
|
|
38
|
+
@cached_property
|
|
39
|
+
def dimension(self) -> int:
|
|
40
|
+
"""Return the dimensionality of this model's embeddings."""
|
|
41
|
+
return self._get_model().get_sentence_embedding_dimension() # type: ignore[return-value]
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""OpenAI embedding adapter (optional paid backend)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from khazad.ports.embedder import Embedder
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from openai import OpenAI
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger("khazad")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OpenAIEmbedder(Embedder):
|
|
17
|
+
"""Embedder backed by the OpenAI Embeddings API."""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
model: str = "text-embedding-3-small",
|
|
22
|
+
api_key: str | None = None,
|
|
23
|
+
dimension: int = 1536,
|
|
24
|
+
) -> None:
|
|
25
|
+
self._model = model
|
|
26
|
+
self._api_key = api_key
|
|
27
|
+
self._dimension = dimension
|
|
28
|
+
self._client: OpenAI | None = None
|
|
29
|
+
|
|
30
|
+
def _get_client(self) -> OpenAI:
|
|
31
|
+
"""Lazily create the OpenAI client; fails clearly if the extra is missing."""
|
|
32
|
+
if self._client is None:
|
|
33
|
+
try:
|
|
34
|
+
from openai import OpenAI
|
|
35
|
+
except ImportError as exc:
|
|
36
|
+
raise ImportError(
|
|
37
|
+
"Install the 'openai' package: pip install khazad[openai-embeddings]"
|
|
38
|
+
) from exc
|
|
39
|
+
self._client = OpenAI(api_key=self._api_key)
|
|
40
|
+
return self._client
|
|
41
|
+
|
|
42
|
+
def embed(self, text: str) -> list[float]:
|
|
43
|
+
"""Generate an embedding via the OpenAI API."""
|
|
44
|
+
response = self._get_client().embeddings.create(input=text, model=self._model)
|
|
45
|
+
return response.data[0].embedding
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def dimension(self) -> int:
|
|
49
|
+
"""Return the configured embedding dimension."""
|
|
50
|
+
return self._dimension
|
|
File without changes
|