dbl-gateway 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dbl_gateway/__init__.py +1 -0
- dbl_gateway/adapters/__init__.py +9 -0
- dbl_gateway/adapters/execution_adapter_kl.py +133 -0
- dbl_gateway/adapters/policy_adapter_dbl_policy.py +96 -0
- dbl_gateway/adapters/store_adapter_sqlite.py +55 -0
- dbl_gateway/admission.py +67 -0
- dbl_gateway/app.py +501 -0
- dbl_gateway/auth.py +295 -0
- dbl_gateway/capabilities.py +79 -0
- dbl_gateway/digest.py +31 -0
- dbl_gateway/execution.py +15 -0
- dbl_gateway/governance.py +20 -0
- dbl_gateway/models.py +24 -0
- dbl_gateway/ports/__init__.py +11 -0
- dbl_gateway/ports/execution_port.py +19 -0
- dbl_gateway/ports/policy_port.py +18 -0
- dbl_gateway/ports/store_port.py +33 -0
- dbl_gateway/projection.py +34 -0
- dbl_gateway/providers/__init__.py +1 -0
- dbl_gateway/providers/anthropic.py +63 -0
- dbl_gateway/providers/errors.py +5 -0
- dbl_gateway/providers/openai.py +105 -0
- dbl_gateway/store/__init__.py +1 -0
- dbl_gateway/store/base.py +35 -0
- dbl_gateway/store/factory.py +12 -0
- dbl_gateway/store/sqlite.py +200 -0
- dbl_gateway/wire_contract.py +65 -0
- dbl_gateway-0.3.2.dist-info/METADATA +78 -0
- dbl_gateway-0.3.2.dist-info/RECORD +33 -0
- dbl_gateway-0.3.2.dist-info/WHEEL +5 -0
- dbl_gateway-0.3.2.dist-info/entry_points.txt +2 -0
- dbl_gateway-0.3.2.dist-info/licenses/LICENSE +21 -0
- dbl_gateway-0.3.2.dist-info/top_level.txt +1 -0
dbl_gateway/auth.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Any, Mapping, Sequence
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class AuthConfig:
|
|
12
|
+
mode: str
|
|
13
|
+
issuer: str
|
|
14
|
+
audience: str
|
|
15
|
+
jwks_url: str
|
|
16
|
+
allowed_tenants: tuple[str, ...]
|
|
17
|
+
allow_all_tenants: bool
|
|
18
|
+
tenant_claim: str
|
|
19
|
+
role_claims: tuple[str, ...]
|
|
20
|
+
role_map: dict[str, list[str]] | None
|
|
21
|
+
dev_actor: str
|
|
22
|
+
dev_roles: tuple[str, ...]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class Actor:
|
|
27
|
+
actor_id: str
|
|
28
|
+
tenant_id: str
|
|
29
|
+
client_id: str
|
|
30
|
+
roles: tuple[str, ...]
|
|
31
|
+
raw_claims: dict[str, Any]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class AuthError(Exception):
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ForbiddenError(Exception):
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
_JWKS_BY_URL: dict[str, dict[str, Any]] = {}
|
|
43
|
+
_JWKS_TS_BY_URL: dict[str, float] = {}
|
|
44
|
+
_JWKS_CACHE_TTL_S: float = 300.0
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def load_auth_config() -> AuthConfig:
|
|
48
|
+
mode = os.getenv("DBL_GATEWAY_AUTH_MODE", "dev").strip().lower()
|
|
49
|
+
issuer = os.getenv("DBL_GATEWAY_OIDC_ISSUER", "").strip()
|
|
50
|
+
audience = os.getenv("DBL_GATEWAY_OIDC_AUDIENCE", "").strip()
|
|
51
|
+
jwks_url = os.getenv("DBL_GATEWAY_OIDC_JWKS_URL", "").strip()
|
|
52
|
+
allowed_tenants_raw = os.getenv("DBL_GATEWAY_ALLOWED_TENANTS", "*").strip()
|
|
53
|
+
tenant_claim = os.getenv("DBL_GATEWAY_TENANT_CLAIM", "tid").strip() or "tid"
|
|
54
|
+
role_claims_raw = os.getenv("DBL_GATEWAY_ROLE_CLAIMS", "roles").strip()
|
|
55
|
+
role_map_raw = os.getenv("DBL_GATEWAY_ROLE_MAP", "").strip()
|
|
56
|
+
|
|
57
|
+
dev_actor = os.getenv("DBL_GATEWAY_DEV_ACTOR", "dev-user").strip()
|
|
58
|
+
dev_roles_raw = os.getenv(
|
|
59
|
+
"DBL_GATEWAY_DEV_ROLES",
|
|
60
|
+
"gateway.intent.write,gateway.decision.write,gateway.snapshot.read",
|
|
61
|
+
).strip()
|
|
62
|
+
dev_roles = tuple([r.strip() for r in dev_roles_raw.split(",") if r.strip()])
|
|
63
|
+
|
|
64
|
+
allow_all_tenants = allowed_tenants_raw == "*" or allowed_tenants_raw == ""
|
|
65
|
+
allowed_tenants = tuple([t.strip() for t in allowed_tenants_raw.split(",") if t.strip()])
|
|
66
|
+
role_claims = tuple([c.strip() for c in role_claims_raw.split(",") if c.strip()])
|
|
67
|
+
role_map = _parse_role_map(role_map_raw)
|
|
68
|
+
|
|
69
|
+
return AuthConfig(
|
|
70
|
+
mode=mode,
|
|
71
|
+
issuer=issuer,
|
|
72
|
+
audience=audience,
|
|
73
|
+
jwks_url=jwks_url,
|
|
74
|
+
allowed_tenants=allowed_tenants,
|
|
75
|
+
allow_all_tenants=allow_all_tenants,
|
|
76
|
+
tenant_claim=tenant_claim,
|
|
77
|
+
role_claims=role_claims,
|
|
78
|
+
role_map=role_map,
|
|
79
|
+
dev_actor=dev_actor,
|
|
80
|
+
dev_roles=dev_roles,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def require_roles(actor: Actor, required: Sequence[str]) -> None:
|
|
85
|
+
missing = [r for r in required if r not in actor.roles]
|
|
86
|
+
if missing:
|
|
87
|
+
raise ForbiddenError(f"missing roles: {', '.join(missing)}")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def require_tenant(actor: Actor, cfg: AuthConfig | None = None) -> None:
|
|
91
|
+
cfg = cfg or load_auth_config()
|
|
92
|
+
if cfg.allow_all_tenants:
|
|
93
|
+
return
|
|
94
|
+
if actor.tenant_id in cfg.allowed_tenants:
|
|
95
|
+
return
|
|
96
|
+
raise ForbiddenError("tenant not allowed")
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
async def authenticate_request(headers: Mapping[str, str], cfg: AuthConfig | None = None) -> Actor:
|
|
100
|
+
cfg = cfg or load_auth_config()
|
|
101
|
+
if cfg.mode == "dev":
|
|
102
|
+
return _authenticate_dev(headers, cfg)
|
|
103
|
+
if cfg.mode == "oidc":
|
|
104
|
+
claims = await _authenticate_oidc(headers, cfg)
|
|
105
|
+
return _authorize_oidc_claims(claims, cfg)
|
|
106
|
+
raise AuthError(f"unsupported auth mode: {cfg.mode}")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _authenticate_dev(headers: Mapping[str, str], cfg: AuthConfig) -> Actor:
|
|
110
|
+
actor_id = headers.get("x-dev-actor", cfg.dev_actor).strip() or cfg.dev_actor
|
|
111
|
+
roles_header = headers.get("x-dev-roles", "")
|
|
112
|
+
roles = cfg.dev_roles
|
|
113
|
+
if roles_header.strip():
|
|
114
|
+
roles = tuple([r.strip() for r in roles_header.split(",") if r.strip()])
|
|
115
|
+
|
|
116
|
+
return Actor(
|
|
117
|
+
actor_id=actor_id,
|
|
118
|
+
tenant_id=headers.get("x-dev-tenant", "dev-tenant").strip() or "dev-tenant",
|
|
119
|
+
client_id=headers.get("x-dev-client", "dev-client").strip() or "dev-client",
|
|
120
|
+
roles=roles,
|
|
121
|
+
raw_claims={"dev": True},
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
async def _authenticate_oidc(headers: Mapping[str, str], cfg: AuthConfig) -> dict[str, Any]:
|
|
126
|
+
if not cfg.issuer or not cfg.audience or not cfg.jwks_url:
|
|
127
|
+
raise AuthError("OIDC config incomplete: issuer, audience, jwks_url required")
|
|
128
|
+
|
|
129
|
+
auth = headers.get("authorization", "")
|
|
130
|
+
if not auth.lower().startswith("bearer "):
|
|
131
|
+
raise AuthError("missing bearer token")
|
|
132
|
+
token = auth.split(" ", 1)[1].strip()
|
|
133
|
+
if not token:
|
|
134
|
+
raise AuthError("missing bearer token")
|
|
135
|
+
|
|
136
|
+
jwks = await _get_jwks(cfg.jwks_url)
|
|
137
|
+
try:
|
|
138
|
+
from jose import jwk, jwt
|
|
139
|
+
from jose.exceptions import JWTError
|
|
140
|
+
|
|
141
|
+
header = jwt.get_unverified_header(token)
|
|
142
|
+
try:
|
|
143
|
+
jwk_data = _select_jwk(header, jwks)
|
|
144
|
+
except AuthError as exc:
|
|
145
|
+
if str(exc) != "no matching JWKS key for kid":
|
|
146
|
+
raise
|
|
147
|
+
jwks = await _get_jwks(cfg.jwks_url, force=True)
|
|
148
|
+
jwk_data = _select_jwk(header, jwks)
|
|
149
|
+
key = jwk.construct(jwk_data)
|
|
150
|
+
claims = jwt.decode(
|
|
151
|
+
token,
|
|
152
|
+
key,
|
|
153
|
+
algorithms=["RS256"],
|
|
154
|
+
issuer=cfg.issuer,
|
|
155
|
+
audience=cfg.audience,
|
|
156
|
+
options={
|
|
157
|
+
"verify_aud": True,
|
|
158
|
+
"verify_iss": True,
|
|
159
|
+
"verify_exp": True,
|
|
160
|
+
"verify_nbf": True,
|
|
161
|
+
"verify_iat": True,
|
|
162
|
+
},
|
|
163
|
+
leeway=60,
|
|
164
|
+
)
|
|
165
|
+
except ImportError as exc:
|
|
166
|
+
raise AuthError("OIDC auth requires python-jose") from exc
|
|
167
|
+
except JWTError as exc:
|
|
168
|
+
raise AuthError(f"invalid token: {exc}") from exc
|
|
169
|
+
|
|
170
|
+
return dict(claims)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
async def _get_jwks(jwks_url: str, *, force: bool = False) -> dict[str, Any]:
|
|
174
|
+
global _JWKS_BY_URL, _JWKS_TS_BY_URL
|
|
175
|
+
now = time.time()
|
|
176
|
+
if not force and jwks_url in _JWKS_BY_URL:
|
|
177
|
+
ts = _JWKS_TS_BY_URL.get(jwks_url, 0.0)
|
|
178
|
+
if (now - ts) < _JWKS_CACHE_TTL_S:
|
|
179
|
+
return _JWKS_BY_URL[jwks_url]
|
|
180
|
+
|
|
181
|
+
try:
|
|
182
|
+
import httpx
|
|
183
|
+
except ImportError as exc:
|
|
184
|
+
raise AuthError("OIDC auth requires httpx") from exc
|
|
185
|
+
|
|
186
|
+
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
187
|
+
resp = await client.get(jwks_url)
|
|
188
|
+
resp.raise_for_status()
|
|
189
|
+
jwks = resp.json()
|
|
190
|
+
|
|
191
|
+
if not isinstance(jwks, dict) or "keys" not in jwks:
|
|
192
|
+
raise AuthError("invalid JWKS payload")
|
|
193
|
+
|
|
194
|
+
_JWKS_BY_URL[jwks_url] = jwks
|
|
195
|
+
_JWKS_TS_BY_URL[jwks_url] = now
|
|
196
|
+
return jwks
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _pick_first_str(claims: Mapping[str, Any], keys: list[str], default: str) -> str:
|
|
200
|
+
for k in keys:
|
|
201
|
+
v = claims.get(k)
|
|
202
|
+
if isinstance(v, str) and v.strip():
|
|
203
|
+
return v.strip()
|
|
204
|
+
return default
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _select_jwk(header: Mapping[str, Any], jwks: Mapping[str, Any]) -> Mapping[str, Any]:
|
|
208
|
+
kid = header.get("kid")
|
|
209
|
+
alg = header.get("alg")
|
|
210
|
+
if alg != "RS256":
|
|
211
|
+
raise AuthError("unsupported token algorithm")
|
|
212
|
+
if not isinstance(kid, str) or not kid.strip():
|
|
213
|
+
raise AuthError("token missing kid")
|
|
214
|
+
keys = jwks.get("keys")
|
|
215
|
+
if not isinstance(keys, list):
|
|
216
|
+
raise AuthError("invalid JWKS payload")
|
|
217
|
+
for key in keys:
|
|
218
|
+
if isinstance(key, Mapping) and key.get("kid") == kid:
|
|
219
|
+
return key
|
|
220
|
+
raise AuthError("no matching JWKS key for kid")
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _authorize_oidc_claims(claims: Mapping[str, Any], cfg: AuthConfig) -> Actor:
|
|
224
|
+
actor_id = _pick_first_str(claims, ["oid", "sub"], default="")
|
|
225
|
+
if not actor_id:
|
|
226
|
+
raise AuthError("token missing actor id claim (oid/sub)")
|
|
227
|
+
tenant_id = _pick_first_str(claims, [cfg.tenant_claim], default="unknown")
|
|
228
|
+
client_id = _pick_first_str(claims, ["azp", "appid"], default="unknown")
|
|
229
|
+
roles = _extract_roles(claims, cfg.role_claims)
|
|
230
|
+
roles = _apply_role_map(roles, cfg.role_map)
|
|
231
|
+
return Actor(
|
|
232
|
+
actor_id=actor_id,
|
|
233
|
+
tenant_id=tenant_id,
|
|
234
|
+
client_id=client_id,
|
|
235
|
+
roles=roles,
|
|
236
|
+
raw_claims=dict(claims),
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _extract_roles(claims: Mapping[str, Any], claim_names: Sequence[str]) -> tuple[str, ...]:
|
|
241
|
+
roles: list[str] = []
|
|
242
|
+
for claim in claim_names:
|
|
243
|
+
v = claims.get(claim)
|
|
244
|
+
if isinstance(v, list):
|
|
245
|
+
for item in v:
|
|
246
|
+
if isinstance(item, str) and item.strip():
|
|
247
|
+
roles.append(item.strip())
|
|
248
|
+
elif isinstance(v, str) and v.strip():
|
|
249
|
+
for item in v.replace(",", " ").split():
|
|
250
|
+
if item.strip():
|
|
251
|
+
roles.append(item.strip())
|
|
252
|
+
seen: set[str] = set()
|
|
253
|
+
unique: list[str] = []
|
|
254
|
+
for r in roles:
|
|
255
|
+
if r not in seen:
|
|
256
|
+
seen.add(r)
|
|
257
|
+
unique.append(r)
|
|
258
|
+
return tuple(unique)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def _apply_role_map(roles: tuple[str, ...], role_map: dict[str, list[str]] | None) -> tuple[str, ...]:
|
|
262
|
+
if role_map is None:
|
|
263
|
+
return roles
|
|
264
|
+
mapped: list[str] = []
|
|
265
|
+
for role in roles:
|
|
266
|
+
if role in role_map:
|
|
267
|
+
mapped.extend(role_map[role])
|
|
268
|
+
else:
|
|
269
|
+
mapped.append(role)
|
|
270
|
+
seen: set[str] = set()
|
|
271
|
+
unique: list[str] = []
|
|
272
|
+
for role in mapped:
|
|
273
|
+
if role not in seen:
|
|
274
|
+
seen.add(role)
|
|
275
|
+
unique.append(role)
|
|
276
|
+
return tuple(unique)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def _parse_role_map(role_map_raw: str) -> dict[str, list[str]] | None:
|
|
280
|
+
if role_map_raw == "":
|
|
281
|
+
return None
|
|
282
|
+
try:
|
|
283
|
+
parsed = json.loads(role_map_raw)
|
|
284
|
+
except json.JSONDecodeError as exc:
|
|
285
|
+
raise AuthError("DBL_GATEWAY_ROLE_MAP must be valid JSON") from exc
|
|
286
|
+
if not isinstance(parsed, dict):
|
|
287
|
+
raise AuthError("DBL_GATEWAY_ROLE_MAP must be a JSON object")
|
|
288
|
+
role_map: dict[str, list[str]] = {}
|
|
289
|
+
for key, value in parsed.items():
|
|
290
|
+
if isinstance(value, str):
|
|
291
|
+
role_map[str(key)] = [value]
|
|
292
|
+
elif isinstance(value, list):
|
|
293
|
+
mapped = [v for v in value if isinstance(v, str) and v.strip()]
|
|
294
|
+
role_map[str(key)] = mapped
|
|
295
|
+
return role_map
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_capabilities() -> dict[str, object]:
|
|
7
|
+
models = _available_models()
|
|
8
|
+
allowed_model_ids = [model["model_id"] for model in models]
|
|
9
|
+
default_model_id = _default_model_id(allowed_model_ids)
|
|
10
|
+
return {
|
|
11
|
+
"allowed_model_ids": allowed_model_ids,
|
|
12
|
+
"default_model_id": default_model_id,
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _available_models() -> list[dict[str, str]]:
|
|
17
|
+
models: list[dict[str, str]] = []
|
|
18
|
+
if _get_openai_key():
|
|
19
|
+
for model_id in _openai_models():
|
|
20
|
+
models.append({"provider": "openai", "model_id": model_id})
|
|
21
|
+
if _get_anthropic_key():
|
|
22
|
+
for model_id in _anthropic_models():
|
|
23
|
+
models.append({"provider": "anthropic", "model_id": model_id})
|
|
24
|
+
return models
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def resolve_provider(model_id: str) -> tuple[str | None, str | None]:
|
|
28
|
+
if model_id in _openai_models():
|
|
29
|
+
if not _get_openai_key():
|
|
30
|
+
return None, "provider.missing_credentials"
|
|
31
|
+
return "openai", None
|
|
32
|
+
if model_id in _anthropic_models():
|
|
33
|
+
if not _get_anthropic_key():
|
|
34
|
+
return None, "provider.missing_credentials"
|
|
35
|
+
return "anthropic", None
|
|
36
|
+
return None, "model.unavailable"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _openai_models() -> list[str]:
|
|
40
|
+
chat_models = _parse_csv("OPENAI_CHAT_MODEL_IDS")
|
|
41
|
+
if not chat_models:
|
|
42
|
+
chat_models = _parse_csv("OPENAI_MODEL_IDS")
|
|
43
|
+
if not chat_models:
|
|
44
|
+
chat_models = ["gpt-4o-mini"]
|
|
45
|
+
response_models = _parse_csv("OPENAI_RESPONSES_MODEL_IDS") or []
|
|
46
|
+
combined: list[str] = []
|
|
47
|
+
seen: set[str] = set()
|
|
48
|
+
for model_id in chat_models + response_models:
|
|
49
|
+
if model_id in seen:
|
|
50
|
+
continue
|
|
51
|
+
seen.add(model_id)
|
|
52
|
+
combined.append(model_id)
|
|
53
|
+
return combined
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _anthropic_models() -> list[str]:
|
|
57
|
+
models = _parse_csv("ANTHROPIC_MODEL_IDS")
|
|
58
|
+
return models or ["claude-3-haiku-20240307"]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _default_model_id(allowed: list[str]) -> str | None:
|
|
62
|
+
if not allowed:
|
|
63
|
+
return None
|
|
64
|
+
return allowed[0]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _get_openai_key() -> str:
|
|
68
|
+
return os.getenv("OPENAI_API_KEY", "").strip()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _get_anthropic_key() -> str:
|
|
72
|
+
return os.getenv("ANTHROPIC_API_KEY", "").strip()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _parse_csv(name: str) -> list[str]:
|
|
76
|
+
raw = os.getenv(name, "").strip()
|
|
77
|
+
if not raw:
|
|
78
|
+
return []
|
|
79
|
+
return [item.strip() for item in raw.split(",") if item.strip()]
|
dbl_gateway/digest.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from dbl_core import DblEvent, DblEventKind
|
|
6
|
+
from dbl_core.events.canonical import canonicalize_value, digest_bytes, json_dumps
|
|
7
|
+
|
|
8
|
+
__all__ = ["event_digest", "v_digest"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def event_digest(kind: str, correlation_id: str, payload: dict[str, Any]) -> tuple[str, int]:
|
|
12
|
+
event_kind = DblEventKind(kind)
|
|
13
|
+
event_payload = _strip_obs(payload)
|
|
14
|
+
event = DblEvent(event_kind=event_kind, correlation_id=correlation_id, data=event_payload)
|
|
15
|
+
canonical_json = event.to_json(include_observational=False)
|
|
16
|
+
return event.digest(), len(canonical_json)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def v_digest(indexed: list[tuple[int, str]]) -> str:
|
|
20
|
+
items = [{"index": idx, "digest": digest} for idx, digest in indexed]
|
|
21
|
+
canonical = canonicalize_value(items)
|
|
22
|
+
canonical_json = json_dumps(canonical)
|
|
23
|
+
return digest_bytes(canonical_json)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _strip_obs(payload: dict[str, Any]) -> dict[str, Any]:
|
|
27
|
+
if "_obs" not in payload:
|
|
28
|
+
return payload
|
|
29
|
+
sanitized = dict(payload)
|
|
30
|
+
sanitized.pop("_obs", None)
|
|
31
|
+
return sanitized
|
dbl_gateway/execution.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .adapters.execution_adapter_kl import KlExecutionAdapter, schedule_execution
|
|
4
|
+
from .ports.execution_port import ExecutionResult
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"ExecutionResult",
|
|
8
|
+
"KlExecutionAdapter",
|
|
9
|
+
"run_execution",
|
|
10
|
+
"schedule_execution",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
async def run_execution(intent_event):
|
|
15
|
+
return await KlExecutionAdapter().run(intent_event)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .adapters.policy_adapter_dbl_policy import (
|
|
4
|
+
ALLOWED_CONTEXT_KEYS,
|
|
5
|
+
DblPolicyAdapter,
|
|
6
|
+
_build_policy_context,
|
|
7
|
+
)
|
|
8
|
+
from .ports.policy_port import DecisionResult
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"ALLOWED_CONTEXT_KEYS",
|
|
12
|
+
"DecisionResult",
|
|
13
|
+
"DblPolicyAdapter",
|
|
14
|
+
"_build_policy_context",
|
|
15
|
+
"decide_for_intent",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def decide_for_intent(authoritative_input):
|
|
20
|
+
return DblPolicyAdapter().decide(authoritative_input)
|
dbl_gateway/models.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, TypedDict
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class EventRecord(TypedDict):
|
|
7
|
+
index: int
|
|
8
|
+
kind: str
|
|
9
|
+
lane: str
|
|
10
|
+
actor: str
|
|
11
|
+
intent_type: str
|
|
12
|
+
stream_id: str
|
|
13
|
+
correlation_id: str
|
|
14
|
+
payload: dict[str, Any]
|
|
15
|
+
digest: str
|
|
16
|
+
canon_len: int
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Snapshot(TypedDict):
|
|
20
|
+
length: int
|
|
21
|
+
offset: int
|
|
22
|
+
limit: int
|
|
23
|
+
v_digest: str
|
|
24
|
+
events: list[EventRecord]
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .execution_port import ExecutionPort, ExecutionResult
|
|
2
|
+
from .policy_port import DecisionResult, PolicyPort
|
|
3
|
+
from .store_port import StorePort
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"DecisionResult",
|
|
7
|
+
"ExecutionPort",
|
|
8
|
+
"ExecutionResult",
|
|
9
|
+
"PolicyPort",
|
|
10
|
+
"StorePort",
|
|
11
|
+
]
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Mapping, Protocol
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(frozen=True)
|
|
8
|
+
class ExecutionResult:
|
|
9
|
+
output_text: str | None = None
|
|
10
|
+
provider: str | None = None
|
|
11
|
+
model_id: str | None = None
|
|
12
|
+
trace: dict[str, Any] | None = None
|
|
13
|
+
trace_digest: str | None = None
|
|
14
|
+
error: dict[str, Any] | None = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ExecutionPort(Protocol):
|
|
18
|
+
async def run(self, intent_event: Mapping[str, Any]) -> ExecutionResult:
|
|
19
|
+
...
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Mapping, Protocol
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(frozen=True)
|
|
8
|
+
class DecisionResult:
|
|
9
|
+
decision: str
|
|
10
|
+
reason_codes: list[str]
|
|
11
|
+
policy_id: str | None = None
|
|
12
|
+
policy_version: int | None = None
|
|
13
|
+
gate_event: object | None = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PolicyPort(Protocol):
|
|
17
|
+
def decide(self, authoritative_input: Mapping[str, Any]) -> DecisionResult:
|
|
18
|
+
...
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Protocol
|
|
4
|
+
|
|
5
|
+
from ..models import EventRecord, Snapshot
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class StorePort(Protocol):
|
|
9
|
+
def append(
|
|
10
|
+
self,
|
|
11
|
+
*,
|
|
12
|
+
kind: str,
|
|
13
|
+
lane: str,
|
|
14
|
+
actor: str,
|
|
15
|
+
intent_type: str,
|
|
16
|
+
stream_id: str,
|
|
17
|
+
correlation_id: str,
|
|
18
|
+
payload: dict[str, object],
|
|
19
|
+
) -> EventRecord:
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
def snapshot(
|
|
23
|
+
self,
|
|
24
|
+
*,
|
|
25
|
+
limit: int,
|
|
26
|
+
offset: int,
|
|
27
|
+
stream_id: str | None = None,
|
|
28
|
+
lane: str | None = None,
|
|
29
|
+
) -> Snapshot:
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
def close(self) -> None:
|
|
33
|
+
...
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Iterable
|
|
4
|
+
|
|
5
|
+
from dbl_core import DblEvent, DblEventKind, GateDecision
|
|
6
|
+
from dbl_main import Phase, RunnerStatus, State, project_state
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def project_runner_state(events: Iterable[dict[str, object]]) -> State:
|
|
10
|
+
dbl_events: list[DblEvent] = []
|
|
11
|
+
for event in events:
|
|
12
|
+
kind = DblEventKind(str(event.get("kind")))
|
|
13
|
+
correlation_id = str(event.get("correlation_id"))
|
|
14
|
+
data = event.get("payload")
|
|
15
|
+
if kind == DblEventKind.DECISION and isinstance(data, dict):
|
|
16
|
+
decision = str(data.get("decision", "DENY"))
|
|
17
|
+
reason_codes = data.get("reason_codes")
|
|
18
|
+
if isinstance(reason_codes, list) and reason_codes:
|
|
19
|
+
reason_code = str(reason_codes[0])
|
|
20
|
+
else:
|
|
21
|
+
reason_code = str(data.get("reason_code", "unspecified"))
|
|
22
|
+
reason_message = data.get("reason_message") if isinstance(data.get("reason_message"), str) else None
|
|
23
|
+
data = GateDecision(decision, reason_code, reason_message)
|
|
24
|
+
dbl_events.append(DblEvent(event_kind=kind, correlation_id=correlation_id, data=data))
|
|
25
|
+
return project_state(dbl_events)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def state_payload(state: State) -> dict[str, object]:
|
|
29
|
+
return {
|
|
30
|
+
"phase": state.phase.value,
|
|
31
|
+
"runner_status": state.runner_status.value,
|
|
32
|
+
"t_index": state.t_index,
|
|
33
|
+
"note": state.note,
|
|
34
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__all__ = ["openai", "anthropic", "errors"]
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
|
|
8
|
+
from .errors import ProviderError
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def execute(message: str, model_id: str) -> str:
|
|
12
|
+
api_key = os.getenv("ANTHROPIC_API_KEY", "").strip()
|
|
13
|
+
if not api_key:
|
|
14
|
+
raise ProviderError("missing Anthropic credentials")
|
|
15
|
+
headers = {
|
|
16
|
+
"x-api-key": api_key,
|
|
17
|
+
"anthropic-version": "2023-06-01",
|
|
18
|
+
}
|
|
19
|
+
payload: dict[str, Any] = {
|
|
20
|
+
"model": model_id,
|
|
21
|
+
"max_tokens": 256,
|
|
22
|
+
"messages": [
|
|
23
|
+
{"role": "user", "content": [{"type": "text", "text": message}]}
|
|
24
|
+
],
|
|
25
|
+
}
|
|
26
|
+
with httpx.Client(timeout=60.0) as client:
|
|
27
|
+
resp = client.post("https://api.anthropic.com/v1/messages", json=payload, headers=headers)
|
|
28
|
+
if resp.status_code >= 400:
|
|
29
|
+
_raise_anthropic(resp)
|
|
30
|
+
data = resp.json()
|
|
31
|
+
return _parse_text(data)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _parse_text(data: dict[str, Any]) -> str:
|
|
35
|
+
content = data.get("content", [])
|
|
36
|
+
if not isinstance(content, list):
|
|
37
|
+
return ""
|
|
38
|
+
parts: list[str] = []
|
|
39
|
+
for entry in content:
|
|
40
|
+
if entry.get("type") == "text":
|
|
41
|
+
text = entry.get("text")
|
|
42
|
+
if isinstance(text, str):
|
|
43
|
+
parts.append(text)
|
|
44
|
+
return "\n".join(parts)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _raise_anthropic(resp: httpx.Response) -> None:
|
|
48
|
+
code = None
|
|
49
|
+
msg = None
|
|
50
|
+
try:
|
|
51
|
+
j = resp.json()
|
|
52
|
+
err = j.get("error") if isinstance(j, dict) else None
|
|
53
|
+
if isinstance(err, dict):
|
|
54
|
+
code = err.get("type")
|
|
55
|
+
msg = err.get("message")
|
|
56
|
+
except Exception:
|
|
57
|
+
pass
|
|
58
|
+
detail = msg or resp.text[:500]
|
|
59
|
+
raise ProviderError(
|
|
60
|
+
f"anthropic.messages failed: {detail}",
|
|
61
|
+
status_code=resp.status_code,
|
|
62
|
+
code=str(code) if code else None,
|
|
63
|
+
)
|