langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +2 -0
- langfun/core/__init__.py +1 -0
- langfun/core/coding/python/correction.py +0 -7
- langfun/core/component.py +6 -0
- langfun/core/component_test.py +1 -0
- langfun/core/eval/__init__.py +2 -0
- langfun/core/eval/base.py +202 -23
- langfun/core/eval/base_test.py +49 -10
- langfun/core/eval/matching.py +26 -9
- langfun/core/eval/matching_test.py +2 -1
- langfun/core/eval/scoring.py +15 -6
- langfun/core/eval/scoring_test.py +2 -1
- langfun/core/langfunc.py +0 -5
- langfun/core/langfunc_test.py +6 -4
- langfun/core/language_model.py +124 -24
- langfun/core/language_model_test.py +249 -26
- langfun/core/llms/__init__.py +19 -2
- langfun/core/llms/anthropic.py +263 -0
- langfun/core/llms/anthropic_test.py +167 -0
- langfun/core/llms/cache/in_memory_test.py +37 -28
- langfun/core/llms/fake.py +31 -22
- langfun/core/llms/fake_test.py +122 -11
- langfun/core/llms/google_genai_test.py +8 -3
- langfun/core/llms/groq.py +260 -0
- langfun/core/llms/groq_test.py +170 -0
- langfun/core/llms/llama_cpp.py +3 -1
- langfun/core/llms/openai.py +97 -79
- langfun/core/llms/openai_test.py +285 -59
- langfun/core/modalities/video.py +5 -2
- langfun/core/structured/__init__.py +3 -0
- langfun/core/structured/completion_test.py +2 -2
- langfun/core/structured/function_generation.py +245 -0
- langfun/core/structured/function_generation_test.py +329 -0
- langfun/core/structured/mapping.py +56 -2
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing_test.py +18 -13
- langfun/core/structured/prompting.py +27 -6
- langfun/core/structured/prompting_test.py +79 -12
- langfun/core/structured/schema.py +4 -2
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/structured/schema_test.py +4 -6
- langfun/core/template.py +125 -10
- langfun/core/template_test.py +75 -0
- langfun/core/templates/selfplay_test.py +6 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +49 -43
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
@@ -103,7 +103,7 @@ class MatchingTest(unittest.TestCase):
|
|
103
103
|
s.result,
|
104
104
|
dict(
|
105
105
|
experiment_setup=dict(
|
106
|
-
id='MyTask@
|
106
|
+
id='MyTask@739a174b',
|
107
107
|
dir=s.dir,
|
108
108
|
model='StaticSequence',
|
109
109
|
prompt_template='{{example.question}}',
|
@@ -125,6 +125,7 @@ class MatchingTest(unittest.TestCase):
|
|
125
125
|
num_mismatches=1,
|
126
126
|
mismatch_rate=0.25,
|
127
127
|
),
|
128
|
+
usage=s.result.usage,
|
128
129
|
),
|
129
130
|
)
|
130
131
|
self.assertTrue(
|
langfun/core/eval/scoring.py
CHANGED
@@ -61,8 +61,18 @@ class Scoring(base.Evaluation):
|
|
61
61
|
super()._reset()
|
62
62
|
self._scored = []
|
63
63
|
|
64
|
-
def
|
64
|
+
def audit_processed(
|
65
|
+
self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
|
66
|
+
) -> None:
|
65
67
|
score = self.score(example, output)
|
68
|
+
|
69
|
+
if dryrun:
|
70
|
+
lf.console.write('')
|
71
|
+
lf.console.write(
|
72
|
+
str(score),
|
73
|
+
title='SCORE',
|
74
|
+
color='blue',
|
75
|
+
)
|
66
76
|
self._scored.append((example, output, score, message))
|
67
77
|
|
68
78
|
@abc.abstractmethod
|
@@ -118,19 +128,18 @@ class Scoring(base.Evaluation):
|
|
118
128
|
super().save(definition, result, report)
|
119
129
|
|
120
130
|
if result:
|
121
|
-
|
122
|
-
def force_dict(v):
|
123
|
-
return pg.object_utils.json_conversion.strip_types(pg.to_json(v))
|
124
|
-
|
125
131
|
# Save scored.
|
126
132
|
pg.save(
|
127
133
|
[
|
128
134
|
# We force the output to be dict as its type may be defined
|
129
135
|
# within functors which could be deserialized.
|
130
|
-
pg.Dict(input=input, output=
|
136
|
+
pg.Dict(input=input, output=output, score=score)
|
131
137
|
for input, output, score, _ in self.scored
|
132
138
|
],
|
133
139
|
os.path.join(self.dir, Scoring.SCORED_JSON),
|
140
|
+
# We force the input and output to be dict so it does not depend on
|
141
|
+
# the downstream to serialize.
|
142
|
+
force_dict=True,
|
134
143
|
)
|
135
144
|
|
136
145
|
if report:
|
@@ -81,7 +81,7 @@ class ScoringTest(unittest.TestCase):
|
|
81
81
|
s.result,
|
82
82
|
dict(
|
83
83
|
experiment_setup=dict(
|
84
|
-
id='ConstraintFollowing@
|
84
|
+
id='ConstraintFollowing@5c88a5eb',
|
85
85
|
dir=s.dir,
|
86
86
|
model='StaticSequence',
|
87
87
|
prompt_template='{{example}}',
|
@@ -102,6 +102,7 @@ class ScoringTest(unittest.TestCase):
|
|
102
102
|
score_rate=1.0,
|
103
103
|
avg_score=0.5,
|
104
104
|
),
|
105
|
+
usage=s.result.usage,
|
105
106
|
),
|
106
107
|
)
|
107
108
|
self.assertTrue(
|
langfun/core/langfunc.py
CHANGED
@@ -261,7 +261,6 @@ class LangFunc(
|
|
261
261
|
if lm_input is None:
|
262
262
|
lm_input = self.render(**kwargs)
|
263
263
|
|
264
|
-
lm_input.tag(message_lib.Message.TAG_LM_INPUT)
|
265
264
|
if skip_lm:
|
266
265
|
return lm_input
|
267
266
|
|
@@ -270,10 +269,6 @@ class LangFunc(
|
|
270
269
|
# Send rendered text to LM.
|
271
270
|
lm_output = self.lm(lm_input, cache_seed=cache_seed)
|
272
271
|
|
273
|
-
# Track the input as the source of the output.
|
274
|
-
lm_output.source = lm_input
|
275
|
-
lm_output.tag(message_lib.Message.TAG_LM_RESPONSE)
|
276
|
-
|
277
272
|
# Transform the output message.
|
278
273
|
lm_output = self.transform_output(lm_output)
|
279
274
|
lm_output.tag(message_lib.Message.TAG_LM_OUTPUT)
|
langfun/core/langfunc_test.py
CHANGED
@@ -82,7 +82,9 @@ class LangFuncCallTest(unittest.TestCase):
|
|
82
82
|
self.assertEqual(i.tags, ['rendered'])
|
83
83
|
|
84
84
|
r = l()
|
85
|
-
self.assertEqual(
|
85
|
+
self.assertEqual(
|
86
|
+
r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
|
87
|
+
)
|
86
88
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
87
89
|
self.assertEqual(r.source, message.UserMessage('Hello'))
|
88
90
|
self.assertEqual(r.source.tags, ['rendered', 'lm-input'])
|
@@ -92,8 +94,8 @@ class LangFuncCallTest(unittest.TestCase):
|
|
92
94
|
self.assertEqual(
|
93
95
|
repr(l),
|
94
96
|
"LangFunc(template_str='Hello', clean=True,"
|
95
|
-
' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=
|
96
|
-
' max_tokens=
|
97
|
+
' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=None,'
|
98
|
+
' max_tokens=None, n=1, top_k=40, top_p=None, stop=None,'
|
97
99
|
' random_seed=None, logprobs=False, top_logprobs=None), cache=None,'
|
98
100
|
' max_concurrency=None, timeout=120.0, max_attempts=5,'
|
99
101
|
' retry_interval=(5, 60), exponential_backoff=True, debug=False))',
|
@@ -106,7 +108,7 @@ class LangFuncCallTest(unittest.TestCase):
|
|
106
108
|
self.assertEqual(l.render(), 'Hello')
|
107
109
|
r = l()
|
108
110
|
self.assertEqual(
|
109
|
-
r, message.AIMessage('Hello!!!', score=0.0, logprobs=None)
|
111
|
+
r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
|
110
112
|
)
|
111
113
|
self.assertEqual(r.tags, ['lm-response', 'lm-output'])
|
112
114
|
|
langfun/core/language_model.py
CHANGED
@@ -24,6 +24,9 @@ from langfun.core import console
|
|
24
24
|
from langfun.core import message as message_lib
|
25
25
|
import pyglove as pg
|
26
26
|
|
27
|
+
TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
|
28
|
+
DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
|
29
|
+
|
27
30
|
|
28
31
|
class LMSample(pg.Object):
|
29
32
|
"""Response candidate."""
|
@@ -47,6 +50,14 @@ class LMSample(pg.Object):
|
|
47
50
|
] = None
|
48
51
|
|
49
52
|
|
53
|
+
class LMSamplingUsage(pg.Object):
|
54
|
+
"""Usage information per completion."""
|
55
|
+
|
56
|
+
prompt_tokens: int
|
57
|
+
completion_tokens: int
|
58
|
+
total_tokens: int
|
59
|
+
|
60
|
+
|
50
61
|
class LMSamplingResult(pg.Object):
|
51
62
|
"""Language model response."""
|
52
63
|
|
@@ -58,19 +69,34 @@ class LMSamplingResult(pg.Object):
|
|
58
69
|
),
|
59
70
|
] = []
|
60
71
|
|
72
|
+
usage: Annotated[
|
73
|
+
LMSamplingUsage | None,
|
74
|
+
'Usage information. Currently only OpenAI models are supported.',
|
75
|
+
] = None
|
76
|
+
|
61
77
|
|
62
78
|
class LMSamplingOptions(component.Component):
|
63
79
|
"""Language model sampling options."""
|
64
80
|
|
65
81
|
temperature: Annotated[
|
66
|
-
float,
|
82
|
+
float | None,
|
67
83
|
(
|
68
84
|
'Model temperature, which is usually between 0 and 1.0. '
|
69
|
-
'OpenAI models have temperature range from 0.0 to 2.0.'
|
85
|
+
'OpenAI models have temperature range from 0.0 to 2.0. '
|
86
|
+
'If None (default), honor the model\'s default behavior. '
|
70
87
|
)
|
71
|
-
] =
|
72
|
-
|
88
|
+
] = None
|
89
|
+
|
90
|
+
max_tokens: Annotated[
|
91
|
+
int | None,
|
92
|
+
(
|
93
|
+
'Per example max tokens to generate. '
|
94
|
+
'If None, use the model default.'
|
95
|
+
)
|
96
|
+
] = None
|
97
|
+
|
73
98
|
n: Annotated[int | None, 'Max number of samples to return.'] = 1
|
99
|
+
|
74
100
|
top_k: Annotated[
|
75
101
|
int | None,
|
76
102
|
(
|
@@ -78,6 +104,7 @@ class LMSamplingOptions(component.Component):
|
|
78
104
|
'Not applicable to OpenAI models.'
|
79
105
|
)
|
80
106
|
] = 40
|
107
|
+
|
81
108
|
top_p: Annotated[
|
82
109
|
float | None,
|
83
110
|
(
|
@@ -86,6 +113,7 @@ class LMSamplingOptions(component.Component):
|
|
86
113
|
'`top_p` but not both.'
|
87
114
|
),
|
88
115
|
] = None
|
116
|
+
|
89
117
|
stop: Annotated[
|
90
118
|
list[str] | None,
|
91
119
|
(
|
@@ -95,9 +123,11 @@ class LMSamplingOptions(component.Component):
|
|
95
123
|
'`Model:` is reached.'
|
96
124
|
),
|
97
125
|
] = None
|
126
|
+
|
98
127
|
random_seed: Annotated[
|
99
128
|
int | None, 'A fixed random seed used during model inference.'
|
100
129
|
] = None
|
130
|
+
|
101
131
|
logprobs: Annotated[
|
102
132
|
bool,
|
103
133
|
(
|
@@ -106,6 +136,7 @@ class LMSamplingOptions(component.Component):
|
|
106
136
|
'in the content of message.'
|
107
137
|
),
|
108
138
|
] = False
|
139
|
+
|
109
140
|
top_logprobs: Annotated[
|
110
141
|
int | None,
|
111
142
|
(
|
@@ -315,9 +346,42 @@ class LanguageModel(component.Component):
|
|
315
346
|
|
316
347
|
with component.context(override_attrs=True, **kwargs):
|
317
348
|
if self.cache is None:
|
318
|
-
|
349
|
+
results = self._sample(prompts)
|
319
350
|
else:
|
320
|
-
|
351
|
+
results = self._sample_with_cache_lookup(prompts, cache_seed)
|
352
|
+
|
353
|
+
for prompt, result in zip(prompts, results):
|
354
|
+
|
355
|
+
# Tag LM input.
|
356
|
+
prompt.tag(message_lib.Message.TAG_LM_INPUT)
|
357
|
+
|
358
|
+
for sample in result.samples:
|
359
|
+
# Update metadata for response message.
|
360
|
+
|
361
|
+
response = sample.response
|
362
|
+
response.metadata.score = sample.score
|
363
|
+
response.metadata.logprobs = sample.logprobs
|
364
|
+
|
365
|
+
# NOTE(daiyip): Current usage is computed at per-result level,
|
366
|
+
# which is accurate when n=1. For n > 1, we average the usage across
|
367
|
+
# multiple samples.
|
368
|
+
usage = result.usage
|
369
|
+
if len(result.samples) == 1 or usage is None:
|
370
|
+
response.metadata.usage = usage
|
371
|
+
else:
|
372
|
+
n = len(result.samples)
|
373
|
+
response.metadata.usage = LMSamplingUsage(
|
374
|
+
prompt_tokens=usage.prompt_tokens // n,
|
375
|
+
completion_tokens=usage.completion_tokens // n,
|
376
|
+
total_tokens=usage.total_tokens // n,
|
377
|
+
)
|
378
|
+
|
379
|
+
# Track the prompt for corresponding response.
|
380
|
+
response.source = prompt
|
381
|
+
|
382
|
+
# Tag LM response.
|
383
|
+
response.tag(message_lib.Message.TAG_LM_RESPONSE)
|
384
|
+
return results
|
321
385
|
|
322
386
|
def _sample_with_cache_lookup(
|
323
387
|
self, prompts: list[str | message_lib.Message], cache_seed: int
|
@@ -405,12 +469,9 @@ class LanguageModel(component.Component):
|
|
405
469
|
result = self.sample(
|
406
470
|
[prompt], sampling_options=sampling_options, cache_seed=cache_seed
|
407
471
|
)[0]
|
408
|
-
response = result.samples[0].response
|
409
|
-
logprobs = result.samples[0].logprobs
|
410
|
-
response.set('score', result.samples[0].score)
|
411
|
-
response.metadata.logprobs = logprobs
|
412
472
|
elapse = time.time() - request_start
|
413
|
-
|
473
|
+
response = result.samples[0].response
|
474
|
+
self._debug(prompt, response, call_counter, result.usage, elapse)
|
414
475
|
return response
|
415
476
|
|
416
477
|
def _debug(
|
@@ -418,35 +479,53 @@ class LanguageModel(component.Component):
|
|
418
479
|
prompt: message_lib.Message,
|
419
480
|
response: message_lib.Message,
|
420
481
|
call_counter: int,
|
482
|
+
usage: LMSamplingUsage | None,
|
421
483
|
elapse: float,
|
422
|
-
):
|
484
|
+
) -> None:
|
423
485
|
"""Outputs debugging information."""
|
424
486
|
debug = self.debug
|
425
487
|
if isinstance(debug, bool):
|
426
488
|
debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
|
427
489
|
|
428
490
|
if debug & LMDebugMode.INFO:
|
429
|
-
self._debug_model_info(call_counter)
|
491
|
+
self._debug_model_info(call_counter, usage)
|
430
492
|
|
431
493
|
if debug & LMDebugMode.PROMPT:
|
432
|
-
self._debug_prompt(prompt, call_counter)
|
494
|
+
self._debug_prompt(prompt, call_counter, usage)
|
433
495
|
|
434
496
|
if debug & LMDebugMode.RESPONSE:
|
435
|
-
self._debug_response(response, call_counter, elapse)
|
497
|
+
self._debug_response(response, call_counter, usage, elapse)
|
436
498
|
|
437
|
-
def _debug_model_info(
|
499
|
+
def _debug_model_info(
|
500
|
+
self, call_counter: int, usage: LMSamplingUsage | None) -> None:
|
438
501
|
"""Outputs debugging information about the model."""
|
502
|
+
title_suffix = ''
|
503
|
+
if usage and usage.total_tokens != 0:
|
504
|
+
title_suffix = console.colored(
|
505
|
+
f' (total {usage.total_tokens} tokens)', 'red')
|
506
|
+
|
439
507
|
console.write(
|
440
508
|
self.format(compact=True, use_inferred=True),
|
441
|
-
title=f'[{call_counter}] LM INFO:',
|
509
|
+
title=f'[{call_counter}] LM INFO{title_suffix}:',
|
442
510
|
color='magenta',
|
443
511
|
)
|
444
512
|
|
445
|
-
def _debug_prompt(
|
513
|
+
def _debug_prompt(
|
514
|
+
self,
|
515
|
+
prompt: message_lib.Message,
|
516
|
+
call_counter: int,
|
517
|
+
usage: LMSamplingUsage | None,
|
518
|
+
) -> None:
|
446
519
|
"""Outputs debugging information about the prompt."""
|
520
|
+
title_suffix = ''
|
521
|
+
if usage and usage.prompt_tokens != 0:
|
522
|
+
title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
|
523
|
+
|
447
524
|
console.write(
|
448
|
-
prompt
|
449
|
-
|
525
|
+
# We use metadata 'formatted_text' for scenarios where the prompt text
|
526
|
+
# is formatted by the LM.
|
527
|
+
prompt.get('formatted_text', prompt.text),
|
528
|
+
title=f'\n[{call_counter}] PROMPT SENT TO LM{title_suffix}:',
|
450
529
|
color='green',
|
451
530
|
)
|
452
531
|
referred_modalities = prompt.referred_modalities()
|
@@ -460,12 +539,22 @@ class LanguageModel(component.Component):
|
|
460
539
|
)
|
461
540
|
|
462
541
|
def _debug_response(
|
463
|
-
self,
|
464
|
-
|
542
|
+
self,
|
543
|
+
response: message_lib.Message,
|
544
|
+
call_counter: int,
|
545
|
+
usage: LMSamplingUsage | None,
|
546
|
+
elapse: float
|
547
|
+
) -> None:
|
465
548
|
"""Outputs debugging information about the response."""
|
549
|
+
title_suffix = ' ('
|
550
|
+
if usage and usage.completion_tokens != 0:
|
551
|
+
title_suffix += f'{usage.completion_tokens} tokens '
|
552
|
+
title_suffix += f'in {elapse:.2f} seconds)'
|
553
|
+
title_suffix = console.colored(title_suffix, 'red')
|
554
|
+
|
466
555
|
console.write(
|
467
556
|
str(response) + '\n',
|
468
|
-
title=f'\n[{call_counter}] LM RESPONSE
|
557
|
+
title=f'\n[{call_counter}] LM RESPONSE{title_suffix}:',
|
469
558
|
color='blue',
|
470
559
|
)
|
471
560
|
|
@@ -512,7 +601,7 @@ class LanguageModel(component.Component):
|
|
512
601
|
debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
|
513
602
|
|
514
603
|
if debug & LMDebugMode.INFO:
|
515
|
-
self._debug_model_info(call_counter)
|
604
|
+
self._debug_model_info(call_counter, None)
|
516
605
|
|
517
606
|
if debug & LMDebugMode.PROMPT:
|
518
607
|
console.write(
|
@@ -548,3 +637,14 @@ class LanguageModel(component.Component):
|
|
548
637
|
f'score: {r.score}',
|
549
638
|
color='blue',
|
550
639
|
)
|
640
|
+
|
641
|
+
def rate_to_max_concurrency(
|
642
|
+
self, requests_per_min: float = 0, tokens_per_min: float = 0
|
643
|
+
) -> int:
|
644
|
+
"""Converts a rate to a max concurrency."""
|
645
|
+
if tokens_per_min > 0:
|
646
|
+
return max(int(tokens_per_min / TOKENS_PER_REQUEST / 60), 1)
|
647
|
+
elif requests_per_min > 0:
|
648
|
+
return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
|
649
|
+
else:
|
650
|
+
return DEFAULT_MAX_CONCURRENCY # Default of 1
|