langfun 0.1.2.dev202507250804__py3-none-any.whl → 0.1.2.dev202507270804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

Files changed (31) hide show
  1. langfun/assistant/capabilities/gui/__init__.py +36 -0
  2. langfun/assistant/capabilities/gui/bounding_box_parser.py +195 -0
  3. langfun/assistant/capabilities/gui/bounding_box_parser_test.py +313 -0
  4. langfun/assistant/capabilities/gui/drawing.py +242 -0
  5. langfun/assistant/capabilities/gui/drawing_test.py +103 -0
  6. langfun/assistant/capabilities/gui/location.py +288 -0
  7. langfun/assistant/capabilities/gui/location_test.py +230 -0
  8. langfun/core/__init__.py +3 -0
  9. langfun/core/agentic/action.py +18 -0
  10. langfun/core/agentic/action_test.py +8 -1
  11. langfun/core/async_support.py +28 -0
  12. langfun/core/async_support_test.py +39 -0
  13. langfun/core/concurrent_test.py +6 -5
  14. langfun/core/language_model.py +54 -0
  15. langfun/core/language_model_test.py +40 -0
  16. langfun/core/structured/__init__.py +7 -0
  17. langfun/core/structured/completion.py +28 -0
  18. langfun/core/structured/completion_test.py +36 -0
  19. langfun/core/structured/parsing.py +77 -0
  20. langfun/core/structured/parsing_test.py +15 -0
  21. langfun/core/structured/querying.py +42 -0
  22. langfun/core/structured/querying_test.py +10 -0
  23. langfun/core/structured/scoring.py +28 -0
  24. langfun/core/structured/scoring_test.py +8 -0
  25. langfun/core/structured/tokenization.py +24 -0
  26. langfun/core/structured/tokenization_test.py +8 -0
  27. {langfun-0.1.2.dev202507250804.dist-info → langfun-0.1.2.dev202507270804.dist-info}/METADATA +1 -1
  28. {langfun-0.1.2.dev202507250804.dist-info → langfun-0.1.2.dev202507270804.dist-info}/RECORD +31 -22
  29. {langfun-0.1.2.dev202507250804.dist-info → langfun-0.1.2.dev202507270804.dist-info}/WHEEL +0 -0
  30. {langfun-0.1.2.dev202507250804.dist-info → langfun-0.1.2.dev202507270804.dist-info}/licenses/LICENSE +0 -0
  31. {langfun-0.1.2.dev202507250804.dist-info → langfun-0.1.2.dev202507270804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,230 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Structures for location objects in an image."""
15
+
16
+ import random
17
+ import unittest
18
+ from langfun.assistant.capabilities.gui import location
19
+
20
+
21
+ class CoordinateTest(unittest.TestCase):
22
+
23
+ def test_basics(self):
24
+ pt = location.Coordinate(1, 2)
25
+ self.assertEqual(pt.x, 1)
26
+ self.assertEqual(pt.y, 2)
27
+ self.assertEqual(pt.as_tuple(), (1, 2))
28
+
29
+ def test_random(self):
30
+ bound = location.BBox(0, 0, 10, 10)
31
+ rand = random.Random(0)
32
+ for _ in range(10):
33
+ pt = location.Coordinate.random(bound, rand=rand)
34
+ self.assertIsInstance(pt, location.Coordinate)
35
+ self.assertIn(pt, bound)
36
+
37
+ def test_from_value(self):
38
+ pt = location.Coordinate.from_value((1, 2))
39
+ self.assertEqual(pt.x, 1)
40
+ self.assertEqual(pt.y, 2)
41
+ self.assertIs(location.Coordinate.from_value(pt), pt)
42
+
43
+ def test_arithmetic(self):
44
+ pt1 = location.Coordinate(1, 2)
45
+ pt2 = location.Coordinate(3, 4)
46
+ self.assertEqual(pt1 + pt2, location.Coordinate(4, 6))
47
+ self.assertEqual(pt2 + pt1, location.Coordinate(4, 6))
48
+ self.assertEqual(pt1 - pt2, location.Coordinate(-2, -2))
49
+ self.assertEqual(pt2 - pt1, location.Coordinate(2, 2))
50
+
51
+ pt1 = location.Coordinate(1, 2)
52
+ pt2 = (3, 4)
53
+ self.assertEqual(pt1 + pt2, location.Coordinate(4, 6))
54
+ self.assertEqual(pt2 + pt1, location.Coordinate(4, 6))
55
+ self.assertEqual(pt1 - pt2, location.Coordinate(-2, -2))
56
+ self.assertEqual(pt2 - pt1, location.Coordinate(2, 2))
57
+
58
+ self.assertEqual(pt1 * 2, location.Coordinate(2, 4))
59
+ self.assertEqual(2 * pt1, location.Coordinate(2, 4))
60
+
61
+ def test_distance_to(self):
62
+ pt1 = location.Coordinate(1, 2)
63
+ pt2 = location.Coordinate(3, 4)
64
+ # compare float values with a tolerance of 1e-6
65
+ self.assertAlmostEqual(pt1.distance_to(pt2), 2.8284271247461903, delta=1e-6)
66
+
67
+
68
+ class BBoxTest(unittest.TestCase):
69
+
70
+ def test_invalid_bbox_creation(self):
71
+ with self.assertRaisesRegex(AssertionError, '.*'):
72
+ location.BBox(100, 50, 100, 450) # Zero width.
73
+ with self.assertRaisesRegex(AssertionError, '.*'):
74
+ location.BBox(100, 50, 300, 50) # Zero height.
75
+ with self.assertRaisesRegex(AssertionError, '.*'):
76
+ location.BBox(100, 50, 50, 50) # Zero width and height
77
+
78
+ def test_basics(self):
79
+ bbox = location.BBox(100, 50, 300, 450)
80
+ self.assertEqual(bbox.x, 100)
81
+ self.assertEqual(bbox.y, 50)
82
+ self.assertEqual(bbox.right, 300)
83
+ self.assertEqual(bbox.bottom, 450)
84
+ self.assertEqual(bbox.left, 100)
85
+ self.assertEqual(bbox.top, 50)
86
+ self.assertEqual(bbox.width, 200)
87
+ self.assertEqual(bbox.height, 400)
88
+
89
+ self.assertEqual(bbox.center, location.Coordinate(200, 250))
90
+ self.assertEqual(bbox.top_left, location.Coordinate(100, 50))
91
+ self.assertEqual(bbox.bottom_right, location.Coordinate(300, 450))
92
+ self.assertEqual(bbox.area, 80000)
93
+ self.assertEqual(bbox.as_tuple(), (100, 50, 300, 450))
94
+
95
+ def test_contains(self):
96
+ bbox = location.BBox(100, 50, 300, 450)
97
+ self.assertIn((120, 200), bbox)
98
+ self.assertIn(location.Coordinate(120, 200), bbox)
99
+ self.assertNotIn(location.Coordinate(80, 200), bbox)
100
+ self.assertNotIn(location.Coordinate(120, 20), bbox)
101
+ self.assertNotIn(location.Coordinate(320, 200), bbox)
102
+ self.assertNotIn(location.Coordinate(120, 470), bbox)
103
+
104
+ self.assertIn((100, 50, 300, 450), bbox)
105
+ self.assertIn(location.BBox(100, 50, 300, 450), bbox)
106
+ self.assertIn(location.BBox(120, 70, 260, 410), bbox)
107
+ self.assertNotIn(location.BBox(60, 70, 260, 410), bbox)
108
+ self.assertNotIn(location.BBox(120, 45, 260, 410), bbox)
109
+ self.assertNotIn(location.BBox(120, 60, 310, 410), bbox)
110
+ self.assertNotIn(location.BBox(120, 60, 260, 470), bbox)
111
+ self.assertNotIn(location.BBox(0, 0, 400, 500), bbox)
112
+
113
+ with self.assertRaisesRegex(ValueError, 'Invalid tuple size'):
114
+ _ = (1, 2, 3) in bbox
115
+
116
+ with self.assertRaisesRegex(ValueError, 'Invalid type'):
117
+ _ = 'abc' in bbox # pytype: disable=unsupported-operands
118
+
119
+ def test_intersects(self):
120
+ bbox = location.BBox(100, 50, 300, 450)
121
+ self.assertTrue(bbox.intersects(location.BBox(100, 50, 300, 450)))
122
+ self.assertTrue(bbox.intersects(location.BBox(120, 70, 260, 410)))
123
+ self.assertTrue(bbox.intersects(location.BBox(60, 70, 260, 470)))
124
+ self.assertFalse(bbox.intersects(location.BBox(0, 70, 90, 410)))
125
+ self.assertFalse(bbox.intersects(location.BBox(60, 10, 400, 30)))
126
+
127
+ def test_clip(self):
128
+ # BBox within image bounds
129
+ bbox = location.BBox(100, 50, 300, 450)
130
+ clipped_bbox = bbox.clip((400, 500))
131
+ self.assertEqual(clipped_bbox, bbox)
132
+
133
+ # BBox exceeds right and bottom bounds
134
+ bbox = location.BBox(100, 50, 500, 600)
135
+ clipped_bbox = bbox.clip((400, 500))
136
+ self.assertEqual(clipped_bbox, location.BBox(100, 50, 400, 500))
137
+
138
+ # BBox exceeds left and top bounds
139
+ bbox = location.BBox(-10, -20, 300, 450)
140
+ clipped_bbox = bbox.clip((400, 500))
141
+ self.assertEqual(clipped_bbox, location.BBox(0, 0, 300, 450))
142
+
143
+ # BBox larger than the image
144
+ bbox = location.BBox(-10, -20, 800, 700)
145
+ clipped_bbox = bbox.clip((400, 500))
146
+ self.assertEqual(clipped_bbox, location.BBox(0, 0, 400, 500))
147
+
148
+ # BBox starts outside the image boundaries (bottom/right)
149
+ bbox = location.BBox(500, 600, 700, 800)
150
+ clipped_bbox = bbox.clip((400, 500))
151
+ self.assertIsNone(clipped_bbox)
152
+
153
+ def test_random(self):
154
+ bound = location.BBox(0, 0, 800, 600)
155
+ rand = random.Random(0)
156
+ for _ in range(10):
157
+ bbox = location.BBox.random(bound, rand=rand)
158
+ self.assertIsInstance(bbox, location.BBox)
159
+ self.assertIn(bbox, bound)
160
+
161
+ with self.assertRaisesRegex(
162
+ ValueError, 'Minimum width or height is larger than the bound'
163
+ ):
164
+ _ = location.BBox.random(location.BBox(0, 0, 100, 100), min_width=110)
165
+
166
+ def test_matches(self):
167
+ bbox1 = location.BBox(100, 50, 300, 450)
168
+ self.assertFalse(bbox1.matches(location.BBox(10, 10, 11, 11)))
169
+
170
+ bbox2 = location.BBox(100, 50, 300, 450)
171
+ self.assertTrue(bbox1.matches(bbox2))
172
+ self.assertTrue(bbox2.matches(bbox1))
173
+
174
+ bbox3 = location.BBox(110, 40, 290, 450)
175
+ self.assertTrue(bbox1.matches(bbox3))
176
+ self.assertFalse(bbox3.matches(bbox1, area_diff_threshold=0.05))
177
+ self.assertFalse(bbox3.matches(bbox1, max_center_distance=1))
178
+
179
+ def test_expand(self):
180
+ # Test case 1: Expanding both width and height.
181
+ bbox = location.BBox(100, 50, 300, 450)
182
+ expanded_bbox = bbox.expand(width_scale=1.2, height_scale=1.1)
183
+ self.assertEqual(expanded_bbox.x, 80)
184
+ self.assertEqual(expanded_bbox.y, 30)
185
+ self.assertEqual(expanded_bbox.right, 320)
186
+ self.assertEqual(expanded_bbox.bottom, 470)
187
+ self.assertEqual(expanded_bbox.width, 240)
188
+ self.assertEqual(expanded_bbox.height, 440)
189
+ self.assertEqual(expanded_bbox.center, location.Coordinate(200, 250))
190
+
191
+ # Test case 2: Expanding only width.
192
+ bbox = location.BBox(100, 50, 300, 450)
193
+ expanded_bbox = bbox.expand(width_scale=1.2)
194
+ self.assertEqual(expanded_bbox.x, 80)
195
+ self.assertEqual(expanded_bbox.y, 50)
196
+ self.assertEqual(expanded_bbox.right, 320)
197
+ self.assertEqual(expanded_bbox.bottom, 450)
198
+ self.assertEqual(expanded_bbox.width, 240)
199
+ self.assertEqual(expanded_bbox.height, 400)
200
+ self.assertEqual(expanded_bbox.center, location.Coordinate(200, 250))
201
+
202
+ # Test case 3: Expanding only height.
203
+ bbox = location.BBox(100, 50, 300, 450)
204
+ expanded_bbox = bbox.expand(height_scale=1.1)
205
+ self.assertEqual(expanded_bbox.x, 100)
206
+ self.assertEqual(expanded_bbox.y, 30)
207
+ self.assertEqual(expanded_bbox.right, 300)
208
+ self.assertEqual(expanded_bbox.bottom, 470)
209
+ self.assertEqual(expanded_bbox.width, 200)
210
+ self.assertEqual(expanded_bbox.height, 440)
211
+ self.assertEqual(expanded_bbox.center, location.Coordinate(200, 250))
212
+
213
+ # Test case 4: Shrinking.
214
+ bbox = location.BBox(100, 50, 300, 450)
215
+ expanded_bbox = bbox.expand(width_scale=0.8, height_scale=0.9)
216
+ self.assertEqual(expanded_bbox.x, 120)
217
+ self.assertEqual(expanded_bbox.y, 70)
218
+ self.assertEqual(expanded_bbox.right, 280)
219
+ self.assertEqual(expanded_bbox.bottom, 430)
220
+ self.assertEqual(expanded_bbox.width, 160)
221
+ self.assertEqual(expanded_bbox.height, 360)
222
+ self.assertEqual(expanded_bbox.center, location.Coordinate(200, 250))
223
+
224
+ # Test case 5: No change.
225
+ bbox = location.BBox(100, 50, 300, 450)
226
+ expanded_bbox = bbox.expand(width_scale=1.0, height_scale=1.0)
227
+ self.assertEqual(expanded_bbox, bbox)
228
+
229
+ if __name__ == '__main__':
230
+ unittest.main()
langfun/core/__init__.py CHANGED
@@ -40,6 +40,9 @@ from langfun.core.component import context
40
40
  as_context = context
41
41
  use_context = context
42
42
 
43
+ # Invoke a callable object asynchronously.
44
+ from langfun.core.async_support import invoke_async
45
+
43
46
  # Shortcut function for overriding components attributes, usually for
44
47
  # override settings.
45
48
  from langfun.core.component import use_settings
@@ -199,6 +199,24 @@ class Action(pg.Object):
199
199
  """Returns last invocation. None if the action is not executed."""
200
200
  return self._invocation
201
201
 
202
+ async def acall(
203
+ self,
204
+ session: Optional['Session'] = None,
205
+ *,
206
+ show_progress: bool = True,
207
+ verbose: bool = False,
208
+ **kwargs
209
+ ) -> Any:
210
+ """Async version of `__call__`."""
211
+ # TODO(daiyip): implement native async calling.
212
+ return await lf.invoke_async(
213
+ self.__call__,
214
+ session,
215
+ show_progress=show_progress,
216
+ verbose=verbose,
217
+ **kwargs
218
+ )
219
+
202
220
  def __call__(
203
221
  self,
204
222
  session: Optional['Session'] = None,
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Tests for base action."""
15
15
 
16
+ import asyncio
16
17
  import unittest
17
18
 
18
19
  import langfun.core as lf
@@ -32,7 +33,7 @@ class Bar(action_lib.Action):
32
33
  session.add_metadata(note='bar')
33
34
  if self.simulate_action_error:
34
35
  raise ValueError('Bar error')
35
- return 2
36
+ return 2 + pg.contextual_value('baz', 0)
36
37
 
37
38
 
38
39
  class Foo(action_lib.Action):
@@ -237,6 +238,12 @@ class SessionTest(unittest.TestCase):
237
238
  json_str = session.to_json_str(save_ref_value=True)
238
239
  self.assertIsInstance(pg.from_json_str(json_str), action_lib.Session)
239
240
 
241
+ def test_acall(self):
242
+ bar = Bar()
243
+ with lf.context(baz=1):
244
+ r = bar.acall(lm=fake.StaticResponse('lm response'))
245
+ self.assertEqual(asyncio.run(r), 3)
246
+
240
247
  def test_failed_action(self):
241
248
  lm = fake.StaticResponse('lm response')
242
249
  foo = Foo(1, simulate_action_error=True)
@@ -0,0 +1,28 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Utility for async IO in Langfun."""
15
+
16
+ import asyncio
17
+ from typing import Any, Callable
18
+ import pyglove as pg
19
+
20
+
21
+ async def invoke_async(
22
+ callable_object: Callable[..., Any], *args, **kwargs
23
+ ) -> Any:
24
+ """Invokes a callable asynchronously with `lf.context` manager enabled."""
25
+ return await asyncio.to_thread(
26
+ # Enable `lf.context` manager for async calls.
27
+ pg.with_contextual_override(callable_object), *args, **kwargs
28
+ )
@@ -0,0 +1,39 @@
1
+ # Copyright 2025 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import asyncio
16
+ import time
17
+ import unittest
18
+
19
+ from langfun.core import async_support
20
+ import pyglove as pg
21
+
22
+
23
+ class AsyncSupportTest(unittest.TestCase):
24
+
25
+ def test_invoke_async(self):
26
+
27
+ def foo(x, *, y):
28
+ time.sleep(2)
29
+ return x + y + pg.contextual_value('z', 0)
30
+
31
+ t = time.time()
32
+ r = async_support.invoke_async(foo, 1, y=2)
33
+ self.assertLess(time.time() - t, 1)
34
+ with pg.contextual_override(z=3):
35
+ self.assertEqual(asyncio.run(r), 6)
36
+
37
+
38
+ if __name__ == '__main__':
39
+ unittest.main()
@@ -334,11 +334,12 @@ class ProgressBarTest(unittest.TestCase):
334
334
  sys.stderr.flush()
335
335
  time.sleep(1)
336
336
  self.assertIn('1/4', string_io.getvalue())
337
- self.assertIn('2/4', string_io.getvalue())
338
- self.assertIn('hello', string_io.getvalue())
339
- self.assertNotIn('3/4', string_io.getvalue())
340
- self.assertIn('4/4', string_io.getvalue())
341
- self.assertIn('x=1', string_io.getvalue())
337
+ # TODO(daiyip): Re-enable once flakiness is fixed.
338
+ # self.assertIn('2/4', string_io.getvalue())
339
+ # self.assertIn('hello', string_io.getvalue())
340
+ # self.assertNotIn('3/4', string_io.getvalue())
341
+ # self.assertIn('4/4', string_io.getvalue())
342
+ # self.assertIn('x=1', string_io.getvalue())
342
343
 
343
344
 
344
345
  class ConcurrentMapTest(unittest.TestCase):
@@ -24,6 +24,7 @@ import re
24
24
  import threading
25
25
  import time
26
26
  from typing import Annotated, Any, Callable, ClassVar, Iterator, Literal, Optional, Sequence, Tuple, Type, Union, final
27
+ from langfun.core import async_support
27
28
  from langfun.core import component
28
29
  from langfun.core import concurrent
29
30
  from langfun.core import console
@@ -922,6 +923,59 @@ class LanguageModel(component.Component):
922
923
  # Language model operations.
923
924
  #
924
925
 
926
+ async def asample(
927
+ self,
928
+ prompts: list[str | message_lib.Message],
929
+ *,
930
+ cache_seed: int = 0,
931
+ **kwargs,
932
+ ) -> message_lib.Message:
933
+ """Async version of sample."""
934
+ # TODO(daiyip): implement native async sampling.
935
+ return await async_support.invoke_async(
936
+ self.sample, prompts, cache_seed=cache_seed, **kwargs
937
+ )
938
+
939
+ async def acall(
940
+ self,
941
+ prompt: str | message_lib.Message,
942
+ *,
943
+ cache_seed: int = 0,
944
+ **kwargs
945
+ ) -> message_lib.Message:
946
+ """Async version of __call__."""
947
+ # TODO(daiyip): implement native async calling.
948
+ return await async_support.invoke_async(
949
+ self.__call__,
950
+ prompt,
951
+ cache_seed=cache_seed,
952
+ **kwargs
953
+ )
954
+
955
+ async def ascore(
956
+ self,
957
+ prompt: str | message_lib.Message | list[message_lib.Message],
958
+ completions: list[str | message_lib.Message],
959
+ **kwargs,
960
+ ) -> float:
961
+ """Async version of score."""
962
+ # TODO(daiyip): implement native async scoring.
963
+ return await async_support.invoke_async(
964
+ self.score,
965
+ prompt,
966
+ completions=completions,
967
+ **kwargs
968
+ )
969
+
970
+ async def atokenize(
971
+ self,
972
+ prompt: str | message_lib.Message,
973
+ **kwargs
974
+ ) -> list[int]:
975
+ """Async version of tokenize."""
976
+ # TODO(daiyip): implement native async tokenization.
977
+ return await async_support.invoke_async(self.tokenize, prompt, **kwargs)
978
+
925
979
  def sample(
926
980
  self,
927
981
  prompts: list[str | message_lib.Message],
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Tests for language model."""
15
15
 
16
+ import asyncio
16
17
  import contextlib
17
18
  import io
18
19
  import unittest
@@ -436,6 +437,14 @@ class LanguageModelTest(unittest.TestCase):
436
437
  ]
437
438
  )
438
439
 
440
+ def test_sample_async(self):
441
+ lm = MockModel(top_k=1)
442
+ response = asyncio.run(lm.asample(['foo', 'bar']))
443
+ self.assertIsInstance(response, list)
444
+ self.assertEqual(len(response), 2)
445
+ self.assertIsInstance(response[0], lm_lib.LMSamplingResult)
446
+ self.assertIsInstance(response[1], lm_lib.LMSamplingResult)
447
+
439
448
  def test_call(self):
440
449
  lm = MockModel(sampling_options=lm_lib.LMSamplingOptions(top_k=1))
441
450
  response = lm(prompt='foo')
@@ -453,6 +462,11 @@ class LanguageModelTest(unittest.TestCase):
453
462
  # Test override individual flags within sampling_options.
454
463
  self.assertEqual(lm('foo', top_k=2), 'foo' * 2)
455
464
 
465
+ def test_acall(self):
466
+ lm = MockModel(sampling_options=lm_lib.LMSamplingOptions(top_k=1))
467
+ response = asyncio.run(lm.acall(prompt='foo'))
468
+ self.assertEqual(response.text, 'foo')
469
+
456
470
  def test_using_cache(self):
457
471
  cache = in_memory.InMemory()
458
472
  lm = MockModel(cache=cache, top_k=1)
@@ -765,6 +779,21 @@ class LanguageModelTest(unittest.TestCase):
765
779
  if debug_mode & lm_lib.LMDebugMode.PROMPT:
766
780
  self.assertIn('[0] MODALITY OBJECTS SENT TO LM', debug_info)
767
781
 
782
+ def test_ascore(self):
783
+ lm = MockScoringModel()
784
+ self.assertEqual(
785
+ asyncio.run(
786
+ lm.ascore(
787
+ message_lib.UserMessage('hi'),
788
+ ['1', '2']
789
+ )
790
+ ),
791
+ [
792
+ lm_lib.LMScoringResult(score=-0.0),
793
+ lm_lib.LMScoringResult(score=-1.0),
794
+ ],
795
+ )
796
+
768
797
  def test_score_with_unmatched_prompt_and_completions(self):
769
798
  with self.assertRaises(ValueError):
770
799
  MockScoringModel().score(['hi',], ['1', '2', '3'])
@@ -828,6 +857,17 @@ class LanguageModelTest(unittest.TestCase):
828
857
  if debug_mode & lm_lib.LMDebugMode.PROMPT:
829
858
  self.assertIn('[0] MODALITY OBJECTS SENT TO LM', debug_info)
830
859
 
860
+ def test_atokenize(self):
861
+ lm = MockTokenizeModel()
862
+ self.assertEqual(
863
+ asyncio.run(
864
+ lm.atokenize(
865
+ message_lib.UserMessage('hi')
866
+ )
867
+ ),
868
+ [('hi', 0)],
869
+ )
870
+
831
871
  def test_tokenize_with_unsupported_model(self):
832
872
  with self.assertRaises(NotImplementedError):
833
873
  MockModel().tokenize('hi')
@@ -51,13 +51,16 @@ from langfun.core.structured.mapping import MappingError
51
51
  from langfun.core.structured.mapping import MappingExample
52
52
 
53
53
  from langfun.core.structured.parsing import parse
54
+ from langfun.core.structured.parsing import aparse
54
55
  from langfun.core.structured.parsing import call
56
+ from langfun.core.structured.parsing import acall
55
57
 
56
58
  from langfun.core.structured.querying import track_queries
57
59
  from langfun.core.structured.querying import QueryInvocation
58
60
 
59
61
  from langfun.core.structured.querying import LfQuery
60
62
  from langfun.core.structured.querying import query
63
+ from langfun.core.structured.querying import aquery
61
64
  from langfun.core.structured.querying import query_and_reduce
62
65
  from langfun.core.structured.querying import query_protocol
63
66
 
@@ -67,10 +70,14 @@ from langfun.core.structured.querying import query_reward
67
70
 
68
71
  from langfun.core.structured.description import describe
69
72
  from langfun.core.structured.completion import complete
73
+ from langfun.core.structured.completion import acomplete
70
74
 
71
75
  from langfun.core.structured.scoring import score
76
+ from langfun.core.structured.scoring import ascore
72
77
 
73
78
  from langfun.core.structured.tokenization import tokenize
79
+ from langfun.core.structured.tokenization import atokenize
80
+
74
81
 
75
82
  # Expose default examples for structured operations so users could refer to
76
83
  # them.
@@ -251,3 +251,31 @@ def complete(
251
251
 
252
252
  output = t(lm=lm, cache_seed=cache_seed, autofix_lm=autofix_lm or lm)
253
253
  return output if returns_message else output.result
254
+
255
+
256
+ async def acomplete(
257
+ input_value: pg.Symbolic,
258
+ default: Any = lf.RAISE_IF_HAS_ERROR,
259
+ *,
260
+ lm: lf.LanguageModel | None = None,
261
+ examples: list[mapping.MappingExample] | None = None,
262
+ cache_seed: int | None = 0,
263
+ autofix: int = 0,
264
+ autofix_lm: lf.LanguageModel | None = None,
265
+ returns_message: bool = False,
266
+ **kwargs,
267
+ ) -> Any:
268
+ """Async version of `lf.complete`."""
269
+ # TODO(daiyip): implement native async completion.
270
+ return await lf.invoke_async(
271
+ complete,
272
+ input_value,
273
+ default,
274
+ lm=lm,
275
+ examples=examples,
276
+ cache_seed=cache_seed,
277
+ autofix=autofix,
278
+ autofix_lm=autofix_lm,
279
+ returns_message=returns_message,
280
+ **kwargs
281
+ )
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  """Tests for langfun.core.structured.completion."""
15
15
 
16
+ import asyncio
16
17
  import inspect
17
18
  import unittest
18
19
 
@@ -620,6 +621,41 @@ class CompleteStructureTest(unittest.TestCase):
620
621
  ):
621
622
  self.assertIsNone(completion.complete(Activity.partial(), None))
622
623
 
624
+ def test_acomplete(self):
625
+ response = """
626
+ ```python
627
+ TripPlan(
628
+ place='San Francisco',
629
+ itineraries=[
630
+ Itinerary(
631
+ day=1,
632
+ type='daytime',
633
+ activities=[
634
+ Activity(description='Arrive in San Francisco and check into your hotel.'),
635
+ Activity(description='Take a walk around Fisherman\\'s Wharf and have dinner at one of the many seafood restaurants.'),
636
+ Activity(description='Visit Pier 39 and see the sea lions.'),
637
+ ],
638
+ ),
639
+ ]
640
+ )
641
+ ```
642
+ """
643
+ with lf.context(
644
+ lm=fake.StaticSequence(
645
+ [response],
646
+ ),
647
+ override_attrs=True,
648
+ ):
649
+ r = completion.acomplete(
650
+ TripPlan.partial(
651
+ place='San Francisco',
652
+ itineraries=[
653
+ Itinerary.partial(day=1),
654
+ ],
655
+ )
656
+ )
657
+ self.assertIsInstance(asyncio.run(r), TripPlan)
658
+
623
659
 
624
660
  if __name__ == '__main__':
625
661
  unittest.main()