langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (59) hide show
  1. langfun/__init__.py +7 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +15 -0
  7. langfun/core/eval/base.py +665 -95
  8. langfun/core/eval/base_test.py +224 -53
  9. langfun/core/eval/matching.py +48 -30
  10. langfun/core/eval/matching_test.py +25 -3
  11. langfun/core/eval/patching.py +130 -0
  12. langfun/core/eval/patching_test.py +170 -0
  13. langfun/core/eval/scoring.py +19 -10
  14. langfun/core/eval/scoring_test.py +21 -3
  15. langfun/core/langfunc.py +1 -22
  16. langfun/core/langfunc_test.py +10 -4
  17. langfun/core/language_model.py +130 -24
  18. langfun/core/language_model_test.py +249 -26
  19. langfun/core/llms/__init__.py +27 -2
  20. langfun/core/llms/anthropic.py +263 -0
  21. langfun/core/llms/anthropic_test.py +167 -0
  22. langfun/core/llms/cache/in_memory_test.py +37 -28
  23. langfun/core/llms/fake.py +34 -25
  24. langfun/core/llms/fake_test.py +122 -11
  25. langfun/core/llms/google_genai.py +8 -0
  26. langfun/core/llms/google_genai_test.py +8 -3
  27. langfun/core/llms/groq.py +260 -0
  28. langfun/core/llms/groq_test.py +170 -0
  29. langfun/core/llms/llama_cpp.py +3 -1
  30. langfun/core/llms/openai.py +100 -81
  31. langfun/core/llms/openai_test.py +287 -60
  32. langfun/core/llms/vertexai.py +291 -0
  33. langfun/core/llms/vertexai_test.py +233 -0
  34. langfun/core/modalities/image.py +1 -3
  35. langfun/core/modalities/mime.py +6 -0
  36. langfun/core/modalities/video.py +6 -5
  37. langfun/core/structured/__init__.py +5 -0
  38. langfun/core/structured/completion_test.py +2 -2
  39. langfun/core/structured/function_generation.py +245 -0
  40. langfun/core/structured/function_generation_test.py +329 -0
  41. langfun/core/structured/mapping.py +61 -3
  42. langfun/core/structured/mapping_test.py +17 -0
  43. langfun/core/structured/parsing_test.py +18 -13
  44. langfun/core/structured/prompting.py +61 -12
  45. langfun/core/structured/prompting_test.py +122 -12
  46. langfun/core/structured/schema.py +38 -6
  47. langfun/core/structured/schema_generation_test.py +2 -2
  48. langfun/core/structured/schema_test.py +36 -7
  49. langfun/core/structured/scoring.py +4 -1
  50. langfun/core/structured/scoring_test.py +6 -0
  51. langfun/core/template.py +147 -11
  52. langfun/core/template_test.py +75 -0
  53. langfun/core/templates/selfplay_test.py +6 -2
  54. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +3 -2
  55. langfun-0.0.2.dev20240511.dist-info/RECORD +112 -0
  56. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  57. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
  58. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
  59. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
@@ -22,8 +22,12 @@ from langfun.core import component
22
22
  from langfun.core import concurrent
23
23
  from langfun.core import console
24
24
  from langfun.core import message as message_lib
25
+
25
26
  import pyglove as pg
26
27
 
28
+ TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
29
+ DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
30
+
27
31
 
28
32
  class LMSample(pg.Object):
29
33
  """Response candidate."""
@@ -47,6 +51,14 @@ class LMSample(pg.Object):
47
51
  ] = None
48
52
 
49
53
 
54
+ class LMSamplingUsage(pg.Object):
55
+ """Usage information per completion."""
56
+
57
+ prompt_tokens: int
58
+ completion_tokens: int
59
+ total_tokens: int
60
+
61
+
50
62
  class LMSamplingResult(pg.Object):
51
63
  """Language model response."""
52
64
 
@@ -58,19 +70,34 @@ class LMSamplingResult(pg.Object):
58
70
  ),
59
71
  ] = []
60
72
 
73
+ usage: Annotated[
74
+ LMSamplingUsage | None,
75
+ 'Usage information. Currently only OpenAI models are supported.',
76
+ ] = None
77
+
61
78
 
62
79
  class LMSamplingOptions(component.Component):
63
80
  """Language model sampling options."""
64
81
 
65
82
  temperature: Annotated[
66
- float,
83
+ float | None,
67
84
  (
68
85
  'Model temperature, which is usually between 0 and 1.0. '
69
- 'OpenAI models have temperature range from 0.0 to 2.0.'
86
+ 'OpenAI models have temperature range from 0.0 to 2.0. '
87
+ 'If None (default), honor the model\'s default behavior. '
70
88
  )
71
- ] = 0.0
72
- max_tokens: Annotated[int, 'Per example max tokens to generate.'] = 1024
89
+ ] = None
90
+
91
+ max_tokens: Annotated[
92
+ int | None,
93
+ (
94
+ 'Per example max tokens to generate. '
95
+ 'If None, use the model default.'
96
+ )
97
+ ] = None
98
+
73
99
  n: Annotated[int | None, 'Max number of samples to return.'] = 1
100
+
74
101
  top_k: Annotated[
75
102
  int | None,
76
103
  (
@@ -78,6 +105,7 @@ class LMSamplingOptions(component.Component):
78
105
  'Not applicable to OpenAI models.'
79
106
  )
80
107
  ] = 40
108
+
81
109
  top_p: Annotated[
82
110
  float | None,
83
111
  (
@@ -86,6 +114,7 @@ class LMSamplingOptions(component.Component):
86
114
  '`top_p` but not both.'
87
115
  ),
88
116
  ] = None
117
+
89
118
  stop: Annotated[
90
119
  list[str] | None,
91
120
  (
@@ -95,9 +124,11 @@ class LMSamplingOptions(component.Component):
95
124
  '`Model:` is reached.'
96
125
  ),
97
126
  ] = None
127
+
98
128
  random_seed: Annotated[
99
129
  int | None, 'A fixed random seed used during model inference.'
100
130
  ] = None
131
+
101
132
  logprobs: Annotated[
102
133
  bool,
103
134
  (
@@ -106,6 +137,7 @@ class LMSamplingOptions(component.Component):
106
137
  'in the content of message.'
107
138
  ),
108
139
  ] = False
140
+
109
141
  top_logprobs: Annotated[
110
142
  int | None,
111
143
  (
@@ -135,6 +167,11 @@ class LMScoringResult(pg.Object):
135
167
  float,
136
168
  'The log likelyhood of the requested completion towards the prompt.',
137
169
  ]
170
+ gradients: Annotated[
171
+ Any | None,
172
+ '(Optional) gradients from the score method, w.r.t.' +
173
+ ' prompt.metadata.weights.',
174
+ ] = None
138
175
 
139
176
 
140
177
  class LMCache(pg.Object):
@@ -315,9 +352,42 @@ class LanguageModel(component.Component):
315
352
 
316
353
  with component.context(override_attrs=True, **kwargs):
317
354
  if self.cache is None:
318
- return self._sample(prompts)
355
+ results = self._sample(prompts)
319
356
  else:
320
- return self._sample_with_cache_lookup(prompts, cache_seed)
357
+ results = self._sample_with_cache_lookup(prompts, cache_seed)
358
+
359
+ for prompt, result in zip(prompts, results):
360
+
361
+ # Tag LM input.
362
+ prompt.tag(message_lib.Message.TAG_LM_INPUT)
363
+
364
+ for sample in result.samples:
365
+ # Update metadata for response message.
366
+
367
+ response = sample.response
368
+ response.metadata.score = sample.score
369
+ response.metadata.logprobs = sample.logprobs
370
+
371
+ # NOTE(daiyip): Current usage is computed at per-result level,
372
+ # which is accurate when n=1. For n > 1, we average the usage across
373
+ # multiple samples.
374
+ usage = result.usage
375
+ if len(result.samples) == 1 or usage is None:
376
+ response.metadata.usage = usage
377
+ else:
378
+ n = len(result.samples)
379
+ response.metadata.usage = LMSamplingUsage(
380
+ prompt_tokens=usage.prompt_tokens // n,
381
+ completion_tokens=usage.completion_tokens // n,
382
+ total_tokens=usage.total_tokens // n,
383
+ )
384
+
385
+ # Track the prompt for corresponding response.
386
+ response.source = prompt
387
+
388
+ # Tag LM response.
389
+ response.tag(message_lib.Message.TAG_LM_RESPONSE)
390
+ return results
321
391
 
322
392
  def _sample_with_cache_lookup(
323
393
  self, prompts: list[str | message_lib.Message], cache_seed: int
@@ -405,12 +475,9 @@ class LanguageModel(component.Component):
405
475
  result = self.sample(
406
476
  [prompt], sampling_options=sampling_options, cache_seed=cache_seed
407
477
  )[0]
408
- response = result.samples[0].response
409
- logprobs = result.samples[0].logprobs
410
- response.set('score', result.samples[0].score)
411
- response.metadata.logprobs = logprobs
412
478
  elapse = time.time() - request_start
413
- self._debug(prompt, response, call_counter, elapse)
479
+ response = result.samples[0].response
480
+ self._debug(prompt, response, call_counter, result.usage, elapse)
414
481
  return response
415
482
 
416
483
  def _debug(
@@ -418,35 +485,53 @@ class LanguageModel(component.Component):
418
485
  prompt: message_lib.Message,
419
486
  response: message_lib.Message,
420
487
  call_counter: int,
488
+ usage: LMSamplingUsage | None,
421
489
  elapse: float,
422
- ):
490
+ ) -> None:
423
491
  """Outputs debugging information."""
424
492
  debug = self.debug
425
493
  if isinstance(debug, bool):
426
494
  debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
427
495
 
428
496
  if debug & LMDebugMode.INFO:
429
- self._debug_model_info(call_counter)
497
+ self._debug_model_info(call_counter, usage)
430
498
 
431
499
  if debug & LMDebugMode.PROMPT:
432
- self._debug_prompt(prompt, call_counter)
500
+ self._debug_prompt(prompt, call_counter, usage)
433
501
 
434
502
  if debug & LMDebugMode.RESPONSE:
435
- self._debug_response(response, call_counter, elapse)
503
+ self._debug_response(response, call_counter, usage, elapse)
436
504
 
437
- def _debug_model_info(self, call_counter: int):
505
+ def _debug_model_info(
506
+ self, call_counter: int, usage: LMSamplingUsage | None) -> None:
438
507
  """Outputs debugging information about the model."""
508
+ title_suffix = ''
509
+ if usage and usage.total_tokens != 0:
510
+ title_suffix = console.colored(
511
+ f' (total {usage.total_tokens} tokens)', 'red')
512
+
439
513
  console.write(
440
514
  self.format(compact=True, use_inferred=True),
441
- title=f'[{call_counter}] LM INFO:',
515
+ title=f'[{call_counter}] LM INFO{title_suffix}:',
442
516
  color='magenta',
443
517
  )
444
518
 
445
- def _debug_prompt(self, prompt: message_lib.Message, call_counter: int):
519
+ def _debug_prompt(
520
+ self,
521
+ prompt: message_lib.Message,
522
+ call_counter: int,
523
+ usage: LMSamplingUsage | None,
524
+ ) -> None:
446
525
  """Outputs debugging information about the prompt."""
526
+ title_suffix = ''
527
+ if usage and usage.prompt_tokens != 0:
528
+ title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
529
+
447
530
  console.write(
448
- prompt,
449
- title=f'\n[{call_counter}] PROMPT SENT TO LM:',
531
+ # We use metadata 'formatted_text' for scenarios where the prompt text
532
+ # is formatted by the LM.
533
+ prompt.get('formatted_text', prompt.text),
534
+ title=f'\n[{call_counter}] PROMPT SENT TO LM{title_suffix}:',
450
535
  color='green',
451
536
  )
452
537
  referred_modalities = prompt.referred_modalities()
@@ -460,12 +545,22 @@ class LanguageModel(component.Component):
460
545
  )
461
546
 
462
547
  def _debug_response(
463
- self, response: message_lib.Message, call_counter: int, elapse: float
464
- ):
548
+ self,
549
+ response: message_lib.Message,
550
+ call_counter: int,
551
+ usage: LMSamplingUsage | None,
552
+ elapse: float
553
+ ) -> None:
465
554
  """Outputs debugging information about the response."""
555
+ title_suffix = ' ('
556
+ if usage and usage.completion_tokens != 0:
557
+ title_suffix += f'{usage.completion_tokens} tokens '
558
+ title_suffix += f'in {elapse:.2f} seconds)'
559
+ title_suffix = console.colored(title_suffix, 'red')
560
+
466
561
  console.write(
467
562
  str(response) + '\n',
468
- title=f'\n[{call_counter}] LM RESPONSE (in {elapse:.2f} seconds):',
563
+ title=f'\n[{call_counter}] LM RESPONSE{title_suffix}:',
469
564
  color='blue',
470
565
  )
471
566
 
@@ -512,7 +607,7 @@ class LanguageModel(component.Component):
512
607
  debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
513
608
 
514
609
  if debug & LMDebugMode.INFO:
515
- self._debug_model_info(call_counter)
610
+ self._debug_model_info(call_counter, None)
516
611
 
517
612
  if debug & LMDebugMode.PROMPT:
518
613
  console.write(
@@ -548,3 +643,14 @@ class LanguageModel(component.Component):
548
643
  f'score: {r.score}',
549
644
  color='blue',
550
645
  )
646
+
647
+ def rate_to_max_concurrency(
648
+ self, requests_per_min: float = 0, tokens_per_min: float = 0
649
+ ) -> int:
650
+ """Converts a rate to a max concurrency."""
651
+ if tokens_per_min > 0:
652
+ return max(int(tokens_per_min / TOKENS_PER_REQUEST / 60), 1)
653
+ elif requests_per_min > 0:
654
+ return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
655
+ else:
656
+ return DEFAULT_MAX_CONCURRENCY # Default of 1
@@ -38,9 +38,19 @@ class MockModel(lm_lib.LanguageModel):
38
38
  def fake_sample(prompts):
39
39
  if context.attempt >= self.failures_before_attempt:
40
40
  return [
41
- lm_lib.LMSamplingResult([lm_lib.LMSample( # pylint: disable=g-complex-comprehension
42
- response=prompt.text * self.sampling_options.top_k,
43
- score=self.sampling_options.temperature)])
41
+ lm_lib.LMSamplingResult(
42
+ [
43
+ lm_lib.LMSample( # pylint: disable=g-complex-comprehension
44
+ response=prompt.text * self.sampling_options.top_k,
45
+ score=self.sampling_options.temperature or -1.0,
46
+ )
47
+ ],
48
+ usage=lm_lib.LMSamplingUsage(
49
+ prompt_tokens=100,
50
+ completion_tokens=100,
51
+ total_tokens=200,
52
+ ),
53
+ )
44
54
  for prompt in prompts
45
55
  ]
46
56
  context.attempt += 1
@@ -73,13 +83,13 @@ class LMSamplingOptionsTest(unittest.TestCase):
73
83
  def test_cache_key(self):
74
84
  options = lm_lib.LMSamplingOptions()
75
85
  key1 = options.cache_key()
76
- self.assertEqual(key1, (0.0, 1024, 1, 40, None, None))
86
+ self.assertEqual(key1, (None, None, 1, 40, None, None))
77
87
  with options.override(temperature=1.0, max_tokens=256):
78
88
  key2 = options.cache_key()
79
89
  self.assertEqual(key2, (1.0, 256, 1, 40, None, None))
80
90
 
81
91
  # Make sure key1 does not change upon override.
82
- self.assertEqual(key1, (0.0, 1024, 1, 40, None, None))
92
+ self.assertEqual(key1, (None, None, 1, 40, None, None))
83
93
 
84
94
 
85
95
  class LanguageModelTest(unittest.TestCase):
@@ -100,8 +110,38 @@ class LanguageModelTest(unittest.TestCase):
100
110
  self.assertEqual(
101
111
  lm.sample(prompts=['foo', 'bar']),
102
112
  [
103
- lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=0.0)]),
104
- lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=0.0)]),
113
+ lm_lib.LMSamplingResult(
114
+ [
115
+ lm_lib.LMSample(
116
+ message_lib.AIMessage(
117
+ 'foo',
118
+ score=-1.0,
119
+ logprobs=None,
120
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
121
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
122
+ ),
123
+ score=-1.0,
124
+ logprobs=None,
125
+ )
126
+ ],
127
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
128
+ ),
129
+ lm_lib.LMSamplingResult(
130
+ [
131
+ lm_lib.LMSample(
132
+ message_lib.AIMessage(
133
+ 'bar',
134
+ score=-1.0,
135
+ logprobs=None,
136
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
137
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
138
+ ),
139
+ score=-1.0,
140
+ logprobs=None,
141
+ )
142
+ ],
143
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
144
+ ),
105
145
  ],
106
146
  )
107
147
  # Test override sampling_options.
@@ -112,38 +152,128 @@ class LanguageModelTest(unittest.TestCase):
112
152
  ),
113
153
  [
114
154
  lm_lib.LMSamplingResult(
115
- [lm_lib.LMSample('foo' * 2, score=0.5)]
155
+ [
156
+ lm_lib.LMSample(
157
+ message_lib.AIMessage(
158
+ 'foo' * 2,
159
+ score=0.5,
160
+ logprobs=None,
161
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
162
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
163
+ ),
164
+ score=0.5,
165
+ logprobs=None,
166
+ ),
167
+ ],
168
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
116
169
  ),
117
170
  lm_lib.LMSamplingResult(
118
- [lm_lib.LMSample('bar' * 2, score=0.5)]
171
+ [
172
+ lm_lib.LMSample(
173
+ message_lib.AIMessage(
174
+ 'bar' * 2,
175
+ score=0.5,
176
+ logprobs=None,
177
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
178
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
179
+ ),
180
+ score=0.5,
181
+ logprobs=None,
182
+ ),
183
+ ],
184
+ usage=lm_lib.LMSamplingUsage(
185
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
186
+ ),
119
187
  ),
120
- ],
188
+ ]
121
189
  )
122
190
  # Test override individual flags within sampling_options.
123
191
  self.assertEqual(
124
192
  lm.sample(prompts=['foo', 'bar'], temperature=1.0),
125
193
  [
126
- lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=1.0)]),
127
- lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=1.0)]),
128
- ],
194
+ lm_lib.LMSamplingResult(
195
+ [
196
+ lm_lib.LMSample(
197
+ message_lib.AIMessage(
198
+ 'foo',
199
+ score=1.0,
200
+ logprobs=None,
201
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
202
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
203
+ ),
204
+ score=1.0,
205
+ logprobs=None,
206
+ ),
207
+ ],
208
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
209
+ ),
210
+ lm_lib.LMSamplingResult(
211
+ [
212
+ lm_lib.LMSample(
213
+ message_lib.AIMessage(
214
+ 'bar',
215
+ score=1.0,
216
+ logprobs=None,
217
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
218
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
219
+ ),
220
+ score=1.0,
221
+ logprobs=None,
222
+ ),
223
+ ],
224
+ usage=lm_lib.LMSamplingUsage(
225
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
226
+ ),
227
+ ),
228
+ ]
129
229
  )
130
230
  self.assertEqual(
131
231
  lm.sample(prompts=['foo', 'bar'], top_k=2, temperature=0.7),
132
232
  [
133
233
  lm_lib.LMSamplingResult(
134
- [lm_lib.LMSample('foo' * 2, score=0.7)]
234
+ [
235
+ lm_lib.LMSample(
236
+ message_lib.AIMessage(
237
+ 'foo' * 2,
238
+ score=0.7,
239
+ logprobs=None,
240
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
241
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
242
+ ),
243
+ score=0.7,
244
+ logprobs=None,
245
+ ),
246
+ ],
247
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
135
248
  ),
136
249
  lm_lib.LMSamplingResult(
137
- [lm_lib.LMSample('bar' * 2, score=0.7)]
250
+ [
251
+ lm_lib.LMSample(
252
+ message_lib.AIMessage(
253
+ 'bar' * 2,
254
+ score=0.7,
255
+ logprobs=None,
256
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
257
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
258
+ ),
259
+ score=0.7,
260
+ logprobs=None,
261
+ ),
262
+ ],
263
+ usage=lm_lib.LMSamplingUsage(
264
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
265
+ ),
138
266
  ),
139
- ],
267
+ ]
140
268
  )
141
269
 
142
270
  def test_call(self):
143
271
  lm = MockModel(sampling_options=lm_lib.LMSamplingOptions(top_k=1))
144
272
  response = lm(prompt='foo')
145
273
  self.assertEqual(response.text, 'foo')
146
- self.assertEqual(response.score, 0.0)
274
+ self.assertEqual(response.score, -1.0)
275
+ self.assertIsNone(response.logprobs)
276
+ self.assertEqual(response.usage, lm_lib.LMSamplingUsage(100, 100, 200))
147
277
 
148
278
  # Test override sampling_options.
149
279
  self.assertEqual(
@@ -158,11 +288,42 @@ class LanguageModelTest(unittest.TestCase):
158
288
  self.assertEqual(
159
289
  lm.sample(prompts=['foo', 'bar']),
160
290
  [
161
- lm_lib.LMSamplingResult([lm_lib.LMSample(
162
- message_lib.AIMessage('foo', cache_seed=0), score=0.0)]),
163
- lm_lib.LMSamplingResult([lm_lib.LMSample(
164
- message_lib.AIMessage('bar', cache_seed=0), score=0.0)]),
165
- ])
291
+ lm_lib.LMSamplingResult(
292
+ [
293
+ lm_lib.LMSample(
294
+ message_lib.AIMessage(
295
+ 'foo',
296
+ cache_seed=0,
297
+ score=-1.0,
298
+ logprobs=None,
299
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
300
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
301
+ ),
302
+ score=-1.0,
303
+ logprobs=None,
304
+ )
305
+ ],
306
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
307
+ ),
308
+ lm_lib.LMSamplingResult(
309
+ [
310
+ lm_lib.LMSample(
311
+ message_lib.AIMessage(
312
+ 'bar',
313
+ cache_seed=0,
314
+ score=-1.0,
315
+ logprobs=None,
316
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
317
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
318
+ ),
319
+ score=-1.0,
320
+ logprobs=None,
321
+ )
322
+ ],
323
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
324
+ ),
325
+ ],
326
+ )
166
327
  self.assertEqual(cache.stats.num_queries, 2)
167
328
  self.assertEqual(cache.stats.num_hits, 0)
168
329
  self.assertEqual(cache.stats.num_updates, 2)
@@ -181,10 +342,40 @@ class LanguageModelTest(unittest.TestCase):
181
342
  self.assertEqual(
182
343
  lm.sample(prompts=['foo', 'baz'], temperature=1.0),
183
344
  [
184
- lm_lib.LMSamplingResult([lm_lib.LMSample(
185
- message_lib.AIMessage('foo', cache_seed=0), score=1.0)]),
186
- lm_lib.LMSamplingResult([lm_lib.LMSample(
187
- message_lib.AIMessage('baz', cache_seed=0), score=1.0)]),
345
+ lm_lib.LMSamplingResult(
346
+ [
347
+ lm_lib.LMSample(
348
+ message_lib.AIMessage(
349
+ 'foo',
350
+ cache_seed=0,
351
+ score=1.0,
352
+ logprobs=None,
353
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
354
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
355
+ ),
356
+ score=1.0,
357
+ logprobs=None,
358
+ )
359
+ ],
360
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
361
+ ),
362
+ lm_lib.LMSamplingResult(
363
+ [
364
+ lm_lib.LMSample(
365
+ message_lib.AIMessage(
366
+ 'baz',
367
+ cache_seed=0,
368
+ score=1.0,
369
+ logprobs=None,
370
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
371
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
372
+ ),
373
+ score=1.0,
374
+ logprobs=None,
375
+ )
376
+ ],
377
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
378
+ ),
188
379
  ],
189
380
  )
190
381
  self.assertEqual(cache.stats.num_queries, 6)
@@ -341,6 +532,38 @@ class LanguageModelTest(unittest.TestCase):
341
532
  with self.assertRaises(NotImplementedError):
342
533
  MockModel().score('hi', ['1', '2'])
343
534
 
535
+ def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
536
+ lm = MockModel()
537
+ self.assertEqual(
538
+ lm_lib.DEFAULT_MAX_CONCURRENCY,
539
+ lm.rate_to_max_concurrency(requests_per_min=0, tokens_per_min=0),
540
+ )
541
+ self.assertEqual(
542
+ lm_lib.DEFAULT_MAX_CONCURRENCY,
543
+ lm.rate_to_max_concurrency(requests_per_min=-1, tokens_per_min=-1),
544
+ )
545
+
546
+ def test_rate_to_max_concurrency_only_rpm_specified_uses_rpm(self) -> None:
547
+ lm = MockModel()
548
+ test_rpm = 1e4
549
+ self.assertEqual(
550
+ lm.rate_to_max_concurrency(requests_per_min=test_rpm),
551
+ int(test_rpm / 60)
552
+ )
553
+
554
+ def test_rate_to_max_concurrency_tpm_specified_uses_tpm(self) -> None:
555
+ lm = MockModel()
556
+ test_tpm = 1e7
557
+ self.assertEqual(
558
+ lm.rate_to_max_concurrency(requests_per_min=1, tokens_per_min=test_tpm),
559
+ int(test_tpm / lm_lib.TOKENS_PER_REQUEST / 60),
560
+ )
561
+
562
+ def test_rate_to_max_concurrency_small_rate_returns_one(self) -> None:
563
+ lm = MockModel()
564
+ self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
565
+ self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
566
+
344
567
 
345
568
  if __name__ == '__main__':
346
569
  unittest.main()
@@ -27,6 +27,7 @@ from langfun.core.llms.fake import StaticSequence
27
27
  # Gemini models.
28
28
  from langfun.core.llms.google_genai import GenAI
29
29
  from langfun.core.llms.google_genai import GeminiPro
30
+ from langfun.core.llms.google_genai import GeminiPro1_5
30
31
  from langfun.core.llms.google_genai import GeminiProVision
31
32
  from langfun.core.llms.google_genai import Palm2
32
33
  from langfun.core.llms.google_genai import Palm2_IT
@@ -35,8 +36,12 @@ from langfun.core.llms.google_genai import Palm2_IT
35
36
  from langfun.core.llms.openai import OpenAI
36
37
 
37
38
  from langfun.core.llms.openai import Gpt4Turbo
38
- from langfun.core.llms.openai import Gpt4Turbo_0125
39
- from langfun.core.llms.openai import Gpt4TurboVision
39
+ from langfun.core.llms.openai import Gpt4Turbo_20240409
40
+ from langfun.core.llms.openai import Gpt4TurboPreview
41
+ from langfun.core.llms.openai import Gpt4TurboPreview_0125
42
+ from langfun.core.llms.openai import Gpt4TurboPreview_1106
43
+ from langfun.core.llms.openai import Gpt4VisionPreview
44
+ from langfun.core.llms.openai import Gpt4VisionPreview_1106
40
45
  from langfun.core.llms.openai import Gpt4
41
46
  from langfun.core.llms.openai import Gpt4_0613
42
47
  from langfun.core.llms.openai import Gpt4_32K
@@ -57,6 +62,26 @@ from langfun.core.llms.openai import Gpt3Curie
57
62
  from langfun.core.llms.openai import Gpt3Babbage
58
63
  from langfun.core.llms.openai import Gpt3Ada
59
64
 
65
+ from langfun.core.llms.anthropic import Anthropic
66
+ from langfun.core.llms.anthropic import Claude3Opus
67
+ from langfun.core.llms.anthropic import Claude3Sonnet
68
+ from langfun.core.llms.anthropic import Claude3Haiku
69
+
70
+ from langfun.core.llms.groq import Groq
71
+ from langfun.core.llms.groq import GroqLlama3_70B
72
+ from langfun.core.llms.groq import GroqLlama3_8B
73
+ from langfun.core.llms.groq import GroqLlama2_70B
74
+ from langfun.core.llms.groq import GroqMistral_8x7B
75
+ from langfun.core.llms.groq import GroqGemma7B_IT
76
+
77
+ from langfun.core.llms.vertexai import VertexAI
78
+ from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
79
+ from langfun.core.llms.vertexai import VertexAIGeminiPro1
80
+ from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision
81
+ from langfun.core.llms.vertexai import VertexAIPalm2
82
+ from langfun.core.llms.vertexai import VertexAIPalm2_32K
83
+
84
+
60
85
  # LLaMA C++ models.
61
86
  from langfun.core.llms.llama_cpp import LlamaCppRemote
62
87