evalgate-sdk 3.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. evalgate_sdk/__init__.py +707 -0
  2. evalgate_sdk/_version.py +3 -0
  3. evalgate_sdk/assertions.py +1362 -0
  4. evalgate_sdk/auto.py +247 -0
  5. evalgate_sdk/batch.py +174 -0
  6. evalgate_sdk/cache.py +111 -0
  7. evalgate_sdk/ci_context.py +123 -0
  8. evalgate_sdk/cli/__init__.py +111 -0
  9. evalgate_sdk/cli/api.py +261 -0
  10. evalgate_sdk/cli/cli_constants.py +20 -0
  11. evalgate_sdk/cli/commands.py +1041 -0
  12. evalgate_sdk/cli/config.py +228 -0
  13. evalgate_sdk/cli/env.py +43 -0
  14. evalgate_sdk/cli/formatters/types.py +132 -0
  15. evalgate_sdk/cli/golden_commands.py +322 -0
  16. evalgate_sdk/cli/manifest.py +301 -0
  17. evalgate_sdk/cli/new_commands.py +435 -0
  18. evalgate_sdk/cli/policy_packs.py +103 -0
  19. evalgate_sdk/cli/profiles.py +12 -0
  20. evalgate_sdk/cli/regression_gate.py +312 -0
  21. evalgate_sdk/cli/render/__init__.py +1 -0
  22. evalgate_sdk/cli/render/snippet.py +18 -0
  23. evalgate_sdk/cli/render/sort.py +29 -0
  24. evalgate_sdk/cli/report/__init__.py +1 -0
  25. evalgate_sdk/cli/report/build_check_report.py +209 -0
  26. evalgate_sdk/cli/traces.py +186 -0
  27. evalgate_sdk/cli/workspace.py +63 -0
  28. evalgate_sdk/client.py +609 -0
  29. evalgate_sdk/cluster.py +359 -0
  30. evalgate_sdk/collector.py +161 -0
  31. evalgate_sdk/constants.py +6 -0
  32. evalgate_sdk/context.py +151 -0
  33. evalgate_sdk/errors.py +236 -0
  34. evalgate_sdk/export.py +238 -0
  35. evalgate_sdk/formatters/__init__.py +11 -0
  36. evalgate_sdk/formatters/github.py +51 -0
  37. evalgate_sdk/formatters/human.py +68 -0
  38. evalgate_sdk/formatters/json_fmt.py +11 -0
  39. evalgate_sdk/formatters/pr_comment.py +80 -0
  40. evalgate_sdk/golden.py +426 -0
  41. evalgate_sdk/integrations/__init__.py +1 -0
  42. evalgate_sdk/integrations/anthropic.py +99 -0
  43. evalgate_sdk/integrations/autogen.py +62 -0
  44. evalgate_sdk/integrations/crewai.py +61 -0
  45. evalgate_sdk/integrations/langchain.py +100 -0
  46. evalgate_sdk/integrations/openai.py +155 -0
  47. evalgate_sdk/integrations/openai_eval.py +221 -0
  48. evalgate_sdk/local.py +144 -0
  49. evalgate_sdk/logger.py +123 -0
  50. evalgate_sdk/matchers.py +62 -0
  51. evalgate_sdk/otel.py +256 -0
  52. evalgate_sdk/pagination.py +145 -0
  53. evalgate_sdk/py.typed +0 -0
  54. evalgate_sdk/pytest_plugin.py +96 -0
  55. evalgate_sdk/reason_codes.py +103 -0
  56. evalgate_sdk/regression.py +196 -0
  57. evalgate_sdk/replay_decision.py +115 -0
  58. evalgate_sdk/runtime/__init__.py +50 -0
  59. evalgate_sdk/runtime/adapters/__init__.py +1 -0
  60. evalgate_sdk/runtime/adapters/config_to_dsl.py +270 -0
  61. evalgate_sdk/runtime/adapters/testsuite_to_dsl.py +213 -0
  62. evalgate_sdk/runtime/context.py +68 -0
  63. evalgate_sdk/runtime/eval.py +318 -0
  64. evalgate_sdk/runtime/execution_mode.py +170 -0
  65. evalgate_sdk/runtime/executor.py +92 -0
  66. evalgate_sdk/runtime/registry.py +125 -0
  67. evalgate_sdk/runtime/run_report.py +249 -0
  68. evalgate_sdk/runtime/types.py +143 -0
  69. evalgate_sdk/snapshot.py +219 -0
  70. evalgate_sdk/streaming.py +124 -0
  71. evalgate_sdk/synthesize.py +226 -0
  72. evalgate_sdk/testing.py +128 -0
  73. evalgate_sdk/types.py +666 -0
  74. evalgate_sdk/utils/__init__.py +1 -0
  75. evalgate_sdk/utils/input_hash.py +42 -0
  76. evalgate_sdk/workflows.py +264 -0
  77. evalgate_sdk-3.3.1.dist-info/METADATA +608 -0
  78. evalgate_sdk-3.3.1.dist-info/RECORD +80 -0
  79. evalgate_sdk-3.3.1.dist-info/WHEEL +4 -0
  80. evalgate_sdk-3.3.1.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,359 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any
5
+
6
+ from evalgate_sdk.golden import NormalizedRunArtifact, NormalizedRunCase
7
+
8
+ STOP_WORDS = {
9
+ "the",
10
+ "and",
11
+ "for",
12
+ "with",
13
+ "that",
14
+ "this",
15
+ "from",
16
+ "into",
17
+ "your",
18
+ "have",
19
+ "should",
20
+ "would",
21
+ "could",
22
+ "about",
23
+ "what",
24
+ "when",
25
+ "where",
26
+ "while",
27
+ "were",
28
+ "them",
29
+ "then",
30
+ "than",
31
+ "also",
32
+ "been",
33
+ "because",
34
+ "expected",
35
+ "actual",
36
+ "input",
37
+ "output",
38
+ "error",
39
+ "failed",
40
+ "passed",
41
+ "skipped",
42
+ "result",
43
+ "spec",
44
+ "case",
45
+ "file",
46
+ }
47
+
48
+
49
+ @dataclass(slots=True)
50
+ class ClusterSample:
51
+ case_id: str
52
+ name: str
53
+
54
+ def to_dict(self) -> dict[str, Any]:
55
+ return {
56
+ "caseId": self.case_id,
57
+ "name": self.name,
58
+ }
59
+
60
+
61
+ @dataclass(slots=True)
62
+ class ClusterCase:
63
+ case_id: str
64
+ name: str
65
+ file_path: str
66
+ status: str
67
+ input: str
68
+ expected: str
69
+ actual: str
70
+
71
+ def to_dict(self) -> dict[str, Any]:
72
+ return {
73
+ "caseId": self.case_id,
74
+ "name": self.name,
75
+ "filePath": self.file_path,
76
+ "status": self.status,
77
+ "input": self.input,
78
+ "expected": self.expected,
79
+ "actual": self.actual,
80
+ }
81
+
82
+
83
+ @dataclass(slots=True)
84
+ class TraceCluster:
85
+ id: str
86
+ cluster_label: str
87
+ dominant_pattern: str
88
+ suggested_failure_mode: str | None
89
+ similarity_threshold: float
90
+ trace_ids: list[str]
91
+ trace_count: int
92
+ keywords: list[str]
93
+ member_ids: list[str]
94
+ member_count: int
95
+ density: float
96
+ status_counts: dict[str, int]
97
+ samples: list[ClusterSample] = field(default_factory=list)
98
+ cases: list[ClusterCase] = field(default_factory=list)
99
+
100
+ def to_dict(self) -> dict[str, Any]:
101
+ return {
102
+ "id": self.id,
103
+ "clusterLabel": self.cluster_label,
104
+ "dominantPattern": self.dominant_pattern,
105
+ "suggestedFailureMode": self.suggested_failure_mode,
106
+ "similarityThreshold": self.similarity_threshold,
107
+ "traceIds": list(self.trace_ids),
108
+ "traceCount": self.trace_count,
109
+ "keywords": list(self.keywords),
110
+ "memberIds": list(self.member_ids),
111
+ "memberCount": self.member_count,
112
+ "density": self.density,
113
+ "statusCounts": dict(self.status_counts),
114
+ "samples": [sample.to_dict() for sample in self.samples],
115
+ "cases": [case.to_dict() for case in self.cases],
116
+ }
117
+
118
+
119
+ @dataclass(slots=True)
120
+ class ClusterSummary:
121
+ run_id: str
122
+ total_run_results: int
123
+ clustered_cases: int
124
+ skipped_cases: int
125
+ requested_clusters: int | None
126
+ include_passed: bool
127
+ clusters: list[TraceCluster]
128
+
129
+ def to_dict(self) -> dict[str, Any]:
130
+ return {
131
+ "runId": self.run_id,
132
+ "totalRunResults": self.total_run_results,
133
+ "clusteredCases": self.clustered_cases,
134
+ "skippedCases": self.skipped_cases,
135
+ "requestedClusters": self.requested_clusters,
136
+ "includePassed": self.include_passed,
137
+ "clusters": [cluster.to_dict() for cluster in self.clusters],
138
+ }
139
+
140
+
141
+ def _tokenize(text: str) -> list[str]:
142
+ normalized = []
143
+ current = []
144
+ for char in text.lower():
145
+ if char.isalnum():
146
+ current.append(char)
147
+ else:
148
+ if current:
149
+ token = "".join(current)
150
+ if len(token) > 2 and token not in STOP_WORDS:
151
+ normalized.append(token)
152
+ current = []
153
+ if current:
154
+ token = "".join(current)
155
+ if len(token) > 2 and token not in STOP_WORDS:
156
+ normalized.append(token)
157
+ return normalized
158
+
159
+
160
+ def _token_set(text: str) -> set[str]:
161
+ return set(_tokenize(text))
162
+
163
+
164
+ def _jaccard(a: set[str], b: set[str]) -> float:
165
+ if not a and not b:
166
+ return 1.0
167
+ if not a or not b:
168
+ return 0.0
169
+ intersection = len(a.intersection(b))
170
+ union = len(a.union(b))
171
+ return intersection / union if union else 0.0
172
+
173
+
174
+ def _build_case_text(case: NormalizedRunCase) -> str:
175
+ return "\n".join(
176
+ part
177
+ for part in [case.name, case.file_path, case.error or "", case.input, case.expected, case.actual]
178
+ if part.strip()
179
+ )
180
+
181
+
182
+ def _centroid_keywords(texts: list[str], top_n: int = 4) -> list[str]:
183
+ frequencies: dict[str, int] = {}
184
+ for text in texts:
185
+ for token in _tokenize(text):
186
+ frequencies[token] = frequencies.get(token, 0) + 1
187
+ ordered = sorted(frequencies.items(), key=lambda item: (-item[1], item[0]))
188
+ return [token for token, _count in ordered[:top_n]]
189
+
190
+
191
+ def _suggest_failure_mode(text: str, has_failed_members: bool) -> str | None:
192
+ if not has_failed_members:
193
+ return None
194
+ lower = text.lower()
195
+ if "timeout" in lower or "slow" in lower:
196
+ return "performance_timeout"
197
+ if "null" in lower or "undefined" in lower:
198
+ return "null_reference"
199
+ if "format" in lower or "parse" in lower:
200
+ return "format_mismatch"
201
+ if "constraint" in lower or "validation" in lower:
202
+ return "constraint_violation"
203
+ if "tone" in lower or "empathetic" in lower or "professional" in lower:
204
+ return "tone_mismatch"
205
+ if "halluc" in lower or "invented" in lower or "grounding" in lower:
206
+ return "hallucination"
207
+ return "general_failure"
208
+
209
+
210
+ def _density(member_sets: list[set[str]]) -> float:
211
+ if len(member_sets) < 2:
212
+ return 1.0
213
+ total = 0.0
214
+ count = 0
215
+ for index, current in enumerate(member_sets):
216
+ for other in member_sets[index + 1 :]:
217
+ total += _jaccard(current, other)
218
+ count += 1
219
+ return total / count if count else 1.0
220
+
221
+
222
+ def _assign_clusters(points: list[dict[str, Any]], cluster_count: int) -> list[list[dict[str, Any]]]:
223
+ if not points:
224
+ return []
225
+ clusters: list[list[dict[str, Any]]] = []
226
+ for point in points:
227
+ if len(clusters) < cluster_count:
228
+ clusters.append([point])
229
+ continue
230
+ best_index = 0
231
+ best_score = -1.0
232
+ for index, cluster in enumerate(clusters):
233
+ scores = [_jaccard(point["tokens"], member["tokens"]) for member in cluster]
234
+ avg_score = sum(scores) / len(scores) if scores else 0.0
235
+ if avg_score > best_score:
236
+ best_score = avg_score
237
+ best_index = index
238
+ clusters[best_index].append(point)
239
+ return clusters
240
+
241
+
242
+ def cluster_run_result(
243
+ run_artifact: NormalizedRunArtifact,
244
+ *,
245
+ clusters: int | None = None,
246
+ include_passed: bool = False,
247
+ ) -> ClusterSummary:
248
+ candidates = [case for case in run_artifact.cases if include_passed or case.status == "failed"]
249
+ if not candidates:
250
+ return ClusterSummary(
251
+ run_id=run_artifact.run_id,
252
+ total_run_results=run_artifact.total_run_results,
253
+ clustered_cases=0,
254
+ skipped_cases=run_artifact.total_run_results,
255
+ requested_clusters=clusters,
256
+ include_passed=include_passed,
257
+ clusters=[],
258
+ )
259
+
260
+ points = []
261
+ for case in candidates:
262
+ text = _build_case_text(case)
263
+ points.append({"case": case, "text": text, "tokens": _token_set(text)})
264
+
265
+ cluster_count = clusters or min(8, max(1, round(len(points) ** 0.5)))
266
+ grouped = _assign_clusters(points, min(cluster_count, len(points)))
267
+
268
+ built_clusters: list[TraceCluster] = []
269
+ for index, members in enumerate(grouped):
270
+ texts = [member["text"] for member in members]
271
+ keywords = _centroid_keywords(texts)
272
+ dominant_pattern = ", ".join(keywords[:3]) if keywords else "No clear pattern"
273
+ failed_text = " ".join(member["text"] for member in members if member["case"].status == "failed")
274
+ status_counts = {"passed": 0, "failed": 0, "skipped": 0}
275
+ for member in members:
276
+ status = member["case"].status
277
+ if status == "passed":
278
+ status_counts["passed"] += 1
279
+ elif status == "skipped":
280
+ status_counts["skipped"] += 1
281
+ else:
282
+ status_counts["failed"] += 1
283
+ member_token_sets = [member["tokens"] for member in members]
284
+ density = _density(member_token_sets)
285
+ cases = [
286
+ ClusterCase(
287
+ case_id=member["case"].case_id,
288
+ name=member["case"].name,
289
+ file_path=member["case"].file_path,
290
+ status=member["case"].status,
291
+ input=member["case"].input,
292
+ expected=member["case"].expected,
293
+ actual=member["case"].actual,
294
+ )
295
+ for member in members
296
+ ]
297
+ built_clusters.append(
298
+ TraceCluster(
299
+ id=f"cluster-{index}",
300
+ cluster_label=", ".join(keywords[:3]) if keywords else f"Cluster {index + 1}",
301
+ dominant_pattern=dominant_pattern,
302
+ suggested_failure_mode=_suggest_failure_mode(failed_text, status_counts["failed"] > 0),
303
+ similarity_threshold=density,
304
+ trace_ids=[member["case"].case_id for member in members],
305
+ trace_count=len(members),
306
+ keywords=keywords,
307
+ member_ids=[member["case"].case_id for member in members],
308
+ member_count=len(members),
309
+ density=density,
310
+ status_counts=status_counts,
311
+ samples=[
312
+ ClusterSample(case_id=member["case"].case_id, name=member["case"].name)
313
+ for member in members[:3]
314
+ ],
315
+ cases=cases,
316
+ )
317
+ )
318
+
319
+ built_clusters.sort(key=lambda item: (-item.trace_count, -item.similarity_threshold, item.id))
320
+ return ClusterSummary(
321
+ run_id=run_artifact.run_id,
322
+ total_run_results=run_artifact.total_run_results,
323
+ clustered_cases=len(candidates),
324
+ skipped_cases=run_artifact.total_run_results - len(candidates),
325
+ requested_clusters=clusters,
326
+ include_passed=include_passed,
327
+ clusters=built_clusters,
328
+ )
329
+
330
+
331
+ def format_cluster_human(summary: ClusterSummary) -> str:
332
+ lines = [
333
+ "Cluster phase",
334
+ f"Run: {summary.run_id}",
335
+ f"Clustered {summary.clustered_cases} case(s) into {len(summary.clusters)} cluster(s)",
336
+ ]
337
+ if summary.skipped_cases > 0:
338
+ lines.append(
339
+ f"Skipped {summary.skipped_cases} case(s) ({'none filtered' if summary.include_passed else 'use --include-passed to include non-failures'})"
340
+ )
341
+ if not summary.clusters:
342
+ lines.append("No cases available for clustering")
343
+ return "\n".join(lines)
344
+ for index, cluster in enumerate(summary.clusters, start=1):
345
+ lines.append("")
346
+ lines.append(
347
+ f"{index}. {cluster.id} — {cluster.cluster_label} ({cluster.trace_count} case(s), {(cluster.similarity_threshold * 100):.1f}% similarity)"
348
+ )
349
+ lines.append(
350
+ f" status: {cluster.status_counts['failed']} failed, {cluster.status_counts['passed']} passed, {cluster.status_counts['skipped']} skipped"
351
+ )
352
+ if cluster.dominant_pattern:
353
+ lines.append(f" pattern: {cluster.dominant_pattern}")
354
+ if cluster.suggested_failure_mode:
355
+ lines.append(f" suggested failure mode: {cluster.suggested_failure_mode}")
356
+ if cluster.samples:
357
+ sample_line = ", ".join(f"{sample.case_id} ({sample.name})" for sample in cluster.samples)
358
+ lines.append(f" samples: {sample_line}")
359
+ return "\n".join(lines)
@@ -0,0 +1,161 @@
1
+ """Production trace collector — report traces with client-side sampling.
2
+
3
+ Port of the TypeScript SDK's ``collector.ts``.
4
+
5
+ Usage::
6
+
7
+ from evalgate_sdk import AIEvalClient
8
+ from evalgate_sdk.collector import report_trace
9
+
10
+ client = AIEvalClient(api_key="...")
11
+ result = await report_trace(client, ReportTraceInput(
12
+ trace_id="t-123",
13
+ name="chat-completion",
14
+ spans=[CollectorSpanInput(span_id="s-1", name="llm-call")],
15
+ ))
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import random
21
+ from dataclasses import dataclass, field
22
+ from typing import Any, Literal
23
+
24
+
25
+ @dataclass
26
+ class CollectorSpanInput:
27
+ span_id: str
28
+ name: str
29
+ type: Literal["llm", "tool", "agent", "retrieval", "default"] | None = None
30
+ parent_span_id: str | None = None
31
+ input: Any = None
32
+ output: Any = None
33
+ model: str | None = None
34
+ vendor: str | None = None
35
+ params: dict[str, Any] | None = None
36
+ metrics: dict[str, Any] | None = None
37
+ timestamps: dict[str, float] | None = None
38
+ error: dict[str, Any] | None = None
39
+ behavioral: dict[str, Any] | None = None
40
+ metadata: dict[str, Any] | None = None
41
+
42
+ def to_dict(self) -> dict[str, Any]:
43
+ d: dict[str, Any] = {"span_id": self.span_id, "name": self.name}
44
+ for key in (
45
+ "type",
46
+ "parent_span_id",
47
+ "input",
48
+ "output",
49
+ "model",
50
+ "vendor",
51
+ "params",
52
+ "metrics",
53
+ "timestamps",
54
+ "error",
55
+ "behavioral",
56
+ "metadata",
57
+ ):
58
+ val = getattr(self, key)
59
+ if val is not None:
60
+ d[key] = val
61
+ return d
62
+
63
+
64
+ @dataclass
65
+ class CollectorFeedbackInput:
66
+ type: Literal["thumbs_up", "thumbs_down", "rating", "comment"]
67
+ value: Any = None
68
+ user_id: str | None = None
69
+
70
+ def to_dict(self) -> dict[str, Any]:
71
+ d: dict[str, Any] = {"type": self.type}
72
+ if self.value is not None:
73
+ d["value"] = self.value
74
+ if self.user_id is not None:
75
+ d["user_id"] = self.user_id
76
+ return d
77
+
78
+
79
+ @dataclass
80
+ class ReportTraceInput:
81
+ trace_id: str
82
+ name: str
83
+ spans: list[CollectorSpanInput] = field(default_factory=list)
84
+ status: Literal["pending", "success", "error"] | None = None
85
+ duration_ms: float | None = None
86
+ source: Literal["sdk", "api", "cli"] | None = None
87
+ environment: Literal["production", "staging", "dev"] | None = None
88
+ metadata: dict[str, Any] | None = None
89
+ user_feedback: CollectorFeedbackInput | None = None
90
+
91
+ def to_dict(self) -> dict[str, Any]:
92
+ d: dict[str, Any] = {
93
+ "trace_id": self.trace_id,
94
+ "name": self.name,
95
+ "spans": [s.to_dict() for s in self.spans],
96
+ }
97
+ for key in ("status", "duration_ms", "source", "environment", "metadata"):
98
+ val = getattr(self, key)
99
+ if val is not None:
100
+ d[key] = val
101
+ if self.user_feedback is not None:
102
+ d["user_feedback"] = self.user_feedback.to_dict()
103
+ return d
104
+
105
+
106
+ @dataclass
107
+ class ReportTraceOptions:
108
+ """Options for ``report_trace``."""
109
+
110
+ sample_rate: float = 1.0
111
+
112
+
113
+ @dataclass
114
+ class ReportTraceResult:
115
+ sent: bool
116
+ trace_id: str
117
+ trace_db_id: int | None = None
118
+ span_count: int | None = None
119
+ queued_for_analysis: bool | None = None
120
+ skip_reason: str | None = None
121
+
122
+
123
+ async def report_trace(
124
+ client: Any,
125
+ input: ReportTraceInput,
126
+ options: ReportTraceOptions | None = None,
127
+ ) -> ReportTraceResult:
128
+ """Report a production trace to the collector endpoint.
129
+
130
+ Client-side sampling: set ``options.sample_rate`` (0–1).
131
+ Error traces and thumbs-down feedback bypass sampling.
132
+ """
133
+ opts = options or ReportTraceOptions()
134
+
135
+ is_error = input.status == "error"
136
+ is_negative_feedback = input.user_feedback is not None and input.user_feedback.type == "thumbs_down"
137
+ bypass_sampling = is_error or is_negative_feedback
138
+
139
+ if not bypass_sampling and opts.sample_rate < 1.0 and random.random() >= opts.sample_rate:
140
+ return ReportTraceResult(
141
+ sent=False,
142
+ trace_id=input.trace_id,
143
+ skip_reason="sampled_out",
144
+ )
145
+
146
+ try:
147
+ response = await client._request("POST", "/api/collector", json=input.to_dict())
148
+ except Exception as exc:
149
+ return ReportTraceResult(
150
+ sent=False,
151
+ trace_id=input.trace_id,
152
+ skip_reason=f"request_failed: {exc}",
153
+ )
154
+
155
+ return ReportTraceResult(
156
+ sent=True,
157
+ trace_id=input.trace_id,
158
+ trace_db_id=response.get("trace_db_id"),
159
+ span_count=response.get("span_count"),
160
+ queued_for_analysis=response.get("queued_for_analysis"),
161
+ )
@@ -0,0 +1,6 @@
1
+ """Default constants for the EvalGate SDK.
2
+
3
+ Port of ``constants.ts``.
4
+ """
5
+
6
+ DEFAULT_BASE_URL = "https://api.evalgate.com"
@@ -0,0 +1,151 @@
1
+ """Context propagation using Python's contextvars — thread-safe and async-safe."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import contextvars
6
+ import functools
7
+ from collections.abc import Callable
8
+ from typing import Any, TypeVar, overload
9
+
10
+ T = TypeVar("T")
11
+ F = TypeVar("F", bound=Callable[..., Any])
12
+
13
+ ContextMetadata = dict[str, Any]
14
+
15
+ _ctx_var: contextvars.ContextVar[ContextMetadata | None] = contextvars.ContextVar("evalai_context", default=None)
16
+
17
+
18
+ class EvalContext:
19
+ """Manages a context scope with metadata propagation."""
20
+
21
+ def __init__(self, metadata: ContextMetadata) -> None:
22
+ self._metadata = dict(metadata)
23
+ self._token: contextvars.Token[ContextMetadata | None] | None = None
24
+
25
+ @property
26
+ def metadata(self) -> ContextMetadata:
27
+ return dict(self._metadata)
28
+
29
+ def enter(self) -> None:
30
+ parent = _ctx_var.get()
31
+ merged = {**(parent or {}), **self._metadata}
32
+ self._token = _ctx_var.set(merged)
33
+
34
+ def exit(self) -> None:
35
+ if self._token is not None:
36
+ _ctx_var.reset(self._token)
37
+ self._token = None
38
+
39
+ def __enter__(self) -> EvalContext:
40
+ self.enter()
41
+ return self
42
+
43
+ def __exit__(self, *args: Any) -> None:
44
+ self.exit()
45
+
46
+ async def __aenter__(self) -> EvalContext:
47
+ self.enter()
48
+ return self
49
+
50
+ async def __aexit__(self, *args: Any) -> None:
51
+ self.exit()
52
+
53
+
54
+ def create_context(metadata: ContextMetadata) -> EvalContext:
55
+ """Create a new context scope."""
56
+ return EvalContext(metadata)
57
+
58
+
59
+ def get_current_context() -> ContextMetadata | None:
60
+ """Get the current context metadata, or None if no context is active."""
61
+ return _ctx_var.get()
62
+
63
+
64
+ def merge_with_context(metadata: dict[str, Any] | None = None) -> dict[str, Any]:
65
+ """Merge provided metadata with the current context."""
66
+ current = _ctx_var.get() or {}
67
+ return {**current, **(metadata or {})}
68
+
69
+
70
+ async def with_context(metadata: ContextMetadata, fn: Callable[[], Any]) -> Any:
71
+ """Run *fn* inside a new context scope (async)."""
72
+ ctx = create_context(metadata)
73
+ async with ctx:
74
+ result = fn()
75
+ if hasattr(result, "__await__"):
76
+ return await result
77
+ return result
78
+
79
+
80
+ def with_context_sync(metadata: ContextMetadata, fn: Callable[[], T]) -> T:
81
+ """Run *fn* inside a new context scope (sync)."""
82
+ ctx = create_context(metadata)
83
+ with ctx:
84
+ return fn()
85
+
86
+
87
+ def clone_context(metadata: ContextMetadata) -> ContextMetadata:
88
+ """Deep-copy metadata."""
89
+ import copy
90
+
91
+ return copy.deepcopy(metadata)
92
+
93
+
94
+ def merge_contexts(*contexts: ContextMetadata) -> ContextMetadata:
95
+ """Merge multiple context dicts left-to-right."""
96
+ result: ContextMetadata = {}
97
+ for c in contexts:
98
+ result.update(c)
99
+ return result
100
+
101
+
102
+ def validate_context(metadata: ContextMetadata) -> None:
103
+ """Validate context metadata — raises ValueError if invalid."""
104
+ if not isinstance(metadata, dict):
105
+ raise ValueError("Context metadata must be a dict")
106
+ for key in metadata:
107
+ if not isinstance(key, str):
108
+ raise ValueError(f"Context keys must be strings, got {type(key)}")
109
+
110
+
111
+ class WithContext:
112
+ """Decorator that wraps a function/method in a context scope.
113
+
114
+ Works with both sync and async functions::
115
+
116
+ @WithContext({"service": "MyService"})
117
+ async def process(self, data):
118
+ ...
119
+
120
+ @WithContext({"component": "parser"})
121
+ def parse(text):
122
+ ...
123
+ """
124
+
125
+ def __init__(self, metadata: ContextMetadata) -> None:
126
+ self._metadata = metadata
127
+
128
+ @overload
129
+ def __call__(self, fn: F) -> F: ...
130
+
131
+ def __call__(self, fn: Callable[..., Any]) -> Callable[..., Any]:
132
+ import asyncio
133
+
134
+ if asyncio.iscoroutinefunction(fn):
135
+
136
+ @functools.wraps(fn)
137
+ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
138
+ ctx = create_context(self._metadata)
139
+ async with ctx:
140
+ return await fn(*args, **kwargs)
141
+
142
+ return async_wrapper # type: ignore[return-value]
143
+ else:
144
+
145
+ @functools.wraps(fn)
146
+ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
147
+ ctx = create_context(self._metadata)
148
+ with ctx:
149
+ return fn(*args, **kwargs)
150
+
151
+ return sync_wrapper # type: ignore[return-value]