langfun 0.0.2.dev20240425__tar.gz → 0.0.2.dev20240428__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/PKG-INFO +1 -1
  2. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/eval/base.py +1 -1
  3. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/langfunc.py +0 -5
  4. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/language_model.py +39 -9
  5. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/language_model_test.py +156 -18
  6. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/fake_test.py +91 -7
  7. langfun-0.0.2.dev20240428/langfun/core/llms/openai_test.py +450 -0
  8. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/prompting.py +14 -4
  9. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/prompting_test.py +33 -0
  10. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/template.py +99 -2
  11. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/template_test.py +66 -0
  12. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun.egg-info/PKG-INFO +1 -1
  13. langfun-0.0.2.dev20240425/langfun/core/llms/openai_test.py +0 -265
  14. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/LICENSE +0 -0
  15. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/README.md +0 -0
  16. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/__init__.py +0 -0
  17. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/__init__.py +0 -0
  18. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/coding/__init__.py +0 -0
  19. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/coding/python/__init__.py +0 -0
  20. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/coding/python/correction.py +0 -0
  21. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/coding/python/correction_test.py +0 -0
  22. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/coding/python/errors.py +0 -0
  23. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/coding/python/errors_test.py +0 -0
  24. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/coding/python/execution.py +0 -0
  25. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/coding/python/execution_test.py +0 -0
  26. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/coding/python/generation.py +0 -0
  27. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/coding/python/generation_test.py +0 -0
  28. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/coding/python/parsing.py +0 -0
  29. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/coding/python/parsing_test.py +0 -0
  30. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/coding/python/permissions.py +0 -0
  31. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/coding/python/permissions_test.py +0 -0
  32. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/component.py +0 -0
  33. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/component_test.py +0 -0
  34. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/concurrent.py +0 -0
  35. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/concurrent_test.py +0 -0
  36. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/console.py +0 -0
  37. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/console_test.py +0 -0
  38. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/eval/__init__.py +0 -0
  39. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/eval/base_test.py +0 -0
  40. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/eval/matching.py +0 -0
  41. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/eval/matching_test.py +0 -0
  42. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/eval/scoring.py +0 -0
  43. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/eval/scoring_test.py +0 -0
  44. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/langfunc_test.py +0 -0
  45. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/__init__.py +0 -0
  46. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/anthropic.py +0 -0
  47. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/anthropic_test.py +0 -0
  48. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/cache/__init__.py +0 -0
  49. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/cache/base.py +0 -0
  50. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/cache/in_memory.py +0 -0
  51. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/cache/in_memory_test.py +0 -0
  52. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/fake.py +0 -0
  53. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/google_genai.py +0 -0
  54. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/google_genai_test.py +0 -0
  55. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/groq.py +0 -0
  56. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/groq_test.py +0 -0
  57. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/llama_cpp.py +0 -0
  58. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/llama_cpp_test.py +0 -0
  59. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/llms/openai.py +0 -0
  60. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/memories/__init__.py +0 -0
  61. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/memories/conversation_history.py +0 -0
  62. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/memories/conversation_history_test.py +0 -0
  63. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/memory.py +0 -0
  64. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/message.py +0 -0
  65. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/message_test.py +0 -0
  66. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/modalities/__init__.py +0 -0
  67. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/modalities/image.py +0 -0
  68. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/modalities/image_test.py +0 -0
  69. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/modalities/mime.py +0 -0
  70. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/modalities/mime_test.py +0 -0
  71. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/modalities/video.py +0 -0
  72. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/modalities/video_test.py +0 -0
  73. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/modality.py +0 -0
  74. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/modality_test.py +0 -0
  75. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/natural_language.py +0 -0
  76. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/natural_language_test.py +0 -0
  77. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/sampling.py +0 -0
  78. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/sampling_test.py +0 -0
  79. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/__init__.py +0 -0
  80. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/completion.py +0 -0
  81. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/completion_test.py +0 -0
  82. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/description.py +0 -0
  83. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/description_test.py +0 -0
  84. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/function_generation.py +0 -0
  85. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/function_generation_test.py +0 -0
  86. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/mapping.py +0 -0
  87. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/mapping_test.py +0 -0
  88. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/parsing.py +0 -0
  89. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/parsing_test.py +0 -0
  90. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/schema.py +0 -0
  91. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/schema_generation.py +0 -0
  92. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/schema_generation_test.py +0 -0
  93. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/schema_test.py +0 -0
  94. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/scoring.py +0 -0
  95. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/structured/scoring_test.py +0 -0
  96. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/subscription.py +0 -0
  97. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/subscription_test.py +0 -0
  98. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/templates/__init__.py +0 -0
  99. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/templates/completion.py +0 -0
  100. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/templates/completion_test.py +0 -0
  101. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/templates/conversation.py +0 -0
  102. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/templates/conversation_test.py +0 -0
  103. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/templates/demonstration.py +0 -0
  104. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/templates/demonstration_test.py +0 -0
  105. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/templates/selfplay.py +0 -0
  106. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/templates/selfplay_test.py +0 -0
  107. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/text_formatting.py +0 -0
  108. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun/core/text_formatting_test.py +0 -0
  109. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun.egg-info/SOURCES.txt +0 -0
  110. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun.egg-info/dependency_links.txt +0 -0
  111. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun.egg-info/requires.txt +0 -0
  112. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/langfun.egg-info/top_level.txt +0 -0
  113. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/setup.cfg +0 -0
  114. {langfun-0.0.2.dev20240425 → langfun-0.0.2.dev20240428}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240425
3
+ Version: 0.0.2.dev20240428
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -540,7 +540,7 @@ class Evaluable(lf.Component):
540
540
  f'<div style="color: {text_color}; white-space: pre-wrap;'
541
541
  'padding: 10px; border: 1px solid; margin-top: 10px">'
542
542
  )
543
- s.write(m.text)
543
+ s.write(m.get('formatted_text', m.text))
544
544
  if m.result is not None:
545
545
  s.write(
546
546
  '<div style="color: magenta; white-space: pre-wrap;'
@@ -261,7 +261,6 @@ class LangFunc(
261
261
  if lm_input is None:
262
262
  lm_input = self.render(**kwargs)
263
263
 
264
- lm_input.tag(message_lib.Message.TAG_LM_INPUT)
265
264
  if skip_lm:
266
265
  return lm_input
267
266
 
@@ -270,10 +269,6 @@ class LangFunc(
270
269
  # Send rendered text to LM.
271
270
  lm_output = self.lm(lm_input, cache_seed=cache_seed)
272
271
 
273
- # Track the input as the source of the output.
274
- lm_output.source = lm_input
275
- lm_output.tag(message_lib.Message.TAG_LM_RESPONSE)
276
-
277
272
  # Transform the output message.
278
273
  lm_output = self.transform_output(lm_output)
279
274
  lm_output.tag(message_lib.Message.TAG_LM_OUTPUT)
@@ -346,9 +346,42 @@ class LanguageModel(component.Component):
346
346
 
347
347
  with component.context(override_attrs=True, **kwargs):
348
348
  if self.cache is None:
349
- return self._sample(prompts)
349
+ results = self._sample(prompts)
350
350
  else:
351
- return self._sample_with_cache_lookup(prompts, cache_seed)
351
+ results = self._sample_with_cache_lookup(prompts, cache_seed)
352
+
353
+ for prompt, result in zip(prompts, results):
354
+
355
+ # Tag LM input.
356
+ prompt.tag(message_lib.Message.TAG_LM_INPUT)
357
+
358
+ for sample in result.samples:
359
+ # Update metadata for response message.
360
+
361
+ response = sample.response
362
+ response.metadata.score = sample.score
363
+ response.metadata.logprobs = sample.logprobs
364
+
365
+ # NOTE(daiyip): Current usage is computed at per-result level,
366
+ # which is accurate when n=1. For n > 1, we average the usage across
367
+ # multiple samples.
368
+ usage = result.usage
369
+ if len(result.samples) == 1 or usage is None:
370
+ response.metadata.usage = usage
371
+ else:
372
+ n = len(result.samples)
373
+ response.metadata.usage = LMSamplingUsage(
374
+ prompt_tokens=usage.prompt_tokens // n,
375
+ completion_tokens=usage.completion_tokens // n,
376
+ total_tokens=usage.total_tokens // n,
377
+ )
378
+
379
+ # Track the prompt for corresponding response.
380
+ response.source = prompt
381
+
382
+ # Tag LM response.
383
+ response.tag(message_lib.Message.TAG_LM_RESPONSE)
384
+ return results
352
385
 
353
386
  def _sample_with_cache_lookup(
354
387
  self, prompts: list[str | message_lib.Message], cache_seed: int
@@ -436,13 +469,8 @@ class LanguageModel(component.Component):
436
469
  result = self.sample(
437
470
  [prompt], sampling_options=sampling_options, cache_seed=cache_seed
438
471
  )[0]
439
- response = result.samples[0].response
440
- logprobs = result.samples[0].logprobs
441
- response.set('score', result.samples[0].score)
442
- response.metadata.logprobs = logprobs
443
- response.metadata.usage = result.usage
444
-
445
472
  elapse = time.time() - request_start
473
+ response = result.samples[0].response
446
474
  self._debug(prompt, response, call_counter, result.usage, elapse)
447
475
  return response
448
476
 
@@ -494,7 +522,9 @@ class LanguageModel(component.Component):
494
522
  title_suffix = console.colored(f' ({usage.prompt_tokens} tokens)', 'red')
495
523
 
496
524
  console.write(
497
- prompt,
525
+ # We use metadata 'formatted_text' for scenarios where the prompt text
526
+ # is formatted by the LM.
527
+ prompt.get('formatted_text', prompt.text),
498
528
  title=f'\n[{call_counter}] PROMPT SENT TO LM{title_suffix}:',
499
529
  color='green',
500
530
  )
@@ -111,11 +111,35 @@ class LanguageModelTest(unittest.TestCase):
111
111
  lm.sample(prompts=['foo', 'bar']),
112
112
  [
113
113
  lm_lib.LMSamplingResult(
114
- [lm_lib.LMSample('foo', score=-1.0)],
114
+ [
115
+ lm_lib.LMSample(
116
+ message_lib.AIMessage(
117
+ 'foo',
118
+ score=-1.0,
119
+ logprobs=None,
120
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
121
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
122
+ ),
123
+ score=-1.0,
124
+ logprobs=None,
125
+ )
126
+ ],
115
127
  usage=lm_lib.LMSamplingUsage(100, 100, 200),
116
128
  ),
117
129
  lm_lib.LMSamplingResult(
118
- [lm_lib.LMSample('bar', score=-1.0)],
130
+ [
131
+ lm_lib.LMSample(
132
+ message_lib.AIMessage(
133
+ 'bar',
134
+ score=-1.0,
135
+ logprobs=None,
136
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
137
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
138
+ ),
139
+ score=-1.0,
140
+ logprobs=None,
141
+ )
142
+ ],
119
143
  usage=lm_lib.LMSamplingUsage(100, 100, 200),
120
144
  ),
121
145
  ],
@@ -128,41 +152,119 @@ class LanguageModelTest(unittest.TestCase):
128
152
  ),
129
153
  [
130
154
  lm_lib.LMSamplingResult(
131
- [lm_lib.LMSample('foo' * 2, score=0.5)],
155
+ [
156
+ lm_lib.LMSample(
157
+ message_lib.AIMessage(
158
+ 'foo' * 2,
159
+ score=0.5,
160
+ logprobs=None,
161
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
162
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
163
+ ),
164
+ score=0.5,
165
+ logprobs=None,
166
+ ),
167
+ ],
132
168
  usage=lm_lib.LMSamplingUsage(100, 100, 200),
133
169
  ),
134
170
  lm_lib.LMSamplingResult(
135
- [lm_lib.LMSample('bar' * 2, score=0.5)],
136
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
171
+ [
172
+ lm_lib.LMSample(
173
+ message_lib.AIMessage(
174
+ 'bar' * 2,
175
+ score=0.5,
176
+ logprobs=None,
177
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
178
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
179
+ ),
180
+ score=0.5,
181
+ logprobs=None,
182
+ ),
183
+ ],
184
+ usage=lm_lib.LMSamplingUsage(
185
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
186
+ ),
137
187
  ),
138
- ],
188
+ ]
139
189
  )
140
190
  # Test override individual flags within sampling_options.
141
191
  self.assertEqual(
142
192
  lm.sample(prompts=['foo', 'bar'], temperature=1.0),
143
193
  [
144
194
  lm_lib.LMSamplingResult(
145
- [lm_lib.LMSample('foo', score=1.0)],
195
+ [
196
+ lm_lib.LMSample(
197
+ message_lib.AIMessage(
198
+ 'foo',
199
+ score=1.0,
200
+ logprobs=None,
201
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
202
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
203
+ ),
204
+ score=1.0,
205
+ logprobs=None,
206
+ ),
207
+ ],
146
208
  usage=lm_lib.LMSamplingUsage(100, 100, 200),
147
209
  ),
148
210
  lm_lib.LMSamplingResult(
149
- [lm_lib.LMSample('bar', score=1.0)],
150
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
211
+ [
212
+ lm_lib.LMSample(
213
+ message_lib.AIMessage(
214
+ 'bar',
215
+ score=1.0,
216
+ logprobs=None,
217
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
218
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
219
+ ),
220
+ score=1.0,
221
+ logprobs=None,
222
+ ),
223
+ ],
224
+ usage=lm_lib.LMSamplingUsage(
225
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
226
+ ),
151
227
  ),
152
- ],
228
+ ]
153
229
  )
154
230
  self.assertEqual(
155
231
  lm.sample(prompts=['foo', 'bar'], top_k=2, temperature=0.7),
156
232
  [
157
233
  lm_lib.LMSamplingResult(
158
- [lm_lib.LMSample('foo' * 2, score=0.7)],
234
+ [
235
+ lm_lib.LMSample(
236
+ message_lib.AIMessage(
237
+ 'foo' * 2,
238
+ score=0.7,
239
+ logprobs=None,
240
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
241
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
242
+ ),
243
+ score=0.7,
244
+ logprobs=None,
245
+ ),
246
+ ],
159
247
  usage=lm_lib.LMSamplingUsage(100, 100, 200),
160
248
  ),
161
249
  lm_lib.LMSamplingResult(
162
- [lm_lib.LMSample('bar' * 2, score=0.7)],
163
- usage=lm_lib.LMSamplingUsage(100, 100, 200),
250
+ [
251
+ lm_lib.LMSample(
252
+ message_lib.AIMessage(
253
+ 'bar' * 2,
254
+ score=0.7,
255
+ logprobs=None,
256
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
257
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
258
+ ),
259
+ score=0.7,
260
+ logprobs=None,
261
+ ),
262
+ ],
263
+ usage=lm_lib.LMSamplingUsage(
264
+ prompt_tokens=100, completion_tokens=100, total_tokens=200
265
+ ),
164
266
  ),
165
- ],
267
+ ]
166
268
  )
167
269
 
168
270
  def test_call(self):
@@ -189,7 +291,16 @@ class LanguageModelTest(unittest.TestCase):
189
291
  lm_lib.LMSamplingResult(
190
292
  [
191
293
  lm_lib.LMSample(
192
- message_lib.AIMessage('foo', cache_seed=0), score=-1.0
294
+ message_lib.AIMessage(
295
+ 'foo',
296
+ cache_seed=0,
297
+ score=-1.0,
298
+ logprobs=None,
299
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
300
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
301
+ ),
302
+ score=-1.0,
303
+ logprobs=None,
193
304
  )
194
305
  ],
195
306
  usage=lm_lib.LMSamplingUsage(100, 100, 200),
@@ -197,7 +308,16 @@ class LanguageModelTest(unittest.TestCase):
197
308
  lm_lib.LMSamplingResult(
198
309
  [
199
310
  lm_lib.LMSample(
200
- message_lib.AIMessage('bar', cache_seed=0), score=-1.0
311
+ message_lib.AIMessage(
312
+ 'bar',
313
+ cache_seed=0,
314
+ score=-1.0,
315
+ logprobs=None,
316
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
317
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
318
+ ),
319
+ score=-1.0,
320
+ logprobs=None,
201
321
  )
202
322
  ],
203
323
  usage=lm_lib.LMSamplingUsage(100, 100, 200),
@@ -225,7 +345,16 @@ class LanguageModelTest(unittest.TestCase):
225
345
  lm_lib.LMSamplingResult(
226
346
  [
227
347
  lm_lib.LMSample(
228
- message_lib.AIMessage('foo', cache_seed=0), score=1.0
348
+ message_lib.AIMessage(
349
+ 'foo',
350
+ cache_seed=0,
351
+ score=1.0,
352
+ logprobs=None,
353
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
354
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
355
+ ),
356
+ score=1.0,
357
+ logprobs=None,
229
358
  )
230
359
  ],
231
360
  usage=lm_lib.LMSamplingUsage(100, 100, 200),
@@ -233,7 +362,16 @@ class LanguageModelTest(unittest.TestCase):
233
362
  lm_lib.LMSamplingResult(
234
363
  [
235
364
  lm_lib.LMSample(
236
- message_lib.AIMessage('baz', cache_seed=0), score=1.0
365
+ message_lib.AIMessage(
366
+ 'baz',
367
+ cache_seed=0,
368
+ score=1.0,
369
+ logprobs=None,
370
+ usage=lm_lib.LMSamplingUsage(100, 100, 200),
371
+ tags=[message_lib.Message.TAG_LM_RESPONSE],
372
+ ),
373
+ score=1.0,
374
+ logprobs=None,
237
375
  )
238
376
  ],
239
377
  usage=lm_lib.LMSamplingUsage(100, 100, 200),
@@ -28,7 +28,19 @@ class EchoTest(unittest.TestCase):
28
28
  lm.sample(['hi']),
29
29
  [
30
30
  lf.LMSamplingResult(
31
- [lf.LMSample('hi', 1.0)],
31
+ [
32
+ lf.LMSample(
33
+ lf.AIMessage(
34
+ 'hi',
35
+ score=1.0,
36
+ logprobs=None,
37
+ usage=lf.LMSamplingUsage(2, 2, 4),
38
+ tags=[lf.Message.TAG_LM_RESPONSE],
39
+ ),
40
+ score=1.0,
41
+ logprobs=None,
42
+ )
43
+ ],
32
44
  lf.LMSamplingUsage(2, 2, 4))
33
45
  ]
34
46
  )
@@ -60,7 +72,19 @@ class StaticResponseTest(unittest.TestCase):
60
72
  lm.sample(['hi']),
61
73
  [
62
74
  lf.LMSamplingResult(
63
- [lf.LMSample(canned_response, 1.0)],
75
+ [
76
+ lf.LMSample(
77
+ lf.AIMessage(
78
+ canned_response,
79
+ score=1.0,
80
+ logprobs=None,
81
+ usage=lf.LMSamplingUsage(2, 38, 40),
82
+ tags=[lf.Message.TAG_LM_RESPONSE],
83
+ ),
84
+ score=1.0,
85
+ logprobs=None,
86
+ )
87
+ ],
64
88
  usage=lf.LMSamplingUsage(2, 38, 40)
65
89
  )
66
90
  ],
@@ -69,7 +93,19 @@ class StaticResponseTest(unittest.TestCase):
69
93
  lm.sample(['Tell me a joke.']),
70
94
  [
71
95
  lf.LMSamplingResult(
72
- [lf.LMSample(canned_response, 1.0)],
96
+ [
97
+ lf.LMSample(
98
+ lf.AIMessage(
99
+ canned_response,
100
+ score=1.0,
101
+ logprobs=None,
102
+ usage=lf.LMSamplingUsage(15, 38, 53),
103
+ tags=[lf.Message.TAG_LM_RESPONSE],
104
+ ),
105
+ score=1.0,
106
+ logprobs=None,
107
+ )
108
+ ],
73
109
  usage=lf.LMSamplingUsage(15, 38, 53)
74
110
  )
75
111
  ],
@@ -101,11 +137,35 @@ class StaticMappingTest(unittest.TestCase):
101
137
  lm.sample(['Hi', 'How are you?']),
102
138
  [
103
139
  lf.LMSamplingResult(
104
- [lf.LMSample('Hello', 1.0)],
140
+ [
141
+ lf.LMSample(
142
+ lf.AIMessage(
143
+ 'Hello',
144
+ score=1.0,
145
+ logprobs=None,
146
+ usage=lf.LMSamplingUsage(2, 5, 7),
147
+ tags=[lf.Message.TAG_LM_RESPONSE],
148
+ ),
149
+ score=1.0,
150
+ logprobs=None,
151
+ )
152
+ ],
105
153
  usage=lf.LMSamplingUsage(2, 5, 7)
106
154
  ),
107
155
  lf.LMSamplingResult(
108
- [lf.LMSample('I am fine, how about you?', 1.0)],
156
+ [
157
+ lf.LMSample(
158
+ lf.AIMessage(
159
+ 'I am fine, how about you?',
160
+ score=1.0,
161
+ logprobs=None,
162
+ usage=lf.LMSamplingUsage(12, 25, 37),
163
+ tags=[lf.Message.TAG_LM_RESPONSE],
164
+ ),
165
+ score=1.0,
166
+ logprobs=None,
167
+ )
168
+ ],
109
169
  usage=lf.LMSamplingUsage(12, 25, 37)
110
170
  )
111
171
  ]
@@ -126,11 +186,35 @@ class StaticSequenceTest(unittest.TestCase):
126
186
  lm.sample(['Hi', 'How are you?']),
127
187
  [
128
188
  lf.LMSamplingResult(
129
- [lf.LMSample('Hello', 1.0)],
189
+ [
190
+ lf.LMSample(
191
+ lf.AIMessage(
192
+ 'Hello',
193
+ score=1.0,
194
+ logprobs=None,
195
+ usage=lf.LMSamplingUsage(2, 5, 7),
196
+ tags=[lf.Message.TAG_LM_RESPONSE],
197
+ ),
198
+ score=1.0,
199
+ logprobs=None,
200
+ )
201
+ ],
130
202
  usage=lf.LMSamplingUsage(2, 5, 7)
131
203
  ),
132
204
  lf.LMSamplingResult(
133
- [lf.LMSample('I am fine, how about you?', 1.0)],
205
+ [
206
+ lf.LMSample(
207
+ lf.AIMessage(
208
+ 'I am fine, how about you?',
209
+ score=1.0,
210
+ logprobs=None,
211
+ usage=lf.LMSamplingUsage(12, 25, 37),
212
+ tags=[lf.Message.TAG_LM_RESPONSE],
213
+ ),
214
+ score=1.0,
215
+ logprobs=None,
216
+ )
217
+ ],
134
218
  usage=lf.LMSamplingUsage(12, 25, 37)
135
219
  )
136
220
  ]