contexttrace 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
contexttrace/client.py ADDED
@@ -0,0 +1,1074 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from types import TracebackType
5
+ from typing import Any, Iterable, Optional
6
+
7
+ from contexttrace.config import ContextTraceConfig, load_config
8
+ from contexttrace.errors import ContextTraceConfigError
9
+ from contexttrace.local import LocalTransport
10
+ from contexttrace.report import ReportGenerator
11
+ from contexttrace.transport import AsyncHttpTransport, AsyncTransport, HttpTransport, Transport
12
+
13
+ logger = logging.getLogger("contexttrace")
14
+
15
+
16
+ class ContextTrace:
17
+ def __init__(
18
+ self,
19
+ *,
20
+ api_key: Optional[str] = None,
21
+ project: Optional[str] = None,
22
+ base_url: Optional[str] = None,
23
+ api_url: Optional[str] = None,
24
+ mode: Optional[str] = None,
25
+ local_only: Optional[bool] = None,
26
+ transport: Transport | None = None,
27
+ timeout: Optional[float] = None,
28
+ retries: Optional[int] = None,
29
+ debug: Optional[bool] = None,
30
+ local_store_dir: Optional[str] = None,
31
+ storage_path: Optional[str] = None,
32
+ log_chunk_text: Optional[bool] = None,
33
+ log_answer_text: Optional[bool] = None,
34
+ config_path: Optional[str] = None,
35
+ ) -> None:
36
+ self.config = load_config(
37
+ api_key=api_key,
38
+ project=project,
39
+ base_url=base_url,
40
+ api_url=api_url,
41
+ mode=mode,
42
+ local_only=local_only,
43
+ timeout=timeout,
44
+ retries=retries,
45
+ debug=debug,
46
+ local_store_dir=local_store_dir,
47
+ storage_path=storage_path,
48
+ log_chunk_text=log_chunk_text,
49
+ log_answer_text=log_answer_text,
50
+ config_path=config_path,
51
+ )
52
+ _configure_logging(self.config)
53
+ self.project = self.config.project
54
+ self.mode = self.config.mode
55
+ self._transport = transport or self._build_transport(self.config)
56
+
57
+ def _build_transport(self, config: ContextTraceConfig) -> Transport:
58
+ if config.mode == "local":
59
+ return LocalTransport(
60
+ store_dir=config.local_store_dir,
61
+ storage_path=config.storage_path,
62
+ debug=config.debug,
63
+ log_chunk_text=config.log_chunk_text,
64
+ log_answer_text=config.log_answer_text,
65
+ )
66
+ if not config.api_key:
67
+ raise ContextTraceConfigError(
68
+ "ContextTrace api_key is required in hosted mode. Pass api_key=..., "
69
+ "set CONTEXTTRACE_API_KEY, or run with mode='local'."
70
+ )
71
+ return HttpTransport(
72
+ base_url=config.base_url,
73
+ api_key=config.api_key,
74
+ timeout=config.timeout,
75
+ retries=config.retries,
76
+ debug=config.debug,
77
+ )
78
+
79
+ def trace(self, *, query: str, metadata: dict[str, Any] | None = None) -> "TraceSession":
80
+ return TraceSession(
81
+ transport=self._transport,
82
+ project=self.project,
83
+ query=query,
84
+ metadata=metadata or {},
85
+ )
86
+
87
+ def create_eval_set(
88
+ self,
89
+ name: str,
90
+ *,
91
+ metadata: dict[str, Any] | None = None,
92
+ ) -> dict[str, Any]:
93
+ return self._transport.post(
94
+ "/v1/eval-sets",
95
+ {
96
+ "name": name,
97
+ "metadata": metadata or {},
98
+ },
99
+ )
100
+
101
+ def add_eval_questions(
102
+ self,
103
+ eval_set_id: str,
104
+ questions: Iterable[Any],
105
+ ) -> dict[str, Any]:
106
+ return self._transport.post(
107
+ f"/v1/eval-sets/{eval_set_id}/questions",
108
+ {"questions": [_normalize_eval_question(question) for question in questions]},
109
+ )
110
+
111
+ def evaluate_existing_traces(self, eval_set_id: str) -> dict[str, Any]:
112
+ return self._transport.post(f"/v1/eval-sets/{eval_set_id}/runs", {})
113
+
114
+ def register_rag_endpoint(
115
+ self,
116
+ *,
117
+ project_id: str,
118
+ name: str,
119
+ url: str,
120
+ method: str = "POST",
121
+ headers: Optional[dict[str, str]] = None,
122
+ body_template: Optional[dict[str, Any]] = None,
123
+ response_mapping: Optional[dict[str, str]] = None,
124
+ ) -> dict[str, Any]:
125
+ return self._transport.post(
126
+ f"/v1/projects/{project_id}/external-endpoints",
127
+ {
128
+ "name": name,
129
+ "url": url,
130
+ "method": method,
131
+ "headers": headers or {},
132
+ "body_template": body_template or {"question": "{{query}}"},
133
+ "response_mapping": response_mapping
134
+ or {
135
+ "answer": "$.answer",
136
+ "citations": "$.citations",
137
+ "retrieved_chunks": "$.retrieved_chunks",
138
+ },
139
+ },
140
+ )
141
+
142
+ def test_rag_endpoint(
143
+ self,
144
+ endpoint_id: str,
145
+ *,
146
+ query: str,
147
+ metadata: Optional[dict[str, Any]] = None,
148
+ ) -> dict[str, Any]:
149
+ return self._transport.post(
150
+ f"/v1/external-endpoints/{endpoint_id}/test",
151
+ {
152
+ "query": query,
153
+ "metadata": metadata or {},
154
+ },
155
+ )
156
+
157
+ def evaluate_rag_endpoint(
158
+ self,
159
+ endpoint_id: str,
160
+ *,
161
+ eval_set_id: str,
162
+ ) -> dict[str, Any]:
163
+ return self._transport.post(
164
+ f"/v1/external-endpoints/{endpoint_id}/run-eval",
165
+ {"eval_set_id": eval_set_id},
166
+ )
167
+
168
+ def list_traces(self, *, limit: int = 20) -> list[dict[str, Any]]:
169
+ response = self._transport.get(f"/v1/traces?limit={limit}")
170
+ traces = response.get("traces") or []
171
+ if not isinstance(traces, list):
172
+ raise ValueError("Trace list response did not include a traces list.")
173
+ return traces[:limit]
174
+
175
+ def get_trace(self, trace_id: str) -> dict[str, Any]:
176
+ return self._transport.get(f"/v1/traces/{trace_id}")
177
+
178
+ def last_trace(self) -> Optional[dict[str, Any]]:
179
+ try:
180
+ return self._transport.get("/v1/traces/last")
181
+ except Exception:
182
+ traces = self.list_traces(limit=1)
183
+ return traces[0] if traces else None
184
+
185
+ def export_report(
186
+ self,
187
+ *,
188
+ trace_id: Optional[str] = None,
189
+ path: str = "report.html",
190
+ last: bool = False,
191
+ ) -> str:
192
+ if last:
193
+ trace = self.last_trace()
194
+ if trace is None:
195
+ raise ValueError("No traces found to export.")
196
+ elif trace_id:
197
+ trace = self.get_trace(trace_id)
198
+ else:
199
+ raise ValueError("Pass trace_id=... or last=True.")
200
+ return ReportGenerator().generate(trace, path=path)
201
+
202
+ def upload_traces(
203
+ self,
204
+ *,
205
+ trace_ids: Optional[Iterable[str]] = None,
206
+ target_transport: Optional[Transport] = None,
207
+ api_key: Optional[str] = None,
208
+ base_url: Optional[str] = None,
209
+ project: Optional[str] = None,
210
+ ) -> dict[str, Any]:
211
+ traces = (
212
+ [self._transport.get(f"/v1/traces/{trace_id}") for trace_id in trace_ids]
213
+ if trace_ids is not None
214
+ else self.list_traces(limit=1000)
215
+ )
216
+
217
+ created_transport = False
218
+ transport = target_transport
219
+ if transport is None:
220
+ resolved_api_key = api_key or self.config.api_key
221
+ if not resolved_api_key:
222
+ raise ContextTraceConfigError(
223
+ "api_key is required to upload local traces to a hosted ContextTrace API."
224
+ )
225
+ transport = HttpTransport(
226
+ base_url=base_url or self.config.base_url,
227
+ api_key=resolved_api_key,
228
+ timeout=self.config.timeout,
229
+ retries=self.config.retries,
230
+ debug=self.config.debug,
231
+ )
232
+ created_transport = True
233
+
234
+ uploaded = []
235
+ try:
236
+ for trace in traces:
237
+ uploaded.append(_replay_trace(trace, transport=transport, project=project or self.project))
238
+ finally:
239
+ if created_transport:
240
+ close = getattr(transport, "close", None)
241
+ if close:
242
+ close()
243
+
244
+ return {"uploaded": len(uploaded), "traces": uploaded}
245
+
246
+ def close(self) -> None:
247
+ close = getattr(self._transport, "close", None)
248
+ if close:
249
+ close()
250
+
251
+
252
+ class TraceSession:
253
+ def __init__(
254
+ self,
255
+ *,
256
+ transport: Transport,
257
+ project: str,
258
+ query: str,
259
+ metadata: dict[str, Any],
260
+ ) -> None:
261
+ self._transport = transport
262
+ self.project = project
263
+ self.query = query
264
+ self.metadata = metadata
265
+ self.trace_id: str | None = None
266
+ self.project_id: str | None = None
267
+
268
+ def __enter__(self) -> "TraceSession":
269
+ response = self._transport.post(
270
+ "/v1/traces/start",
271
+ {
272
+ "project": self.project,
273
+ "query": self.query,
274
+ "metadata": self.metadata,
275
+ },
276
+ )
277
+ self.trace_id = response["trace_id"]
278
+ self.project_id = response["project_id"]
279
+ return self
280
+
281
+ def __exit__(
282
+ self,
283
+ exc_type: type[BaseException] | None,
284
+ exc: BaseException | None,
285
+ traceback: TracebackType | None,
286
+ ) -> bool:
287
+ return False
288
+
289
+ def log_retrieval(
290
+ self,
291
+ chunks: Iterable[Any],
292
+ *,
293
+ retriever_name: str | None = None,
294
+ metadata: dict[str, Any] | None = None,
295
+ ) -> dict[str, Any]:
296
+ return self._post(
297
+ "retrieval",
298
+ {
299
+ "chunks": [_normalize_chunk(chunk) for chunk in chunks],
300
+ "retriever_name": retriever_name,
301
+ "metadata": metadata or {},
302
+ },
303
+ )
304
+
305
+ def log_context(
306
+ self,
307
+ chunks: Iterable[Any] | None = None,
308
+ *,
309
+ chunk_ids: list[str] | None = None,
310
+ metadata: dict[str, Any] | None = None,
311
+ ) -> dict[str, Any]:
312
+ payload: dict[str, Any] = {
313
+ "chunk_ids": chunk_ids,
314
+ "metadata": metadata or {},
315
+ }
316
+ if chunks is not None:
317
+ payload["chunks"] = [_normalize_chunk(chunk) for chunk in chunks]
318
+ return self._post("context", payload)
319
+
320
+ def log_answer(
321
+ self,
322
+ answer: str,
323
+ *,
324
+ model: str | None = None,
325
+ usage: dict[str, Any] | None = None,
326
+ metadata: dict[str, Any] | None = None,
327
+ ) -> dict[str, Any]:
328
+ return self._post(
329
+ "answer",
330
+ {
331
+ "answer": answer,
332
+ "model": model,
333
+ "usage": usage or {},
334
+ "metadata": metadata or {},
335
+ },
336
+ )
337
+
338
+ def log_citations(self, citations: Iterable[dict[str, Any]]) -> dict[str, Any]:
339
+ return self._post("citations", {"citations": list(citations)})
340
+
341
+ def log_agent_event(
342
+ self,
343
+ *,
344
+ event_type: str,
345
+ name: str | None = None,
346
+ input_json: Any | None = None,
347
+ output_json: Any | None = None,
348
+ metadata: dict[str, Any] | None = None,
349
+ latency_ms: float | None = None,
350
+ error_message: str | None = None,
351
+ ) -> dict[str, Any]:
352
+ return self._post(
353
+ "agent-events",
354
+ _agent_event_payload(
355
+ event_type=event_type,
356
+ name=name,
357
+ input_json=input_json,
358
+ output_json=output_json,
359
+ metadata=metadata,
360
+ latency_ms=latency_ms,
361
+ error_message=error_message,
362
+ ),
363
+ )
364
+
365
+ def log_tool_call(
366
+ self,
367
+ name: str,
368
+ *,
369
+ input_json: Any | None = None,
370
+ metadata: dict[str, Any] | None = None,
371
+ latency_ms: float | None = None,
372
+ ) -> dict[str, Any]:
373
+ return self.log_agent_event(
374
+ event_type="tool_call",
375
+ name=name,
376
+ input_json=input_json,
377
+ metadata=metadata,
378
+ latency_ms=latency_ms,
379
+ )
380
+
381
+ def log_tool_result(
382
+ self,
383
+ name: str,
384
+ *,
385
+ output_json: Any | None = None,
386
+ input_json: Any | None = None,
387
+ metadata: dict[str, Any] | None = None,
388
+ latency_ms: float | None = None,
389
+ error_message: str | None = None,
390
+ ) -> dict[str, Any]:
391
+ return self.log_agent_event(
392
+ event_type="tool_result",
393
+ name=name,
394
+ input_json=input_json,
395
+ output_json=output_json,
396
+ metadata=metadata,
397
+ latency_ms=latency_ms,
398
+ error_message=error_message,
399
+ )
400
+
401
+ def log_memory_read(
402
+ self,
403
+ name: str = "memory_read",
404
+ *,
405
+ input_json: Any | None = None,
406
+ output_json: Any | None = None,
407
+ metadata: dict[str, Any] | None = None,
408
+ latency_ms: float | None = None,
409
+ ) -> dict[str, Any]:
410
+ return self.log_agent_event(
411
+ event_type="memory_read",
412
+ name=name,
413
+ input_json=input_json,
414
+ output_json=output_json,
415
+ metadata=metadata,
416
+ latency_ms=latency_ms,
417
+ )
418
+
419
+ def log_memory_write(
420
+ self,
421
+ name: str = "memory_write",
422
+ *,
423
+ input_json: Any | None = None,
424
+ output_json: Any | None = None,
425
+ metadata: dict[str, Any] | None = None,
426
+ latency_ms: float | None = None,
427
+ ) -> dict[str, Any]:
428
+ return self.log_agent_event(
429
+ event_type="memory_write",
430
+ name=name,
431
+ input_json=input_json,
432
+ output_json=output_json,
433
+ metadata=metadata,
434
+ latency_ms=latency_ms,
435
+ )
436
+
437
+ def log_planner_step(
438
+ self,
439
+ name: str,
440
+ *,
441
+ input_json: Any | None = None,
442
+ output_json: Any | None = None,
443
+ metadata: dict[str, Any] | None = None,
444
+ latency_ms: float | None = None,
445
+ ) -> dict[str, Any]:
446
+ return self.log_agent_event(
447
+ event_type="planner_step",
448
+ name=name,
449
+ input_json=input_json,
450
+ output_json=output_json,
451
+ metadata=metadata,
452
+ latency_ms=latency_ms,
453
+ )
454
+
455
+ def log_agent_error(
456
+ self,
457
+ error_message: str,
458
+ *,
459
+ name: str = "agent_error",
460
+ input_json: Any | None = None,
461
+ output_json: Any | None = None,
462
+ metadata: dict[str, Any] | None = None,
463
+ latency_ms: float | None = None,
464
+ ) -> dict[str, Any]:
465
+ return self.log_agent_event(
466
+ event_type="error",
467
+ name=name,
468
+ input_json=input_json,
469
+ output_json=output_json,
470
+ metadata=metadata,
471
+ latency_ms=latency_ms,
472
+ error_message=error_message,
473
+ )
474
+
475
+ def list_agent_events(self) -> dict[str, Any]:
476
+ self._require_started()
477
+ return self._transport.get(f"/v1/traces/{self.trace_id}/agent-events")
478
+
479
+ def evaluate(self) -> dict[str, Any]:
480
+ return self._post("evaluate", {})
481
+
482
+ def export_report(self, *, path: str = "report.html") -> str:
483
+ trace = self.fetch()
484
+ return ReportGenerator().generate(trace, path=path)
485
+
486
+ def fetch(self) -> dict[str, Any]:
487
+ self._require_started()
488
+ return self._transport.get(f"/v1/traces/{self.trace_id}")
489
+
490
+ def _post(self, endpoint: str, payload: dict[str, Any]) -> dict[str, Any]:
491
+ self._require_started()
492
+ return self._transport.post(f"/v1/traces/{self.trace_id}/{endpoint}", payload)
493
+
494
+ def _require_started(self) -> None:
495
+ if not self.trace_id:
496
+ raise RuntimeError("Trace has not started. Use ContextTrace.trace(...) as a context manager.")
497
+
498
+
499
+ class AsyncContextTrace:
500
+ def __init__(
501
+ self,
502
+ *,
503
+ api_key: Optional[str] = None,
504
+ project: Optional[str] = None,
505
+ base_url: Optional[str] = None,
506
+ api_url: Optional[str] = None,
507
+ mode: Optional[str] = None,
508
+ local_only: Optional[bool] = None,
509
+ transport: AsyncTransport | None = None,
510
+ timeout: Optional[float] = None,
511
+ retries: Optional[int] = None,
512
+ debug: Optional[bool] = None,
513
+ local_store_dir: Optional[str] = None,
514
+ storage_path: Optional[str] = None,
515
+ log_chunk_text: Optional[bool] = None,
516
+ log_answer_text: Optional[bool] = None,
517
+ config_path: Optional[str] = None,
518
+ ) -> None:
519
+ self.config = load_config(
520
+ api_key=api_key,
521
+ project=project,
522
+ base_url=base_url,
523
+ api_url=api_url,
524
+ mode=mode,
525
+ local_only=local_only,
526
+ timeout=timeout,
527
+ retries=retries,
528
+ debug=debug,
529
+ local_store_dir=local_store_dir,
530
+ storage_path=storage_path,
531
+ log_chunk_text=log_chunk_text,
532
+ log_answer_text=log_answer_text,
533
+ config_path=config_path,
534
+ )
535
+ _configure_logging(self.config)
536
+ self.project = self.config.project
537
+ self.mode = self.config.mode
538
+ self._transport = transport or self._build_transport(self.config)
539
+
540
+ def _build_transport(self, config: ContextTraceConfig) -> AsyncTransport:
541
+ if config.mode == "local":
542
+ return _AsyncTransportAdapter(
543
+ LocalTransport(
544
+ store_dir=config.local_store_dir,
545
+ storage_path=config.storage_path,
546
+ debug=config.debug,
547
+ log_chunk_text=config.log_chunk_text,
548
+ log_answer_text=config.log_answer_text,
549
+ )
550
+ )
551
+ if not config.api_key:
552
+ raise ContextTraceConfigError(
553
+ "ContextTrace api_key is required in hosted mode. Pass api_key=..., "
554
+ "set CONTEXTTRACE_API_KEY, or run with mode='local'."
555
+ )
556
+ return AsyncHttpTransport(
557
+ base_url=config.base_url,
558
+ api_key=config.api_key,
559
+ timeout=config.timeout,
560
+ retries=config.retries,
561
+ debug=config.debug,
562
+ )
563
+
564
+ def trace(self, *, query: str, metadata: dict[str, Any] | None = None) -> "AsyncTraceSession":
565
+ return AsyncTraceSession(
566
+ transport=self._transport,
567
+ project=self.project,
568
+ query=query,
569
+ metadata=metadata or {},
570
+ )
571
+
572
+ async def create_eval_set(
573
+ self,
574
+ name: str,
575
+ *,
576
+ metadata: dict[str, Any] | None = None,
577
+ ) -> dict[str, Any]:
578
+ return await self._transport.post(
579
+ "/v1/eval-sets",
580
+ {
581
+ "name": name,
582
+ "metadata": metadata or {},
583
+ },
584
+ )
585
+
586
+ async def add_eval_questions(
587
+ self,
588
+ eval_set_id: str,
589
+ questions: Iterable[Any],
590
+ ) -> dict[str, Any]:
591
+ return await self._transport.post(
592
+ f"/v1/eval-sets/{eval_set_id}/questions",
593
+ {"questions": [_normalize_eval_question(question) for question in questions]},
594
+ )
595
+
596
+ async def evaluate_existing_traces(self, eval_set_id: str) -> dict[str, Any]:
597
+ return await self._transport.post(f"/v1/eval-sets/{eval_set_id}/runs", {})
598
+
599
+ async def register_rag_endpoint(
600
+ self,
601
+ *,
602
+ project_id: str,
603
+ name: str,
604
+ url: str,
605
+ method: str = "POST",
606
+ headers: Optional[dict[str, str]] = None,
607
+ body_template: Optional[dict[str, Any]] = None,
608
+ response_mapping: Optional[dict[str, str]] = None,
609
+ ) -> dict[str, Any]:
610
+ return await self._transport.post(
611
+ f"/v1/projects/{project_id}/external-endpoints",
612
+ {
613
+ "name": name,
614
+ "url": url,
615
+ "method": method,
616
+ "headers": headers or {},
617
+ "body_template": body_template or {"question": "{{query}}"},
618
+ "response_mapping": response_mapping
619
+ or {
620
+ "answer": "$.answer",
621
+ "citations": "$.citations",
622
+ "retrieved_chunks": "$.retrieved_chunks",
623
+ },
624
+ },
625
+ )
626
+
627
+ async def test_rag_endpoint(
628
+ self,
629
+ endpoint_id: str,
630
+ *,
631
+ query: str,
632
+ metadata: Optional[dict[str, Any]] = None,
633
+ ) -> dict[str, Any]:
634
+ return await self._transport.post(
635
+ f"/v1/external-endpoints/{endpoint_id}/test",
636
+ {
637
+ "query": query,
638
+ "metadata": metadata or {},
639
+ },
640
+ )
641
+
642
+ async def evaluate_rag_endpoint(
643
+ self,
644
+ endpoint_id: str,
645
+ *,
646
+ eval_set_id: str,
647
+ ) -> dict[str, Any]:
648
+ return await self._transport.post(
649
+ f"/v1/external-endpoints/{endpoint_id}/run-eval",
650
+ {"eval_set_id": eval_set_id},
651
+ )
652
+
653
+ async def close(self) -> None:
654
+ close = getattr(self._transport, "close", None)
655
+ if close:
656
+ result = close()
657
+ if hasattr(result, "__await__"):
658
+ await result
659
+
660
+
661
+ class AsyncTraceSession:
662
+ def __init__(
663
+ self,
664
+ *,
665
+ transport: AsyncTransport,
666
+ project: str,
667
+ query: str,
668
+ metadata: dict[str, Any],
669
+ ) -> None:
670
+ self._transport = transport
671
+ self.project = project
672
+ self.query = query
673
+ self.metadata = metadata
674
+ self.trace_id: str | None = None
675
+ self.project_id: str | None = None
676
+
677
+ async def __aenter__(self) -> "AsyncTraceSession":
678
+ response = await self._transport.post(
679
+ "/v1/traces/start",
680
+ {
681
+ "project": self.project,
682
+ "query": self.query,
683
+ "metadata": self.metadata,
684
+ },
685
+ )
686
+ self.trace_id = response["trace_id"]
687
+ self.project_id = response["project_id"]
688
+ return self
689
+
690
+ async def __aexit__(
691
+ self,
692
+ exc_type: type[BaseException] | None,
693
+ exc: BaseException | None,
694
+ traceback: TracebackType | None,
695
+ ) -> bool:
696
+ return False
697
+
698
+ async def log_retrieval(
699
+ self,
700
+ chunks: Iterable[Any],
701
+ *,
702
+ retriever_name: str | None = None,
703
+ metadata: dict[str, Any] | None = None,
704
+ ) -> dict[str, Any]:
705
+ return await self._post(
706
+ "retrieval",
707
+ {
708
+ "chunks": [_normalize_chunk(chunk) for chunk in chunks],
709
+ "retriever_name": retriever_name,
710
+ "metadata": metadata or {},
711
+ },
712
+ )
713
+
714
+ async def log_context(
715
+ self,
716
+ chunks: Iterable[Any] | None = None,
717
+ *,
718
+ chunk_ids: list[str] | None = None,
719
+ metadata: dict[str, Any] | None = None,
720
+ ) -> dict[str, Any]:
721
+ payload: dict[str, Any] = {
722
+ "chunk_ids": chunk_ids,
723
+ "metadata": metadata or {},
724
+ }
725
+ if chunks is not None:
726
+ payload["chunks"] = [_normalize_chunk(chunk) for chunk in chunks]
727
+ return await self._post("context", payload)
728
+
729
+ async def log_answer(
730
+ self,
731
+ answer: str,
732
+ *,
733
+ model: str | None = None,
734
+ usage: dict[str, Any] | None = None,
735
+ metadata: dict[str, Any] | None = None,
736
+ ) -> dict[str, Any]:
737
+ return await self._post(
738
+ "answer",
739
+ {
740
+ "answer": answer,
741
+ "model": model,
742
+ "usage": usage or {},
743
+ "metadata": metadata or {},
744
+ },
745
+ )
746
+
747
+ async def log_citations(self, citations: Iterable[dict[str, Any]]) -> dict[str, Any]:
748
+ return await self._post("citations", {"citations": list(citations)})
749
+
750
+ async def log_agent_event(
751
+ self,
752
+ *,
753
+ event_type: str,
754
+ name: str | None = None,
755
+ input_json: Any | None = None,
756
+ output_json: Any | None = None,
757
+ metadata: dict[str, Any] | None = None,
758
+ latency_ms: float | None = None,
759
+ error_message: str | None = None,
760
+ ) -> dict[str, Any]:
761
+ return await self._post(
762
+ "agent-events",
763
+ _agent_event_payload(
764
+ event_type=event_type,
765
+ name=name,
766
+ input_json=input_json,
767
+ output_json=output_json,
768
+ metadata=metadata,
769
+ latency_ms=latency_ms,
770
+ error_message=error_message,
771
+ ),
772
+ )
773
+
774
+ async def log_tool_call(
775
+ self,
776
+ name: str,
777
+ *,
778
+ input_json: Any | None = None,
779
+ metadata: dict[str, Any] | None = None,
780
+ latency_ms: float | None = None,
781
+ ) -> dict[str, Any]:
782
+ return await self.log_agent_event(
783
+ event_type="tool_call",
784
+ name=name,
785
+ input_json=input_json,
786
+ metadata=metadata,
787
+ latency_ms=latency_ms,
788
+ )
789
+
790
+ async def log_tool_result(
791
+ self,
792
+ name: str,
793
+ *,
794
+ output_json: Any | None = None,
795
+ input_json: Any | None = None,
796
+ metadata: dict[str, Any] | None = None,
797
+ latency_ms: float | None = None,
798
+ error_message: str | None = None,
799
+ ) -> dict[str, Any]:
800
+ return await self.log_agent_event(
801
+ event_type="tool_result",
802
+ name=name,
803
+ input_json=input_json,
804
+ output_json=output_json,
805
+ metadata=metadata,
806
+ latency_ms=latency_ms,
807
+ error_message=error_message,
808
+ )
809
+
810
+ async def log_memory_read(
811
+ self,
812
+ name: str = "memory_read",
813
+ *,
814
+ input_json: Any | None = None,
815
+ output_json: Any | None = None,
816
+ metadata: dict[str, Any] | None = None,
817
+ latency_ms: float | None = None,
818
+ ) -> dict[str, Any]:
819
+ return await self.log_agent_event(
820
+ event_type="memory_read",
821
+ name=name,
822
+ input_json=input_json,
823
+ output_json=output_json,
824
+ metadata=metadata,
825
+ latency_ms=latency_ms,
826
+ )
827
+
828
+ async def log_memory_write(
829
+ self,
830
+ name: str = "memory_write",
831
+ *,
832
+ input_json: Any | None = None,
833
+ output_json: Any | None = None,
834
+ metadata: dict[str, Any] | None = None,
835
+ latency_ms: float | None = None,
836
+ ) -> dict[str, Any]:
837
+ return await self.log_agent_event(
838
+ event_type="memory_write",
839
+ name=name,
840
+ input_json=input_json,
841
+ output_json=output_json,
842
+ metadata=metadata,
843
+ latency_ms=latency_ms,
844
+ )
845
+
846
+ async def log_planner_step(
847
+ self,
848
+ name: str,
849
+ *,
850
+ input_json: Any | None = None,
851
+ output_json: Any | None = None,
852
+ metadata: dict[str, Any] | None = None,
853
+ latency_ms: float | None = None,
854
+ ) -> dict[str, Any]:
855
+ return await self.log_agent_event(
856
+ event_type="planner_step",
857
+ name=name,
858
+ input_json=input_json,
859
+ output_json=output_json,
860
+ metadata=metadata,
861
+ latency_ms=latency_ms,
862
+ )
863
+
864
+ async def log_agent_error(
865
+ self,
866
+ error_message: str,
867
+ *,
868
+ name: str = "agent_error",
869
+ input_json: Any | None = None,
870
+ output_json: Any | None = None,
871
+ metadata: dict[str, Any] | None = None,
872
+ latency_ms: float | None = None,
873
+ ) -> dict[str, Any]:
874
+ return await self.log_agent_event(
875
+ event_type="error",
876
+ name=name,
877
+ input_json=input_json,
878
+ output_json=output_json,
879
+ metadata=metadata,
880
+ latency_ms=latency_ms,
881
+ error_message=error_message,
882
+ )
883
+
884
+ async def list_agent_events(self) -> dict[str, Any]:
885
+ self._require_started()
886
+ return await self._transport.get(f"/v1/traces/{self.trace_id}/agent-events")
887
+
888
+ async def evaluate(self) -> dict[str, Any]:
889
+ return await self._post("evaluate", {})
890
+
891
+ async def export_report(self, *, path: str = "report.html") -> str:
892
+ trace = await self.fetch()
893
+ return ReportGenerator().generate(trace, path=path)
894
+
895
+ async def fetch(self) -> dict[str, Any]:
896
+ self._require_started()
897
+ return await self._transport.get(f"/v1/traces/{self.trace_id}")
898
+
899
+ async def _post(self, endpoint: str, payload: dict[str, Any]) -> dict[str, Any]:
900
+ self._require_started()
901
+ return await self._transport.post(f"/v1/traces/{self.trace_id}/{endpoint}", payload)
902
+
903
+ def _require_started(self) -> None:
904
+ if not self.trace_id:
905
+ raise RuntimeError("Trace has not started. Use AsyncContextTrace.trace(...) as an async context manager.")
906
+
907
+
908
+ class _AsyncTransportAdapter:
909
+ def __init__(self, transport: Transport) -> None:
910
+ self._transport = transport
911
+
912
+ async def post(self, path: str, payload: dict[str, Any] | None = None) -> dict[str, Any]:
913
+ return self._transport.post(path, payload)
914
+
915
+ async def get(self, path: str) -> dict[str, Any]:
916
+ return self._transport.get(path)
917
+
918
+ def close(self) -> None:
919
+ close = getattr(self._transport, "close", None)
920
+ if close:
921
+ close()
922
+
923
+
924
+ def _normalize_chunk(chunk: Any) -> dict[str, Any]:
925
+ if isinstance(chunk, dict):
926
+ chunk_id = (
927
+ chunk.get("chunk_id")
928
+ or chunk.get("id")
929
+ or chunk.get("source_chunk_id")
930
+ )
931
+ content = chunk.get("content") or chunk.get("text") or chunk.get("page_content")
932
+ source = chunk.get("source")
933
+ metadata = chunk.get("metadata") or {}
934
+ relevance_score = chunk.get("relevance_score") or chunk.get("score")
935
+ else:
936
+ chunk_id = (
937
+ getattr(chunk, "chunk_id", None)
938
+ or getattr(chunk, "id", None)
939
+ or getattr(chunk, "source_chunk_id", None)
940
+ )
941
+ content = (
942
+ getattr(chunk, "content", None)
943
+ or getattr(chunk, "text", None)
944
+ or getattr(chunk, "page_content", None)
945
+ )
946
+ source = getattr(chunk, "source", None)
947
+ metadata = getattr(chunk, "metadata", None) or {}
948
+ relevance_score = getattr(chunk, "relevance_score", None) or getattr(chunk, "score", None)
949
+
950
+ if not content:
951
+ raise ValueError("Each chunk must include content, text, or page_content.")
952
+
953
+ return {
954
+ "chunk_id": str(chunk_id) if chunk_id is not None else None,
955
+ "content": str(content),
956
+ "source": source,
957
+ "metadata": metadata,
958
+ "relevance_score": relevance_score,
959
+ }
960
+
961
+
962
+ def _normalize_eval_question(question: Any) -> dict[str, Any]:
963
+ if isinstance(question, str):
964
+ return {
965
+ "question": question,
966
+ "trace_id": None,
967
+ "expected_answer": None,
968
+ "metadata": {},
969
+ }
970
+
971
+ if not isinstance(question, dict):
972
+ raise ValueError("Eval questions must be strings or dictionaries.")
973
+
974
+ text = question.get("question") or question.get("query")
975
+ if not text:
976
+ raise ValueError("Each eval question must include question or query.")
977
+
978
+ return {
979
+ "question": str(text),
980
+ "trace_id": question.get("trace_id"),
981
+ "expected_answer": question.get("expected_answer"),
982
+ "metadata": question.get("metadata") or {},
983
+ }
984
+
985
+
986
+ def _agent_event_payload(
987
+ *,
988
+ event_type: str,
989
+ name: str | None,
990
+ input_json: Any | None,
991
+ output_json: Any | None,
992
+ metadata: dict[str, Any] | None,
993
+ latency_ms: float | None,
994
+ error_message: str | None,
995
+ ) -> dict[str, Any]:
996
+ return {
997
+ "event_type": event_type,
998
+ "name": name,
999
+ "input_json": input_json if input_json is not None else {},
1000
+ "output_json": output_json if output_json is not None else {},
1001
+ "metadata_json": metadata or {},
1002
+ "latency_ms": latency_ms,
1003
+ "error_message": error_message,
1004
+ }
1005
+
1006
+
1007
+ def _replay_trace(trace: dict[str, Any], *, transport: Transport, project: str) -> dict[str, Any]:
1008
+ started = transport.post(
1009
+ "/v1/traces/start",
1010
+ {
1011
+ "project": trace.get("project") or project,
1012
+ "query": trace.get("query") or "",
1013
+ "metadata": trace.get("metadata") or {},
1014
+ },
1015
+ )
1016
+ remote_trace_id = started["trace_id"]
1017
+ chunks = trace.get("chunks") or []
1018
+ if chunks:
1019
+ transport.post(
1020
+ f"/v1/traces/{remote_trace_id}/retrieval",
1021
+ {"chunks": chunks, "retriever_name": "contexttrace-batch-upload", "metadata": {}},
1022
+ )
1023
+ selected = [chunk for chunk in chunks if chunk.get("selected")]
1024
+ if selected:
1025
+ transport.post(
1026
+ f"/v1/traces/{remote_trace_id}/context",
1027
+ {"chunks": selected, "metadata": {"source": "contexttrace-batch-upload"}},
1028
+ )
1029
+
1030
+ answer = trace.get("answer") or {}
1031
+ if answer.get("answer"):
1032
+ transport.post(
1033
+ f"/v1/traces/{remote_trace_id}/answer",
1034
+ {
1035
+ "answer": answer["answer"],
1036
+ "model": answer.get("model"),
1037
+ "usage": answer.get("usage") or {},
1038
+ "metadata": answer.get("metadata") or {},
1039
+ },
1040
+ )
1041
+
1042
+ checks = trace.get("citation_checks") or []
1043
+ citations = [
1044
+ {
1045
+ "claim": check.get("claim"),
1046
+ "source_chunk_id": check.get("source_chunk_id"),
1047
+ }
1048
+ for check in checks
1049
+ if check.get("claim") and check.get("source_chunk_id")
1050
+ ]
1051
+ if citations:
1052
+ transport.post(f"/v1/traces/{remote_trace_id}/citations", {"citations": citations})
1053
+
1054
+ for event in trace.get("agent_events") or []:
1055
+ transport.post(
1056
+ f"/v1/traces/{remote_trace_id}/agent-events",
1057
+ {
1058
+ "event_type": event.get("event_type"),
1059
+ "name": event.get("name"),
1060
+ "input_json": event.get("input_json") or {},
1061
+ "output_json": event.get("output_json") or {},
1062
+ "metadata_json": event.get("metadata_json") or {},
1063
+ "latency_ms": event.get("latency_ms"),
1064
+ "error_message": event.get("error_message"),
1065
+ },
1066
+ )
1067
+
1068
+ return {"local_trace_id": trace.get("id"), "trace_id": remote_trace_id}
1069
+
1070
+
1071
+ def _configure_logging(config: ContextTraceConfig) -> None:
1072
+ if config.debug:
1073
+ logging.basicConfig(level=logging.DEBUG)
1074
+ logger.debug("ContextTrace config loaded: %s", config)