docent-python 0.1.19a0__py3-none-any.whl → 0.1.27a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

Files changed (38) hide show
  1. docent/_llm_util/__init__.py +0 -0
  2. docent/_llm_util/data_models/__init__.py +0 -0
  3. docent/_llm_util/data_models/exceptions.py +48 -0
  4. docent/_llm_util/data_models/llm_output.py +331 -0
  5. docent/_llm_util/llm_cache.py +193 -0
  6. docent/_llm_util/llm_svc.py +472 -0
  7. docent/_llm_util/model_registry.py +130 -0
  8. docent/_llm_util/providers/__init__.py +0 -0
  9. docent/_llm_util/providers/anthropic.py +537 -0
  10. docent/_llm_util/providers/common.py +41 -0
  11. docent/_llm_util/providers/google.py +530 -0
  12. docent/_llm_util/providers/openai.py +745 -0
  13. docent/_llm_util/providers/openrouter.py +375 -0
  14. docent/_llm_util/providers/preference_types.py +104 -0
  15. docent/_llm_util/providers/provider_registry.py +164 -0
  16. docent/data_models/__init__.py +2 -2
  17. docent/data_models/agent_run.py +1 -0
  18. docent/data_models/judge.py +7 -4
  19. docent/data_models/transcript.py +2 -0
  20. docent/data_models/util.py +170 -0
  21. docent/judges/__init__.py +23 -0
  22. docent/judges/analysis.py +77 -0
  23. docent/judges/impl.py +587 -0
  24. docent/judges/runner.py +129 -0
  25. docent/judges/stats.py +205 -0
  26. docent/judges/types.py +311 -0
  27. docent/judges/util/forgiving_json.py +108 -0
  28. docent/judges/util/meta_schema.json +86 -0
  29. docent/judges/util/meta_schema.py +29 -0
  30. docent/judges/util/parse_output.py +87 -0
  31. docent/judges/util/voting.py +139 -0
  32. docent/sdk/client.py +181 -44
  33. docent/trace.py +362 -44
  34. {docent_python-0.1.19a0.dist-info → docent_python-0.1.27a0.dist-info}/METADATA +11 -5
  35. docent_python-0.1.27a0.dist-info/RECORD +59 -0
  36. docent_python-0.1.19a0.dist-info/RECORD +0 -32
  37. {docent_python-0.1.19a0.dist-info → docent_python-0.1.27a0.dist-info}/WHEEL +0 -0
  38. {docent_python-0.1.19a0.dist-info → docent_python-0.1.27a0.dist-info}/licenses/LICENSE.md +0 -0
docent/sdk/client.py CHANGED
@@ -8,7 +8,8 @@ from tqdm import tqdm
8
8
 
9
9
  from docent._log_util.logger import get_logger
10
10
  from docent.data_models.agent_run import AgentRun
11
- from docent.data_models.judge import JudgeRunLabel
11
+ from docent.data_models.judge import Label
12
+ from docent.judges.util.meta_schema import validate_judge_result_schema
12
13
  from docent.loaders import load_inspect
13
14
 
14
15
  logger = get_logger(__name__)
@@ -50,9 +51,15 @@ class Docent:
50
51
  self._login(api_key)
51
52
 
52
53
  def _handle_response_errors(self, response: requests.Response):
53
- """Handle API response and raise informative errors.
54
- TODO: make this more informative."""
55
- response.raise_for_status()
54
+ """Handle API response and raise informative errors."""
55
+ if response.status_code >= 400:
56
+ try:
57
+ error_data = response.json()
58
+ detail = error_data.get("detail", response.text)
59
+ except Exception:
60
+ detail = response.text
61
+
62
+ raise requests.HTTPError(f"HTTP {response.status_code}: {detail}", response=response)
56
63
 
57
64
  def _login(self, api_key: str):
58
65
  """Login with email/password to establish session."""
@@ -182,21 +189,24 @@ class Docent:
182
189
  self._handle_response_errors(response)
183
190
  return response.json()
184
191
 
185
- def get_rubric_run_state(self, collection_id: str, rubric_id: str) -> dict[str, Any]:
192
+ def get_rubric_run_state(
193
+ self, collection_id: str, rubric_id: str, version: int | None = None
194
+ ) -> dict[str, Any]:
186
195
  """Get rubric run state for a given collection and rubric.
187
196
 
188
197
  Args:
189
198
  collection_id: ID of the Collection.
190
199
  rubric_id: The ID of the rubric to get run state for.
200
+ version: The version of the rubric to get run state for. If None, the latest version is used.
191
201
 
192
202
  Returns:
193
- dict: Dictionary containing rubric run state with results, job_id, and total_agent_runs.
203
+ dict: Dictionary containing rubric run state with results, job_id, and total_results_needed.
194
204
 
195
205
  Raises:
196
206
  requests.exceptions.HTTPError: If the API request fails.
197
207
  """
198
208
  url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/rubric_run_state"
199
- response = self._session.get(url)
209
+ response = self._session.get(url, params={"version": version})
200
210
  self._handle_response_errors(response)
201
211
  return response.json()
202
212
 
@@ -250,30 +260,59 @@ class Docent:
250
260
  clustering_state = self.get_clustering_state(collection_id, rubric_id)
251
261
  return clustering_state.get("assignments", {})
252
262
 
263
+ def create_label_set(
264
+ self,
265
+ collection_id: str,
266
+ name: str,
267
+ label_schema: dict[str, Any],
268
+ description: str | None = None,
269
+ ) -> str:
270
+ """Create a new label set with a JSON schema.
271
+
272
+ Args:
273
+ collection_id: ID of the collection.
274
+ name: Name of the label set.
275
+ label_schema: JSON schema for validating labels in this set.
276
+ description: Optional description of the label set.
277
+
278
+ Returns:
279
+ str: The ID of the created label set.
280
+
281
+ Raises:
282
+ ValueError: If the response is missing the label_set_id.
283
+ jsonschema.ValidationError: If the label schema is invalid.
284
+ requests.exceptions.HTTPError: If the API request fails.
285
+ """
286
+ validate_judge_result_schema(label_schema)
287
+
288
+ url = f"{self._server_url}/label/{collection_id}/label_set"
289
+ payload = {
290
+ "name": name,
291
+ "label_schema": label_schema,
292
+ "description": description,
293
+ }
294
+ response = self._session.post(url, json=payload)
295
+ self._handle_response_errors(response)
296
+ return response.json()["label_set_id"]
297
+
253
298
  def add_label(
254
299
  self,
255
300
  collection_id: str,
256
- rubric_id: str,
257
- label: JudgeRunLabel,
258
- ) -> dict[str, Any]:
259
- """Attach a manual label to an agent run for a rubric.
301
+ label: Label,
302
+ ) -> dict[str, str]:
303
+ """Create a label in a label set.
260
304
 
261
305
  Args:
262
- collection_id: ID of the Collection that owns the rubric.
263
- rubric_id: ID of the rubric the label applies to.
264
- label: A `JudgeRunLabel` that must comply with the rubric's output schema.
306
+ collection_id: ID of the Collection.
307
+ label: A `Label` object that must comply with the label set's schema.
265
308
 
266
309
  Returns:
267
- dict: API response containing a status message.
310
+ dict: API response containing the label_id.
268
311
 
269
312
  Raises:
270
- ValueError: If the label does not target the rubric specified in the path.
271
313
  requests.exceptions.HTTPError: If the API request fails or validation errors occur.
272
314
  """
273
- if label.rubric_id != rubric_id:
274
- raise ValueError("Label rubric_id must match the rubric_id argument")
275
-
276
- url = f"{self._server_url}/rubric/{collection_id}/rubric/{rubric_id}/label"
315
+ url = f"{self._server_url}/label/{collection_id}/label"
277
316
  payload = {"label": label.model_dump(mode="json")}
278
317
  response = self._session.post(url, json=payload)
279
318
  self._handle_response_errors(response)
@@ -282,55 +321,50 @@ class Docent:
282
321
  def add_labels(
283
322
  self,
284
323
  collection_id: str,
285
- rubric_id: str,
286
- labels: list[JudgeRunLabel],
324
+ labels: list[Label],
287
325
  ) -> dict[str, Any]:
288
- """Attach multiple manual labels to a rubric.
326
+ """Create multiple labels.
289
327
 
290
328
  Args:
291
- collection_id: ID of the Collection that owns the rubric.
292
- rubric_id: ID of the rubric the labels apply to.
293
- labels: List of `JudgeRunLabel` objects.
329
+ collection_id: ID of the Collection.
330
+ labels: List of `Label` objects.
294
331
 
295
332
  Returns:
296
- dict: API response containing status information.
333
+ dict: API response containing label_ids list and optional errors list.
297
334
 
298
335
  Raises:
299
336
  ValueError: If no labels are provided.
300
- ValueError: If any label targets a different rubric.
301
337
  requests.exceptions.HTTPError: If the API request fails.
302
338
  """
303
339
  if not labels:
304
340
  raise ValueError("labels must contain at least one entry")
305
341
 
306
- rubric_ids = {label.rubric_id for label in labels}
307
- if rubric_ids != {rubric_id}:
308
- raise ValueError(
309
- "All labels must specify the same rubric_id that is provided to add_labels"
310
- )
311
-
312
- payload = {"labels": [l.model_dump(mode="json") for l in labels]}
313
-
314
- url = f"{self._server_url}/rubric/{collection_id}/rubric/{rubric_id}/labels"
342
+ url = f"{self._server_url}/label/{collection_id}/labels"
343
+ payload = {"labels": [label.model_dump(mode="json") for label in labels]}
315
344
  response = self._session.post(url, json=payload)
316
345
  self._handle_response_errors(response)
317
346
  return response.json()
318
347
 
319
- def get_labels(self, collection_id: str, rubric_id: str) -> list[dict[str, Any]]:
320
- """Retrieve all manual labels for a rubric.
348
+ def get_labels(
349
+ self, collection_id: str, label_set_id: str, filter_valid_labels: bool = False
350
+ ) -> list[dict[str, Any]]:
351
+ """Retrieve all labels in a label set.
321
352
 
322
353
  Args:
323
- collection_id: ID of the Collection that owns the rubric.
324
- rubric_id: ID of the rubric to fetch labels for.
354
+ collection_id: ID of the Collection.
355
+ label_set_id: ID of the label set to fetch labels for.
356
+ filter_valid_labels: If True, only return labels that match the label set schema
357
+ INCLUDING requirements. Default is False (returns all labels).
325
358
 
326
359
  Returns:
327
- list: List of label dictionaries. Each includes agent_run_id and label content.
360
+ list: List of label dictionaries.
328
361
 
329
362
  Raises:
330
363
  requests.exceptions.HTTPError: If the API request fails.
331
364
  """
332
- url = f"{self._server_url}/rubric/{collection_id}/rubric/{rubric_id}/labels"
333
- response = self._session.get(url)
365
+ url = f"{self._server_url}/label/{collection_id}/label_set/{label_set_id}/labels"
366
+ params = {"filter_valid_labels": filter_valid_labels}
367
+ response = self._session.get(url, params=params)
334
368
  self._handle_response_errors(response)
335
369
  return response.json()
336
370
 
@@ -357,6 +391,24 @@ class Docent:
357
391
  # TODO(mengk): kinda hacky
358
392
  return AgentRun.model_validate(response.json())
359
393
 
394
+ def get_chat_sessions(self, collection_id: str, agent_run_id: str) -> list[dict[str, Any]]:
395
+ """Get all chat sessions for an agent run, excluding judge result sessions.
396
+
397
+ Args:
398
+ collection_id: ID of the Collection.
399
+ agent_run_id: The ID of the agent run to retrieve chat sessions for.
400
+
401
+ Returns:
402
+ list: List of chat session dictionaries.
403
+
404
+ Raises:
405
+ requests.exceptions.HTTPError: If the API request fails.
406
+ """
407
+ url = f"{self._server_url}/chat/{collection_id}/{agent_run_id}/sessions"
408
+ response = self._session.get(url)
409
+ self._handle_response_errors(response)
410
+ return response.json()
411
+
360
412
  def make_collection_public(self, collection_id: str) -> dict[str, Any]:
361
413
  """Make a collection publicly accessible to anyone with the link.
362
414
 
@@ -398,6 +450,91 @@ class Docent:
398
450
  logger.info(f"Successfully shared Collection '{collection_id}' with {email}")
399
451
  return response.json()
400
452
 
453
+ def get_dql_schema(self, collection_id: str) -> dict[str, Any]:
454
+ """Retrieve the DQL schema for a collection.
455
+
456
+ Args:
457
+ collection_id: ID of the Collection.
458
+
459
+ Returns:
460
+ dict: Dictionary containing available tables, columns, and metadata for DQL queries.
461
+
462
+ Raises:
463
+ requests.exceptions.HTTPError: If the API request fails.
464
+ """
465
+ url = f"{self._server_url}/dql/{collection_id}/schema"
466
+ response = self._session.get(url)
467
+ self._handle_response_errors(response)
468
+ return response.json()
469
+
470
+ def execute_dql(self, collection_id: str, dql: str) -> dict[str, Any]:
471
+ """Execute a DQL query against a collection.
472
+
473
+ Args:
474
+ collection_id: ID of the Collection.
475
+ dql: The DQL query string to execute.
476
+
477
+ Returns:
478
+ dict: Query execution results including rows, columns, execution metadata, and selected columns.
479
+
480
+ Raises:
481
+ ValueError: If `dql` is empty.
482
+ requests.exceptions.HTTPError: If the API request fails or the query is invalid.
483
+ """
484
+ if not dql.strip():
485
+ raise ValueError("dql must be a non-empty string")
486
+
487
+ url = f"{self._server_url}/dql/{collection_id}/execute"
488
+ response = self._session.post(url, json={"dql": dql})
489
+ self._handle_response_errors(response)
490
+ return response.json()
491
+
492
+ def select_agent_run_ids(
493
+ self,
494
+ collection_id: str,
495
+ where_clause: str | None = None,
496
+ limit: int | None = None,
497
+ ) -> list[str]:
498
+ """Convenience helper to fetch agent run IDs via DQL.
499
+
500
+ Args:
501
+ collection_id: ID of the Collection to query.
502
+ where_clause: Optional DQL WHERE clause applied to the agent_runs table.
503
+ limit: Optional LIMIT applied to the underlying DQL query.
504
+
505
+ Returns:
506
+ list[str]: Agent run IDs matching the criteria.
507
+
508
+ Raises:
509
+ ValueError: If the inputs are invalid.
510
+ requests.exceptions.HTTPError: If the API request fails.
511
+ """
512
+ query = "SELECT agent_runs.id AS agent_run_id FROM agent_runs"
513
+
514
+ if where_clause:
515
+ where_clause = where_clause.strip()
516
+ if not where_clause:
517
+ raise ValueError("where_clause must be a non-empty string when provided")
518
+ query += f" WHERE {where_clause}"
519
+
520
+ if limit is not None:
521
+ if limit <= 0:
522
+ raise ValueError("limit must be a positive integer when provided")
523
+ query += f" LIMIT {limit}"
524
+
525
+ result = self.execute_dql(collection_id, query)
526
+ rows = result.get("rows", [])
527
+ agent_run_ids = [str(row[0]) for row in rows if row]
528
+
529
+ if result.get("truncated"):
530
+ logger.warning(
531
+ "DQL query truncated at applied limit %s; returning %s agent run IDs",
532
+ result.get("applied_limit"),
533
+ len(agent_run_ids),
534
+ )
535
+
536
+ return agent_run_ids
537
+
401
538
  def list_agent_run_ids(self, collection_id: str) -> list[str]:
402
539
  """Get all agent run IDs for a collection.
403
540