cats-scoring 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cats/__init__.py +3 -0
- cats/api/__init__.py +1 -0
- cats/api/main.py +145 -0
- cats/api/routes/__init__.py +1 -0
- cats/api/routes/evaluate.py +162 -0
- cats/api/schemas.py +69 -0
- cats/audit/__init__.py +1 -0
- cats/audit/logger.py +120 -0
- cats/calibration/__init__.py +24 -0
- cats/calibration/__main__.py +4 -0
- cats/calibration/build_dataset.py +220 -0
- cats/calibration/calibrate.py +109 -0
- cats/calibration/collect_rss.py +321 -0
- cats/calibration/dataset.py +65 -0
- cats/calibration/evaluate.py +237 -0
- cats/calibration/ga.py +124 -0
- cats/calibration/label_from_ratings.py +243 -0
- cats/calibration/merge_snapshots.py +144 -0
- cats/calibration/objective.py +131 -0
- cats/calibration/report.py +135 -0
- cats/calibration/split.py +145 -0
- cats/core/__init__.py +1 -0
- cats/core/config.py +58 -0
- cats/core/db.py +30 -0
- cats/core/metrics.py +33 -0
- cats/core/models.py +53 -0
- cats/core/security.py +173 -0
- cats/lite.py +123 -0
- cats/pipeline/__init__.py +1 -0
- cats/pipeline/normalizer.py +36 -0
- cats/scoring/__init__.py +1 -0
- cats/scoring/engine.py +44 -0
- cats/scoring/explainer.py +53 -0
- cats/scoring/weights.py +70 -0
- cats/signals/__init__.py +1 -0
- cats/signals/coherence.py +128 -0
- cats/signals/gaming.py +66 -0
- cats/signals/sentiment.py +76 -0
- cats/signals/silence.py +31 -0
- cats/signals/types.py +43 -0
- cats/signals/volatility.py +31 -0
- cats_scoring-1.3.0.dist-info/METADATA +223 -0
- cats_scoring-1.3.0.dist-info/RECORD +46 -0
- cats_scoring-1.3.0.dist-info/WHEEL +5 -0
- cats_scoring-1.3.0.dist-info/licenses/LICENSE +21 -0
- cats_scoring-1.3.0.dist-info/top_level.txt +1 -0
cats/__init__.py
ADDED
cats/api/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""API layer: FastAPI routes, schemas, and request handling."""
|
cats/api/main.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import time
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
|
7
|
+
from fastapi import FastAPI, Request
|
|
8
|
+
from fastapi.exceptions import RequestValidationError
|
|
9
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
10
|
+
from fastapi.responses import JSONResponse, Response
|
|
11
|
+
from prometheus_client import CONTENT_TYPE_LATEST, generate_latest
|
|
12
|
+
|
|
13
|
+
from cats.api.routes.evaluate import router as evaluate_router
|
|
14
|
+
from cats.audit.logger import purge_expired_audits
|
|
15
|
+
from cats.core.config import settings
|
|
16
|
+
from cats.core.db import AsyncSessionLocal
|
|
17
|
+
from cats.core.metrics import HTTP_LATENCY, HTTP_REQUESTS
|
|
18
|
+
from cats.core.security import init_jwt_keys, init_redis
|
|
19
|
+
from cats.signals.coherence import init_nlp
|
|
20
|
+
|
|
21
|
+
# N-06: JSON structured logging. Use structlog-native processors + a filtering
|
|
22
|
+
# bound logger so level filtering works without a stdlib logging backend (the
|
|
23
|
+
# stdlib `filter_by_level` processor calls isEnabledFor(), which PrintLogger
|
|
24
|
+
# lacks, and would crash on every log call).
|
|
25
|
+
structlog.configure(
|
|
26
|
+
processors=[
|
|
27
|
+
structlog.processors.add_log_level,
|
|
28
|
+
structlog.processors.TimeStamper(fmt="iso"),
|
|
29
|
+
structlog.processors.JSONRenderer(),
|
|
30
|
+
],
|
|
31
|
+
wrapper_class=structlog.make_filtering_bound_logger(getattr(logging, settings.log_level.upper(), logging.INFO)),
|
|
32
|
+
logger_factory=structlog.PrintLoggerFactory(),
|
|
33
|
+
cache_logger_on_first_use=True,
|
|
34
|
+
)
|
|
35
|
+
logger = structlog.get_logger()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@asynccontextmanager
|
|
39
|
+
async def lifespan(app: FastAPI):
|
|
40
|
+
logger.info("startup", env=settings.environment)
|
|
41
|
+
init_jwt_keys() # S-01
|
|
42
|
+
await init_redis() # S-03
|
|
43
|
+
init_nlp(settings.spacy_model) # N-01: singleton
|
|
44
|
+
|
|
45
|
+
# Q-03: max_instances=1 prevents overlapping purge jobs
|
|
46
|
+
sched = AsyncIOScheduler()
|
|
47
|
+
sched.add_job(_purge_job, "cron", hour=2, minute=0, max_instances=1, coalesce=True, misfire_grace_time=3600)
|
|
48
|
+
sched.start()
|
|
49
|
+
yield
|
|
50
|
+
sched.shutdown()
|
|
51
|
+
logger.info("shutdown")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
async def _purge_job():
|
|
55
|
+
async with AsyncSessionLocal() as db:
|
|
56
|
+
await purge_expired_audits(db)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
app = FastAPI(title="CATS API", version="1.2.0", lifespan=lifespan)
|
|
60
|
+
|
|
61
|
+
if settings.cors_origins:
|
|
62
|
+
app.add_middleware(
|
|
63
|
+
CORSMiddleware,
|
|
64
|
+
allow_origins=[o.strip() for o in settings.cors_origins.split(",")],
|
|
65
|
+
allow_credentials=True,
|
|
66
|
+
allow_methods=["*"],
|
|
67
|
+
allow_headers=["*"],
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
app.include_router(evaluate_router, prefix="/v1/cats", tags=["cats"])
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# N-07: Prometheus request metrics. Label by the matched route template (not the
|
|
74
|
+
# raw path) to keep label cardinality bounded.
|
|
75
|
+
@app.middleware("http")
|
|
76
|
+
async def _prometheus_middleware(request: Request, call_next):
|
|
77
|
+
start = time.perf_counter()
|
|
78
|
+
response = await call_next(request)
|
|
79
|
+
route = request.scope.get("route")
|
|
80
|
+
path = getattr(route, "path", None) or "unmatched"
|
|
81
|
+
HTTP_REQUESTS.labels(request.method, path, response.status_code).inc()
|
|
82
|
+
HTTP_LATENCY.labels(request.method, path).observe(time.perf_counter() - start)
|
|
83
|
+
return response
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# A-03: RFC 7807 Problem Details
|
|
87
|
+
@app.exception_handler(RequestValidationError)
|
|
88
|
+
async def _val_err(request: Request, exc: RequestValidationError):
|
|
89
|
+
return JSONResponse(
|
|
90
|
+
status_code=422,
|
|
91
|
+
content={
|
|
92
|
+
"type": "about:blank",
|
|
93
|
+
"title": "Validation Error",
|
|
94
|
+
"status": 422,
|
|
95
|
+
"detail": exc.errors(),
|
|
96
|
+
"instance": str(request.url),
|
|
97
|
+
},
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@app.exception_handler(Exception)
|
|
102
|
+
async def _generic_err(request: Request, exc: Exception):
|
|
103
|
+
logger.error("unhandled", error=str(exc), path=str(request.url))
|
|
104
|
+
return JSONResponse(
|
|
105
|
+
status_code=500,
|
|
106
|
+
content={
|
|
107
|
+
"type": "about:blank",
|
|
108
|
+
"title": "Internal Server Error",
|
|
109
|
+
"status": 500,
|
|
110
|
+
"detail": "Unexpected error",
|
|
111
|
+
"instance": str(request.url),
|
|
112
|
+
},
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# Q-02: deep health check
|
|
117
|
+
@app.get("/health")
|
|
118
|
+
async def health():
|
|
119
|
+
from cats.core.db import engine
|
|
120
|
+
from cats.core.security import redis_client
|
|
121
|
+
from cats.signals.coherence import nlp
|
|
122
|
+
|
|
123
|
+
checks: dict = {"api": "ok"}
|
|
124
|
+
try:
|
|
125
|
+
await redis_client.ping()
|
|
126
|
+
checks["redis"] = "ok"
|
|
127
|
+
except Exception as e:
|
|
128
|
+
checks["redis"] = f"error:{e}"
|
|
129
|
+
try:
|
|
130
|
+
from sqlalchemy import text
|
|
131
|
+
|
|
132
|
+
async with engine.connect() as conn:
|
|
133
|
+
await conn.execute(text("SELECT 1"))
|
|
134
|
+
checks["database"] = "ok"
|
|
135
|
+
except Exception as e:
|
|
136
|
+
checks["database"] = f"error:{e}"
|
|
137
|
+
checks["nlp"] = "ok" if nlp else "not_loaded"
|
|
138
|
+
|
|
139
|
+
overall = "healthy" if all(v == "ok" for v in checks.values()) else "degraded"
|
|
140
|
+
return {"status": overall, "checks": checks}
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@app.get("/metrics")
|
|
144
|
+
async def metrics():
|
|
145
|
+
return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""API route handlers."""
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import uuid
|
|
3
|
+
from functools import partial
|
|
4
|
+
|
|
5
|
+
import structlog
|
|
6
|
+
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|
7
|
+
from sqlalchemy import func, select
|
|
8
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
9
|
+
|
|
10
|
+
from cats.api.schemas import (
|
|
11
|
+
BatchEvaluateRequest,
|
|
12
|
+
BatchEvaluateResponse,
|
|
13
|
+
ContestRequest,
|
|
14
|
+
ContestResponse,
|
|
15
|
+
EvaluateRequest,
|
|
16
|
+
EvaluateResponse,
|
|
17
|
+
ExplainResponse,
|
|
18
|
+
StatsResponse,
|
|
19
|
+
)
|
|
20
|
+
from cats.audit.logger import log_contest, log_evaluation
|
|
21
|
+
from cats.core.db import get_db
|
|
22
|
+
from cats.core.metrics import EVALUATIONS, TRUST_SCORE
|
|
23
|
+
from cats.core.models import TrustScore
|
|
24
|
+
from cats.core.security import api_key_bearer, get_client_ip, get_tenant
|
|
25
|
+
from cats.pipeline.normalizer import normalize_messages
|
|
26
|
+
from cats.scoring.engine import aggregate_score, determine_band, requires_human_review
|
|
27
|
+
from cats.scoring.explainer import generate_explanation
|
|
28
|
+
from cats.scoring.weights import get_dynamic_weights
|
|
29
|
+
from cats.signals.coherence import compute_coherence
|
|
30
|
+
from cats.signals.gaming import compute_gaming
|
|
31
|
+
from cats.signals.silence import compute_silence
|
|
32
|
+
from cats.signals.types import SignalResult
|
|
33
|
+
from cats.signals.volatility import compute_volatility
|
|
34
|
+
|
|
35
|
+
logger = structlog.get_logger()
|
|
36
|
+
router = APIRouter()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
async def _evaluate_item(item: EvaluateRequest, request: Request, db: AsyncSession) -> EvaluateResponse:
|
|
40
|
+
"""Run the 4-signal pipeline for one source and stage the result + audit row.
|
|
41
|
+
|
|
42
|
+
Does not commit: the caller commits so /evaluate and /batch control the
|
|
43
|
+
transaction boundary (a batch is all-or-nothing).
|
|
44
|
+
"""
|
|
45
|
+
trace_id = str(uuid.uuid4())
|
|
46
|
+
tenant_id = get_tenant(request)
|
|
47
|
+
msgs = normalize_messages([m.model_dump() for m in item.messages])
|
|
48
|
+
context = item.context or {}
|
|
49
|
+
|
|
50
|
+
loop = asyncio.get_running_loop()
|
|
51
|
+
raw_signals = await asyncio.gather(
|
|
52
|
+
loop.run_in_executor(None, compute_coherence, msgs),
|
|
53
|
+
loop.run_in_executor(None, compute_volatility, msgs),
|
|
54
|
+
loop.run_in_executor(None, partial(compute_silence, msgs, context.get("source_type", "social"))),
|
|
55
|
+
loop.run_in_executor(None, compute_gaming, msgs),
|
|
56
|
+
)
|
|
57
|
+
signals: list[SignalResult] = list(raw_signals)
|
|
58
|
+
weights = get_dynamic_weights(context)
|
|
59
|
+
score = aggregate_score(signals, weights)
|
|
60
|
+
band = determine_band(score)
|
|
61
|
+
review = requires_human_review(score, band, signals)
|
|
62
|
+
|
|
63
|
+
EVALUATIONS.labels(band=band).inc()
|
|
64
|
+
TRUST_SCORE.observe(score)
|
|
65
|
+
|
|
66
|
+
db.add(
|
|
67
|
+
TrustScore(
|
|
68
|
+
tenant_id=tenant_id,
|
|
69
|
+
trace_id=trace_id,
|
|
70
|
+
source_id=item.source_id,
|
|
71
|
+
score=score,
|
|
72
|
+
band=band,
|
|
73
|
+
signals={s.name: {"value": s.value, "confidence": s.confidence} for s in signals},
|
|
74
|
+
weights=weights,
|
|
75
|
+
context_data=context,
|
|
76
|
+
)
|
|
77
|
+
)
|
|
78
|
+
await log_evaluation(
|
|
79
|
+
db,
|
|
80
|
+
trace_id,
|
|
81
|
+
{"source_id": item.source_id, "score": score, "band": band},
|
|
82
|
+
ip=get_client_ip(request),
|
|
83
|
+
tenant_id=tenant_id,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return EvaluateResponse(
|
|
87
|
+
trace_id=trace_id,
|
|
88
|
+
score=score,
|
|
89
|
+
band=band,
|
|
90
|
+
requires_review=review,
|
|
91
|
+
signals=[
|
|
92
|
+
{"name": s.name, "value": s.value, "confidence": s.confidence, "metadata": s.metadata} for s in signals
|
|
93
|
+
],
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@router.post("/evaluate", response_model=EvaluateResponse, dependencies=[Depends(api_key_bearer)])
|
|
98
|
+
async def evaluate(req: EvaluateRequest, request: Request, db: AsyncSession = Depends(get_db)):
|
|
99
|
+
result = await _evaluate_item(req, request, db)
|
|
100
|
+
await db.commit()
|
|
101
|
+
return result
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@router.post("/batch", response_model=BatchEvaluateResponse, dependencies=[Depends(api_key_bearer)])
|
|
105
|
+
async def batch_evaluate(req: BatchEvaluateRequest, request: Request, db: AsyncSession = Depends(get_db)):
|
|
106
|
+
# Sequential per item (signals within an item already run in parallel); a
|
|
107
|
+
# single commit keeps the batch atomic.
|
|
108
|
+
results = [await _evaluate_item(item, request, db) for item in req.items]
|
|
109
|
+
await db.commit()
|
|
110
|
+
return BatchEvaluateResponse(count=len(results), results=results)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@router.get("/explain/{trace_id}", response_model=ExplainResponse, dependencies=[Depends(api_key_bearer)])
|
|
114
|
+
async def explain(trace_id: str, request: Request, db: AsyncSession = Depends(get_db)):
|
|
115
|
+
r = await db.execute(
|
|
116
|
+
select(TrustScore).where(TrustScore.trace_id == trace_id, TrustScore.tenant_id == get_tenant(request))
|
|
117
|
+
)
|
|
118
|
+
ts = r.scalars().first()
|
|
119
|
+
if not ts:
|
|
120
|
+
raise HTTPException(status.HTTP_404_NOT_FOUND, "Trace ID not found")
|
|
121
|
+
sigs = [SignalResult(name=n, value=d["value"], confidence=d["confidence"]) for n, d in ts.signals.items()]
|
|
122
|
+
return ExplainResponse(trace_id=trace_id, explanation=generate_explanation(ts.score, ts.band, sigs, ts.weights))
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@router.post("/contest/{trace_id}", response_model=ContestResponse, dependencies=[Depends(api_key_bearer)])
|
|
126
|
+
async def contest(trace_id: str, body: ContestRequest, request: Request, db: AsyncSession = Depends(get_db)):
|
|
127
|
+
tenant_id = get_tenant(request)
|
|
128
|
+
r = await db.execute(select(TrustScore).where(TrustScore.trace_id == trace_id, TrustScore.tenant_id == tenant_id))
|
|
129
|
+
if not r.scalars().first():
|
|
130
|
+
raise HTTPException(status.HTTP_404_NOT_FOUND, "Trace ID not found")
|
|
131
|
+
cid = await log_contest(db, trace_id, body.reason, tenant_id=tenant_id) # D-02: returns cid
|
|
132
|
+
return ContestResponse(contest_id=cid, status="pending")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@router.post("/review/{trace_id}", dependencies=[Depends(api_key_bearer)])
|
|
136
|
+
async def review(trace_id: str, request: Request, db: AsyncSession = Depends(get_db)):
|
|
137
|
+
tenant_id = get_tenant(request)
|
|
138
|
+
r = await db.execute(select(TrustScore).where(TrustScore.trace_id == trace_id, TrustScore.tenant_id == tenant_id))
|
|
139
|
+
if not r.scalars().first():
|
|
140
|
+
raise HTTPException(status.HTTP_404_NOT_FOUND, "Trace ID not found")
|
|
141
|
+
await log_evaluation(
|
|
142
|
+
db, trace_id, {"event": "human_review_requested"}, ip=get_client_ip(request), tenant_id=tenant_id
|
|
143
|
+
)
|
|
144
|
+
await db.commit()
|
|
145
|
+
return {"message": "Review logged", "trace_id": trace_id}
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@router.get("/stats", response_model=StatsResponse, dependencies=[Depends(api_key_bearer)])
|
|
149
|
+
async def stats(request: Request, db: AsyncSession = Depends(get_db)):
|
|
150
|
+
tenant_id = get_tenant(request)
|
|
151
|
+
total = await db.scalar(select(func.count(TrustScore.id)).where(TrustScore.tenant_id == tenant_id))
|
|
152
|
+
avg = await db.scalar(select(func.avg(TrustScore.score)).where(TrustScore.tenant_id == tenant_id))
|
|
153
|
+
bands_raw = await db.execute(
|
|
154
|
+
select(TrustScore.band, func.count(TrustScore.id))
|
|
155
|
+
.where(TrustScore.tenant_id == tenant_id)
|
|
156
|
+
.group_by(TrustScore.band)
|
|
157
|
+
)
|
|
158
|
+
return StatsResponse(
|
|
159
|
+
total_evaluations=total or 0,
|
|
160
|
+
average_score=avg or 0.0,
|
|
161
|
+
band_distribution={r[0]: r[1] for r in bands_raw.all()},
|
|
162
|
+
)
|
cats/api/schemas.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field, field_validator
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MessageSchema(BaseModel):
|
|
7
|
+
timestamp: str = Field(..., description="ISO 8601 UTC (e.g. 2024-01-15T10:00:00Z)")
|
|
8
|
+
text: str = Field(..., min_length=1, max_length=10_000)
|
|
9
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
10
|
+
|
|
11
|
+
@field_validator("timestamp")
|
|
12
|
+
@classmethod
|
|
13
|
+
def validate_timestamp(cls, v: str) -> str:
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
datetime.fromisoformat(v.replace("Z", "+00:00"))
|
|
18
|
+
except ValueError:
|
|
19
|
+
raise ValueError("timestamp must be ISO 8601 UTC")
|
|
20
|
+
return v
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class EvaluateRequest(BaseModel):
|
|
24
|
+
source_id: str = Field(..., min_length=1, max_length=256)
|
|
25
|
+
messages: List[MessageSchema] = Field(..., min_length=1, max_length=500)
|
|
26
|
+
context: Optional[Dict[str, Any]] = None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class EvaluateResponse(BaseModel):
|
|
30
|
+
trace_id: str
|
|
31
|
+
score: float
|
|
32
|
+
band: str
|
|
33
|
+
requires_review: bool
|
|
34
|
+
signals: List[Dict[str, Any]]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class BatchEvaluateRequest(BaseModel):
|
|
38
|
+
items: List[EvaluateRequest] = Field(..., min_length=1, max_length=50)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class BatchEvaluateResponse(BaseModel):
|
|
42
|
+
count: int
|
|
43
|
+
results: List[EvaluateResponse]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ExplainResponse(BaseModel):
|
|
47
|
+
trace_id: str
|
|
48
|
+
explanation: Dict[str, Any]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ContestRequest(BaseModel):
|
|
52
|
+
reason: str = Field(..., min_length=10, max_length=2000)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ContestResponse(BaseModel):
|
|
56
|
+
contest_id: int
|
|
57
|
+
status: str
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class StatsResponse(BaseModel):
|
|
61
|
+
total_evaluations: int
|
|
62
|
+
average_score: float
|
|
63
|
+
band_distribution: Dict[str, int]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class HealthResponse(BaseModel):
|
|
67
|
+
status: str
|
|
68
|
+
checks: Dict[str, str]
|
|
69
|
+
version: str = "1.2.0"
|
cats/audit/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Audit logging and GDPR compliance."""
|
cats/audit/logger.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from datetime import datetime, timedelta, timezone
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import structlog
|
|
8
|
+
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
|
9
|
+
from sqlalchemy import select
|
|
10
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
11
|
+
|
|
12
|
+
from cats.core.config import settings
|
|
13
|
+
from cats.core.models import AuditLog, Contest
|
|
14
|
+
|
|
15
|
+
logger = structlog.get_logger()
|
|
16
|
+
|
|
17
|
+
_gcm: Optional[AESGCM] = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _get_gcm() -> AESGCM:
|
|
21
|
+
global _gcm
|
|
22
|
+
if _gcm is None:
|
|
23
|
+
key = base64.b64decode(settings.audit_encryption_key)
|
|
24
|
+
if len(key) != 32:
|
|
25
|
+
raise ValueError("AUDIT_ENCRYPTION_KEY must be exactly 32 bytes (256-bit) base64-encoded")
|
|
26
|
+
_gcm = AESGCM(key)
|
|
27
|
+
return _gcm
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _encrypt(data: dict) -> str:
|
|
31
|
+
nonce = os.urandom(12)
|
|
32
|
+
ct = _get_gcm().encrypt(nonce, json.dumps(data).encode(), None)
|
|
33
|
+
return base64.b64encode(nonce + ct).decode()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _decrypt(blob: str) -> dict:
|
|
37
|
+
raw = base64.b64decode(blob)
|
|
38
|
+
return json.loads(_get_gcm().decrypt(raw[:12], raw[12:], None).decode())
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
async def log_evaluation(
|
|
42
|
+
db: AsyncSession,
|
|
43
|
+
trace_id: str,
|
|
44
|
+
data: dict,
|
|
45
|
+
user_id: Optional[str] = None,
|
|
46
|
+
ip: Optional[str] = None,
|
|
47
|
+
tenant_id: str = "default",
|
|
48
|
+
) -> None:
|
|
49
|
+
db.add(
|
|
50
|
+
AuditLog(
|
|
51
|
+
tenant_id=tenant_id,
|
|
52
|
+
trace_id=trace_id,
|
|
53
|
+
event_type="evaluation",
|
|
54
|
+
encrypted_data=_encrypt(data),
|
|
55
|
+
user_id=user_id,
|
|
56
|
+
ip_address=ip,
|
|
57
|
+
timestamp=datetime.now(timezone.utc),
|
|
58
|
+
)
|
|
59
|
+
)
|
|
60
|
+
await db.flush()
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
async def log_contest(
|
|
64
|
+
db: AsyncSession,
|
|
65
|
+
trace_id: str,
|
|
66
|
+
reason: str,
|
|
67
|
+
user_id: Optional[str] = None,
|
|
68
|
+
tenant_id: str = "default",
|
|
69
|
+
) -> int:
|
|
70
|
+
c = Contest(
|
|
71
|
+
tenant_id=tenant_id,
|
|
72
|
+
trace_id=trace_id,
|
|
73
|
+
reason=reason,
|
|
74
|
+
status="pending",
|
|
75
|
+
user_id=user_id,
|
|
76
|
+
created_at=datetime.now(timezone.utc),
|
|
77
|
+
)
|
|
78
|
+
db.add(c)
|
|
79
|
+
await db.commit()
|
|
80
|
+
await db.refresh(c)
|
|
81
|
+
return c.id
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
async def get_audit_log(db: AsyncSession, trace_id: str) -> Optional[dict]:
|
|
85
|
+
r = await db.execute(select(AuditLog).where(AuditLog.trace_id == trace_id).order_by(AuditLog.timestamp.desc()))
|
|
86
|
+
a = r.scalars().first()
|
|
87
|
+
if not a:
|
|
88
|
+
return None
|
|
89
|
+
return {
|
|
90
|
+
"trace_id": a.trace_id,
|
|
91
|
+
"event_type": a.event_type,
|
|
92
|
+
"data": _decrypt(a.encrypted_data),
|
|
93
|
+
"timestamp": a.timestamp.isoformat(),
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _get_redis():
|
|
98
|
+
from cats.core.security import redis_client
|
|
99
|
+
|
|
100
|
+
if redis_client is None:
|
|
101
|
+
raise RuntimeError("Redis not initialized — call init_redis() first")
|
|
102
|
+
return redis_client
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
async def purge_expired_audits(db: AsyncSession) -> None:
|
|
106
|
+
redis = _get_redis()
|
|
107
|
+
acquired = await redis.set("cats:purge_lock", "1", nx=True, ex=300)
|
|
108
|
+
if not acquired:
|
|
109
|
+
logger.info("purge_skipped", reason="lock_held")
|
|
110
|
+
return
|
|
111
|
+
try:
|
|
112
|
+
cutoff = datetime.now(timezone.utc) - timedelta(days=settings.audit_retention_days)
|
|
113
|
+
res = await db.execute(select(AuditLog).where(AuditLog.timestamp < cutoff))
|
|
114
|
+
old = res.scalars().all()
|
|
115
|
+
for row in old:
|
|
116
|
+
await db.delete(row)
|
|
117
|
+
await db.commit()
|
|
118
|
+
logger.info("purge_done", deleted=len(old))
|
|
119
|
+
finally:
|
|
120
|
+
await redis.delete("cats:purge_lock")
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""CATS weight calibration.
|
|
2
|
+
|
|
3
|
+
Tooling to empirically calibrate the signal weight matrix against a labelled
|
|
4
|
+
dataset, replacing the initial hand-picked estimates flagged in WP 4.1.
|
|
5
|
+
|
|
6
|
+
The optimisation uses a small, dependency-free genetic algorithm (see
|
|
7
|
+
``cats.calibration.ga``). Because CATS scores are *ordinal* rankings (WP 4.3),
|
|
8
|
+
the objective maximises a rank-agreement metric (Spearman / pairwise
|
|
9
|
+
concordance) rather than absolute error.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from cats.calibration.calibrate import CalibrationOutput, calibrate
|
|
13
|
+
from cats.calibration.dataset import LabeledSample, load_dataset
|
|
14
|
+
from cats.calibration.ga import GAConfig, GAResult, GeneticOptimizer
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"calibrate",
|
|
18
|
+
"CalibrationOutput",
|
|
19
|
+
"load_dataset",
|
|
20
|
+
"LabeledSample",
|
|
21
|
+
"GeneticOptimizer",
|
|
22
|
+
"GAConfig",
|
|
23
|
+
"GAResult",
|
|
24
|
+
]
|