spanforge 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spanforge/__init__.py +815 -0
- spanforge/_ansi.py +93 -0
- spanforge/_batch_exporter.py +409 -0
- spanforge/_cli.py +2094 -0
- spanforge/_cli_audit.py +639 -0
- spanforge/_cli_compliance.py +711 -0
- spanforge/_cli_cost.py +243 -0
- spanforge/_cli_ops.py +791 -0
- spanforge/_cli_phase11.py +356 -0
- spanforge/_hooks.py +337 -0
- spanforge/_server.py +1708 -0
- spanforge/_span.py +1036 -0
- spanforge/_store.py +288 -0
- spanforge/_stream.py +664 -0
- spanforge/_trace.py +335 -0
- spanforge/_tracer.py +254 -0
- spanforge/actor.py +141 -0
- spanforge/alerts.py +469 -0
- spanforge/auto.py +464 -0
- spanforge/baseline.py +335 -0
- spanforge/cache.py +635 -0
- spanforge/compliance.py +325 -0
- spanforge/config.py +532 -0
- spanforge/consent.py +228 -0
- spanforge/consumer.py +377 -0
- spanforge/core/__init__.py +5 -0
- spanforge/core/compliance_mapping.py +1254 -0
- spanforge/cost.py +600 -0
- spanforge/debug.py +548 -0
- spanforge/deprecations.py +205 -0
- spanforge/drift.py +482 -0
- spanforge/egress.py +58 -0
- spanforge/eval.py +648 -0
- spanforge/event.py +1064 -0
- spanforge/exceptions.py +240 -0
- spanforge/explain.py +178 -0
- spanforge/export/__init__.py +69 -0
- spanforge/export/append_only.py +337 -0
- spanforge/export/cloud.py +357 -0
- spanforge/export/datadog.py +497 -0
- spanforge/export/grafana.py +320 -0
- spanforge/export/jsonl.py +195 -0
- spanforge/export/openinference.py +158 -0
- spanforge/export/otel_bridge.py +294 -0
- spanforge/export/otlp.py +811 -0
- spanforge/export/otlp_bridge.py +233 -0
- spanforge/export/redis_backend.py +282 -0
- spanforge/export/siem_schema.py +98 -0
- spanforge/export/siem_splunk.py +264 -0
- spanforge/export/siem_syslog.py +212 -0
- spanforge/export/webhook.py +299 -0
- spanforge/exporters/__init__.py +30 -0
- spanforge/exporters/console.py +271 -0
- spanforge/exporters/jsonl.py +144 -0
- spanforge/exporters/sqlite.py +142 -0
- spanforge/gate.py +1150 -0
- spanforge/governance.py +181 -0
- spanforge/hitl.py +295 -0
- spanforge/http.py +187 -0
- spanforge/inspect.py +427 -0
- spanforge/integrations/__init__.py +45 -0
- spanforge/integrations/_pricing.py +280 -0
- spanforge/integrations/anthropic.py +388 -0
- spanforge/integrations/azure_openai.py +133 -0
- spanforge/integrations/bedrock.py +292 -0
- spanforge/integrations/crewai.py +251 -0
- spanforge/integrations/gemini.py +351 -0
- spanforge/integrations/groq.py +442 -0
- spanforge/integrations/langchain.py +349 -0
- spanforge/integrations/langgraph.py +306 -0
- spanforge/integrations/llamaindex.py +373 -0
- spanforge/integrations/ollama.py +287 -0
- spanforge/integrations/openai.py +368 -0
- spanforge/integrations/together.py +483 -0
- spanforge/io.py +214 -0
- spanforge/lint.py +322 -0
- spanforge/metrics.py +417 -0
- spanforge/metrics_export.py +343 -0
- spanforge/migrate.py +402 -0
- spanforge/model_registry.py +278 -0
- spanforge/models.py +389 -0
- spanforge/namespaces/__init__.py +254 -0
- spanforge/namespaces/audit.py +256 -0
- spanforge/namespaces/cache.py +237 -0
- spanforge/namespaces/chain.py +77 -0
- spanforge/namespaces/confidence.py +72 -0
- spanforge/namespaces/consent.py +92 -0
- spanforge/namespaces/cost.py +179 -0
- spanforge/namespaces/decision.py +143 -0
- spanforge/namespaces/diff.py +157 -0
- spanforge/namespaces/drift.py +80 -0
- spanforge/namespaces/eval_.py +251 -0
- spanforge/namespaces/feedback.py +241 -0
- spanforge/namespaces/fence.py +193 -0
- spanforge/namespaces/guard.py +105 -0
- spanforge/namespaces/hitl.py +91 -0
- spanforge/namespaces/latency.py +72 -0
- spanforge/namespaces/prompt.py +190 -0
- spanforge/namespaces/redact.py +173 -0
- spanforge/namespaces/retrieval.py +379 -0
- spanforge/namespaces/runtime_governance.py +494 -0
- spanforge/namespaces/template.py +208 -0
- spanforge/namespaces/tool_call.py +77 -0
- spanforge/namespaces/trace.py +1029 -0
- spanforge/normalizer.py +171 -0
- spanforge/plugins.py +82 -0
- spanforge/presidio_backend.py +349 -0
- spanforge/processor.py +258 -0
- spanforge/prompt_registry.py +418 -0
- spanforge/py.typed +0 -0
- spanforge/redact.py +914 -0
- spanforge/regression.py +192 -0
- spanforge/runtime_policy.py +159 -0
- spanforge/sampling.py +511 -0
- spanforge/schema.py +183 -0
- spanforge/schemas/v1.0/schema.json +170 -0
- spanforge/schemas/v2.0/schema.json +536 -0
- spanforge/sdk/__init__.py +625 -0
- spanforge/sdk/_base.py +584 -0
- spanforge/sdk/_base.pyi +71 -0
- spanforge/sdk/_exceptions.py +1096 -0
- spanforge/sdk/_types.py +2184 -0
- spanforge/sdk/alert.py +1514 -0
- spanforge/sdk/alert.pyi +56 -0
- spanforge/sdk/audit.py +1196 -0
- spanforge/sdk/audit.pyi +67 -0
- spanforge/sdk/cec.py +1215 -0
- spanforge/sdk/cec.pyi +37 -0
- spanforge/sdk/config.py +641 -0
- spanforge/sdk/config.pyi +55 -0
- spanforge/sdk/enterprise.py +714 -0
- spanforge/sdk/enterprise.pyi +79 -0
- spanforge/sdk/explain.py +170 -0
- spanforge/sdk/fallback.py +432 -0
- spanforge/sdk/feedback.py +351 -0
- spanforge/sdk/gate.py +874 -0
- spanforge/sdk/gate.pyi +51 -0
- spanforge/sdk/identity.py +2114 -0
- spanforge/sdk/identity.pyi +47 -0
- spanforge/sdk/lineage.py +175 -0
- spanforge/sdk/observe.py +1065 -0
- spanforge/sdk/observe.pyi +50 -0
- spanforge/sdk/operator.py +338 -0
- spanforge/sdk/pii.py +1473 -0
- spanforge/sdk/pii.pyi +119 -0
- spanforge/sdk/pipelines.py +458 -0
- spanforge/sdk/pipelines.pyi +39 -0
- spanforge/sdk/policy.py +930 -0
- spanforge/sdk/rag.py +594 -0
- spanforge/sdk/rbac.py +280 -0
- spanforge/sdk/registry.py +430 -0
- spanforge/sdk/registry.pyi +46 -0
- spanforge/sdk/scope.py +279 -0
- spanforge/sdk/secrets.py +293 -0
- spanforge/sdk/secrets.pyi +25 -0
- spanforge/sdk/security.py +560 -0
- spanforge/sdk/security.pyi +57 -0
- spanforge/sdk/trust.py +472 -0
- spanforge/sdk/trust.pyi +41 -0
- spanforge/secrets.py +799 -0
- spanforge/signing.py +1179 -0
- spanforge/stats.py +100 -0
- spanforge/stream.py +560 -0
- spanforge/testing.py +378 -0
- spanforge/testing_mocks.py +1052 -0
- spanforge/trace.py +199 -0
- spanforge/types.py +696 -0
- spanforge/ulid.py +300 -0
- spanforge/validate.py +379 -0
- spanforge-1.0.0.dist-info/METADATA +1509 -0
- spanforge-1.0.0.dist-info/RECORD +174 -0
- spanforge-1.0.0.dist-info/WHEEL +4 -0
- spanforge-1.0.0.dist-info/entry_points.txt +5 -0
- spanforge-1.0.0.dist-info/licenses/LICENSE +128 -0
spanforge/sdk/rag.py
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
1
|
+
"""spanforge.sdk.rag — SpanForge sf-rag RAG Tracing client (Phase 13).
|
|
2
|
+
|
|
3
|
+
Implements RAG-001 through RAG-006: full tracing for Retrieval-Augmented
|
|
4
|
+
Generation pipelines including query tracing, retrieval tracing, generation
|
|
5
|
+
tracing, grounding scoring, and session summaries.
|
|
6
|
+
|
|
7
|
+
Architecture
|
|
8
|
+
------------
|
|
9
|
+
* :meth:`trace_query` records a ``llm.rag.query`` event and returns the
|
|
10
|
+
auto-generated ``session_id`` to be threaded through subsequent calls.
|
|
11
|
+
* :meth:`trace_retrieval` records a ``llm.rag.retrieved`` event with chunk
|
|
12
|
+
metadata (raw chunk text is NEVER stored).
|
|
13
|
+
* :meth:`trace_generation` records a ``llm.rag.generated`` event linking the
|
|
14
|
+
LLM generation span to the retrieved chunk IDs.
|
|
15
|
+
* :meth:`get_session` returns a :class:`~spanforge.sdk._types.RAGSessionInfo`
|
|
16
|
+
aggregate for a given ``session_id``.
|
|
17
|
+
* :meth:`get_status` returns service health and session statistics.
|
|
18
|
+
|
|
19
|
+
All operations run locally in-process when ``config.endpoint`` is empty or
|
|
20
|
+
the remote service is unreachable and ``local_fallback_enabled`` is ``True``.
|
|
21
|
+
|
|
22
|
+
Security requirements
|
|
23
|
+
---------------------
|
|
24
|
+
* Raw user queries and retrieved document text are **never** stored; only
|
|
25
|
+
SHA-256 content hashes are kept.
|
|
26
|
+
* Chunk ``chunk_id`` values are stored as-is — callers must ensure they do
|
|
27
|
+
not contain personally identifiable information.
|
|
28
|
+
* Thread-safety: all in-process state uses locks.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
from __future__ import annotations
|
|
32
|
+
|
|
33
|
+
import hashlib
|
|
34
|
+
import logging
|
|
35
|
+
import threading
|
|
36
|
+
from dataclasses import dataclass, field
|
|
37
|
+
from datetime import datetime, timezone
|
|
38
|
+
from typing import Any
|
|
39
|
+
|
|
40
|
+
from spanforge.namespaces.retrieval import (
|
|
41
|
+
RAGSessionPayload,
|
|
42
|
+
RAGSpanPayload,
|
|
43
|
+
RetrievalQueryPayload,
|
|
44
|
+
RetrievalResultPayload,
|
|
45
|
+
RetrievedChunk,
|
|
46
|
+
)
|
|
47
|
+
from spanforge.namespaces.runtime_governance import GroundingClaim, GroundingPayload
|
|
48
|
+
from spanforge.sdk._base import SFClientConfig, SFServiceClient
|
|
49
|
+
|
|
50
|
+
__all__ = ["SFRAGClient"]
|
|
51
|
+
|
|
52
|
+
_log = logging.getLogger(__name__)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# ---------------------------------------------------------------------------
|
|
56
|
+
# Local session store
|
|
57
|
+
# ---------------------------------------------------------------------------
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class _RAGSession:
|
|
62
|
+
"""Internal in-process state accumulated across a single RAG session."""
|
|
63
|
+
|
|
64
|
+
session_id: str
|
|
65
|
+
queries: int = 0
|
|
66
|
+
chunk_ids: set[str] = field(default_factory=set)
|
|
67
|
+
input_tokens: int = 0
|
|
68
|
+
output_tokens: int = 0
|
|
69
|
+
grounding_scores: list[float] = field(default_factory=list)
|
|
70
|
+
total_latency_ms: float = 0.0
|
|
71
|
+
retriever_name: str = ""
|
|
72
|
+
status: str = "ok"
|
|
73
|
+
started_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# ---------------------------------------------------------------------------
|
|
77
|
+
# Status dataclass
|
|
78
|
+
# ---------------------------------------------------------------------------
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class RAGStatusInfo:
|
|
83
|
+
"""sf-rag service status.
|
|
84
|
+
|
|
85
|
+
Returned by :meth:`SFRAGClient.get_status`.
|
|
86
|
+
|
|
87
|
+
Attributes:
|
|
88
|
+
status: ``"ok"`` or ``"degraded"``.
|
|
89
|
+
active_sessions: Number of sessions that have been started but not
|
|
90
|
+
yet finalised with :meth:`SFRAGClient.end_session`.
|
|
91
|
+
total_queries: Total ``trace_query`` calls in this process lifetime.
|
|
92
|
+
total_spans: Total ``trace_generation`` calls in this process
|
|
93
|
+
lifetime.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
status: str
|
|
97
|
+
active_sessions: int
|
|
98
|
+
total_queries: int
|
|
99
|
+
total_spans: int
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
# ---------------------------------------------------------------------------
|
|
103
|
+
# Client
|
|
104
|
+
# ---------------------------------------------------------------------------
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class SFRAGClient(SFServiceClient):
|
|
108
|
+
"""SpanForge RAG Tracing service client.
|
|
109
|
+
|
|
110
|
+
Provides end-to-end observability for Retrieval-Augmented Generation
|
|
111
|
+
pipelines. Complements :class:`~spanforge.sdk.observe.SFObserveClient`
|
|
112
|
+
by adding RAG-specific query → retrieval → generation correlation.
|
|
113
|
+
|
|
114
|
+
Example usage::
|
|
115
|
+
|
|
116
|
+
import spanforge
|
|
117
|
+
from spanforge.sdk import sf_rag
|
|
118
|
+
|
|
119
|
+
session_id = sf_rag.trace_query(
|
|
120
|
+
query="What is the capital of France?",
|
|
121
|
+
top_k=5,
|
|
122
|
+
retriever_name="chroma-main",
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
sf_rag.trace_retrieval(
|
|
126
|
+
session_id=session_id,
|
|
127
|
+
chunks=[
|
|
128
|
+
{"chunk_id": "doc-42-p3", "score": 0.93, "source": "docs/geo.md"},
|
|
129
|
+
],
|
|
130
|
+
latency_ms=45.2,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
sf_rag.trace_generation(
|
|
134
|
+
session_id=session_id,
|
|
135
|
+
model="gpt-4o",
|
|
136
|
+
chunk_ids_used=["doc-42-p3"],
|
|
137
|
+
prompt_tokens=512,
|
|
138
|
+
output_tokens=128,
|
|
139
|
+
grounding_score=0.91,
|
|
140
|
+
latency_ms=1230.0,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
summary = sf_rag.end_session(session_id)
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def __init__(self, config: SFClientConfig) -> None:
|
|
147
|
+
super().__init__(config, service_name="rag")
|
|
148
|
+
self._lock = threading.Lock()
|
|
149
|
+
self._sessions: dict[str, _RAGSession] = {}
|
|
150
|
+
self._grounding_records: dict[str, GroundingPayload] = {}
|
|
151
|
+
self._grounding_by_trace: dict[str, list[str]] = {}
|
|
152
|
+
self._total_queries: int = 0
|
|
153
|
+
self._total_spans: int = 0
|
|
154
|
+
|
|
155
|
+
# ------------------------------------------------------------------
|
|
156
|
+
# RAG-001: trace_query
|
|
157
|
+
# ------------------------------------------------------------------
|
|
158
|
+
|
|
159
|
+
def trace_query(
|
|
160
|
+
self,
|
|
161
|
+
query: str,
|
|
162
|
+
*,
|
|
163
|
+
session_id: str | None = None,
|
|
164
|
+
top_k: int = 5,
|
|
165
|
+
retriever_name: str = "",
|
|
166
|
+
embedding_model: str = "",
|
|
167
|
+
namespace: str = "",
|
|
168
|
+
latency_ms: float = 0.0,
|
|
169
|
+
filters: dict[str, Any] | None = None,
|
|
170
|
+
) -> str:
|
|
171
|
+
"""Record a RAG query and return the session ID.
|
|
172
|
+
|
|
173
|
+
The raw *query* text is **never stored**; only its SHA-256 hash is
|
|
174
|
+
retained for correlation.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
query: Raw user query text (hashed, not stored).
|
|
178
|
+
session_id: Optional existing session ID to continue. A new
|
|
179
|
+
ULID is generated when ``None``.
|
|
180
|
+
top_k: Number of chunks to retrieve.
|
|
181
|
+
retriever_name: Name of the vector store / retriever.
|
|
182
|
+
embedding_model: Embedding model used to encode the query.
|
|
183
|
+
namespace: Optional vector store namespace or collection.
|
|
184
|
+
latency_ms: Time to submit the query (ms).
|
|
185
|
+
filters: Metadata filters applied to the retrieval query.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
The ``session_id`` to pass to :meth:`trace_retrieval` and
|
|
189
|
+
:meth:`trace_generation`.
|
|
190
|
+
"""
|
|
191
|
+
from spanforge.ulid import generate as _ulid
|
|
192
|
+
|
|
193
|
+
sid = session_id or _ulid()
|
|
194
|
+
query_hash = hashlib.sha256(query.encode("utf-8")).hexdigest()
|
|
195
|
+
|
|
196
|
+
payload = RetrievalQueryPayload(
|
|
197
|
+
session_id=sid,
|
|
198
|
+
query_hash=query_hash,
|
|
199
|
+
top_k=top_k,
|
|
200
|
+
retriever_name=retriever_name,
|
|
201
|
+
embedding_model=embedding_model,
|
|
202
|
+
namespace=namespace,
|
|
203
|
+
latency_ms=latency_ms,
|
|
204
|
+
filters=filters or {},
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
with self._lock:
|
|
208
|
+
if sid not in self._sessions:
|
|
209
|
+
self._sessions[sid] = _RAGSession(
|
|
210
|
+
session_id=sid,
|
|
211
|
+
retriever_name=retriever_name,
|
|
212
|
+
)
|
|
213
|
+
self._sessions[sid].queries += 1
|
|
214
|
+
self._sessions[sid].total_latency_ms += latency_ms
|
|
215
|
+
if retriever_name and not self._sessions[sid].retriever_name:
|
|
216
|
+
self._sessions[sid].retriever_name = retriever_name
|
|
217
|
+
self._total_queries += 1
|
|
218
|
+
|
|
219
|
+
self._emit_local("llm.rag.query", payload.to_dict(), session_id=sid)
|
|
220
|
+
return sid
|
|
221
|
+
|
|
222
|
+
# ------------------------------------------------------------------
|
|
223
|
+
# RAG-002: trace_retrieval
|
|
224
|
+
# ------------------------------------------------------------------
|
|
225
|
+
|
|
226
|
+
def trace_retrieval(
|
|
227
|
+
self,
|
|
228
|
+
session_id: str,
|
|
229
|
+
chunks: list[dict[str, Any]],
|
|
230
|
+
*,
|
|
231
|
+
total_found: int | None = None,
|
|
232
|
+
latency_ms: float = 0.0,
|
|
233
|
+
query_hash: str = "",
|
|
234
|
+
status: str = "ok",
|
|
235
|
+
error_message: str | None = None,
|
|
236
|
+
) -> None:
|
|
237
|
+
"""Record retrieved chunks for a session.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
session_id: The session ID returned by :meth:`trace_query`.
|
|
241
|
+
chunks: List of chunk dicts, each requiring ``chunk_id``
|
|
242
|
+
(str) and ``score`` (float 0–1). ``source`` and
|
|
243
|
+
``content_hash`` are optional.
|
|
244
|
+
total_found: Total matching chunks before ``top_k`` truncation.
|
|
245
|
+
latency_ms: Time taken for retrieval (ms).
|
|
246
|
+
query_hash: SHA-256 hash of the triggering query (optional).
|
|
247
|
+
status: ``"ok"``, ``"partial"``, ``"error"``, or ``"timeout"``.
|
|
248
|
+
error_message: Error detail when *status* is not ``"ok"``.
|
|
249
|
+
"""
|
|
250
|
+
parsed_chunks = [RetrievedChunk.from_dict(c) for c in chunks]
|
|
251
|
+
payload = RetrievalResultPayload(
|
|
252
|
+
session_id=session_id,
|
|
253
|
+
query_hash=query_hash,
|
|
254
|
+
chunks=parsed_chunks,
|
|
255
|
+
total_found=total_found if total_found is not None else len(parsed_chunks),
|
|
256
|
+
latency_ms=latency_ms,
|
|
257
|
+
status=status, # type: ignore[arg-type]
|
|
258
|
+
error_message=error_message,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
with self._lock:
|
|
262
|
+
if session_id in self._sessions:
|
|
263
|
+
for chunk in parsed_chunks:
|
|
264
|
+
self._sessions[session_id].chunk_ids.add(chunk.chunk_id)
|
|
265
|
+
self._sessions[session_id].total_latency_ms += latency_ms
|
|
266
|
+
if status not in ("ok", "partial"):
|
|
267
|
+
self._sessions[session_id].status = status
|
|
268
|
+
|
|
269
|
+
self._emit_local("llm.rag.retrieved", payload.to_dict(), session_id=session_id)
|
|
270
|
+
|
|
271
|
+
# ------------------------------------------------------------------
|
|
272
|
+
# RAG-003: trace_generation
|
|
273
|
+
# ------------------------------------------------------------------
|
|
274
|
+
|
|
275
|
+
def trace_generation(
|
|
276
|
+
self,
|
|
277
|
+
session_id: str,
|
|
278
|
+
model: str,
|
|
279
|
+
*,
|
|
280
|
+
span_name: str = "rag-generation",
|
|
281
|
+
chunk_ids_used: list[str] | None = None,
|
|
282
|
+
context_tokens: int = 0,
|
|
283
|
+
prompt_tokens: int = 0,
|
|
284
|
+
output_tokens: int = 0,
|
|
285
|
+
latency_ms: float = 0.0,
|
|
286
|
+
status: str = "ok",
|
|
287
|
+
grounding_score: float | None = None,
|
|
288
|
+
error_message: str | None = None,
|
|
289
|
+
) -> None:
|
|
290
|
+
"""Record an LLM generation span over retrieved context.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
session_id: The session ID returned by :meth:`trace_query`.
|
|
294
|
+
model: Model identifier (e.g. ``"gpt-4o"``).
|
|
295
|
+
span_name: Human-readable label for this generation step.
|
|
296
|
+
chunk_ids_used: Chunk IDs included in the context window.
|
|
297
|
+
context_tokens: Tokens consumed by the retrieved context.
|
|
298
|
+
prompt_tokens: Total prompt tokens (context + instruction).
|
|
299
|
+
output_tokens: Tokens in the generated response.
|
|
300
|
+
latency_ms: Generation latency in milliseconds.
|
|
301
|
+
status: ``"ok"``, ``"error"``, or ``"timeout"``.
|
|
302
|
+
grounding_score: 0.0–1.0 grounding quality score (optional).
|
|
303
|
+
error_message: Error detail when *status* is not ``"ok"``.
|
|
304
|
+
"""
|
|
305
|
+
payload = RAGSpanPayload(
|
|
306
|
+
session_id=session_id,
|
|
307
|
+
span_name=span_name,
|
|
308
|
+
model=model,
|
|
309
|
+
chunk_ids_used=chunk_ids_used or [],
|
|
310
|
+
context_tokens=context_tokens,
|
|
311
|
+
prompt_tokens=prompt_tokens,
|
|
312
|
+
output_tokens=output_tokens,
|
|
313
|
+
latency_ms=latency_ms,
|
|
314
|
+
status=status, # type: ignore[arg-type]
|
|
315
|
+
grounding_score=grounding_score,
|
|
316
|
+
error_message=error_message,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
with self._lock:
|
|
320
|
+
if session_id in self._sessions:
|
|
321
|
+
sess = self._sessions[session_id]
|
|
322
|
+
sess.input_tokens += prompt_tokens
|
|
323
|
+
sess.output_tokens += output_tokens
|
|
324
|
+
sess.total_latency_ms += latency_ms
|
|
325
|
+
if grounding_score is not None:
|
|
326
|
+
sess.grounding_scores.append(grounding_score)
|
|
327
|
+
if chunk_ids_used:
|
|
328
|
+
sess.chunk_ids.update(chunk_ids_used)
|
|
329
|
+
if status not in ("ok",):
|
|
330
|
+
sess.status = status
|
|
331
|
+
self._total_spans += 1
|
|
332
|
+
|
|
333
|
+
self._emit_local("llm.rag.generated", payload.to_dict(), session_id=session_id)
|
|
334
|
+
|
|
335
|
+
# ------------------------------------------------------------------
|
|
336
|
+
# RAG-004: end_session / get_session
|
|
337
|
+
# ------------------------------------------------------------------
|
|
338
|
+
|
|
339
|
+
def end_session(self, session_id: str) -> RAGSessionPayload:
|
|
340
|
+
"""Finalise a session and emit a ``llm.rag.session`` summary event.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
session_id: The session ID to finalise.
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
A :class:`~spanforge.namespaces.retrieval.RAGSessionPayload`
|
|
347
|
+
capturing session-level aggregates.
|
|
348
|
+
|
|
349
|
+
Raises:
|
|
350
|
+
KeyError: If *session_id* is unknown.
|
|
351
|
+
"""
|
|
352
|
+
with self._lock:
|
|
353
|
+
sess = self._sessions.pop(session_id)
|
|
354
|
+
|
|
355
|
+
gs: float | None = None
|
|
356
|
+
if sess.grounding_scores:
|
|
357
|
+
gs = sum(sess.grounding_scores) / len(sess.grounding_scores)
|
|
358
|
+
|
|
359
|
+
payload = RAGSessionPayload(
|
|
360
|
+
session_id=session_id,
|
|
361
|
+
total_queries=sess.queries,
|
|
362
|
+
total_chunks_used=len(sess.chunk_ids),
|
|
363
|
+
total_input_tokens=sess.input_tokens,
|
|
364
|
+
total_output_tokens=sess.output_tokens,
|
|
365
|
+
avg_grounding_score=gs,
|
|
366
|
+
total_latency_ms=sess.total_latency_ms,
|
|
367
|
+
status=self._session_status(sess.status), # type: ignore[arg-type]
|
|
368
|
+
retriever_name=sess.retriever_name,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
self._emit_local("llm.rag.session", payload.to_dict(), session_id=session_id)
|
|
372
|
+
return payload
|
|
373
|
+
|
|
374
|
+
def get_session(self, session_id: str) -> RAGSessionPayload | None:
|
|
375
|
+
"""Return a live snapshot for *session_id* without finalising it.
|
|
376
|
+
|
|
377
|
+
Returns ``None`` if the session is unknown.
|
|
378
|
+
"""
|
|
379
|
+
with self._lock:
|
|
380
|
+
sess = self._sessions.get(session_id)
|
|
381
|
+
if sess is None:
|
|
382
|
+
return None
|
|
383
|
+
|
|
384
|
+
gs: float | None = None
|
|
385
|
+
if sess.grounding_scores:
|
|
386
|
+
gs = sum(sess.grounding_scores) / len(sess.grounding_scores)
|
|
387
|
+
|
|
388
|
+
return RAGSessionPayload(
|
|
389
|
+
session_id=session_id,
|
|
390
|
+
total_queries=sess.queries,
|
|
391
|
+
total_chunks_used=len(sess.chunk_ids),
|
|
392
|
+
total_input_tokens=sess.input_tokens,
|
|
393
|
+
total_output_tokens=sess.output_tokens,
|
|
394
|
+
avg_grounding_score=gs,
|
|
395
|
+
total_latency_ms=sess.total_latency_ms,
|
|
396
|
+
status=self._session_status(sess.status), # type: ignore[arg-type]
|
|
397
|
+
retriever_name=sess.retriever_name,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# ------------------------------------------------------------------
|
|
401
|
+
# RAG-005: get_status
|
|
402
|
+
# ------------------------------------------------------------------
|
|
403
|
+
|
|
404
|
+
def get_status(self) -> RAGStatusInfo:
|
|
405
|
+
"""Return service health and session statistics."""
|
|
406
|
+
with self._lock:
|
|
407
|
+
active = len(self._sessions)
|
|
408
|
+
queries = self._total_queries
|
|
409
|
+
spans = self._total_spans
|
|
410
|
+
return RAGStatusInfo(
|
|
411
|
+
status="ok",
|
|
412
|
+
active_sessions=active,
|
|
413
|
+
total_queries=queries,
|
|
414
|
+
total_spans=spans,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
def assess_grounding(
|
|
418
|
+
self,
|
|
419
|
+
*,
|
|
420
|
+
trace_id: str,
|
|
421
|
+
decision_id: str,
|
|
422
|
+
session_id: str,
|
|
423
|
+
claims: list[GroundingClaim | dict[str, Any]] | None,
|
|
424
|
+
threshold: float,
|
|
425
|
+
policy_action: str,
|
|
426
|
+
assessed_at: str,
|
|
427
|
+
grounding_id: str | None = None,
|
|
428
|
+
model_id: str | None = None,
|
|
429
|
+
retriever_name: str | None = None,
|
|
430
|
+
) -> GroundingPayload:
|
|
431
|
+
"""Create and persist a canonical runtime grounding record."""
|
|
432
|
+
from spanforge.ulid import generate as _ulid
|
|
433
|
+
|
|
434
|
+
parsed_claims = [
|
|
435
|
+
claim if isinstance(claim, GroundingClaim) else GroundingClaim.from_dict(claim)
|
|
436
|
+
for claim in (claims or [])
|
|
437
|
+
]
|
|
438
|
+
average_score = self._average_claim_score(parsed_claims)
|
|
439
|
+
payload = GroundingPayload(
|
|
440
|
+
grounding_id=grounding_id or _ulid(),
|
|
441
|
+
trace_id=trace_id,
|
|
442
|
+
decision_id=decision_id,
|
|
443
|
+
session_id=session_id,
|
|
444
|
+
status=self._grounding_status(parsed_claims, average_score, threshold),
|
|
445
|
+
average_score=average_score,
|
|
446
|
+
threshold=threshold,
|
|
447
|
+
policy_action=policy_action,
|
|
448
|
+
assessed_at=assessed_at,
|
|
449
|
+
claims=parsed_claims,
|
|
450
|
+
model_id=model_id,
|
|
451
|
+
retriever_name=retriever_name or self._infer_retriever_name(session_id),
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
with self._lock:
|
|
455
|
+
self._grounding_records[payload.grounding_id] = payload
|
|
456
|
+
self._grounding_by_trace.setdefault(trace_id, []).append(payload.grounding_id)
|
|
457
|
+
|
|
458
|
+
self._emit_signed_record(payload)
|
|
459
|
+
return payload
|
|
460
|
+
|
|
461
|
+
def assess_grounding_with_policy(
|
|
462
|
+
self,
|
|
463
|
+
*,
|
|
464
|
+
environment: str,
|
|
465
|
+
trace_id: str,
|
|
466
|
+
decision_id: str,
|
|
467
|
+
session_id: str,
|
|
468
|
+
claims: list[GroundingClaim | dict[str, Any]] | None,
|
|
469
|
+
assessed_at: str,
|
|
470
|
+
policy_client: Any | None = None,
|
|
471
|
+
control: str = "grounding_threshold",
|
|
472
|
+
threshold: float = 0.0,
|
|
473
|
+
retriever_name: str | None = None,
|
|
474
|
+
model_id: str | None = None,
|
|
475
|
+
) -> GroundingPayload:
|
|
476
|
+
"""Assess grounding using the active runtime policy threshold and action."""
|
|
477
|
+
parsed_claims = [
|
|
478
|
+
claim if isinstance(claim, GroundingClaim) else GroundingClaim.from_dict(claim)
|
|
479
|
+
for claim in (claims or [])
|
|
480
|
+
]
|
|
481
|
+
observed_value = self._average_claim_score(parsed_claims)
|
|
482
|
+
engine = policy_client or self._default_policy_client()
|
|
483
|
+
decision = engine.evaluate(
|
|
484
|
+
environment=environment,
|
|
485
|
+
trace_id=trace_id,
|
|
486
|
+
service="sf_rag",
|
|
487
|
+
control=control,
|
|
488
|
+
evaluated_at=assessed_at,
|
|
489
|
+
observed_value=observed_value,
|
|
490
|
+
metadata={"session_id": session_id},
|
|
491
|
+
)
|
|
492
|
+
return self.assess_grounding(
|
|
493
|
+
trace_id=trace_id,
|
|
494
|
+
decision_id=decision_id,
|
|
495
|
+
session_id=session_id,
|
|
496
|
+
claims=parsed_claims, # type: ignore[arg-type]
|
|
497
|
+
threshold=decision.threshold if decision.threshold is not None else threshold,
|
|
498
|
+
policy_action=decision.action,
|
|
499
|
+
assessed_at=assessed_at,
|
|
500
|
+
retriever_name=retriever_name,
|
|
501
|
+
model_id=model_id,
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
async def assess_grounding_async(self, **kwargs: Any) -> GroundingPayload:
|
|
505
|
+
"""Async wrapper around :meth:`assess_grounding`."""
|
|
506
|
+
import asyncio
|
|
507
|
+
|
|
508
|
+
loop = asyncio.get_event_loop()
|
|
509
|
+
return await loop.run_in_executor(None, lambda: self.assess_grounding(**kwargs))
|
|
510
|
+
|
|
511
|
+
def get_grounding(self, grounding_id: str) -> GroundingPayload | None:
|
|
512
|
+
"""Return a previously emitted grounding assessment."""
|
|
513
|
+
with self._lock:
|
|
514
|
+
return self._grounding_records.get(grounding_id)
|
|
515
|
+
|
|
516
|
+
def list_grounding_for_trace(self, trace_id: str) -> list[GroundingPayload]:
|
|
517
|
+
"""Return all grounding assessments emitted for a trace."""
|
|
518
|
+
with self._lock:
|
|
519
|
+
ids = list(self._grounding_by_trace.get(trace_id, []))
|
|
520
|
+
return [self._grounding_records[item] for item in ids if item in self._grounding_records]
|
|
521
|
+
|
|
522
|
+
# ------------------------------------------------------------------
|
|
523
|
+
# Internal helpers
|
|
524
|
+
# ------------------------------------------------------------------
|
|
525
|
+
|
|
526
|
+
@staticmethod
|
|
527
|
+
def _average_claim_score(claims: list[GroundingClaim]) -> float:
|
|
528
|
+
if not claims:
|
|
529
|
+
return 0.0
|
|
530
|
+
return round(sum(claim.score for claim in claims) / len(claims), 6)
|
|
531
|
+
|
|
532
|
+
@staticmethod
|
|
533
|
+
def _grounding_status(
|
|
534
|
+
claims: list[GroundingClaim],
|
|
535
|
+
average_score: float,
|
|
536
|
+
threshold: float,
|
|
537
|
+
) -> str:
|
|
538
|
+
if not claims:
|
|
539
|
+
return "ungrounded"
|
|
540
|
+
grounded_count = sum(1 for claim in claims if claim.grounded)
|
|
541
|
+
if grounded_count == len(claims) and average_score >= threshold:
|
|
542
|
+
return "grounded"
|
|
543
|
+
if grounded_count == 0:
|
|
544
|
+
return "ungrounded"
|
|
545
|
+
return "partially_grounded"
|
|
546
|
+
|
|
547
|
+
def _infer_retriever_name(self, session_id: str) -> str | None:
|
|
548
|
+
with self._lock:
|
|
549
|
+
session = self._sessions.get(session_id)
|
|
550
|
+
if session is None or not session.retriever_name:
|
|
551
|
+
return None
|
|
552
|
+
return session.retriever_name
|
|
553
|
+
|
|
554
|
+
@staticmethod
|
|
555
|
+
def _session_status(status: str) -> str:
|
|
556
|
+
if status == "timeout":
|
|
557
|
+
return "error"
|
|
558
|
+
return status
|
|
559
|
+
|
|
560
|
+
def _emit_signed_record(self, payload: GroundingPayload) -> None:
|
|
561
|
+
"""Write the grounding payload into sf-audit."""
|
|
562
|
+
from spanforge.sdk import sf_audit
|
|
563
|
+
|
|
564
|
+
sf_audit.append(payload.to_dict(), "spanforge.grounding.v1")
|
|
565
|
+
|
|
566
|
+
@staticmethod
|
|
567
|
+
def _default_policy_client() -> Any:
|
|
568
|
+
from spanforge.sdk import sf_policy
|
|
569
|
+
|
|
570
|
+
return sf_policy
|
|
571
|
+
|
|
572
|
+
def _emit_local(
|
|
573
|
+
self,
|
|
574
|
+
event_type: str,
|
|
575
|
+
payload: dict[str, Any],
|
|
576
|
+
*,
|
|
577
|
+
session_id: str = "",
|
|
578
|
+
) -> None:
|
|
579
|
+
"""Emit a RAG event locally or forward to remote endpoint."""
|
|
580
|
+
try:
|
|
581
|
+
from spanforge.sdk.observe import SFObserveClient
|
|
582
|
+
|
|
583
|
+
observe_config = SFClientConfig(
|
|
584
|
+
endpoint=self._config.endpoint,
|
|
585
|
+
api_key=self._config.api_key,
|
|
586
|
+
)
|
|
587
|
+
obs = SFObserveClient(observe_config)
|
|
588
|
+
obs.emit_span( # type: ignore[call-arg]
|
|
589
|
+
name=event_type,
|
|
590
|
+
payload=payload,
|
|
591
|
+
trace_id=session_id,
|
|
592
|
+
)
|
|
593
|
+
except Exception: # NOSONAR — local fallback; never suppress intentionally
|
|
594
|
+
_log.debug("sf-rag local emit %s session=%s", event_type, session_id)
|