docent-python 0.1.4a0__py3-none-any.whl → 0.1.6a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

@@ -15,6 +15,7 @@ from pydantic import (
15
15
  from docent.data_models._tiktoken_util import get_token_count, group_messages_into_ranges
16
16
  from docent.data_models.transcript import (
17
17
  Transcript,
18
+ TranscriptGroup,
18
19
  TranscriptWithoutMetadataValidator,
19
20
  fake_model_dump,
20
21
  )
@@ -36,6 +37,7 @@ class AgentRun(BaseModel):
36
37
  name: Optional human-readable name for the agent run.
37
38
  description: Optional description of the agent run.
38
39
  transcripts: Dict mapping transcript IDs to Transcript objects.
40
+ transcript_groups: Dict mapping transcript group IDs to TranscriptGroup objects.
39
41
  metadata: Additional structured metadata about the agent run as a JSON-serializable dictionary.
40
42
  """
41
43
 
@@ -44,6 +46,7 @@ class AgentRun(BaseModel):
44
46
  description: str | None = None
45
47
 
46
48
  transcripts: dict[str, Transcript]
49
+ transcript_groups: dict[str, TranscriptGroup] = Field(default_factory=dict)
47
50
  metadata: dict[str, Any] = Field(default_factory=dict)
48
51
 
49
52
  @field_serializer("metadata")
@@ -1,4 +1,5 @@
1
1
  import sys
2
+ from datetime import datetime
2
3
  from typing import Any
3
4
  from uuid import uuid4
4
5
 
@@ -73,6 +74,8 @@ class TranscriptGroup(BaseModel):
73
74
  id: Unique identifier for the transcript group, auto-generated by default.
74
75
  name: Optional human-readable name for the transcript group.
75
76
  description: Optional description of the transcript group.
77
+ collection_id: ID of the collection this transcript group belongs to.
78
+ agent_run_id: ID of the agent run this transcript group belongs to.
76
79
  parent_transcript_group_id: Optional ID of the parent transcript group.
77
80
  metadata: Additional structured metadata about the transcript group.
78
81
  """
@@ -80,7 +83,10 @@ class TranscriptGroup(BaseModel):
80
83
  id: str = Field(default_factory=lambda: str(uuid4()))
81
84
  name: str | None = None
82
85
  description: str | None = None
86
+ collection_id: str
87
+ agent_run_id: str
83
88
  parent_transcript_group_id: str | None = None
89
+ created_at: datetime | None = None
84
90
  metadata: dict[str, Any] = Field(default_factory=dict)
85
91
 
86
92
  @field_serializer("metadata")
@@ -129,6 +135,7 @@ class Transcript(BaseModel):
129
135
  name: str | None = None
130
136
  description: str | None = None
131
137
  transcript_group_id: str | None = None
138
+ created_at: datetime | None = None
132
139
 
133
140
  messages: list[ChatMessage]
134
141
  metadata: dict[str, Any] = Field(default_factory=dict)
@@ -1,4 +1,7 @@
1
- from typing import Any
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any, BinaryIO, Generator, Tuple
4
+ from zipfile import ZipFile
2
5
 
3
6
  from inspect_ai.log import EvalLog
4
7
  from inspect_ai.scorer import CORRECT, INCORRECT, NOANSWER, PARTIAL, Score
@@ -7,9 +10,9 @@ from docent.data_models import AgentRun, Transcript
7
10
  from docent.data_models.chat import parse_chat_message
8
11
 
9
12
 
10
- def _normalize_inspect_score(score: Score) -> Any:
13
+ def _normalize_inspect_score(score: Score | dict[str, Any]) -> Any:
11
14
  """
12
- Normalize an inspect score to a float. This implements the same logic as inspect_ai.scorer._metric.value_to_float, but fails more conspicuously.
15
+ Normalize an inspect score to a float. Logic mirrors inspect_ai.scorer._metric.value_to_float.
13
16
 
14
17
  Args:
15
18
  score: The inspect score to normalize.
@@ -18,7 +21,7 @@ def _normalize_inspect_score(score: Score) -> Any:
18
21
  The normalized score as a float, or None if the score is not a valid value.
19
22
  """
20
23
 
21
- def _leaf_normalize(value: int | float | bool | str | None) -> float | str | None:
24
+ def _leaf_normalize(value: Any) -> Any:
22
25
  if value is None:
23
26
  return None
24
27
  if isinstance(value, int | float | bool):
@@ -38,12 +41,17 @@ def _normalize_inspect_score(score: Score) -> Any:
38
41
  return float(value)
39
42
  return value
40
43
 
41
- if isinstance(score.value, int | float | bool | str):
42
- return _leaf_normalize(score.value)
43
- if isinstance(score.value, list):
44
- return [_leaf_normalize(v) for v in score.value]
45
- assert isinstance(score.value, dict), "Inspect score must be leaf value, list, or dict"
46
- return {k: _leaf_normalize(v) for k, v in score.value.items()}
44
+ if isinstance(score, dict):
45
+ value = score["value"]
46
+ else:
47
+ value = score.value
48
+
49
+ if isinstance(value, int | float | bool | str):
50
+ return _leaf_normalize(value)
51
+ if isinstance(value, list):
52
+ return [_leaf_normalize(v) for v in value] # type: ignore
53
+ assert isinstance(value, dict), "Inspect score must be leaf value, list, or dict"
54
+ return {k: _leaf_normalize(v) for k, v in value.items()} # type: ignore
47
55
 
48
56
 
49
57
  def load_inspect_log(log: EvalLog) -> list[AgentRun]:
@@ -86,3 +94,117 @@ def load_inspect_log(log: EvalLog) -> list[AgentRun]:
86
94
  )
87
95
 
88
96
  return agent_runs
97
+
98
+
99
+ def _read_sample_as_run(data: dict[str, Any], header_metadata: dict[str, Any] = {}) -> AgentRun:
100
+ if "scores" in data:
101
+ normalized_scores = {k: _normalize_inspect_score(v) for k, v in data["scores"].items()}
102
+ else:
103
+ normalized_scores = {}
104
+
105
+ if "metadata" in data:
106
+ sample_metadata = data["metadata"]
107
+ else:
108
+ sample_metadata = {}
109
+
110
+ run_metadata: dict[str, Any] = {
111
+ "sample_id": data.get("id"),
112
+ "epoch": data.get("epoch"),
113
+ "target": data.get("target"),
114
+ # Scores could have answers, explanations, and other metadata besides the values we extract
115
+ "scoring_metadata": data.get("scores"),
116
+ "scores": normalized_scores,
117
+ # If a key exists in header and sample, sample takes precedence
118
+ **header_metadata,
119
+ **sample_metadata,
120
+ }
121
+
122
+ run = AgentRun(
123
+ transcripts={
124
+ "main": Transcript(
125
+ messages=[parse_chat_message(m) for m in data["messages"]], metadata={}
126
+ ),
127
+ },
128
+ metadata=run_metadata,
129
+ )
130
+ return run
131
+
132
+
133
+ def _run_metadata_from_header(header: dict[str, Any]) -> dict[str, Any]:
134
+ """
135
+ Inspect logs often have a lot of metadata.
136
+ This function tries to get the most important stuff without adding clutter.
137
+ """
138
+ m: dict[str, Any] = {}
139
+ if e := header.get("eval"):
140
+ m["task"] = e["task"]
141
+ m["model"] = e["model"]
142
+ return m
143
+
144
+
145
+ def get_total_samples(file_path: Path, format: str = "json") -> int:
146
+ """Return the total number of samples in the provided file."""
147
+ with open(file_path, "rb") as f:
148
+ if format == "json":
149
+ data = json.load(f)
150
+ return len(data.get("samples", []))
151
+ elif format == "eval":
152
+ z = ZipFile(f, mode="r")
153
+ try:
154
+ return sum(
155
+ 1
156
+ for name in z.namelist()
157
+ if name.startswith("samples/") and name.endswith(".json")
158
+ )
159
+ finally:
160
+ z.close()
161
+ else:
162
+ raise ValueError(f"Format must be 'json' or 'eval': {format}")
163
+
164
+
165
+ def _runs_from_eval_file(
166
+ file: BinaryIO,
167
+ ) -> Tuple[dict[str, Any], Generator[AgentRun, None, None]]:
168
+ zip = ZipFile(file, mode="r")
169
+ header: dict[str, Any] = json.load(zip.open("header.json", "r"))
170
+ header_metadata = _run_metadata_from_header(header)
171
+
172
+ def _iter_runs() -> Generator[AgentRun, None, None]:
173
+ try:
174
+ for sample_file in zip.namelist():
175
+ if not (sample_file.startswith("samples/") and sample_file.endswith(".json")):
176
+ continue
177
+ with zip.open(sample_file, "r") as f:
178
+ data = json.load(f)
179
+ run: AgentRun = _read_sample_as_run(data, header_metadata)
180
+ yield run
181
+ finally:
182
+ zip.close()
183
+
184
+ return header_metadata, _iter_runs()
185
+
186
+
187
+ def _runs_from_json_file(
188
+ file: BinaryIO,
189
+ ) -> Tuple[dict[str, Any], Generator[AgentRun, None, None]]:
190
+ data = json.load(file)
191
+ header_metadata = _run_metadata_from_header(data)
192
+
193
+ def _iter_runs() -> Generator[AgentRun, None, None]:
194
+ for sample in data["samples"]:
195
+ run: AgentRun = _read_sample_as_run(sample, header_metadata)
196
+ yield run
197
+
198
+ return header_metadata, _iter_runs()
199
+
200
+
201
+ def runs_from_file(
202
+ file: BinaryIO, format: str = "json"
203
+ ) -> Tuple[dict[str, Any], Generator[AgentRun, None, None]]:
204
+ if format == "json":
205
+ result = _runs_from_json_file(file)
206
+ elif format == "eval":
207
+ result = _runs_from_eval_file(file)
208
+ else:
209
+ raise ValueError(f"Format must be 'json' or 'eval': {format}")
210
+ return result
docent/sdk/client.py CHANGED
@@ -197,75 +197,85 @@ class Docent:
197
197
  return response.json()
198
198
 
199
199
  def list_searches(self, collection_id: str) -> list[dict[str, Any]]:
200
- """List all searches for a given collection.
200
+ """List all rubrics for a given collection.
201
201
 
202
202
  Args:
203
203
  collection_id: ID of the Collection.
204
204
 
205
205
  Returns:
206
- list: List of dictionaries containing search query information.
206
+ list: List of dictionaries containing rubric information.
207
207
 
208
208
  Raises:
209
209
  requests.exceptions.HTTPError: If the API request fails.
210
210
  """
211
- url = f"{self._server_url}/{collection_id}/list_search_queries"
211
+ url = f"{self._server_url}/rubric/{collection_id}/rubrics"
212
212
  response = self._session.get(url)
213
213
  response.raise_for_status()
214
214
  return response.json()
215
215
 
216
- def get_search_results(self, collection_id: str, search_query: str) -> list[dict[str, Any]]:
217
- """Get search results for a given collection and search query.
218
- Pass in either search_query or query_id.
216
+ def get_search_results(
217
+ self, collection_id: str, rubric_id: str, rubric_version: int
218
+ ) -> list[dict[str, Any]]:
219
+ """Get rubric results for a given collection, rubric and version.
219
220
 
220
221
  Args:
221
222
  collection_id: ID of the Collection.
222
- search_query: The search query to get results for.
223
+ rubric_id: The ID of the rubric to get results for.
224
+ rubric_version: The version of the rubric to get results for.
223
225
 
224
226
  Returns:
225
- list: List of dictionaries containing search result information.
227
+ list: List of dictionaries containing rubric result information.
226
228
 
227
229
  Raises:
228
230
  requests.exceptions.HTTPError: If the API request fails.
229
231
  """
230
- url = f"{self._server_url}/{collection_id}/get_search_results"
231
- response = self._session.post(url, json={"search_query": search_query})
232
+ url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/results"
233
+ response = self._session.get(url, params={"rubric_version": rubric_version})
232
234
  response.raise_for_status()
233
235
  return response.json()
234
236
 
235
- def list_search_clusters(self, collection_id: str, search_query: str) -> list[dict[str, Any]]:
236
- """List all search clusters for a given collection.
237
- Pass in either search_query or query_id.
237
+ def list_search_clusters(
238
+ self, collection_id: str, rubric_id: str, rubric_version: int | None = None
239
+ ) -> list[dict[str, Any]]:
240
+ """List all centroids for a given collection and rubric.
238
241
 
239
242
  Args:
240
243
  collection_id: ID of the Collection.
241
- search_query: The search query to get clusters for.
244
+ rubric_id: The ID of the rubric to get centroids for.
245
+ rubric_version: Optional version of the rubric. If not provided, uses latest.
242
246
 
243
247
  Returns:
244
- list: List of dictionaries containing search cluster information.
248
+ list: List of dictionaries containing centroid information.
245
249
 
246
250
  Raises:
247
251
  requests.exceptions.HTTPError: If the API request fails.
248
252
  """
249
- url = f"{self._server_url}/{collection_id}/list_search_clusters"
250
- response = self._session.post(url, json={"search_query": search_query})
253
+ url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/centroids"
254
+ params: dict[str, int] = {}
255
+ if rubric_version is not None:
256
+ params["rubric_version"] = rubric_version
257
+ response = self._session.get(url, params=params)
251
258
  response.raise_for_status()
252
259
  return response.json()
253
260
 
254
- def get_cluster_matches(self, collection_id: str, centroid: str) -> list[dict[str, Any]]:
255
- """Get the matches for a given cluster.
261
+ def get_cluster_matches(
262
+ self, collection_id: str, rubric_id: str, rubric_version: int
263
+ ) -> list[dict[str, Any]]:
264
+ """Get centroid assignments for a given rubric.
256
265
 
257
266
  Args:
258
267
  collection_id: ID of the Collection.
259
- cluster_id: The ID of the cluster to get matches for.
268
+ rubric_id: The ID of the rubric to get assignments for.
269
+ rubric_version: The version of the rubric to get assignments for.
260
270
 
261
271
  Returns:
262
- list: List of dictionaries containing the search results that match the cluster.
272
+ list: List of dictionaries containing centroid assignment information.
263
273
 
264
274
  Raises:
265
275
  requests.exceptions.HTTPError: If the API request fails.
266
276
  """
267
- url = f"{self._server_url}/{collection_id}/get_cluster_matches"
268
- response = self._session.post(url, json={"centroid": centroid})
277
+ url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/assignments"
278
+ response = self._session.get(url, params={"rubric_version": rubric_version})
269
279
  response.raise_for_status()
270
280
  return response.json()
271
281
 
docent/trace.py CHANGED
@@ -11,17 +11,15 @@ from collections import defaultdict
11
11
  from contextlib import asynccontextmanager, contextmanager
12
12
  from contextvars import ContextVar, Token
13
13
  from datetime import datetime, timezone
14
- from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
14
+ from enum import Enum
15
+ from importlib.metadata import Distribution, distributions
16
+ from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Set, Union
15
17
 
16
18
  import requests
17
19
  from opentelemetry import trace
18
20
  from opentelemetry.context import Context
19
21
  from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCExporter
20
22
  from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPExporter
21
- from opentelemetry.instrumentation.anthropic import AnthropicInstrumentor
22
- from opentelemetry.instrumentation.bedrock import BedrockInstrumentor
23
- from opentelemetry.instrumentation.langchain import LangchainInstrumentor
24
- from opentelemetry.instrumentation.openai import OpenAIInstrumentor
25
23
  from opentelemetry.instrumentation.threading import ThreadingInstrumentor
26
24
  from opentelemetry.sdk.resources import Resource
27
25
  from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor, TracerProvider
@@ -33,15 +31,23 @@ from opentelemetry.sdk.trace.export import (
33
31
  from opentelemetry.trace import Span
34
32
 
35
33
  # Configure logging
36
- logging.basicConfig(level=logging.INFO)
37
34
  logger = logging.getLogger(__name__)
38
- logger.disabled = True
35
+ logger.setLevel(logging.ERROR)
39
36
 
40
37
  # Default configuration
41
38
  DEFAULT_ENDPOINT = "https://api.docent.transluce.org/rest/telemetry"
42
39
  DEFAULT_COLLECTION_NAME = "default-collection-name"
43
40
 
44
41
 
42
+ class Instruments(Enum):
43
+ """Enumeration of available instrument types."""
44
+
45
+ OPENAI = "openai"
46
+ ANTHROPIC = "anthropic"
47
+ BEDROCK = "bedrock"
48
+ LANGCHAIN = "langchain"
49
+
50
+
45
51
  def _is_notebook() -> bool:
46
52
  """Check if we're running in a Jupyter notebook."""
47
53
  try:
@@ -64,6 +70,8 @@ class DocentTracer:
64
70
  enable_console_export: bool = False,
65
71
  enable_otlp_export: bool = True,
66
72
  disable_batch: bool = False,
73
+ instruments: Optional[Set[Instruments]] = None,
74
+ block_instruments: Optional[Set[Instruments]] = None,
67
75
  ):
68
76
  """
69
77
  Initialize Docent tracing manager.
@@ -78,6 +86,8 @@ class DocentTracer:
78
86
  enable_console_export: Whether to export to console
79
87
  enable_otlp_export: Whether to export to OTLP endpoint
80
88
  disable_batch: Whether to disable batch processing (use SimpleSpanProcessor)
89
+ instruments: Set of instruments to enable (None = all instruments)
90
+ block_instruments: Set of instruments to explicitly disable
81
91
  """
82
92
  self.collection_name: str = collection_name
83
93
  self.collection_id: str = collection_id if collection_id else str(uuid.uuid4())
@@ -105,6 +115,9 @@ class DocentTracer:
105
115
  self.enable_console_export = enable_console_export
106
116
  self.enable_otlp_export = enable_otlp_export
107
117
  self.disable_batch = disable_batch
118
+ self.disabled_instruments: Set[Instruments] = {Instruments.LANGCHAIN}
119
+ self.instruments = instruments or (set(Instruments) - self.disabled_instruments)
120
+ self.block_instruments = block_instruments or set()
108
121
 
109
122
  # Use separate tracer provider to avoid interfering with existing OTEL setup
110
123
  self._tracer_provider: Optional[TracerProvider] = None
@@ -206,7 +219,7 @@ class DocentTracer:
206
219
  exporters.append(exporter)
207
220
  logger.info(f"Initialized exporter for endpoint: {endpoint}")
208
221
  else:
209
- logger.warning(f"Failed to initialize exporter for endpoint: {endpoint}")
222
+ logger.critical(f"Failed to initialize exporter for endpoint: {endpoint}")
210
223
 
211
224
  return exporters
212
225
 
@@ -309,8 +322,6 @@ class DocentTracer:
309
322
  logger.info(
310
323
  f"Added {len(otlp_exporters)} OTLP exporters for {len(self.endpoints)} endpoints"
311
324
  )
312
- else:
313
- logger.warning("Failed to initialize OTLP exporter")
314
325
 
315
326
  if self.enable_console_export:
316
327
  console_exporter: ConsoleSpanExporter = ConsoleSpanExporter()
@@ -333,33 +344,51 @@ class DocentTracer:
333
344
  except Exception as e:
334
345
  logger.warning(f"Failed to instrument threading: {e}")
335
346
 
347
+ enabled_instruments = self.instruments - self.block_instruments
348
+
336
349
  # Instrument OpenAI with our isolated tracer provider
337
- try:
338
- OpenAIInstrumentor().instrument(tracer_provider=self._tracer_provider)
339
- logger.info("Instrumented OpenAI")
340
- except Exception as e:
341
- logger.warning(f"Failed to instrument OpenAI: {e}")
350
+ if Instruments.OPENAI in enabled_instruments:
351
+ try:
352
+ if is_package_installed("openai"):
353
+ from opentelemetry.instrumentation.openai import OpenAIInstrumentor
354
+
355
+ OpenAIInstrumentor().instrument(tracer_provider=self._tracer_provider)
356
+ logger.info("Instrumented OpenAI")
357
+ except Exception as e:
358
+ logger.warning(f"Failed to instrument OpenAI: {e}")
342
359
 
343
360
  # Instrument Anthropic with our isolated tracer provider
344
- try:
345
- AnthropicInstrumentor().instrument(tracer_provider=self._tracer_provider)
346
- logger.info("Instrumented Anthropic")
347
- except Exception as e:
348
- logger.warning(f"Failed to instrument Anthropic: {e}")
361
+ if Instruments.ANTHROPIC in enabled_instruments:
362
+ try:
363
+ if is_package_installed("anthropic"):
364
+ from opentelemetry.instrumentation.anthropic import AnthropicInstrumentor
365
+
366
+ AnthropicInstrumentor().instrument(tracer_provider=self._tracer_provider)
367
+ logger.info("Instrumented Anthropic")
368
+ except Exception as e:
369
+ logger.warning(f"Failed to instrument Anthropic: {e}")
349
370
 
350
371
  # Instrument Bedrock with our isolated tracer provider
351
- try:
352
- BedrockInstrumentor().instrument(tracer_provider=self._tracer_provider)
353
- logger.info("Instrumented Bedrock")
354
- except Exception as e:
355
- logger.warning(f"Failed to instrument Bedrock: {e}")
372
+ if Instruments.BEDROCK in enabled_instruments:
373
+ try:
374
+ if is_package_installed("boto3"):
375
+ from opentelemetry.instrumentation.bedrock import BedrockInstrumentor
376
+
377
+ BedrockInstrumentor().instrument(tracer_provider=self._tracer_provider)
378
+ logger.info("Instrumented Bedrock")
379
+ except Exception as e:
380
+ logger.warning(f"Failed to instrument Bedrock: {e}")
356
381
 
357
382
  # Instrument LangChain with our isolated tracer provider
358
- try:
359
- LangchainInstrumentor().instrument(tracer_provider=self._tracer_provider)
360
- logger.info("Instrumented LangChain")
361
- except Exception as e:
362
- logger.warning(f"Failed to instrument LangChain: {e}")
383
+ if Instruments.LANGCHAIN in enabled_instruments:
384
+ try:
385
+ if is_package_installed("langchain") or is_package_installed("langgraph"):
386
+ from opentelemetry.instrumentation.langchain import LangchainInstrumentor
387
+
388
+ LangchainInstrumentor().instrument(tracer_provider=self._tracer_provider)
389
+ logger.info("Instrumented LangChain")
390
+ except Exception as e:
391
+ logger.warning(f"Failed to instrument LangChain: {e}")
363
392
 
364
393
  # Register cleanup handlers
365
394
  self._register_cleanup()
@@ -789,9 +818,19 @@ class DocentTracer:
789
818
  metadata: Optional metadata to send
790
819
  """
791
820
  collection_id = self.collection_id
821
+
822
+ # Get agent_run_id from current context
823
+ agent_run_id = self.get_current_agent_run_id()
824
+ if not agent_run_id:
825
+ logger.error(
826
+ f"Cannot send transcript group metadata for {transcript_group_id} - no agent_run_id in context"
827
+ )
828
+ return
829
+
792
830
  payload: Dict[str, Any] = {
793
831
  "collection_id": collection_id,
794
832
  "transcript_group_id": transcript_group_id,
833
+ "agent_run_id": agent_run_id,
795
834
  "timestamp": datetime.now(timezone.utc).isoformat(),
796
835
  }
797
836
 
@@ -942,6 +981,8 @@ def initialize_tracing(
942
981
  enable_console_export: bool = False,
943
982
  enable_otlp_export: bool = True,
944
983
  disable_batch: bool = False,
984
+ instruments: Optional[Set[Instruments]] = None,
985
+ block_instruments: Optional[Set[Instruments]] = None,
945
986
  ) -> DocentTracer:
946
987
  """
947
988
  Initialize the global Docent tracer.
@@ -958,6 +999,8 @@ def initialize_tracing(
958
999
  enable_console_export: Whether to export spans to console
959
1000
  enable_otlp_export: Whether to export spans to OTLP endpoint
960
1001
  disable_batch: Whether to disable batch processing (use SimpleSpanProcessor)
1002
+ instruments: Set of instruments to enable (None = all instruments).
1003
+ block_instruments: Set of instruments to explicitly disable.
961
1004
 
962
1005
  Returns:
963
1006
  The initialized Docent tracer
@@ -966,6 +1009,7 @@ def initialize_tracing(
966
1009
  # Basic setup
967
1010
  initialize_tracing("my-collection")
968
1011
  """
1012
+
969
1013
  global _global_tracer
970
1014
 
971
1015
  # Check for API key in environment variable if not provided as parameter
@@ -983,12 +1027,30 @@ def initialize_tracing(
983
1027
  enable_console_export=enable_console_export,
984
1028
  enable_otlp_export=enable_otlp_export,
985
1029
  disable_batch=disable_batch,
1030
+ instruments=instruments,
1031
+ block_instruments=block_instruments,
986
1032
  )
987
1033
  _global_tracer.initialize()
988
1034
 
989
1035
  return _global_tracer
990
1036
 
991
1037
 
1038
+ def _get_package_name(dist: Distribution) -> str | None:
1039
+ try:
1040
+ return dist.name.lower()
1041
+ except (KeyError, AttributeError):
1042
+ return None
1043
+
1044
+
1045
+ installed_packages = {
1046
+ name for dist in distributions() if (name := _get_package_name(dist)) is not None
1047
+ }
1048
+
1049
+
1050
+ def is_package_installed(package_name: str) -> bool:
1051
+ return package_name.lower() in installed_packages
1052
+
1053
+
992
1054
  def get_tracer() -> DocentTracer:
993
1055
  """Get the global Docent tracer."""
994
1056
  if _global_tracer is None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: docent-python
3
- Version: 0.1.4a0
3
+ Version: 0.1.6a0
4
4
  Summary: Docent SDK
5
5
  Project-URL: Homepage, https://github.com/TransluceAI/docent
6
6
  Project-URL: Issues, https://github.com/TransluceAI/docent/issues
@@ -22,4 +22,3 @@ Requires-Dist: pydantic>=2.11.7
22
22
  Requires-Dist: pyyaml>=6.0.2
23
23
  Requires-Dist: tiktoken>=0.7.0
24
24
  Requires-Dist: tqdm>=4.67.1
25
- Requires-Dist: traceloop-sdk>=0.44.1
@@ -1,30 +1,29 @@
1
1
  docent/__init__.py,sha256=J2BbO6rzilfw9WXRUeolr439EGFezqbMU_kCpCCryRA,59
2
2
  docent/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- docent/trace.py,sha256=T6F9jbj0JcnGBxlcR65iDPF0lmeFR6laTlhDBcw0Mh0,61167
4
- docent/trace_alt.py,sha256=k0EZ_jyvVDqQ3HExDPZzdZYj6tKFL_cgbNRnOq4oGyg,17747
3
+ docent/trace.py,sha256=oGhNizXcOU-FZJZkmnd8WQjlCvlRKWI3PncMHUHKQ_4,63667
5
4
  docent/trace_temp.py,sha256=Z0lAPwVzXjFvxpiU-CuvfWIslq9Q4alNkZMoQ77Xudk,40711
6
5
  docent/_log_util/__init__.py,sha256=3HXXrxrSm8PxwG4llotrCnSnp7GuroK1FNHsdg6f7aE,73
7
6
  docent/_log_util/logger.py,sha256=kwM0yRW1IJd6-XTorjWn48B4l8qvD2ZM6VDjY5eskQI,4422
8
7
  docent/data_models/__init__.py,sha256=4JbTDVzRhS5VZgo8MALwd_YI17GaN7X9E3rOc4Xl7kw,327
9
8
  docent/data_models/_tiktoken_util.py,sha256=hC0EDDWItv5-0cONBnHWgZtQOflDU7ZNEhXPFo4DvPc,3057
10
- docent/data_models/agent_run.py,sha256=lw-odD2zzFi-RGvkAFjz9x8l6XWPrGT6uRGqTj9h8qU,9621
9
+ docent/data_models/agent_run.py,sha256=bDRToWUlY52PugoHWU1D9hasr5t_fnTmRLpkzWP1s_k,9811
11
10
  docent/data_models/citation.py,sha256=WsVQZcBT2EJD24ysyeVOC5Xfo165RI7P5_cOnJBgHj0,10015
12
11
  docent/data_models/metadata.py,sha256=r0SYC4i2x096dXMLfw_rAMtcJQCsoV6EOMPZuEngbGA,9062
13
12
  docent/data_models/regex.py,sha256=0ciIerkrNwb91bY5mTcyO5nDWH67xx2tZYObV52fmBo,1684
14
13
  docent/data_models/shared_types.py,sha256=jjm-Dh5S6v7UKInW7SEqoziOsx6Z7Uu4e3VzgCbTWvc,225
15
- docent/data_models/transcript.py,sha256=NDcpvil4dJ8YhG_JJ0X-w0prkXhwhsdO-zoL-CZMipM,15446
14
+ docent/data_models/transcript.py,sha256=0iF2ujcWhTss8WkkpNMeIKJyKOfMEsiMoAQMGwY4ing,15753
16
15
  docent/data_models/chat/__init__.py,sha256=O04XQ2NmO8GTWqkkB_Iydj8j_CucZuLhoyMVTxJN_cs,570
17
16
  docent/data_models/chat/content.py,sha256=Co-jO8frQa_DSP11wJuhPX0s-GpJk8yqtKqPeiAIZ_U,1672
18
17
  docent/data_models/chat/message.py,sha256=iAo38kbV6wYbFh8S23cxLy6HY4C_i3PzQ6RpSQG5dxM,3861
19
18
  docent/data_models/chat/tool.py,sha256=x7NKINswPe0Kqvcx4ubjHzB-n0-i4DbFodvaBb2vitk,3042
20
- docent/loaders/load_inspect.py,sha256=yK6LZgprT8kc0Jg4N_cnbhsGCq9lINmMcgALXA9AibY,2812
19
+ docent/loaders/load_inspect.py,sha256=_cK2Qd6gyLQuJVzOlsvEZz7TrqzNmH6ZsLTkSCWAPqQ,6628
21
20
  docent/samples/__init__.py,sha256=roDFnU6515l9Q8v17Es_SpWyY9jbm5d6X9lV01V0MZo,143
22
21
  docent/samples/load.py,sha256=ZGE07r83GBNO4A0QBh5aQ18WAu3mTWA1vxUoHd90nrM,207
23
22
  docent/samples/log.eval,sha256=orrW__9WBfANq7NwKsPSq9oTsQRcG6KohG5tMr_X_XY,397708
24
23
  docent/samples/tb_airline.json,sha256=eR2jFFRtOw06xqbEglh6-dPewjifOk-cuxJq67Dtu5I,47028
25
24
  docent/sdk/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
- docent/sdk/client.py,sha256=uyhTisb9bHk7Hd2G4UKLdfvuiAmYOOqJiwEPbYWN9IE,12371
27
- docent_python-0.1.4a0.dist-info/METADATA,sha256=AUptJVGzZtABJ8V1Hpzhi2EOaMdBJB_nSaldpN6J8Bg,1074
28
- docent_python-0.1.4a0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
29
- docent_python-0.1.4a0.dist-info/licenses/LICENSE.md,sha256=vOHzq3K4Ndu0UV9hPrtXvlD7pHOjyDQmGjHuLSIkRQY,1087
30
- docent_python-0.1.4a0.dist-info/RECORD,,
25
+ docent/sdk/client.py,sha256=fLdniy8JzMLoZpaS9SP2pHban_ToavgtI8VeHZLMNZo,12773
26
+ docent_python-0.1.6a0.dist-info/METADATA,sha256=ib_GqBFrOmPvacYb4uncrC5qLsoygIJ0wU852MOea_8,1037
27
+ docent_python-0.1.6a0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
28
+ docent_python-0.1.6a0.dist-info/licenses/LICENSE.md,sha256=vOHzq3K4Ndu0UV9hPrtXvlD7pHOjyDQmGjHuLSIkRQY,1087
29
+ docent_python-0.1.6a0.dist-info/RECORD,,
docent/trace_alt.py DELETED
@@ -1,513 +0,0 @@
1
- import asyncio
2
- import atexit
3
- import functools
4
- import io
5
- import logging
6
- import os
7
- import uuid
8
- from contextlib import asynccontextmanager, contextmanager, redirect_stdout
9
- from contextvars import ContextVar, Token
10
- from typing import Any, AsyncIterator, Callable, Dict, Iterator, Optional, Set
11
-
12
- import requests
13
- from opentelemetry.context import Context
14
- from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor
15
- from opentelemetry.trace import Span
16
- from traceloop.sdk import Traceloop
17
-
18
- # Configure logging
19
- logger = logging.getLogger(__name__)
20
- logger.disabled = True
21
-
22
- DEFAULT_ENDPOINT = "https://api.docent.transluce.org/rest/telemetry"
23
-
24
- # Context variables for tracking current agent run and collection
25
- _current_agent_run_id: ContextVar[Optional[str]] = ContextVar("current_agent_run_id", default=None)
26
- _current_collection_id: ContextVar[Optional[str]] = ContextVar(
27
- "current_collection_id", default=None
28
- )
29
-
30
- # Global configuration
31
- _tracing_initialized = False
32
- _collection_name: Optional[str] = None
33
- _collection_id: Optional[str] = None
34
- _default_agent_run_id: Optional[str] = None
35
- _endpoint: Optional[str] = None
36
- _api_key: Optional[str] = None
37
- _enable_console_export = False
38
- _disable_batch = False
39
- _instruments: Optional[Set[Any]] = None
40
- _block_instruments: Optional[Set[Any]] = None
41
-
42
-
43
- class DocentSpanProcessor(SpanProcessor):
44
- """Custom span processor to add Docent metadata to spans.
45
-
46
- This processor integrates cleanly with Traceloop's existing span processing
47
- and adds Docent-specific attributes to all spans.
48
- """
49
-
50
- def __init__(self, collection_id: str, enable_console_export: bool = False):
51
- self.collection_id = collection_id
52
- self.enable_console_export = enable_console_export
53
-
54
- def on_start(self, span: Span, parent_context: Optional[Context] = None) -> None:
55
- """Add Docent metadata when a span starts."""
56
- # Always add collection_id
57
- span.set_attribute("collection_id", self.collection_id)
58
-
59
- # Add agent_run_id if available
60
- agent_run_id = _get_current_agent_run_id()
61
- if agent_run_id:
62
- span.set_attribute("agent_run_id", agent_run_id)
63
- else:
64
- span.set_attribute("agent_run_id", _get_default_agent_run_id())
65
- span.set_attribute("agent_run_id_default", True)
66
-
67
- # Add service name for better integration with existing OTEL setups
68
- span.set_attribute("service.name", _collection_name or "docent-trace")
69
-
70
- if self.enable_console_export:
71
- logging.debug(
72
- f"Span started - collection_id: {self.collection_id}, agent_run_id: {agent_run_id}"
73
- )
74
-
75
- def on_end(self, span: ReadableSpan) -> None:
76
- pass
77
-
78
- def shutdown(self) -> None:
79
- """Called when the processor is shut down."""
80
-
81
- def force_flush(self, timeout_millis: float = 30000) -> bool:
82
- """Force flush any pending spans."""
83
- return True
84
-
85
-
86
- def initialize_tracing(
87
- collection_name: str,
88
- collection_id: Optional[str] = None,
89
- endpoint: Optional[str] = None,
90
- api_key: Optional[str] = None,
91
- enable_console_export: bool = False,
92
- disable_batch: bool = False,
93
- instruments: Optional[Set[Any]] = None,
94
- block_instruments: Optional[Set[Any]] = None,
95
- ) -> None:
96
- """Initialize Docent tracing with the specified configuration.
97
-
98
- This function provides a comprehensive initialization that integrates cleanly
99
- with existing OpenTelemetry setups and provides extensive configuration options.
100
-
101
- Args:
102
- collection_name: Name for your application/collection
103
- collection_id: Optional collection ID (auto-generated if not provided)
104
- endpoint: Optional OTLP endpoint URL (defaults to Docent's hosted service)
105
- api_key: Optional API key (uses DOCENT_API_KEY environment variable if not provided)
106
- enable_console_export: Whether to also export traces to console for debugging
107
- disable_batch: Whether to disable batch processing (use SimpleSpanProcessor)
108
- instruments: Set of instruments to enable (None = all instruments)
109
- block_instruments: Set of instruments to explicitly disable
110
- """
111
- global _tracing_initialized, _collection_name, _collection_id, _default_agent_run_id, _endpoint, _api_key
112
- global _enable_console_export, _disable_batch, _instruments, _block_instruments
113
-
114
- if _tracing_initialized:
115
- logging.warning("Docent tracing already initialized")
116
- return
117
-
118
- _collection_name = collection_name
119
- _collection_id = collection_id or _generate_id()
120
- _default_agent_run_id = _get_default_agent_run_id() # Generate default ID if not set
121
- _endpoint = endpoint or DEFAULT_ENDPOINT
122
- _api_key = api_key or os.getenv("DOCENT_API_KEY")
123
- _enable_console_export = enable_console_export
124
- _disable_batch = disable_batch
125
- _instruments = instruments
126
- _block_instruments = block_instruments
127
-
128
- _set_current_collection_id(_collection_id)
129
-
130
- if not _api_key:
131
- raise ValueError(
132
- "API key is required. Set DOCENT_API_KEY environment variable or pass api_key parameter."
133
- )
134
-
135
- # Initialize Traceloop with comprehensive configuration
136
-
137
- # Get Traceloop's default span processor
138
- from traceloop.sdk.tracing.tracing import get_default_span_processor
139
-
140
- # Create our custom context span processor (only adds metadata, doesn't export)
141
- docent_processor = DocentSpanProcessor(_collection_id, enable_console_export)
142
-
143
- # Get Traceloop's default span processor for export
144
- export_processor = get_default_span_processor(
145
- disable_batch=_disable_batch,
146
- api_endpoint=_endpoint,
147
- headers={"Authorization": f"Bearer {_api_key}"},
148
- )
149
-
150
- # Combine both processors
151
- processors = [docent_processor, export_processor]
152
-
153
- os.environ["TRACELOOP_METRICS_ENABLED"] = "false"
154
- os.environ["TRACELOOP_TRACE_ENABLED"] = "true"
155
-
156
- # Temporarily redirect stdout to suppress print statements
157
- with redirect_stdout(io.StringIO()):
158
- Traceloop.init( # type: ignore
159
- app_name=collection_name,
160
- api_endpoint=_endpoint,
161
- api_key=_api_key,
162
- telemetry_enabled=False, # don't send analytics to traceloop's backend
163
- disable_batch=_disable_batch,
164
- instruments=_instruments,
165
- block_instruments=_block_instruments,
166
- processor=processors, # Add both our context processor and export processor
167
- )
168
-
169
- _tracing_initialized = True
170
- logging.info(
171
- f"Docent tracing initialized for collection: {collection_name} with collection_id: {_collection_id}"
172
- )
173
-
174
- # Register cleanup handlers
175
- atexit.register(_cleanup_tracing)
176
-
177
-
178
- def _cleanup_tracing() -> None:
179
- """Clean up tracing resources on shutdown."""
180
- global _tracing_initialized
181
- if _tracing_initialized:
182
- try:
183
- # Notify API that the trace is over
184
- _notify_trace_done()
185
-
186
- logging.info("Docent tracing cleanup completed")
187
- except Exception as e:
188
- logging.warning(f"Error during tracing cleanup: {e}")
189
- finally:
190
- _tracing_initialized = False
191
-
192
-
193
- def _ensure_tracing_initialized():
194
- """Ensure tracing has been initialized before use."""
195
- if not _tracing_initialized:
196
- raise RuntimeError("Docent tracing not initialized. Call initialize_tracing() first.")
197
-
198
-
199
- def _generate_id() -> str:
200
- """Generate a unique ID for agent runs or collections."""
201
- return str(uuid.uuid4())
202
-
203
-
204
- def _get_current_agent_run_id() -> Optional[str]:
205
- """Get the current agent run ID from context."""
206
- return _current_agent_run_id.get()
207
-
208
-
209
- def _get_current_collection_id() -> Optional[str]:
210
- """Get the current collection ID from context."""
211
- return _current_collection_id.get()
212
-
213
-
214
- def _get_default_agent_run_id() -> str:
215
- """Get the default agent run ID, generating it if not set."""
216
- global _default_agent_run_id
217
- if _default_agent_run_id is None:
218
- _default_agent_run_id = _generate_id()
219
- return _default_agent_run_id
220
-
221
-
222
- def _set_current_agent_run_id(agent_run_id: Optional[str]) -> Token[Optional[str]]:
223
- """Set the current agent run ID in context."""
224
- return _current_agent_run_id.set(agent_run_id)
225
-
226
-
227
- def _set_current_collection_id(collection_id: Optional[str]) -> Token[Optional[str]]:
228
- """Set the current collection ID in context."""
229
- return _current_collection_id.set(collection_id)
230
-
231
-
232
- def _send_to_api(endpoint: str, data: Dict[str, Any]) -> None:
233
- """Send data to the Docent API endpoint.
234
-
235
- Args:
236
- endpoint: The API endpoint URL
237
- data: The data to send
238
- """
239
- try:
240
- headers = {"Content-Type": "application/json", "Authorization": f"Bearer {_api_key}"}
241
-
242
- response = requests.post(endpoint, json=data, headers=headers, timeout=10)
243
- response.raise_for_status()
244
-
245
- logging.debug(f"Successfully sent data to {endpoint}")
246
- except requests.exceptions.RequestException as e:
247
- logging.error(f"Failed to send data to {endpoint}: {e}")
248
- except Exception as e:
249
- logging.error(f"Unexpected error sending data to {endpoint}: {e}")
250
-
251
-
252
- def _notify_trace_done() -> None:
253
- """Notify the Docent API that the trace is done."""
254
- collection_id = _get_current_collection_id()
255
- if collection_id and _endpoint:
256
- data = {"collection_id": collection_id, "status": "completed"}
257
- _send_to_api(f"{_endpoint}/v1/trace-done", data)
258
-
259
-
260
- def agent_run_score(name: str, score: float, attributes: Optional[Dict[str, Any]] = None) -> None:
261
- """
262
- Record a score event on the current span.
263
- Automatically works in both sync and async contexts.
264
-
265
- Args:
266
- name: Name of the score metric
267
- score: Numeric score value
268
- attributes: Optional additional attributes for the score event
269
- """
270
- _ensure_tracing_initialized()
271
-
272
- agent_run_id = _get_current_agent_run_id()
273
- if not agent_run_id:
274
- logging.warning("No active agent run context. Score will not be sent.")
275
- return
276
-
277
- collection_id = _get_current_collection_id() or _collection_id
278
- if not collection_id:
279
- logging.warning("No collection ID available. Score will not be sent.")
280
- return
281
-
282
- # Send score directly to API
283
- score_data = {
284
- "collection_id": collection_id,
285
- "agent_run_id": agent_run_id,
286
- "score_name": name,
287
- "score_value": score,
288
- }
289
-
290
- # Add additional attributes if provided
291
- if attributes:
292
- score_data.update(attributes)
293
-
294
- _send_to_api(f"{_endpoint}/v1/scores", score_data)
295
-
296
-
297
- def agent_run_metadata(metadata: Dict[str, Any]) -> None:
298
- """Attach metadata to the current agent run.
299
-
300
- Args:
301
- metadata: Dictionary of metadata to attach
302
- """
303
- _ensure_tracing_initialized()
304
-
305
- agent_run_id = _get_current_agent_run_id()
306
- if not agent_run_id:
307
- logging.warning("No active agent run context. Metadata will not be sent.")
308
- return
309
-
310
- collection_id = _get_current_collection_id() or _collection_id
311
- if not collection_id:
312
- logging.warning("No collection ID available. Metadata will not be sent.")
313
- return
314
-
315
- # Send metadata directly to API
316
- metadata_data = {
317
- "collection_id": collection_id,
318
- "agent_run_id": agent_run_id,
319
- "metadata": metadata,
320
- }
321
-
322
- _send_to_api(f"{_endpoint}/v1/metadata", metadata_data)
323
-
324
-
325
- @contextmanager
326
- def _agent_run_context_sync(
327
- agent_run_id: Optional[str] = None,
328
- metadata: Optional[Dict[str, Any]] = None,
329
- ) -> Iterator[tuple[str, Optional[str]]]:
330
- """Synchronous context manager for creating and managing agent runs."""
331
- _ensure_tracing_initialized()
332
-
333
- # Generate IDs if not provided
334
- current_agent_run_id = agent_run_id or _generate_id()
335
-
336
- # Set up context
337
- agent_run_token = _set_current_agent_run_id(current_agent_run_id)
338
-
339
- try:
340
- # Send metadata to API if provided
341
- if metadata:
342
- agent_run_metadata(metadata)
343
-
344
- # Yield the agent run ID and None for transcript_id (handled by backend)
345
- # Traceloop will automatically create spans for any instrumented operations
346
- # and our DocentSpanProcessor will add the appropriate metadata
347
- yield (current_agent_run_id, None)
348
- finally:
349
- # Restore context
350
- _current_agent_run_id.reset(agent_run_token)
351
-
352
-
353
- @asynccontextmanager
354
- async def _agent_run_context_async(
355
- agent_run_id: Optional[str] = None,
356
- metadata: Optional[Dict[str, Any]] = None,
357
- ) -> AsyncIterator[tuple[str, Optional[str]]]:
358
- """Asynchronous context manager for creating and managing agent runs."""
359
- _ensure_tracing_initialized()
360
-
361
- # Generate IDs if not provided
362
- current_agent_run_id = agent_run_id or _generate_id()
363
-
364
- # Set up context
365
- agent_run_token = _set_current_agent_run_id(current_agent_run_id)
366
-
367
- try:
368
- # Send metadata to API if provided
369
- if metadata:
370
- agent_run_metadata(metadata)
371
-
372
- # Yield the agent run ID and None for transcript_id (handled by backend)
373
- # Traceloop will automatically create spans for any instrumented operations
374
- # and our DocentSpanProcessor will add the appropriate metadata
375
- yield (current_agent_run_id, None)
376
- finally:
377
- # Restore context
378
- _current_agent_run_id.reset(agent_run_token)
379
-
380
-
381
- def agent_run_context(
382
- agent_run_id: Optional[str] = None,
383
- metadata: Optional[Dict[str, Any]] = None,
384
- ):
385
- """Context manager for creating and managing agent runs.
386
-
387
- This context manager can be used in both synchronous and asynchronous contexts.
388
- In async contexts, use it with `async with agent_run_context()`.
389
- In sync contexts, use it with `with agent_run_context()`.
390
-
391
- Args:
392
- agent_run_id: Optional agent run ID (auto-generated if not provided)
393
- metadata: Optional metadata to attach to the agent run
394
-
395
- Returns:
396
- A context manager that yields a tuple of (agent_run_id, transcript_id)
397
- where transcript_id is None for now as it's handled by backend
398
- """
399
- # Check if we're in an async context by looking at the current frame
400
- import inspect
401
-
402
- frame = inspect.currentframe()
403
- try:
404
- # Look for async context indicators in the call stack
405
- while frame:
406
- if frame.f_code.co_flags & 0x80: # CO_COROUTINE flag
407
- return _agent_run_context_async(agent_run_id=agent_run_id, metadata=metadata)
408
- frame = frame.f_back
409
- finally:
410
- # Clean up the frame reference
411
- del frame
412
-
413
- # Default to sync context manager
414
- return _agent_run_context_sync(agent_run_id=agent_run_id, metadata=metadata)
415
-
416
-
417
- def agent_run(
418
- func: Optional[Callable[..., Any]] = None,
419
- *,
420
- agent_run_id: Optional[str] = None,
421
- metadata: Optional[Dict[str, Any]] = None,
422
- ) -> Callable[..., Any]:
423
- """Decorator for creating agent runs around functions.
424
-
425
- Args:
426
- func: Function to decorate
427
- agent_run_id: Optional agent run ID (auto-generated if not provided)
428
- metadata: Optional metadata to attach to the agent run
429
-
430
- Returns:
431
- Decorated function
432
- """
433
-
434
- def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
435
- @functools.wraps(func)
436
- def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
437
- with _agent_run_context_sync(agent_run_id=agent_run_id, metadata=metadata) as (
438
- run_id,
439
- _,
440
- ):
441
- result = func(*args, **kwargs)
442
- # Store agent run ID as an attribute for access
443
- setattr(sync_wrapper, "docent", type("DocentInfo", (), {"agent_run_id": run_id})()) # type: ignore
444
- return result
445
-
446
- @functools.wraps(func)
447
- async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
448
- async with _agent_run_context_async(agent_run_id=agent_run_id, metadata=metadata) as (
449
- run_id,
450
- _,
451
- ):
452
- result = await func(*args, **kwargs)
453
- # Store agent run ID as an attribute for access
454
- setattr(async_wrapper, "docent", type("DocentInfo", (), {"agent_run_id": run_id})()) # type: ignore
455
- return result
456
-
457
- # Return appropriate wrapper based on function type
458
- if asyncio.iscoroutinefunction(func):
459
- return async_wrapper
460
- else:
461
- return sync_wrapper
462
-
463
- # Handle both @agent_run and @agent_run(agent_run_id=..., metadata=...)
464
- if func is None:
465
- return decorator
466
- else:
467
- return decorator(func)
468
-
469
-
470
- # Additional utility functions for better integration
471
-
472
-
473
- def get_current_agent_run_id() -> Optional[str]:
474
- """Get the current agent run ID from context.
475
-
476
- Returns:
477
- The current agent run ID if available, None otherwise
478
- """
479
- return _get_current_agent_run_id()
480
-
481
-
482
- def get_current_collection_id() -> Optional[str]:
483
- """Get the current collection ID from context.
484
-
485
- Returns:
486
- The current collection ID if available, None otherwise
487
- """
488
- return _get_current_collection_id()
489
-
490
-
491
- def is_tracing_initialized() -> bool:
492
- """Check if tracing has been initialized.
493
-
494
- Returns:
495
- True if tracing is initialized, False otherwise
496
- """
497
- return _tracing_initialized
498
-
499
-
500
- def flush_spans() -> None:
501
- """Force flush any pending spans to the backend.
502
-
503
- This is useful for ensuring all spans are sent before shutdown
504
- or for debugging purposes.
505
- """
506
- if _tracing_initialized:
507
- try:
508
- traceloop_instance = Traceloop.get()
509
- if hasattr(traceloop_instance, "flush"):
510
- traceloop_instance.flush() # type: ignore
511
- logging.debug("Spans flushed successfully")
512
- except Exception as e:
513
- logging.warning(f"Error flushing spans: {e}")