prismcortex 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,44 @@
1
+ """Salience gate — the cheap novelty check that runs *before* the expensive extraction.
2
+
3
+ Biology gates encoding on novelty/urgency (the amygdala) instead of recording every
4
+ moment. We do the same: low-value turns ("ok thanks") never trigger an LLM extraction
5
+ call, and high-urgency turns fast-track straight to consolidation. This is the
6
+ difference between a demo and something with a sane per-turn cost.
7
+
8
+ These are deterministic heuristics — no randomness, no model call. A production build
9
+ can replace this with prismresonance's FrequencyFamily classifier behind the same
10
+ function signature.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ from .models import Band
15
+
16
+ _URGENCY = (
17
+ "urgent", "asap", "critical", "emergency", "immediately", "right now",
18
+ "breaking", "alert", "deadline", "outage", "down ", "failure",
19
+ )
20
+ _CORRECTION = (
21
+ "actually", "correction", "i meant", "not ", "no, ", "wrong", "instead",
22
+ "update ", "change ", "rather ", "should be", "is now",
23
+ )
24
+ _LOW_VALUE = frozenset({
25
+ "ok", "okay", "k", "thanks", "thank you", "thx", "cool", "nice", "great",
26
+ "got it", "sure", "yes", "no", "yep", "nope", "hi", "hello", "hey", "bye",
27
+ "lol", "haha", "good", "fine",
28
+ })
29
+
30
+
31
+ def assess(text: str) -> Band:
32
+ """Classify a payload's salience band. Cheap, deterministic, runs on every turn."""
33
+ t = " ".join(text.lower().split())
34
+ if not t:
35
+ return Band.ARCHIVE
36
+ if t.rstrip("!.") in _LOW_VALUE:
37
+ return Band.ARCHIVE
38
+ if len(t.split()) <= 2:
39
+ return Band.NEUTRAL
40
+ if any(w in t for w in _URGENCY):
41
+ return Band.EMERGENCY
42
+ if any(w in t for w in _CORRECTION):
43
+ return Band.ALERT
44
+ return Band.NORMAL
prismcortex/server.py ADDED
@@ -0,0 +1,520 @@
1
+ """PrismCortex memory service — multi-tenant, RBAC, observability, enterprise APIs."""
2
+ from __future__ import annotations
3
+
4
+ import asyncio
5
+ import hmac
6
+ import json
7
+ import logging
8
+ import os
9
+ import time
10
+ from collections import deque
11
+ from threading import Lock, Semaphore
12
+ from typing import Any, Callable, Optional
13
+
14
+ from fastapi import FastAPI, Request
15
+ from fastapi.responses import JSONResponse
16
+ from fastapi.staticfiles import StaticFiles
17
+ from pydantic import BaseModel, Field
18
+
19
+ from .auth import (
20
+ ROLE_ADMIN,
21
+ ROLE_FORGET,
22
+ ROLE_READ,
23
+ ROLE_WRITE,
24
+ AuthContext,
25
+ auth_required,
26
+ authenticate,
27
+ )
28
+ from .engine import Memory
29
+ from .labels import aliases_snapshot, load_aliases, register_alias, save_aliases
30
+ from .policy import PolicyEngine
31
+ from .server_helpers import CountingGemini, rate_limiter_from_env, read_executor, write_executor
32
+ from .tenant import TenantMemoryManager
33
+ from .tracing import current_trace, start_trace, trace_enabled, traced
34
+
35
+ DATA_DIR = os.environ.get("PRISMCORTEX_DATA", ".prismcortex_data")
36
+ os.makedirs(DATA_DIR, exist_ok=True)
37
+
38
+ logger = logging.getLogger("prismcortex")
39
+ logger.setLevel(logging.INFO)
40
+ _fmt = logging.Formatter("%(message)s")
41
+ _sh = logging.StreamHandler()
42
+ _sh.setFormatter(_fmt)
43
+ logger.addHandler(_sh)
44
+ _fh = logging.FileHandler(os.path.join(DATA_DIR, "server.jsonl"))
45
+ _fh.setFormatter(_fmt)
46
+ logger.addHandler(_fh)
47
+
48
+
49
+ def log_event(**fields) -> None:
50
+ tr = current_trace()
51
+ if tr:
52
+ fields["trace_id"] = tr.trace_id
53
+ fields["ts"] = round(time.time(), 4)
54
+ logger.info(json.dumps(fields, separators=(",", ":")))
55
+
56
+
57
+ class Metrics:
58
+ def __init__(self) -> None:
59
+ self.started = time.time()
60
+ self.counts = {"digest": 0, "recall": 0, "sleep": 0, "errors": 0, "rate_limited": 0}
61
+ self.cache_hits = 0
62
+ self.cache_misses = 0
63
+ self.raw_bytes = 0
64
+ self._lat = {"digest": deque(maxlen=5000), "recall": deque(maxlen=5000)}
65
+ self._lock = Lock()
66
+
67
+ def record(self, op: str, ms: float) -> None:
68
+ with self._lock:
69
+ if op in self._lat:
70
+ self._lat[op].append(ms)
71
+
72
+ def reset(self) -> None:
73
+ with self._lock:
74
+ self.started = time.time()
75
+ self.counts = {"digest": 0, "recall": 0, "sleep": 0, "errors": 0, "rate_limited": 0}
76
+ self.cache_hits = 0
77
+ self.cache_misses = 0
78
+ self.raw_bytes = 0
79
+ for d in self._lat.values():
80
+ d.clear()
81
+
82
+ @staticmethod
83
+ def _pct(vals, p):
84
+ if not vals:
85
+ return None
86
+ s = sorted(vals)
87
+ i = min(len(s) - 1, int(round((p / 100) * (len(s) - 1))))
88
+ return round(s[i], 2)
89
+
90
+ def snapshot(self, gemini_calls: int, backend: str, graph_version: int, *, staging: int = 0, conflicts: int = 0) -> dict:
91
+ with self._lock:
92
+ lat = {op: {"n": len(v), "p50": self._pct(v, 50), "p95": self._pct(v, 95), "p99": self._pct(v, 99)} for op, v in self._lat.items()}
93
+ total = self.cache_hits + self.cache_misses
94
+ return {
95
+ "backend": backend,
96
+ "uptime_s": round(time.time() - self.started, 1),
97
+ "counts": dict(self.counts),
98
+ "cache_hits": self.cache_hits,
99
+ "cache_misses": self.cache_misses,
100
+ "cache_hit_rate": round(self.cache_hits / total, 4) if total else None,
101
+ "gemini_calls": gemini_calls,
102
+ "graph_version": graph_version,
103
+ "staging_pending": staging,
104
+ "conflicts_open": conflicts,
105
+ "latency_ms": lat,
106
+ }
107
+
108
+
109
+ metrics = Metrics()
110
+ _backend = os.environ.get("PRISMCORTEX_BACKEND", "lite")
111
+ _use_ann = os.environ.get("PRISMCORTEX_USE_ANN", "1") != "0"
112
+ _tenant_mgr: Optional[TenantMemoryManager] = None
113
+ _policy = PolicyEngine(DATA_DIR)
114
+ _rate_limiter = rate_limiter_from_env()
115
+ _digest_sem = Semaphore(int(os.environ.get("PRISMCORTEX_MAX_CONCURRENT_DIGEST", "16")))
116
+ _build_lock = Lock()
117
+
118
+
119
+ async def _run_read(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
120
+ loop = asyncio.get_running_loop()
121
+ return await loop.run_in_executor(read_executor(), lambda: fn(*args, **kwargs))
122
+
123
+
124
+ async def _run_write(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
125
+ loop = asyncio.get_running_loop()
126
+ return await loop.run_in_executor(write_executor(), lambda: fn(*args, **kwargs))
127
+
128
+ # Back-compat for tests: set _memory directly to bypass tenant manager
129
+ _memory: Optional[Memory] = None
130
+ _llm: Optional[CountingGemini] = None
131
+
132
+ API_KEY = os.environ.get("PRISMCORTEX_API_KEY")
133
+ _OPEN_PATHS = {"/health", "/docs", "/openapi.json", "/redoc", "/docs/oauth2-redirect", "/console", "/console/"}
134
+ if not auth_required():
135
+ logger.warning(json.dumps({"event": "auth_disabled", "warn": "No API keys configured — UNAUTHENTICATED (dev only)"}))
136
+
137
+ _alias_path = os.path.join(DATA_DIR, "aliases.json")
138
+ if os.path.isfile(_alias_path):
139
+ load_aliases(_alias_path)
140
+
141
+
142
+ def _tenant_mgr_instance() -> TenantMemoryManager:
143
+ global _tenant_mgr
144
+ if _tenant_mgr is None:
145
+ with _build_lock:
146
+ if _tenant_mgr is None:
147
+ _tenant_mgr = TenantMemoryManager(DATA_DIR, _backend, use_ann=_use_ann)
148
+ return _tenant_mgr
149
+
150
+
151
+ def get_memory(auth: Optional[AuthContext] = None) -> tuple[Memory, Optional[CountingGemini]]:
152
+ global _memory, _llm
153
+ if _memory is not None:
154
+ return _memory, _llm
155
+ tenant = auth.tenant_id if auth else "default"
156
+ region = auth.region if auth else os.environ.get("PRISMCORTEX_REGION", "default")
157
+ mem, llm = _tenant_mgr_instance().get(tenant, region)
158
+ log_event(event="memory_built", tenant=tenant, region=region, backend=_backend)
159
+ return mem, llm
160
+
161
+
162
+ def _auth_ctx(request: Request) -> Optional[AuthContext]:
163
+ if not auth_required():
164
+ return AuthContext()
165
+ token = request.headers.get("x-api-key") or _bearer(request.headers.get("authorization"))
166
+ return authenticate(token)
167
+
168
+
169
+ def _bearer(auth: Optional[str]) -> Optional[str]:
170
+ if auth and auth.lower().startswith("bearer "):
171
+ return auth.split(" ", 1)[1]
172
+ return None
173
+
174
+
175
+ def _deny(msg: str, code: int = 403) -> JSONResponse:
176
+ return JSONResponse({"detail": msg}, status_code=code)
177
+
178
+
179
+ app = FastAPI(title="PrismCortex", version="0.2.1")
180
+ _static = os.path.join(os.path.dirname(__file__), "static")
181
+ if os.path.isdir(_static):
182
+ app.mount("/console", StaticFiles(directory=_static, html=True), name="console")
183
+
184
+
185
+ @app.middleware("http")
186
+ async def _middleware(request: Request, call_next):
187
+ if trace_enabled() and request.url.path not in _OPEN_PATHS:
188
+ start_trace(request.headers.get("x-trace-id"))
189
+ auth = _auth_ctx(request)
190
+ if auth_required() and request.url.path not in _OPEN_PATHS:
191
+ token = request.headers.get("x-api-key") or _bearer(request.headers.get("authorization"))
192
+ if not auth:
193
+ return JSONResponse({"detail": "invalid or missing API key"}, status_code=401)
194
+ if _rate_limiter and not _rate_limiter.allow(token or request.client.host or "anon"):
195
+ metrics.counts["rate_limited"] += 1
196
+ return JSONResponse({"detail": "rate limit exceeded"}, status_code=429)
197
+ request.state.auth = auth
198
+ resp = await call_next(request)
199
+ tr = current_trace()
200
+ if tr and trace_enabled():
201
+ log_event(event="trace", **tr.to_dict())
202
+ return resp
203
+
204
+
205
+ class DigestBody(BaseModel):
206
+ text: str = Field(max_length=100_000)
207
+ source_id: Optional[str] = Field(default=None, max_length=256)
208
+ agent_id: Optional[str] = Field(default=None, max_length=256)
209
+
210
+
211
+ class RecallBody(BaseModel):
212
+ query: str = Field(max_length=8_000)
213
+
214
+
215
+ class ForgetBody(BaseModel):
216
+ source_id: str
217
+
218
+
219
+ class AliasBody(BaseModel):
220
+ canonical: str
221
+ alias: str
222
+
223
+
224
+ class ResolveBody(BaseModel):
225
+ subject: str
226
+ relation: str
227
+ chosen_value: str
228
+
229
+
230
+ class LegalHoldBody(BaseModel):
231
+ source_id: str
232
+
233
+
234
+ @app.get("/health")
235
+ def health():
236
+ mem, _ = get_memory(AuthContext()) if _memory or _tenant_mgr else (None, None)
237
+ alerts = []
238
+ staging = mem.staging.pending_count() if mem else 0
239
+ if staging > int(os.environ.get("PRISMCORTEX_STAGING_WARN", "50")):
240
+ alerts.append({"level": "warn", "msg": f"staging backlog={staging}"})
241
+ if metrics.counts["errors"] > 10:
242
+ alerts.append({"level": "warn", "msg": f"errors={metrics.counts['errors']}"})
243
+ return {
244
+ "ok": True,
245
+ "version": "0.2.1",
246
+ "backend": _backend,
247
+ "auth": auth_required(),
248
+ "multi_tenant": True,
249
+ "ann": _use_ann,
250
+ "alerts": alerts,
251
+ }
252
+
253
+
254
+ @app.post("/digest")
255
+ async def digest(body: DigestBody, request: Request):
256
+ auth: AuthContext = request.state.auth or AuthContext()
257
+ if auth_required() and not auth.allows(ROLE_WRITE):
258
+ return _deny("write role required")
259
+ if not _digest_sem.acquire(blocking=False):
260
+ metrics.counts["rate_limited"] += 1
261
+ return JSONResponse({"detail": "digest backpressure — retry later"}, status_code=429)
262
+ try:
263
+ mem, _ = get_memory(auth)
264
+ metrics.raw_bytes += len(body.text.encode("utf-8"))
265
+ t0 = time.perf_counter()
266
+
267
+ def work():
268
+ with traced("digest"):
269
+ return mem.digest(body.text, source_id=body.source_id, agent_id=body.agent_id)
270
+
271
+ try:
272
+ res = await _run_write(work)
273
+ except Exception as exc: # noqa: BLE001
274
+ metrics.counts["errors"] += 1
275
+ log_event(event="digest_error", error=str(exc)[:200])
276
+ raise
277
+ ms = (time.perf_counter() - t0) * 1000
278
+ metrics.counts["digest"] += 1
279
+ metrics.record("digest", ms)
280
+ log_event(event="digest", tenant=auth.tenant_id, outcome=res.outcome.value, ms=round(ms, 2))
281
+ return {"outcome": res.outcome.value, "band": res.band.value, "version": res.version.version, "ms": round(ms, 2)}
282
+ finally:
283
+ _digest_sem.release()
284
+
285
+
286
+ @app.post("/recall")
287
+ async def recall(body: RecallBody, request: Request):
288
+ auth: AuthContext = request.state.auth or AuthContext()
289
+ if auth_required() and not auth.allows(ROLE_READ):
290
+ return _deny("read role required")
291
+ mem, _ = get_memory(auth)
292
+ t0 = time.perf_counter()
293
+
294
+ def work():
295
+ with traced("recall"):
296
+ return mem.recall(body.query)
297
+
298
+ res = await _run_read(work)
299
+ ms = (time.perf_counter() - t0) * 1000
300
+ metrics.counts["recall"] += 1
301
+ metrics.record("recall", ms)
302
+ metrics.cache_hits += int(res.cache_hit)
303
+ metrics.cache_misses += int(not res.cache_hit)
304
+ return {
305
+ "answer": res.answer, "cache_hit": res.cache_hit, "subgraph_hash": res.subgraph_hash,
306
+ "version": res.version, "confidence": res.confidence,
307
+ "freshness": res.freshness.isoformat() if res.freshness else None,
308
+ "node_ids": res.node_ids, "edge_ids": res.edge_ids, "ms": round(ms, 2),
309
+ }
310
+
311
+
312
+ @app.post("/explain")
313
+ def explain(body: RecallBody, request: Request):
314
+ auth: AuthContext = request.state.auth or AuthContext()
315
+ if auth_required() and not auth.allows(ROLE_READ):
316
+ return _deny("read role required")
317
+ mem, _ = get_memory(auth)
318
+ return mem.explain(body.query).model_dump(mode="json")
319
+
320
+
321
+ @app.post("/recall_at")
322
+ def recall_at(body: RecallBody, request: Request, at: Optional[str] = None):
323
+ auth: AuthContext = request.state.auth or AuthContext()
324
+ if auth_required() and not auth.allows(ROLE_READ):
325
+ return _deny("read role required")
326
+ mem, _ = get_memory(auth)
327
+ res = mem.recall_at(body.query, at=at)
328
+ return res.model_dump(mode="json")
329
+
330
+
331
+ @app.get("/replay_certificate")
332
+ def replay_certificate(query: str, request: Request):
333
+ auth: AuthContext = request.state.auth or AuthContext()
334
+ if auth_required() and not auth.allows(ROLE_READ):
335
+ return _deny("read role required")
336
+ mem, _ = get_memory(auth)
337
+ return mem.replay_certificate(query)
338
+
339
+
340
+ @app.post("/forget")
341
+ def forget(body: ForgetBody, request: Request):
342
+ auth: AuthContext = request.state.auth or AuthContext()
343
+ if auth_required() and not auth.allows(ROLE_FORGET, ROLE_ADMIN):
344
+ return _deny("forget/admin role required")
345
+ ok, reason = _policy.can_forget(body.source_id)
346
+ if not ok:
347
+ return _deny(reason, 409)
348
+ mem, _ = get_memory(auth)
349
+ receipt = mem.forget(body.source_id)
350
+ log_event(event="forget", tenant=auth.tenant_id, **receipt)
351
+ return receipt
352
+
353
+
354
+ @app.get("/conflicts")
355
+ def conflicts(request: Request):
356
+ auth: AuthContext = request.state.auth or AuthContext()
357
+ if auth_required() and not auth.allows(ROLE_READ):
358
+ return _deny("read role required")
359
+ mem, _ = get_memory(auth)
360
+ return {"conflicts": mem.conflicts()}
361
+
362
+
363
+ @app.post("/conflicts/resolve")
364
+ def resolve_conflict(body: ResolveBody, request: Request):
365
+ auth: AuthContext = request.state.auth or AuthContext()
366
+ if auth_required() and not auth.allows(ROLE_WRITE, ROLE_ADMIN):
367
+ return _deny("write/admin role required")
368
+ mem, _ = get_memory(auth)
369
+ try:
370
+ v = mem.resolve_conflict(body.subject, body.relation, body.chosen_value)
371
+ except ValueError as exc:
372
+ return JSONResponse({"detail": str(exc)}, status_code=404)
373
+ return {"version": v.version, "content_hash": v.content_hash}
374
+
375
+
376
+ @app.get("/aliases")
377
+ def list_aliases(request: Request):
378
+ auth: AuthContext = request.state.auth or AuthContext()
379
+ return {"aliases": aliases_snapshot(tenant_id=auth.tenant_id)}
380
+
381
+
382
+ @app.post("/aliases")
383
+ def add_alias(body: AliasBody, request: Request):
384
+ auth: AuthContext = request.state.auth or AuthContext()
385
+ if auth_required() and not auth.allows(ROLE_WRITE, ROLE_ADMIN):
386
+ return _deny("write role required")
387
+ register_alias(body.canonical, body.alias, tenant_id=auth.tenant_id)
388
+ save_aliases(os.path.join(DATA_DIR, f"aliases_{auth.tenant_id}.json"), tenant_id=auth.tenant_id)
389
+ return {"ok": True}
390
+
391
+
392
+ @app.post("/legal_hold")
393
+ def legal_hold(body: LegalHoldBody, request: Request):
394
+ auth: AuthContext = request.state.auth or AuthContext()
395
+ if auth_required() and not auth.allows(ROLE_ADMIN):
396
+ return _deny("admin role required")
397
+ _policy.add_legal_hold(body.source_id)
398
+ return {"ok": True, "source_id": body.source_id}
399
+
400
+
401
+ @app.delete("/legal_hold/{source_id}")
402
+ def release_hold(source_id: str, request: Request):
403
+ auth: AuthContext = request.state.auth or AuthContext()
404
+ if auth_required() and not auth.allows(ROLE_ADMIN):
405
+ return _deny("admin role required")
406
+ _policy.remove_legal_hold(source_id)
407
+ return {"ok": True}
408
+
409
+
410
+ @app.get("/policy")
411
+ def policy_snapshot(request: Request):
412
+ auth: AuthContext = request.state.auth or AuthContext()
413
+ if auth_required() and not auth.allows(ROLE_READ, ROLE_ADMIN):
414
+ return _deny("read role required")
415
+ return _policy.snapshot()
416
+
417
+
418
+ @app.get("/tombstones")
419
+ def tombstones(request: Request):
420
+ auth: AuthContext = request.state.auth or AuthContext()
421
+ mem, _ = get_memory(auth)
422
+ return {"tombstones": mem.store.tombstones() if hasattr(mem.store, "tombstones") else []}
423
+
424
+
425
+ @app.post("/sleep")
426
+ async def sleep(request: Request):
427
+ auth: AuthContext = request.state.auth or AuthContext()
428
+ if auth_required() and not auth.allows(ROLE_WRITE):
429
+ return _deny("write role required")
430
+ mem, _ = get_memory(auth)
431
+ n = await _run_write(mem.sleep)
432
+ metrics.counts["sleep"] += 1
433
+ log_event(event="sleep", consolidated=n)
434
+ return {"consolidated": n}
435
+
436
+
437
+ @app.get("/audit")
438
+ def audit(request: Request, at: Optional[str] = None):
439
+ auth: AuthContext = request.state.auth or AuthContext()
440
+ mem, _ = get_memory(auth)
441
+ edges = mem.store.all_edges() if hasattr(mem.store, "all_edges") else []
442
+ if at:
443
+ from datetime import datetime, timezone
444
+ ts = datetime.fromisoformat(at.replace("Z", "+00:00"))
445
+ if ts.tzinfo is None:
446
+ ts = ts.replace(tzinfo=timezone.utc)
447
+ valid = []
448
+ for e in edges:
449
+ vf = e.valid_from.replace(tzinfo=timezone.utc) if e.valid_from.tzinfo is None else e.valid_from
450
+ vt = e.valid_to.replace(tzinfo=timezone.utc) if e.valid_to and e.valid_to.tzinfo is None else e.valid_to
451
+ if vf <= ts and (vt is None or ts < vt):
452
+ valid.append(e)
453
+ return {"at": at, "edges_valid": len(valid), "total_edges": len(edges)}
454
+ superseded = [e for e in edges if e.valid_to is not None]
455
+ return {"total_edges": len(edges), "current": sum(1 for e in edges if e.valid_to is None), "superseded_retained": len(superseded)}
456
+
457
+
458
+ @app.get("/memory_stats")
459
+ def memory_stats(request: Request):
460
+ auth: AuthContext = request.state.auth or AuthContext()
461
+ mem, _ = get_memory(auth)
462
+ nodes = mem.store.all_nodes() if hasattr(mem.store, "all_nodes") else []
463
+ edges = [e for e in mem.store.all_edges() if e.valid_to is None] if hasattr(mem.store, "all_edges") else []
464
+ id2label = {n.id: n.label for n in nodes}
465
+ gist = json.dumps({"nodes": [{"label": n.label, "kind": n.kind} for n in nodes],
466
+ "edges": [{"s": id2label.get(e.src), "r": e.relation, "d": id2label.get(e.dst)} for e in edges]}, separators=(",", ":"))
467
+ dim = len(nodes[0].embedding) if nodes and nodes[0].embedding else 0
468
+ gist_bytes = len(gist.encode("utf-8"))
469
+ index_bytes = len(nodes) * dim * 4
470
+ raw = metrics.raw_bytes
471
+ return {
472
+ "tenant": auth.tenant_id, "region": auth.region,
473
+ "raw_bytes_ingested": raw, "gist_bytes": gist_bytes, "index_bytes_est": index_bytes,
474
+ "graph_nodes": len(nodes), "graph_current_edges": len(edges),
475
+ "compression_ratio_gist": round(raw / gist_bytes, 2) if gist_bytes else None,
476
+ "ann_enabled": getattr(mem.store, "tenant_id", None) is not None and _use_ann,
477
+ }
478
+
479
+
480
+ @app.get("/metrics")
481
+ def get_metrics(request: Request):
482
+ auth: AuthContext = request.state.auth or AuthContext()
483
+ if auth_required() and not auth.allows(ROLE_READ):
484
+ return _deny("read role required")
485
+ mem, llm = (_memory, _llm) if _memory is not None else (None, None)
486
+ if mem is None and _tenant_mgr is not None:
487
+ mem, llm = _tenant_mgr.peek(auth.tenant_id, auth.region)
488
+ gv = mem.store.version().version if mem else 0
489
+ staging = mem.staging.pending_count() if mem else 0
490
+ conflicts = len(mem.conflicts()) if mem else 0
491
+ return metrics.snapshot(gemini_calls=(llm.calls if llm else 0), backend=_backend,
492
+ graph_version=gv, staging=staging, conflicts=conflicts)
493
+
494
+
495
+ @app.get("/dashboard")
496
+ def dashboard(request: Request):
497
+ """Ops snapshot for monitoring (cache, staging, conflicts, errors)."""
498
+ auth: AuthContext = request.state.auth or AuthContext()
499
+ if auth_required() and not auth.allows(ROLE_READ):
500
+ return _deny("read role required")
501
+ m = get_metrics(request)
502
+ if isinstance(m, JSONResponse):
503
+ return m
504
+ return {"health": health(), "metrics": m, "policy": _policy.snapshot()}
505
+
506
+
507
+ @app.post("/reset")
508
+ def reset(request: Request):
509
+ auth: AuthContext = request.state.auth or AuthContext()
510
+ if auth_required() and not auth.allows(ROLE_ADMIN):
511
+ return _deny("admin role required")
512
+ global _memory, _llm
513
+ if _memory is not None:
514
+ _memory = None
515
+ _llm = None
516
+ else:
517
+ _tenant_mgr_instance().reset(auth.tenant_id, auth.region)
518
+ metrics.reset()
519
+ log_event(event="reset", tenant=auth.tenant_id)
520
+ return {"ok": True}
@@ -0,0 +1,74 @@
1
+ """Shared server utilities (keeps server.py thinner)."""
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ import time
6
+ from collections import deque
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from functools import lru_cache
9
+ from threading import Lock
10
+ from typing import Optional
11
+
12
+
13
+ class CountingGemini:
14
+ """Wraps GeminiClient and counts model calls."""
15
+
16
+ def __init__(self, model: Optional[str] = None):
17
+ from .llm.gemini import GeminiClient
18
+
19
+ self._g = GeminiClient(model=model)
20
+ self.calls = 0
21
+
22
+ @property
23
+ def model_id(self):
24
+ return self._g.model_id
25
+
26
+ def extract(self, text, context):
27
+ self.calls += 1
28
+ return self._g.extract(text, context)
29
+
30
+ def render(self, query, subgraph):
31
+ self.calls += 1
32
+ return self._g.render(query, subgraph)
33
+
34
+
35
+ class RateLimiter:
36
+ """Token-bucket rate limit per client key (API key or IP)."""
37
+
38
+ def __init__(self, rpm: int = 600) -> None:
39
+ self._rpm = max(1, rpm)
40
+ self._windows: dict[str, deque] = {}
41
+ self._lock = Lock()
42
+
43
+ def allow(self, key: str) -> bool:
44
+ now = time.time()
45
+ window = 60.0
46
+ with self._lock:
47
+ q = self._windows.setdefault(key, deque())
48
+ while q and now - q[0] > window:
49
+ q.popleft()
50
+ if len(q) >= self._rpm:
51
+ return False
52
+ q.append(now)
53
+ return True
54
+
55
+
56
+ def rate_limiter_from_env() -> Optional[RateLimiter]:
57
+ rpm = os.environ.get("PRISMCORTEX_RATE_LIMIT_RPM")
58
+ if rpm is None or rpm == "0":
59
+ return None
60
+ return RateLimiter(rpm=int(rpm))
61
+
62
+
63
+ @lru_cache(maxsize=1)
64
+ def read_executor() -> ThreadPoolExecutor:
65
+ """Dedicated pool for /recall and other read paths — not starved by digest work."""
66
+ n = int(os.environ.get("PRISMCORTEX_READ_POOL", "64"))
67
+ return ThreadPoolExecutor(max_workers=max(4, n), thread_name_prefix="pc-read")
68
+
69
+
70
+ @lru_cache(maxsize=1)
71
+ def write_executor() -> ThreadPoolExecutor:
72
+ """Digest and other write paths; size aligned with PRISMCORTEX_MAX_CONCURRENT_DIGEST."""
73
+ n = int(os.environ.get("PRISMCORTEX_MAX_CONCURRENT_DIGEST", "16"))
74
+ return ThreadPoolExecutor(max_workers=max(1, n), thread_name_prefix="pc-write")