langfun 0.0.2.dev20240109__tar.gz → 0.0.2.dev20240111__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/PKG-INFO +3 -2
  2. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/execution.py +47 -35
  3. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/execution_test.py +15 -6
  4. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/generation_test.py +2 -1
  5. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/langfunc.py +6 -4
  6. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/langfunc_test.py +2 -2
  7. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/__init__.py +5 -0
  8. langfun-0.0.2.dev20240111/langfun/core/llms/gemini.py +190 -0
  9. langfun-0.0.2.dev20240111/langfun/core/llms/gemini_test.py +163 -0
  10. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/message.py +1 -1
  11. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/message_test.py +3 -2
  12. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/modality.py +5 -10
  13. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/modality_test.py +29 -0
  14. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/completion.py +15 -22
  15. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/completion_test.py +34 -23
  16. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/description.py +3 -5
  17. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/description_test.py +18 -16
  18. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/mapping.py +16 -8
  19. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/parsing.py +5 -18
  20. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/parsing_test.py +2 -3
  21. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/prompting.py +23 -23
  22. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/prompting_test.py +212 -56
  23. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun.egg-info/PKG-INFO +3 -2
  24. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun.egg-info/SOURCES.txt +2 -0
  25. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun.egg-info/requires.txt +2 -1
  26. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/LICENSE +0 -0
  27. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/README.md +0 -0
  28. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/__init__.py +0 -0
  29. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/__init__.py +0 -0
  30. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/__init__.py +0 -0
  31. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/__init__.py +0 -0
  32. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/correction.py +0 -0
  33. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/correction_test.py +0 -0
  34. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/errors.py +0 -0
  35. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/errors_test.py +0 -0
  36. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/generation.py +0 -0
  37. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/parsing.py +0 -0
  38. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/parsing_test.py +0 -0
  39. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/permissions.py +0 -0
  40. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/coding/python/permissions_test.py +0 -0
  41. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/component.py +0 -0
  42. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/component_test.py +0 -0
  43. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/concurrent.py +0 -0
  44. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/concurrent_test.py +0 -0
  45. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/console.py +0 -0
  46. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/console_test.py +0 -0
  47. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/eval/__init__.py +0 -0
  48. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/eval/base.py +0 -0
  49. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/eval/base_test.py +0 -0
  50. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/eval/matching.py +0 -0
  51. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/eval/matching_test.py +0 -0
  52. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/eval/scoring.py +0 -0
  53. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/eval/scoring_test.py +0 -0
  54. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/language_model.py +0 -0
  55. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/language_model_test.py +0 -0
  56. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/cache/__init__.py +0 -0
  57. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/cache/base.py +0 -0
  58. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/cache/in_memory.py +0 -0
  59. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/cache/in_memory_test.py +0 -0
  60. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/fake.py +0 -0
  61. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/fake_test.py +0 -0
  62. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/llama_cpp.py +0 -0
  63. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/llama_cpp_test.py +0 -0
  64. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/openai.py +0 -0
  65. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/llms/openai_test.py +0 -0
  66. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/memories/__init__.py +0 -0
  67. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/memories/conversation_history.py +0 -0
  68. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/memories/conversation_history_test.py +0 -0
  69. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/memory.py +0 -0
  70. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/modalities/__init__.py +0 -0
  71. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/modalities/image.py +0 -0
  72. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/modalities/image_test.py +0 -0
  73. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/natural_language.py +0 -0
  74. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/natural_language_test.py +0 -0
  75. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/sampling.py +0 -0
  76. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/sampling_test.py +0 -0
  77. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/__init__.py +0 -0
  78. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/mapping_test.py +0 -0
  79. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/schema.py +0 -0
  80. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/structured/schema_test.py +0 -0
  81. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/subscription.py +0 -0
  82. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/subscription_test.py +0 -0
  83. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/template.py +0 -0
  84. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/template_test.py +0 -0
  85. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/__init__.py +0 -0
  86. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/completion.py +0 -0
  87. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/completion_test.py +0 -0
  88. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/conversation.py +0 -0
  89. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/conversation_test.py +0 -0
  90. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/demonstration.py +0 -0
  91. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/demonstration_test.py +0 -0
  92. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/selfplay.py +0 -0
  93. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/templates/selfplay_test.py +0 -0
  94. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/text_formatting.py +0 -0
  95. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun/core/text_formatting_test.py +0 -0
  96. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun.egg-info/dependency_links.txt +0 -0
  97. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/langfun.egg-info/top_level.txt +0 -0
  98. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/setup.cfg +0 -0
  99. {langfun-0.0.2.dev20240109 → langfun-0.0.2.dev20240111}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240109
3
+ Version: 0.0.2.dev20240111
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -21,9 +21,10 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
21
21
  Classifier: Topic :: Software Development :: Libraries
22
22
  Description-Content-Type: text/markdown
23
23
  License-File: LICENSE
24
+ Requires-Dist: google-generativeai>=0.3.2
24
25
  Requires-Dist: jinja2>=3.1.2
25
26
  Requires-Dist: openai==0.27.2
26
- Requires-Dist: pyglove>=0.4.5.dev20240105
27
+ Requires-Dist: pyglove>=0.4.5.dev20240109
27
28
  Requires-Dist: requests>=2.31.0
28
29
  Requires-Dist: termcolor==1.1.0
29
30
  Requires-Dist: tqdm>=4.64.1
@@ -15,6 +15,7 @@
15
15
 
16
16
  import ast
17
17
  import contextlib
18
+ import io
18
19
  import multiprocessing
19
20
  from typing import Any, Callable
20
21
 
@@ -24,6 +25,9 @@ from langfun.core.coding.python import permissions
24
25
  import pyglove as pg
25
26
 
26
27
 
28
+ # Key in returned dict that captures stdout.
29
+ STDOUT_KEY = '__stdout__'
30
+
27
31
  # Key in the returned dict that represents the final result.
28
32
  RESULT_KEY = '__result__'
29
33
  _TLS_CODE_RUN_CONTEXT = '__code_run_context__'
@@ -86,45 +90,51 @@ def evaluate(
86
90
  code, code_block = parsing.PythonCodeParser().parse(code, permission)
87
91
  global_vars, orig_global_vars = ctx, ctx.copy()
88
92
 
89
- if hasattr(code_block.body[-1], 'value'):
90
- last_expr = code_block.body.pop() # pytype: disable=attribute-error
91
- result_vars = [RESULT_KEY]
93
+ # No code.
94
+ if not code_block.body:
95
+ return {} if outputs_intermediate else None
92
96
 
93
- if isinstance(last_expr, ast.Assign):
94
- for name_node in last_expr.targets:
95
- result_vars.append(name_node.id)
97
+ stdout = io.StringIO()
98
+ with contextlib.redirect_stdout(stdout):
99
+ if hasattr(code_block.body[-1], 'value'):
100
+ last_expr = code_block.body.pop() # pytype: disable=attribute-error
101
+ result_vars = [RESULT_KEY]
96
102
 
97
- last_expr = ast.Expression(last_expr.value) # pytype: disable=attribute-error
103
+ if isinstance(last_expr, ast.Assign):
104
+ for name_node in last_expr.targets:
105
+ result_vars.append(name_node.id)
98
106
 
99
- try:
100
- # Execute the lines before the last expression.
101
- # NOTE(daiyip): Only a `globals` dict is specified here, which will also
102
- # be used to output intermediate values by `exec`. We do not specify a
103
- # separate `locals` dict here, for - "If exec gets two separate objects as
104
- # globals and locals, the code will be executed as if it were embedded in
105
- # a class definition." - as the Python document explains. The outcome is
106
- # that new functions defined in the code block could not be called by
107
- # other newly defined functions.
108
- # Refer to https://stackoverflow.com/questions/
109
- # 73940751/why-cant-i-call-a-function-from-another-function-using-exec
110
- # for more details.
111
- exec(compile(code_block, '', mode='exec'), global_vars) # pylint: disable=exec-used
112
-
113
- # Evaluate the last expression.
114
- result = eval( # pylint: disable=eval-used
115
- compile(last_expr, '', mode='eval'), global_vars
116
- )
117
- except Exception as e:
118
- raise errors.CodeError(code, e) from e
107
+ last_expr = ast.Expression(last_expr.value) # pytype: disable=attribute-error
119
108
 
120
- for result_var in result_vars:
121
- global_vars[result_var] = result
122
- else:
123
- try:
124
- exec(compile(code_block, '', mode='exec'), global_vars) # pylint: disable=exec-used
125
- except Exception as e:
126
- raise errors.CodeError(code, e) from e
127
- global_vars[RESULT_KEY] = list(global_vars.values())[-1]
109
+ try:
110
+ # Execute the lines before the last expression.
111
+ # NOTE(daiyip): Only a `globals` dict is specified here, which will also
112
+ # be used to output intermediate values by `exec`. We do not specify a
113
+ # separate `locals` dict here, for - "If exec gets two separate objects
114
+ # as globals and locals, the code will be executed as if it were
115
+ # embedded in a class definition." - as the Python document explains.
116
+ # The outcome is that new functions defined in the code block could not
117
+ # be called by other newly defined functions.
118
+ # Refer to https://stackoverflow.com/questions/
119
+ # 73940751/why-cant-i-call-a-function-from-another-function-using-exec
120
+ # for more details.
121
+ exec(compile(code_block, '', mode='exec'), global_vars) # pylint: disable=exec-used
122
+
123
+ # Evaluate the last expression.
124
+ result = eval( # pylint: disable=eval-used
125
+ compile(last_expr, '', mode='eval'), global_vars
126
+ )
127
+ except Exception as e:
128
+ raise errors.CodeError(code, e) from e
129
+
130
+ for result_var in result_vars:
131
+ global_vars[result_var] = result
132
+ else:
133
+ try:
134
+ exec(compile(code_block, '', mode='exec'), global_vars) # pylint: disable=exec-used
135
+ except Exception as e:
136
+ raise errors.CodeError(code, e) from e
137
+ global_vars[RESULT_KEY] = list(global_vars.values())[-1]
128
138
 
129
139
  if outputs_intermediate:
130
140
  outputs = {}
@@ -133,6 +143,8 @@ def evaluate(
133
143
  continue
134
144
  if k not in orig_global_vars or v is not orig_global_vars[k]:
135
145
  outputs[k] = v
146
+ # Add stdout to outputs.
147
+ outputs[STDOUT_KEY] = stdout.getvalue()
136
148
  return outputs
137
149
  return global_vars[RESULT_KEY]
138
150
 
@@ -36,7 +36,7 @@ class EvaluateTest(unittest.TestCase):
36
36
  global_vars=dict(z=3),
37
37
  outputs_intermediate=True,
38
38
  ),
39
- dict(p=2 + 0 + 3, __result__=2 + 0 + 3),
39
+ dict(p=2 + 0 + 3, __result__=2 + 0 + 3, __stdout__=''),
40
40
  )
41
41
 
42
42
  def test_basics(self):
@@ -45,17 +45,19 @@ class EvaluateTest(unittest.TestCase):
45
45
  """
46
46
  x = 1
47
47
  y = x + 1
48
+ print(y)
48
49
  z = x + y
49
50
  """,
50
51
  outputs_intermediate=True,
51
52
  ),
52
- dict(x=1, y=2, z=3, __result__=3),
53
+ dict(x=1, y=2, z=3, __result__=3, __stdout__='2\n'),
53
54
  )
54
55
  self.assertEqual(
55
56
  execution.evaluate(
56
57
  """
57
58
  x = 1
58
59
  y = x + 1
60
+ print(y)
59
61
  z = x + y
60
62
  """,
61
63
  ),
@@ -75,9 +77,10 @@ class EvaluateTest(unittest.TestCase):
75
77
  global_vars=dict(pg=pg),
76
78
  outputs_intermediate=True,
77
79
  )
78
- self.assertEqual(list(ret.keys()), ['A', '__result__'])
80
+ self.assertEqual(list(ret.keys()), ['A', '__result__', '__stdout__'])
79
81
  self.assertTrue(issubclass(ret['A'], pg.Object))
80
82
  self.assertIs(ret['__result__'], ret['A'])
83
+ self.assertEqual(ret['__stdout__'], '')
81
84
 
82
85
  def test_function_def(self):
83
86
  ret = execution.evaluate(
@@ -91,7 +94,9 @@ class EvaluateTest(unittest.TestCase):
91
94
  permission=permissions.CodePermission.ALL,
92
95
  outputs_intermediate=True,
93
96
  )
94
- self.assertEqual(list(ret.keys()), ['foo', 'bar', '__result__'])
97
+ self.assertEqual(
98
+ list(ret.keys()), ['foo', 'bar', '__result__', '__stdout__']
99
+ )
95
100
  self.assertTrue(inspect.isfunction(ret['foo']))
96
101
  self.assertTrue(inspect.isfunction(ret['bar']))
97
102
  self.assertIs(ret['__result__'], ret['bar'])
@@ -110,7 +115,9 @@ class EvaluateTest(unittest.TestCase):
110
115
  permission=permissions.CodePermission.ALL,
111
116
  outputs_intermediate=True,
112
117
  )
113
- self.assertEqual(list(ret.keys()), ['foo', 'bar', '__result__'])
118
+ self.assertEqual(
119
+ list(ret.keys()), ['foo', 'bar', '__result__', '__stdout__']
120
+ )
114
121
  self.assertEqual(ret['__result__'], 3)
115
122
 
116
123
  def test_complex(self):
@@ -131,7 +138,9 @@ class EvaluateTest(unittest.TestCase):
131
138
  global_vars=dict(pg=pg),
132
139
  outputs_intermediate=True,
133
140
  )
134
- self.assertEqual(list(ret.keys()), ['A', 'foo', 'k', '__result__'])
141
+ self.assertEqual(
142
+ list(ret.keys()), ['A', 'foo', 'k', '__result__', '__stdout__']
143
+ )
135
144
  self.assertTrue(issubclass(ret['A'], pg.Object))
136
145
  self.assertTrue(inspect.isfunction(ret['foo']))
137
146
  self.assertIsInstance(ret['k'], pg.Object)
@@ -41,9 +41,10 @@ class PythonCodeTest(unittest.TestCase):
41
41
  generation.PythonCode("""
42
42
  x = 1
43
43
  y = x + 1
44
+ print(y)
44
45
  z = x + y
45
46
  """).eval(),
46
- dict(x=1, y=2, z=3, __result__=3),
47
+ dict(x=1, y=2, z=3, __result__=3, __stdout__='2\n'),
47
48
  )
48
49
 
49
50
  def test_call(self):
@@ -326,18 +326,20 @@ class LangFunc(
326
326
  return lm_output
327
327
 
328
328
  @classmethod
329
- def from_value(cls, value: Union[str, template_lib.Template]) -> 'LangFunc':
329
+ def from_value(
330
+ cls, value: Union[str, template_lib.Template], **kwargs
331
+ ) -> 'LangFunc':
330
332
  """Create a LangFunc object from a string or template."""
331
333
  if isinstance(value, LangFunc):
332
334
  return value
333
335
  if isinstance(value, template_lib.Template):
334
- lfun = LangFunc(value.template_str)
336
+ lfun = LangFunc(value.template_str, **kwargs)
335
337
  # So lfun could acccess all attributes from value.
336
338
  lfun.sym_setparent(value)
337
339
  return lfun
338
340
  if isinstance(value, str):
339
- return LangFunc(template_str=value)
340
- raise TypeError(f'Unsupported input type: {value!r}.')
341
+ return LangFunc(template_str=value, **kwargs)
342
+ return LangFunc('{{input}}', input=value, **kwargs)
341
343
 
342
344
 
343
345
  # Register converter from str to LangFunc, therefore we can always
@@ -64,8 +64,8 @@ class BasicTest(unittest.TestCase):
64
64
  l3 = LangFunc.from_value(c.l)
65
65
  self.assertEqual(l3.render(), '1 + 2')
66
66
 
67
- with self.assertRaisesRegex(TypeError, 'Unsupported input type'):
68
- LangFunc.from_value(1)
67
+ l4 = LangFunc.from_value(1)
68
+ self.assertEqual(l4.render(), '1')
69
69
 
70
70
 
71
71
  class LangFuncCallTest(unittest.TestCase):
@@ -23,6 +23,11 @@ from langfun.core.llms.fake import StaticMapping
23
23
  from langfun.core.llms.fake import StaticResponse
24
24
  from langfun.core.llms.fake import StaticSequence
25
25
 
26
+ # Gemini models.
27
+ from langfun.core.llms.gemini import Gemini
28
+ from langfun.core.llms.gemini import GeminiPro
29
+ from langfun.core.llms.gemini import GeminiProVision
30
+
26
31
  # OpenAI models.
27
32
  from langfun.core.llms.openai import OpenAI
28
33
 
@@ -0,0 +1,190 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Gemini models exposed through Google Generative AI APIs."""
15
+
16
+ import functools
17
+ import os
18
+ from typing import Annotated, Any, Literal
19
+
20
+ import google.generativeai as genai
21
+ import langfun.core as lf
22
+ from langfun.core import modalities as lf_modalities
23
+
24
+
25
+ @lf.use_init_args(['model'])
26
+ class Gemini(lf.LanguageModel):
27
+ """Language model served on VertexAI."""
28
+
29
+ model: Annotated[
30
+ Literal['gemini-pro', 'gemini-pro-vision', ''],
31
+ 'Model name.',
32
+ ]
33
+
34
+ api_key: Annotated[
35
+ str | None,
36
+ (
37
+ 'API key. If None, the key will be read from environment variable '
38
+ "'GOOGLE_API_KEY'."
39
+ ),
40
+ ] = None
41
+
42
+ multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
43
+ False
44
+ )
45
+
46
+ def _on_bound(self):
47
+ super()._on_bound()
48
+ self.__dict__.pop('_api_initialized', None)
49
+
50
+ @functools.cached_property
51
+ def _api_initialized(self):
52
+ api_key = self.api_key or os.environ.get('GOOGLE_API_KEY', None)
53
+ if not api_key:
54
+ raise ValueError(
55
+ 'Please specify `api_key` during `__init__` or set environment '
56
+ 'variable `GOOGLE_API_KEY` with your Google Cloud API key. '
57
+ 'Check out '
58
+ 'https://cloud.google.com/api-keys/docs/create-manage-api-keys '
59
+ 'for more details.'
60
+ )
61
+ genai.configure(api_key=api_key)
62
+ return True
63
+
64
+ @classmethod
65
+ def dir(cls) -> list[str]:
66
+ """Lists generative models."""
67
+ return [
68
+ m.name.lstrip('models/')
69
+ for m in genai.list_models()
70
+ if 'generateContent' in m.supported_generation_methods
71
+ ]
72
+
73
+ @property
74
+ def model_id(self) -> str:
75
+ """Returns a string to identify the model."""
76
+ return self.model
77
+
78
+ @property
79
+ def resource_id(self) -> str:
80
+ """Returns a string to identify the resource for rate control."""
81
+ return self.model_id
82
+
83
+ @property
84
+ def max_concurrency(self) -> int:
85
+ """Max concurrent requests."""
86
+ return 8
87
+
88
+ def _generation_config(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
89
+ """Creates generation config from langfun sampling options."""
90
+ return genai.GenerationConfig(
91
+ candidate_count=options.n,
92
+ temperature=options.temperature,
93
+ top_p=options.top_p,
94
+ top_k=options.top_k,
95
+ max_output_tokens=options.max_tokens,
96
+ stop_sequences=options.stop,
97
+ )
98
+
99
+ def _content_from_message(
100
+ self, prompt: lf.Message
101
+ ) -> list[str | genai.types.BlobDict]:
102
+ """Gets Evergreen formatted content from langfun message."""
103
+ formatted = lf.UserMessage(prompt.text)
104
+ formatted.source = prompt
105
+
106
+ chunks = []
107
+ for lf_chunk in formatted.chunk():
108
+ if isinstance(lf_chunk, str):
109
+ chunk = lf_chunk
110
+ elif self.multimodal and isinstance(lf_chunk, lf_modalities.Image):
111
+ chunk = genai.types.BlobDict(
112
+ data=lf_chunk.to_bytes(), mime_type=f'image/{lf_chunk.image_format}'
113
+ )
114
+ else:
115
+ raise ValueError(f'Unsupported modality: {lf_chunk!r}')
116
+ chunks.append(chunk)
117
+ return chunks
118
+
119
+ def _response_to_result(
120
+ self, response: genai.types.GenerateContentResponse
121
+ ) -> lf.LMSamplingResult:
122
+ """Parses generative response into message."""
123
+ samples = []
124
+ for candidate in response.candidates:
125
+ chunks = []
126
+ for part in candidate.content.parts:
127
+ # TODO(daiyip): support multi-modal parts when they are available via
128
+ # Gemini API.
129
+ if hasattr(part, 'text'):
130
+ chunks.append(part.text)
131
+ samples.append(lf.LMSample(lf.AIMessage.from_chunks(chunks), score=0.0))
132
+ return lf.LMSamplingResult(samples)
133
+
134
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
135
+ assert self._api_initialized, 'Vertex AI API is not initialized.'
136
+ return lf.concurrent_execute(
137
+ self._sample_single,
138
+ prompts,
139
+ executor=self.resource_id,
140
+ max_workers=self.max_concurrency,
141
+ # NOTE(daiyip): Vertex has its own policy on handling
142
+ # with rate limit, so we do not retry on errors.
143
+ retry_on_errors=None,
144
+ )
145
+
146
+ def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
147
+ """Samples a single prompt."""
148
+ model = _GOOGLE_GENAI_MODEL_HUB.get(self.model)
149
+ input_content = self._content_from_message(prompt)
150
+ response = model.generate_content(
151
+ input_content,
152
+ generation_config=self._generation_config(self.sampling_options),
153
+ )
154
+ return self._response_to_result(response)
155
+
156
+
157
+ class _ModelHub:
158
+ """Google Generative AI model hub."""
159
+
160
+ def __init__(self):
161
+ self._model_cache = {}
162
+
163
+ def get(self, model_name: str) -> genai.GenerativeModel:
164
+ """Gets a generative model by model id."""
165
+ model = self._model_cache.get(model_name, None)
166
+ if model is None:
167
+ model = genai.GenerativeModel(model_name)
168
+ self._model_cache[model_name] = model
169
+ return model
170
+
171
+
172
+ _GOOGLE_GENAI_MODEL_HUB = _ModelHub()
173
+
174
+
175
+ #
176
+ # Public Gemini models.
177
+ #
178
+
179
+
180
+ class GeminiPro(Gemini):
181
+ """Gemini Pro model."""
182
+
183
+ model = 'gemini-pro'
184
+
185
+
186
+ class GeminiProVision(Gemini):
187
+ """Gemini Pro vision model."""
188
+
189
+ model = 'gemini-pro-vision'
190
+ multimodal = True
@@ -0,0 +1,163 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for Gemini models."""
15
+
16
+ import os
17
+ import unittest
18
+ from unittest import mock
19
+
20
+ from google import generativeai as genai
21
+ import langfun.core as lf
22
+ from langfun.core import modalities as lf_modalities
23
+ from langfun.core.llms import gemini
24
+ import pyglove as pg
25
+
26
+
27
+ example_image = (
28
+ b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x18\x00\x00\x00\x18\x04'
29
+ b'\x03\x00\x00\x00\x12Y \xcb\x00\x00\x00\x18PLTE\x00\x00'
30
+ b'\x00fff_chaag_cg_ch^ci_ciC\xedb\x94\x00\x00\x00\x08tRNS'
31
+ b'\x00\n\x9f*\xd4\xff_\xf4\xe4\x8b\xf3a\x00\x00\x00>IDATx'
32
+ b'\x01c \x05\x08)"\xd8\xcc\xae!\x06pNz\x88k\x19\\Q\xa8"\x10'
33
+ b'\xc1\x14\x95\x01%\xc1\n\xa143Ta\xa8"D-\x84\x03QM\x98\xc3'
34
+ b'\x1a\x1a\x1a@5\x0e\x04\xa0q\x88\x05\x00\x07\xf8\x18\xf9'
35
+ b'\xdao\xd0|\x00\x00\x00\x00IEND\xaeB`\x82'
36
+ )
37
+
38
+
39
+ def mock_generate_content(content, generation_config, **kwargs):
40
+ del kwargs
41
+ c = generation_config
42
+ return genai.types.GenerateContentResponse(
43
+ done=True,
44
+ iterator=None,
45
+ chunks=[],
46
+ result=pg.Dict(
47
+ prompt_feedback=pg.Dict(block_reason=None),
48
+ candidates=[
49
+ pg.Dict(
50
+ content=pg.Dict(
51
+ parts=[
52
+ pg.Dict(
53
+ text=(
54
+ f'This is a response to {content[0]} with '
55
+ f'n={c.candidate_count}, '
56
+ f'temperature={c.temperature}, '
57
+ f'top_p={c.top_p}, '
58
+ f'top_k={c.top_k}, '
59
+ f'max_tokens={c.max_output_tokens}, '
60
+ f'stop={c.stop_sequences}.'
61
+ )
62
+ )
63
+ ]
64
+ ),
65
+ ),
66
+ ],
67
+ ),
68
+ )
69
+
70
+
71
+ class GeminiTest(unittest.TestCase):
72
+ """Tests for Evergreen language model."""
73
+
74
+ def test_content_from_message_text_only(self):
75
+ text = 'This is a beautiful day'
76
+ model = gemini.GeminiPro()
77
+ chunks = model._content_from_message(lf.UserMessage(text))
78
+ self.assertEqual(chunks, [text])
79
+
80
+ def test_content_from_message_mm(self):
81
+ message = lf.UserMessage(
82
+ 'This is an {{image}}, what is it?',
83
+ image=lf_modalities.Image.from_bytes(example_image),
84
+ )
85
+
86
+ # Non-multimodal model.
87
+ with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
88
+ gemini.GeminiPro()._content_from_message(message)
89
+
90
+ model = gemini.GeminiProVision()
91
+ chunks = model._content_from_message(message)
92
+ self.maxDiff = None
93
+ self.assertEqual(
94
+ chunks,
95
+ [
96
+ 'This is an',
97
+ genai.types.BlobDict(mime_type='image/png', data=example_image),
98
+ ', what is it?',
99
+ ],
100
+ )
101
+
102
+ def test_response_to_result_text_only(self):
103
+ response = genai.types.GenerateContentResponse(
104
+ done=True,
105
+ iterator=None,
106
+ chunks=[],
107
+ result=pg.Dict(
108
+ prompt_feedback=pg.Dict(block_reason=None),
109
+ candidates=[
110
+ pg.Dict(
111
+ content=pg.Dict(
112
+ parts=[pg.Dict(text='This is response 1.')]
113
+ ),
114
+ ),
115
+ pg.Dict(
116
+ content=pg.Dict(parts=[pg.Dict(text='This is response 2.')])
117
+ ),
118
+ ],
119
+ ),
120
+ )
121
+ model = gemini.GeminiProVision()
122
+ result = model._response_to_result(response)
123
+ self.assertEqual(
124
+ result,
125
+ lf.LMSamplingResult([
126
+ lf.LMSample(lf.AIMessage('This is response 1.'), score=0.0),
127
+ lf.LMSample(lf.AIMessage('This is response 2.'), score=0.0),
128
+ ]),
129
+ )
130
+
131
+ def test_model_hub(self):
132
+ model = gemini._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro')
133
+ self.assertIsNotNone(model)
134
+ self.assertIs(gemini._GOOGLE_GENAI_MODEL_HUB.get('gemini-pro'), model)
135
+
136
+ def test_api_key_check(self):
137
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
138
+ _ = gemini.GeminiPro()._api_initialized
139
+
140
+ self.assertTrue(gemini.GeminiPro(api_key='abc')._api_initialized)
141
+ os.environ['GOOGLE_API_KEY'] = 'abc'
142
+ self.assertTrue(gemini.GeminiPro()._api_initialized)
143
+ del os.environ['GOOGLE_API_KEY']
144
+
145
+ def test_call(self):
146
+ with mock.patch(
147
+ 'google.generativeai.generative_models.GenerativeModel.generate_content'
148
+ ) as mock_generate:
149
+ mock_generate.side_effect = mock_generate_content
150
+
151
+ lm = gemini.GeminiPro(api_key='test_key')
152
+ self.maxDiff = None
153
+ self.assertEqual(
154
+ lm('hello', temperature=2.0, top_k=20).text,
155
+ (
156
+ 'This is a response to hello with n=1, temperature=2.0, '
157
+ 'top_p=None, top_k=20, max_tokens=1024, stop=None.'
158
+ ),
159
+ )
160
+
161
+
162
+ if __name__ == '__main__':
163
+ unittest.main()
@@ -195,7 +195,7 @@ class Message(natural_language.NaturalLanguageFormattable, pg.Object):
195
195
  if key_path == Message.PATH_TEXT:
196
196
  return self.text
197
197
  else:
198
- v = self.metadata.sym_get(key_path, default)
198
+ v = self.metadata.sym_get(key_path, default, use_inferred=True)
199
199
  return v.value if isinstance(v, pg.Ref) else v
200
200
 
201
201
  #
@@ -110,16 +110,17 @@ class MessageTest(unittest.TestCase):
110
110
  def test_get(self):
111
111
 
112
112
  class A(pg.Object):
113
- pass
113
+ p: int
114
114
 
115
115
  # Create a symbolic object and assign it to a container, so we could test
116
116
  # pg.Ref.
117
- a = A()
117
+ a = A(1)
118
118
  d = pg.Dict(x=a)
119
119
 
120
120
  m = message.UserMessage('hi', x=pg.Ref(a), y=dict(z=[0, 1, 2]))
121
121
  self.assertEqual(m.get('text'), 'hi')
122
122
  self.assertIs(m.get('x'), a)
123
+ self.assertIs(m.get('x.p'), 1)
123
124
  self.assertEqual(m.get('y'), dict(z=[0, 1, 2]))
124
125
  self.assertEqual(m.get('y.z'), [0, 1, 2])
125
126
  self.assertEqual(m.get('y.z[0]'), 0)