langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -14,16 +14,50 @@
14
14
  """Interface for language model."""
15
15
 
16
16
  import abc
17
+ import contextlib
17
18
  import dataclasses
18
19
  import enum
20
+ import functools
21
+ import math
22
+ import threading
19
23
  import time
20
- from typing import Annotated, Any, Callable, Sequence, Tuple, Type, Union
24
+ from typing import Annotated, Any, Callable, Iterator, Optional, Sequence, Tuple, Type, Union
21
25
  from langfun.core import component
22
26
  from langfun.core import concurrent
23
27
  from langfun.core import console
24
28
  from langfun.core import message as message_lib
29
+
25
30
  import pyglove as pg
26
31
 
32
+ TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
33
+ DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
34
+
35
+
36
+ #
37
+ # Common errors during calling language models.
38
+ #
39
+
40
+
41
+ class LMError(RuntimeError):
42
+ """Base class for language model errors."""
43
+
44
+
45
+ class RetryableLMError(LMError):
46
+ """Base class for LLM errors that can be solved by retrying."""
47
+
48
+
49
+ class RateLimitError(RetryableLMError):
50
+ """Error for rate limit reached."""
51
+
52
+
53
+ class TemporaryLMError(RetryableLMError):
54
+ """Error for temporary service issues that can be retried."""
55
+
56
+
57
+ #
58
+ # Language model input/output interfaces.
59
+ #
60
+
27
61
 
28
62
  class LMSample(pg.Object):
29
63
  """Response candidate."""
@@ -47,6 +81,142 @@ class LMSample(pg.Object):
47
81
  ] = None
48
82
 
49
83
 
84
+ class RetryStats(pg.Object):
85
+ """Retry stats, which is aggregated across multiple retry entries."""
86
+
87
+ num_occurences: Annotated[
88
+ int,
89
+ 'Total number of retry attempts on LLM (excluding the first attempt).',
90
+ ] = 0
91
+ total_wait_interval: Annotated[
92
+ float, 'Total wait interval in seconds due to retry.'
93
+ ] = 0
94
+ total_call_interval: Annotated[
95
+ float, 'Total LLM call interval in seconds.'
96
+ ] = 0
97
+ errors: Annotated[
98
+ dict[str, int],
99
+ 'A Counter of error types encountered during the retry attempts.',
100
+ ] = {}
101
+
102
+ @classmethod
103
+ def from_retry_entries(
104
+ cls, retry_entries: Sequence[concurrent.RetryEntry]
105
+ ) -> 'RetryStats':
106
+ """Creates a RetryStats from a sequence of RetryEntry."""
107
+ if not retry_entries:
108
+ return RetryStats()
109
+ errors = {}
110
+ for retry in retry_entries:
111
+ if retry.error is not None:
112
+ errors[retry.error.__class__.__name__] = (
113
+ errors.get(retry.error.__class__.__name__, 0) + 1
114
+ )
115
+ return RetryStats(
116
+ num_occurences=len(retry_entries) - 1,
117
+ total_wait_interval=sum(e.wait_interval for e in retry_entries),
118
+ total_call_interval=sum(e.call_interval for e in retry_entries),
119
+ errors=errors,
120
+ )
121
+
122
+ def __add__(self, other: 'RetryStats') -> 'RetryStats':
123
+ errors = self.errors.copy()
124
+ for error, count in other.errors.items():
125
+ errors[error] = errors.get(error, 0) + count
126
+ return RetryStats(
127
+ num_occurences=self.num_occurences + other.num_occurences,
128
+ total_wait_interval=self.total_wait_interval
129
+ + other.total_wait_interval,
130
+ total_call_interval=self.total_call_interval
131
+ + other.total_call_interval,
132
+ errors=errors,
133
+ )
134
+
135
+ def __radd__(self, other: 'RetryStats') -> 'RetryStats':
136
+ return self + other
137
+
138
+
139
+ class LMSamplingUsage(pg.Object):
140
+ """Usage information per completion."""
141
+
142
+ prompt_tokens: int
143
+ completion_tokens: int
144
+ total_tokens: int
145
+ num_requests: int = 1
146
+ estimated_cost: Annotated[
147
+ float | None,
148
+ (
149
+ 'Estimated cost in US dollars. If None, cost estimating is not '
150
+ 'suppported on the model being queried.'
151
+ ),
152
+ ] = None
153
+ retry_stats: RetryStats = RetryStats()
154
+
155
+ def __bool__(self) -> bool:
156
+ return self.num_requests > 0
157
+
158
+ @property
159
+ def average_prompt_tokens(self) -> int:
160
+ """Returns the average prompt tokens per request."""
161
+ return self.prompt_tokens // self.num_requests
162
+
163
+ @property
164
+ def average_completion_tokens(self) -> int:
165
+ """Returns the average completion tokens per request."""
166
+ return self.completion_tokens // self.num_requests
167
+
168
+ @property
169
+ def average_total_tokens(self) -> int:
170
+ """Returns the average total tokens per request."""
171
+ return self.total_tokens // self.num_requests
172
+
173
+ @property
174
+ def average_estimated_cost(self) -> float | None:
175
+ """Returns the average estimated cost per request."""
176
+ if self.estimated_cost is None:
177
+ return None
178
+ return self.estimated_cost / self.num_requests
179
+
180
+ def __add__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
181
+ if other is None:
182
+ return self
183
+ if self.estimated_cost is None:
184
+ estimated_cost = other.estimated_cost
185
+ elif other.estimated_cost is None:
186
+ estimated_cost = self.estimated_cost
187
+ else:
188
+ estimated_cost = self.estimated_cost + other.estimated_cost
189
+ return LMSamplingUsage(
190
+ prompt_tokens=self.prompt_tokens + other.prompt_tokens,
191
+ completion_tokens=self.completion_tokens + other.completion_tokens,
192
+ total_tokens=self.total_tokens + other.total_tokens,
193
+ num_requests=self.num_requests + other.num_requests,
194
+ estimated_cost=estimated_cost,
195
+ retry_stats=self.retry_stats + other.retry_stats,
196
+ )
197
+
198
+ def __radd__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
199
+ return self + other
200
+
201
+
202
+ class UsageNotAvailable(LMSamplingUsage):
203
+ """Usage information not available."""
204
+ prompt_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
205
+ completion_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
206
+ total_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
207
+ estimated_cost: pg.typing.Float(default=None, is_noneable=True).freeze() # pytype: disable=invalid-annotation
208
+
209
+ def __add__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
210
+ if other is None:
211
+ return self
212
+ return UsageNotAvailable(
213
+ num_requests=self.num_requests + other.num_requests
214
+ )
215
+
216
+ def __radd__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
217
+ return self + other
218
+
219
+
50
220
  class LMSamplingResult(pg.Object):
51
221
  """Language model response."""
52
222
 
@@ -58,19 +228,39 @@ class LMSamplingResult(pg.Object):
58
228
  ),
59
229
  ] = []
60
230
 
231
+ usage: Annotated[
232
+ LMSamplingUsage,
233
+ 'Usage information. Currently only OpenAI models are supported.',
234
+ ] = UsageNotAvailable()
235
+
236
+ is_cached: Annotated[
237
+ bool,
238
+ 'Whether the result is from cache or not.'
239
+ ] = False
240
+
61
241
 
62
242
  class LMSamplingOptions(component.Component):
63
243
  """Language model sampling options."""
64
244
 
65
245
  temperature: Annotated[
66
- float,
246
+ float | None,
67
247
  (
68
248
  'Model temperature, which is usually between 0 and 1.0. '
69
- 'OpenAI models have temperature range from 0.0 to 2.0.'
249
+ 'OpenAI models have temperature range from 0.0 to 2.0. '
250
+ 'If None (default), honor the model\'s default behavior. '
70
251
  )
71
- ] = 0.0
72
- max_tokens: Annotated[int, 'Per example max tokens to generate.'] = 1024
252
+ ] = None
253
+
254
+ max_tokens: Annotated[
255
+ int | None,
256
+ (
257
+ 'Per example max tokens to generate. '
258
+ 'If None, use the model default.'
259
+ )
260
+ ] = None
261
+
73
262
  n: Annotated[int | None, 'Max number of samples to return.'] = 1
263
+
74
264
  top_k: Annotated[
75
265
  int | None,
76
266
  (
@@ -78,6 +268,7 @@ class LMSamplingOptions(component.Component):
78
268
  'Not applicable to OpenAI models.'
79
269
  )
80
270
  ] = 40
271
+
81
272
  top_p: Annotated[
82
273
  float | None,
83
274
  (
@@ -86,6 +277,7 @@ class LMSamplingOptions(component.Component):
86
277
  '`top_p` but not both.'
87
278
  ),
88
279
  ] = None
280
+
89
281
  stop: Annotated[
90
282
  list[str] | None,
91
283
  (
@@ -95,9 +287,11 @@ class LMSamplingOptions(component.Component):
95
287
  '`Model:` is reached.'
96
288
  ),
97
289
  ] = None
290
+
98
291
  random_seed: Annotated[
99
292
  int | None, 'A fixed random seed used during model inference.'
100
293
  ] = None
294
+
101
295
  logprobs: Annotated[
102
296
  bool,
103
297
  (
@@ -106,6 +300,7 @@ class LMSamplingOptions(component.Component):
106
300
  'in the content of message.'
107
301
  ),
108
302
  ] = False
303
+
109
304
  top_logprobs: Annotated[
110
305
  int | None,
111
306
  (
@@ -135,6 +330,11 @@ class LMScoringResult(pg.Object):
135
330
  float,
136
331
  'The log likelyhood of the requested completion towards the prompt.',
137
332
  ]
333
+ gradients: Annotated[
334
+ Any | None,
335
+ '(Optional) gradients from the score method, w.r.t.' +
336
+ ' prompt.metadata.weights.',
337
+ ] = None
138
338
 
139
339
 
140
340
  class LMCache(pg.Object):
@@ -149,6 +349,7 @@ class LMCache(pg.Object):
149
349
  num_hit_expires: int = 0
150
350
  num_misses: int = 0
151
351
  num_updates: int = 0
352
+ num_deletes: int = 0
152
353
 
153
354
  @abc.abstractmethod
154
355
  def get(
@@ -166,6 +367,15 @@ class LMCache(pg.Object):
166
367
  ) -> None:
167
368
  """Puts the result of a prompt generated by a language model in cache."""
168
369
 
370
+ @abc.abstractmethod
371
+ def delete(
372
+ self,
373
+ lm: 'LanguageModel',
374
+ prompt: message_lib.Message,
375
+ seed: int,
376
+ ) -> bool:
377
+ """Deletes the result of a prompt generated by a language model in cache."""
378
+
169
379
  @property
170
380
  @abc.abstractmethod
171
381
  def stats(self) -> Stats:
@@ -259,6 +469,15 @@ class LanguageModel(component.Component):
259
469
  )
260
470
  ] = True
261
471
 
472
+ max_retry_interval: Annotated[
473
+ int,
474
+ (
475
+ 'The max retry interval in seconds. This is useful when the retry '
476
+ 'interval is exponential, to avoid the wait time to grow '
477
+ 'exponentially.'
478
+ )
479
+ ] = 300
480
+
262
481
  debug: Annotated[
263
482
  bool | LMDebugMode,
264
483
  (
@@ -272,7 +491,10 @@ class LanguageModel(component.Component):
272
491
  def __init__(self, *args, **kwargs) -> None:
273
492
  """Overrides __init__ to pass through **kwargs to sampling options."""
274
493
 
275
- sampling_options = kwargs.pop('sampling_options', LMSamplingOptions())
494
+ sampling_options = kwargs.pop(
495
+ 'sampling_options',
496
+ pg.clone(self.__schema__.fields['sampling_options'].default_value)
497
+ )
276
498
  sampling_options_delta = {}
277
499
 
278
500
  for k, v in kwargs.items():
@@ -315,9 +537,64 @@ class LanguageModel(component.Component):
315
537
 
316
538
  with component.context(override_attrs=True, **kwargs):
317
539
  if self.cache is None:
318
- return self._sample(prompts)
540
+ results = self._sample(prompts)
319
541
  else:
320
- return self._sample_with_cache_lookup(prompts, cache_seed)
542
+ results = self._sample_with_cache_lookup(prompts, cache_seed)
543
+
544
+ for prompt, result in zip(prompts, results):
545
+
546
+ # Tag LM input.
547
+ prompt.tag(message_lib.Message.TAG_LM_INPUT)
548
+
549
+ for sample in result.samples:
550
+ # Update metadata for response message.
551
+
552
+ response = sample.response
553
+ response.metadata.score = sample.score
554
+ response.metadata.logprobs = sample.logprobs
555
+ response.metadata.is_cached = result.is_cached
556
+
557
+ # NOTE(daiyip): Current usage is computed at per-result level,
558
+ # which is accurate when n=1. For n > 1, we average the usage across
559
+ # multiple samples.
560
+ usage = result.usage
561
+ if len(result.samples) == 1 or isinstance(usage, UsageNotAvailable):
562
+ response.metadata.usage = usage
563
+ else:
564
+ n = len(result.samples)
565
+ response.metadata.usage = LMSamplingUsage(
566
+ prompt_tokens=usage.prompt_tokens // n,
567
+ completion_tokens=usage.completion_tokens // n,
568
+ total_tokens=usage.total_tokens // n,
569
+ estimated_cost=(
570
+ usage.estimated_cost / n if usage.estimated_cost else None
571
+ ),
572
+ retry_stats=RetryStats(
573
+ num_occurences=usage.retry_stats.num_occurences // n,
574
+ total_wait_interval=usage.retry_stats.total_wait_interval
575
+ / n,
576
+ total_call_interval=usage.retry_stats.total_call_interval
577
+ / n,
578
+ errors={
579
+ error: count // n
580
+ for error, count in usage.retry_stats.errors.items()
581
+ },
582
+ ),
583
+ )
584
+
585
+ # Track usage.
586
+ trackers = component.context_value('__usage_trackers__', [])
587
+ if trackers:
588
+ model_id = self.model_id
589
+ for tracker in trackers:
590
+ tracker.track(model_id, usage, result.is_cached)
591
+
592
+ # Track the prompt for corresponding response.
593
+ response.source = prompt
594
+
595
+ # Tag LM response.
596
+ response.tag(message_lib.Message.TAG_LM_RESPONSE)
597
+ return results
321
598
 
322
599
  def _sample_with_cache_lookup(
323
600
  self, prompts: list[str | message_lib.Message], cache_seed: int
@@ -339,7 +616,9 @@ class LanguageModel(component.Component):
339
616
  request_to_result_index[len(requests)] = i
340
617
  requests.append(prompt)
341
618
  else:
342
- results[i] = r.clone()
619
+ result = r.clone()
620
+ assert result.is_cached, result
621
+ results[i] = result
343
622
 
344
623
  # Sample non-cache-hit prompts.
345
624
  if requests:
@@ -356,8 +635,12 @@ class LanguageModel(component.Component):
356
635
  sample.response.set('cache_seed', cache_seed)
357
636
 
358
637
  if cache_seed is not None:
359
- self.cache.put(self, prompt, result.clone(), seed=cache_seed)
360
-
638
+ self.cache.put(
639
+ self,
640
+ prompt,
641
+ result.clone(override=dict(is_cached=True)),
642
+ seed=cache_seed
643
+ )
361
644
  return results # pytype: disable=bad-return-type
362
645
 
363
646
  @abc.abstractmethod
@@ -369,16 +652,16 @@ class LanguageModel(component.Component):
369
652
 
370
653
  def _parallel_execute_with_currency_control(
371
654
  self,
372
- action: Callable[..., Any],
655
+ action: Callable[..., LMSamplingResult],
373
656
  inputs: Sequence[Any],
374
657
  retry_on_errors: Union[
375
658
  None,
376
- Union[Type[Exception], Tuple[Type[Exception], str]],
377
- Sequence[Union[Type[Exception], Tuple[Type[Exception], str]]],
378
- ] = None,
379
- ) -> Any:
659
+ Union[Type[BaseException], Tuple[Type[BaseException], str]],
660
+ Sequence[Union[Type[BaseException], Tuple[Type[BaseException], str]]],
661
+ ] = RetryableLMError,
662
+ ) -> list[Any]:
380
663
  """Helper method for subclasses for implementing _sample."""
381
- return concurrent.concurrent_execute(
664
+ executed_jobs = concurrent.concurrent_execute(
382
665
  action,
383
666
  inputs,
384
667
  executor=self.resource_id if self.max_concurrency else None,
@@ -387,7 +670,16 @@ class LanguageModel(component.Component):
387
670
  max_attempts=self.max_attempts,
388
671
  retry_interval=self.retry_interval,
389
672
  exponential_backoff=self.exponential_backoff,
673
+ max_retry_interval=self.max_retry_interval,
674
+ return_jobs=True,
390
675
  )
676
+ for job in executed_jobs:
677
+ if isinstance(job.result, LMSamplingResult):
678
+ job.result.usage.rebind(
679
+ retry_stats=RetryStats.from_retry_entries(job.retry_entries),
680
+ skip_notification=True,
681
+ )
682
+ return [job.result for job in executed_jobs]
391
683
 
392
684
  def __call__(
393
685
  self, prompt: message_lib.Message, *, cache_seed: int = 0, **kwargs
@@ -405,12 +697,9 @@ class LanguageModel(component.Component):
405
697
  result = self.sample(
406
698
  [prompt], sampling_options=sampling_options, cache_seed=cache_seed
407
699
  )[0]
408
- response = result.samples[0].response
409
- logprobs = result.samples[0].logprobs
410
- response.set('score', result.samples[0].score)
411
- response.metadata.logprobs = logprobs
412
700
  elapse = time.time() - request_start
413
- self._debug(prompt, response, call_counter, elapse)
701
+ response = result.samples[0].response
702
+ self._debug(prompt, response, call_counter, result.usage, elapse)
414
703
  return response
415
704
 
416
705
  def _debug(
@@ -418,35 +707,54 @@ class LanguageModel(component.Component):
418
707
  prompt: message_lib.Message,
419
708
  response: message_lib.Message,
420
709
  call_counter: int,
710
+ usage: LMSamplingUsage,
421
711
  elapse: float,
422
- ):
712
+ ) -> None:
423
713
  """Outputs debugging information."""
424
714
  debug = self.debug
425
715
  if isinstance(debug, bool):
426
716
  debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
427
717
 
428
718
  if debug & LMDebugMode.INFO:
429
- self._debug_model_info(call_counter)
719
+ self._debug_model_info(call_counter, usage)
430
720
 
431
721
  if debug & LMDebugMode.PROMPT:
432
- self._debug_prompt(prompt, call_counter)
722
+ self._debug_prompt(prompt, call_counter, usage)
433
723
 
434
724
  if debug & LMDebugMode.RESPONSE:
435
- self._debug_response(response, call_counter, elapse)
725
+ self._debug_response(response, call_counter, usage, elapse)
436
726
 
437
- def _debug_model_info(self, call_counter: int):
727
+ def _debug_model_info(
728
+ self, call_counter: int, usage: LMSamplingUsage) -> None:
438
729
  """Outputs debugging information about the model."""
730
+ title_suffix = ''
731
+ if usage.total_tokens != 0:
732
+ title_suffix = pg.colored(
733
+ f' (total {usage.total_tokens} tokens)', 'red'
734
+ )
735
+
439
736
  console.write(
440
737
  self.format(compact=True, use_inferred=True),
441
- title=f'[{call_counter}] LM INFO:',
738
+ title=f'[{call_counter}] LM INFO{title_suffix}:',
442
739
  color='magenta',
443
740
  )
444
741
 
445
- def _debug_prompt(self, prompt: message_lib.Message, call_counter: int):
742
+ def _debug_prompt(
743
+ self,
744
+ prompt: message_lib.Message,
745
+ call_counter: int,
746
+ usage: LMSamplingUsage,
747
+ ) -> None:
446
748
  """Outputs debugging information about the prompt."""
749
+ title_suffix = ''
750
+ if usage.prompt_tokens != 0:
751
+ title_suffix = pg.colored(f' ({usage.prompt_tokens} tokens)', 'red')
752
+
447
753
  console.write(
448
- prompt,
449
- title=f'\n[{call_counter}] PROMPT SENT TO LM:',
754
+ # We use metadata 'formatted_text' for scenarios where the prompt text
755
+ # is formatted by the LM.
756
+ prompt.get('formatted_text', prompt.text),
757
+ title=f'\n[{call_counter}] PROMPT SENT TO LM{title_suffix}:',
450
758
  color='green',
451
759
  )
452
760
  referred_modalities = prompt.referred_modalities()
@@ -460,23 +768,40 @@ class LanguageModel(component.Component):
460
768
  )
461
769
 
462
770
  def _debug_response(
463
- self, response: message_lib.Message, call_counter: int, elapse: float
464
- ):
771
+ self,
772
+ response: message_lib.Message,
773
+ call_counter: int,
774
+ usage: LMSamplingUsage,
775
+ elapse: float
776
+ ) -> None:
465
777
  """Outputs debugging information about the response."""
778
+ title_suffix = ' ('
779
+ if usage.completion_tokens != 0:
780
+ title_suffix += f'{usage.completion_tokens} tokens '
781
+ title_suffix += f'in {elapse:.2f} seconds)'
782
+ title_suffix = pg.colored(title_suffix, 'red')
783
+
466
784
  console.write(
467
785
  str(response) + '\n',
468
- title=f'\n[{call_counter}] LM RESPONSE (in {elapse:.2f} seconds):',
786
+ title=f'\n[{call_counter}] LM RESPONSE{title_suffix}:',
469
787
  color='blue',
470
788
  )
471
789
 
472
790
  def score(
473
791
  self,
474
- prompt: str | message_lib.Message,
792
+ prompt: str | message_lib.Message | list[message_lib.Message],
475
793
  completions: list[str | message_lib.Message],
476
794
  **kwargs,
477
795
  ) -> list[LMScoringResult]:
478
796
  """Scores the given prompt."""
479
- prompt = message_lib.UserMessage.from_value(prompt)
797
+ if isinstance(prompt, list):
798
+ if len(prompt) != len(completions):
799
+ raise ValueError(
800
+ 'prompt and completions must have the same length.'
801
+ )
802
+ prompt = [message_lib.UserMessage.from_value(p) for p in prompt]
803
+ else:
804
+ prompt = message_lib.UserMessage.from_value(prompt)
480
805
  completions = [message_lib.UserMessage.from_value(c) for c in completions]
481
806
 
482
807
  call_counter = self._call_counter
@@ -492,7 +817,8 @@ class LanguageModel(component.Component):
492
817
  return scoring_results
493
818
 
494
819
  def _score(
495
- self, prompt: message_lib.Message, completions: list[message_lib.Message]
820
+ self, prompt: message_lib.Message | list[message_lib.Message],
821
+ completions: list[message_lib.Message]
496
822
  ) -> list[LMScoringResult]:
497
823
  """Subclass to implement."""
498
824
  raise NotImplementedError(
@@ -501,7 +827,7 @@ class LanguageModel(component.Component):
501
827
 
502
828
  def _debug_score(
503
829
  self,
504
- prompt: message_lib.Message,
830
+ prompt: message_lib.Message | list[message_lib.Message],
505
831
  completions: list[message_lib.Message],
506
832
  scoring_results: list[LMScoringResult],
507
833
  call_counter: int,
@@ -512,7 +838,7 @@ class LanguageModel(component.Component):
512
838
  debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
513
839
 
514
840
  if debug & LMDebugMode.INFO:
515
- self._debug_model_info(call_counter)
841
+ self._debug_model_info(call_counter, UsageNotAvailable())
516
842
 
517
843
  if debug & LMDebugMode.PROMPT:
518
844
  console.write(
@@ -520,15 +846,19 @@ class LanguageModel(component.Component):
520
846
  title=f'\n[{call_counter}] SCORING LM WITH PROMPT:',
521
847
  color='green',
522
848
  )
523
- referred_modalities = prompt.referred_modalities()
524
- if referred_modalities:
525
- console.write(
526
- pg.object_utils.kvlist_str(
527
- [(k, repr(v), None) for k, v in referred_modalities.items()]
528
- ),
529
- title=f'\n[{call_counter}] MODALITY OBJECTS SENT TO LM:',
530
- color='green',
531
- )
849
+ if isinstance(prompt, list):
850
+ referred_modalities_lst = [p.referred_modalities() for p in prompt]
851
+ else:
852
+ referred_modalities_lst = [prompt.referred_modalities(),]
853
+ if referred_modalities_lst:
854
+ for referred_modalities in referred_modalities_lst:
855
+ console.write(
856
+ pg.object_utils.kvlist_str(
857
+ [(k, repr(v), None) for k, v in referred_modalities.items()]
858
+ ),
859
+ title=f'\n[{call_counter}] MODALITY OBJECTS SENT TO LM:',
860
+ color='green',
861
+ )
532
862
 
533
863
  if debug & LMDebugMode.RESPONSE:
534
864
  console.write(
@@ -548,3 +878,311 @@ class LanguageModel(component.Component):
548
878
  f'score: {r.score}',
549
879
  color='blue',
550
880
  )
881
+
882
+ def tokenize(
883
+ self,
884
+ prompt: str | message_lib.Message,
885
+ **kwargs,
886
+ ) -> list[tuple[str | bytes, int]]:
887
+ """Tokenizes the given prompt."""
888
+ prompt = message_lib.UserMessage.from_value(prompt)
889
+ call_counter = self._call_counter
890
+ self._call_counter += 1
891
+
892
+ with component.context(override_attrs=True, **kwargs):
893
+ request_start = time.time()
894
+ tokens = self._tokenize(prompt)
895
+ elapse = time.time() - request_start
896
+ self._debug_tokenize(prompt, tokens, call_counter, elapse)
897
+ return tokens
898
+
899
+ def _tokenize(
900
+ self, prompt: message_lib.Message
901
+ ) -> list[tuple[str | bytes, int]]:
902
+ """Subclass to implement."""
903
+ raise NotImplementedError(
904
+ f'{self.__class__.__name__} does not support tokenization.'
905
+ )
906
+
907
+ def _debug_tokenize(
908
+ self,
909
+ prompt: message_lib.Message,
910
+ tokens: list[tuple[str | bytes, int]],
911
+ call_counter: int,
912
+ elapse: float,
913
+ ):
914
+ debug = self.debug
915
+ if isinstance(debug, bool):
916
+ debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
917
+
918
+ if debug & LMDebugMode.INFO:
919
+ self._debug_model_info(call_counter, UsageNotAvailable())
920
+
921
+ if debug & LMDebugMode.PROMPT:
922
+ console.write(
923
+ prompt,
924
+ title=f'\n[{call_counter}] PROMPT TO TOKENIZE:',
925
+ color='green',
926
+ )
927
+ referred_modalities_lst = [prompt.referred_modalities(),]
928
+ if referred_modalities_lst:
929
+ for referred_modalities in referred_modalities_lst:
930
+ console.write(
931
+ pg.object_utils.kvlist_str(
932
+ [(k, repr(v), None) for k, v in referred_modalities.items()]
933
+ ),
934
+ title=f'\n[{call_counter}] MODALITY OBJECTS SENT TO LM:',
935
+ color='green',
936
+ )
937
+
938
+ if debug & LMDebugMode.RESPONSE:
939
+ console.write(
940
+ tokens,
941
+ title=(
942
+ f'\n[{call_counter}] {len(tokens)} TOKENS RETURNED '
943
+ f'(in {elapse:.2f} seconds):'
944
+ ),
945
+ color='blue',
946
+ )
947
+
948
+ def rate_to_max_concurrency(
949
+ self, requests_per_min: float = 0, tokens_per_min: float = 0
950
+ ) -> int:
951
+ """Converts a rate to a max concurrency."""
952
+ if tokens_per_min > 0:
953
+ return max(int(tokens_per_min / TOKENS_PER_REQUEST / 60), 1)
954
+ elif requests_per_min > 0:
955
+ return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
956
+ else:
957
+ return DEFAULT_MAX_CONCURRENCY # Default of 1
958
+
959
+
960
+ class UsageSummary(pg.Object, pg.views.HtmlTreeView.Extension):
961
+ """Usage sumary."""
962
+
963
+ class AggregatedUsage(pg.Object):
964
+ """Aggregated usage."""
965
+
966
+ total: LMSamplingUsage = LMSamplingUsage(0, 0, 0, 0, 0.0)
967
+ breakdown: dict[str, LMSamplingUsage] = {}
968
+
969
+ def __bool__(self) -> bool:
970
+ """Returns True if the usage is non-empty."""
971
+ return bool(self.breakdown)
972
+
973
+ def add(
974
+ self,
975
+ model_id: str,
976
+ usage: LMSamplingUsage,
977
+ ) -> None:
978
+ """Adds an entry to the breakdown."""
979
+ aggregated = self.breakdown.get(model_id, None)
980
+ with pg.notify_on_change(False):
981
+ self.breakdown[model_id] = usage + aggregated
982
+ self.rebind(
983
+ total=self.total + usage,
984
+ raise_on_no_change=False
985
+ )
986
+
987
+ def merge(self, other: 'UsageSummary.AggregatedUsage') -> None:
988
+ """Merges the usage summary."""
989
+ with pg.notify_on_change(False):
990
+ for model_id, usage in other.breakdown.items():
991
+ self.add(model_id, usage)
992
+
993
+ def _on_bound(self):
994
+ super()._on_bound()
995
+ self._usage_badge = None
996
+ self._lock = threading.Lock()
997
+
998
+ @property
999
+ def total(self) -> LMSamplingUsage:
1000
+ return self.cached.total + self.uncached.total
1001
+
1002
+ def add(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
1003
+ """Updates the usage summary."""
1004
+ with self._lock:
1005
+ if is_cached:
1006
+ usage.rebind(estimated_cost=0.0, skip_notification=True)
1007
+ self.cached.add(model_id, usage)
1008
+ else:
1009
+ self.uncached.add(model_id, usage)
1010
+ self._update_view()
1011
+
1012
+ def merge(self, other: 'UsageSummary', as_cached: bool = False) -> None:
1013
+ """Aggregates the usage summary.
1014
+
1015
+ Args:
1016
+ other: The usage summary to merge.
1017
+ as_cached: Whether to merge the usage summary as cached.
1018
+ """
1019
+ with self._lock:
1020
+ self.cached.merge(other.cached)
1021
+ if as_cached:
1022
+ self.cached.merge(other.uncached)
1023
+ else:
1024
+ self.uncached.merge(other.uncached)
1025
+ self._update_view()
1026
+
1027
+ def _sym_nondefault(self) -> dict[str, Any]:
1028
+ """Overrides nondefault values so volatile values are not included."""
1029
+ return dict()
1030
+
1031
+ #
1032
+ # Html views for the usage summary.
1033
+ #
1034
+
1035
+ def _update_view(self):
1036
+ if self._usage_badge is not None:
1037
+ self._usage_badge.update(
1038
+ self._badge_text(),
1039
+ tooltip=pg.format(
1040
+ self, verbose=False, custom_format=self._tooltip_format
1041
+ ),
1042
+ styles=dict(color=self._badge_color()),
1043
+ )
1044
+
1045
+ def _badge_text(self) -> str:
1046
+ if self.total.estimated_cost is not None:
1047
+ return f'{self.total.estimated_cost:.3f}'
1048
+ return '0.000'
1049
+
1050
+ def _badge_color(self) -> str | None:
1051
+ if self.total.estimated_cost is None or self.total.estimated_cost < 1.0:
1052
+ return None
1053
+
1054
+ # Step 1: The normal cost range is around 1e-3 to 1e5.
1055
+ # Therefore we normalize the log10 value from [-3, 5] to [0, 1].
1056
+ normalized_value = (math.log10(self.total.estimated_cost) + 3) / (5 + 3)
1057
+
1058
+ # Step 2: Interpolate between green and red
1059
+ red = int(255 * normalized_value)
1060
+ green = int(255 * (1 - normalized_value))
1061
+ return f'rgb({red}, {green}, 0)'
1062
+
1063
+ def _tooltip_format(self, v, root_indent):
1064
+ del root_indent
1065
+ if isinstance(v, int):
1066
+ return f'{v:,}'
1067
+ if isinstance(v, float):
1068
+ return f'{v:,.3f}'
1069
+ return None
1070
+
1071
+ def _html_tree_view(
1072
+ self,
1073
+ *,
1074
+ view: pg.views.HtmlTreeView,
1075
+ extra_flags: dict[str, Any] | None = None,
1076
+ **kwargs
1077
+ ) -> pg.Html:
1078
+ extra_flags = extra_flags or {}
1079
+ as_badge = extra_flags.pop('as_badge', False)
1080
+ interactive = extra_flags.get('interactive', True)
1081
+ if as_badge:
1082
+ usage_badge = self._usage_badge
1083
+ if usage_badge is None:
1084
+ usage_badge = pg.views.html.controls.Badge(
1085
+ self._badge_text(),
1086
+ tooltip=pg.format(
1087
+ self, custom_format=self._tooltip_format, verbose=False
1088
+ ),
1089
+ css_classes=['usage-summary'],
1090
+ styles=dict(color=self._badge_color()),
1091
+ interactive=True,
1092
+ )
1093
+ if interactive:
1094
+ self._usage_badge = usage_badge
1095
+ return usage_badge.to_html()
1096
+ return super()._html_tree_view(
1097
+ view=view,
1098
+ extra_flags=extra_flags,
1099
+ **kwargs
1100
+ )
1101
+
1102
+ @classmethod
1103
+ @functools.cache
1104
+ def _html_tree_view_css_styles(cls) -> list[str]:
1105
+ return super()._html_tree_view_css_styles() + [
1106
+ """
1107
+ .usage-summary.label {
1108
+ display: inline-flex;
1109
+ border-radius: 5px;
1110
+ padding: 5px;
1111
+ background-color: #f1f1f1;
1112
+ color: #CCC;
1113
+ }
1114
+ .usage-summary.label::before {
1115
+ content: '$';
1116
+ }
1117
+ """
1118
+ ]
1119
+
1120
+ pg.members(
1121
+ dict(
1122
+ cached=(
1123
+ pg.typing.Object(
1124
+ UsageSummary.AggregatedUsage,
1125
+ default=UsageSummary.AggregatedUsage()
1126
+ ),
1127
+ 'Aggregated usages for cached LLM calls.'
1128
+ ),
1129
+ uncached=(
1130
+ pg.typing.Object(
1131
+ UsageSummary.AggregatedUsage,
1132
+ default=UsageSummary.AggregatedUsage()
1133
+ ),
1134
+ 'Aggregated usages for uncached LLM calls.'
1135
+ ),
1136
+ )
1137
+ )(UsageSummary)
1138
+
1139
+
1140
+ class _UsageTracker:
1141
+ """Usage tracker."""
1142
+
1143
+ def __init__(self, model_ids: set[str] | None):
1144
+ self.model_ids = model_ids
1145
+ self.usage_summary = UsageSummary()
1146
+
1147
+ def track(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
1148
+ if self.model_ids is None or model_id in self.model_ids:
1149
+ self.usage_summary.add(model_id, usage, is_cached)
1150
+
1151
+
1152
+ @contextlib.contextmanager
1153
+ def track_usages(
1154
+ *lm: Union[str, LanguageModel]
1155
+ ) -> Iterator[UsageSummary]:
1156
+ """Context manager to track the usages of all language models in scope.
1157
+
1158
+ `lf.track_usages` works with threads spawned by `lf.concurrent_map` and
1159
+ `lf.concurrent_execute`.
1160
+
1161
+ Example:
1162
+ ```
1163
+ lm = lf.llms.GeminiPro1()
1164
+ with lf.track_usages() as usages:
1165
+ # invoke any code that will call LLMs.
1166
+
1167
+ print(usages[lm.model_id])
1168
+ ```
1169
+
1170
+ Args:
1171
+ *lm: The language model(s) to track. If None, track all models in scope.
1172
+
1173
+ Yields:
1174
+ A dictionary of model ID to usage. If a model does not supports usage
1175
+ counting, the dict entry will be None.
1176
+ """
1177
+ if not lm:
1178
+ model_ids = None
1179
+ else:
1180
+ model_ids = [m.model_id if isinstance(m, LanguageModel) else m for m in lm]
1181
+
1182
+ trackers = component.context_value('__usage_trackers__', [])
1183
+ tracker = _UsageTracker(set(model_ids) if model_ids else None)
1184
+ with component.context(__usage_trackers__=trackers + [tracker]):
1185
+ try:
1186
+ yield tracker.usage_summary
1187
+ finally:
1188
+ pass