langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240511__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/__init__.py +7 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +15 -0
- langfun/core/eval/base.py +665 -95
- langfun/core/eval/base_test.py +224 -53
- langfun/core/eval/matching.py +48 -30
- langfun/core/eval/matching_test.py +25 -3
- langfun/core/eval/patching.py +130 -0
- langfun/core/eval/patching_test.py +170 -0
- langfun/core/eval/scoring.py +19 -10
- langfun/core/eval/scoring_test.py +21 -3
- langfun/core/langfunc.py +1 -22
- langfun/core/langfunc_test.py +10 -4
- langfun/core/language_model.py +130 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +27 -2
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +34 -25
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/google_genai.py +8 -0
- langfun/core/llms/google_genai_test.py +8 -3
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +100 -81
- langfun/core/llms/openai_test.py +287 -60
- langfun/core/llms/vertexai.py +291 -0
- langfun/core/llms/vertexai_test.py +233 -0
- langfun/core/modalities/image.py +1 -3
- langfun/core/modalities/mime.py +6 -0
- langfun/core/modalities/video.py +6 -5
- langfun/core/structured/__init__.py +5 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +61 -3
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +61 -12
- langfun/core/structured/prompting_test.py +122 -12
- langfun/core/structured/schema.py +38 -6
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +36 -7
- langfun/core/structured/scoring.py +4 -1
- langfun/core/structured/scoring_test.py +6 -0
- langfun/core/template.py +147 -11
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/METADATA +3 -2
- langfun-0.0.2.dev20240511.dist-info/RECORD +112 -0
- langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240511.dist-info}/top_level.txt +0 -0
langfun/core/language_model.py
CHANGED
@@ -22,8 +22,12 @@ from langfun.core import component
|
|
22
22
|
from langfun.core import concurrent
|
23
23
|
from langfun.core import console
|
24
24
|
from langfun.core import message as message_lib
|
25
|
+
|
25
26
|
import pyglove as pg
|
26
27
|
|
28
|
+
TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
|
29
|
+
DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
|
30
|
+
|
27
31
|
|
28
32
|
class LMSample(pg.Object):
|
29
33
|
"""Response candidate."""
|
@@ -47,6 +51,14 @@ class LMSample(pg.Object):
|
|
47
51
|
] = None
|
48
52
|
|
49
53
|
|
54
|
+
class LMSamplingUsage(pg.Object):
|
55
|
+
"""Usage information per completion."""
|
56
|
+
|
57
|
+
prompt_tokens: int
|
58
|
+
completion_tokens: int
|
59
|
+
total_tokens: int
|
60
|
+
|
61
|
+
|
50
62
|
class LMSamplingResult(pg.Object):
|
51
63
|
"""Language model response."""
|
52
64
|
|
@@ -58,19 +70,34 @@ class LMSamplingResult(pg.Object):
|
|
58
70
|
),
|
59
71
|
] = []
|
60
72
|
|
73
|
+
usage: Annotated[
|
74
|
+
LMSamplingUsage | None,
|
75
|
+
'Usage information. Currently only OpenAI models are supported.',
|
76
|
+
] = None
|
77
|
+
|
61
78
|
|
62
79
|
class LMSamplingOptions(component.Component):
|
63
80
|
"""Language model sampling options."""
|
64
81
|
|
65
82
|
temperature: Annotated[
|
66
|
-
float,
|
83
|
+
float | None,
|
67
84
|
(
|
68
85
|
'Model temperature, which is usually between 0 and 1.0. '
|
69
|
-
'OpenAI models have temperature range from 0.0 to 2.0.'
|
86
|
+
'OpenAI models have temperature range from 0.0 to 2.0. '
|
87
|
+
'If None (default), honor the model\'s default behavior. '
|
70
88
|
)
|
71
|
-
] =
|
72
|
-
|
89
|
+
] = None
|
90
|
+
|
91
|
+
max_tokens: Annotated[
|
92
|
+
int | None,
|
93
|
+
(
|
94
|
+
'Per example max tokens to generate. '
|
95
|
+
'If None, use the model default.'
|
96
|
+
)
|
97
|
+
] = None
|
98
|
+
|
73
99
|
n: Annotated[int | None, 'Max number of samples to return.'] = 1
|
100
|
+
|
74
101
|
top_k: Annotated[
|
75
102
|
int | None,
|
76
103
|
(
|
@@ -78,6 +105,7 @@ class LMSamplingOptions(component.Component):
|
|
78
105
|
'Not applicable to OpenAI models.'
|
79
106
|
)
|
80
107
|
] = 40
|
108
|
+
|
81
109
|
top_p: Annotated[
|
82
110
|
float | None,
|
83
111
|
(
|
@@ -86,6 +114,7 @@ class LMSamplingOptions(component.Component):
|
|
86
114
|
'`top_p` but not both.'
|
87
115
|
),
|
88
116
|
] = None
|
117
|
+
|
89
118
|
stop: Annotated[
|
90
119
|
list[str] | None,
|
91
120
|
(
|
@@ -95,9 +124,11 @@ class LMSamplingOptions(component.Component):
|
|
95
124
|
'`Model:` is reached.'
|
96
125
|
),
|
97
126
|
] = None
|
127
|
+
|
98
128
|
random_seed: Annotated[
|
99
129
|
int | None, 'A fixed random seed used during model inference.'
|
100
130
|
] = None
|
131
|
+
|
101
132
|
logprobs: Annotated[
|
102
133
|
bool,
|
103
134
|
(
|
@@ -106,6 +137,7 @@ class LMSamplingOptions(component.Component):
|
|
106
137
|
'in the content of message.'
|
107
138
|
),
|
108
139
|
] = False
|
140
|
+
|
109
141
|
top_logprobs: Annotated[
|
110
142
|
int | None,
|
111
143
|
(
|
@@ -135,6 +167,11 @@ class LMScoringResult(pg.Object):
|
|
135
167
|
float,
|
136
168
|
'The log likelyhood of the requested completion towards the prompt.',
|
137
169
|
]
|
170
|
+
gradients: Annotated[
|
171
|
+
Any | None,
|
172
|
+
'(Optional) gradients from the score method, w.r.t.' +
|
173
|
+
' prompt.metadata.weights.',
|
174
|
+
] = None
|
138
175
|
|
139
176
|
|
140
177
|
class LMCache(pg.Object):
|
@@ -315,9 +352,42 @@ class LanguageModel(component.Component):
|
|
315
352
|
|
316
353
|
with component.context(override_attrs=True, **kwargs):
|
317
354
|
if self.cache is None:
|
318
|
-
|
355
|
+
results = self._sample(prompts)
|
319
356
|
else:
|
320
|
-
|
357
|
+
results = self._sample_with_cache_lookup(prompts, cache_seed)
|
358
|
+
|
359
|
+
for prompt, result in zip(prompts, results):
|
360
|
+
|
361
|
+
# Tag LM input.
|
362
|
+
prompt.tag(message_lib.Message.TAG_LM_INPUT)
|
363
|
+
|
364
|
+
for sample in result.samples:
|
365
|
+
# Update metadata for response message.
|
366
|
+
|
367
|
+
response = sample.response
|
368
|
+
response.metadata.score = sample.score
|
369
|
+
response.metadata.logprobs = sample.logprobs
|
370
|
+
|
371
|
+
# NOTE(daiyip): Current usage is computed at per-result level,
|
372
|
+
# which is accurate when n=1. For n > 1, we average the usage across
|
373
|
+
# multiple samples.
|
374
|
+
usage = result.usage
|
375
|
+
if len(result.samples) == 1 or usage is None:
|
376
|
+
response.metadata.usage = usage
|
377
|
+
else:
|
378
|
+
n = len(result.samples)
|
379
|
+
response.metadata.usage = LMSamplingUsage(
|
380
|
+
prompt_tokens=usage.prompt_tokens // n,
|
381
|
+
completion_tokens=usage.completion_tokens // n,
|
382
|
+
total_tokens=usage.total_tokens // n,
|
383
|
+
)
|
384
|
+
|
385
|
+
# Track the prompt for corresponding response.
|
386
|
+
response.source = prompt
|
387
|
+
|
388
|
+
# Tag LM response.
|
389
|
+
response.tag(message_lib.Message.TAG_LM_RESPONSE)
|
390
|
+
return results
|
321
391
|
|
322
392
|
def _sample_with_cache_lookup(
|
323
393
|
self, prompts: list[str | message_lib.Message], cache_seed: int
|
@@ -405,12 +475,9 @@ class LanguageModel(component.Component):
|
|
405
475
|
result = self.sample(
|
406
476
|
[prompt], sampling_options=sampling_options, cache_seed=cache_seed
|
407
477
|
)[0]
|
408
|
-
response = result.samples[0].response
|
409
|
-
logprobs = result.samples[0].logprobs
|
410
|
-
response.set('score', result.samples[0].score)
|
411
|
-
response.metadata.logprobs = logprobs
|
412
478
|
elapse = time.time() - request_start
|
413
|
-
|
479
|
+
response = result.samples[0].response
|
480
|
+
self._debug(prompt, response, call_counter, result.usage, elapse)
|
414
481
|
return response
|
415
482
|
|
416
483
|
def _debug(
|
@@ -418,35 +485,53 @@ class LanguageModel(component.Component):
|
|
418
485
|
prompt: message_lib.Message,
|
419
486
|
response: message_lib.Message,
|
420
487
|
call_counter: int,
|
488
|
+
usage: LMSamplingUsage | None,
|
421
489
|
elapse: float,
|
422
|
-
):
|
490
|
+
) -> None:
|
423
491
|
"""Outputs debugging information."""
|
424
492
|
debug = self.debug
|
425
493
|
if isinstance(debug, bool):
|
426
494
|
debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
|
427
495
|
|
428
496
|
if debug & LMDebugMode.INFO:
|
429
|
-
self._debug_model_info(call_counter)
|
497
|
+
self._debug_model_info(call_counter, usage)
|
430
498
|
|
431
499
|
if debug & LMDebugMode.PROMPT:
|
432
|
-
self._debug_prompt(prompt, call_counter)
|
500
|
+
self._debug_prompt(prompt, call_counter, usage)
|
433
501
|
|
434
502
|
if debug & LMDebugMode.RESPONSE:
|
435
|
-
self._debug_response(response, call_counter, elapse)
|
503
|
+
self._debug_response(response, call_counter, usage, elapse)
|
436
504
|
|
437
|
-
def _debug_model_info(
|
505
|
+
def _debug_model_info(
|
506
|
+
self, call_counter: int, usage: LMSamplingUsage | None) -> None:
|
438
507
|
"""Outputs debugging information about the model."""
|
508
|
+
title_suffix = ''
|
509
|
+
if usage and usage.total_tokens != 0:
|
510
|
+
title_suffix = console.colored(
|
511
|
+
f' (total {usage.total_tokens} tokens)', 'red')
|
512
|
+
|
439
513
|
console.write(
|
440
514
|
self.format(compact=True, use_inferred=True),
|
441
|
-
title=f'[{call_counter}] LM INFO:',
|
515
|
+
title=f'[{call_counter}] LM INFO{title_suffix}:',
|
442
516
|
color='magenta',
|
443
517
|
)
|
444
518
|
|
445
|
-
def _debug_prompt(
|
519
|
+
def _debug_prompt(
|
520
|
+
self,
|
521
|
+
prompt: message_lib.Message,
|
522
|
+
call_counter: int,
|
523
|
+
usage: LMSamplingUsage | None,
|
524
|
+
) -> None:
|
446
525
|
"""Outputs debugging information about the prompt."""
|
526
|
+
title_suffix = ''
|
527
|
+
if usage and usage.prompt_tokens != 0:
|
528
|
+
title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
|
529
|
+
|
447
530
|
console.write(
|
448
|
-
prompt
|
449
|
-
|
531
|
+
# We use metadata 'formatted_text' for scenarios where the prompt text
|
532
|
+
# is formatted by the LM.
|
533
|
+
prompt.get('formatted_text', prompt.text),
|
534
|
+
title=f'\n[{call_counter}] PROMPT SENT TO LM{title_suffix}:',
|
450
535
|
color='green',
|
451
536
|
)
|
452
537
|
referred_modalities = prompt.referred_modalities()
|
@@ -460,12 +545,22 @@ class LanguageModel(component.Component):
|
|
460
545
|
)
|
461
546
|
|
462
547
|
def _debug_response(
|
463
|
-
self,
|
464
|
-
|
548
|
+
self,
|
549
|
+
response: message_lib.Message,
|
550
|
+
call_counter: int,
|
551
|
+
usage: LMSamplingUsage | None,
|
552
|
+
elapse: float
|
553
|
+
) -> None:
|
465
554
|
"""Outputs debugging information about the response."""
|
555
|
+
title_suffix = ' ('
|
556
|
+
if usage and usage.completion_tokens != 0:
|
557
|
+
title_suffix += f'{usage.completion_tokens} tokens '
|
558
|
+
title_suffix += f'in {elapse:.2f} seconds)'
|
559
|
+
title_suffix = console.colored(title_suffix, 'red')
|
560
|
+
|
466
561
|
console.write(
|
467
562
|
str(response) + '\n',
|
468
|
-
title=f'\n[{call_counter}] LM RESPONSE
|
563
|
+
title=f'\n[{call_counter}] LM RESPONSE{title_suffix}:',
|
469
564
|
color='blue',
|
470
565
|
)
|
471
566
|
|
@@ -512,7 +607,7 @@ class LanguageModel(component.Component):
|
|
512
607
|
debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
|
513
608
|
|
514
609
|
if debug & LMDebugMode.INFO:
|
515
|
-
self._debug_model_info(call_counter)
|
610
|
+
self._debug_model_info(call_counter, None)
|
516
611
|
|
517
612
|
if debug & LMDebugMode.PROMPT:
|
518
613
|
console.write(
|
@@ -548,3 +643,14 @@ class LanguageModel(component.Component):
|
|
548
643
|
f'score: {r.score}',
|
549
644
|
color='blue',
|
550
645
|
)
|
646
|
+
|
647
|
+
def rate_to_max_concurrency(
|
648
|
+
self, requests_per_min: float = 0, tokens_per_min: float = 0
|
649
|
+
) -> int:
|
650
|
+
"""Converts a rate to a max concurrency."""
|
651
|
+
if tokens_per_min > 0:
|
652
|
+
return max(int(tokens_per_min / TOKENS_PER_REQUEST / 60), 1)
|
653
|
+
elif requests_per_min > 0:
|
654
|
+
return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
|
655
|
+
else:
|
656
|
+
return DEFAULT_MAX_CONCURRENCY # Default of 1
|
@@ -38,9 +38,19 @@ class MockModel(lm_lib.LanguageModel):
|
|
38
38
|
def fake_sample(prompts):
|
39
39
|
if context.attempt >= self.failures_before_attempt:
|
40
40
|
return [
|
41
|
-
lm_lib.LMSamplingResult(
|
42
|
-
|
43
|
-
|
41
|
+
lm_lib.LMSamplingResult(
|
42
|
+
[
|
43
|
+
lm_lib.LMSample( # pylint: disable=g-complex-comprehension
|
44
|
+
response=prompt.text * self.sampling_options.top_k,
|
45
|
+
score=self.sampling_options.temperature or -1.0,
|
46
|
+
)
|
47
|
+
],
|
48
|
+
usage=lm_lib.LMSamplingUsage(
|
49
|
+
prompt_tokens=100,
|
50
|
+
completion_tokens=100,
|
51
|
+
total_tokens=200,
|
52
|
+
),
|
53
|
+
)
|
44
54
|
for prompt in prompts
|
45
55
|
]
|
46
56
|
context.attempt += 1
|
@@ -73,13 +83,13 @@ class LMSamplingOptionsTest(unittest.TestCase):
|
|
73
83
|
def test_cache_key(self):
|
74
84
|
options = lm_lib.LMSamplingOptions()
|
75
85
|
key1 = options.cache_key()
|
76
|
-
self.assertEqual(key1, (
|
86
|
+
self.assertEqual(key1, (None, None, 1, 40, None, None))
|
77
87
|
with options.override(temperature=1.0, max_tokens=256):
|
78
88
|
key2 = options.cache_key()
|
79
89
|
self.assertEqual(key2, (1.0, 256, 1, 40, None, None))
|
80
90
|
|
81
91
|
# Make sure key1 does not change upon override.
|
82
|
-
self.assertEqual(key1, (
|
92
|
+
self.assertEqual(key1, (None, None, 1, 40, None, None))
|
83
93
|
|
84
94
|
|
85
95
|
class LanguageModelTest(unittest.TestCase):
|
@@ -100,8 +110,38 @@ class LanguageModelTest(unittest.TestCase):
|
|
100
110
|
self.assertEqual(
|
101
111
|
lm.sample(prompts=['foo', 'bar']),
|
102
112
|
[
|
103
|
-
lm_lib.LMSamplingResult(
|
104
|
-
|
113
|
+
lm_lib.LMSamplingResult(
|
114
|
+
[
|
115
|
+
lm_lib.LMSample(
|
116
|
+
message_lib.AIMessage(
|
117
|
+
'foo',
|
118
|
+
score=-1.0,
|
119
|
+
logprobs=None,
|
120
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
121
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
122
|
+
),
|
123
|
+
score=-1.0,
|
124
|
+
logprobs=None,
|
125
|
+
)
|
126
|
+
],
|
127
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
128
|
+
),
|
129
|
+
lm_lib.LMSamplingResult(
|
130
|
+
[
|
131
|
+
lm_lib.LMSample(
|
132
|
+
message_lib.AIMessage(
|
133
|
+
'bar',
|
134
|
+
score=-1.0,
|
135
|
+
logprobs=None,
|
136
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
137
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
138
|
+
),
|
139
|
+
score=-1.0,
|
140
|
+
logprobs=None,
|
141
|
+
)
|
142
|
+
],
|
143
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
144
|
+
),
|
105
145
|
],
|
106
146
|
)
|
107
147
|
# Test override sampling_options.
|
@@ -112,38 +152,128 @@ class LanguageModelTest(unittest.TestCase):
|
|
112
152
|
),
|
113
153
|
[
|
114
154
|
lm_lib.LMSamplingResult(
|
115
|
-
[
|
155
|
+
[
|
156
|
+
lm_lib.LMSample(
|
157
|
+
message_lib.AIMessage(
|
158
|
+
'foo' * 2,
|
159
|
+
score=0.5,
|
160
|
+
logprobs=None,
|
161
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
162
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
163
|
+
),
|
164
|
+
score=0.5,
|
165
|
+
logprobs=None,
|
166
|
+
),
|
167
|
+
],
|
168
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
116
169
|
),
|
117
170
|
lm_lib.LMSamplingResult(
|
118
|
-
[
|
171
|
+
[
|
172
|
+
lm_lib.LMSample(
|
173
|
+
message_lib.AIMessage(
|
174
|
+
'bar' * 2,
|
175
|
+
score=0.5,
|
176
|
+
logprobs=None,
|
177
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
178
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
179
|
+
),
|
180
|
+
score=0.5,
|
181
|
+
logprobs=None,
|
182
|
+
),
|
183
|
+
],
|
184
|
+
usage=lm_lib.LMSamplingUsage(
|
185
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
186
|
+
),
|
119
187
|
),
|
120
|
-
]
|
188
|
+
]
|
121
189
|
)
|
122
190
|
# Test override individual flags within sampling_options.
|
123
191
|
self.assertEqual(
|
124
192
|
lm.sample(prompts=['foo', 'bar'], temperature=1.0),
|
125
193
|
[
|
126
|
-
lm_lib.LMSamplingResult(
|
127
|
-
|
128
|
-
|
194
|
+
lm_lib.LMSamplingResult(
|
195
|
+
[
|
196
|
+
lm_lib.LMSample(
|
197
|
+
message_lib.AIMessage(
|
198
|
+
'foo',
|
199
|
+
score=1.0,
|
200
|
+
logprobs=None,
|
201
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
202
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
203
|
+
),
|
204
|
+
score=1.0,
|
205
|
+
logprobs=None,
|
206
|
+
),
|
207
|
+
],
|
208
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
209
|
+
),
|
210
|
+
lm_lib.LMSamplingResult(
|
211
|
+
[
|
212
|
+
lm_lib.LMSample(
|
213
|
+
message_lib.AIMessage(
|
214
|
+
'bar',
|
215
|
+
score=1.0,
|
216
|
+
logprobs=None,
|
217
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
218
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
219
|
+
),
|
220
|
+
score=1.0,
|
221
|
+
logprobs=None,
|
222
|
+
),
|
223
|
+
],
|
224
|
+
usage=lm_lib.LMSamplingUsage(
|
225
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
226
|
+
),
|
227
|
+
),
|
228
|
+
]
|
129
229
|
)
|
130
230
|
self.assertEqual(
|
131
231
|
lm.sample(prompts=['foo', 'bar'], top_k=2, temperature=0.7),
|
132
232
|
[
|
133
233
|
lm_lib.LMSamplingResult(
|
134
|
-
[
|
234
|
+
[
|
235
|
+
lm_lib.LMSample(
|
236
|
+
message_lib.AIMessage(
|
237
|
+
'foo' * 2,
|
238
|
+
score=0.7,
|
239
|
+
logprobs=None,
|
240
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
241
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
242
|
+
),
|
243
|
+
score=0.7,
|
244
|
+
logprobs=None,
|
245
|
+
),
|
246
|
+
],
|
247
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
135
248
|
),
|
136
249
|
lm_lib.LMSamplingResult(
|
137
|
-
[
|
250
|
+
[
|
251
|
+
lm_lib.LMSample(
|
252
|
+
message_lib.AIMessage(
|
253
|
+
'bar' * 2,
|
254
|
+
score=0.7,
|
255
|
+
logprobs=None,
|
256
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
257
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
258
|
+
),
|
259
|
+
score=0.7,
|
260
|
+
logprobs=None,
|
261
|
+
),
|
262
|
+
],
|
263
|
+
usage=lm_lib.LMSamplingUsage(
|
264
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
265
|
+
),
|
138
266
|
),
|
139
|
-
]
|
267
|
+
]
|
140
268
|
)
|
141
269
|
|
142
270
|
def test_call(self):
|
143
271
|
lm = MockModel(sampling_options=lm_lib.LMSamplingOptions(top_k=1))
|
144
272
|
response = lm(prompt='foo')
|
145
273
|
self.assertEqual(response.text, 'foo')
|
146
|
-
self.assertEqual(response.score,
|
274
|
+
self.assertEqual(response.score, -1.0)
|
275
|
+
self.assertIsNone(response.logprobs)
|
276
|
+
self.assertEqual(response.usage, lm_lib.LMSamplingUsage(100, 100, 200))
|
147
277
|
|
148
278
|
# Test override sampling_options.
|
149
279
|
self.assertEqual(
|
@@ -158,11 +288,42 @@ class LanguageModelTest(unittest.TestCase):
|
|
158
288
|
self.assertEqual(
|
159
289
|
lm.sample(prompts=['foo', 'bar']),
|
160
290
|
[
|
161
|
-
lm_lib.LMSamplingResult(
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
291
|
+
lm_lib.LMSamplingResult(
|
292
|
+
[
|
293
|
+
lm_lib.LMSample(
|
294
|
+
message_lib.AIMessage(
|
295
|
+
'foo',
|
296
|
+
cache_seed=0,
|
297
|
+
score=-1.0,
|
298
|
+
logprobs=None,
|
299
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
300
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
301
|
+
),
|
302
|
+
score=-1.0,
|
303
|
+
logprobs=None,
|
304
|
+
)
|
305
|
+
],
|
306
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
307
|
+
),
|
308
|
+
lm_lib.LMSamplingResult(
|
309
|
+
[
|
310
|
+
lm_lib.LMSample(
|
311
|
+
message_lib.AIMessage(
|
312
|
+
'bar',
|
313
|
+
cache_seed=0,
|
314
|
+
score=-1.0,
|
315
|
+
logprobs=None,
|
316
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
317
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
318
|
+
),
|
319
|
+
score=-1.0,
|
320
|
+
logprobs=None,
|
321
|
+
)
|
322
|
+
],
|
323
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
324
|
+
),
|
325
|
+
],
|
326
|
+
)
|
166
327
|
self.assertEqual(cache.stats.num_queries, 2)
|
167
328
|
self.assertEqual(cache.stats.num_hits, 0)
|
168
329
|
self.assertEqual(cache.stats.num_updates, 2)
|
@@ -181,10 +342,40 @@ class LanguageModelTest(unittest.TestCase):
|
|
181
342
|
self.assertEqual(
|
182
343
|
lm.sample(prompts=['foo', 'baz'], temperature=1.0),
|
183
344
|
[
|
184
|
-
lm_lib.LMSamplingResult(
|
185
|
-
|
186
|
-
|
187
|
-
|
345
|
+
lm_lib.LMSamplingResult(
|
346
|
+
[
|
347
|
+
lm_lib.LMSample(
|
348
|
+
message_lib.AIMessage(
|
349
|
+
'foo',
|
350
|
+
cache_seed=0,
|
351
|
+
score=1.0,
|
352
|
+
logprobs=None,
|
353
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
354
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
355
|
+
),
|
356
|
+
score=1.0,
|
357
|
+
logprobs=None,
|
358
|
+
)
|
359
|
+
],
|
360
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
361
|
+
),
|
362
|
+
lm_lib.LMSamplingResult(
|
363
|
+
[
|
364
|
+
lm_lib.LMSample(
|
365
|
+
message_lib.AIMessage(
|
366
|
+
'baz',
|
367
|
+
cache_seed=0,
|
368
|
+
score=1.0,
|
369
|
+
logprobs=None,
|
370
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
371
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
372
|
+
),
|
373
|
+
score=1.0,
|
374
|
+
logprobs=None,
|
375
|
+
)
|
376
|
+
],
|
377
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
378
|
+
),
|
188
379
|
],
|
189
380
|
)
|
190
381
|
self.assertEqual(cache.stats.num_queries, 6)
|
@@ -341,6 +532,38 @@ class LanguageModelTest(unittest.TestCase):
|
|
341
532
|
with self.assertRaises(NotImplementedError):
|
342
533
|
MockModel().score('hi', ['1', '2'])
|
343
534
|
|
535
|
+
def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
|
536
|
+
lm = MockModel()
|
537
|
+
self.assertEqual(
|
538
|
+
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
539
|
+
lm.rate_to_max_concurrency(requests_per_min=0, tokens_per_min=0),
|
540
|
+
)
|
541
|
+
self.assertEqual(
|
542
|
+
lm_lib.DEFAULT_MAX_CONCURRENCY,
|
543
|
+
lm.rate_to_max_concurrency(requests_per_min=-1, tokens_per_min=-1),
|
544
|
+
)
|
545
|
+
|
546
|
+
def test_rate_to_max_concurrency_only_rpm_specified_uses_rpm(self) -> None:
|
547
|
+
lm = MockModel()
|
548
|
+
test_rpm = 1e4
|
549
|
+
self.assertEqual(
|
550
|
+
lm.rate_to_max_concurrency(requests_per_min=test_rpm),
|
551
|
+
int(test_rpm / 60)
|
552
|
+
)
|
553
|
+
|
554
|
+
def test_rate_to_max_concurrency_tpm_specified_uses_tpm(self) -> None:
|
555
|
+
lm = MockModel()
|
556
|
+
test_tpm = 1e7
|
557
|
+
self.assertEqual(
|
558
|
+
lm.rate_to_max_concurrency(requests_per_min=1, tokens_per_min=test_tpm),
|
559
|
+
int(test_tpm / lm_lib.TOKENS_PER_REQUEST / 60),
|
560
|
+
)
|
561
|
+
|
562
|
+
def test_rate_to_max_concurrency_small_rate_returns_one(self) -> None:
|
563
|
+
lm = MockModel()
|
564
|
+
self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
|
565
|
+
self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
|
566
|
+
|
344
567
|
|
345
568
|
if __name__ == '__main__':
|
346
569
|
unittest.main()
|
langfun/core/llms/__init__.py
CHANGED
@@ -27,6 +27,7 @@ from langfun.core.llms.fake import StaticSequence
|
|
27
27
|
# Gemini models.
|
28
28
|
from langfun.core.llms.google_genai import GenAI
|
29
29
|
from langfun.core.llms.google_genai import GeminiPro
|
30
|
+
from langfun.core.llms.google_genai import GeminiPro1_5
|
30
31
|
from langfun.core.llms.google_genai import GeminiProVision
|
31
32
|
from langfun.core.llms.google_genai import Palm2
|
32
33
|
from langfun.core.llms.google_genai import Palm2_IT
|
@@ -35,8 +36,12 @@ from langfun.core.llms.google_genai import Palm2_IT
|
|
35
36
|
from langfun.core.llms.openai import OpenAI
|
36
37
|
|
37
38
|
from langfun.core.llms.openai import Gpt4Turbo
|
38
|
-
from langfun.core.llms.openai import
|
39
|
-
from langfun.core.llms.openai import
|
39
|
+
from langfun.core.llms.openai import Gpt4Turbo_20240409
|
40
|
+
from langfun.core.llms.openai import Gpt4TurboPreview
|
41
|
+
from langfun.core.llms.openai import Gpt4TurboPreview_0125
|
42
|
+
from langfun.core.llms.openai import Gpt4TurboPreview_1106
|
43
|
+
from langfun.core.llms.openai import Gpt4VisionPreview
|
44
|
+
from langfun.core.llms.openai import Gpt4VisionPreview_1106
|
40
45
|
from langfun.core.llms.openai import Gpt4
|
41
46
|
from langfun.core.llms.openai import Gpt4_0613
|
42
47
|
from langfun.core.llms.openai import Gpt4_32K
|
@@ -57,6 +62,26 @@ from langfun.core.llms.openai import Gpt3Curie
|
|
57
62
|
from langfun.core.llms.openai import Gpt3Babbage
|
58
63
|
from langfun.core.llms.openai import Gpt3Ada
|
59
64
|
|
65
|
+
from langfun.core.llms.anthropic import Anthropic
|
66
|
+
from langfun.core.llms.anthropic import Claude3Opus
|
67
|
+
from langfun.core.llms.anthropic import Claude3Sonnet
|
68
|
+
from langfun.core.llms.anthropic import Claude3Haiku
|
69
|
+
|
70
|
+
from langfun.core.llms.groq import Groq
|
71
|
+
from langfun.core.llms.groq import GroqLlama3_70B
|
72
|
+
from langfun.core.llms.groq import GroqLlama3_8B
|
73
|
+
from langfun.core.llms.groq import GroqLlama2_70B
|
74
|
+
from langfun.core.llms.groq import GroqMistral_8x7B
|
75
|
+
from langfun.core.llms.groq import GroqGemma7B_IT
|
76
|
+
|
77
|
+
from langfun.core.llms.vertexai import VertexAI
|
78
|
+
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
|
79
|
+
from langfun.core.llms.vertexai import VertexAIGeminiPro1
|
80
|
+
from langfun.core.llms.vertexai import VertexAIGeminiPro1Vision
|
81
|
+
from langfun.core.llms.vertexai import VertexAIPalm2
|
82
|
+
from langfun.core.llms.vertexai import VertexAIPalm2_32K
|
83
|
+
|
84
|
+
|
60
85
|
# LLaMA C++ models.
|
61
86
|
from langfun.core.llms.llama_cpp import LlamaCppRemote
|
62
87
|
|