langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -14,20 +14,51 @@
14
14
  """Interface for language model."""
15
15
 
16
16
  import abc
17
+ import contextlib
17
18
  import dataclasses
18
19
  import enum
20
+ import functools
21
+ import math
22
+ import threading
19
23
  import time
20
- from typing import Annotated, Any, Callable, Sequence, Tuple, Type, Union
24
+ from typing import Annotated, Any, Callable, Iterator, Optional, Sequence, Tuple, Type, Union
21
25
  from langfun.core import component
22
26
  from langfun.core import concurrent
23
27
  from langfun.core import console
24
28
  from langfun.core import message as message_lib
29
+
25
30
  import pyglove as pg
26
31
 
27
32
  TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
28
33
  DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
29
34
 
30
35
 
36
+ #
37
+ # Common errors during calling language models.
38
+ #
39
+
40
+
41
+ class LMError(RuntimeError):
42
+ """Base class for language model errors."""
43
+
44
+
45
+ class RetryableLMError(LMError):
46
+ """Base class for LLM errors that can be solved by retrying."""
47
+
48
+
49
+ class RateLimitError(RetryableLMError):
50
+ """Error for rate limit reached."""
51
+
52
+
53
+ class TemporaryLMError(RetryableLMError):
54
+ """Error for temporary service issues that can be retried."""
55
+
56
+
57
+ #
58
+ # Language model input/output interfaces.
59
+ #
60
+
61
+
31
62
  class LMSample(pg.Object):
32
63
  """Response candidate."""
33
64
 
@@ -50,12 +81,140 @@ class LMSample(pg.Object):
50
81
  ] = None
51
82
 
52
83
 
84
+ class RetryStats(pg.Object):
85
+ """Retry stats, which is aggregated across multiple retry entries."""
86
+
87
+ num_occurences: Annotated[
88
+ int,
89
+ 'Total number of retry attempts on LLM (excluding the first attempt).',
90
+ ] = 0
91
+ total_wait_interval: Annotated[
92
+ float, 'Total wait interval in seconds due to retry.'
93
+ ] = 0
94
+ total_call_interval: Annotated[
95
+ float, 'Total LLM call interval in seconds.'
96
+ ] = 0
97
+ errors: Annotated[
98
+ dict[str, int],
99
+ 'A Counter of error types encountered during the retry attempts.',
100
+ ] = {}
101
+
102
+ @classmethod
103
+ def from_retry_entries(
104
+ cls, retry_entries: Sequence[concurrent.RetryEntry]
105
+ ) -> 'RetryStats':
106
+ """Creates a RetryStats from a sequence of RetryEntry."""
107
+ if not retry_entries:
108
+ return RetryStats()
109
+ errors = {}
110
+ for retry in retry_entries:
111
+ if retry.error is not None:
112
+ errors[retry.error.__class__.__name__] = (
113
+ errors.get(retry.error.__class__.__name__, 0) + 1
114
+ )
115
+ return RetryStats(
116
+ num_occurences=len(retry_entries) - 1,
117
+ total_wait_interval=sum(e.wait_interval for e in retry_entries),
118
+ total_call_interval=sum(e.call_interval for e in retry_entries),
119
+ errors=errors,
120
+ )
121
+
122
+ def __add__(self, other: 'RetryStats') -> 'RetryStats':
123
+ errors = self.errors.copy()
124
+ for error, count in other.errors.items():
125
+ errors[error] = errors.get(error, 0) + count
126
+ return RetryStats(
127
+ num_occurences=self.num_occurences + other.num_occurences,
128
+ total_wait_interval=self.total_wait_interval
129
+ + other.total_wait_interval,
130
+ total_call_interval=self.total_call_interval
131
+ + other.total_call_interval,
132
+ errors=errors,
133
+ )
134
+
135
+ def __radd__(self, other: 'RetryStats') -> 'RetryStats':
136
+ return self + other
137
+
138
+
53
139
  class LMSamplingUsage(pg.Object):
54
140
  """Usage information per completion."""
55
141
 
56
142
  prompt_tokens: int
57
143
  completion_tokens: int
58
144
  total_tokens: int
145
+ num_requests: int = 1
146
+ estimated_cost: Annotated[
147
+ float | None,
148
+ (
149
+ 'Estimated cost in US dollars. If None, cost estimating is not '
150
+ 'suppported on the model being queried.'
151
+ ),
152
+ ] = None
153
+ retry_stats: RetryStats = RetryStats()
154
+
155
+ def __bool__(self) -> bool:
156
+ return self.num_requests > 0
157
+
158
+ @property
159
+ def average_prompt_tokens(self) -> int:
160
+ """Returns the average prompt tokens per request."""
161
+ return self.prompt_tokens // self.num_requests
162
+
163
+ @property
164
+ def average_completion_tokens(self) -> int:
165
+ """Returns the average completion tokens per request."""
166
+ return self.completion_tokens // self.num_requests
167
+
168
+ @property
169
+ def average_total_tokens(self) -> int:
170
+ """Returns the average total tokens per request."""
171
+ return self.total_tokens // self.num_requests
172
+
173
+ @property
174
+ def average_estimated_cost(self) -> float | None:
175
+ """Returns the average estimated cost per request."""
176
+ if self.estimated_cost is None:
177
+ return None
178
+ return self.estimated_cost / self.num_requests
179
+
180
+ def __add__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
181
+ if other is None:
182
+ return self
183
+ if self.estimated_cost is None:
184
+ estimated_cost = other.estimated_cost
185
+ elif other.estimated_cost is None:
186
+ estimated_cost = self.estimated_cost
187
+ else:
188
+ estimated_cost = self.estimated_cost + other.estimated_cost
189
+ return LMSamplingUsage(
190
+ prompt_tokens=self.prompt_tokens + other.prompt_tokens,
191
+ completion_tokens=self.completion_tokens + other.completion_tokens,
192
+ total_tokens=self.total_tokens + other.total_tokens,
193
+ num_requests=self.num_requests + other.num_requests,
194
+ estimated_cost=estimated_cost,
195
+ retry_stats=self.retry_stats + other.retry_stats,
196
+ )
197
+
198
+ def __radd__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
199
+ return self + other
200
+
201
+
202
+ class UsageNotAvailable(LMSamplingUsage):
203
+ """Usage information not available."""
204
+ prompt_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
205
+ completion_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
206
+ total_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
207
+ estimated_cost: pg.typing.Float(default=None, is_noneable=True).freeze() # pytype: disable=invalid-annotation
208
+
209
+ def __add__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
210
+ if other is None:
211
+ return self
212
+ return UsageNotAvailable(
213
+ num_requests=self.num_requests + other.num_requests
214
+ )
215
+
216
+ def __radd__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
217
+ return self + other
59
218
 
60
219
 
61
220
  class LMSamplingResult(pg.Object):
@@ -70,9 +229,14 @@ class LMSamplingResult(pg.Object):
70
229
  ] = []
71
230
 
72
231
  usage: Annotated[
73
- LMSamplingUsage | None,
232
+ LMSamplingUsage,
74
233
  'Usage information. Currently only OpenAI models are supported.',
75
- ] = None
234
+ ] = UsageNotAvailable()
235
+
236
+ is_cached: Annotated[
237
+ bool,
238
+ 'Whether the result is from cache or not.'
239
+ ] = False
76
240
 
77
241
 
78
242
  class LMSamplingOptions(component.Component):
@@ -166,6 +330,11 @@ class LMScoringResult(pg.Object):
166
330
  float,
167
331
  'The log likelyhood of the requested completion towards the prompt.',
168
332
  ]
333
+ gradients: Annotated[
334
+ Any | None,
335
+ '(Optional) gradients from the score method, w.r.t.' +
336
+ ' prompt.metadata.weights.',
337
+ ] = None
169
338
 
170
339
 
171
340
  class LMCache(pg.Object):
@@ -180,6 +349,7 @@ class LMCache(pg.Object):
180
349
  num_hit_expires: int = 0
181
350
  num_misses: int = 0
182
351
  num_updates: int = 0
352
+ num_deletes: int = 0
183
353
 
184
354
  @abc.abstractmethod
185
355
  def get(
@@ -197,6 +367,15 @@ class LMCache(pg.Object):
197
367
  ) -> None:
198
368
  """Puts the result of a prompt generated by a language model in cache."""
199
369
 
370
+ @abc.abstractmethod
371
+ def delete(
372
+ self,
373
+ lm: 'LanguageModel',
374
+ prompt: message_lib.Message,
375
+ seed: int,
376
+ ) -> bool:
377
+ """Deletes the result of a prompt generated by a language model in cache."""
378
+
200
379
  @property
201
380
  @abc.abstractmethod
202
381
  def stats(self) -> Stats:
@@ -290,6 +469,15 @@ class LanguageModel(component.Component):
290
469
  )
291
470
  ] = True
292
471
 
472
+ max_retry_interval: Annotated[
473
+ int,
474
+ (
475
+ 'The max retry interval in seconds. This is useful when the retry '
476
+ 'interval is exponential, to avoid the wait time to grow '
477
+ 'exponentially.'
478
+ )
479
+ ] = 300
480
+
293
481
  debug: Annotated[
294
482
  bool | LMDebugMode,
295
483
  (
@@ -303,7 +491,10 @@ class LanguageModel(component.Component):
303
491
  def __init__(self, *args, **kwargs) -> None:
304
492
  """Overrides __init__ to pass through **kwargs to sampling options."""
305
493
 
306
- sampling_options = kwargs.pop('sampling_options', LMSamplingOptions())
494
+ sampling_options = kwargs.pop(
495
+ 'sampling_options',
496
+ pg.clone(self.__schema__.fields['sampling_options'].default_value)
497
+ )
307
498
  sampling_options_delta = {}
308
499
 
309
500
  for k, v in kwargs.items():
@@ -361,12 +552,13 @@ class LanguageModel(component.Component):
361
552
  response = sample.response
362
553
  response.metadata.score = sample.score
363
554
  response.metadata.logprobs = sample.logprobs
555
+ response.metadata.is_cached = result.is_cached
364
556
 
365
557
  # NOTE(daiyip): Current usage is computed at per-result level,
366
558
  # which is accurate when n=1. For n > 1, we average the usage across
367
559
  # multiple samples.
368
560
  usage = result.usage
369
- if len(result.samples) == 1 or usage is None:
561
+ if len(result.samples) == 1 or isinstance(usage, UsageNotAvailable):
370
562
  response.metadata.usage = usage
371
563
  else:
372
564
  n = len(result.samples)
@@ -374,8 +566,29 @@ class LanguageModel(component.Component):
374
566
  prompt_tokens=usage.prompt_tokens // n,
375
567
  completion_tokens=usage.completion_tokens // n,
376
568
  total_tokens=usage.total_tokens // n,
569
+ estimated_cost=(
570
+ usage.estimated_cost / n if usage.estimated_cost else None
571
+ ),
572
+ retry_stats=RetryStats(
573
+ num_occurences=usage.retry_stats.num_occurences // n,
574
+ total_wait_interval=usage.retry_stats.total_wait_interval
575
+ / n,
576
+ total_call_interval=usage.retry_stats.total_call_interval
577
+ / n,
578
+ errors={
579
+ error: count // n
580
+ for error, count in usage.retry_stats.errors.items()
581
+ },
582
+ ),
377
583
  )
378
584
 
585
+ # Track usage.
586
+ trackers = component.context_value('__usage_trackers__', [])
587
+ if trackers:
588
+ model_id = self.model_id
589
+ for tracker in trackers:
590
+ tracker.track(model_id, usage, result.is_cached)
591
+
379
592
  # Track the prompt for corresponding response.
380
593
  response.source = prompt
381
594
 
@@ -403,7 +616,9 @@ class LanguageModel(component.Component):
403
616
  request_to_result_index[len(requests)] = i
404
617
  requests.append(prompt)
405
618
  else:
406
- results[i] = r.clone()
619
+ result = r.clone()
620
+ assert result.is_cached, result
621
+ results[i] = result
407
622
 
408
623
  # Sample non-cache-hit prompts.
409
624
  if requests:
@@ -420,8 +635,12 @@ class LanguageModel(component.Component):
420
635
  sample.response.set('cache_seed', cache_seed)
421
636
 
422
637
  if cache_seed is not None:
423
- self.cache.put(self, prompt, result.clone(), seed=cache_seed)
424
-
638
+ self.cache.put(
639
+ self,
640
+ prompt,
641
+ result.clone(override=dict(is_cached=True)),
642
+ seed=cache_seed
643
+ )
425
644
  return results # pytype: disable=bad-return-type
426
645
 
427
646
  @abc.abstractmethod
@@ -433,16 +652,16 @@ class LanguageModel(component.Component):
433
652
 
434
653
  def _parallel_execute_with_currency_control(
435
654
  self,
436
- action: Callable[..., Any],
655
+ action: Callable[..., LMSamplingResult],
437
656
  inputs: Sequence[Any],
438
657
  retry_on_errors: Union[
439
658
  None,
440
- Union[Type[Exception], Tuple[Type[Exception], str]],
441
- Sequence[Union[Type[Exception], Tuple[Type[Exception], str]]],
442
- ] = None,
443
- ) -> Any:
659
+ Union[Type[BaseException], Tuple[Type[BaseException], str]],
660
+ Sequence[Union[Type[BaseException], Tuple[Type[BaseException], str]]],
661
+ ] = RetryableLMError,
662
+ ) -> list[Any]:
444
663
  """Helper method for subclasses for implementing _sample."""
445
- return concurrent.concurrent_execute(
664
+ executed_jobs = concurrent.concurrent_execute(
446
665
  action,
447
666
  inputs,
448
667
  executor=self.resource_id if self.max_concurrency else None,
@@ -451,7 +670,16 @@ class LanguageModel(component.Component):
451
670
  max_attempts=self.max_attempts,
452
671
  retry_interval=self.retry_interval,
453
672
  exponential_backoff=self.exponential_backoff,
673
+ max_retry_interval=self.max_retry_interval,
674
+ return_jobs=True,
454
675
  )
676
+ for job in executed_jobs:
677
+ if isinstance(job.result, LMSamplingResult):
678
+ job.result.usage.rebind(
679
+ retry_stats=RetryStats.from_retry_entries(job.retry_entries),
680
+ skip_notification=True,
681
+ )
682
+ return [job.result for job in executed_jobs]
455
683
 
456
684
  def __call__(
457
685
  self, prompt: message_lib.Message, *, cache_seed: int = 0, **kwargs
@@ -479,7 +707,7 @@ class LanguageModel(component.Component):
479
707
  prompt: message_lib.Message,
480
708
  response: message_lib.Message,
481
709
  call_counter: int,
482
- usage: LMSamplingUsage | None,
710
+ usage: LMSamplingUsage,
483
711
  elapse: float,
484
712
  ) -> None:
485
713
  """Outputs debugging information."""
@@ -497,12 +725,13 @@ class LanguageModel(component.Component):
497
725
  self._debug_response(response, call_counter, usage, elapse)
498
726
 
499
727
  def _debug_model_info(
500
- self, call_counter: int, usage: LMSamplingUsage | None) -> None:
728
+ self, call_counter: int, usage: LMSamplingUsage) -> None:
501
729
  """Outputs debugging information about the model."""
502
730
  title_suffix = ''
503
- if usage and usage.total_tokens != 0:
504
- title_suffix = console.colored(
505
- f' (total {usage.total_tokens} tokens)', 'red')
731
+ if usage.total_tokens != 0:
732
+ title_suffix = pg.colored(
733
+ f' (total {usage.total_tokens} tokens)', 'red'
734
+ )
506
735
 
507
736
  console.write(
508
737
  self.format(compact=True, use_inferred=True),
@@ -514,12 +743,12 @@ class LanguageModel(component.Component):
514
743
  self,
515
744
  prompt: message_lib.Message,
516
745
  call_counter: int,
517
- usage: LMSamplingUsage | None,
746
+ usage: LMSamplingUsage,
518
747
  ) -> None:
519
748
  """Outputs debugging information about the prompt."""
520
749
  title_suffix = ''
521
- if usage and usage.prompt_tokens != 0:
522
- title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
750
+ if usage.prompt_tokens != 0:
751
+ title_suffix = pg.colored(f' ({usage.prompt_tokens} tokens)', 'red')
523
752
 
524
753
  console.write(
525
754
  # We use metadata 'formatted_text' for scenarios where the prompt text
@@ -542,15 +771,15 @@ class LanguageModel(component.Component):
542
771
  self,
543
772
  response: message_lib.Message,
544
773
  call_counter: int,
545
- usage: LMSamplingUsage | None,
774
+ usage: LMSamplingUsage,
546
775
  elapse: float
547
776
  ) -> None:
548
777
  """Outputs debugging information about the response."""
549
778
  title_suffix = ' ('
550
- if usage and usage.completion_tokens != 0:
779
+ if usage.completion_tokens != 0:
551
780
  title_suffix += f'{usage.completion_tokens} tokens '
552
781
  title_suffix += f'in {elapse:.2f} seconds)'
553
- title_suffix = console.colored(title_suffix, 'red')
782
+ title_suffix = pg.colored(title_suffix, 'red')
554
783
 
555
784
  console.write(
556
785
  str(response) + '\n',
@@ -560,12 +789,19 @@ class LanguageModel(component.Component):
560
789
 
561
790
  def score(
562
791
  self,
563
- prompt: str | message_lib.Message,
792
+ prompt: str | message_lib.Message | list[message_lib.Message],
564
793
  completions: list[str | message_lib.Message],
565
794
  **kwargs,
566
795
  ) -> list[LMScoringResult]:
567
796
  """Scores the given prompt."""
568
- prompt = message_lib.UserMessage.from_value(prompt)
797
+ if isinstance(prompt, list):
798
+ if len(prompt) != len(completions):
799
+ raise ValueError(
800
+ 'prompt and completions must have the same length.'
801
+ )
802
+ prompt = [message_lib.UserMessage.from_value(p) for p in prompt]
803
+ else:
804
+ prompt = message_lib.UserMessage.from_value(prompt)
569
805
  completions = [message_lib.UserMessage.from_value(c) for c in completions]
570
806
 
571
807
  call_counter = self._call_counter
@@ -581,7 +817,8 @@ class LanguageModel(component.Component):
581
817
  return scoring_results
582
818
 
583
819
  def _score(
584
- self, prompt: message_lib.Message, completions: list[message_lib.Message]
820
+ self, prompt: message_lib.Message | list[message_lib.Message],
821
+ completions: list[message_lib.Message]
585
822
  ) -> list[LMScoringResult]:
586
823
  """Subclass to implement."""
587
824
  raise NotImplementedError(
@@ -590,7 +827,7 @@ class LanguageModel(component.Component):
590
827
 
591
828
  def _debug_score(
592
829
  self,
593
- prompt: message_lib.Message,
830
+ prompt: message_lib.Message | list[message_lib.Message],
594
831
  completions: list[message_lib.Message],
595
832
  scoring_results: list[LMScoringResult],
596
833
  call_counter: int,
@@ -601,7 +838,7 @@ class LanguageModel(component.Component):
601
838
  debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
602
839
 
603
840
  if debug & LMDebugMode.INFO:
604
- self._debug_model_info(call_counter, None)
841
+ self._debug_model_info(call_counter, UsageNotAvailable())
605
842
 
606
843
  if debug & LMDebugMode.PROMPT:
607
844
  console.write(
@@ -609,15 +846,19 @@ class LanguageModel(component.Component):
609
846
  title=f'\n[{call_counter}] SCORING LM WITH PROMPT:',
610
847
  color='green',
611
848
  )
612
- referred_modalities = prompt.referred_modalities()
613
- if referred_modalities:
614
- console.write(
615
- pg.object_utils.kvlist_str(
616
- [(k, repr(v), None) for k, v in referred_modalities.items()]
617
- ),
618
- title=f'\n[{call_counter}] MODALITY OBJECTS SENT TO LM:',
619
- color='green',
620
- )
849
+ if isinstance(prompt, list):
850
+ referred_modalities_lst = [p.referred_modalities() for p in prompt]
851
+ else:
852
+ referred_modalities_lst = [prompt.referred_modalities(),]
853
+ if referred_modalities_lst:
854
+ for referred_modalities in referred_modalities_lst:
855
+ console.write(
856
+ pg.object_utils.kvlist_str(
857
+ [(k, repr(v), None) for k, v in referred_modalities.items()]
858
+ ),
859
+ title=f'\n[{call_counter}] MODALITY OBJECTS SENT TO LM:',
860
+ color='green',
861
+ )
621
862
 
622
863
  if debug & LMDebugMode.RESPONSE:
623
864
  console.write(
@@ -638,6 +879,72 @@ class LanguageModel(component.Component):
638
879
  color='blue',
639
880
  )
640
881
 
882
+ def tokenize(
883
+ self,
884
+ prompt: str | message_lib.Message,
885
+ **kwargs,
886
+ ) -> list[tuple[str | bytes, int]]:
887
+ """Tokenizes the given prompt."""
888
+ prompt = message_lib.UserMessage.from_value(prompt)
889
+ call_counter = self._call_counter
890
+ self._call_counter += 1
891
+
892
+ with component.context(override_attrs=True, **kwargs):
893
+ request_start = time.time()
894
+ tokens = self._tokenize(prompt)
895
+ elapse = time.time() - request_start
896
+ self._debug_tokenize(prompt, tokens, call_counter, elapse)
897
+ return tokens
898
+
899
+ def _tokenize(
900
+ self, prompt: message_lib.Message
901
+ ) -> list[tuple[str | bytes, int]]:
902
+ """Subclass to implement."""
903
+ raise NotImplementedError(
904
+ f'{self.__class__.__name__} does not support tokenization.'
905
+ )
906
+
907
+ def _debug_tokenize(
908
+ self,
909
+ prompt: message_lib.Message,
910
+ tokens: list[tuple[str | bytes, int]],
911
+ call_counter: int,
912
+ elapse: float,
913
+ ):
914
+ debug = self.debug
915
+ if isinstance(debug, bool):
916
+ debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
917
+
918
+ if debug & LMDebugMode.INFO:
919
+ self._debug_model_info(call_counter, UsageNotAvailable())
920
+
921
+ if debug & LMDebugMode.PROMPT:
922
+ console.write(
923
+ prompt,
924
+ title=f'\n[{call_counter}] PROMPT TO TOKENIZE:',
925
+ color='green',
926
+ )
927
+ referred_modalities_lst = [prompt.referred_modalities(),]
928
+ if referred_modalities_lst:
929
+ for referred_modalities in referred_modalities_lst:
930
+ console.write(
931
+ pg.object_utils.kvlist_str(
932
+ [(k, repr(v), None) for k, v in referred_modalities.items()]
933
+ ),
934
+ title=f'\n[{call_counter}] MODALITY OBJECTS SENT TO LM:',
935
+ color='green',
936
+ )
937
+
938
+ if debug & LMDebugMode.RESPONSE:
939
+ console.write(
940
+ tokens,
941
+ title=(
942
+ f'\n[{call_counter}] {len(tokens)} TOKENS RETURNED '
943
+ f'(in {elapse:.2f} seconds):'
944
+ ),
945
+ color='blue',
946
+ )
947
+
641
948
  def rate_to_max_concurrency(
642
949
  self, requests_per_min: float = 0, tokens_per_min: float = 0
643
950
  ) -> int:
@@ -648,3 +955,234 @@ class LanguageModel(component.Component):
648
955
  return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
649
956
  else:
650
957
  return DEFAULT_MAX_CONCURRENCY # Default of 1
958
+
959
+
960
+ class UsageSummary(pg.Object, pg.views.HtmlTreeView.Extension):
961
+ """Usage sumary."""
962
+
963
+ class AggregatedUsage(pg.Object):
964
+ """Aggregated usage."""
965
+
966
+ total: LMSamplingUsage = LMSamplingUsage(0, 0, 0, 0, 0.0)
967
+ breakdown: dict[str, LMSamplingUsage] = {}
968
+
969
+ def __bool__(self) -> bool:
970
+ """Returns True if the usage is non-empty."""
971
+ return bool(self.breakdown)
972
+
973
+ def add(
974
+ self,
975
+ model_id: str,
976
+ usage: LMSamplingUsage,
977
+ ) -> None:
978
+ """Adds an entry to the breakdown."""
979
+ aggregated = self.breakdown.get(model_id, None)
980
+ with pg.notify_on_change(False):
981
+ self.breakdown[model_id] = usage + aggregated
982
+ self.rebind(
983
+ total=self.total + usage,
984
+ raise_on_no_change=False
985
+ )
986
+
987
+ def merge(self, other: 'UsageSummary.AggregatedUsage') -> None:
988
+ """Merges the usage summary."""
989
+ with pg.notify_on_change(False):
990
+ for model_id, usage in other.breakdown.items():
991
+ self.add(model_id, usage)
992
+
993
+ def _on_bound(self):
994
+ super()._on_bound()
995
+ self._usage_badge = None
996
+ self._lock = threading.Lock()
997
+
998
+ @property
999
+ def total(self) -> LMSamplingUsage:
1000
+ return self.cached.total + self.uncached.total
1001
+
1002
+ def add(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
1003
+ """Updates the usage summary."""
1004
+ with self._lock:
1005
+ if is_cached:
1006
+ usage.rebind(estimated_cost=0.0, skip_notification=True)
1007
+ self.cached.add(model_id, usage)
1008
+ else:
1009
+ self.uncached.add(model_id, usage)
1010
+ self._update_view()
1011
+
1012
+ def merge(self, other: 'UsageSummary', as_cached: bool = False) -> None:
1013
+ """Aggregates the usage summary.
1014
+
1015
+ Args:
1016
+ other: The usage summary to merge.
1017
+ as_cached: Whether to merge the usage summary as cached.
1018
+ """
1019
+ with self._lock:
1020
+ self.cached.merge(other.cached)
1021
+ if as_cached:
1022
+ self.cached.merge(other.uncached)
1023
+ else:
1024
+ self.uncached.merge(other.uncached)
1025
+ self._update_view()
1026
+
1027
+ def _sym_nondefault(self) -> dict[str, Any]:
1028
+ """Overrides nondefault values so volatile values are not included."""
1029
+ return dict()
1030
+
1031
+ #
1032
+ # Html views for the usage summary.
1033
+ #
1034
+
1035
+ def _update_view(self):
1036
+ if self._usage_badge is not None:
1037
+ self._usage_badge.update(
1038
+ self._badge_text(),
1039
+ tooltip=pg.format(
1040
+ self, verbose=False, custom_format=self._tooltip_format
1041
+ ),
1042
+ styles=dict(color=self._badge_color()),
1043
+ )
1044
+
1045
+ def _badge_text(self) -> str:
1046
+ if self.total.estimated_cost is not None:
1047
+ return f'{self.total.estimated_cost:.3f}'
1048
+ return '0.000'
1049
+
1050
+ def _badge_color(self) -> str | None:
1051
+ if self.total.estimated_cost is None or self.total.estimated_cost < 1.0:
1052
+ return None
1053
+
1054
+ # Step 1: The normal cost range is around 1e-3 to 1e5.
1055
+ # Therefore we normalize the log10 value from [-3, 5] to [0, 1].
1056
+ normalized_value = (math.log10(self.total.estimated_cost) + 3) / (5 + 3)
1057
+
1058
+ # Step 2: Interpolate between green and red
1059
+ red = int(255 * normalized_value)
1060
+ green = int(255 * (1 - normalized_value))
1061
+ return f'rgb({red}, {green}, 0)'
1062
+
1063
+ def _tooltip_format(self, v, root_indent):
1064
+ del root_indent
1065
+ if isinstance(v, int):
1066
+ return f'{v:,}'
1067
+ if isinstance(v, float):
1068
+ return f'{v:,.3f}'
1069
+ return None
1070
+
1071
+ def _html_tree_view(
1072
+ self,
1073
+ *,
1074
+ view: pg.views.HtmlTreeView,
1075
+ extra_flags: dict[str, Any] | None = None,
1076
+ **kwargs
1077
+ ) -> pg.Html:
1078
+ extra_flags = extra_flags or {}
1079
+ as_badge = extra_flags.pop('as_badge', False)
1080
+ interactive = extra_flags.get('interactive', True)
1081
+ if as_badge:
1082
+ usage_badge = self._usage_badge
1083
+ if usage_badge is None:
1084
+ usage_badge = pg.views.html.controls.Badge(
1085
+ self._badge_text(),
1086
+ tooltip=pg.format(
1087
+ self, custom_format=self._tooltip_format, verbose=False
1088
+ ),
1089
+ css_classes=['usage-summary'],
1090
+ styles=dict(color=self._badge_color()),
1091
+ interactive=True,
1092
+ )
1093
+ if interactive:
1094
+ self._usage_badge = usage_badge
1095
+ return usage_badge.to_html()
1096
+ return super()._html_tree_view(
1097
+ view=view,
1098
+ extra_flags=extra_flags,
1099
+ **kwargs
1100
+ )
1101
+
1102
+ @classmethod
1103
+ @functools.cache
1104
+ def _html_tree_view_css_styles(cls) -> list[str]:
1105
+ return super()._html_tree_view_css_styles() + [
1106
+ """
1107
+ .usage-summary.label {
1108
+ display: inline-flex;
1109
+ border-radius: 5px;
1110
+ padding: 5px;
1111
+ background-color: #f1f1f1;
1112
+ color: #CCC;
1113
+ }
1114
+ .usage-summary.label::before {
1115
+ content: '$';
1116
+ }
1117
+ """
1118
+ ]
1119
+
1120
+ pg.members(
1121
+ dict(
1122
+ cached=(
1123
+ pg.typing.Object(
1124
+ UsageSummary.AggregatedUsage,
1125
+ default=UsageSummary.AggregatedUsage()
1126
+ ),
1127
+ 'Aggregated usages for cached LLM calls.'
1128
+ ),
1129
+ uncached=(
1130
+ pg.typing.Object(
1131
+ UsageSummary.AggregatedUsage,
1132
+ default=UsageSummary.AggregatedUsage()
1133
+ ),
1134
+ 'Aggregated usages for uncached LLM calls.'
1135
+ ),
1136
+ )
1137
+ )(UsageSummary)
1138
+
1139
+
1140
+ class _UsageTracker:
1141
+ """Usage tracker."""
1142
+
1143
+ def __init__(self, model_ids: set[str] | None):
1144
+ self.model_ids = model_ids
1145
+ self.usage_summary = UsageSummary()
1146
+
1147
+ def track(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
1148
+ if self.model_ids is None or model_id in self.model_ids:
1149
+ self.usage_summary.add(model_id, usage, is_cached)
1150
+
1151
+
1152
+ @contextlib.contextmanager
1153
+ def track_usages(
1154
+ *lm: Union[str, LanguageModel]
1155
+ ) -> Iterator[UsageSummary]:
1156
+ """Context manager to track the usages of all language models in scope.
1157
+
1158
+ `lf.track_usages` works with threads spawned by `lf.concurrent_map` and
1159
+ `lf.concurrent_execute`.
1160
+
1161
+ Example:
1162
+ ```
1163
+ lm = lf.llms.GeminiPro1()
1164
+ with lf.track_usages() as usages:
1165
+ # invoke any code that will call LLMs.
1166
+
1167
+ print(usages[lm.model_id])
1168
+ ```
1169
+
1170
+ Args:
1171
+ *lm: The language model(s) to track. If None, track all models in scope.
1172
+
1173
+ Yields:
1174
+ A dictionary of model ID to usage. If a model does not supports usage
1175
+ counting, the dict entry will be None.
1176
+ """
1177
+ if not lm:
1178
+ model_ids = None
1179
+ else:
1180
+ model_ids = [m.model_id if isinstance(m, LanguageModel) else m for m in lm]
1181
+
1182
+ trackers = component.context_value('__usage_trackers__', [])
1183
+ tracker = _UsageTracker(set(model_ids) if model_ids else None)
1184
+ with component.context(__usage_trackers__=trackers + [tracker]):
1185
+ try:
1186
+ yield tracker.usage_summary
1187
+ finally:
1188
+ pass