ai-lib-python 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_lib_python/__init__.py +43 -0
- ai_lib_python/batch/__init__.py +15 -0
- ai_lib_python/batch/collector.py +244 -0
- ai_lib_python/batch/executor.py +224 -0
- ai_lib_python/cache/__init__.py +26 -0
- ai_lib_python/cache/backends.py +380 -0
- ai_lib_python/cache/key.py +237 -0
- ai_lib_python/cache/manager.py +332 -0
- ai_lib_python/client/__init__.py +37 -0
- ai_lib_python/client/builder.py +528 -0
- ai_lib_python/client/cancel.py +368 -0
- ai_lib_python/client/core.py +433 -0
- ai_lib_python/client/response.py +134 -0
- ai_lib_python/embeddings/__init__.py +36 -0
- ai_lib_python/embeddings/client.py +339 -0
- ai_lib_python/embeddings/types.py +234 -0
- ai_lib_python/embeddings/vectors.py +246 -0
- ai_lib_python/errors/__init__.py +41 -0
- ai_lib_python/errors/base.py +316 -0
- ai_lib_python/errors/classification.py +210 -0
- ai_lib_python/guardrails/__init__.py +35 -0
- ai_lib_python/guardrails/base.py +336 -0
- ai_lib_python/guardrails/filters.py +583 -0
- ai_lib_python/guardrails/validators.py +475 -0
- ai_lib_python/pipeline/__init__.py +55 -0
- ai_lib_python/pipeline/accumulate.py +248 -0
- ai_lib_python/pipeline/base.py +240 -0
- ai_lib_python/pipeline/decode.py +281 -0
- ai_lib_python/pipeline/event_map.py +506 -0
- ai_lib_python/pipeline/fan_out.py +284 -0
- ai_lib_python/pipeline/select.py +297 -0
- ai_lib_python/plugins/__init__.py +32 -0
- ai_lib_python/plugins/base.py +294 -0
- ai_lib_python/plugins/hooks.py +296 -0
- ai_lib_python/plugins/middleware.py +285 -0
- ai_lib_python/plugins/registry.py +294 -0
- ai_lib_python/protocol/__init__.py +71 -0
- ai_lib_python/protocol/loader.py +317 -0
- ai_lib_python/protocol/manifest.py +385 -0
- ai_lib_python/protocol/validator.py +460 -0
- ai_lib_python/py.typed +1 -0
- ai_lib_python/resilience/__init__.py +102 -0
- ai_lib_python/resilience/backpressure.py +225 -0
- ai_lib_python/resilience/circuit_breaker.py +318 -0
- ai_lib_python/resilience/executor.py +343 -0
- ai_lib_python/resilience/fallback.py +341 -0
- ai_lib_python/resilience/preflight.py +413 -0
- ai_lib_python/resilience/rate_limiter.py +291 -0
- ai_lib_python/resilience/retry.py +299 -0
- ai_lib_python/resilience/signals.py +283 -0
- ai_lib_python/routing/__init__.py +118 -0
- ai_lib_python/routing/manager.py +593 -0
- ai_lib_python/routing/strategy.py +345 -0
- ai_lib_python/routing/types.py +397 -0
- ai_lib_python/structured/__init__.py +33 -0
- ai_lib_python/structured/json_mode.py +281 -0
- ai_lib_python/structured/schema.py +316 -0
- ai_lib_python/structured/validator.py +334 -0
- ai_lib_python/telemetry/__init__.py +127 -0
- ai_lib_python/telemetry/exporters/__init__.py +9 -0
- ai_lib_python/telemetry/exporters/prometheus.py +111 -0
- ai_lib_python/telemetry/feedback.py +446 -0
- ai_lib_python/telemetry/health.py +409 -0
- ai_lib_python/telemetry/logger.py +389 -0
- ai_lib_python/telemetry/metrics.py +496 -0
- ai_lib_python/telemetry/tracer.py +473 -0
- ai_lib_python/tokens/__init__.py +25 -0
- ai_lib_python/tokens/counter.py +282 -0
- ai_lib_python/tokens/estimator.py +286 -0
- ai_lib_python/transport/__init__.py +34 -0
- ai_lib_python/transport/auth.py +141 -0
- ai_lib_python/transport/http.py +364 -0
- ai_lib_python/transport/pool.py +425 -0
- ai_lib_python/types/__init__.py +41 -0
- ai_lib_python/types/events.py +343 -0
- ai_lib_python/types/message.py +332 -0
- ai_lib_python/types/tool.py +191 -0
- ai_lib_python/utils/__init__.py +21 -0
- ai_lib_python/utils/tool_call_assembler.py +317 -0
- ai_lib_python-0.5.0.dist-info/METADATA +837 -0
- ai_lib_python-0.5.0.dist-info/RECORD +84 -0
- ai_lib_python-0.5.0.dist-info/WHEEL +4 -0
- ai_lib_python-0.5.0.dist-info/licenses/LICENSE-APACHE +201 -0
- ai_lib_python-0.5.0.dist-info/licenses/LICENSE-MIT +21 -0
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Middleware system for request/response processing.
|
|
3
|
+
|
|
4
|
+
Provides a chain-of-responsibility pattern for processing.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import TYPE_CHECKING, Any
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from collections.abc import Awaitable, Callable
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class MiddlewareContext:
|
|
19
|
+
"""Context for middleware execution.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
request: Current request
|
|
23
|
+
response: Current response (if available)
|
|
24
|
+
model: Model identifier
|
|
25
|
+
provider: Provider identifier
|
|
26
|
+
metadata: Additional metadata
|
|
27
|
+
aborted: Whether processing was aborted
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
request: dict[str, Any] = field(default_factory=dict)
|
|
31
|
+
response: dict[str, Any] | None = None
|
|
32
|
+
model: str = ""
|
|
33
|
+
provider: str = ""
|
|
34
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
35
|
+
aborted: bool = False
|
|
36
|
+
|
|
37
|
+
def abort(self, response: dict[str, Any] | None = None) -> None:
|
|
38
|
+
"""Abort middleware chain and return response.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
response: Response to return
|
|
42
|
+
"""
|
|
43
|
+
self.aborted = True
|
|
44
|
+
if response:
|
|
45
|
+
self.response = response
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class Middleware(ABC):
|
|
49
|
+
"""Base class for middleware.
|
|
50
|
+
|
|
51
|
+
Middleware processes requests and responses in a chain.
|
|
52
|
+
|
|
53
|
+
Example:
|
|
54
|
+
>>> class LoggingMiddleware(Middleware):
|
|
55
|
+
... async def process(self, ctx, next):
|
|
56
|
+
... print(f"Request: {ctx.request}")
|
|
57
|
+
... response = await next(ctx)
|
|
58
|
+
... print(f"Response: {response}")
|
|
59
|
+
... return response
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def name(self) -> str:
|
|
64
|
+
"""Get middleware name.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Middleware name
|
|
68
|
+
"""
|
|
69
|
+
return self.__class__.__name__
|
|
70
|
+
|
|
71
|
+
@abstractmethod
|
|
72
|
+
async def process(
|
|
73
|
+
self,
|
|
74
|
+
ctx: MiddlewareContext,
|
|
75
|
+
next: Callable[[MiddlewareContext], Awaitable[dict[str, Any] | None]],
|
|
76
|
+
) -> dict[str, Any] | None:
|
|
77
|
+
"""Process the request/response.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
ctx: Middleware context
|
|
81
|
+
next: Next middleware in chain
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Response data
|
|
85
|
+
"""
|
|
86
|
+
raise NotImplementedError
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class MiddlewareChain:
|
|
90
|
+
"""Chain of middleware for processing.
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
>>> chain = MiddlewareChain()
|
|
94
|
+
>>> chain.use(LoggingMiddleware())
|
|
95
|
+
>>> chain.use(CachingMiddleware())
|
|
96
|
+
>>>
|
|
97
|
+
>>> async def handler(ctx):
|
|
98
|
+
... # Make actual request
|
|
99
|
+
... return response
|
|
100
|
+
>>>
|
|
101
|
+
>>> result = await chain.execute(ctx, handler)
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(self) -> None:
|
|
105
|
+
"""Initialize middleware chain."""
|
|
106
|
+
self._middleware: list[Middleware] = []
|
|
107
|
+
|
|
108
|
+
def use(self, middleware: Middleware) -> MiddlewareChain:
|
|
109
|
+
"""Add middleware to the chain.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
middleware: Middleware to add
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Self for chaining
|
|
116
|
+
"""
|
|
117
|
+
self._middleware.append(middleware)
|
|
118
|
+
return self
|
|
119
|
+
|
|
120
|
+
def use_before(
|
|
121
|
+
self, middleware: Middleware, before: type | str
|
|
122
|
+
) -> MiddlewareChain:
|
|
123
|
+
"""Add middleware before another.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
middleware: Middleware to add
|
|
127
|
+
before: Middleware class or name to insert before
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Self for chaining
|
|
131
|
+
"""
|
|
132
|
+
for i, m in enumerate(self._middleware):
|
|
133
|
+
if (isinstance(before, type) and isinstance(m, before)) or (
|
|
134
|
+
isinstance(before, str) and m.name == before
|
|
135
|
+
):
|
|
136
|
+
self._middleware.insert(i, middleware)
|
|
137
|
+
return self
|
|
138
|
+
|
|
139
|
+
# If not found, add at start
|
|
140
|
+
self._middleware.insert(0, middleware)
|
|
141
|
+
return self
|
|
142
|
+
|
|
143
|
+
def use_after(
|
|
144
|
+
self, middleware: Middleware, after: type | str
|
|
145
|
+
) -> MiddlewareChain:
|
|
146
|
+
"""Add middleware after another.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
middleware: Middleware to add
|
|
150
|
+
after: Middleware class or name to insert after
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Self for chaining
|
|
154
|
+
"""
|
|
155
|
+
for i, m in enumerate(self._middleware):
|
|
156
|
+
if (isinstance(after, type) and isinstance(m, after)) or (
|
|
157
|
+
isinstance(after, str) and m.name == after
|
|
158
|
+
):
|
|
159
|
+
self._middleware.insert(i + 1, middleware)
|
|
160
|
+
return self
|
|
161
|
+
|
|
162
|
+
# If not found, add at end
|
|
163
|
+
self._middleware.append(middleware)
|
|
164
|
+
return self
|
|
165
|
+
|
|
166
|
+
def remove(self, middleware: type | str | Middleware) -> bool:
|
|
167
|
+
"""Remove middleware from the chain.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
middleware: Middleware to remove (class, name, or instance)
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
True if removed, False if not found
|
|
174
|
+
"""
|
|
175
|
+
for i, m in enumerate(self._middleware):
|
|
176
|
+
if (
|
|
177
|
+
m is middleware
|
|
178
|
+
or (isinstance(middleware, type) and isinstance(m, middleware))
|
|
179
|
+
or (isinstance(middleware, str) and m.name == middleware)
|
|
180
|
+
):
|
|
181
|
+
self._middleware.pop(i)
|
|
182
|
+
return True
|
|
183
|
+
return False
|
|
184
|
+
|
|
185
|
+
async def execute(
|
|
186
|
+
self,
|
|
187
|
+
ctx: MiddlewareContext,
|
|
188
|
+
handler: Callable[[MiddlewareContext], Awaitable[dict[str, Any] | None]],
|
|
189
|
+
) -> dict[str, Any] | None:
|
|
190
|
+
"""Execute the middleware chain.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
ctx: Middleware context
|
|
194
|
+
handler: Final handler to execute
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Response data
|
|
198
|
+
"""
|
|
199
|
+
if not self._middleware:
|
|
200
|
+
return await handler(ctx)
|
|
201
|
+
|
|
202
|
+
# Build chain from end to start
|
|
203
|
+
chain = handler
|
|
204
|
+
|
|
205
|
+
for middleware in reversed(self._middleware):
|
|
206
|
+
# Capture middleware in closure
|
|
207
|
+
chain = self._create_next(middleware, chain)
|
|
208
|
+
|
|
209
|
+
return await chain(ctx)
|
|
210
|
+
|
|
211
|
+
def _create_next(
|
|
212
|
+
self,
|
|
213
|
+
middleware: Middleware,
|
|
214
|
+
next_handler: Callable[[MiddlewareContext], Awaitable[dict[str, Any] | None]],
|
|
215
|
+
) -> Callable[[MiddlewareContext], Awaitable[dict[str, Any] | None]]:
|
|
216
|
+
"""Create next handler in chain.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
middleware: Current middleware
|
|
220
|
+
next_handler: Next handler
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Wrapped handler
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
async def handler(ctx: MiddlewareContext) -> dict[str, Any] | None:
|
|
227
|
+
if ctx.aborted:
|
|
228
|
+
return ctx.response
|
|
229
|
+
return await middleware.process(ctx, next_handler)
|
|
230
|
+
|
|
231
|
+
return handler
|
|
232
|
+
|
|
233
|
+
@property
|
|
234
|
+
def middleware_names(self) -> list[str]:
|
|
235
|
+
"""Get list of middleware names."""
|
|
236
|
+
return [m.name for m in self._middleware]
|
|
237
|
+
|
|
238
|
+
def __len__(self) -> int:
|
|
239
|
+
"""Get number of middleware."""
|
|
240
|
+
return len(self._middleware)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class FunctionMiddleware(Middleware):
|
|
244
|
+
"""Middleware created from a function.
|
|
245
|
+
|
|
246
|
+
Example:
|
|
247
|
+
>>> async def log_request(ctx, next):
|
|
248
|
+
... print(f"Request: {ctx.request}")
|
|
249
|
+
... return await next(ctx)
|
|
250
|
+
>>>
|
|
251
|
+
>>> middleware = FunctionMiddleware("logging", log_request)
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
def __init__(
|
|
255
|
+
self,
|
|
256
|
+
name: str,
|
|
257
|
+
func: Callable[
|
|
258
|
+
[
|
|
259
|
+
MiddlewareContext,
|
|
260
|
+
Callable[[MiddlewareContext], Awaitable[dict[str, Any] | None]],
|
|
261
|
+
],
|
|
262
|
+
Awaitable[dict[str, Any] | None],
|
|
263
|
+
],
|
|
264
|
+
) -> None:
|
|
265
|
+
"""Initialize function middleware.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
name: Middleware name
|
|
269
|
+
func: Middleware function
|
|
270
|
+
"""
|
|
271
|
+
self._name = name
|
|
272
|
+
self._func = func
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def name(self) -> str:
|
|
276
|
+
"""Get middleware name."""
|
|
277
|
+
return self._name
|
|
278
|
+
|
|
279
|
+
async def process(
|
|
280
|
+
self,
|
|
281
|
+
ctx: MiddlewareContext,
|
|
282
|
+
next: Callable[[MiddlewareContext], Awaitable[dict[str, Any] | None]],
|
|
283
|
+
) -> dict[str, Any] | None:
|
|
284
|
+
"""Process using the function."""
|
|
285
|
+
return await self._func(ctx, next)
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Plugin registry for managing plugins.
|
|
3
|
+
|
|
4
|
+
Provides plugin registration, discovery, and lifecycle management.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from ai_lib_python.plugins.base import Plugin, PluginContext
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PluginRegistry:
|
|
15
|
+
"""Registry for managing plugins.
|
|
16
|
+
|
|
17
|
+
Handles plugin registration, lookup, and lifecycle.
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
>>> registry = PluginRegistry()
|
|
21
|
+
>>> registry.register(LoggingPlugin())
|
|
22
|
+
>>> registry.register(CachingPlugin())
|
|
23
|
+
>>>
|
|
24
|
+
>>> # Initialize all plugins
|
|
25
|
+
>>> await registry.init_all(ctx)
|
|
26
|
+
>>>
|
|
27
|
+
>>> # Process request through plugins
|
|
28
|
+
>>> request = await registry.process_request(ctx, request)
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self) -> None:
|
|
32
|
+
"""Initialize plugin registry."""
|
|
33
|
+
self._plugins: dict[str, Plugin] = {}
|
|
34
|
+
self._ordered: list[Plugin] = []
|
|
35
|
+
self._initialized = False
|
|
36
|
+
|
|
37
|
+
def register(self, plugin: Plugin) -> PluginRegistry:
|
|
38
|
+
"""Register a plugin.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
plugin: Plugin to register
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Self for chaining
|
|
45
|
+
"""
|
|
46
|
+
if plugin.name in self._plugins:
|
|
47
|
+
raise ValueError(f"Plugin already registered: {plugin.name}")
|
|
48
|
+
|
|
49
|
+
self._plugins[plugin.name] = plugin
|
|
50
|
+
self._update_order()
|
|
51
|
+
return self
|
|
52
|
+
|
|
53
|
+
def unregister(self, name: str) -> bool:
|
|
54
|
+
"""Unregister a plugin.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
name: Plugin name
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
True if unregistered, False if not found
|
|
61
|
+
"""
|
|
62
|
+
if name in self._plugins:
|
|
63
|
+
del self._plugins[name]
|
|
64
|
+
self._update_order()
|
|
65
|
+
return True
|
|
66
|
+
return False
|
|
67
|
+
|
|
68
|
+
def get(self, name: str) -> Plugin | None:
|
|
69
|
+
"""Get a plugin by name.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
name: Plugin name
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Plugin or None
|
|
76
|
+
"""
|
|
77
|
+
return self._plugins.get(name)
|
|
78
|
+
|
|
79
|
+
def has(self, name: str) -> bool:
|
|
80
|
+
"""Check if a plugin is registered.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
name: Plugin name
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
True if registered
|
|
87
|
+
"""
|
|
88
|
+
return name in self._plugins
|
|
89
|
+
|
|
90
|
+
def enable(self, name: str) -> bool:
|
|
91
|
+
"""Enable a plugin (no-op if already enabled).
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
name: Plugin name
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
True if plugin exists
|
|
98
|
+
"""
|
|
99
|
+
return name in self._plugins
|
|
100
|
+
|
|
101
|
+
def disable(self, name: str) -> bool:
|
|
102
|
+
"""Disable a plugin by unregistering it.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
name: Plugin name
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
True if disabled
|
|
109
|
+
"""
|
|
110
|
+
return self.unregister(name)
|
|
111
|
+
|
|
112
|
+
async def init_all(self, ctx: PluginContext | None = None) -> None:
|
|
113
|
+
"""Initialize all registered plugins.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
ctx: Plugin context
|
|
117
|
+
"""
|
|
118
|
+
if self._initialized:
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
ctx = ctx or PluginContext()
|
|
122
|
+
|
|
123
|
+
for plugin in self._ordered:
|
|
124
|
+
if plugin.enabled:
|
|
125
|
+
await plugin.on_init(ctx)
|
|
126
|
+
|
|
127
|
+
self._initialized = True
|
|
128
|
+
|
|
129
|
+
async def shutdown_all(self, ctx: PluginContext | None = None) -> None:
|
|
130
|
+
"""Shutdown all registered plugins.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
ctx: Plugin context
|
|
134
|
+
"""
|
|
135
|
+
if not self._initialized:
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
ctx = ctx or PluginContext()
|
|
139
|
+
|
|
140
|
+
for plugin in reversed(self._ordered):
|
|
141
|
+
if plugin.enabled:
|
|
142
|
+
await plugin.on_shutdown(ctx)
|
|
143
|
+
|
|
144
|
+
self._initialized = False
|
|
145
|
+
|
|
146
|
+
async def process_request(
|
|
147
|
+
self,
|
|
148
|
+
ctx: PluginContext,
|
|
149
|
+
request: dict[str, Any],
|
|
150
|
+
) -> dict[str, Any]:
|
|
151
|
+
"""Process request through all plugins.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
ctx: Plugin context
|
|
155
|
+
request: Request payload
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Modified request
|
|
159
|
+
"""
|
|
160
|
+
for plugin in self._ordered:
|
|
161
|
+
if plugin.enabled:
|
|
162
|
+
request = await plugin.on_request(ctx, request)
|
|
163
|
+
return request
|
|
164
|
+
|
|
165
|
+
async def process_response(
|
|
166
|
+
self,
|
|
167
|
+
ctx: PluginContext,
|
|
168
|
+
response: dict[str, Any],
|
|
169
|
+
) -> dict[str, Any]:
|
|
170
|
+
"""Process response through all plugins.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
ctx: Plugin context
|
|
174
|
+
response: Response data
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
Modified response
|
|
178
|
+
"""
|
|
179
|
+
for plugin in reversed(self._ordered):
|
|
180
|
+
if plugin.enabled:
|
|
181
|
+
response = await plugin.on_response(ctx, response)
|
|
182
|
+
return response
|
|
183
|
+
|
|
184
|
+
async def process_error(
|
|
185
|
+
self,
|
|
186
|
+
ctx: PluginContext,
|
|
187
|
+
error: Exception,
|
|
188
|
+
) -> Exception | None:
|
|
189
|
+
"""Process error through all plugins.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
ctx: Plugin context
|
|
193
|
+
error: The error
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
Modified error or None to suppress
|
|
197
|
+
"""
|
|
198
|
+
for plugin in self._ordered:
|
|
199
|
+
if plugin.enabled:
|
|
200
|
+
result = await plugin.on_error(ctx, error)
|
|
201
|
+
if result is None:
|
|
202
|
+
return None
|
|
203
|
+
error = result
|
|
204
|
+
return error
|
|
205
|
+
|
|
206
|
+
async def on_stream_start(self, ctx: PluginContext) -> None:
|
|
207
|
+
"""Notify plugins of stream start.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
ctx: Plugin context
|
|
211
|
+
"""
|
|
212
|
+
for plugin in self._ordered:
|
|
213
|
+
if plugin.enabled:
|
|
214
|
+
await plugin.on_stream_start(ctx)
|
|
215
|
+
|
|
216
|
+
async def process_stream_chunk(
|
|
217
|
+
self,
|
|
218
|
+
ctx: PluginContext,
|
|
219
|
+
chunk: dict[str, Any],
|
|
220
|
+
) -> dict[str, Any]:
|
|
221
|
+
"""Process stream chunk through all plugins.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
ctx: Plugin context
|
|
225
|
+
chunk: Stream chunk
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
Modified chunk
|
|
229
|
+
"""
|
|
230
|
+
for plugin in self._ordered:
|
|
231
|
+
if plugin.enabled:
|
|
232
|
+
chunk = await plugin.on_stream_chunk(ctx, chunk)
|
|
233
|
+
return chunk
|
|
234
|
+
|
|
235
|
+
async def on_stream_end(self, ctx: PluginContext) -> None:
|
|
236
|
+
"""Notify plugins of stream end.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
ctx: Plugin context
|
|
240
|
+
"""
|
|
241
|
+
for plugin in reversed(self._ordered):
|
|
242
|
+
if plugin.enabled:
|
|
243
|
+
await plugin.on_stream_end(ctx)
|
|
244
|
+
|
|
245
|
+
def _update_order(self) -> None:
|
|
246
|
+
"""Update the plugin execution order."""
|
|
247
|
+
self._ordered = sorted(
|
|
248
|
+
self._plugins.values(),
|
|
249
|
+
key=lambda p: p.priority,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
@property
|
|
253
|
+
def plugins(self) -> list[Plugin]:
|
|
254
|
+
"""Get all plugins in execution order."""
|
|
255
|
+
return list(self._ordered)
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def plugin_names(self) -> list[str]:
|
|
259
|
+
"""Get all plugin names."""
|
|
260
|
+
return [p.name for p in self._ordered]
|
|
261
|
+
|
|
262
|
+
def __len__(self) -> int:
|
|
263
|
+
"""Get number of registered plugins."""
|
|
264
|
+
return len(self._plugins)
|
|
265
|
+
|
|
266
|
+
def __contains__(self, name: str) -> bool:
|
|
267
|
+
"""Check if plugin is registered."""
|
|
268
|
+
return name in self._plugins
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
# Global plugin registry
|
|
272
|
+
_global_registry: PluginRegistry | None = None
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def get_plugin_registry() -> PluginRegistry:
|
|
276
|
+
"""Get the global plugin registry.
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
Global PluginRegistry instance
|
|
280
|
+
"""
|
|
281
|
+
global _global_registry
|
|
282
|
+
if _global_registry is None:
|
|
283
|
+
_global_registry = PluginRegistry()
|
|
284
|
+
return _global_registry
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def set_plugin_registry(registry: PluginRegistry) -> None:
|
|
288
|
+
"""Set the global plugin registry.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
registry: PluginRegistry instance
|
|
292
|
+
"""
|
|
293
|
+
global _global_registry
|
|
294
|
+
_global_registry = registry
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Protocol layer - Protocol loading, validation, and manifest models.
|
|
3
|
+
|
|
4
|
+
This module handles:
|
|
5
|
+
- Loading protocol manifests from various sources
|
|
6
|
+
- Validating manifests against JSON Schema
|
|
7
|
+
- Protocol version validation
|
|
8
|
+
- Strict streaming validation
|
|
9
|
+
- Typed manifest models for runtime use
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from ai_lib_python.protocol.loader import ProtocolLoader
|
|
13
|
+
from ai_lib_python.protocol.manifest import (
|
|
14
|
+
AccumulatorConfig,
|
|
15
|
+
AuthConfig,
|
|
16
|
+
AvailabilityConfig,
|
|
17
|
+
CandidateConfig,
|
|
18
|
+
CapabilitiesConfig,
|
|
19
|
+
DecoderConfig,
|
|
20
|
+
EndpointConfig,
|
|
21
|
+
ErrorClassification,
|
|
22
|
+
EventMapRule,
|
|
23
|
+
HealthCheckConfig,
|
|
24
|
+
ProtocolManifest,
|
|
25
|
+
RateLimitHeaders,
|
|
26
|
+
RetryPolicy,
|
|
27
|
+
ServiceConfig,
|
|
28
|
+
StreamingConfig,
|
|
29
|
+
ToolingConfig,
|
|
30
|
+
ToolUseConfig,
|
|
31
|
+
)
|
|
32
|
+
from ai_lib_python.protocol.validator import (
|
|
33
|
+
SUPPORTED_PROTOCOL_VERSIONS,
|
|
34
|
+
ProtocolValidator,
|
|
35
|
+
ValidationResult,
|
|
36
|
+
validate_manifest,
|
|
37
|
+
validate_manifest_or_raise,
|
|
38
|
+
validate_protocol_version,
|
|
39
|
+
validate_streaming_config,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
__all__ = [
|
|
43
|
+
"AccumulatorConfig",
|
|
44
|
+
"AuthConfig",
|
|
45
|
+
"AvailabilityConfig",
|
|
46
|
+
"CandidateConfig",
|
|
47
|
+
"CapabilitiesConfig",
|
|
48
|
+
"DecoderConfig",
|
|
49
|
+
"EndpointConfig",
|
|
50
|
+
"ErrorClassification",
|
|
51
|
+
"EventMapRule",
|
|
52
|
+
"HealthCheckConfig",
|
|
53
|
+
# Loader and validator
|
|
54
|
+
"ProtocolLoader",
|
|
55
|
+
# Manifest models
|
|
56
|
+
"ProtocolManifest",
|
|
57
|
+
"ProtocolValidator",
|
|
58
|
+
"RateLimitHeaders",
|
|
59
|
+
"RetryPolicy",
|
|
60
|
+
"SUPPORTED_PROTOCOL_VERSIONS",
|
|
61
|
+
"ServiceConfig",
|
|
62
|
+
"StreamingConfig",
|
|
63
|
+
"ToolUseConfig",
|
|
64
|
+
"ToolingConfig",
|
|
65
|
+
# Validation functions
|
|
66
|
+
"ValidationResult",
|
|
67
|
+
"validate_manifest",
|
|
68
|
+
"validate_manifest_or_raise",
|
|
69
|
+
"validate_protocol_version",
|
|
70
|
+
"validate_streaming_config",
|
|
71
|
+
]
|