docent-python 0.1.18a0__py3-none-any.whl → 0.1.20a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

docent/sdk/client.py CHANGED
@@ -8,6 +8,7 @@ from tqdm import tqdm
8
8
 
9
9
  from docent._log_util.logger import get_logger
10
10
  from docent.data_models.agent_run import AgentRun
11
+ from docent.data_models.judge import JudgeRunLabel
11
12
  from docent.loaders import load_inspect
12
13
 
13
14
  logger = get_logger(__name__)
@@ -48,13 +49,18 @@ class Docent:
48
49
 
49
50
  self._login(api_key)
50
51
 
52
+ def _handle_response_errors(self, response: requests.Response):
53
+ """Handle API response and raise informative errors.
54
+ TODO: make this more informative."""
55
+ response.raise_for_status()
56
+
51
57
  def _login(self, api_key: str):
52
58
  """Login with email/password to establish session."""
53
59
  self._session.headers.update({"Authorization": f"Bearer {api_key}"})
54
60
 
55
61
  url = f"{self._server_url}/api-keys/test"
56
62
  response = self._session.get(url)
57
- response.raise_for_status()
63
+ self._handle_response_errors(response)
58
64
 
59
65
  logger.info("Logged in with API key")
60
66
  return
@@ -90,7 +96,7 @@ class Docent:
90
96
  }
91
97
 
92
98
  response = self._session.post(url, json=payload)
93
- response.raise_for_status()
99
+ self._handle_response_errors(response)
94
100
 
95
101
  response_data = response.json()
96
102
  collection_id = response_data.get("collection_id")
@@ -134,13 +140,13 @@ class Docent:
134
140
  payload = {"agent_runs": [ar.model_dump(mode="json") for ar in batch]}
135
141
 
136
142
  response = self._session.post(url, json=payload)
137
- response.raise_for_status()
143
+ self._handle_response_errors(response)
138
144
 
139
145
  pbar.update(len(batch))
140
146
 
141
147
  url = f"{self._server_url}/{collection_id}/compute_embeddings"
142
148
  response = self._session.post(url)
143
- response.raise_for_status()
149
+ self._handle_response_errors(response)
144
150
 
145
151
  logger.info(f"Successfully added {total_runs} agent runs to Collection '{collection_id}'")
146
152
  return {"status": "success", "total_runs_added": total_runs}
@@ -156,7 +162,7 @@ class Docent:
156
162
  """
157
163
  url = f"{self._server_url}/collections"
158
164
  response = self._session.get(url)
159
- response.raise_for_status()
165
+ self._handle_response_errors(response)
160
166
  return response.json()
161
167
 
162
168
  def list_rubrics(self, collection_id: str) -> list[dict[str, Any]]:
@@ -173,15 +179,18 @@ class Docent:
173
179
  """
174
180
  url = f"{self._server_url}/rubric/{collection_id}/rubrics"
175
181
  response = self._session.get(url)
176
- response.raise_for_status()
182
+ self._handle_response_errors(response)
177
183
  return response.json()
178
184
 
179
- def get_rubric_run_state(self, collection_id: str, rubric_id: str) -> dict[str, Any]:
185
+ def get_rubric_run_state(
186
+ self, collection_id: str, rubric_id: str, version: int | None = None
187
+ ) -> dict[str, Any]:
180
188
  """Get rubric run state for a given collection and rubric.
181
189
 
182
190
  Args:
183
191
  collection_id: ID of the Collection.
184
192
  rubric_id: The ID of the rubric to get run state for.
193
+ version: The version of the rubric to get run state for. If None, the latest version is used.
185
194
 
186
195
  Returns:
187
196
  dict: Dictionary containing rubric run state with results, job_id, and total_agent_runs.
@@ -190,8 +199,8 @@ class Docent:
190
199
  requests.exceptions.HTTPError: If the API request fails.
191
200
  """
192
201
  url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/rubric_run_state"
193
- response = self._session.get(url)
194
- response.raise_for_status()
202
+ response = self._session.get(url, params={"version": version})
203
+ self._handle_response_errors(response)
195
204
  return response.json()
196
205
 
197
206
  def get_clustering_state(self, collection_id: str, rubric_id: str) -> dict[str, Any]:
@@ -209,7 +218,7 @@ class Docent:
209
218
  """
210
219
  url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/clustering_job"
211
220
  response = self._session.get(url)
212
- response.raise_for_status()
221
+ self._handle_response_errors(response)
213
222
  return response.json()
214
223
 
215
224
  def get_cluster_centroids(self, collection_id: str, rubric_id: str) -> list[dict[str, Any]]:
@@ -244,6 +253,90 @@ class Docent:
244
253
  clustering_state = self.get_clustering_state(collection_id, rubric_id)
245
254
  return clustering_state.get("assignments", {})
246
255
 
256
+ def add_label(
257
+ self,
258
+ collection_id: str,
259
+ rubric_id: str,
260
+ label: JudgeRunLabel,
261
+ ) -> dict[str, Any]:
262
+ """Attach a manual label to an agent run for a rubric.
263
+
264
+ Args:
265
+ collection_id: ID of the Collection that owns the rubric.
266
+ rubric_id: ID of the rubric the label applies to.
267
+ label: A `JudgeRunLabel` that must comply with the rubric's output schema.
268
+
269
+ Returns:
270
+ dict: API response containing a status message.
271
+
272
+ Raises:
273
+ ValueError: If the label does not target the rubric specified in the path.
274
+ requests.exceptions.HTTPError: If the API request fails or validation errors occur.
275
+ """
276
+ if label.rubric_id != rubric_id:
277
+ raise ValueError("Label rubric_id must match the rubric_id argument")
278
+
279
+ url = f"{self._server_url}/rubric/{collection_id}/rubric/{rubric_id}/label"
280
+ payload = {"label": label.model_dump(mode="json")}
281
+ response = self._session.post(url, json=payload)
282
+ self._handle_response_errors(response)
283
+ return response.json()
284
+
285
+ def add_labels(
286
+ self,
287
+ collection_id: str,
288
+ rubric_id: str,
289
+ labels: list[JudgeRunLabel],
290
+ ) -> dict[str, Any]:
291
+ """Attach multiple manual labels to a rubric.
292
+
293
+ Args:
294
+ collection_id: ID of the Collection that owns the rubric.
295
+ rubric_id: ID of the rubric the labels apply to.
296
+ labels: List of `JudgeRunLabel` objects.
297
+
298
+ Returns:
299
+ dict: API response containing status information.
300
+
301
+ Raises:
302
+ ValueError: If no labels are provided.
303
+ ValueError: If any label targets a different rubric.
304
+ requests.exceptions.HTTPError: If the API request fails.
305
+ """
306
+ if not labels:
307
+ raise ValueError("labels must contain at least one entry")
308
+
309
+ rubric_ids = {label.rubric_id for label in labels}
310
+ if rubric_ids != {rubric_id}:
311
+ raise ValueError(
312
+ "All labels must specify the same rubric_id that is provided to add_labels"
313
+ )
314
+
315
+ payload = {"labels": [l.model_dump(mode="json") for l in labels]}
316
+
317
+ url = f"{self._server_url}/rubric/{collection_id}/rubric/{rubric_id}/labels"
318
+ response = self._session.post(url, json=payload)
319
+ self._handle_response_errors(response)
320
+ return response.json()
321
+
322
+ def get_labels(self, collection_id: str, rubric_id: str) -> list[dict[str, Any]]:
323
+ """Retrieve all manual labels for a rubric.
324
+
325
+ Args:
326
+ collection_id: ID of the Collection that owns the rubric.
327
+ rubric_id: ID of the rubric to fetch labels for.
328
+
329
+ Returns:
330
+ list: List of label dictionaries. Each includes agent_run_id and label content.
331
+
332
+ Raises:
333
+ requests.exceptions.HTTPError: If the API request fails.
334
+ """
335
+ url = f"{self._server_url}/rubric/{collection_id}/rubric/{rubric_id}/labels"
336
+ response = self._session.get(url)
337
+ self._handle_response_errors(response)
338
+ return response.json()
339
+
247
340
  def get_agent_run(self, collection_id: str, agent_run_id: str) -> AgentRun | None:
248
341
  """Get a specific agent run by its ID.
249
342
 
@@ -259,7 +352,7 @@ class Docent:
259
352
  """
260
353
  url = f"{self._server_url}/{collection_id}/agent_run"
261
354
  response = self._session.get(url, params={"agent_run_id": agent_run_id})
262
- response.raise_for_status()
355
+ self._handle_response_errors(response)
263
356
  if response.json() is None:
264
357
  return None
265
358
  else:
@@ -281,7 +374,7 @@ class Docent:
281
374
  """
282
375
  url = f"{self._server_url}/{collection_id}/make_public"
283
376
  response = self._session.post(url)
284
- response.raise_for_status()
377
+ self._handle_response_errors(response)
285
378
 
286
379
  logger.info(f"Successfully made Collection '{collection_id}' public")
287
380
  return response.json()
@@ -303,13 +396,7 @@ class Docent:
303
396
  payload = {"email": email}
304
397
  response = self._session.post(url, json=payload)
305
398
 
306
- try:
307
- response.raise_for_status()
308
- except requests.exceptions.HTTPError:
309
- if response.status_code == 404:
310
- raise ValueError(f"The user you are trying to share with ({email}) does not exist.")
311
- else:
312
- raise # Re-raise the original exception
399
+ self._handle_response_errors(response)
313
400
 
314
401
  logger.info(f"Successfully shared Collection '{collection_id}' with {email}")
315
402
  return response.json()
@@ -328,7 +415,7 @@ class Docent:
328
415
  """
329
416
  url = f"{self._server_url}/{collection_id}/agent_run_ids"
330
417
  response = self._session.get(url)
331
- response.raise_for_status()
418
+ self._handle_response_errors(response)
332
419
  return response.json()
333
420
 
334
421
  def recursively_ingest_inspect_logs(self, collection_id: str, fpath: str):
@@ -393,7 +480,7 @@ class Docent:
393
480
  payload = {"agent_runs": [ar.model_dump(mode="json") for ar in batch_list]}
394
481
 
395
482
  response = self._session.post(url, json=payload)
396
- response.raise_for_status()
483
+ self._handle_response_errors(response)
397
484
 
398
485
  runs_from_file += len(batch_list)
399
486
  file_pbar.update(len(batch_list))
@@ -406,7 +493,7 @@ class Docent:
406
493
  logger.info("Computing embeddings for added runs...")
407
494
  url = f"{self._server_url}/{collection_id}/compute_embeddings"
408
495
  response = self._session.post(url)
409
- response.raise_for_status()
496
+ self._handle_response_errors(response)
410
497
 
411
498
  logger.info(
412
499
  f"Successfully ingested {total_runs_added} total agent runs from {len(eval_files)} files"
docent/trace.py CHANGED
@@ -21,7 +21,7 @@ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExport
21
21
  from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPExporter
22
22
  from opentelemetry.instrumentation.threading import ThreadingInstrumentor
23
23
  from opentelemetry.sdk.resources import Resource
24
- from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor, TracerProvider
24
+ from opentelemetry.sdk.trace import ReadableSpan, SpanLimits, SpanProcessor, TracerProvider
25
25
  from opentelemetry.sdk.trace.export import (
26
26
  BatchSpanProcessor,
27
27
  ConsoleSpanExporter,
@@ -29,20 +29,13 @@ from opentelemetry.sdk.trace.export import (
29
29
  )
30
30
  from opentelemetry.trace import Span
31
31
 
32
- # Configure logging
33
32
  logger = logging.getLogger(__name__)
34
- logger.setLevel(logging.ERROR)
35
33
 
36
34
  # Default configuration
37
35
  DEFAULT_ENDPOINT = "https://api.docent.transluce.org/rest/telemetry"
38
36
  DEFAULT_COLLECTION_NAME = "default-collection-name"
39
37
 
40
38
 
41
- def _is_tracing_disabled() -> bool:
42
- """Check if tracing is disabled via environment variable."""
43
- return os.environ.get("DOCENT_DISABLE_TRACING", "").lower() == "true"
44
-
45
-
46
39
  class Instruments(Enum):
47
40
  """Enumeration of available instrument types."""
48
41
 
@@ -52,16 +45,10 @@ class Instruments(Enum):
52
45
  LANGCHAIN = "langchain"
53
46
 
54
47
 
55
- def _is_notebook() -> bool:
56
- """Check if we're running in a Jupyter notebook."""
57
- try:
58
- return "ipykernel" in sys.modules
59
- except Exception:
60
- return False
61
-
62
-
63
48
  class DocentTracer:
64
- """Manages Docent tracing setup and provides tracing utilities."""
49
+ """
50
+ Manages Docent tracing setup and provides tracing utilities.
51
+ """
65
52
 
66
53
  def __init__(
67
54
  self,
@@ -77,22 +64,6 @@ class DocentTracer:
77
64
  instruments: Optional[Set[Instruments]] = None,
78
65
  block_instruments: Optional[Set[Instruments]] = None,
79
66
  ):
80
- """
81
- Initialize Docent tracing manager.
82
-
83
- Args:
84
- collection_name: Name of the collection for resource attributes
85
- collection_id: Optional collection ID (auto-generated if not provided)
86
- agent_run_id: Optional agent_run_id to use for code outside of an agent run context (auto-generated if not provided)
87
- endpoint: OTLP endpoint URL(s) - can be a single string or list of strings for multiple endpoints
88
- headers: Optional headers for authentication
89
- api_key: Optional API key for bearer token authentication (takes precedence over env var)
90
- enable_console_export: Whether to export to console
91
- enable_otlp_export: Whether to export to OTLP endpoint
92
- disable_batch: Whether to disable batch processing (use SimpleSpanProcessor)
93
- instruments: Set of instruments to enable (None = all instruments)
94
- block_instruments: Set of instruments to explicitly disable
95
- """
96
67
  self._initialized: bool = False
97
68
  # Check if tracing is disabled via environment variable
98
69
  if _is_tracing_disabled():
@@ -163,8 +134,12 @@ class DocentTracer:
163
134
  """
164
135
  Get the current agent run ID from context.
165
136
 
137
+ Retrieves the agent run ID that was set in the current execution context.
138
+ If no agent run context is active, returns the default agent run ID.
139
+
166
140
  Returns:
167
- The current agent run ID if available, None otherwise
141
+ The current agent run ID if available, or the default agent run ID
142
+ if no context is active.
168
143
  """
169
144
  try:
170
145
  return self._agent_run_id_var.get()
@@ -249,12 +224,23 @@ class DocentTracer:
249
224
  return
250
225
 
251
226
  try:
227
+
228
+ # Check for OTEL_SPAN_ATTRIBUTE_COUNT_LIMIT environment variable
229
+ default_attribute_limit = 1024
230
+ env_value = os.environ.get("OTEL_SPAN_ATTRIBUTE_COUNT_LIMIT", "0")
231
+ env_limit = int(env_value) if env_value.isdigit() else 0
232
+ attribute_limit = max(env_limit, default_attribute_limit)
233
+
234
+ span_limits = SpanLimits(
235
+ max_attributes=attribute_limit,
236
+ )
237
+
252
238
  # Create our own isolated tracer provider
253
239
  self._tracer_provider = TracerProvider(
254
- resource=Resource.create({"service.name": self.collection_name})
240
+ resource=Resource.create({"service.name": self.collection_name}),
241
+ span_limits=span_limits,
255
242
  )
256
243
 
257
- # Add custom span processor for agent_run_id and transcript_id
258
244
  class ContextSpanProcessor(SpanProcessor):
259
245
  def __init__(self, manager: "DocentTracer"):
260
246
  self.manager: "DocentTracer" = manager
@@ -312,11 +298,7 @@ class DocentTracer:
312
298
  )
313
299
 
314
300
  def on_end(self, span: ReadableSpan) -> None:
315
- # Debug logging for span completion
316
- span_attrs = span.attributes or {}
317
- logger.debug(
318
- f"Completed span: name='{span.name}', collection_id={span_attrs.get('collection_id')}, agent_run_id={span_attrs.get('agent_run_id')}, transcript_id={span_attrs.get('transcript_id')}, duration_ns={span.end_time - span.start_time if span.end_time and span.start_time else 'unknown'}"
319
- )
301
+ pass
320
302
 
321
303
  def shutdown(self) -> None:
322
304
  pass
@@ -422,7 +404,17 @@ class DocentTracer:
422
404
  raise
423
405
 
424
406
  def cleanup(self):
425
- """Clean up Docent tracing resources and signal trace completion to backend."""
407
+ """
408
+ Clean up Docent tracing resources.
409
+
410
+ Flushes all pending spans to exporters and shuts down the tracer provider.
411
+ This method is automatically called during application shutdown via atexit
412
+ handlers, but can also be called manually for explicit cleanup.
413
+
414
+ The cleanup process:
415
+ 1. Flushes all span processors to ensure data is exported
416
+ 2. Shuts down the tracer provider and releases resources
417
+ """
426
418
  if self._disabled:
427
419
  return
428
420
 
@@ -473,7 +465,7 @@ class DocentTracer:
473
465
  if disabled and self._initialized:
474
466
  self.cleanup()
475
467
 
476
- def verify_initialized(self) -> bool:
468
+ def is_initialized(self) -> bool:
477
469
  """Verify if the manager is properly initialized."""
478
470
  return self._initialized
479
471
 
@@ -1063,8 +1055,9 @@ def initialize_tracing(
1063
1055
  collection_id: Optional collection ID (auto-generated if not provided)
1064
1056
  endpoint: OTLP endpoint URL(s) for span export - can be a single string or list of strings for multiple endpoints
1065
1057
  headers: Optional headers for authentication
1066
- api_key: Optional API key for bearer token authentication (takes precedence over env var)
1067
- enable_console_export: Whether to export spans to console
1058
+ api_key: Optional API key for bearer token authentication (takes precedence
1059
+ over DOCENT_API_KEY environment variable)
1060
+ enable_console_export: Whether to export spans to console for debugging
1068
1061
  enable_otlp_export: Whether to export spans to OTLP endpoint
1069
1062
  disable_batch: Whether to disable batch processing (use SimpleSpanProcessor)
1070
1063
  instruments: Set of instruments to enable (None = all instruments).
@@ -1074,7 +1067,6 @@ def initialize_tracing(
1074
1067
  The initialized Docent tracer
1075
1068
 
1076
1069
  Example:
1077
- # Basic setup
1078
1070
  initialize_tracing("my-collection")
1079
1071
  """
1080
1072
 
@@ -1137,17 +1129,17 @@ def close_tracing() -> None:
1137
1129
  def flush_tracing() -> None:
1138
1130
  """Force flush all spans to exporters."""
1139
1131
  if _global_tracer:
1140
- logger.debug("Flushing global tracer")
1132
+ logger.debug("Flushing Docent tracer")
1141
1133
  _global_tracer.flush()
1142
1134
  else:
1143
1135
  logger.debug("No global tracer available to flush")
1144
1136
 
1145
1137
 
1146
- def verify_initialized() -> bool:
1138
+ def is_initialized() -> bool:
1147
1139
  """Verify if the global Docent tracer is properly initialized."""
1148
1140
  if _global_tracer is None:
1149
1141
  return False
1150
- return _global_tracer.verify_initialized()
1142
+ return _global_tracer.is_initialized()
1151
1143
 
1152
1144
 
1153
1145
  def is_disabled() -> bool:
@@ -1764,3 +1756,16 @@ def transcript_group_context(
1764
1756
  return TranscriptGroupContext(
1765
1757
  name, transcript_group_id, description, metadata, parent_transcript_group_id
1766
1758
  )
1759
+
1760
+
1761
+ def _is_tracing_disabled() -> bool:
1762
+ """Check if tracing is disabled via environment variable."""
1763
+ return os.environ.get("DOCENT_DISABLE_TRACING", "").lower() == "true"
1764
+
1765
+
1766
+ def _is_notebook() -> bool:
1767
+ """Check if we're running in a Jupyter notebook."""
1768
+ try:
1769
+ return "ipykernel" in sys.modules
1770
+ except Exception:
1771
+ return False