langfun 0.0.2.dev20240327__tar.gz → 0.0.2.dev20240330__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/PKG-INFO +2 -2
  2. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/eval/base.py +39 -18
  3. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/eval/base_test.py +4 -9
  4. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/eval/matching_test.py +2 -4
  5. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/eval/scoring_test.py +1 -2
  6. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/__init__.py +5 -3
  7. langfun-0.0.2.dev20240327/langfun/core/llms/gemini.py → langfun-0.0.2.dev20240330/langfun/core/llms/google_genai.py +117 -15
  8. langfun-0.0.2.dev20240327/langfun/core/llms/gemini_test.py → langfun-0.0.2.dev20240330/langfun/core/llms/google_genai_test.py +76 -13
  9. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/schema.py +21 -20
  10. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/schema_test.py +38 -21
  11. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun.egg-info/PKG-INFO +2 -2
  12. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun.egg-info/SOURCES.txt +2 -2
  13. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun.egg-info/requires.txt +1 -1
  14. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/LICENSE +0 -0
  15. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/README.md +0 -0
  16. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/__init__.py +0 -0
  17. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/__init__.py +0 -0
  18. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/__init__.py +0 -0
  19. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/__init__.py +0 -0
  20. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/correction.py +0 -0
  21. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/correction_test.py +0 -0
  22. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/errors.py +0 -0
  23. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/errors_test.py +0 -0
  24. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/execution.py +0 -0
  25. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/execution_test.py +0 -0
  26. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/generation.py +0 -0
  27. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/generation_test.py +0 -0
  28. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/parsing.py +0 -0
  29. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/parsing_test.py +0 -0
  30. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/permissions.py +0 -0
  31. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/coding/python/permissions_test.py +0 -0
  32. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/component.py +0 -0
  33. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/component_test.py +0 -0
  34. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/concurrent.py +0 -0
  35. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/concurrent_test.py +0 -0
  36. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/console.py +0 -0
  37. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/console_test.py +0 -0
  38. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/eval/__init__.py +0 -0
  39. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/eval/matching.py +0 -0
  40. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/eval/scoring.py +0 -0
  41. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/langfunc.py +0 -0
  42. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/langfunc_test.py +0 -0
  43. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/language_model.py +0 -0
  44. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/language_model_test.py +0 -0
  45. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/cache/__init__.py +0 -0
  46. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/cache/base.py +0 -0
  47. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/cache/in_memory.py +0 -0
  48. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/cache/in_memory_test.py +0 -0
  49. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/fake.py +0 -0
  50. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/fake_test.py +0 -0
  51. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/llama_cpp.py +0 -0
  52. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/llama_cpp_test.py +0 -0
  53. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/openai.py +0 -0
  54. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/llms/openai_test.py +0 -0
  55. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/memories/__init__.py +0 -0
  56. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/memories/conversation_history.py +0 -0
  57. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/memories/conversation_history_test.py +0 -0
  58. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/memory.py +0 -0
  59. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/message.py +0 -0
  60. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/message_test.py +0 -0
  61. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modalities/__init__.py +0 -0
  62. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modalities/image.py +0 -0
  63. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modalities/image_test.py +0 -0
  64. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modalities/mime.py +0 -0
  65. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modalities/mime_test.py +0 -0
  66. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modalities/video.py +0 -0
  67. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modalities/video_test.py +0 -0
  68. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modality.py +0 -0
  69. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/modality_test.py +0 -0
  70. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/natural_language.py +0 -0
  71. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/natural_language_test.py +0 -0
  72. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/sampling.py +0 -0
  73. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/sampling_test.py +0 -0
  74. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/__init__.py +0 -0
  75. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/completion.py +0 -0
  76. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/completion_test.py +0 -0
  77. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/description.py +0 -0
  78. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/description_test.py +0 -0
  79. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/mapping.py +0 -0
  80. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/mapping_test.py +0 -0
  81. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/parsing.py +0 -0
  82. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/parsing_test.py +0 -0
  83. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/prompting.py +0 -0
  84. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/prompting_test.py +0 -0
  85. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/schema_generation.py +0 -0
  86. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/schema_generation_test.py +0 -0
  87. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/scoring.py +0 -0
  88. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/structured/scoring_test.py +0 -0
  89. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/subscription.py +0 -0
  90. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/subscription_test.py +0 -0
  91. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/template.py +0 -0
  92. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/template_test.py +0 -0
  93. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/__init__.py +0 -0
  94. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/completion.py +0 -0
  95. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/completion_test.py +0 -0
  96. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/conversation.py +0 -0
  97. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/conversation_test.py +0 -0
  98. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/demonstration.py +0 -0
  99. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/demonstration_test.py +0 -0
  100. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/selfplay.py +0 -0
  101. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/templates/selfplay_test.py +0 -0
  102. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/text_formatting.py +0 -0
  103. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun/core/text_formatting_test.py +0 -0
  104. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun.egg-info/dependency_links.txt +0 -0
  105. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/langfun.egg-info/top_level.txt +0 -0
  106. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/setup.cfg +0 -0
  107. {langfun-0.0.2.dev20240327 → langfun-0.0.2.dev20240330}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240327
3
+ Version: 0.0.2.dev20240330
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -24,7 +24,7 @@ License-File: LICENSE
24
24
  Requires-Dist: google-generativeai>=0.3.2
25
25
  Requires-Dist: jinja2>=3.1.2
26
26
  Requires-Dist: openai==0.27.2
27
- Requires-Dist: pyglove>=0.4.5.dev20240314
27
+ Requires-Dist: pyglove>=0.4.5.dev20240323
28
28
  Requires-Dist: python-magic>=0.4.27
29
29
  Requires-Dist: requests>=2.31.0
30
30
  Requires-Dist: termcolor==1.1.0
@@ -27,6 +27,7 @@ import time
27
27
  from typing import Annotated, Any, Callable, Iterator, Literal, Optional, Sequence, Type, Union
28
28
 
29
29
  import langfun.core as lf
30
+ import langfun.core.coding as lf_coding
30
31
  from langfun.core.llms.cache import in_memory
31
32
  import langfun.core.structured as lf_structured
32
33
  import pyglove as pg
@@ -41,14 +42,6 @@ class Evaluable(lf.Component):
41
42
  INDEX_HTML = 'index.html'
42
43
  SUMMARY_HTML = 'summary.html'
43
44
 
44
- id: Annotated[
45
- str,
46
- (
47
- 'The ID of the evaluation, which should be unique across all '
48
- 'evaluations.'
49
- ),
50
- ]
51
-
52
45
  root_dir: Annotated[
53
46
  str | None,
54
47
  (
@@ -61,6 +54,18 @@ class Evaluable(lf.Component):
61
54
  int, 'Number of decimals when reporting precision.'
62
55
  ] = lf.contextual(default=1)
63
56
 
57
+ @property
58
+ @abc.abstractmethod
59
+ def id(self) -> str:
60
+ """Returns the ID of the task.
61
+
62
+ Returns:
63
+ Evaluation task ID. Different evaluation task should have their unique
64
+ task IDs, for each task will be stored in sub-directoreis identified by
65
+ their IDs. For suites, the ID could be an empty string as they will not
66
+ produce sub-directories
67
+ """
68
+
64
69
  @property
65
70
  def dir(self) -> str | None:
66
71
  """Returns the directory for saving results and details."""
@@ -578,12 +583,15 @@ class _LeafNode:
578
583
  progress_bar: int | None = None
579
584
 
580
585
 
581
- @pg.use_init_args(['id', 'children'])
586
+ @pg.use_init_args(['children'])
582
587
  class Suite(Evaluable):
583
588
  """Evaluation suite."""
584
589
 
585
590
  children: Annotated[list[Evaluable], 'Child evaluation sets or suites.']
586
591
 
592
+ # Use empty ID as suite is just a container of child evaluations.
593
+ id: str = ''
594
+
587
595
  __kwargs__: Annotated[
588
596
  Any,
589
597
  (
@@ -841,8 +849,10 @@ class Evaluation(Evaluable):
841
849
  kwargs['evaluation'] = self
842
850
  return self.schema_fn(**kwargs)
843
851
 
844
- def _formalize_schema(self, annotation) -> lf_structured.Schema:
852
+ def _formalize_schema(self, annotation) -> lf_structured.Schema | None:
845
853
  """Formalizes schema from annotation."""
854
+ if annotation in (str, None):
855
+ return None
846
856
  if self.method == 'complete':
847
857
  if not hasattr(annotation, '__schema__'):
848
858
  raise TypeError(
@@ -883,6 +893,14 @@ class Evaluation(Evaluable):
883
893
  completion_examples.append(ex)
884
894
  return completion_examples
885
895
 
896
+ @property
897
+ def id(self) -> str:
898
+ """Returns the ID of this evaluation."""
899
+ id_prefix = self.__class__.__name__
900
+ if not self.is_deterministic:
901
+ return id_prefix
902
+ return f'{id_prefix}@{self.hash}'
903
+
886
904
  @functools.cached_property
887
905
  def children(self) -> list['Evaluation']:
888
906
  """Returns the trials as child evaluations if this evaluation is a space."""
@@ -892,7 +910,6 @@ class Evaluation(Evaluable):
892
910
  for i, child in enumerate(pg.iter(self)):
893
911
  child.sym_setparent(self)
894
912
  child.sym_setpath(self.sym_path + f'children[{i}]')
895
- child.rebind(id=f'{self.id}@{child.hash}', skip_notification=True)
896
913
  children.append(child)
897
914
  return children
898
915
 
@@ -1004,7 +1021,11 @@ class Evaluation(Evaluable):
1004
1021
  self._reset()
1005
1022
 
1006
1023
  def _process(example: Any):
1007
- return self.process(example, **(self.additional_args or {}))
1024
+ # NOTE(daiyip): set the `input` symbol of the globals to None, so LLM
1025
+ # generated code with calls to `input` will raise an error, thus not
1026
+ # blocking the evaluation.
1027
+ with lf_coding.context(input=None):
1028
+ return self.process(example, **(self.additional_args or {}))
1008
1029
 
1009
1030
  try:
1010
1031
  for example, message, error in lf.concurrent_map(
@@ -1015,10 +1036,7 @@ class Evaluation(Evaluable):
1015
1036
  status_fn=self._status,
1016
1037
  ):
1017
1038
  if error is not None:
1018
- try:
1019
- self._failures.append((example, str(error)))
1020
- except Exception as e: # pylint: disable=broad-exception-caught
1021
- self._failures.append((example, str(e)))
1039
+ self._failures.append((example, str(error)))
1022
1040
  else:
1023
1041
  output = message.text if self.schema is None else message.result
1024
1042
  self.audit(example, output, message)
@@ -1521,9 +1539,12 @@ class Summary(pg.Object):
1521
1539
  pivot_field = pivot_field or self.pivot_field
1522
1540
  s = io.StringIO()
1523
1541
  s.write('<html><body>')
1524
- for task in self.tasks():
1542
+ for task in sorted(self.tasks(), key=lambda cls: cls.__name__):
1543
+ table_id = task.__name__.lower()
1525
1544
  s.write('<div>')
1526
- s.write(f'<h2>{task.__name__}</h2>')
1545
+ s.write(f'<a id="{table_id}"')
1546
+ s.write(f'<h2><a href="#{table_id}">{task.__name__}</a></h2>')
1547
+ s.write('</a>')
1527
1548
  table = Summary.Table.from_evaluations(
1528
1549
  self.select(task=task).evaluations, pivot_field
1529
1550
  )
@@ -70,8 +70,7 @@ def eval_set(
70
70
  """Creates an evaluation object for testing."""
71
71
  tmp_dir = tempfile.gettempdir()
72
72
  return cls(
73
- id=eval_id,
74
- root_dir=tmp_dir,
73
+ root_dir=os.path.join(tmp_dir, eval_id),
75
74
  inputs=base.as_inputs([
76
75
  pg.Dict(question='Compute 1 + 1'),
77
76
  pg.Dict(question='Compute 1 + 2'),
@@ -210,7 +209,7 @@ class EvaluationTest(unittest.TestCase):
210
209
  s.result,
211
210
  dict(
212
211
  experiment_setup=dict(
213
- id='run_test',
212
+ id='Evaluation@17915dc6',
214
213
  dir=s.dir,
215
214
  model='StaticSequence',
216
215
  prompt_template='{{example.question}}',
@@ -302,7 +301,6 @@ class EvaluationTest(unittest.TestCase):
302
301
  '3',
303
302
  ])
304
303
  s = base.Evaluation(
305
- id='search_space_test',
306
304
  root_dir=tempfile.gettempdir(),
307
305
  inputs=base.as_inputs([
308
306
  pg.Dict(question='Compute 1 + 1'),
@@ -439,7 +437,6 @@ class SuiteTest(unittest.TestCase):
439
437
  '3',
440
438
  ] * 5)
441
439
  s = base.Suite(
442
- 'suite_run_test',
443
440
  [
444
441
  eval_set('run_test_1', 'query', schema_fn=answer_schema()),
445
442
  # A suite of search space. Two of the sub-experiments are identical,
@@ -548,7 +545,6 @@ class SummaryTest(unittest.TestCase):
548
545
  def _eval_set(self, root_dir):
549
546
  return base.Suite(id='select_test', children=[
550
547
  TaskA(
551
- id='task_a',
552
548
  inputs=base.as_inputs([
553
549
  pg.Dict(question='Compute 1 + 1'),
554
550
  ]),
@@ -569,7 +565,6 @@ class SummaryTest(unittest.TestCase):
569
565
  max_workers=1,
570
566
  ),
571
567
  TaskB(
572
- id='task_b',
573
568
  inputs=base.as_inputs([
574
569
  pg.Dict(question='Compute 1 + 1'),
575
570
  ]),
@@ -650,10 +645,10 @@ class SummaryTest(unittest.TestCase):
650
645
  len(base.Summary.from_dirs(root_dir)), 2 * 2 * 2 * 2 + 2 * 1 * 1 * 2
651
646
  )
652
647
  self.assertEqual(
653
- len(base.Summary.from_dirs(root_dir, 'task_b')), 2 * 1 * 1 * 2
648
+ len(base.Summary.from_dirs(root_dir, 'TaskB')), 2 * 1 * 1 * 2
654
649
  )
655
650
  self.assertEqual(
656
- len(base.Summary.from_dirs(root_dir, ('task_a'))), 2 * 2 * 2 * 2
651
+ len(base.Summary.from_dirs(root_dir, ('TaskA'))), 2 * 2 * 2 * 2
657
652
  )
658
653
 
659
654
  def test_monitor(self):
@@ -65,10 +65,8 @@ def eval_set(
65
65
  use_cache: bool = True,
66
66
  ):
67
67
  """Creates an evaluation object for testing."""
68
- tmp_dir = tempfile.gettempdir()
69
68
  return MyTask(
70
- id=eval_id,
71
- root_dir=tmp_dir,
69
+ root_dir=os.path.join(tempfile.gettempdir(), eval_id),
72
70
  inputs=base.as_inputs([
73
71
  pg.Dict(question='Compute 1 + 1', groundtruth=2),
74
72
  pg.Dict(question='Compute 1 + 2', groundtruth=3),
@@ -105,7 +103,7 @@ class MatchingTest(unittest.TestCase):
105
103
  s.result,
106
104
  dict(
107
105
  experiment_setup=dict(
108
- id='match_run_test',
106
+ id='MyTask@3d87f97f',
109
107
  dir=s.dir,
110
108
  model='StaticSequence',
111
109
  prompt_template='{{example.question}}',
@@ -43,7 +43,6 @@ def constrained_by_upperbound(upper_bound: int):
43
43
 
44
44
 
45
45
  class ConstraintFollowing(scoring.Scoring):
46
- id = 'constraint_following'
47
46
  inputs = constrained_by_upperbound(1)
48
47
  prompt = '{{example}}'
49
48
  method = 'query'
@@ -82,7 +81,7 @@ class ScoringTest(unittest.TestCase):
82
81
  s.result,
83
82
  dict(
84
83
  experiment_setup=dict(
85
- id='constraint_following',
84
+ id='ConstraintFollowing@9e51bb9e',
86
85
  dir=s.dir,
87
86
  model='StaticSequence',
88
87
  prompt_template='{{example}}',
@@ -25,9 +25,11 @@ from langfun.core.llms.fake import StaticResponse
25
25
  from langfun.core.llms.fake import StaticSequence
26
26
 
27
27
  # Gemini models.
28
- from langfun.core.llms.gemini import Gemini
29
- from langfun.core.llms.gemini import GeminiPro
30
- from langfun.core.llms.gemini import GeminiProVision
28
+ from langfun.core.llms.google_genai import GenAI
29
+ from langfun.core.llms.google_genai import GeminiPro
30
+ from langfun.core.llms.google_genai import GeminiProVision
31
+ from langfun.core.llms.google_genai import Palm2
32
+ from langfun.core.llms.google_genai import Palm2_IT
31
33
 
32
34
  # OpenAI models.
33
35
  from langfun.core.llms.openai import OpenAI
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Gemini models exposed through Google Generative AI APIs."""
15
15
 
16
+ import abc
16
17
  import functools
17
18
  import os
18
19
  from typing import Annotated, Any, Literal
@@ -20,14 +21,20 @@ from typing import Annotated, Any, Literal
20
21
  import google.generativeai as genai
21
22
  import langfun.core as lf
22
23
  from langfun.core import modalities as lf_modalities
24
+ import pyglove as pg
23
25
 
24
26
 
25
27
  @lf.use_init_args(['model'])
26
- class Gemini(lf.LanguageModel):
27
- """Language model served on VertexAI."""
28
+ class GenAI(lf.LanguageModel):
29
+ """Language models provided by Google GenAI."""
28
30
 
29
31
  model: Annotated[
30
- Literal['gemini-pro', 'gemini-pro-vision', ''],
32
+ Literal[
33
+ 'gemini-pro',
34
+ 'gemini-pro-vision',
35
+ 'text-bison-001',
36
+ 'chat-bison-001',
37
+ ],
31
38
  'Model name.',
32
39
  ]
33
40
 
@@ -35,7 +42,8 @@ class Gemini(lf.LanguageModel):
35
42
  str | None,
36
43
  (
37
44
  'API key. If None, the key will be read from environment variable '
38
- "'GOOGLE_API_KEY'."
45
+ "'GOOGLE_API_KEY'. "
46
+ 'Get an API key at https://ai.google.dev/tutorials/setup'
39
47
  ),
40
48
  ] = None
41
49
 
@@ -43,6 +51,9 @@ class Gemini(lf.LanguageModel):
43
51
  False
44
52
  )
45
53
 
54
+ # Set the default max concurrency to 8 workers.
55
+ max_concurrency = 8
56
+
46
57
  def _on_bound(self):
47
58
  super()._on_bound()
48
59
  self.__dict__.pop('_api_initialized', None)
@@ -67,7 +78,11 @@ class Gemini(lf.LanguageModel):
67
78
  return [
68
79
  m.name.lstrip('models/')
69
80
  for m in genai.list_models()
70
- if 'generateContent' in m.supported_generation_methods
81
+ if (
82
+ 'generateContent' in m.supported_generation_methods
83
+ or 'generateText' in m.supported_generation_methods
84
+ or 'generateMessage' in m.supported_generation_methods
85
+ )
71
86
  ]
72
87
 
73
88
  @property
@@ -80,11 +95,6 @@ class Gemini(lf.LanguageModel):
80
95
  """Returns a string to identify the resource for rate control."""
81
96
  return self.model_id
82
97
 
83
- @property
84
- def max_concurrency(self) -> int:
85
- """Max concurrent requests."""
86
- return 8
87
-
88
98
  def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
89
99
  """Creates generation config from langfun sampling options."""
90
100
  return genai.GenerationConfig(
@@ -117,7 +127,7 @@ class Gemini(lf.LanguageModel):
117
127
  return chunks
118
128
 
119
129
  def _response_to_result(
120
- self, response: genai.types.GenerateContentResponse
130
+ self, response: genai.types.GenerateContentResponse | pg.Dict
121
131
  ) -> lf.LMSamplingResult:
122
132
  """Parses generative response into message."""
123
133
  samples = []
@@ -149,17 +159,97 @@ class Gemini(lf.LanguageModel):
149
159
  return self._response_to_result(response)
150
160
 
151
161
 
162
+ class _LegacyGenerativeModel(pg.Object):
163
+ """Base for legacy GenAI generative model."""
164
+
165
+ model: str
166
+
167
+ def generate_content(
168
+ self,
169
+ input_content: list[str | genai.types.BlobDict],
170
+ generation_config: genai.GenerationConfig,
171
+ ) -> pg.Dict:
172
+ """Generate content."""
173
+ segments = []
174
+ for s in input_content:
175
+ if not isinstance(s, str):
176
+ raise ValueError(f'Unsupported modality: {s!r}')
177
+ segments.append(s)
178
+ return self.generate(' '.join(segments), generation_config)
179
+
180
+ @abc.abstractmethod
181
+ def generate(
182
+ self, prompt: str, generation_config: genai.GenerationConfig) -> pg.Dict:
183
+ """Generate response based on prompt."""
184
+
185
+
186
+ class _LegacyCompletionModel(_LegacyGenerativeModel):
187
+ """Legacy GenAI completion model."""
188
+
189
+ def generate(
190
+ self, prompt: str, generation_config: genai.GenerationConfig
191
+ ) -> pg.Dict:
192
+ completion: genai.types.Completion = genai.generate_text(
193
+ model=f'models/{self.model}',
194
+ prompt=prompt,
195
+ temperature=generation_config.temperature,
196
+ top_k=generation_config.top_k,
197
+ top_p=generation_config.top_p,
198
+ candidate_count=generation_config.candidate_count,
199
+ max_output_tokens=generation_config.max_output_tokens,
200
+ stop_sequences=generation_config.stop_sequences,
201
+ )
202
+ return pg.Dict(
203
+ candidates=[
204
+ pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['output'])]))
205
+ for c in completion.candidates
206
+ ]
207
+ )
208
+
209
+
210
+ class _LegacyChatModel(_LegacyGenerativeModel):
211
+ """Legacy GenAI chat model."""
212
+
213
+ def generate(
214
+ self, prompt: str, generation_config: genai.GenerationConfig
215
+ ) -> pg.Dict:
216
+ response: genai.types.ChatResponse = genai.chat(
217
+ model=f'models/{self.model}',
218
+ messages=prompt,
219
+ temperature=generation_config.temperature,
220
+ top_k=generation_config.top_k,
221
+ top_p=generation_config.top_p,
222
+ candidate_count=generation_config.candidate_count,
223
+ )
224
+ return pg.Dict(
225
+ candidates=[
226
+ pg.Dict(content=pg.Dict(parts=[pg.Dict(text=c['content'])]))
227
+ for c in response.candidates
228
+ ]
229
+ )
230
+
231
+
152
232
  class _ModelHub:
153
233
  """Google Generative AI model hub."""
154
234
 
155
235
  def __init__(self):
156
236
  self._model_cache = {}
157
237
 
158
- def get(self, model_name: str) -> genai.GenerativeModel:
238
+ def get(
239
+ self, model_name: str
240
+ ) -> genai.GenerativeModel | _LegacyGenerativeModel:
159
241
  """Gets a generative model by model id."""
160
242
  model = self._model_cache.get(model_name, None)
161
243
  if model is None:
162
- model = genai.GenerativeModel(model_name)
244
+ model_info = genai.get_model(f'models/{model_name}')
245
+ if 'generateContent' in model_info.supported_generation_methods:
246
+ model = genai.GenerativeModel(model_name)
247
+ elif 'generateText' in model_info.supported_generation_methods:
248
+ model = _LegacyCompletionModel(model_name)
249
+ elif 'generateMessage' in model_info.supported_generation_methods:
250
+ model = _LegacyChatModel(model_name)
251
+ else:
252
+ raise ValueError(f'Unsupported model: {model_name!r}')
163
253
  self._model_cache[model_name] = model
164
254
  return model
165
255
 
@@ -172,14 +262,26 @@ _GOOGLE_GENAI_MODEL_HUB = _ModelHub()
172
262
  #
173
263
 
174
264
 
175
- class GeminiPro(Gemini):
265
+ class GeminiPro(GenAI):
176
266
  """Gemini Pro model."""
177
267
 
178
268
  model = 'gemini-pro'
179
269
 
180
270
 
181
- class GeminiProVision(Gemini):
271
+ class GeminiProVision(GenAI):
182
272
  """Gemini Pro vision model."""
183
273
 
184
274
  model = 'gemini-pro-vision'
185
275
  multimodal = True
276
+
277
+
278
+ class Palm2(GenAI):
279
+ """PaLM2 model."""
280
+
281
+ model = 'text-bison-001'
282
+
283
+
284
+ class Palm2_IT(GenAI): # pylint: disable=invalid-name
285
+ """PaLM2 instruction-tuned model."""
286
+
287
+ model = 'chat-bison-001'
@@ -20,7 +20,7 @@ from unittest import mock
20
20
  from google import generativeai as genai
21
21
  import langfun.core as lf
22
22
  from langfun.core import modalities as lf_modalities
23
- from langfun.core.llms import gemini
23
+ from langfun.core.llms import google_genai
24
24
  import pyglove as pg
25
25
 
26
26
 
@@ -36,6 +36,29 @@ example_image = (
36
36
  )
37
37
 
38
38
 
39
+ def mock_get_model(model_name, *args, **kwargs):
40
+ del args, kwargs
41
+ if 'gemini' in model_name:
42
+ method = 'generateContent'
43
+ elif 'chat' in model_name:
44
+ method = 'generateMessage'
45
+ else:
46
+ method = 'generateText'
47
+ return pg.Dict(supported_generation_methods=[method])
48
+
49
+
50
+ def mock_generate_text(*, model, prompt, **kwargs):
51
+ return pg.Dict(
52
+ candidates=[pg.Dict(output=f'{prompt} to {model} with {kwargs}')]
53
+ )
54
+
55
+
56
+ def mock_chat(*, model, messages, **kwargs):
57
+ return pg.Dict(
58
+ candidates=[pg.Dict(content=f'{messages} to {model} with {kwargs}')]
59
+ )
60
+
61
+
39
62
  def mock_generate_content(content, generation_config, **kwargs):
40
63
  del kwargs
41
64
  c = generation_config
@@ -68,12 +91,12 @@ def mock_generate_content(content, generation_config, **kwargs):
68
91
  )
69
92
 
70
93
 
71
- class GeminiTest(unittest.TestCase):
72
- """Tests for Evergreen language model."""
94
+ class GenAITest(unittest.TestCase):
95
+ """Tests for Google GenAI model."""
73
96
 
74
97
  def test_content_from_message_text_only(self):
75
98
  text = 'This is a beautiful day'
76
- model = gemini.GeminiPro()
99
+ model = google_genai.GeminiPro()
77
100
  chunks = model._content_from_message(lf.UserMessage(text))
78
101
  self.assertEqual(chunks, [text])
79
102
 
@@ -85,9 +108,9 @@ class GeminiTest(unittest.TestCase):
85
108
 
86
109
  # Non-multimodal model.
87
110
  with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
88
- gemini.GeminiPro()._content_from_message(message)
111
+ google_genai.GeminiPro()._content_from_message(message)
89
112
 
90
- model = gemini.GeminiProVision()
113
+ model = google_genai.GeminiProVision()
91
114
  chunks = model._content_from_message(message)
92
115
  self.maxDiff = None
93
116
  self.assertEqual(
@@ -118,7 +141,7 @@ class GeminiTest(unittest.TestCase):
118
141
  ],
119
142
  ),
120
143
  )
121
- model = gemini.GeminiProVision()
144
+ model = google_genai.GeminiProVision()
122
145
  result = model._response_to_result(response)
123
146
  self.assertEqual(
124
147
  result,
@@ -129,26 +152,28 @@ class GeminiTest(unittest.TestCase):
129
152
  )
130
153
 
131
154
  def test_model_hub(self):
132
- model = gemini._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
155
+ model = google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
133
156
  self.assertIsNotNone(model)
134
- self.assertIs(gemini._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
157
+ self.assertIs(google_genai._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
135
158
 
136
159
  def test_api_key_check(self):
137
160
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
138
- _ = gemini.GeminiPro()._api_initialized
161
+ _ = google_genai.GeminiPro()._api_initialized
139
162
 
140
- self.assertTrue(gemini.GeminiPro(api_key='abc')._api_initialized)
163
+ self.assertTrue(google_genai.GeminiPro(api_key='abc')._api_initialized)
141
164
  os.environ['GOOGLE_API_KEY'] = 'abc'
142
- self.assertTrue(gemini.GeminiPro()._api_initialized)
165
+ self.assertTrue(google_genai.GeminiPro()._api_initialized)
143
166
  del os.environ['GOOGLE_API_KEY']
144
167
 
145
168
  def test_call(self):
146
169
  with mock.patch(
147
170
  'google.generativeai.generative_models.GenerativeModel.generate_content'
148
171
  ) as mock_generate:
172
+ orig_get_model = genai.get_model
173
+ genai.get_model = mock_get_model
149
174
  mock_generate.side_effect = mock_generate_content
150
175
 
151
- lm = gemini.GeminiPro(api_key='test_key')
176
+ lm = google_genai.GeminiPro(api_key='test_key')
152
177
  self.maxDiff = None
153
178
  self.assertEqual(
154
179
  lm('hello', temperature=2.0, top_k=20).text,
@@ -157,6 +182,44 @@ class GeminiTest(unittest.TestCase):
157
182
  'top_p=None, top_k=20, max_tokens=1024, stop=None.'
158
183
  ),
159
184
  )
185
+ genai.get_model = orig_get_model
186
+
187
+ def test_call_with_legacy_completion_model(self):
188
+ orig_get_model = genai.get_model
189
+ genai.get_model = mock_get_model
190
+ orig_generate_text = genai.generate_text
191
+ genai.generate_text = mock_generate_text
192
+
193
+ lm = google_genai.Palm2(api_key='test_key')
194
+ self.maxDiff = None
195
+ self.assertEqual(
196
+ lm('hello', temperature=2.0, top_k=20).text,
197
+ (
198
+ "hello to models/text-bison-001 with {'temperature': 2.0, "
199
+ "'top_k': 20, 'top_p': None, 'candidate_count': 1, "
200
+ "'max_output_tokens': 1024, 'stop_sequences': None}"
201
+ ),
202
+ )
203
+ genai.get_model = orig_get_model
204
+ genai.generate_text = orig_generate_text
205
+
206
+ def test_call_with_legacy_chat_model(self):
207
+ orig_get_model = genai.get_model
208
+ genai.get_model = mock_get_model
209
+ orig_chat = genai.chat
210
+ genai.chat = mock_chat
211
+
212
+ lm = google_genai.Palm2_IT(api_key='test_key')
213
+ self.maxDiff = None
214
+ self.assertEqual(
215
+ lm('hello', temperature=2.0, top_k=20).text,
216
+ (
217
+ "hello to models/chat-bison-001 with {'temperature': 2.0, "
218
+ "'top_k': 20, 'top_p': None, 'candidate_count': 1}"
219
+ ),
220
+ )
221
+ genai.get_model = orig_get_model
222
+ genai.chat = orig_chat
160
223
 
161
224
 
162
225
  if __name__ == '__main__':