langfun 0.0.2.dev20240408__tar.gz → 0.0.2.dev20240413__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/PKG-INFO +1 -1
  2. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/eval/base.py +25 -0
  3. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/eval/base_test.py +17 -7
  4. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/eval/matching.py +8 -8
  5. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/eval/matching_test.py +1 -1
  6. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/eval/scoring.py +4 -5
  7. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/eval/scoring_test.py +1 -1
  8. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/langfunc_test.py +2 -2
  9. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/language_model.py +19 -4
  10. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/language_model_test.py +8 -8
  11. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/llms/cache/in_memory_test.py +24 -24
  12. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/llms/google_genai_test.py +8 -3
  13. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/llms/llama_cpp.py +3 -1
  14. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/llms/openai.py +8 -2
  15. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/llms/openai_test.py +0 -1
  16. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/prompting.py +3 -1
  17. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/prompting_test.py +26 -8
  18. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/schema_test.py +3 -3
  19. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun.egg-info/PKG-INFO +1 -1
  20. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/LICENSE +0 -0
  21. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/README.md +0 -0
  22. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/__init__.py +0 -0
  23. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/__init__.py +0 -0
  24. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/coding/__init__.py +0 -0
  25. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/coding/python/__init__.py +0 -0
  26. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/coding/python/correction.py +0 -0
  27. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/coding/python/correction_test.py +0 -0
  28. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/coding/python/errors.py +0 -0
  29. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/coding/python/errors_test.py +0 -0
  30. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/coding/python/execution.py +0 -0
  31. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/coding/python/execution_test.py +0 -0
  32. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/coding/python/generation.py +0 -0
  33. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/coding/python/generation_test.py +0 -0
  34. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/coding/python/parsing.py +0 -0
  35. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/coding/python/parsing_test.py +0 -0
  36. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/coding/python/permissions.py +0 -0
  37. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/coding/python/permissions_test.py +0 -0
  38. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/component.py +0 -0
  39. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/component_test.py +0 -0
  40. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/concurrent.py +0 -0
  41. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/concurrent_test.py +0 -0
  42. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/console.py +0 -0
  43. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/console_test.py +0 -0
  44. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/eval/__init__.py +0 -0
  45. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/langfunc.py +0 -0
  46. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/llms/__init__.py +0 -0
  47. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/llms/cache/__init__.py +0 -0
  48. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/llms/cache/base.py +0 -0
  49. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/llms/cache/in_memory.py +0 -0
  50. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/llms/fake.py +0 -0
  51. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/llms/fake_test.py +0 -0
  52. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/llms/google_genai.py +0 -0
  53. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/llms/llama_cpp_test.py +0 -0
  54. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/memories/__init__.py +0 -0
  55. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/memories/conversation_history.py +0 -0
  56. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/memories/conversation_history_test.py +0 -0
  57. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/memory.py +0 -0
  58. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/message.py +0 -0
  59. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/message_test.py +0 -0
  60. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/modalities/__init__.py +0 -0
  61. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/modalities/image.py +0 -0
  62. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/modalities/image_test.py +0 -0
  63. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/modalities/mime.py +0 -0
  64. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/modalities/mime_test.py +0 -0
  65. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/modalities/video.py +0 -0
  66. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/modalities/video_test.py +0 -0
  67. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/modality.py +0 -0
  68. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/modality_test.py +0 -0
  69. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/natural_language.py +0 -0
  70. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/natural_language_test.py +0 -0
  71. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/sampling.py +0 -0
  72. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/sampling_test.py +0 -0
  73. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/__init__.py +0 -0
  74. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/completion.py +0 -0
  75. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/completion_test.py +0 -0
  76. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/description.py +0 -0
  77. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/description_test.py +0 -0
  78. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/mapping.py +0 -0
  79. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/mapping_test.py +0 -0
  80. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/parsing.py +0 -0
  81. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/parsing_test.py +0 -0
  82. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/schema.py +0 -0
  83. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/schema_generation.py +0 -0
  84. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/schema_generation_test.py +0 -0
  85. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/scoring.py +0 -0
  86. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/structured/scoring_test.py +0 -0
  87. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/subscription.py +0 -0
  88. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/subscription_test.py +0 -0
  89. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/template.py +0 -0
  90. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/template_test.py +0 -0
  91. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/templates/__init__.py +0 -0
  92. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/templates/completion.py +0 -0
  93. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/templates/completion_test.py +0 -0
  94. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/templates/conversation.py +0 -0
  95. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/templates/conversation_test.py +0 -0
  96. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/templates/demonstration.py +0 -0
  97. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/templates/demonstration_test.py +0 -0
  98. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/templates/selfplay.py +0 -0
  99. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/templates/selfplay_test.py +0 -0
  100. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/text_formatting.py +0 -0
  101. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun/core/text_formatting_test.py +0 -0
  102. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun.egg-info/SOURCES.txt +0 -0
  103. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun.egg-info/dependency_links.txt +0 -0
  104. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun.egg-info/requires.txt +0 -0
  105. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/langfun.egg-info/top_level.txt +0 -0
  106. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/setup.cfg +0 -0
  107. {langfun-0.0.2.dev20240408 → langfun-0.0.2.dev20240413}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240408
3
+ Version: 0.0.2.dev20240413
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -1551,8 +1551,33 @@ class Summary(pg.Object):
1551
1551
  def _repr_html_(self) -> str:
1552
1552
  return self.html()
1553
1553
 
1554
+ def json(
1555
+ self,
1556
+ ) -> dict[
1557
+ str, # Task name
1558
+ list[pg.Dict], # List of pg.Dict with `experiment` and `metrics`.
1559
+ ]:
1560
+ """Returns the JSON representation of the summary."""
1561
+ task_results = {}
1562
+ for task in sorted(self.tasks(), key=lambda cls: cls.__name__):
1563
+ results = []
1564
+ for entry in self.select(task=task).evaluations:
1565
+ results.append(
1566
+ pg.Dict(
1567
+ experiment=entry,
1568
+ metrics=entry.result.metrics if entry.result else None,
1569
+ )
1570
+ )
1571
+ task_results[task.__name__] = results
1572
+ return task_results
1573
+
1554
1574
  def save(self, file: str, pivot_field: str | None = None) -> None:
1555
1575
  pg.save(self.html(pivot_field), file, file_format='txt')
1576
+ if file.endswith('.html'):
1577
+ json_file = file.replace('.html', '.json')
1578
+ else:
1579
+ json_file = os.path.join(file, '.json')
1580
+ pg.save(self.json(), json_file)
1556
1581
 
1557
1582
  @classmethod
1558
1583
  def from_dirs(
@@ -101,7 +101,7 @@ class EvaluationTest(unittest.TestCase):
101
101
  self.assertEqual(s.dir, os.path.join(s.root_dir, s.id))
102
102
  self.assertEqual(s.hash, s.clone().hash)
103
103
  # Test persistent hash.
104
- self.assertEqual(s.hash, 'abc7c29a')
104
+ self.assertEqual(s.hash, '436dc80c')
105
105
  self.assertEqual(
106
106
  s.hash, s.clone(override={'max_workers': 2, 'lm.timeout': 20}).hash
107
107
  )
@@ -209,7 +209,7 @@ class EvaluationTest(unittest.TestCase):
209
209
  s.result,
210
210
  dict(
211
211
  experiment_setup=dict(
212
- id='Evaluation@17915dc6',
212
+ id='Evaluation@f1aa5126',
213
213
  dir=s.dir,
214
214
  model='StaticSequence',
215
215
  prompt_template='{{example.question}}',
@@ -228,13 +228,23 @@ class EvaluationTest(unittest.TestCase):
228
228
  os.path.exists(os.path.join(s.dir, base.Evaluation.RESULT_JSON)))
229
229
  self.assertTrue(
230
230
  os.path.exists(os.path.join(s.dir, base.Evaluation.CACHE_JSON)))
231
- self.assertTrue(
232
- os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
233
- )
234
231
  self.assertTrue(
235
232
  os.path.exists(os.path.join(s.dir, base.Evaluation.INDEX_HTML)))
236
233
  self.assertTrue(
237
234
  os.path.exists(os.path.join(s.dir, base.Evaluation.FAILURES_HTML)))
235
+ self.assertTrue(
236
+ os.path.exists(os.path.join(s.root_dir, base.Evaluation.SUMMARY_HTML))
237
+ )
238
+ # Check summary JSON.
239
+ summary_json = os.path.join(
240
+ s.root_dir, base.Evaluation.SUMMARY_HTML.replace('.html', '.json')
241
+ )
242
+ self.assertTrue(os.path.exists(summary_json))
243
+ summary = pg.load(summary_json, force_dict=True)
244
+ self.assertIn('Evaluation', summary)
245
+ self.assertEqual(len(summary['Evaluation']), 1)
246
+ self.assertIsNotNone(summary['Evaluation'][0].experiment)
247
+ self.assertIsNotNone(summary['Evaluation'][0].metrics)
238
248
 
239
249
  def test_run_wihtout_save(self):
240
250
  lm = fake.StaticSequence([
@@ -321,7 +331,7 @@ class EvaluationTest(unittest.TestCase):
321
331
  s.children[0].dir, os.path.join(s.root_dir, s.children[0].id)
322
332
  )
323
333
  # Test persistent hash.
324
- self.assertEqual(s.hash, 'ca7f722b')
334
+ self.assertEqual(s.hash, 'b66a4e88')
325
335
 
326
336
  summary = s.run(verbose=True)
327
337
  self.assertEqual(len(summary.evaluations), 2)
@@ -448,7 +458,7 @@ class SuiteTest(unittest.TestCase):
448
458
  lm=lm
449
459
  )
450
460
  # Test for persistent hash.
451
- self.assertEqual(s.hash, '7285e52b')
461
+ self.assertEqual(s.hash, 'bbfdc7a8')
452
462
  s.run()
453
463
  expected = {
454
464
  s.children[0].id: dict(
@@ -155,19 +155,16 @@ class Matching(base.Evaluation):
155
155
  super().save(definition, result, report)
156
156
 
157
157
  if result:
158
-
159
- def force_dict(v):
160
- return pg.object_utils.json_conversion.strip_types(pg.to_json(v))
161
-
162
158
  # Save matches.
163
159
  pg.save(
164
160
  [
165
- # We force the output to be dict as its type may be defined
166
- # within functors which could be deserialized.
167
- pg.Dict(input=input, output=force_dict(output))
161
+ pg.Dict(input=input, output=output)
168
162
  for input, output, _ in self.matches
169
163
  ],
170
164
  os.path.join(self.dir, Matching.MATCHES_JSON),
165
+ # We force the input and output to be dict so it does not depend on
166
+ # the downstream to serialize.
167
+ force_dict=True,
171
168
  )
172
169
 
173
170
  # Save mismatches.
@@ -175,10 +172,13 @@ class Matching(base.Evaluation):
175
172
  [
176
173
  # We force the output to be dict as its type may be defined
177
174
  # within functors which could be deserialized.
178
- pg.Dict(input=input, output=force_dict(output))
175
+ pg.Dict(input=input, output=output)
179
176
  for input, output, _ in self.mismatches
180
177
  ],
181
178
  os.path.join(self.dir, Matching.MISMATCHES_JSON),
179
+ # We force the input and output to be dict so it does not depend on
180
+ # the downstream to serialize.
181
+ force_dict=True,
182
182
  )
183
183
 
184
184
  if report:
@@ -103,7 +103,7 @@ class MatchingTest(unittest.TestCase):
103
103
  s.result,
104
104
  dict(
105
105
  experiment_setup=dict(
106
- id='MyTask@3d87f97f',
106
+ id='MyTask@acd56a61',
107
107
  dir=s.dir,
108
108
  model='StaticSequence',
109
109
  prompt_template='{{example.question}}',
@@ -118,19 +118,18 @@ class Scoring(base.Evaluation):
118
118
  super().save(definition, result, report)
119
119
 
120
120
  if result:
121
-
122
- def force_dict(v):
123
- return pg.object_utils.json_conversion.strip_types(pg.to_json(v))
124
-
125
121
  # Save scored.
126
122
  pg.save(
127
123
  [
128
124
  # We force the output to be dict as its type may be defined
129
125
  # within functors which could be deserialized.
130
- pg.Dict(input=input, output=force_dict(output), score=score)
126
+ pg.Dict(input=input, output=output, score=score)
131
127
  for input, output, score, _ in self.scored
132
128
  ],
133
129
  os.path.join(self.dir, Scoring.SCORED_JSON),
130
+ # We force the input and output to be dict so it does not depend on
131
+ # the downstream to serialize.
132
+ force_dict=True,
134
133
  )
135
134
 
136
135
  if report:
@@ -81,7 +81,7 @@ class ScoringTest(unittest.TestCase):
81
81
  s.result,
82
82
  dict(
83
83
  experiment_setup=dict(
84
- id='ConstraintFollowing@9e51bb9e',
84
+ id='ConstraintFollowing@a44d8b89',
85
85
  dir=s.dir,
86
86
  model='StaticSequence',
87
87
  prompt_template='{{example}}',
@@ -92,8 +92,8 @@ class LangFuncCallTest(unittest.TestCase):
92
92
  self.assertEqual(
93
93
  repr(l),
94
94
  "LangFunc(template_str='Hello', clean=True,"
95
- ' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0,'
96
- ' max_tokens=1024, n=1, top_k=40, top_p=None, stop=None,'
95
+ ' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=None,'
96
+ ' max_tokens=None, n=1, top_k=40, top_p=None, stop=None,'
97
97
  ' random_seed=None, logprobs=False, top_logprobs=None), cache=None,'
98
98
  ' max_concurrency=None, timeout=120.0, max_attempts=5,'
99
99
  ' retry_interval=(5, 60), exponential_backoff=True, debug=False))',
@@ -63,14 +63,24 @@ class LMSamplingOptions(component.Component):
63
63
  """Language model sampling options."""
64
64
 
65
65
  temperature: Annotated[
66
- float,
66
+ float | None,
67
67
  (
68
68
  'Model temperature, which is usually between 0 and 1.0. '
69
- 'OpenAI models have temperature range from 0.0 to 2.0.'
69
+ 'OpenAI models have temperature range from 0.0 to 2.0. '
70
+ 'If None (default), honor the model\'s default behavior. '
70
71
  )
71
- ] = 0.0
72
- max_tokens: Annotated[int, 'Per example max tokens to generate.'] = 1024
72
+ ] = None
73
+
74
+ max_tokens: Annotated[
75
+ int | None,
76
+ (
77
+ 'Per example max tokens to generate. '
78
+ 'If None, use the model default.'
79
+ )
80
+ ] = None
81
+
73
82
  n: Annotated[int | None, 'Max number of samples to return.'] = 1
83
+
74
84
  top_k: Annotated[
75
85
  int | None,
76
86
  (
@@ -78,6 +88,7 @@ class LMSamplingOptions(component.Component):
78
88
  'Not applicable to OpenAI models.'
79
89
  )
80
90
  ] = 40
91
+
81
92
  top_p: Annotated[
82
93
  float | None,
83
94
  (
@@ -86,6 +97,7 @@ class LMSamplingOptions(component.Component):
86
97
  '`top_p` but not both.'
87
98
  ),
88
99
  ] = None
100
+
89
101
  stop: Annotated[
90
102
  list[str] | None,
91
103
  (
@@ -95,9 +107,11 @@ class LMSamplingOptions(component.Component):
95
107
  '`Model:` is reached.'
96
108
  ),
97
109
  ] = None
110
+
98
111
  random_seed: Annotated[
99
112
  int | None, 'A fixed random seed used during model inference.'
100
113
  ] = None
114
+
101
115
  logprobs: Annotated[
102
116
  bool,
103
117
  (
@@ -106,6 +120,7 @@ class LMSamplingOptions(component.Component):
106
120
  'in the content of message.'
107
121
  ),
108
122
  ] = False
123
+
109
124
  top_logprobs: Annotated[
110
125
  int | None,
111
126
  (
@@ -40,7 +40,7 @@ class MockModel(lm_lib.LanguageModel):
40
40
  return [
41
41
  lm_lib.LMSamplingResult([lm_lib.LMSample( # pylint: disable=g-complex-comprehension
42
42
  response=prompt.text * self.sampling_options.top_k,
43
- score=self.sampling_options.temperature)])
43
+ score=self.sampling_options.temperature or -1.0)])
44
44
  for prompt in prompts
45
45
  ]
46
46
  context.attempt += 1
@@ -73,13 +73,13 @@ class LMSamplingOptionsTest(unittest.TestCase):
73
73
  def test_cache_key(self):
74
74
  options = lm_lib.LMSamplingOptions()
75
75
  key1 = options.cache_key()
76
- self.assertEqual(key1, (0.0, 1024, 1, 40, None, None))
76
+ self.assertEqual(key1, (None, None, 1, 40, None, None))
77
77
  with options.override(temperature=1.0, max_tokens=256):
78
78
  key2 = options.cache_key()
79
79
  self.assertEqual(key2, (1.0, 256, 1, 40, None, None))
80
80
 
81
81
  # Make sure key1 does not change upon override.
82
- self.assertEqual(key1, (0.0, 1024, 1, 40, None, None))
82
+ self.assertEqual(key1, (None, None, 1, 40, None, None))
83
83
 
84
84
 
85
85
  class LanguageModelTest(unittest.TestCase):
@@ -100,8 +100,8 @@ class LanguageModelTest(unittest.TestCase):
100
100
  self.assertEqual(
101
101
  lm.sample(prompts=['foo', 'bar']),
102
102
  [
103
- lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=0.0)]),
104
- lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=0.0)]),
103
+ lm_lib.LMSamplingResult([lm_lib.LMSample('foo', score=-1.0)]),
104
+ lm_lib.LMSamplingResult([lm_lib.LMSample('bar', score=-1.0)]),
105
105
  ],
106
106
  )
107
107
  # Test override sampling_options.
@@ -143,7 +143,7 @@ class LanguageModelTest(unittest.TestCase):
143
143
  lm = MockModel(sampling_options=lm_lib.LMSamplingOptions(top_k=1))
144
144
  response = lm(prompt='foo')
145
145
  self.assertEqual(response.text, 'foo')
146
- self.assertEqual(response.score, 0.0)
146
+ self.assertEqual(response.score, -1.0)
147
147
 
148
148
  # Test override sampling_options.
149
149
  self.assertEqual(
@@ -159,9 +159,9 @@ class LanguageModelTest(unittest.TestCase):
159
159
  lm.sample(prompts=['foo', 'bar']),
160
160
  [
161
161
  lm_lib.LMSamplingResult([lm_lib.LMSample(
162
- message_lib.AIMessage('foo', cache_seed=0), score=0.0)]),
162
+ message_lib.AIMessage('foo', cache_seed=0), score=-1.0)]),
163
163
  lm_lib.LMSamplingResult([lm_lib.LMSample(
164
- message_lib.AIMessage('bar', cache_seed=0), score=0.0)]),
164
+ message_lib.AIMessage('bar', cache_seed=0), score=-1.0)]),
165
165
  ])
166
166
  self.assertEqual(cache.stats.num_queries, 2)
167
167
  self.assertEqual(cache.stats.num_hits, 0)
@@ -44,19 +44,19 @@ class InMemoryLMCacheTest(unittest.TestCase):
44
44
  self.assertEqual(
45
45
  list(cache.keys()),
46
46
  [
47
- ('a', (0.0, 1024, 1, 40, None, None), 0),
48
- ('a', (0.0, 1024, 1, 40, None, None), 1),
49
- ('b', (0.0, 1024, 1, 40, None, None), 0),
50
- ('c', (0.0, 1024, 1, 40, None, None), 0),
47
+ ('a', (None, None, 1, 40, None, None), 0),
48
+ ('a', (None, None, 1, 40, None, None), 1),
49
+ ('b', (None, None, 1, 40, None, None), 0),
50
+ ('c', (None, None, 1, 40, None, None), 0),
51
51
  ],
52
52
  )
53
53
  self.assertEqual(
54
54
  list(cache.keys('StaticSequence')),
55
55
  [
56
- ('a', (0.0, 1024, 1, 40, None, None), 0),
57
- ('a', (0.0, 1024, 1, 40, None, None), 1),
58
- ('b', (0.0, 1024, 1, 40, None, None), 0),
59
- ('c', (0.0, 1024, 1, 40, None, None), 0),
56
+ ('a', (None, None, 1, 40, None, None), 0),
57
+ ('a', (None, None, 1, 40, None, None), 1),
58
+ ('b', (None, None, 1, 40, None, None), 0),
59
+ ('c', (None, None, 1, 40, None, None), 0),
60
60
  ],
61
61
  )
62
62
 
@@ -90,19 +90,19 @@ class InMemoryLMCacheTest(unittest.TestCase):
90
90
  list(cache.items()),
91
91
  [
92
92
  (
93
- ('a', (0.0, 1024, 1, 40, None, None), 0),
93
+ ('a', (None, None, 1, 40, None, None), 0),
94
94
  cache_entry('1'),
95
95
  ),
96
96
  (
97
- ('a', (0.0, 1024, 1, 40, None, None), 1),
97
+ ('a', (None, None, 1, 40, None, None), 1),
98
98
  cache_entry('2', 1),
99
99
  ),
100
100
  (
101
- ('b', (0.0, 1024, 1, 40, None, None), 0),
101
+ ('b', (None, None, 1, 40, None, None), 0),
102
102
  cache_entry('3'),
103
103
  ),
104
104
  (
105
- ('c', (0.0, 1024, 1, 40, None, None), 0),
105
+ ('c', (None, None, 1, 40, None, None), 0),
106
106
  cache_entry('4'),
107
107
  ),
108
108
  ],
@@ -111,19 +111,19 @@ class InMemoryLMCacheTest(unittest.TestCase):
111
111
  list(cache.items('StaticSequence')),
112
112
  [
113
113
  (
114
- ('a', (0.0, 1024, 1, 40, None, None), 0),
114
+ ('a', (None, None, 1, 40, None, None), 0),
115
115
  cache_entry('1'),
116
116
  ),
117
117
  (
118
- ('a', (0.0, 1024, 1, 40, None, None), 1),
118
+ ('a', (None, None, 1, 40, None, None), 1),
119
119
  cache_entry('2', 1),
120
120
  ),
121
121
  (
122
- ('b', (0.0, 1024, 1, 40, None, None), 0),
122
+ ('b', (None, None, 1, 40, None, None), 0),
123
123
  cache_entry('3'),
124
124
  ),
125
125
  (
126
- ('c', (0.0, 1024, 1, 40, None, None), 0),
126
+ ('c', (None, None, 1, 40, None, None), 0),
127
127
  cache_entry('4'),
128
128
  ),
129
129
  ],
@@ -161,15 +161,15 @@ class InMemoryLMCacheTest(unittest.TestCase):
161
161
  self.assertEqual(
162
162
  list(cache.keys()),
163
163
  [
164
- ('a', (0.0, 1024, 1, 40, None, None), 0),
165
- ('a', (1.0, 1024, 1, 40, None, None), 0),
164
+ ('a', (None, None, 1, 40, None, None), 0),
165
+ ('a', (1.0, None, 1, 40, None, None), 0),
166
166
  ],
167
167
  )
168
168
 
169
169
  def test_different_model(self):
170
170
  cache = in_memory.InMemory()
171
- lm1 = fake.StaticSequence(['1', '2', '3'], cache=cache)
172
- lm2 = fake.Echo(cache=cache)
171
+ lm1 = fake.StaticSequence(['1', '2', '3'], cache=cache, temperature=0.0)
172
+ lm2 = fake.Echo(cache=cache, temperature=0.0)
173
173
 
174
174
  self.assertEqual(lm1('a'), '1')
175
175
  self.assertEqual(lm2('a'), 'a')
@@ -180,15 +180,15 @@ class InMemoryLMCacheTest(unittest.TestCase):
180
180
  self.assertEqual(
181
181
  list(cache.keys('StaticSequence')),
182
182
  [
183
- ('a', (0.0, 1024, 1, 40, None, None), 0),
184
- ('b', (0.0, 1024, 1, 40, None, None), 0),
183
+ ('a', (0.0, None, 1, 40, None, None), 0),
184
+ ('b', (0.0, None, 1, 40, None, None), 0),
185
185
  ],
186
186
  )
187
187
  self.assertEqual(
188
188
  list(cache.keys('Echo')),
189
189
  [
190
- ('a', (0.0, 1024, 1, 40, None, None), 0),
191
- ('b', (0.0, 1024, 1, 40, None, None), 0),
190
+ ('a', (0.0, None, 1, 40, None, None), 0),
191
+ ('b', (0.0, None, 1, 40, None, None), 0),
192
192
  ],
193
193
  )
194
194
  self.assertEqual(len(cache), 4)
@@ -152,10 +152,15 @@ class GenAITest(unittest.TestCase):
152
152
  )
153
153
 
154
154
  def test_model_hub(self):
155
+ orig_get_model = genai.get_model
156
+ genai.get_model = mock_get_model
157
+
155
158
  model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
156
159
  self.assertIsNotNone(model)
157
160
  self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
158
161
 
162
+ genai.get_model = orig_get_model
163
+
159
164
  def test_api_key_check(self):
160
165
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
161
166
  _ = google_genai.GeminiPro()._api_initialized
@@ -167,7 +172,7 @@ class GenAITest(unittest.TestCase):
167
172
 
168
173
  def test_call(self):
169
174
  with mock.patch(
170
- 'google.generativeai.generative_models.GenerativeModel.generate_content'
175
+ 'google.generativeai.GenerativeModel.generate_content',
171
176
  ) as mock_generate:
172
177
  orig_get_model = genai.get_model
173
178
  genai.get_model = mock_get_model
@@ -176,7 +181,7 @@ class GenAITest(unittest.TestCase):
176
181
  lm = google_genai.GeminiPro(api_key='test_key')
177
182
  self.maxDiff = None
178
183
  self.assertEqual(
179
- lm('hello', temperature=2.0, top_k=20).text,
184
+ lm('hello', temperature=2.0, top_k=20, max_tokens=1024).text,
180
185
  (
181
186
  'This is a response to hello with n=1, temperature=2.0, '
182
187
  'top_p=None, top_k=20, max_tokens=1024, stop=None.'
@@ -197,7 +202,7 @@ class GenAITest(unittest.TestCase):
197
202
  (
198
203
  "hello to models/text-bison-001 with {'temperature': 2.0, "
199
204
  "'top_k': 20, 'top_p': None, 'candidate_count': 1, "
200
- "'max_output_tokens': 1024, 'stop_sequences': None}"
205
+ "'max_output_tokens': None, 'stop_sequences': None}"
201
206
  ),
202
207
  )
203
208
  genai.get_model = orig_get_model
@@ -51,10 +51,12 @@ class LlamaCppRemote(lf.LanguageModel):
51
51
  data = {
52
52
  "prompt": prompt.text,
53
53
  "n_predict": self.sampling_options.max_tokens,
54
- "temperature": self.sampling_options.temperature,
55
54
  "top_k": self.sampling_options.top_k or 50,
56
55
  "top_p": self.sampling_options.top_p or 0.95,
57
56
  }
57
+ if self.sampling_options.temperature is not None:
58
+ data["temperature"] = self.sampling_options.temperature
59
+
58
60
  response = requests.post(
59
61
  f"{self.url}/completion",
60
62
  json=data,
@@ -163,8 +163,6 @@ class OpenAI(lf.LanguageModel):
163
163
  # NOTE(daiyip): options.top_k is not applicable.
164
164
  args = dict(
165
165
  n=options.n,
166
- temperature=options.temperature,
167
- max_tokens=options.max_tokens,
168
166
  stream=False,
169
167
  timeout=self.timeout,
170
168
  logprobs=options.logprobs,
@@ -173,6 +171,10 @@ class OpenAI(lf.LanguageModel):
173
171
  # Completion and ChatCompletion uses different parameter name for model.
174
172
  args['model' if self.is_chat_model else 'engine'] = self.model
175
173
 
174
+ if options.temperature is not None:
175
+ args['temperature'] = options.temperature
176
+ if options.max_tokens is not None:
177
+ args['max_tokens'] = options.max_tokens
176
178
  if options.top_p is not None:
177
179
  args['top_p'] = options.top_p
178
180
  if options.stop:
@@ -220,6 +222,10 @@ class OpenAI(lf.LanguageModel):
220
222
  retry_on_errors=(
221
223
  openai_error.ServiceUnavailableError,
222
224
  openai_error.RateLimitError,
225
+ # Handling transient OpenAI server error (code 500). Check out
226
+ # https://platform.openai.com/docs/guides/error-codes/error-codes
227
+ (openai_error.APIError,
228
+ '.*The server had an error processing your request'),
223
229
  ),
224
230
  )[0]
225
231
 
@@ -121,7 +121,6 @@ class OpenaiTest(unittest.TestCase):
121
121
  top_logprobs=None,
122
122
  n=1,
123
123
  temperature=1.0,
124
- max_tokens=1024,
125
124
  stream=False,
126
125
  timeout=120.0,
127
126
  stop=['\n'],
@@ -78,7 +78,9 @@ class QueryStructurePython(QueryStructure):
78
78
 
79
79
  {{ output_title }}:
80
80
  ```python
81
- Answer(final_answer=2)
81
+ Answer(
82
+ final_answer=2
83
+ )
82
84
  ```
83
85
  """
84
86
  protocol = 'python'
@@ -116,12 +116,26 @@ class QueryTest(unittest.TestCase):
116
116
  y=2,
117
117
  lm=lm.clone(),
118
118
  expected_snippet=(
119
- 'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT'
120
- ' according to OUTPUT_TYPE.\n\nINPUT_OBJECT:\n 1 + 1'
121
- ' =\n\nOUTPUT_TYPE:\n Answer\n\n ```python\n class Answer:\n '
122
- ' final_answer: int\n ```\n\nOUTPUT_OBJECT:\n ```python\n '
123
- ' Answer(final_answer=2)\n ```\n\nINPUT_OBJECT:\n What is 1 +'
124
- ' 2?\n\nOUTPUT_TYPE:\n int\n\nOUTPUT_OBJECT:'
119
+ 'Please respond to the last INPUT_OBJECT with OUTPUT_OBJECT '
120
+ 'according to OUTPUT_TYPE.\n\n'
121
+ 'INPUT_OBJECT:\n 1 + 1 =\n\n'
122
+ 'OUTPUT_TYPE:\n'
123
+ ' Answer\n\n'
124
+ ' ```python\n'
125
+ ' class Answer:\n'
126
+ ' final_answer: int\n'
127
+ ' ```\n\n'
128
+ 'OUTPUT_OBJECT:\n'
129
+ ' ```python\n'
130
+ ' Answer(\n'
131
+ ' final_answer=2\n'
132
+ ' )\n'
133
+ ' ```\n\n'
134
+ 'INPUT_OBJECT:\n'
135
+ ' What is 1 + 2?\n\n'
136
+ 'OUTPUT_TYPE:\n'
137
+ ' int\n\n'
138
+ 'OUTPUT_OBJECT:'
125
139
  ),
126
140
  )
127
141
 
@@ -264,7 +278,9 @@ class QueryStructurePythonTest(unittest.TestCase):
264
278
 
265
279
  OUTPUT_OBJECT:
266
280
  ```python
267
- Answer(final_answer=2)
281
+ Answer(
282
+ final_answer=2
283
+ )
268
284
  ```
269
285
 
270
286
  INPUT_OBJECT:
@@ -308,7 +324,9 @@ class QueryStructurePythonTest(unittest.TestCase):
308
324
 
309
325
  OUTPUT_OBJECT:
310
326
  ```python
311
- Answer(final_answer=2)
327
+ Answer(
328
+ final_answer=2
329
+ )
312
330
  ```
313
331
 
314
332
  INPUT_OBJECT:
@@ -192,9 +192,9 @@ class SchemaTest(unittest.TestCase):
192
192
  self.assertEqual(schema.parse('{"result": 1}'), 1)
193
193
  schema = schema_lib.Schema(dict[str, int])
194
194
  self.assertEqual(
195
- schema.parse(
196
- '{"result": {"_type": "Unknown", "x": 1}}}', force_dict=True),
197
- dict(x=1))
195
+ schema.parse('{"result": {"x": 1}}}'),
196
+ dict(x=1)
197
+ )
198
198
  with self.assertRaisesRegex(
199
199
  schema_lib.SchemaError, 'Expect .* but encountered .*'):
200
200
  schema.parse('{"result": "def"}')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240408
3
+ Version: 0.0.2.dev20240413
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors