langfun 0.0.2.dev20240319__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +2 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +2 -0
- langfun/core/eval/base.py +240 -37
- langfun/core/eval/base_test.py +52 -18
- langfun/core/eval/matching.py +26 -9
- langfun/core/eval/matching_test.py +3 -4
- langfun/core/eval/scoring.py +15 -6
- langfun/core/eval/scoring_test.py +2 -2
- langfun/core/langfunc.py +0 -5
- langfun/core/langfunc_test.py +6 -4
- langfun/core/language_model.py +124 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +24 -5
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +31 -22
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/{gemini.py → google_genai.py} +117 -15
- langfun/core/llms/{gemini_test.py → google_genai_test.py} +83 -15
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +97 -79
- langfun/core/llms/openai_test.py +285 -59
- langfun/core/modalities/video.py +5 -2
- langfun/core/structured/__init__.py +3 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +59 -3
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing.py +2 -1
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +27 -6
- langfun/core/structured/prompting_test.py +79 -12
- langfun/core/structured/schema.py +25 -22
- langfun/core/structured/schema_generation.py +2 -3
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +42 -27
- langfun/core/template.py +125 -10
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +52 -46
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240319.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
langfun/core/language_model.py
CHANGED
@@ -24,6 +24,9 @@ from langfun.core import console
|
|
24
24
|
from langfun.core import message as message_lib
|
25
25
|
import pyglove as pg
|
26
26
|
|
27
|
+
TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
|
28
|
+
DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
|
29
|
+
|
27
30
|
|
28
31
|
class LMSample(pg.Object):
|
29
32
|
"""Response candidate."""
|
@@ -47,6 +50,14 @@ class LMSample(pg.Object):
|
|
47
50
|
] = None
|
48
51
|
|
49
52
|
|
53
|
+
class LMSamplingUsage(pg.Object):
|
54
|
+
"""Usage information per completion."""
|
55
|
+
|
56
|
+
prompt_tokens: int
|
57
|
+
completion_tokens: int
|
58
|
+
total_tokens: int
|
59
|
+
|
60
|
+
|
50
61
|
class LMSamplingResult(pg.Object):
|
51
62
|
"""Language model response."""
|
52
63
|
|
@@ -58,19 +69,34 @@ class LMSamplingResult(pg.Object):
|
|
58
69
|
),
|
59
70
|
] = []
|
60
71
|
|
72
|
+
usage: Annotated[
|
73
|
+
LMSamplingUsage | None,
|
74
|
+
'Usage information. Currently only OpenAI models are supported.',
|
75
|
+
] = None
|
76
|
+
|
61
77
|
|
62
78
|
class LMSamplingOptions(component.Component):
|
63
79
|
"""Language model sampling options."""
|
64
80
|
|
65
81
|
temperature: Annotated[
|
66
|
-
float,
|
82
|
+
float | None,
|
67
83
|
(
|
68
84
|
'Model temperature, which is usually between 0 and 1.0. '
|
69
|
-
'OpenAI models have temperature range from 0.0 to 2.0.'
|
85
|
+
'OpenAI models have temperature range from 0.0 to 2.0. '
|
86
|
+
'If None (default), honor the model\'s default behavior. '
|
70
87
|
)
|
71
|
-
] =
|
72
|
-
|
88
|
+
] = None
|
89
|
+
|
90
|
+
max_tokens: Annotated[
|
91
|
+
int | None,
|
92
|
+
(
|
93
|
+
'Per example max tokens to generate. '
|
94
|
+
'If None, use the model default.'
|
95
|
+
)
|
96
|
+
] = None
|
97
|
+
|
73
98
|
n: Annotated[int | None, 'Max number of samples to return.'] = 1
|
99
|
+
|
74
100
|
top_k: Annotated[
|
75
101
|
int | None,
|
76
102
|
(
|
@@ -78,6 +104,7 @@ class LMSamplingOptions(component.Component):
|
|
78
104
|
'Not applicable to OpenAI models.'
|
79
105
|
)
|
80
106
|
] = 40
|
107
|
+
|
81
108
|
top_p: Annotated[
|
82
109
|
float | None,
|
83
110
|
(
|
@@ -86,6 +113,7 @@ class LMSamplingOptions(component.Component):
|
|
86
113
|
'`top_p` but not both.'
|
87
114
|
),
|
88
115
|
] = None
|
116
|
+
|
89
117
|
stop: Annotated[
|
90
118
|
list[str] | None,
|
91
119
|
(
|
@@ -95,9 +123,11 @@ class LMSamplingOptions(component.Component):
|
|
95
123
|
'`Model:` is reached.'
|
96
124
|
),
|
97
125
|
] = None
|
126
|
+
|
98
127
|
random_seed: Annotated[
|
99
128
|
int | None, 'A fixed random seed used during model inference.'
|
100
129
|
] = None
|
130
|
+
|
101
131
|
logprobs: Annotated[
|
102
132
|
bool,
|
103
133
|
(
|
@@ -106,6 +136,7 @@ class LMSamplingOptions(component.Component):
|
|
106
136
|
'in the content of message.'
|
107
137
|
),
|
108
138
|
] = False
|
139
|
+
|
109
140
|
top_logprobs: Annotated[
|
110
141
|
int | None,
|
111
142
|
(
|
@@ -315,9 +346,42 @@ class LanguageModel(component.Component):
|
|
315
346
|
|
316
347
|
with component.context(override_attrs=True, **kwargs):
|
317
348
|
if self.cache is None:
|
318
|
-
|
349
|
+
results = self._sample(prompts)
|
319
350
|
else:
|
320
|
-
|
351
|
+
results = self._sample_with_cache_lookup(prompts, cache_seed)
|
352
|
+
|
353
|
+
for prompt, result in zip(prompts, results):
|
354
|
+
|
355
|
+
# Tag LM input.
|
356
|
+
prompt.tag(message_lib.Message.TAG_LM_INPUT)
|
357
|
+
|
358
|
+
for sample in result.samples:
|
359
|
+
# Update metadata for response message.
|
360
|
+
|
361
|
+
response = sample.response
|
362
|
+
response.metadata.score = sample.score
|
363
|
+
response.metadata.logprobs = sample.logprobs
|
364
|
+
|
365
|
+
# NOTE(daiyip): Current usage is computed at per-result level,
|
366
|
+
# which is accurate when n=1. For n > 1, we average the usage across
|
367
|
+
# multiple samples.
|
368
|
+
usage = result.usage
|
369
|
+
if len(result.samples) == 1 or usage is None:
|
370
|
+
response.metadata.usage = usage
|
371
|
+
else:
|
372
|
+
n = len(result.samples)
|
373
|
+
response.metadata.usage = LMSamplingUsage(
|
374
|
+
prompt_tokens=usage.prompt_tokens // n,
|
375
|
+
completion_tokens=usage.completion_tokens // n,
|
376
|
+
total_tokens=usage.total_tokens // n,
|
377
|
+
)
|
378
|
+
|
379
|
+
# Track the prompt for corresponding response.
|
380
|
+
response.source = prompt
|
381
|
+
|
382
|
+
# Tag LM response.
|
383
|
+
response.tag(message_lib.Message.TAG_LM_RESPONSE)
|
384
|
+
return results
|
321
385
|
|
322
386
|
def _sample_with_cache_lookup(
|
323
387
|
self, prompts: list[str | message_lib.Message], cache_seed: int
|
@@ -405,12 +469,9 @@ class LanguageModel(component.Component):
|
|
405
469
|
result = self.sample(
|
406
470
|
[prompt], sampling_options=sampling_options, cache_seed=cache_seed
|
407
471
|
)[0]
|
408
|
-
response = result.samples[0].response
|
409
|
-
logprobs = result.samples[0].logprobs
|
410
|
-
response.set('score', result.samples[0].score)
|
411
|
-
response.metadata.logprobs = logprobs
|
412
472
|
elapse = time.time() - request_start
|
413
|
-
|
473
|
+
response = result.samples[0].response
|
474
|
+
self._debug(prompt, response, call_counter, result.usage, elapse)
|
414
475
|
return response
|
415
476
|
|
416
477
|
def _debug(
|
@@ -418,35 +479,53 @@ class LanguageModel(component.Component):
|
|
418
479
|
prompt: message_lib.Message,
|
419
480
|
response: message_lib.Message,
|
420
481
|
call_counter: int,
|
482
|
+
usage: LMSamplingUsage | None,
|
421
483
|
elapse: float,
|
422
|
-
):
|
484
|
+
) -> None:
|
423
485
|
"""Outputs debugging information."""
|
424
486
|
debug = self.debug
|
425
487
|
if isinstance(debug, bool):
|
426
488
|
debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
|
427
489
|
|
428
490
|
if debug & LMDebugMode.INFO:
|
429
|
-
self._debug_model_info(call_counter)
|
491
|
+
self._debug_model_info(call_counter, usage)
|
430
492
|
|
431
493
|
if debug & LMDebugMode.PROMPT:
|
432
|
-
self._debug_prompt(prompt, call_counter)
|
494
|
+
self._debug_prompt(prompt, call_counter, usage)
|
433
495
|
|
434
496
|
if debug & LMDebugMode.RESPONSE:
|
435
|
-
self._debug_response(response, call_counter, elapse)
|
497
|
+
self._debug_response(response, call_counter, usage, elapse)
|
436
498
|
|
437
|
-
def _debug_model_info(
|
499
|
+
def _debug_model_info(
|
500
|
+
self, call_counter: int, usage: LMSamplingUsage | None) -> None:
|
438
501
|
"""Outputs debugging information about the model."""
|
502
|
+
title_suffix = ''
|
503
|
+
if usage and usage.total_tokens != 0:
|
504
|
+
title_suffix = console.colored(
|
505
|
+
f' (total {usage.total_tokens} tokens)', 'red')
|
506
|
+
|
439
507
|
console.write(
|
440
508
|
self.format(compact=True, use_inferred=True),
|
441
|
-
title=f'[{call_counter}] LM INFO:',
|
509
|
+
title=f'[{call_counter}] LM INFO{title_suffix}:',
|
442
510
|
color='magenta',
|
443
511
|
)
|
444
512
|
|
445
|
-
def _debug_prompt(
|
513
|
+
def _debug_prompt(
|
514
|
+
self,
|
515
|
+
prompt: message_lib.Message,
|
516
|
+
call_counter: int,
|
517
|
+
usage: LMSamplingUsage | None,
|
518
|
+
) -> None:
|
446
519
|
"""Outputs debugging information about the prompt."""
|
520
|
+
title_suffix = ''
|
521
|
+
if usage and usage.prompt_tokens != 0:
|
522
|
+
title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
|
523
|
+
|
447
524
|
console.write(
|
448
|
-
prompt
|
449
|
-
|
525
|
+
# We use metadata 'formatted_text' for scenarios where the prompt text
|
526
|
+
# is formatted by the LM.
|
527
|
+
prompt.get('formatted_text', prompt.text),
|
528
|
+
title=f'\n[{call_counter}] PROMPT SENT TO LM{title_suffix}:',
|
450
529
|
color='green',
|
451
530
|
)
|
452
531
|
referred_modalities = prompt.referred_modalities()
|
@@ -460,12 +539,22 @@ class LanguageModel(component.Component):
|
|
460
539
|
)
|
461
540
|
|
462
541
|
def _debug_response(
|
463
|
-
self,
|
464
|
-
|
542
|
+
self,
|
543
|
+
response: message_lib.Message,
|
544
|
+
call_counter: int,
|
545
|
+
usage: LMSamplingUsage | None,
|
546
|
+
elapse: float
|
547
|
+
) -> None:
|
465
548
|
"""Outputs debugging information about the response."""
|
549
|
+
title_suffix = ' ('
|
550
|
+
if usage and usage.completion_tokens != 0:
|
551
|
+
title_suffix += f'{usage.completion_tokens} tokens '
|
552
|
+
title_suffix += f'in {elapse:.2f} seconds)'
|
553
|
+
title_suffix = console.colored(title_suffix, 'red')
|
554
|
+
|
466
555
|
console.write(
|
467
556
|
str(response) + '\n',
|
468
|
-
title=f'\n[{call_counter}] LM RESPONSE
|
557
|
+
title=f'\n[{call_counter}] LM RESPONSE{title_suffix}:',
|
469
558
|
color='blue',
|
470
559
|
)
|
471
560
|
|
@@ -512,7 +601,7 @@ class LanguageModel(component.Component):
|
|
512
601
|
debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
|
513
602
|
|
514
603
|
if debug & LMDebugMode.INFO:
|
515
|
-
self._debug_model_info(call_counter)
|
604
|
+
self._debug_model_info(call_counter, None)
|
516
605
|
|
517
606
|
if debug & LMDebugMode.PROMPT:
|
518
607
|
console.write(
|
@@ -548,3 +637,14 @@ class LanguageModel(component.Component):
|
|
548
637
|
f'score: {r.score}',
|
549
638
|
color='blue',
|
550
639
|
)
|
640
|
+
|
641
|
+
def rate_to_max_concurrency(
|
642
|
+
self, requests_per_min: float = 0, tokens_per_min: float = 0
|
643
|
+
) -> int:
|
644
|
+
"""Converts a rate to a max concurrency."""
|
645
|
+
if tokens_per_min > 0:
|
646
|
+
return max(int(tokens_per_min / TOKENS_PER_REQUEST / 60), 1)
|
647
|
+
elif requests_per_min > 0:
|
648
|
+
return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
|
649
|
+
else:
|
650
|
+
return DEFAULT_MAX_CONCURRENCY # Default of 1
|
@@ -38,9 +38,19 @@ class MockModel(lm_lib.LanguageModel):
|
|
38
38
|
def fake_sample(prompts):
|
39
39
|
if context.attempt >= self.failures_before_attempt:
|
40
40
|
return [
|
41
|
-
lm_lib.LMSamplingResult(
|
42
|
-
|
43
|
-
|
41
|
+
lm_lib.LMSamplingResult(
|
42
|
+
[
|
43
|
+
lm_lib.LMSample( # pylint: disable=g-complex-comprehension
|
44
|
+
response=prompt.text * self.sampling_options.top_k,
|
45
|
+
score=self.sampling_options.temperature or -1.0,
|
46
|
+
)
|
47
|
+
],
|
48
|
+
usage=lm_lib.LMSamplingUsage(
|
49
|
+
prompt_tokens=100,
|
50
|
+
completion_tokens=100,
|
51
|
+
total_tokens=200,
|
52
|
+
),
|
53
|
+
)
|
44
54
|
for prompt in prompts
|
45
55
|
]
|
46
56
|
context.attempt += 1
|
@@ -73,13 +83,13 @@ class LMSamplingOptionsTest(unittest.TestCase):
|
|
73
83
|
def test_cache_key(self):
|
74
84
|
options = lm_lib.LMSamplingOptions()
|
75
85
|
key1 = options.cache_key()
|
76
|
-
self.assertEqual(key1, (
|
86
|
+
self.assertEqual(key1, (None, None, 1, 40, None, None))
|
77
87
|
with options.override(temperature=1.0, max_tokens=256):
|
78
88
|
key2 = options.cache_key()
|
79
89
|
self.assertEqual(key2, (1.0, 256, 1, 40, None, None))
|
80
90
|
|
81
91
|
# Make sure key1 does not change upon override.
|
82
|
-
self.assertEqual(key1, (
|
92
|
+
self.assertEqual(key1, (None, None, 1, 40, None, None))
|
83
93
|
|
84
94
|
|
85
95
|
class LanguageModelTest(unittest.TestCase):
|
@@ -100,8 +110,38 @@ class LanguageModelTest(unittest.TestCase):
|
|
100
110
|
self.assertEqual(
|
101
111
|
lm.sample(prompts=['foo', 'bar']),
|
102
112
|
[
|
103
|
-
lm_lib.LMSamplingResult(
|
104
|
-
|
113
|
+
lm_lib.LMSamplingResult(
|
114
|
+
[
|
115
|
+
lm_lib.LMSample(
|
116
|
+
message_lib.AIMessage(
|
117
|
+
'foo',
|
118
|
+
score=-1.0,
|
119
|
+
logprobs=None,
|
120
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
121
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
122
|
+
),
|
123
|
+
score=-1.0,
|
124
|
+
logprobs=None,
|
125
|
+
)
|
126
|
+
],
|
127
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
128
|
+
),
|
129
|
+
lm_lib.LMSamplingResult(
|
130
|
+
[
|
131
|
+
lm_lib.LMSample(
|
132
|
+
message_lib.AIMessage(
|
133
|
+
'bar',
|
134
|
+
score=-1.0,
|
135
|
+
logprobs=None,
|
136
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
137
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
138
|
+
),
|
139
|
+
score=-1.0,
|
140
|
+
logprobs=None,
|
141
|
+
)
|
142
|
+
],
|
143
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
144
|
+
),
|
105
145
|
],
|
106
146
|
)
|
107
147
|
# Test override sampling_options.
|
@@ -112,38 +152,128 @@ class LanguageModelTest(unittest.TestCase):
|
|
112
152
|
),
|
113
153
|
[
|
114
154
|
lm_lib.LMSamplingResult(
|
115
|
-
[
|
155
|
+
[
|
156
|
+
lm_lib.LMSample(
|
157
|
+
message_lib.AIMessage(
|
158
|
+
'foo' * 2,
|
159
|
+
score=0.5,
|
160
|
+
logprobs=None,
|
161
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
162
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
163
|
+
),
|
164
|
+
score=0.5,
|
165
|
+
logprobs=None,
|
166
|
+
),
|
167
|
+
],
|
168
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
116
169
|
),
|
117
170
|
lm_lib.LMSamplingResult(
|
118
|
-
[
|
171
|
+
[
|
172
|
+
lm_lib.LMSample(
|
173
|
+
message_lib.AIMessage(
|
174
|
+
'bar' * 2,
|
175
|
+
score=0.5,
|
176
|
+
logprobs=None,
|
177
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
178
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
179
|
+
),
|
180
|
+
score=0.5,
|
181
|
+
logprobs=None,
|
182
|
+
),
|
183
|
+
],
|
184
|
+
usage=lm_lib.LMSamplingUsage(
|
185
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
186
|
+
),
|
119
187
|
),
|
120
|
-
]
|
188
|
+
]
|
121
189
|
)
|
122
190
|
# Test override individual flags within sampling_options.
|
123
191
|
self.assertEqual(
|
124
192
|
lm.sample(prompts=['foo', 'bar'], temperature=1.0),
|
125
193
|
[
|
126
|
-
lm_lib.LMSamplingResult(
|
127
|
-
|
128
|
-
|
194
|
+
lm_lib.LMSamplingResult(
|
195
|
+
[
|
196
|
+
lm_lib.LMSample(
|
197
|
+
message_lib.AIMessage(
|
198
|
+
'foo',
|
199
|
+
score=1.0,
|
200
|
+
logprobs=None,
|
201
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
202
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
203
|
+
),
|
204
|
+
score=1.0,
|
205
|
+
logprobs=None,
|
206
|
+
),
|
207
|
+
],
|
208
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
209
|
+
),
|
210
|
+
lm_lib.LMSamplingResult(
|
211
|
+
[
|
212
|
+
lm_lib.LMSample(
|
213
|
+
message_lib.AIMessage(
|
214
|
+
'bar',
|
215
|
+
score=1.0,
|
216
|
+
logprobs=None,
|
217
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
218
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
219
|
+
),
|
220
|
+
score=1.0,
|
221
|
+
logprobs=None,
|
222
|
+
),
|
223
|
+
],
|
224
|
+
usage=lm_lib.LMSamplingUsage(
|
225
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
226
|
+
),
|
227
|
+
),
|
228
|
+
]
|
129
229
|
)
|
130
230
|
self.assertEqual(
|
131
231
|
lm.sample(prompts=['foo', 'bar'], top_k=2, temperature=0.7),
|
132
232
|
[
|
133
233
|
lm_lib.LMSamplingResult(
|
134
|
-
[
|
234
|
+
[
|
235
|
+
lm_lib.LMSample(
|
236
|
+
message_lib.AIMessage(
|
237
|
+
'foo' * 2,
|
238
|
+
score=0.7,
|
239
|
+
logprobs=None,
|
240
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
241
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
242
|
+
),
|
243
|
+
score=0.7,
|
244
|
+
logprobs=None,
|
245
|
+
),
|
246
|
+
],
|
247
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
135
248
|
),
|
136
249
|
lm_lib.LMSamplingResult(
|
137
|
-
[
|
250
|
+
[
|
251
|
+
lm_lib.LMSample(
|
252
|
+
message_lib.AIMessage(
|
253
|
+
'bar' * 2,
|
254
|
+
score=0.7,
|
255
|
+
logprobs=None,
|
256
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
257
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
258
|
+
),
|
259
|
+
score=0.7,
|
260
|
+
logprobs=None,
|
261
|
+
),
|
262
|
+
],
|
263
|
+
usage=lm_lib.LMSamplingUsage(
|
264
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
265
|
+
),
|
138
266
|
),
|
139
|
-
]
|
267
|
+
]
|
140
268
|
)
|
141
269
|
|
142
270
|
def test_call(self):
|
143
271
|
lm = MockModel(sampling_options=lm_lib.LMSamplingOptions(top_k=1))
|
144
272
|
response = lm(prompt='foo')
|
145
273
|
self.assertEqual(response.text, 'foo')
|
146
|
-
self.assertEqual(response.score,
|
274
|
+
self.assertEqual(response.score, -1.0)
|
275
|
+
self.assertIsNone(response.logprobs)
|
276
|
+
self.assertEqual(response.usage, lm_lib.LMSamplingUsage(100, 100, 200))
|
147
277
|
|
148
278
|
# Test override sampling_options.
|
149
279
|
self.assertEqual(
|
@@ -158,11 +288,42 @@ class LanguageModelTest(unittest.TestCase):
|
|
158
288
|
self.assertEqual(
|
159
289
|
lm.sample(prompts=['foo', 'bar']),
|
160
290
|
[
|
161
|
-
lm_lib.LMSamplingResult(
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
291
|
+
lm_lib.LMSamplingResult(
|
292
|
+
[
|
293
|
+
lm_lib.LMSample(
|
294
|
+
message_lib.AIMessage(
|
295
|
+
'foo',
|
296
|
+
cache_seed=0,
|
297
|
+
score=-1.0,
|
298
|
+
logprobs=None,
|
299
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
300
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
301
|
+
),
|
302
|
+
score=-1.0,
|
303
|
+
logprobs=None,
|
304
|
+
)
|
305
|
+
],
|
306
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
307
|
+
),
|
308
|
+
lm_lib.LMSamplingResult(
|
309
|
+
[
|
310
|
+
lm_lib.LMSample(
|
311
|
+
message_lib.AIMessage(
|
312
|
+
'bar',
|
313
|
+
cache_seed=0,
|
314
|
+
score=-1.0,
|
315
|
+
logprobs=None,
|
316
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
317
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
318
|
+
),
|
319
|
+
score=-1.0,
|
320
|
+
logprobs=None,
|
321
|
+
)
|
322
|
+
],
|
323
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
324
|
+
),
|
325
|
+
],
|
326
|
+
)
|
166
327
|
self.assertEqual(cache.stats.num_queries, 2)
|
167
328
|
self.assertEqual(cache.stats.num_hits, 0)
|
168
329
|
self.assertEqual(cache.stats.num_updates, 2)
|
@@ -181,10 +342,40 @@ class LanguageModelTest(unittest.TestCase):
|
|
181
342
|
self.assertEqual(
|
182
343
|
lm.sample(prompts=['foo', 'baz'], temperature=1.0),
|
183
344
|
[
|
184
|
-
lm_lib.LMSamplingResult(
|
185
|
-
|
186
|
-
|
187
|
-
|
345
|
+
lm_lib.LMSamplingResult(
|
346
|
+
[
|
347
|
+
lm_lib.LMSample(
|
348
|
+
message_lib.AIMessage(
|
349
|
+
'foo',
|
350
|
+
cache_seed=0,
|
351
|
+
score=1.0,
|
352
|
+
logprobs=None,
|
353
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
354
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
355
|
+
),
|
356
|
+
score=1.0,
|
357
|
+
logprobs=None,
|
358
|
+
)
|
359
|
+
],
|
360
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
361
|
+
),
|
362
|
+
lm_lib.LMSamplingResult(
|
363
|
+
[
|
364
|
+
lm_lib.LMSample(
|
365
|
+
message_lib.AIMessage(
|
366
|
+
'baz',
|
367
|
+
cache_seed=0,
|
368
|
+
score=1.0,
|
369
|
+
logprobs=None,
|
370
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
371
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
372
|
+
),
|
373
|
+
score=1.0,
|
374
|
+
logprobs=None,
|
375
|
+
)
|
376
|
+
],
|
377
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
378
|
+
),
|
188
379
|
],
|
189
380
|
)
|
190
381
|
self.assertEqual(cache.stats.num_queries, 6)
|
@@ -341,6 +532,38 @@ class LanguageModelTest(unittest.TestCase):
|
|
341
532
|
with self.assertRaises(NotImplementedError):
|
342
533
|
MockModel().score('hi', ['1', '2'])
|
343
534
|
|
535
|
+
def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
|
536
|
+
lm = MockModel()
|
537
|
+
self.assertEqual(
|
538
|
+
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
539
|
+
lm.rate_to_max_concurrency(requests_per_min=0, tokens_per_min=0),
|
540
|
+
)
|
541
|
+
self.assertEqual(
|
542
|
+
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
543
|
+
lm.rate_to_max_concurrency(requests_per_min=-1, tokens_per_min=-1),
|
544
|
+
)
|
545
|
+
|
546
|
+
def test_rate_to_max_concurrency_only_rpm_specified_uses_rpm(self) -> None:
|
547
|
+
lm = MockModel()
|
548
|
+
test_rpm = 1e4
|
549
|
+
self.assertEqual(
|
550
|
+
lm.rate_to_max_concurrency(requests_per_min=test_rpm),
|
551
|
+
int(test_rpm / 60)
|
552
|
+
)
|
553
|
+
|
554
|
+
def test_rate_to_max_concurrency_tpm_specified_uses_tpm(self) -> None:
|
555
|
+
lm = MockModel()
|
556
|
+
test_tpm = 1e7
|
557
|
+
self.assertEqual(
|
558
|
+
lm.rate_to_max_concurrency(requests_per_min=1, tokens_per_min=test_tpm),
|
559
|
+
int(test_tpm / lm_lib.TOKENS_PER_REQUEST / 60),
|
560
|
+
)
|
561
|
+
|
562
|
+
def test_rate_to_max_concurrency_small_rate_returns_one(self) -> None:
|
563
|
+
lm = MockModel()
|
564
|
+
self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
|
565
|
+
self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
|
566
|
+
|
344
567
|
|
345
568
|
if __name__ == '__main__':
|
346
569
|
unittest.main()
|
langfun/core/llms/__init__.py
CHANGED
@@ -25,16 +25,22 @@ from langfun.core.llms.fake import StaticResponse
|
|
25
25
|
from langfun.core.llms.fake import StaticSequence
|
26
26
|
|
27
27
|
# Gemini models.
|
28
|
-
from langfun.core.llms.
|
29
|
-
from langfun.core.llms.
|
30
|
-
from langfun.core.llms.
|
28
|
+
from langfun.core.llms.google_genai import GenAI
|
29
|
+
from langfun.core.llms.google_genai import GeminiPro
|
30
|
+
from langfun.core.llms.google_genai import GeminiProVision
|
31
|
+
from langfun.core.llms.google_genai import Palm2
|
32
|
+
from langfun.core.llms.google_genai import Palm2_IT
|
31
33
|
|
32
34
|
# OpenAI models.
|
33
35
|
from langfun.core.llms.openai import OpenAI
|
34
36
|
|
35
37
|
from langfun.core.llms.openai import Gpt4Turbo
|
36
|
-
from langfun.core.llms.openai import
|
37
|
-
from langfun.core.llms.openai import
|
38
|
+
from langfun.core.llms.openai import Gpt4Turbo_20240409
|
39
|
+
from langfun.core.llms.openai import Gpt4TurboPreview
|
40
|
+
from langfun.core.llms.openai import Gpt4TurboPreview_0125
|
41
|
+
from langfun.core.llms.openai import Gpt4TurboPreview_1106
|
42
|
+
from langfun.core.llms.openai import Gpt4VisionPreview
|
43
|
+
from langfun.core.llms.openai import Gpt4VisionPreview_1106
|
38
44
|
from langfun.core.llms.openai import Gpt4
|
39
45
|
from langfun.core.llms.openai import Gpt4_0613
|
40
46
|
from langfun.core.llms.openai import Gpt4_32K
|
@@ -55,6 +61,19 @@ from langfun.core.llms.openai import Gpt3Curie
|
|
55
61
|
from langfun.core.llms.openai import Gpt3Babbage
|
56
62
|
from langfun.core.llms.openai import Gpt3Ada
|
57
63
|
|
64
|
+
from langfun.core.llms.anthropic import Anthropic
|
65
|
+
from langfun.core.llms.anthropic import Claude3Opus
|
66
|
+
from langfun.core.llms.anthropic import Claude3Sonnet
|
67
|
+
from langfun.core.llms.anthropic import Claude3Haiku
|
68
|
+
|
69
|
+
from langfun.core.llms.groq import Groq
|
70
|
+
from langfun.core.llms.groq import GroqLlama3_70B
|
71
|
+
from langfun.core.llms.groq import GroqLlama3_8B
|
72
|
+
from langfun.core.llms.groq import GroqLlama2_70B
|
73
|
+
from langfun.core.llms.groq import GroqMistral_8x7B
|
74
|
+
from langfun.core.llms.groq import GroqGemma7B_IT
|
75
|
+
|
76
|
+
|
58
77
|
# LLaMA C++ models.
|
59
78
|
from langfun.core.llms.llama_cpp import LlamaCppRemote
|
60
79
|
|