langfun 0.0.2.dev20240330__py3-none-any.whl → 0.0.2.dev20240429__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. langfun/__init__.py +2 -0
  2. langfun/core/__init__.py +1 -0
  3. langfun/core/coding/python/correction.py +0 -7
  4. langfun/core/component.py +6 -0
  5. langfun/core/component_test.py +1 -0
  6. langfun/core/eval/__init__.py +2 -0
  7. langfun/core/eval/base.py +202 -23
  8. langfun/core/eval/base_test.py +49 -10
  9. langfun/core/eval/matching.py +26 -9
  10. langfun/core/eval/matching_test.py +2 -1
  11. langfun/core/eval/scoring.py +15 -6
  12. langfun/core/eval/scoring_test.py +2 -1
  13. langfun/core/langfunc.py +0 -5
  14. langfun/core/langfunc_test.py +6 -4
  15. langfun/core/language_model.py +124 -24
  16. langfun/core/language_model_test.py +249 -26
  17. langfun/core/llms/__init__.py +19 -2
  18. langfun/core/llms/anthropic.py +263 -0
  19. langfun/core/llms/anthropic_test.py +167 -0
  20. langfun/core/llms/cache/in_memory_test.py +37 -28
  21. langfun/core/llms/fake.py +31 -22
  22. langfun/core/llms/fake_test.py +122 -11
  23. langfun/core/llms/google_genai_test.py +8 -3
  24. langfun/core/llms/groq.py +260 -0
  25. langfun/core/llms/groq_test.py +170 -0
  26. langfun/core/llms/llama_cpp.py +3 -1
  27. langfun/core/llms/openai.py +97 -79
  28. langfun/core/llms/openai_test.py +285 -59
  29. langfun/core/modalities/video.py +5 -2
  30. langfun/core/structured/__init__.py +3 -0
  31. langfun/core/structured/completion_test.py +2 -2
  32. langfun/core/structured/function_generation.py +245 -0
  33. langfun/core/structured/function_generation_test.py +329 -0
  34. langfun/core/structured/mapping.py +56 -2
  35. langfun/core/structured/mapping_test.py +17 -0
  36. langfun/core/structured/parsing_test.py +18 -13
  37. langfun/core/structured/prompting.py +27 -6
  38. langfun/core/structured/prompting_test.py +79 -12
  39. langfun/core/structured/schema.py +4 -2
  40. langfun/core/structured/schema_generation_test.py +2 -2
  41. langfun/core/structured/schema_test.py +4 -6
  42. langfun/core/template.py +125 -10
  43. langfun/core/template_test.py +75 -0
  44. langfun/core/templates/selfplay_test.py +6 -2
  45. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/METADATA +3 -2
  46. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/RECORD +49 -43
  47. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/LICENSE +0 -0
  48. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/WHEEL +0 -0
  49. {langfun-0.0.2.dev20240330.dist-info → langfun-0.0.2.dev20240429.dist-info}/top_level.txt +0 -0
@@ -103,7 +103,7 @@ class MatchingTest(unittest.TestCase):
103
103
  s.result,
104
104
  dict(
105
105
  experiment_setup=dict(
106
- id='MyTask@3d87f97f',
106
+ id='MyTask@739a174b',
107
107
  dir=s.dir,
108
108
  model='StaticSequence',
109
109
  prompt_template='{{example.question}}',
@@ -125,6 +125,7 @@ class MatchingTest(unittest.TestCase):
125
125
  num_mismatches=1,
126
126
  mismatch_rate=0.25,
127
127
  ),
128
+ usage=s.result.usage,
128
129
  ),
129
130
  )
130
131
  self.assertTrue(
@@ -61,8 +61,18 @@ class Scoring(base.Evaluation):
61
61
  super()._reset()
62
62
  self._scored = []
63
63
 
64
- def audit(self, example: Any, output: Any, message: lf.Message) -> None:
64
+ def audit_processed(
65
+ self, example: Any, output: Any, message: lf.Message, dryrun: bool = False
66
+ ) -> None:
65
67
  score = self.score(example, output)
68
+
69
+ if dryrun:
70
+ lf.console.write('')
71
+ lf.console.write(
72
+ str(score),
73
+ title='SCORE',
74
+ color='blue',
75
+ )
66
76
  self._scored.append((example, output, score, message))
67
77
 
68
78
  @abc.abstractmethod
@@ -118,19 +128,18 @@ class Scoring(base.Evaluation):
118
128
  super().save(definition, result, report)
119
129
 
120
130
  if result:
121
-
122
- def force_dict(v):
123
- return pg.object_utils.json_conversion.strip_types(pg.to_json(v))
124
-
125
131
  # Save scored.
126
132
  pg.save(
127
133
  [
128
134
  # We force the output to be dict as its type may be defined
129
135
  # within functors which could be deserialized.
130
- pg.Dict(input=input, output=force_dict(output), score=score)
136
+ pg.Dict(input=input, output=output, score=score)
131
137
  for input, output, score, _ in self.scored
132
138
  ],
133
139
  os.path.join(self.dir, Scoring.SCORED_JSON),
140
+ # We force the input and output to be dict so it does not depend on
141
+ # the downstream to serialize.
142
+ force_dict=True,
134
143
  )
135
144
 
136
145
  if report:
@@ -81,7 +81,7 @@ class ScoringTest(unittest.TestCase):
81
81
  s.result,
82
82
  dict(
83
83
  experiment_setup=dict(
84
- id='ConstraintFollowing@9e51bb9e',
84
+ id='ConstraintFollowing@5c88a5eb',
85
85
  dir=s.dir,
86
86
  model='StaticSequence',
87
87
  prompt_template='{{example}}',
@@ -102,6 +102,7 @@ class ScoringTest(unittest.TestCase):
102
102
  score_rate=1.0,
103
103
  avg_score=0.5,
104
104
  ),
105
+ usage=s.result.usage,
105
106
  ),
106
107
  )
107
108
  self.assertTrue(
langfun/core/langfunc.py CHANGED
@@ -261,7 +261,6 @@ class LangFunc(
261
261
  if lm_input is None:
262
262
  lm_input = self.render(**kwargs)
263
263
 
264
- lm_input.tag(message_lib.Message.TAG_LM_INPUT)
265
264
  if skip_lm:
266
265
  return lm_input
267
266
 
@@ -270,10 +269,6 @@ class LangFunc(
270
269
  # Send rendered text to LM.
271
270
  lm_output = self.lm(lm_input, cache_seed=cache_seed)
272
271
 
273
- # Track the input as the source of the output.
274
- lm_output.source = lm_input
275
- lm_output.tag(message_lib.Message.TAG_LM_RESPONSE)
276
-
277
272
  # Transform the output message.
278
273
  lm_output = self.transform_output(lm_output)
279
274
  lm_output.tag(message_lib.Message.TAG_LM_OUTPUT)
@@ -82,7 +82,9 @@ class LangFuncCallTest(unittest.TestCase):
82
82
  self.assertEqual(i.tags, ['rendered'])
83
83
 
84
84
  r = l()
85
- self.assertEqual(r, message.AIMessage('Hello!!!', score=0.0, logprobs=None))
85
+ self.assertEqual(
86
+ r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
87
+ )
86
88
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
87
89
  self.assertEqual(r.source, message.UserMessage('Hello'))
88
90
  self.assertEqual(r.source.tags, ['rendered', 'lm-input'])
@@ -92,8 +94,8 @@ class LangFuncCallTest(unittest.TestCase):
92
94
  self.assertEqual(
93
95
  repr(l),
94
96
  "LangFunc(template_str='Hello', clean=True,"
95
- ' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0,'
96
- ' max_tokens=1024, n=1, top_k=40, top_p=None, stop=None,'
97
+ ' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=None,'
98
+ ' max_tokens=None, n=1, top_k=40, top_p=None, stop=None,'
97
99
  ' random_seed=None, logprobs=False, top_logprobs=None), cache=None,'
98
100
  ' max_concurrency=None, timeout=120.0, max_attempts=5,'
99
101
  ' retry_interval=(5, 60), exponential_backoff=True, debug=False))',
@@ -106,7 +108,7 @@ class LangFuncCallTest(unittest.TestCase):
106
108
  self.assertEqual(l.render(), 'Hello')
107
109
  r = l()
108
110
  self.assertEqual(
109
- r, message.AIMessage('Hello!!!', score=0.0, logprobs=None)
111
+ r, message.AIMessage('Hello!!!', score=0.0, logprobs=None, usage=None)
110
112
  )
111
113
  self.assertEqual(r.tags, ['lm-response', 'lm-output'])
112
114
 
@@ -24,6 +24,9 @@ from langfun.core import console
24
24
  from langfun.core import message as message_lib
25
25
  import pyglove as pg
26
26
 
27
+ TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
28
+ DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
29
+
27
30
 
28
31
  class LMSample(pg.Object):
29
32
  """Response candidate."""
@@ -47,6 +50,14 @@ class LMSample(pg.Object):
47
50
  ] = None
48
51
 
49
52
 
53
+ class LMSamplingUsage(pg.Object):
54
+ """Usage information per completion."""
55
+
56
+ prompt_tokens: int
57
+ completion_tokens: int
58
+ total_tokens: int
59
+
60
+
50
61
  class LMSamplingResult(pg.Object):
51
62
  """Language model response."""
52
63
 
@@ -58,19 +69,34 @@ class LMSamplingResult(pg.Object):
58
69
  ),
59
70
  ] = []
60
71
 
72
+ usage: Annotated[
73
+ LMSamplingUsage | None,
74
+ 'Usage information. Currently only OpenAI models are supported.',
75
+ ] = None
76
+
61
77
 
62
78
  class LMSamplingOptions(component.Component):
63
79
  """Language model sampling options."""
64
80
 
65
81
  temperature: Annotated[
66
- float,
82
+ float | None,
67
83
  (
68
84
  'Model temperature, which is usually between 0 and 1.0. '
69
- 'OpenAI models have temperature range from 0.0 to 2.0.'
85
+ 'OpenAI models have temperature range from 0.0 to 2.0. '
86
+ 'If None (default), honor the model\'s default behavior. '
70
87
  )
71
- ] = 0.0
72
- max_tokens: Annotated[int, 'Per example max tokens to generate.'] = 1024
88
+ ] = None
89
+
90
+ max_tokens: Annotated[
91
+ int | None,
92
+ (
93
+ 'Per example max tokens to generate. '
94
+ 'If None, use the model default.'
95
+ )
96
+ ] = None
97
+
73
98
  n: Annotated[int | None, 'Max number of samples to return.'] = 1
99
+
74
100
  top_k: Annotated[
75
101
  int | None,
76
102
  (
@@ -78,6 +104,7 @@ class LMSamplingOptions(component.Component):
78
104
  'Not applicable to OpenAI models.'
79
105
  )
80
106
  ] = 40
107
+
81
108
  top_p: Annotated[
82
109
  float | None,
83
110
  (
@@ -86,6 +113,7 @@ class LMSamplingOptions(component.Component):
86
113
  '`top_p` but not both.'
87
114
  ),
88
115
  ] = None
116
+
89
117
  stop: Annotated[
90
118
  list[str] | None,
91
119
  (
@@ -95,9 +123,11 @@ class LMSamplingOptions(component.Component):
95
123
  '`Model:` is reached.'
96
124
  ),
97
125
  ] = None
126
+
98
127
  random_seed: Annotated[
99
128
  int | None, 'A fixed random seed used during model inference.'
100
129
  ] = None
130
+
101
131
  logprobs: Annotated[
102
132
  bool,
103
133
  (
@@ -106,6 +136,7 @@ class LMSamplingOptions(component.Component):
106
136
  'in the content of message.'
107
137
  ),
108
138
  ] = False
139
+
109
140
  top_logprobs: Annotated[
110
141
  int | None,
111
142
  (
@@ -315,9 +346,42 @@ class LanguageModel(component.Component):
315
346
 
316
347
  with component.context(override_attrs=True, **kwargs):
317
348
  if self.cache is None:
318
- return self._sample(prompts)
349
+ results = self._sample(prompts)
319
350
  else:
320
- return self._sample_with_cache_lookup(prompts, cache_seed)
351
+ results = self._sample_with_cache_lookup(prompts, cache_seed)
352
+
353
+ for prompt, result in zip(prompts, results):
354
+
355
+ # Tag LM input.
356
+ prompt.tag(message_lib.Message.TAG_LM_INPUT)
357
+
358
+ for sample in result.samples:
359
+ # Update metadata for response message.
360
+
361
+ response = sample.response
362
+ response.metadata.score = sample.score
363
+ response.metadata.logprobs = sample.logprobs
364
+
365
+ # NOTE(daiyip): Current usage is computed at per-result level,
366
+ # which is accurate when n=1. For n > 1, we average the usage across
367
+ # multiple samples.
368
+ usage = result.usage
369
+ if len(result.samples) == 1 or usage is None:
370
+ response.metadata.usage = usage
371
+ else:
372
+ n = len(result.samples)
373
+ response.metadata.usage = LMSamplingUsage(
374
+ prompt_tokens=usage.prompt_tokens // n,
375
+ completion_tokens=usage.completion_tokens // n,
376
+ total_tokens=usage.total_tokens // n,
377
+ )
378
+
379
+ # Track the prompt for corresponding response.
380
+ response.source = prompt
381
+
382
+ # Tag LM response.
383
+ response.tag(message_lib.Message.TAG_LM_RESPONSE)
384
+ return results
321
385
 
322
386
  def _sample_with_cache_lookup(
323
387
  self, prompts: list[str | message_lib.Message], cache_seed: int
@@ -405,12 +469,9 @@ class LanguageModel(component.Component):
405
469
  result = self.sample(
406
470
  [prompt], sampling_options=sampling_options, cache_seed=cache_seed
407
471
  )[0]
408
- response = result.samples[0].response
409
- logprobs = result.samples[0].logprobs
410
- response.set('score', result.samples[0].score)
411
- response.metadata.logprobs = logprobs
412
472
  elapse = time.time() - request_start
413
- self._debug(prompt, response, call_counter, elapse)
473
+ response = result.samples[0].response
474
+ self._debug(prompt, response, call_counter, result.usage, elapse)
414
475
  return response
415
476
 
416
477
  def _debug(
@@ -418,35 +479,53 @@ class LanguageModel(component.Component):
418
479
  prompt: message_lib.Message,
419
480
  response: message_lib.Message,
420
481
  call_counter: int,
482
+ usage: LMSamplingUsage | None,
421
483
  elapse: float,
422
- ):
484
+ ) -> None:
423
485
  """Outputs debugging information."""
424
486
  debug = self.debug
425
487
  if isinstance(debug, bool):
426
488
  debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
427
489
 
428
490
  if debug & LMDebugMode.INFO:
429
- self._debug_model_info(call_counter)
491
+ self._debug_model_info(call_counter, usage)
430
492
 
431
493
  if debug & LMDebugMode.PROMPT:
432
- self._debug_prompt(prompt, call_counter)
494
+ self._debug_prompt(prompt, call_counter, usage)
433
495
 
434
496
  if debug & LMDebugMode.RESPONSE:
435
- self._debug_response(response, call_counter, elapse)
497
+ self._debug_response(response, call_counter, usage, elapse)
436
498
 
437
- def _debug_model_info(self, call_counter: int):
499
+ def _debug_model_info(
500
+ self, call_counter: int, usage: LMSamplingUsage | None) -> None:
438
501
  """Outputs debugging information about the model."""
502
+ title_suffix = ''
503
+ if usage and usage.total_tokens != 0:
504
+ title_suffix = console.colored(
505
+ f' (total {usage.total_tokens} tokens)', 'red')
506
+
439
507
  console.write(
440
508
  self.format(compact=True, use_inferred=True),
441
- title=f'[{call_counter}] LM INFO:',
509
+ title=f'[{call_counter}] LM INFO{title_suffix}:',
442
510
  color='magenta',
443
511
  )
444
512
 
445
- def _debug_prompt(self, prompt: message_lib.Message, call_counter: int):
513
+ def _debug_prompt(
514
+ self,
515
+ prompt: message_lib.Message,
516
+ call_counter: int,
517
+ usage: LMSamplingUsage | None,
518
+ ) -> None:
446
519
  """Outputs debugging information about the prompt."""
520
+ title_suffix = ''
521
+ if usage and usage.prompt_tokens != 0:
522
+ title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
523
+
447
524
  console.write(
448
- prompt,
449
- title=f'\n[{call_counter}] PROMPT SENT TO LM:',
525
+ # We use metadata 'formatted_text' for scenarios where the prompt text
526
+ # is formatted by the LM.
527
+ prompt.get('formatted_text', prompt.text),
528
+ title=f'\n[{call_counter}] PROMPT SENT TO LM{title_suffix}:',
450
529
  color='green',
451
530
  )
452
531
  referred_modalities = prompt.referred_modalities()
@@ -460,12 +539,22 @@ class LanguageModel(component.Component):
460
539
  )
461
540
 
462
541
  def _debug_response(
463
- self, response: message_lib.Message, call_counter: int, elapse: float
464
- ):
542
+ self,
543
+ response: message_lib.Message,
544
+ call_counter: int,
545
+ usage: LMSamplingUsage | None,
546
+ elapse: float
547
+ ) -> None:
465
548
  """Outputs debugging information about the response."""
549
+ title_suffix = ' ('
550
+ if usage and usage.completion_tokens != 0:
551
+ title_suffix += f'{usage.completion_tokens} tokens '
552
+ title_suffix += f'in {elapse:.2f} seconds)'
553
+ title_suffix = console.colored(title_suffix, 'red')
554
+
466
555
  console.write(
467
556
  str(response) + '\n',
468
- title=f'\n[{call_counter}] LM RESPONSE (in {elapse:.2f} seconds):',
557
+ title=f'\n[{call_counter}] LM RESPONSE{title_suffix}:',
469
558
  color='blue',
470
559
  )
471
560
 
@@ -512,7 +601,7 @@ class LanguageModel(component.Component):
512
601
  debug = LMDebugMode.ALL if debug else LMDebugMode.NONE
513
602
 
514
603
  if debug & LMDebugMode.INFO:
515
- self._debug_model_info(call_counter)
604
+ self._debug_model_info(call_counter, None)
516
605
 
517
606
  if debug & LMDebugMode.PROMPT:
518
607
  console.write(
@@ -548,3 +637,14 @@ class LanguageModel(component.Component):
548
637
  f'score: {r.score}',
549
638
  color='blue',
550
639
  )
640
+
641
+ def rate_to_max_concurrency(
642
+ self, requests_per_min: float = 0, tokens_per_min: float = 0
643
+ ) -> int:
644
+ """Converts a rate to a max concurrency."""
645
+ if tokens_per_min > 0:
646
+ return max(int(tokens_per_min / TOKENS_PER_REQUEST / 60), 1)
647
+ elif requests_per_min > 0:
648
+ return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
649
+ else:
650
+ return DEFAULT_MAX_CONCURRENCY # Default of 1