langfun 0.0.2.dev20240319__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. langfun/__init__.py +2 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +2 -0
  7. langfun/core/eval/base.py +240 -37
  8. langfun/core/eval/base_test.py +52 -18
  9. langfun/core/eval/matching.py +26 -9
  10. langfun/core/eval/matching_test.py +3 -4
  11. langfun/core/eval/scoring.py +15 -6
  12. langfun/core/eval/scoring_test.py +2 -2
  13. langfun/core/langfunc.py +0 -5
  14. langfun/core/langfunc_test.py +6 -4
  15. langfun/core/language_model.py +124 -24
  16. langfun/core/language_model_test.py +249 -26
  17. langfun/core/llms/__init__.py +24 -5
  18. langfun/core/llms/anthropic.py +263 -0
  19. langfun/core/llms/anthropic_test.py +167 -0
  20. langfun/core/llms/cache/in_memory_test.py +37 -28
  21. langfun/core/llms/fake.py +31 -22
  22. langfun/core/llms/fake_test.py +122 -11
  23. langfun/core/llms/{gemini.py → google_genai.py} +117 -15
  24. langfun/core/llms/{gemini_test.py → google_genai_test.py} +83 -15
  25. langfun/core/llms/groq.py +260 -0
  26. langfun/core/llms/groq_test.py +170 -0
  27. langfun/core/llms/llama_cpp.py +3 -1
  28. langfun/core/llms/openai.py +97 -79
  29. langfun/core/llms/openai_test.py +285 -59
  30. langfun/core/modalities/video.py +5 -2
  31. langfun/core/structured/__init__.py +3 -0
  32. langfun/core/structured/completion_test.py +2 -2
  33. langfun/core/structured/function_generation.py +245 -0
  34. langfun/core/structured/function_generation_test.py +329 -0
  35. langfun/core/structured/mapping.py +59 -3
  36. langfun/core/structured/mapping_test.py +17 -0
  37. langfun/core/structured/parsing.py +2 -1
  38. langfun/core/structured/parsing_test.py +18 -13
  39. langfun/core/structured/prompting.py +27 -6
  40. langfun/core/structured/prompting_test.py +79 -12
  41. langfun/core/structured/schema.py +25 -22
  42. langfun/core/structured/schema_generation.py +2 -3
  43. langfun/core/structured/schema_generation_test.py +2 -2
  44. langfun/core/structured/schema_test.py +42 -27
  45. langfun/core/template.py +125 -10
  46. langfun/core/template_test.py +75 -0
  47. langfun/core/templates/selfplay_test.py +6 -2
  48. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
  49. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +52 -46
  50. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
  51. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
  52. {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,9 @@ from langfun.core import console
24
24
  from langfun.core import message as message_lib
25
25
  import pyglove as pg
26
26
 
27
+ TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
28
+ DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
29
+
27
30
 
28
31
  class LMSample(pg.Object):
29
32
  """Response candidate."""
@@ -47,6 +50,14 @@ class LMSample(pg.Object):
47
50
  ] = None
48
51
 
49
52
 
53
+ class LMSamplingUsage(pg.Object):
54
+ """Usage information per completion."""
55
+
56
+ prompt_tokens: int
57
+ completion_tokens: int
58
+ total_tokens: int
59
+
60
+
50
61
  class LMSamplingResult(pg.Object):
51
62
  """Language model response."""
52
63
 
@@ -58,19 +69,34 @@ class LMSamplingResult(pg.Object):
58
69
  ),
59
70
  ] = []
60
71
 
72
+ usage: Annotated[
73
+ LMSamplingUsage | None,
74
+ 'Usage information. Currently only OpenAI models are supported.',
75
+ ] = None
76
+
61
77
 
62
78
  class LMSamplingOptions(component.Component):
63
79
  """Language model sampling options."""
64
80
 
65
81
  temperature: Annotated[
66
- float,
82
+ float | None,
67
83
  (
68
84
  'Model temperature, which is usually between 0 and 1.0. '
69
- 'OpenAI models have temperature range from 0.0 to 2.0.'
85
+ 'OpenAI models have temperature range from 0.0 to 2.0. '
86
+ 'If None (default), honor the model\'s default behavior. '
70
87
  )
71
- ] = 0.0
72
- max_tokens: Annotated[int, 'Per example max tokens to generate.'] = 1024
88
+ ] = None
89
+
90
+ max_tokens: Annotated[
91
+ int | None,
92
+ (
93
+ 'Per example max tokens to generate. '
94
+ 'If None, use the model default.'
95
+ )
96
+ ] = None
97
+
73
98
  n: Annotated[int | None, 'Max number of samples to return.'] = 1
99
+
74
100
  top_k: Annotated[
75
101
  int | None,
76
102
  (
@@ -78,6 +104,7 @@ class LMSamplingOptions(component.Component):
78
104
  'Not applicable to OpenAI models.'
79
105
  )
80
106
  ] = 40
107
+
81
108
  top_p: Annotated[
82
109
  float | None,
83
110
  (
@@ -86,6 +113,7 @@ class LMSamplingOptions(component.Component):
86
113
  '`top_p` but not both.'
87
114
  ),
88
115
  ] = None
116
+
89
117
  stop: Annotated[
90
118
  list[str] | None,
91
119
  (
@@ -95,9 +123,11 @@ class LMSamplingOptions(component.Component):
95
123
  '`Model:` is reached.'
96
124
  ),
97
125
  ] = None
126
+
98
127
  random_seed: Annotated[
99
128
  int | None, 'A fixed random seed used during model inference.'
100
129
  ] = None
130
+
101
131
  logprobs: Annotated[
102
132
  bool,
103
133
  (
@@ -106,6 +136,7 @@ class LMSamplingOptions(component.Component):
106
136
  'in the content of message.'
107
137
  ),
108
138
  ] = False
139
+
109
140
  top_logprobs: Annotated[
110
141
  int | None,
111
142
  (
@@ -315,9 +346,42 @@ class LanguageModel(component.Component):
315
346
 
316
347
  with component.context(override_attrs=True, **kwargs):
317
348
  if self.cache is None:
318
- return self._sample(prompts)
349
+ results = self._sample(prompts)
319
350
  else:
320
- return self._sample_with_cache_lookup(prompts, cache_seed)
351
+ results = self._sample_with_cache_lookup(prompts, cache_seed)
352
+
353
+ for prompt, result in zip(prompts, results):
354
+
355
+ # Tag LM input.
356
+ prompt.tag(message_lib.Message.TAG_LM_INPUT)
357
+
358
+ for sample in result.samples:
359
+ # Update metadata for response message.
360
+
361
+ response = sample.response
362
+ response.metadata.score = sample.score
363
+ response.metadata.logprobs = sample.logprobs
364
+
365
+ # NOTE(daiyip): Current usage is computed at per-result level,
366
+ # which is accurate when n=1. For n > 1, we average the usage across
367
+ # multiple samples.
368
+ usage = result.usage
369
+ if len(result.samples) == 1 or usage is None:
370
+ response.metadata.usage = usage
371
+ else:
372
+ n = len(result.samples)
373
+ response.metadata.usage = LMSamplingUsage(
374
+ prompt_tokens=usage.prompt_tokens // n,
375
+ completion_tokens=usage.completion_tokens // n,
376
+ total_tokens=usage.total_tokens // n,
377
+ )
378
+
379
+ # Track the prompt for corresponding response.
380
+ response.source = prompt
381
+
382
+ # Tag LM response.
383
+ response.tag(message_lib.Message.TAG_LM_RESPONSE)
384
+ return results
321
385
 
322
386
  def _sample_with_cache_lookup(
323
387
  self, prompts: list[str | message_lib.Message], cache_seed: int
@@ -405,12 +469,9 @@ class LanguageModel(component.Component):
405
469
  result = self.sample(
406
470
  [prompt], sampling_options=sampling_options, cache_seed=cache_seed
407
471
  )[0]
408
- response = result.samples[0].response
409
- logprobs = result.samples[0].logprobs
410
- response.set('score', result.samples[0].score)
411
- response.metadata.logprobs = logprobs
412
472
  elapse = time.time() - request_start
413
- self._debug(prompt, response, call_counter, elapse)
473
+ response = result.samples[0].response
474
+ self._debug(prompt, response, call_counter, result.usage, elapse)
414
475
  return response
415
476
 
416
477
  def _debug(
@@ -418,35 +479,53 @@ class LanguageModel(component.Component):
418
479
  prompt: message_lib.Message,
419
480
  response: message_lib.Message,
420
481
  call_counter: int,
482
+ usage: LMSamplingUsage | None,
421
483
  elapse: float,
422
- ):
484
+ ) -> None:
423
485
  """Outputs debugging information."""
424
486
  debug = self.debug
425
487
  if isinstance(debug, bool):
426
488
  debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
427
489
 
428
490
  if debug & LMDebugMode.INFO:
429
- self._debug_model_info(call_counter)
491
+ self._debug_model_info(call_counter, usage)
430
492
 
431
493
  if debug & LMDebugMode.PROMPT:
432
- self._debug_prompt(prompt, call_counter)
494
+ self._debug_prompt(prompt, call_counter, usage)
433
495
 
434
496
  if debug & LMDebugMode.RESPONSE:
435
- self._debug_response(response, call_counter, elapse)
497
+ self._debug_response(response, call_counter, usage, elapse)
436
498
 
437
- def _debug_model_info(self, call_counter: int):
499
+ def _debug_model_info(
500
+ self, call_counter: int, usage: LMSamplingUsage | None) -> None:
438
501
  """Outputs debugging information about the model."""
502
+ title_suffix = ''
503
+ if usage and usage.total_tokens != 0:
504
+ title_suffix = console.colored(
505
+ f' (total {usage.total_tokens} tokens)', 'red')
506
+
439
507
  console.write(
440
508
  self.format(compact=True, use_inferred=True),
441
- title=f'[{call_counter}] LM INFO:',
509
+ title=f'[{call_counter}] LM INFO{title_suffix}:',
442
510
  color='magenta',
443
511
  )
444
512
 
445
- def _debug_prompt(self, prompt: message_lib.Message, call_counter: int):
513
+ def _debug_prompt(
514
+ self,
515
+ prompt: message_lib.Message,
516
+ call_counter: int,
517
+ usage: LMSamplingUsage | None,
518
+ ) -> None:
446
519
  """Outputs debugging information about the prompt."""
520
+ title_suffix = ''
521
+ if usage and usage.prompt_tokens != 0:
522
+ title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
523
+
447
524
  console.write(
448
- prompt,
449
- title=f'\n[{call_counter}] PROMPT SENT TO LM:',
525
+ # We use metadata 'formatted_text' for scenarios where the prompt text
526
+ # is formatted by the LM.
527
+ prompt.get('formatted_text', prompt.text),
528
+ title=f'\n[{call_counter}] PROMPT SENT TO LM{title_suffix}:',
450
529
  color='green',
451
530
  )
452
531
  referred_modalities = prompt.referred_modalities()
@@ -460,12 +539,22 @@ class LanguageModel(component.Component):
460
539
  )
461
540
 
462
541
  def _debug_response(
463
- self, response: message_lib.Message, call_counter: int, elapse: float
464
- ):
542
+ self,
543
+ response: message_lib.Message,
544
+ call_counter: int,
545
+ usage: LMSamplingUsage | None,
546
+ elapse: float
547
+ ) -> None:
465
548
  """Outputs debugging information about the response."""
549
+ title_suffix = ' ('
550
+ if usage and usage.completion_tokens != 0:
551
+ title_suffix += f'{usage.completion_tokens} tokens '
552
+ title_suffix += f'in {elapse:.2f} seconds)'
553
+ title_suffix = console.colored(title_suffix, 'red')
554
+
466
555
  console.write(
467
556
  str(response) + '\n',
468
- title=f'\n[{call_counter}] LM RESPONSE (in {elapse:.2f} seconds):',
557
+ title=f'\n[{call_counter}] LM RESPONSE{title_suffix}:',
469
558
  color='blue',
470
559
  )
471
560
 
@@ -512,7 +601,7 @@ class LanguageModel(component.Component):
512
601
  debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
513
602
 
514
603
  if debug & LMDebugMode.INFO:
515
- self._debug_model_info(call_counter)
604
+ self._debug_model_info(call_counter, None)
516
605
 
517
606
  if debug & LMDebugMode.PROMPT:
518
607
  console.write(
@@ -548,3 +637,14 @@ class LanguageModel(component.Component):
548
637
  f'score: {r.score}',
549
638
  color='blue',
550
639
  )
640
+
641
+ def rate_to_max_concurrency(
642
+ self, requests_per_min: float = 0, tokens_per_min: float = 0
643
+ ) -> int:
644
+ """Converts a rate to a max concurrency."""
645
+ if tokens_per_min > 0:
646
+ return max(int(tokens_per_min / TOKENS_PER_REQUEST / 60), 1)
647
+ elif requests_per_min > 0:
648
+ return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
649
+ else:
650
+ return DEFAULT_MAX_CONCURRENCY # Default of 1
@@ -38,9 +38,19 @@ class MockModel(lm_lib.LanguageModel):
38
38
  def fake_sample(prompts):
39
39
  if context.attempt >= self.failures_before_attempt:
40
40
  return [
41
- lm_lib.LMSamplingResult([lm_lib.LMSample( # pylint: disable=g-complex-comprehension
42
- response=prompt.text * self.sampling_options.top_k,
43
- score=self.sampling_options.temperature)])
41
+ lm_lib.LMSamplingResult(
42
+ [
43
+ lm_lib.LMSample( # pylint: disable=g-complex-comprehension
44
+ response=prompt.text * self.sampling_options.top_k,
45
+ score=self.sampling_options.temperature or -1.0,
46
+ )
47
+ ],
48
+ usage=lm_lib.LMSamplingUsage(
49
+ prompt_tokens=100,
50
+ completion_tokens=100,
51
+ total_tokens=200,
52
+ ),
53
+ )
44
54
  for prompt in prompts
45
55
  ]
46
56
  context.attempt += 1
@@ -73,13 +83,13 @@ class LMSamplingOptionsTest(unittest.TestCase):
73
83
  def test_cache_key(self):
74
84
  options = lm_lib.LMSamplingOptions()
75
85
  key1 = options.cache_key()
76
- self.assertEqual(key1, (0.0, 1024, 1, 40, None, None))
86
+ self.assertEqual(key1, (None, None, 1, 40, None, None))
77
87
  with options.override(temperature=1.0, max_tokens=256):
78
88
  key2 = options.cache_key()
79
89
  self.assertEqual(key2, (1.0, 256, 1, 40, None, None))
80
90
 
81
91
  # Make sure key1 does not change upon override.
82
- self.assertEqual(key1, (0.0, 1024, 1, 40, None, None))
92
+ self.assertEqual(key1, (None, None, 1, 40, None, None))
83
93
 
84
94
 
85
95
  class LanguageModelTest(unittest.TestCase):
@@ -100,8 +110,38 @@ class LanguageModelTest(unittest.TestCase):
100
110
  self.assertEqual(
101
111
  lm.sample(prompts=['foo', 'bar']),
102
112
  [
103
- lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=0.0)]),
104
- lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=0.0)]),
113
+ lm_lib.LMSamplingResult(
114
+ [
115
+ lm_lib.LMSample(
116
+ message_lib.AIMessage(
117
+ 'foo',
118
+ score=-1.0,
119
+ logprobs=None,
120
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
121
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
122
+ ),
123
+ score=-1.0,
124
+ logprobs=None,
125
+ )
126
+ ],
127
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
128
+ ),
129
+ lm_lib.LMSamplingResult(
130
+ [
131
+ lm_lib.LMSample(
132
+ message_lib.AIMessage(
133
+ 'bar',
134
+ score=-1.0,
135
+ logprobs=None,
136
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
137
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
138
+ ),
139
+ score=-1.0,
140
+ logprobs=None,
141
+ )
142
+ ],
143
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
144
+ ),
105
145
  ],
106
146
  )
107
147
  # Test override sampling_options.
@@ -112,38 +152,128 @@ class LanguageModelTest(unittest.TestCase):
112
152
  ),
113
153
  [
114
154
  lm_lib.LMSamplingResult(
115
- [lm_lib.LMSample('foo' * 2, score=0.5)]
155
+ [
156
+ lm_lib.LMSample(
157
+ message_lib.AIMessage(
158
+ 'foo' * 2,
159
+ score=0.5,
160
+ logprobs=None,
161
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
162
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
163
+ ),
164
+ score=0.5,
165
+ logprobs=None,
166
+ ),
167
+ ],
168
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
116
169
  ),
117
170
  lm_lib.LMSamplingResult(
118
- [lm_lib.LMSample('bar' * 2, score=0.5)]
171
+ [
172
+ lm_lib.LMSample(
173
+ message_lib.AIMessage(
174
+ 'bar' * 2,
175
+ score=0.5,
176
+ logprobs=None,
177
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
178
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
179
+ ),
180
+ score=0.5,
181
+ logprobs=None,
182
+ ),
183
+ ],
184
+ usage=lm_lib.LMSamplingUsage(
185
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
186
+ ),
119
187
  ),
120
- ],
188
+ ]
121
189
  )
122
190
  # Test override individual flags within sampling_options.
123
191
  self.assertEqual(
124
192
  lm.sample(prompts=['foo', 'bar'], temperature=1.0),
125
193
  [
126
- lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=1.0)]),
127
- lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=1.0)]),
128
- ],
194
+ lm_lib.LMSamplingResult(
195
+ [
196
+ lm_lib.LMSample(
197
+ message_lib.AIMessage(
198
+ 'foo',
199
+ score=1.0,
200
+ logprobs=None,
201
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
202
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
203
+ ),
204
+ score=1.0,
205
+ logprobs=None,
206
+ ),
207
+ ],
208
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
209
+ ),
210
+ lm_lib.LMSamplingResult(
211
+ [
212
+ lm_lib.LMSample(
213
+ message_lib.AIMessage(
214
+ 'bar',
215
+ score=1.0,
216
+ logprobs=None,
217
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
218
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
219
+ ),
220
+ score=1.0,
221
+ logprobs=None,
222
+ ),
223
+ ],
224
+ usage=lm_lib.LMSamplingUsage(
225
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
226
+ ),
227
+ ),
228
+ ]
129
229
  )
130
230
  self.assertEqual(
131
231
  lm.sample(prompts=['foo', 'bar'], top_k=2, temperature=0.7),
132
232
  [
133
233
  lm_lib.LMSamplingResult(
134
- [lm_lib.LMSample('foo' * 2, score=0.7)]
234
+ [
235
+ lm_lib.LMSample(
236
+ message_lib.AIMessage(
237
+ 'foo' * 2,
238
+ score=0.7,
239
+ logprobs=None,
240
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
241
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
242
+ ),
243
+ score=0.7,
244
+ logprobs=None,
245
+ ),
246
+ ],
247
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
135
248
  ),
136
249
  lm_lib.LMSamplingResult(
137
- [lm_lib.LMSample('bar' * 2, score=0.7)]
250
+ [
251
+ lm_lib.LMSample(
252
+ message_lib.AIMessage(
253
+ 'bar' * 2,
254
+ score=0.7,
255
+ logprobs=None,
256
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
257
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
258
+ ),
259
+ score=0.7,
260
+ logprobs=None,
261
+ ),
262
+ ],
263
+ usage=lm_lib.LMSamplingUsage(
264
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
265
+ ),
138
266
  ),
139
- ],
267
+ ]
140
268
  )
141
269
 
142
270
  def test_call(self):
143
271
  lm = MockModel(sampling_options=lm_lib.LMSamplingOptions(top_k=1))
144
272
  response = lm(prompt='foo')
145
273
  self.assertEqual(response.text, 'foo')
146
- self.assertEqual(response.score, 0.0)
274
+ self.assertEqual(response.score, -1.0)
275
+ self.assertIsNone(response.logprobs)
276
+ self.assertEqual(response.usage, lm_lib.LMSamplingUsage(100, 100, 200))
147
277
 
148
278
  # Test override sampling_options.
149
279
  self.assertEqual(
@@ -158,11 +288,42 @@ class LanguageModelTest(unittest.TestCase):
158
288
  self.assertEqual(
159
289
  lm.sample(prompts=['foo', 'bar']),
160
290
  [
161
- lm_lib.LMSamplingResult([lm_lib.LMSample(
162
- message_lib.AIMessage('foo', cache_seed=0), score=0.0)]),
163
- lm_lib.LMSamplingResult([lm_lib.LMSample(
164
- message_lib.AIMessage('bar', cache_seed=0), score=0.0)]),
165
- ])
291
+ lm_lib.LMSamplingResult(
292
+ [
293
+ lm_lib.LMSample(
294
+ message_lib.AIMessage(
295
+ 'foo',
296
+ cache_seed=0,
297
+ score=-1.0,
298
+ logprobs=None,
299
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
300
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
301
+ ),
302
+ score=-1.0,
303
+ logprobs=None,
304
+ )
305
+ ],
306
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
307
+ ),
308
+ lm_lib.LMSamplingResult(
309
+ [
310
+ lm_lib.LMSample(
311
+ message_lib.AIMessage(
312
+ 'bar',
313
+ cache_seed=0,
314
+ score=-1.0,
315
+ logprobs=None,
316
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
317
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
318
+ ),
319
+ score=-1.0,
320
+ logprobs=None,
321
+ )
322
+ ],
323
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
324
+ ),
325
+ ],
326
+ )
166
327
  self.assertEqual(cache.stats.num_queries, 2)
167
328
  self.assertEqual(cache.stats.num_hits, 0)
168
329
  self.assertEqual(cache.stats.num_updates, 2)
@@ -181,10 +342,40 @@ class LanguageModelTest(unittest.TestCase):
181
342
  self.assertEqual(
182
343
  lm.sample(prompts=['foo', 'baz'], temperature=1.0),
183
344
  [
184
- lm_lib.LMSamplingResult([lm_lib.LMSample(
185
- message_lib.AIMessage('foo', cache_seed=0), score=1.0)]),
186
- lm_lib.LMSamplingResult([lm_lib.LMSample(
187
- message_lib.AIMessage('baz', cache_seed=0), score=1.0)]),
345
+ lm_lib.LMSamplingResult(
346
+ [
347
+ lm_lib.LMSample(
348
+ message_lib.AIMessage(
349
+ 'foo',
350
+ cache_seed=0,
351
+ score=1.0,
352
+ logprobs=None,
353
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
354
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
355
+ ),
356
+ score=1.0,
357
+ logprobs=None,
358
+ )
359
+ ],
360
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
361
+ ),
362
+ lm_lib.LMSamplingResult(
363
+ [
364
+ lm_lib.LMSample(
365
+ message_lib.AIMessage(
366
+ 'baz',
367
+ cache_seed=0,
368
+ score=1.0,
369
+ logprobs=None,
370
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
371
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
372
+ ),
373
+ score=1.0,
374
+ logprobs=None,
375
+ )
376
+ ],
377
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
378
+ ),
188
379
  ],
189
380
  )
190
381
  self.assertEqual(cache.stats.num_queries, 6)
@@ -341,6 +532,38 @@ class LanguageModelTest(unittest.TestCase):
341
532
  with self.assertRaises(NotImplementedError):
342
533
  MockModel().score('hi', ['1', '2'])
343
534
 
535
+ def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
536
+ lm = MockModel()
537
+ self.assertEqual(
538
+ lm_lib.DEFAULT_MAX_CONCURRENCY,
539
+ lm.rate_to_max_concurrency(requests_per_min=0, tokens_per_min=0),
540
+ )
541
+ self.assertEqual(
542
+ lm_lib.DEFAULT_MAX_CONCURRENCY,
543
+ lm.rate_to_max_concurrency(requests_per_min=-1, tokens_per_min=-1),
544
+ )
545
+
546
+ def test_rate_to_max_concurrency_only_rpm_specified_uses_rpm(self) -> None:
547
+ lm = MockModel()
548
+ test_rpm = 1e4
549
+ self.assertEqual(
550
+ lm.rate_to_max_concurrency(requests_per_min=test_rpm),
551
+ int(test_rpm / 60)
552
+ )
553
+
554
+ def test_rate_to_max_concurrency_tpm_specified_uses_tpm(self) -> None:
555
+ lm = MockModel()
556
+ test_tpm = 1e7
557
+ self.assertEqual(
558
+ lm.rate_to_max_concurrency(requests_per_min=1, tokens_per_min=test_tpm),
559
+ int(test_tpm / lm_lib.TOKENS_PER_REQUEST / 60),
560
+ )
561
+
562
+ def test_rate_to_max_concurrency_small_rate_returns_one(self) -> None:
563
+ lm = MockModel()
564
+ self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
565
+ self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
566
+
344
567
 
345
568
  if __name__ == '__main__':
346
569
  unittest.main()
@@ -25,16 +25,22 @@ from langfun.core.llms.fake import StaticResponse
25
25
  from langfun.core.llms.fake import StaticSequence
26
26
 
27
27
  # Gemini models.
28
- from langfun.core.llms.gemini import Gemini
29
- from langfun.core.llms.gemini import GeminiPro
30
- from langfun.core.llms.gemini import GeminiProVision
28
+ from langfun.core.llms.google_genai import GenAI
29
+ from langfun.core.llms.google_genai import GeminiPro
30
+ from langfun.core.llms.google_genai import GeminiProVision
31
+ from langfun.core.llms.google_genai import Palm2
32
+ from langfun.core.llms.google_genai import Palm2_IT
31
33
 
32
34
  # OpenAI models.
33
35
  from langfun.core.llms.openai import OpenAI
34
36
 
35
37
  from langfun.core.llms.openai import Gpt4Turbo
36
- from langfun.core.llms.openai import Gpt4Turbo_0125
37
- from langfun.core.llms.openai import Gpt4TurboVision
38
+ from langfun.core.llms.openai import Gpt4Turbo_20240409
39
+ from langfun.core.llms.openai import Gpt4TurboPreview
40
+ from langfun.core.llms.openai import Gpt4TurboPreview_0125
41
+ from langfun.core.llms.openai import Gpt4TurboPreview_1106
42
+ from langfun.core.llms.openai import Gpt4VisionPreview
43
+ from langfun.core.llms.openai import Gpt4VisionPreview_1106
38
44
  from langfun.core.llms.openai import Gpt4
39
45
  from langfun.core.llms.openai import Gpt4_0613
40
46
  from langfun.core.llms.openai import Gpt4_32K
@@ -55,6 +61,19 @@ from langfun.core.llms.openai import Gpt3Curie
55
61
  from langfun.core.llms.openai import Gpt3Babbage
56
62
  from langfun.core.llms.openai import Gpt3Ada
57
63
 
64
+ from langfun.core.llms.anthropic import Anthropic
65
+ from langfun.core.llms.anthropic import Claude3Opus
66
+ from langfun.core.llms.anthropic import Claude3Sonnet
67
+ from langfun.core.llms.anthropic import Claude3Haiku
68
+
69
+ from langfun.core.llms.groq import Groq
70
+ from langfun.core.llms.groq import GroqLlama3_70B
71
+ from langfun.core.llms.groq import GroqLlama3_8B
72
+ from langfun.core.llms.groq import GroqLlama2_70B
73
+ from langfun.core.llms.groq import GroqMistral_8x7B
74
+ from langfun.core.llms.groq import GroqGemma7B_IT
75
+
76
+
58
77
  # LLaMA C++ models.
59
78
  from langfun.core.llms.llama_cpp import LlamaCppRemote
60
79