docent-python 0.1.22a0__py3-none-any.whl → 0.1.24a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docent-python might be problematic. Click here for more details.

@@ -96,6 +96,7 @@ class LLMOutput:
96
96
  errors: list[LLMException] = field(default_factory=list)
97
97
  usage: UsageMetrics = field(default_factory=UsageMetrics)
98
98
  from_cache: bool = False
99
+ duration: float | None = None
99
100
 
100
101
  @property
101
102
  def non_empty(self) -> bool:
@@ -140,6 +141,7 @@ class LLMOutput:
140
141
  "errors": [e.error_type_id for e in self.errors],
141
142
  "usage": self.usage.to_dict(),
142
143
  "from_cache": self.from_cache,
144
+ "duration": self.duration,
143
145
  }
144
146
 
145
147
  @classmethod
@@ -161,6 +163,7 @@ class LLMOutput:
161
163
  errors=errors,
162
164
  usage=UsageMetrics(**usage),
163
165
  from_cache=bool(data.get("from_cache", False)),
166
+ duration=data.get("duration"),
164
167
  )
165
168
 
166
169
 
@@ -55,7 +55,7 @@ class LLMCache:
55
55
  *,
56
56
  tools: list[ToolInfo] | None = None,
57
57
  tool_choice: Literal["auto", "required"] | None = None,
58
- reasoning_effort: Literal["low", "medium", "high"] | None = None,
58
+ reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None,
59
59
  temperature: float = 1.0,
60
60
  logprobs: bool = False,
61
61
  top_logprobs: int | None = None,
@@ -86,7 +86,7 @@ class LLMCache:
86
86
  *,
87
87
  tools: list[ToolInfo] | None = None,
88
88
  tool_choice: Literal["auto", "required"] | None = None,
89
- reasoning_effort: Literal["low", "medium", "high"] | None = None,
89
+ reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None,
90
90
  temperature: float = 1.0,
91
91
  logprobs: bool = False,
92
92
  top_logprobs: int | None = None,
@@ -121,7 +121,7 @@ class LLMCache:
121
121
  *,
122
122
  tools: list[ToolInfo] | None = None,
123
123
  tool_choice: Literal["auto", "required"] | None = None,
124
- reasoning_effort: Literal["low", "medium", "high"] | None = None,
124
+ reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None,
125
125
  temperature: float = 1.0,
126
126
  logprobs: bool = False,
127
127
  top_logprobs: int | None = None,
@@ -154,7 +154,7 @@ class LLMCache:
154
154
  *,
155
155
  tools: list[ToolInfo] | None = None,
156
156
  tool_choice: Literal["auto", "required"] | None = None,
157
- reasoning_effort: Literal["low", "medium", "high"] | None = None,
157
+ reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None,
158
158
  temperature: float = 1.0,
159
159
  logprobs: bool = False,
160
160
  top_logprobs: int | None = None,
@@ -1,17 +1,8 @@
1
- """
2
- At some point we'll want to do a refactor to support different types of provider/key swapping
3
- due to different scenarios. However, this'll probably be a breaking change, which is why I'm
4
- not doing it now.
5
-
6
- - mengk
7
- """
8
-
1
+ import time
9
2
  import traceback
10
- from contextlib import nullcontext
11
3
  from functools import partial
12
4
  from typing import (
13
5
  Any,
14
- AsyncContextManager,
15
6
  Literal,
16
7
  Protocol,
17
8
  Sequence,
@@ -20,6 +11,7 @@ from typing import (
20
11
  )
21
12
 
22
13
  import anyio
14
+ from anyio import Lock, Semaphore
23
15
  from anyio.abc import TaskGroup
24
16
  from tqdm.auto import tqdm
25
17
 
@@ -44,10 +36,12 @@ from docent._llm_util.providers.provider_registry import (
44
36
  from docent._log_util import get_logger
45
37
  from docent.data_models.chat import ChatMessage, ToolInfo, parse_chat_message
46
38
 
47
- MAX_VALIDATION_ATTEMPTS = 3
48
-
49
39
  logger = get_logger(__name__)
50
40
 
41
+ MAX_VALIDATION_ATTEMPTS = 3
42
+ DEFAULT_MAX_CONCURRENCY = 100
43
+ DEFAULT_SVC_MAX_CONCURRENCY = 100
44
+
51
45
 
52
46
  @runtime_checkable
53
47
  class MessageResolver(Protocol):
@@ -87,11 +81,11 @@ async def _parallelize_calls(
87
81
  tool_choice: Literal["auto", "required"] | None,
88
82
  max_new_tokens: int,
89
83
  temperature: float,
90
- reasoning_effort: Literal["low", "medium", "high"] | None,
84
+ reasoning_effort: Literal["minimal", "low", "medium", "high"] | None,
91
85
  logprobs: bool,
92
86
  top_logprobs: int | None,
93
87
  timeout: float,
94
- semaphore: AsyncContextManager[anyio.Semaphore] | None,
88
+ semaphore: Semaphore,
95
89
  # use_tqdm: bool,
96
90
  cache: LLMCache | None = None,
97
91
  ):
@@ -122,17 +116,19 @@ async def _parallelize_calls(
122
116
  # Save resolved messages to avoid multiple resolutions
123
117
  resolved_messages: list[list[ChatMessage] | None] = [None] * len(inputs)
124
118
 
125
- cancelled_due_to_usage_limit: bool = False
119
+ # Not sure why the cast is necessary for the type checker
120
+ cancelled_due_to_usage_limit: bool = cast(bool, False)
126
121
 
127
122
  async def _limited_task(i: int, cur_input: MessagesInput, tg: TaskGroup):
128
123
  nonlocal responses, pbar, resolved_messages, cancelled_due_to_usage_limit
129
124
 
130
- async with semaphore or nullcontext():
125
+ async with semaphore:
131
126
  messages = _resolve_messages_input(cur_input)
132
127
  resolved_messages[i] = messages
133
128
 
134
129
  retry_count = 0
135
130
  result = None
131
+ call_started_at: float | None = None
136
132
 
137
133
  # Check if there's a cached result
138
134
  cached_result = (
@@ -154,6 +150,7 @@ async def _parallelize_calls(
154
150
  if streaming_callback is not None:
155
151
  await streaming_callback(i, result)
156
152
  else:
153
+ call_started_at = time.perf_counter()
157
154
  while retry_count < MAX_VALIDATION_ATTEMPTS:
158
155
  try:
159
156
  if streaming_callback is None:
@@ -187,7 +184,7 @@ async def _parallelize_calls(
187
184
  errors=[e],
188
185
  )
189
186
  break
190
- except DocentUsageLimitException as e:
187
+ except DocentUsageLimitException as _:
191
188
  result = LLMOutput(
192
189
  model=model_name,
193
190
  completions=[],
@@ -219,6 +216,10 @@ async def _parallelize_calls(
219
216
  )
220
217
  break
221
218
 
219
+ # Only store the elapsed time if we didn't hit the cache and the call was successful
220
+ if cached_result is None and result is not None and call_started_at is not None:
221
+ result.duration = time.perf_counter() - call_started_at
222
+
222
223
  # Always call completion callback with final result (success or error)
223
224
  if completion_callback and result is not None:
224
225
  try:
@@ -273,21 +274,24 @@ async def _parallelize_calls(
273
274
  # Cache what we have so far if something got cancelled
274
275
  except anyio.get_cancelled_exc_class():
275
276
  num_cached = _cache_responses()
276
- logger.info(
277
- f"Cancelled {len(inputs) - num_cached} unfinished LLM API calls; cached {num_cached} completed responses"
278
- )
279
- raise
277
+ if num_cached:
278
+ logger.info(
279
+ f"Cancelled {len(inputs) - num_cached} unfinished LLM API calls, but cached {num_cached} completed responses"
280
+ )
280
281
 
281
- if cancelled_due_to_usage_limit:
282
- for i in range(len(responses)):
283
- if responses[i] is None:
284
- responses[i] = LLMOutput(
285
- model=model_name,
286
- completions=[],
287
- errors=[DocentUsageLimitException()],
288
- )
289
- else:
290
- responses[i].errors.append(DocentUsageLimitException())
282
+ # If the task was cancelled due to usage limit, set the response to a usage limit exception
283
+ if cancelled_due_to_usage_limit:
284
+ for i, response in enumerate(responses):
285
+ if response is None:
286
+ responses[i] = LLMOutput(
287
+ model=model_name,
288
+ completions=[],
289
+ errors=[DocentUsageLimitException()],
290
+ )
291
+ else:
292
+ response.errors.append(DocentUsageLimitException())
293
+
294
+ raise
291
295
 
292
296
  # Cache results if available
293
297
  _cache_responses()
@@ -300,51 +304,88 @@ async def _parallelize_calls(
300
304
  return cast(list[LLMOutput], responses)
301
305
 
302
306
 
303
- class LLMManager:
304
- def __init__(
305
- self,
306
- model_options: list[ModelOption],
307
- api_key_overrides: dict[str, str] | None = None,
308
- use_cache: bool = False,
309
- ):
310
- # TODO(mengk): make this more robust, possibly move to a NoSQL database or something
311
- try:
312
- self.cache = LLMCache() if use_cache else None
313
- except ValueError as e:
314
- logger.warning(f"Disabling LLM cache due to init error: {e}")
315
- self.cache = None
307
+ class BaseLLMService:
308
+ def __init__(self, max_concurrency: int = DEFAULT_SVC_MAX_CONCURRENCY):
309
+ self._semaphore = Semaphore(max_concurrency)
310
+ self._client_cache: dict[tuple[str, str | None], Any] = {} # (provider, api_key) -> client
311
+ self._client_cache_lock = Lock()
312
+
313
+ async def _get_cached_client(self, provider: str, override_key: str | None) -> Any:
314
+ """Return a cached client for the provider/api-key tuple, creating one if needed."""
315
+ cache_key = (provider, override_key)
316
+ async with self._client_cache_lock:
317
+ cached = self._client_cache.get(cache_key)
318
+ if cached is not None:
319
+ return cached
316
320
 
317
- self.model_options = model_options
318
- self.current_model_option_index = 0
319
- self.api_key_overrides = api_key_overrides or {}
321
+ client_factory = PROVIDERS[provider]["async_client_getter"]
322
+ new_client = client_factory(override_key)
323
+ self._client_cache[cache_key] = new_client
324
+ return new_client
320
325
 
321
326
  async def get_completions(
322
327
  self,
328
+ *,
323
329
  inputs: list[MessagesInput],
330
+ model_options: list[ModelOption],
324
331
  tools: list[ToolInfo] | None = None,
325
332
  tool_choice: Literal["auto", "required"] | None = None,
326
- max_new_tokens: int = 32,
333
+ max_new_tokens: int = 1024,
327
334
  temperature: float = 1.0,
328
335
  logprobs: bool = False,
329
336
  top_logprobs: int | None = None,
330
- max_concurrency: int | None = None,
331
- timeout: float = 5.0,
337
+ timeout: float = 120.0,
332
338
  streaming_callback: AsyncLLMOutputStreamingCallback | None = None,
333
339
  validation_callback: AsyncLLMOutputStreamingCallback | None = None,
334
340
  completion_callback: AsyncLLMOutputStreamingCallback | None = None,
341
+ use_cache: bool = False,
342
+ _api_key_overrides: dict[str, str] = dict(),
335
343
  ) -> list[LLMOutput]:
344
+ """Request completions from a configured LLM provider."""
345
+
346
+ # We don't support logprobs for Anthropic yet
347
+ if logprobs:
348
+ for model_option in model_options:
349
+ if model_option.provider == "anthropic":
350
+ raise ValueError(
351
+ f"Logprobs are not supported for Anthropic, so we can't use model {model_option.model_name}"
352
+ )
353
+
354
+ # Instantiate cache
355
+ # TODO(mengk): make this more robust, possibly move to a NoSQL database or something
356
+ try:
357
+ cache = LLMCache() if use_cache else None
358
+ except ValueError as e:
359
+ logger.warning(f"Disabling LLM cache due to init error: {e}")
360
+ cache = None
361
+
362
+ # Initialize pointer to which model we're using; used for model rotation after failures
363
+ current_model_option_index = 0
364
+
365
+ def _rotate_model_option() -> ModelOption | None:
366
+ nonlocal current_model_option_index
367
+
368
+ current_model_option_index += 1
369
+ if current_model_option_index >= len(model_options):
370
+ logger.error("All model options are exhausted")
371
+ return None
372
+
373
+ new_model_option = model_options[current_model_option_index]
374
+ logger.warning(f"Switched to next model {new_model_option.model_name}")
375
+ return new_model_option
376
+
336
377
  while True:
337
378
  # Parse the current model option
338
- cur_option = self.model_options[self.current_model_option_index]
379
+ cur_option = model_options[current_model_option_index]
339
380
  provider, model_name, reasoning_effort = (
340
381
  cur_option.provider,
341
382
  cur_option.model_name,
342
383
  cur_option.reasoning_effort,
343
384
  )
344
385
 
345
- override_key = self.api_key_overrides.get(provider)
386
+ override_key = _api_key_overrides.get(provider)
346
387
 
347
- client = PROVIDERS[provider]["async_client_getter"](override_key)
388
+ client = await self._get_cached_client(provider, override_key)
348
389
  single_output_getter = PROVIDERS[provider]["single_output_getter"]
349
390
  single_streaming_output_getter = PROVIDERS[provider]["single_streaming_output_getter"]
350
391
 
@@ -369,10 +410,8 @@ class LLMManager:
369
410
  logprobs=logprobs,
370
411
  top_logprobs=top_logprobs,
371
412
  timeout=timeout,
372
- semaphore=(
373
- anyio.Semaphore(max_concurrency) if max_concurrency is not None else None
374
- ),
375
- cache=self.cache,
413
+ semaphore=self._semaphore,
414
+ cache=cache,
376
415
  )
377
416
  assert len(outputs) == len(inputs), "Number of outputs must match number of messages"
378
417
 
@@ -388,23 +427,13 @@ class LLMManager:
388
427
  )
389
428
  if num_rotation_errors > 0:
390
429
  logger.warning(f"{model_name}: {num_rotation_errors} API errors")
391
- if not self._rotate_model_option():
430
+ if not _rotate_model_option():
392
431
  break
393
432
  else:
394
433
  break
395
434
 
396
435
  return outputs
397
436
 
398
- def _rotate_model_option(self) -> ModelOption | None:
399
- self.current_model_option_index += 1
400
- if self.current_model_option_index >= len(self.model_options):
401
- logger.error("All model options are exhausted")
402
- return None
403
-
404
- new_model_option = self.model_options[self.current_model_option_index]
405
- logger.warning(f"Switched to next model {new_model_option.model_name}")
406
- return new_model_option
407
-
408
437
 
409
438
  async def get_llm_completions_async(
410
439
  inputs: list[MessagesInput],
@@ -415,40 +444,29 @@ async def get_llm_completions_async(
415
444
  temperature: float = 1.0,
416
445
  logprobs: bool = False,
417
446
  top_logprobs: int | None = None,
418
- max_concurrency: int = 100,
419
447
  timeout: float = 120.0,
420
448
  streaming_callback: AsyncLLMOutputStreamingCallback | None = None,
421
449
  validation_callback: AsyncLLMOutputStreamingCallback | None = None,
422
450
  completion_callback: AsyncLLMOutputStreamingCallback | None = None,
423
451
  use_cache: bool = False,
424
- api_key_overrides: dict[str, str] | None = None,
452
+ _api_key_overrides: dict[str, str] = dict(),
425
453
  ) -> list[LLMOutput]:
426
- # We don't support logprobs for Anthropic yet
427
- if logprobs:
428
- for model_option in model_options:
429
- if model_option.provider == "anthropic":
430
- raise ValueError(
431
- f"Logprobs are not supported for Anthropic, so we can't use model {model_option.model_name}"
432
- )
454
+ """Convenience method for backward compatibility"""
433
455
 
434
- # Create the LLM manager
435
- llm_manager = LLMManager(
456
+ svc = BaseLLMService()
457
+ return await svc.get_completions(
458
+ inputs=inputs,
436
459
  model_options=model_options,
437
- api_key_overrides=api_key_overrides,
438
- use_cache=use_cache,
439
- )
440
-
441
- return await llm_manager.get_completions(
442
- inputs,
443
460
  tools=tools,
444
461
  tool_choice=tool_choice,
445
462
  max_new_tokens=max_new_tokens,
446
463
  temperature=temperature,
447
464
  logprobs=logprobs,
448
465
  top_logprobs=top_logprobs,
449
- max_concurrency=max_concurrency,
450
466
  timeout=timeout,
451
467
  streaming_callback=streaming_callback,
452
468
  validation_callback=validation_callback,
453
469
  completion_callback=completion_callback,
470
+ use_cache=use_cache,
471
+ _api_key_overrides=_api_key_overrides,
454
472
  )
@@ -22,7 +22,7 @@ class ModelOption(BaseModel):
22
22
 
23
23
  provider: str
24
24
  model_name: str
25
- reasoning_effort: Literal["low", "medium", "high"] | None = None
25
+ reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None
26
26
 
27
27
 
28
28
  class ModelOptionWithContext(BaseModel):
@@ -39,7 +39,7 @@ class ModelOptionWithContext(BaseModel):
39
39
 
40
40
  provider: str
41
41
  model_name: str
42
- reasoning_effort: Literal["low", "medium", "high"] | None = None
42
+ reasoning_effort: Literal["minimal", "low", "medium", "high"] | None = None
43
43
  context_window: int
44
44
  uses_byok: bool
45
45
 
@@ -1,13 +1,13 @@
1
1
  from docent.data_models.agent_run import AgentRun
2
2
  from docent.data_models.citation import Citation
3
- from docent.data_models.judge import JudgeRunLabel
3
+ from docent.data_models.judge import Label
4
4
  from docent.data_models.regex import RegexSnippet
5
5
  from docent.data_models.transcript import Transcript, TranscriptGroup
6
6
 
7
7
  __all__ = [
8
8
  "AgentRun",
9
9
  "Citation",
10
- "JudgeRunLabel",
10
+ "Label",
11
11
  "RegexSnippet",
12
12
  "Transcript",
13
13
  "TranscriptGroup",
@@ -6,11 +6,14 @@ from uuid import uuid4
6
6
  from pydantic import BaseModel, Field
7
7
 
8
8
 
9
- class JudgeRunLabel(BaseModel):
9
+ class Label(BaseModel):
10
10
  id: str = Field(default_factory=lambda: str(uuid4()))
11
+
12
+ label_set_id: str
13
+
14
+ label_value: dict[str, Any]
15
+
11
16
  agent_run_id: str
12
- rubric_id: str
13
- label: dict[str, Any]
14
17
 
15
18
 
16
- __all__ = ["JudgeRunLabel"]
19
+ __all__ = ["Label"]
docent/judges/__init__.py CHANGED
@@ -3,6 +3,7 @@ from docent.judges.types import (
3
3
  JudgeResult,
4
4
  JudgeResultCompletionCallback,
5
5
  JudgeResultWithCitations,
6
+ JudgeVariant,
6
7
  ResultType,
7
8
  Rubric,
8
9
  )
@@ -18,4 +19,5 @@ __all__ = [
18
19
  "JudgeResultWithCitations",
19
20
  "JudgeResultCompletionCallback",
20
21
  "ResultType",
22
+ "JudgeVariant",
21
23
  ]
@@ -0,0 +1,77 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ import anyio
6
+ from pydantic import BaseModel
7
+ from pydantic_core import to_jsonable_python
8
+ from tqdm.auto import tqdm
9
+
10
+ from docent._log_util import get_logger
11
+ from docent.data_models.agent_run import AgentRun
12
+ from docent.judges.impl import BaseJudge
13
+ from docent.judges.util.voting import JudgeOutputDistribution
14
+
15
+ logger = get_logger(__name__)
16
+
17
+
18
+ class MultiReflectRollouts(BaseModel):
19
+ """Object is associated with a single agent run"""
20
+
21
+ agent_run_id: str
22
+
23
+ first_step_rollouts: list[dict[str, Any]]
24
+ first_step_rollout_metadata: list[dict[str, Any] | None]
25
+ # Each index in second_step_rollouts corresponds to an index in first_step_combinations
26
+ # Step 2 rollouts are computed by passing each step 1 combo into the judge several times
27
+ first_step_combinations: list[list[dict[str, Any]]] | None = None
28
+ second_step_rollouts: list[list[dict[str, Any]]] | None = None
29
+ second_step_rollout_metadata: list[list[dict[str, Any] | None]] | None = None
30
+
31
+ distributions: dict[str, JudgeOutputDistribution]
32
+
33
+
34
+ async def collect_judge_pvs(
35
+ judge: BaseJudge,
36
+ agent_runs: list[AgentRun],
37
+ *,
38
+ results_path: Path,
39
+ estimate_output_distrs_kwargs: dict[str, Any],
40
+ ):
41
+ if results_path.exists():
42
+ raise FileExistsError(f"Results path already exists: {results_path}")
43
+ results_path.parent.mkdir(parents=True, exist_ok=True)
44
+
45
+ results = dict[str, MultiReflectRollouts]()
46
+ persist_lock = anyio.Lock()
47
+ pbar = tqdm(total=len(agent_runs), desc="Processing agent runs")
48
+
49
+ async def _persist():
50
+ async with persist_lock:
51
+ with open(str(results_path), "w") as f:
52
+ json.dump(to_jsonable_python(results), f, indent=2)
53
+
54
+ async def _execute_for_agent_run(agent_run: AgentRun):
55
+ result = await judge.estimate_output_distrs(agent_run, **estimate_output_distrs_kwargs)
56
+ if result is None:
57
+ pbar.update(1)
58
+ return
59
+
60
+ distrs, metadata = result
61
+ results[agent_run.id] = MultiReflectRollouts.model_validate(
62
+ {
63
+ "agent_run_id": agent_run.id,
64
+ "distributions": distrs,
65
+ **metadata,
66
+ }
67
+ )
68
+ await _persist()
69
+ pbar.update(1)
70
+
71
+ async with anyio.create_task_group() as tg_outer:
72
+ for agent_run in agent_runs:
73
+ tg_outer.start_soon(_execute_for_agent_run, agent_run)
74
+
75
+ pbar.close()
76
+
77
+ return results