docent-python 0.1.17a0__py3-none-any.whl → 0.1.27a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

Files changed (45) hide show
  1. docent/_llm_util/__init__.py +0 -0
  2. docent/_llm_util/data_models/__init__.py +0 -0
  3. docent/_llm_util/data_models/exceptions.py +48 -0
  4. docent/_llm_util/data_models/llm_output.py +331 -0
  5. docent/_llm_util/llm_cache.py +193 -0
  6. docent/_llm_util/llm_svc.py +472 -0
  7. docent/_llm_util/model_registry.py +130 -0
  8. docent/_llm_util/providers/__init__.py +0 -0
  9. docent/_llm_util/providers/anthropic.py +537 -0
  10. docent/_llm_util/providers/common.py +41 -0
  11. docent/_llm_util/providers/google.py +530 -0
  12. docent/_llm_util/providers/openai.py +745 -0
  13. docent/_llm_util/providers/openrouter.py +375 -0
  14. docent/_llm_util/providers/preference_types.py +104 -0
  15. docent/_llm_util/providers/provider_registry.py +164 -0
  16. docent/data_models/__init__.py +2 -0
  17. docent/data_models/agent_run.py +6 -5
  18. docent/data_models/chat/__init__.py +6 -1
  19. docent/data_models/citation.py +103 -22
  20. docent/data_models/judge.py +19 -0
  21. docent/data_models/metadata_util.py +16 -0
  22. docent/data_models/remove_invalid_citation_ranges.py +23 -10
  23. docent/data_models/transcript.py +20 -16
  24. docent/data_models/util.py +170 -0
  25. docent/judges/__init__.py +23 -0
  26. docent/judges/analysis.py +77 -0
  27. docent/judges/impl.py +587 -0
  28. docent/judges/runner.py +129 -0
  29. docent/judges/stats.py +205 -0
  30. docent/judges/types.py +311 -0
  31. docent/judges/util/forgiving_json.py +108 -0
  32. docent/judges/util/meta_schema.json +86 -0
  33. docent/judges/util/meta_schema.py +29 -0
  34. docent/judges/util/parse_output.py +87 -0
  35. docent/judges/util/voting.py +139 -0
  36. docent/sdk/agent_run_writer.py +62 -19
  37. docent/sdk/client.py +244 -23
  38. docent/trace.py +413 -90
  39. {docent_python-0.1.17a0.dist-info → docent_python-0.1.27a0.dist-info}/METADATA +11 -5
  40. docent_python-0.1.27a0.dist-info/RECORD +59 -0
  41. docent/data_models/metadata.py +0 -229
  42. docent/data_models/yaml_util.py +0 -12
  43. docent_python-0.1.17a0.dist-info/RECORD +0 -32
  44. {docent_python-0.1.17a0.dist-info → docent_python-0.1.27a0.dist-info}/WHEEL +0 -0
  45. {docent_python-0.1.17a0.dist-info → docent_python-0.1.27a0.dist-info}/licenses/LICENSE.md +0 -0
docent/sdk/client.py CHANGED
@@ -8,6 +8,8 @@ from tqdm import tqdm
8
8
 
9
9
  from docent._log_util.logger import get_logger
10
10
  from docent.data_models.agent_run import AgentRun
11
+ from docent.data_models.judge import Label
12
+ from docent.judges.util.meta_schema import validate_judge_result_schema
11
13
  from docent.loaders import load_inspect
12
14
 
13
15
  logger = get_logger(__name__)
@@ -48,13 +50,24 @@ class Docent:
48
50
 
49
51
  self._login(api_key)
50
52
 
53
+ def _handle_response_errors(self, response: requests.Response):
54
+ """Handle API response and raise informative errors."""
55
+ if response.status_code >= 400:
56
+ try:
57
+ error_data = response.json()
58
+ detail = error_data.get("detail", response.text)
59
+ except Exception:
60
+ detail = response.text
61
+
62
+ raise requests.HTTPError(f"HTTP {response.status_code}: {detail}", response=response)
63
+
51
64
  def _login(self, api_key: str):
52
65
  """Login with email/password to establish session."""
53
66
  self._session.headers.update({"Authorization": f"Bearer {api_key}"})
54
67
 
55
68
  url = f"{self._server_url}/api-keys/test"
56
69
  response = self._session.get(url)
57
- response.raise_for_status()
70
+ self._handle_response_errors(response)
58
71
 
59
72
  logger.info("Logged in with API key")
60
73
  return
@@ -90,7 +103,7 @@ class Docent:
90
103
  }
91
104
 
92
105
  response = self._session.post(url, json=payload)
93
- response.raise_for_status()
106
+ self._handle_response_errors(response)
94
107
 
95
108
  response_data = response.json()
96
109
  collection_id = response_data.get("collection_id")
@@ -134,13 +147,13 @@ class Docent:
134
147
  payload = {"agent_runs": [ar.model_dump(mode="json") for ar in batch]}
135
148
 
136
149
  response = self._session.post(url, json=payload)
137
- response.raise_for_status()
150
+ self._handle_response_errors(response)
138
151
 
139
152
  pbar.update(len(batch))
140
153
 
141
154
  url = f"{self._server_url}/{collection_id}/compute_embeddings"
142
155
  response = self._session.post(url)
143
- response.raise_for_status()
156
+ self._handle_response_errors(response)
144
157
 
145
158
  logger.info(f"Successfully added {total_runs} agent runs to Collection '{collection_id}'")
146
159
  return {"status": "success", "total_runs_added": total_runs}
@@ -156,7 +169,7 @@ class Docent:
156
169
  """
157
170
  url = f"{self._server_url}/collections"
158
171
  response = self._session.get(url)
159
- response.raise_for_status()
172
+ self._handle_response_errors(response)
160
173
  return response.json()
161
174
 
162
175
  def list_rubrics(self, collection_id: str) -> list[dict[str, Any]]:
@@ -173,25 +186,28 @@ class Docent:
173
186
  """
174
187
  url = f"{self._server_url}/rubric/{collection_id}/rubrics"
175
188
  response = self._session.get(url)
176
- response.raise_for_status()
189
+ self._handle_response_errors(response)
177
190
  return response.json()
178
191
 
179
- def get_rubric_run_state(self, collection_id: str, rubric_id: str) -> dict[str, Any]:
192
+ def get_rubric_run_state(
193
+ self, collection_id: str, rubric_id: str, version: int | None = None
194
+ ) -> dict[str, Any]:
180
195
  """Get rubric run state for a given collection and rubric.
181
196
 
182
197
  Args:
183
198
  collection_id: ID of the Collection.
184
199
  rubric_id: The ID of the rubric to get run state for.
200
+ version: The version of the rubric to get run state for. If None, the latest version is used.
185
201
 
186
202
  Returns:
187
- dict: Dictionary containing rubric run state with results, job_id, and total_agent_runs.
203
+ dict: Dictionary containing rubric run state with results, job_id, and total_results_needed.
188
204
 
189
205
  Raises:
190
206
  requests.exceptions.HTTPError: If the API request fails.
191
207
  """
192
208
  url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/rubric_run_state"
193
- response = self._session.get(url)
194
- response.raise_for_status()
209
+ response = self._session.get(url, params={"version": version})
210
+ self._handle_response_errors(response)
195
211
  return response.json()
196
212
 
197
213
  def get_clustering_state(self, collection_id: str, rubric_id: str) -> dict[str, Any]:
@@ -209,7 +225,7 @@ class Docent:
209
225
  """
210
226
  url = f"{self._server_url}/rubric/{collection_id}/{rubric_id}/clustering_job"
211
227
  response = self._session.get(url)
212
- response.raise_for_status()
228
+ self._handle_response_errors(response)
213
229
  return response.json()
214
230
 
215
231
  def get_cluster_centroids(self, collection_id: str, rubric_id: str) -> list[dict[str, Any]]:
@@ -244,6 +260,114 @@ class Docent:
244
260
  clustering_state = self.get_clustering_state(collection_id, rubric_id)
245
261
  return clustering_state.get("assignments", {})
246
262
 
263
+ def create_label_set(
264
+ self,
265
+ collection_id: str,
266
+ name: str,
267
+ label_schema: dict[str, Any],
268
+ description: str | None = None,
269
+ ) -> str:
270
+ """Create a new label set with a JSON schema.
271
+
272
+ Args:
273
+ collection_id: ID of the collection.
274
+ name: Name of the label set.
275
+ label_schema: JSON schema for validating labels in this set.
276
+ description: Optional description of the label set.
277
+
278
+ Returns:
279
+ str: The ID of the created label set.
280
+
281
+ Raises:
282
+ ValueError: If the response is missing the label_set_id.
283
+ jsonschema.ValidationError: If the label schema is invalid.
284
+ requests.exceptions.HTTPError: If the API request fails.
285
+ """
286
+ validate_judge_result_schema(label_schema)
287
+
288
+ url = f"{self._server_url}/label/{collection_id}/label_set"
289
+ payload = {
290
+ "name": name,
291
+ "label_schema": label_schema,
292
+ "description": description,
293
+ }
294
+ response = self._session.post(url, json=payload)
295
+ self._handle_response_errors(response)
296
+ return response.json()["label_set_id"]
297
+
298
+ def add_label(
299
+ self,
300
+ collection_id: str,
301
+ label: Label,
302
+ ) -> dict[str, str]:
303
+ """Create a label in a label set.
304
+
305
+ Args:
306
+ collection_id: ID of the Collection.
307
+ label: A `Label` object that must comply with the label set's schema.
308
+
309
+ Returns:
310
+ dict: API response containing the label_id.
311
+
312
+ Raises:
313
+ requests.exceptions.HTTPError: If the API request fails or validation errors occur.
314
+ """
315
+ url = f"{self._server_url}/label/{collection_id}/label"
316
+ payload = {"label": label.model_dump(mode="json")}
317
+ response = self._session.post(url, json=payload)
318
+ self._handle_response_errors(response)
319
+ return response.json()
320
+
321
+ def add_labels(
322
+ self,
323
+ collection_id: str,
324
+ labels: list[Label],
325
+ ) -> dict[str, Any]:
326
+ """Create multiple labels.
327
+
328
+ Args:
329
+ collection_id: ID of the Collection.
330
+ labels: List of `Label` objects.
331
+
332
+ Returns:
333
+ dict: API response containing label_ids list and optional errors list.
334
+
335
+ Raises:
336
+ ValueError: If no labels are provided.
337
+ requests.exceptions.HTTPError: If the API request fails.
338
+ """
339
+ if not labels:
340
+ raise ValueError("labels must contain at least one entry")
341
+
342
+ url = f"{self._server_url}/label/{collection_id}/labels"
343
+ payload = {"labels": [label.model_dump(mode="json") for label in labels]}
344
+ response = self._session.post(url, json=payload)
345
+ self._handle_response_errors(response)
346
+ return response.json()
347
+
348
+ def get_labels(
349
+ self, collection_id: str, label_set_id: str, filter_valid_labels: bool = False
350
+ ) -> list[dict[str, Any]]:
351
+ """Retrieve all labels in a label set.
352
+
353
+ Args:
354
+ collection_id: ID of the Collection.
355
+ label_set_id: ID of the label set to fetch labels for.
356
+ filter_valid_labels: If True, only return labels that match the label set schema
357
+ INCLUDING requirements. Default is False (returns all labels).
358
+
359
+ Returns:
360
+ list: List of label dictionaries.
361
+
362
+ Raises:
363
+ requests.exceptions.HTTPError: If the API request fails.
364
+ """
365
+ url = f"{self._server_url}/label/{collection_id}/label_set/{label_set_id}/labels"
366
+ params = {"filter_valid_labels": filter_valid_labels}
367
+ response = self._session.get(url, params=params)
368
+ self._handle_response_errors(response)
369
+ return response.json()
370
+
247
371
  def get_agent_run(self, collection_id: str, agent_run_id: str) -> AgentRun | None:
248
372
  """Get a specific agent run by its ID.
249
373
 
@@ -259,7 +383,7 @@ class Docent:
259
383
  """
260
384
  url = f"{self._server_url}/{collection_id}/agent_run"
261
385
  response = self._session.get(url, params={"agent_run_id": agent_run_id})
262
- response.raise_for_status()
386
+ self._handle_response_errors(response)
263
387
  if response.json() is None:
264
388
  return None
265
389
  else:
@@ -267,6 +391,24 @@ class Docent:
267
391
  # TODO(mengk): kinda hacky
268
392
  return AgentRun.model_validate(response.json())
269
393
 
394
+ def get_chat_sessions(self, collection_id: str, agent_run_id: str) -> list[dict[str, Any]]:
395
+ """Get all chat sessions for an agent run, excluding judge result sessions.
396
+
397
+ Args:
398
+ collection_id: ID of the Collection.
399
+ agent_run_id: The ID of the agent run to retrieve chat sessions for.
400
+
401
+ Returns:
402
+ list: List of chat session dictionaries.
403
+
404
+ Raises:
405
+ requests.exceptions.HTTPError: If the API request fails.
406
+ """
407
+ url = f"{self._server_url}/chat/{collection_id}/{agent_run_id}/sessions"
408
+ response = self._session.get(url)
409
+ self._handle_response_errors(response)
410
+ return response.json()
411
+
270
412
  def make_collection_public(self, collection_id: str) -> dict[str, Any]:
271
413
  """Make a collection publicly accessible to anyone with the link.
272
414
 
@@ -281,7 +423,7 @@ class Docent:
281
423
  """
282
424
  url = f"{self._server_url}/{collection_id}/make_public"
283
425
  response = self._session.post(url)
284
- response.raise_for_status()
426
+ self._handle_response_errors(response)
285
427
 
286
428
  logger.info(f"Successfully made Collection '{collection_id}' public")
287
429
  return response.json()
@@ -303,17 +445,96 @@ class Docent:
303
445
  payload = {"email": email}
304
446
  response = self._session.post(url, json=payload)
305
447
 
306
- try:
307
- response.raise_for_status()
308
- except requests.exceptions.HTTPError:
309
- if response.status_code == 404:
310
- raise ValueError(f"The user you are trying to share with ({email}) does not exist.")
311
- else:
312
- raise # Re-raise the original exception
448
+ self._handle_response_errors(response)
313
449
 
314
450
  logger.info(f"Successfully shared Collection '{collection_id}' with {email}")
315
451
  return response.json()
316
452
 
453
+ def get_dql_schema(self, collection_id: str) -> dict[str, Any]:
454
+ """Retrieve the DQL schema for a collection.
455
+
456
+ Args:
457
+ collection_id: ID of the Collection.
458
+
459
+ Returns:
460
+ dict: Dictionary containing available tables, columns, and metadata for DQL queries.
461
+
462
+ Raises:
463
+ requests.exceptions.HTTPError: If the API request fails.
464
+ """
465
+ url = f"{self._server_url}/dql/{collection_id}/schema"
466
+ response = self._session.get(url)
467
+ self._handle_response_errors(response)
468
+ return response.json()
469
+
470
+ def execute_dql(self, collection_id: str, dql: str) -> dict[str, Any]:
471
+ """Execute a DQL query against a collection.
472
+
473
+ Args:
474
+ collection_id: ID of the Collection.
475
+ dql: The DQL query string to execute.
476
+
477
+ Returns:
478
+ dict: Query execution results including rows, columns, execution metadata, and selected columns.
479
+
480
+ Raises:
481
+ ValueError: If `dql` is empty.
482
+ requests.exceptions.HTTPError: If the API request fails or the query is invalid.
483
+ """
484
+ if not dql.strip():
485
+ raise ValueError("dql must be a non-empty string")
486
+
487
+ url = f"{self._server_url}/dql/{collection_id}/execute"
488
+ response = self._session.post(url, json={"dql": dql})
489
+ self._handle_response_errors(response)
490
+ return response.json()
491
+
492
+ def select_agent_run_ids(
493
+ self,
494
+ collection_id: str,
495
+ where_clause: str | None = None,
496
+ limit: int | None = None,
497
+ ) -> list[str]:
498
+ """Convenience helper to fetch agent run IDs via DQL.
499
+
500
+ Args:
501
+ collection_id: ID of the Collection to query.
502
+ where_clause: Optional DQL WHERE clause applied to the agent_runs table.
503
+ limit: Optional LIMIT applied to the underlying DQL query.
504
+
505
+ Returns:
506
+ list[str]: Agent run IDs matching the criteria.
507
+
508
+ Raises:
509
+ ValueError: If the inputs are invalid.
510
+ requests.exceptions.HTTPError: If the API request fails.
511
+ """
512
+ query = "SELECT agent_runs.id AS agent_run_id FROM agent_runs"
513
+
514
+ if where_clause:
515
+ where_clause = where_clause.strip()
516
+ if not where_clause:
517
+ raise ValueError("where_clause must be a non-empty string when provided")
518
+ query += f" WHERE {where_clause}"
519
+
520
+ if limit is not None:
521
+ if limit <= 0:
522
+ raise ValueError("limit must be a positive integer when provided")
523
+ query += f" LIMIT {limit}"
524
+
525
+ result = self.execute_dql(collection_id, query)
526
+ rows = result.get("rows", [])
527
+ agent_run_ids = [str(row[0]) for row in rows if row]
528
+
529
+ if result.get("truncated"):
530
+ logger.warning(
531
+ "DQL query truncated at applied limit %s; returning %s agent run IDs",
532
+ result.get("applied_limit"),
533
+ len(agent_run_ids),
534
+ )
535
+
536
+ return agent_run_ids
537
+
317
538
  def list_agent_run_ids(self, collection_id: str) -> list[str]:
318
539
  """Get all agent run IDs for a collection.
319
540
 
@@ -328,7 +549,7 @@ class Docent:
328
549
  """
329
550
  url = f"{self._server_url}/{collection_id}/agent_run_ids"
330
551
  response = self._session.get(url)
331
- response.raise_for_status()
552
+ self._handle_response_errors(response)
332
553
  return response.json()
333
554
 
334
555
  def recursively_ingest_inspect_logs(self, collection_id: str, fpath: str):
@@ -393,7 +614,7 @@ class Docent:
393
614
  payload = {"agent_runs": [ar.model_dump(mode="json") for ar in batch_list]}
394
615
 
395
616
  response = self._session.post(url, json=payload)
396
- response.raise_for_status()
617
+ self._handle_response_errors(response)
397
618
 
398
619
  runs_from_file += len(batch_list)
399
620
  file_pbar.update(len(batch_list))
@@ -406,7 +627,7 @@ class Docent:
406
627
  logger.info("Computing embeddings for added runs...")
407
628
  url = f"{self._server_url}/{collection_id}/compute_embeddings"
408
629
  response = self._session.post(url)
409
- response.raise_for_status()
630
+ self._handle_response_errors(response)
410
631
 
411
632
  logger.info(
412
633
  f"Successfully ingested {total_runs_added} total agent runs from {len(eval_files)} files"