langfun 0.1.2.dev202507250804__py3-none-any.whl → 0.1.2.dev202507270804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/assistant/capabilities/gui/__init__.py +36 -0
- langfun/assistant/capabilities/gui/bounding_box_parser.py +195 -0
- langfun/assistant/capabilities/gui/bounding_box_parser_test.py +313 -0
- langfun/assistant/capabilities/gui/drawing.py +242 -0
- langfun/assistant/capabilities/gui/drawing_test.py +103 -0
- langfun/assistant/capabilities/gui/location.py +288 -0
- langfun/assistant/capabilities/gui/location_test.py +230 -0
- langfun/core/__init__.py +3 -0
- langfun/core/agentic/action.py +18 -0
- langfun/core/agentic/action_test.py +8 -1
- langfun/core/async_support.py +28 -0
- langfun/core/async_support_test.py +39 -0
- langfun/core/concurrent_test.py +6 -5
- langfun/core/language_model.py +54 -0
- langfun/core/language_model_test.py +40 -0
- langfun/core/structured/__init__.py +7 -0
- langfun/core/structured/completion.py +28 -0
- langfun/core/structured/completion_test.py +36 -0
- langfun/core/structured/parsing.py +77 -0
- langfun/core/structured/parsing_test.py +15 -0
- langfun/core/structured/querying.py +42 -0
- langfun/core/structured/querying_test.py +10 -0
- langfun/core/structured/scoring.py +28 -0
- langfun/core/structured/scoring_test.py +8 -0
- langfun/core/structured/tokenization.py +24 -0
- langfun/core/structured/tokenization_test.py +8 -0
- {langfun-0.1.2.dev202507250804.dist-info → langfun-0.1.2.dev202507270804.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202507250804.dist-info → langfun-0.1.2.dev202507270804.dist-info}/RECORD +31 -22
- {langfun-0.1.2.dev202507250804.dist-info → langfun-0.1.2.dev202507270804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202507250804.dist-info → langfun-0.1.2.dev202507270804.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202507250804.dist-info → langfun-0.1.2.dev202507270804.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Structures for location objects in an image."""
|
|
15
|
+
|
|
16
|
+
import random
|
|
17
|
+
import unittest
|
|
18
|
+
from langfun.assistant.capabilities.gui import location
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CoordinateTest(unittest.TestCase):
|
|
22
|
+
|
|
23
|
+
def test_basics(self):
|
|
24
|
+
pt = location.Coordinate(1, 2)
|
|
25
|
+
self.assertEqual(pt.x, 1)
|
|
26
|
+
self.assertEqual(pt.y, 2)
|
|
27
|
+
self.assertEqual(pt.as_tuple(), (1, 2))
|
|
28
|
+
|
|
29
|
+
def test_random(self):
|
|
30
|
+
bound = location.BBox(0, 0, 10, 10)
|
|
31
|
+
rand = random.Random(0)
|
|
32
|
+
for _ in range(10):
|
|
33
|
+
pt = location.Coordinate.random(bound, rand=rand)
|
|
34
|
+
self.assertIsInstance(pt, location.Coordinate)
|
|
35
|
+
self.assertIn(pt, bound)
|
|
36
|
+
|
|
37
|
+
def test_from_value(self):
|
|
38
|
+
pt = location.Coordinate.from_value((1, 2))
|
|
39
|
+
self.assertEqual(pt.x, 1)
|
|
40
|
+
self.assertEqual(pt.y, 2)
|
|
41
|
+
self.assertIs(location.Coordinate.from_value(pt), pt)
|
|
42
|
+
|
|
43
|
+
def test_arithmetic(self):
|
|
44
|
+
pt1 = location.Coordinate(1, 2)
|
|
45
|
+
pt2 = location.Coordinate(3, 4)
|
|
46
|
+
self.assertEqual(pt1 + pt2, location.Coordinate(4, 6))
|
|
47
|
+
self.assertEqual(pt2 + pt1, location.Coordinate(4, 6))
|
|
48
|
+
self.assertEqual(pt1 - pt2, location.Coordinate(-2, -2))
|
|
49
|
+
self.assertEqual(pt2 - pt1, location.Coordinate(2, 2))
|
|
50
|
+
|
|
51
|
+
pt1 = location.Coordinate(1, 2)
|
|
52
|
+
pt2 = (3, 4)
|
|
53
|
+
self.assertEqual(pt1 + pt2, location.Coordinate(4, 6))
|
|
54
|
+
self.assertEqual(pt2 + pt1, location.Coordinate(4, 6))
|
|
55
|
+
self.assertEqual(pt1 - pt2, location.Coordinate(-2, -2))
|
|
56
|
+
self.assertEqual(pt2 - pt1, location.Coordinate(2, 2))
|
|
57
|
+
|
|
58
|
+
self.assertEqual(pt1 * 2, location.Coordinate(2, 4))
|
|
59
|
+
self.assertEqual(2 * pt1, location.Coordinate(2, 4))
|
|
60
|
+
|
|
61
|
+
def test_distance_to(self):
|
|
62
|
+
pt1 = location.Coordinate(1, 2)
|
|
63
|
+
pt2 = location.Coordinate(3, 4)
|
|
64
|
+
# compare float values with a tolerance of 1e-6
|
|
65
|
+
self.assertAlmostEqual(pt1.distance_to(pt2), 2.8284271247461903, delta=1e-6)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class BBoxTest(unittest.TestCase):
|
|
69
|
+
|
|
70
|
+
def test_invalid_bbox_creation(self):
|
|
71
|
+
with self.assertRaisesRegex(AssertionError, '.*'):
|
|
72
|
+
location.BBox(100, 50, 100, 450) # Zero width.
|
|
73
|
+
with self.assertRaisesRegex(AssertionError, '.*'):
|
|
74
|
+
location.BBox(100, 50, 300, 50) # Zero height.
|
|
75
|
+
with self.assertRaisesRegex(AssertionError, '.*'):
|
|
76
|
+
location.BBox(100, 50, 50, 50) # Zero width and height
|
|
77
|
+
|
|
78
|
+
def test_basics(self):
|
|
79
|
+
bbox = location.BBox(100, 50, 300, 450)
|
|
80
|
+
self.assertEqual(bbox.x, 100)
|
|
81
|
+
self.assertEqual(bbox.y, 50)
|
|
82
|
+
self.assertEqual(bbox.right, 300)
|
|
83
|
+
self.assertEqual(bbox.bottom, 450)
|
|
84
|
+
self.assertEqual(bbox.left, 100)
|
|
85
|
+
self.assertEqual(bbox.top, 50)
|
|
86
|
+
self.assertEqual(bbox.width, 200)
|
|
87
|
+
self.assertEqual(bbox.height, 400)
|
|
88
|
+
|
|
89
|
+
self.assertEqual(bbox.center, location.Coordinate(200, 250))
|
|
90
|
+
self.assertEqual(bbox.top_left, location.Coordinate(100, 50))
|
|
91
|
+
self.assertEqual(bbox.bottom_right, location.Coordinate(300, 450))
|
|
92
|
+
self.assertEqual(bbox.area, 80000)
|
|
93
|
+
self.assertEqual(bbox.as_tuple(), (100, 50, 300, 450))
|
|
94
|
+
|
|
95
|
+
def test_contains(self):
|
|
96
|
+
bbox = location.BBox(100, 50, 300, 450)
|
|
97
|
+
self.assertIn((120, 200), bbox)
|
|
98
|
+
self.assertIn(location.Coordinate(120, 200), bbox)
|
|
99
|
+
self.assertNotIn(location.Coordinate(80, 200), bbox)
|
|
100
|
+
self.assertNotIn(location.Coordinate(120, 20), bbox)
|
|
101
|
+
self.assertNotIn(location.Coordinate(320, 200), bbox)
|
|
102
|
+
self.assertNotIn(location.Coordinate(120, 470), bbox)
|
|
103
|
+
|
|
104
|
+
self.assertIn((100, 50, 300, 450), bbox)
|
|
105
|
+
self.assertIn(location.BBox(100, 50, 300, 450), bbox)
|
|
106
|
+
self.assertIn(location.BBox(120, 70, 260, 410), bbox)
|
|
107
|
+
self.assertNotIn(location.BBox(60, 70, 260, 410), bbox)
|
|
108
|
+
self.assertNotIn(location.BBox(120, 45, 260, 410), bbox)
|
|
109
|
+
self.assertNotIn(location.BBox(120, 60, 310, 410), bbox)
|
|
110
|
+
self.assertNotIn(location.BBox(120, 60, 260, 470), bbox)
|
|
111
|
+
self.assertNotIn(location.BBox(0, 0, 400, 500), bbox)
|
|
112
|
+
|
|
113
|
+
with self.assertRaisesRegex(ValueError, 'Invalid tuple size'):
|
|
114
|
+
_ = (1, 2, 3) in bbox
|
|
115
|
+
|
|
116
|
+
with self.assertRaisesRegex(ValueError, 'Invalid type'):
|
|
117
|
+
_ = 'abc' in bbox # pytype: disable=unsupported-operands
|
|
118
|
+
|
|
119
|
+
def test_intersects(self):
|
|
120
|
+
bbox = location.BBox(100, 50, 300, 450)
|
|
121
|
+
self.assertTrue(bbox.intersects(location.BBox(100, 50, 300, 450)))
|
|
122
|
+
self.assertTrue(bbox.intersects(location.BBox(120, 70, 260, 410)))
|
|
123
|
+
self.assertTrue(bbox.intersects(location.BBox(60, 70, 260, 470)))
|
|
124
|
+
self.assertFalse(bbox.intersects(location.BBox(0, 70, 90, 410)))
|
|
125
|
+
self.assertFalse(bbox.intersects(location.BBox(60, 10, 400, 30)))
|
|
126
|
+
|
|
127
|
+
def test_clip(self):
|
|
128
|
+
# BBox within image bounds
|
|
129
|
+
bbox = location.BBox(100, 50, 300, 450)
|
|
130
|
+
clipped_bbox = bbox.clip((400, 500))
|
|
131
|
+
self.assertEqual(clipped_bbox, bbox)
|
|
132
|
+
|
|
133
|
+
# BBox exceeds right and bottom bounds
|
|
134
|
+
bbox = location.BBox(100, 50, 500, 600)
|
|
135
|
+
clipped_bbox = bbox.clip((400, 500))
|
|
136
|
+
self.assertEqual(clipped_bbox, location.BBox(100, 50, 400, 500))
|
|
137
|
+
|
|
138
|
+
# BBox exceeds left and top bounds
|
|
139
|
+
bbox = location.BBox(-10, -20, 300, 450)
|
|
140
|
+
clipped_bbox = bbox.clip((400, 500))
|
|
141
|
+
self.assertEqual(clipped_bbox, location.BBox(0, 0, 300, 450))
|
|
142
|
+
|
|
143
|
+
# BBox larger than the image
|
|
144
|
+
bbox = location.BBox(-10, -20, 800, 700)
|
|
145
|
+
clipped_bbox = bbox.clip((400, 500))
|
|
146
|
+
self.assertEqual(clipped_bbox, location.BBox(0, 0, 400, 500))
|
|
147
|
+
|
|
148
|
+
# BBox starts outside the image boundaries (bottom/right)
|
|
149
|
+
bbox = location.BBox(500, 600, 700, 800)
|
|
150
|
+
clipped_bbox = bbox.clip((400, 500))
|
|
151
|
+
self.assertIsNone(clipped_bbox)
|
|
152
|
+
|
|
153
|
+
def test_random(self):
|
|
154
|
+
bound = location.BBox(0, 0, 800, 600)
|
|
155
|
+
rand = random.Random(0)
|
|
156
|
+
for _ in range(10):
|
|
157
|
+
bbox = location.BBox.random(bound, rand=rand)
|
|
158
|
+
self.assertIsInstance(bbox, location.BBox)
|
|
159
|
+
self.assertIn(bbox, bound)
|
|
160
|
+
|
|
161
|
+
with self.assertRaisesRegex(
|
|
162
|
+
ValueError, 'Minimum width or height is larger than the bound'
|
|
163
|
+
):
|
|
164
|
+
_ = location.BBox.random(location.BBox(0, 0, 100, 100), min_width=110)
|
|
165
|
+
|
|
166
|
+
def test_matches(self):
|
|
167
|
+
bbox1 = location.BBox(100, 50, 300, 450)
|
|
168
|
+
self.assertFalse(bbox1.matches(location.BBox(10, 10, 11, 11)))
|
|
169
|
+
|
|
170
|
+
bbox2 = location.BBox(100, 50, 300, 450)
|
|
171
|
+
self.assertTrue(bbox1.matches(bbox2))
|
|
172
|
+
self.assertTrue(bbox2.matches(bbox1))
|
|
173
|
+
|
|
174
|
+
bbox3 = location.BBox(110, 40, 290, 450)
|
|
175
|
+
self.assertTrue(bbox1.matches(bbox3))
|
|
176
|
+
self.assertFalse(bbox3.matches(bbox1, area_diff_threshold=0.05))
|
|
177
|
+
self.assertFalse(bbox3.matches(bbox1, max_center_distance=1))
|
|
178
|
+
|
|
179
|
+
def test_expand(self):
|
|
180
|
+
# Test case 1: Expanding both width and height.
|
|
181
|
+
bbox = location.BBox(100, 50, 300, 450)
|
|
182
|
+
expanded_bbox = bbox.expand(width_scale=1.2, height_scale=1.1)
|
|
183
|
+
self.assertEqual(expanded_bbox.x, 80)
|
|
184
|
+
self.assertEqual(expanded_bbox.y, 30)
|
|
185
|
+
self.assertEqual(expanded_bbox.right, 320)
|
|
186
|
+
self.assertEqual(expanded_bbox.bottom, 470)
|
|
187
|
+
self.assertEqual(expanded_bbox.width, 240)
|
|
188
|
+
self.assertEqual(expanded_bbox.height, 440)
|
|
189
|
+
self.assertEqual(expanded_bbox.center, location.Coordinate(200, 250))
|
|
190
|
+
|
|
191
|
+
# Test case 2: Expanding only width.
|
|
192
|
+
bbox = location.BBox(100, 50, 300, 450)
|
|
193
|
+
expanded_bbox = bbox.expand(width_scale=1.2)
|
|
194
|
+
self.assertEqual(expanded_bbox.x, 80)
|
|
195
|
+
self.assertEqual(expanded_bbox.y, 50)
|
|
196
|
+
self.assertEqual(expanded_bbox.right, 320)
|
|
197
|
+
self.assertEqual(expanded_bbox.bottom, 450)
|
|
198
|
+
self.assertEqual(expanded_bbox.width, 240)
|
|
199
|
+
self.assertEqual(expanded_bbox.height, 400)
|
|
200
|
+
self.assertEqual(expanded_bbox.center, location.Coordinate(200, 250))
|
|
201
|
+
|
|
202
|
+
# Test case 3: Expanding only height.
|
|
203
|
+
bbox = location.BBox(100, 50, 300, 450)
|
|
204
|
+
expanded_bbox = bbox.expand(height_scale=1.1)
|
|
205
|
+
self.assertEqual(expanded_bbox.x, 100)
|
|
206
|
+
self.assertEqual(expanded_bbox.y, 30)
|
|
207
|
+
self.assertEqual(expanded_bbox.right, 300)
|
|
208
|
+
self.assertEqual(expanded_bbox.bottom, 470)
|
|
209
|
+
self.assertEqual(expanded_bbox.width, 200)
|
|
210
|
+
self.assertEqual(expanded_bbox.height, 440)
|
|
211
|
+
self.assertEqual(expanded_bbox.center, location.Coordinate(200, 250))
|
|
212
|
+
|
|
213
|
+
# Test case 4: Shrinking.
|
|
214
|
+
bbox = location.BBox(100, 50, 300, 450)
|
|
215
|
+
expanded_bbox = bbox.expand(width_scale=0.8, height_scale=0.9)
|
|
216
|
+
self.assertEqual(expanded_bbox.x, 120)
|
|
217
|
+
self.assertEqual(expanded_bbox.y, 70)
|
|
218
|
+
self.assertEqual(expanded_bbox.right, 280)
|
|
219
|
+
self.assertEqual(expanded_bbox.bottom, 430)
|
|
220
|
+
self.assertEqual(expanded_bbox.width, 160)
|
|
221
|
+
self.assertEqual(expanded_bbox.height, 360)
|
|
222
|
+
self.assertEqual(expanded_bbox.center, location.Coordinate(200, 250))
|
|
223
|
+
|
|
224
|
+
# Test case 5: No change.
|
|
225
|
+
bbox = location.BBox(100, 50, 300, 450)
|
|
226
|
+
expanded_bbox = bbox.expand(width_scale=1.0, height_scale=1.0)
|
|
227
|
+
self.assertEqual(expanded_bbox, bbox)
|
|
228
|
+
|
|
229
|
+
if __name__ == '__main__':
|
|
230
|
+
unittest.main()
|
langfun/core/__init__.py
CHANGED
|
@@ -40,6 +40,9 @@ from langfun.core.component import context
|
|
|
40
40
|
as_context = context
|
|
41
41
|
use_context = context
|
|
42
42
|
|
|
43
|
+
# Invoke a callable object asynchronously.
|
|
44
|
+
from langfun.core.async_support import invoke_async
|
|
45
|
+
|
|
43
46
|
# Shortcut function for overriding components attributes, usually for
|
|
44
47
|
# override settings.
|
|
45
48
|
from langfun.core.component import use_settings
|
langfun/core/agentic/action.py
CHANGED
|
@@ -199,6 +199,24 @@ class Action(pg.Object):
|
|
|
199
199
|
"""Returns last invocation. None if the action is not executed."""
|
|
200
200
|
return self._invocation
|
|
201
201
|
|
|
202
|
+
async def acall(
|
|
203
|
+
self,
|
|
204
|
+
session: Optional['Session'] = None,
|
|
205
|
+
*,
|
|
206
|
+
show_progress: bool = True,
|
|
207
|
+
verbose: bool = False,
|
|
208
|
+
**kwargs
|
|
209
|
+
) -> Any:
|
|
210
|
+
"""Async version of `__call__`."""
|
|
211
|
+
# TODO(daiyip): implement native async calling.
|
|
212
|
+
return await lf.invoke_async(
|
|
213
|
+
self.__call__,
|
|
214
|
+
session,
|
|
215
|
+
show_progress=show_progress,
|
|
216
|
+
verbose=verbose,
|
|
217
|
+
**kwargs
|
|
218
|
+
)
|
|
219
|
+
|
|
202
220
|
def __call__(
|
|
203
221
|
self,
|
|
204
222
|
session: Optional['Session'] = None,
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
"""Tests for base action."""
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
16
17
|
import unittest
|
|
17
18
|
|
|
18
19
|
import langfun.core as lf
|
|
@@ -32,7 +33,7 @@ class Bar(action_lib.Action):
|
|
|
32
33
|
session.add_metadata(note='bar')
|
|
33
34
|
if self.simulate_action_error:
|
|
34
35
|
raise ValueError('Bar error')
|
|
35
|
-
return 2
|
|
36
|
+
return 2 + pg.contextual_value('baz', 0)
|
|
36
37
|
|
|
37
38
|
|
|
38
39
|
class Foo(action_lib.Action):
|
|
@@ -237,6 +238,12 @@ class SessionTest(unittest.TestCase):
|
|
|
237
238
|
json_str = session.to_json_str(save_ref_value=True)
|
|
238
239
|
self.assertIsInstance(pg.from_json_str(json_str), action_lib.Session)
|
|
239
240
|
|
|
241
|
+
def test_acall(self):
|
|
242
|
+
bar = Bar()
|
|
243
|
+
with lf.context(baz=1):
|
|
244
|
+
r = bar.acall(lm=fake.StaticResponse('lm response'))
|
|
245
|
+
self.assertEqual(asyncio.run(r), 3)
|
|
246
|
+
|
|
240
247
|
def test_failed_action(self):
|
|
241
248
|
lm = fake.StaticResponse('lm response')
|
|
242
249
|
foo = Foo(1, simulate_action_error=True)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""Utility for async IO in Langfun."""
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
from typing import Any, Callable
|
|
18
|
+
import pyglove as pg
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
async def invoke_async(
|
|
22
|
+
callable_object: Callable[..., Any], *args, **kwargs
|
|
23
|
+
) -> Any:
|
|
24
|
+
"""Invokes a callable asynchronously with `lf.context` manager enabled."""
|
|
25
|
+
return await asyncio.to_thread(
|
|
26
|
+
# Enable `lf.context` manager for async calls.
|
|
27
|
+
pg.with_contextual_override(callable_object), *args, **kwargs
|
|
28
|
+
)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# Copyright 2025 The Langfun Authors
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import time
|
|
17
|
+
import unittest
|
|
18
|
+
|
|
19
|
+
from langfun.core import async_support
|
|
20
|
+
import pyglove as pg
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AsyncSupportTest(unittest.TestCase):
|
|
24
|
+
|
|
25
|
+
def test_invoke_async(self):
|
|
26
|
+
|
|
27
|
+
def foo(x, *, y):
|
|
28
|
+
time.sleep(2)
|
|
29
|
+
return x + y + pg.contextual_value('z', 0)
|
|
30
|
+
|
|
31
|
+
t = time.time()
|
|
32
|
+
r = async_support.invoke_async(foo, 1, y=2)
|
|
33
|
+
self.assertLess(time.time() - t, 1)
|
|
34
|
+
with pg.contextual_override(z=3):
|
|
35
|
+
self.assertEqual(asyncio.run(r), 6)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
if __name__ == '__main__':
|
|
39
|
+
unittest.main()
|
langfun/core/concurrent_test.py
CHANGED
|
@@ -334,11 +334,12 @@ class ProgressBarTest(unittest.TestCase):
|
|
|
334
334
|
sys.stderr.flush()
|
|
335
335
|
time.sleep(1)
|
|
336
336
|
self.assertIn('1/4', string_io.getvalue())
|
|
337
|
-
|
|
338
|
-
self.assertIn('
|
|
339
|
-
self.
|
|
340
|
-
self.
|
|
341
|
-
self.assertIn('
|
|
337
|
+
# TODO(daiyip): Re-enable once flakiness is fixed.
|
|
338
|
+
# self.assertIn('2/4', string_io.getvalue())
|
|
339
|
+
# self.assertIn('hello', string_io.getvalue())
|
|
340
|
+
# self.assertNotIn('3/4', string_io.getvalue())
|
|
341
|
+
# self.assertIn('4/4', string_io.getvalue())
|
|
342
|
+
# self.assertIn('x=1', string_io.getvalue())
|
|
342
343
|
|
|
343
344
|
|
|
344
345
|
class ConcurrentMapTest(unittest.TestCase):
|
langfun/core/language_model.py
CHANGED
|
@@ -24,6 +24,7 @@ import re
|
|
|
24
24
|
import threading
|
|
25
25
|
import time
|
|
26
26
|
from typing import Annotated, Any, Callable, ClassVar, Iterator, Literal, Optional, Sequence, Tuple, Type, Union, final
|
|
27
|
+
from langfun.core import async_support
|
|
27
28
|
from langfun.core import component
|
|
28
29
|
from langfun.core import concurrent
|
|
29
30
|
from langfun.core import console
|
|
@@ -922,6 +923,59 @@ class LanguageModel(component.Component):
|
|
|
922
923
|
# Language model operations.
|
|
923
924
|
#
|
|
924
925
|
|
|
926
|
+
async def asample(
|
|
927
|
+
self,
|
|
928
|
+
prompts: list[str | message_lib.Message],
|
|
929
|
+
*,
|
|
930
|
+
cache_seed: int = 0,
|
|
931
|
+
**kwargs,
|
|
932
|
+
) -> message_lib.Message:
|
|
933
|
+
"""Async version of sample."""
|
|
934
|
+
# TODO(daiyip): implement native async sampling.
|
|
935
|
+
return await async_support.invoke_async(
|
|
936
|
+
self.sample, prompts, cache_seed=cache_seed, **kwargs
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
async def acall(
|
|
940
|
+
self,
|
|
941
|
+
prompt: str | message_lib.Message,
|
|
942
|
+
*,
|
|
943
|
+
cache_seed: int = 0,
|
|
944
|
+
**kwargs
|
|
945
|
+
) -> message_lib.Message:
|
|
946
|
+
"""Async version of __call__."""
|
|
947
|
+
# TODO(daiyip): implement native async calling.
|
|
948
|
+
return await async_support.invoke_async(
|
|
949
|
+
self.__call__,
|
|
950
|
+
prompt,
|
|
951
|
+
cache_seed=cache_seed,
|
|
952
|
+
**kwargs
|
|
953
|
+
)
|
|
954
|
+
|
|
955
|
+
async def ascore(
|
|
956
|
+
self,
|
|
957
|
+
prompt: str | message_lib.Message | list[message_lib.Message],
|
|
958
|
+
completions: list[str | message_lib.Message],
|
|
959
|
+
**kwargs,
|
|
960
|
+
) -> float:
|
|
961
|
+
"""Async version of score."""
|
|
962
|
+
# TODO(daiyip): implement native async scoring.
|
|
963
|
+
return await async_support.invoke_async(
|
|
964
|
+
self.score,
|
|
965
|
+
prompt,
|
|
966
|
+
completions=completions,
|
|
967
|
+
**kwargs
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
async def atokenize(
|
|
971
|
+
self,
|
|
972
|
+
prompt: str | message_lib.Message,
|
|
973
|
+
**kwargs
|
|
974
|
+
) -> list[int]:
|
|
975
|
+
"""Async version of tokenize."""
|
|
976
|
+
# TODO(daiyip): implement native async tokenization.
|
|
977
|
+
return await async_support.invoke_async(self.tokenize, prompt, **kwargs)
|
|
978
|
+
|
|
925
979
|
def sample(
|
|
926
980
|
self,
|
|
927
981
|
prompts: list[str | message_lib.Message],
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
"""Tests for language model."""
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
16
17
|
import contextlib
|
|
17
18
|
import io
|
|
18
19
|
import unittest
|
|
@@ -436,6 +437,14 @@ class LanguageModelTest(unittest.TestCase):
|
|
|
436
437
|
]
|
|
437
438
|
)
|
|
438
439
|
|
|
440
|
+
def test_sample_async(self):
|
|
441
|
+
lm = MockModel(top_k=1)
|
|
442
|
+
response = asyncio.run(lm.asample(['foo', 'bar']))
|
|
443
|
+
self.assertIsInstance(response, list)
|
|
444
|
+
self.assertEqual(len(response), 2)
|
|
445
|
+
self.assertIsInstance(response[0], lm_lib.LMSamplingResult)
|
|
446
|
+
self.assertIsInstance(response[1], lm_lib.LMSamplingResult)
|
|
447
|
+
|
|
439
448
|
def test_call(self):
|
|
440
449
|
lm = MockModel(sampling_options=lm_lib.LMSamplingOptions(top_k=1))
|
|
441
450
|
response = lm(prompt='foo')
|
|
@@ -453,6 +462,11 @@ class LanguageModelTest(unittest.TestCase):
|
|
|
453
462
|
# Test override individual flags within sampling_options.
|
|
454
463
|
self.assertEqual(lm('foo', top_k=2), 'foo' * 2)
|
|
455
464
|
|
|
465
|
+
def test_acall(self):
|
|
466
|
+
lm = MockModel(sampling_options=lm_lib.LMSamplingOptions(top_k=1))
|
|
467
|
+
response = asyncio.run(lm.acall(prompt='foo'))
|
|
468
|
+
self.assertEqual(response.text, 'foo')
|
|
469
|
+
|
|
456
470
|
def test_using_cache(self):
|
|
457
471
|
cache = in_memory.InMemory()
|
|
458
472
|
lm = MockModel(cache=cache, top_k=1)
|
|
@@ -765,6 +779,21 @@ class LanguageModelTest(unittest.TestCase):
|
|
|
765
779
|
if debug_mode & lm_lib.LMDebugMode.PROMPT:
|
|
766
780
|
self.assertIn('[0] MODALITY OBJECTS SENT TO LM', debug_info)
|
|
767
781
|
|
|
782
|
+
def test_ascore(self):
|
|
783
|
+
lm = MockScoringModel()
|
|
784
|
+
self.assertEqual(
|
|
785
|
+
asyncio.run(
|
|
786
|
+
lm.ascore(
|
|
787
|
+
message_lib.UserMessage('hi'),
|
|
788
|
+
['1', '2']
|
|
789
|
+
)
|
|
790
|
+
),
|
|
791
|
+
[
|
|
792
|
+
lm_lib.LMScoringResult(score=-0.0),
|
|
793
|
+
lm_lib.LMScoringResult(score=-1.0),
|
|
794
|
+
],
|
|
795
|
+
)
|
|
796
|
+
|
|
768
797
|
def test_score_with_unmatched_prompt_and_completions(self):
|
|
769
798
|
with self.assertRaises(ValueError):
|
|
770
799
|
MockScoringModel().score(['hi',], ['1', '2', '3'])
|
|
@@ -828,6 +857,17 @@ class LanguageModelTest(unittest.TestCase):
|
|
|
828
857
|
if debug_mode & lm_lib.LMDebugMode.PROMPT:
|
|
829
858
|
self.assertIn('[0] MODALITY OBJECTS SENT TO LM', debug_info)
|
|
830
859
|
|
|
860
|
+
def test_atokenize(self):
|
|
861
|
+
lm = MockTokenizeModel()
|
|
862
|
+
self.assertEqual(
|
|
863
|
+
asyncio.run(
|
|
864
|
+
lm.atokenize(
|
|
865
|
+
message_lib.UserMessage('hi')
|
|
866
|
+
)
|
|
867
|
+
),
|
|
868
|
+
[('hi', 0)],
|
|
869
|
+
)
|
|
870
|
+
|
|
831
871
|
def test_tokenize_with_unsupported_model(self):
|
|
832
872
|
with self.assertRaises(NotImplementedError):
|
|
833
873
|
MockModel().tokenize('hi')
|
|
@@ -51,13 +51,16 @@ from langfun.core.structured.mapping import MappingError
|
|
|
51
51
|
from langfun.core.structured.mapping import MappingExample
|
|
52
52
|
|
|
53
53
|
from langfun.core.structured.parsing import parse
|
|
54
|
+
from langfun.core.structured.parsing import aparse
|
|
54
55
|
from langfun.core.structured.parsing import call
|
|
56
|
+
from langfun.core.structured.parsing import acall
|
|
55
57
|
|
|
56
58
|
from langfun.core.structured.querying import track_queries
|
|
57
59
|
from langfun.core.structured.querying import QueryInvocation
|
|
58
60
|
|
|
59
61
|
from langfun.core.structured.querying import LfQuery
|
|
60
62
|
from langfun.core.structured.querying import query
|
|
63
|
+
from langfun.core.structured.querying import aquery
|
|
61
64
|
from langfun.core.structured.querying import query_and_reduce
|
|
62
65
|
from langfun.core.structured.querying import query_protocol
|
|
63
66
|
|
|
@@ -67,10 +70,14 @@ from langfun.core.structured.querying import query_reward
|
|
|
67
70
|
|
|
68
71
|
from langfun.core.structured.description import describe
|
|
69
72
|
from langfun.core.structured.completion import complete
|
|
73
|
+
from langfun.core.structured.completion import acomplete
|
|
70
74
|
|
|
71
75
|
from langfun.core.structured.scoring import score
|
|
76
|
+
from langfun.core.structured.scoring import ascore
|
|
72
77
|
|
|
73
78
|
from langfun.core.structured.tokenization import tokenize
|
|
79
|
+
from langfun.core.structured.tokenization import atokenize
|
|
80
|
+
|
|
74
81
|
|
|
75
82
|
# Expose default examples for structured operations so users could refer to
|
|
76
83
|
# them.
|
|
@@ -251,3 +251,31 @@ def complete(
|
|
|
251
251
|
|
|
252
252
|
output = t(lm=lm, cache_seed=cache_seed, autofix_lm=autofix_lm or lm)
|
|
253
253
|
return output if returns_message else output.result
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
async def acomplete(
|
|
257
|
+
input_value: pg.Symbolic,
|
|
258
|
+
default: Any = lf.RAISE_IF_HAS_ERROR,
|
|
259
|
+
*,
|
|
260
|
+
lm: lf.LanguageModel | None = None,
|
|
261
|
+
examples: list[mapping.MappingExample] | None = None,
|
|
262
|
+
cache_seed: int | None = 0,
|
|
263
|
+
autofix: int = 0,
|
|
264
|
+
autofix_lm: lf.LanguageModel | None = None,
|
|
265
|
+
returns_message: bool = False,
|
|
266
|
+
**kwargs,
|
|
267
|
+
) -> Any:
|
|
268
|
+
"""Async version of `lf.complete`."""
|
|
269
|
+
# TODO(daiyip): implement native async completion.
|
|
270
|
+
return await lf.invoke_async(
|
|
271
|
+
complete,
|
|
272
|
+
input_value,
|
|
273
|
+
default,
|
|
274
|
+
lm=lm,
|
|
275
|
+
examples=examples,
|
|
276
|
+
cache_seed=cache_seed,
|
|
277
|
+
autofix=autofix,
|
|
278
|
+
autofix_lm=autofix_lm,
|
|
279
|
+
returns_message=returns_message,
|
|
280
|
+
**kwargs
|
|
281
|
+
)
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
"""Tests for langfun.core.structured.completion."""
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
16
17
|
import inspect
|
|
17
18
|
import unittest
|
|
18
19
|
|
|
@@ -620,6 +621,41 @@ class CompleteStructureTest(unittest.TestCase):
|
|
|
620
621
|
):
|
|
621
622
|
self.assertIsNone(completion.complete(Activity.partial(), None))
|
|
622
623
|
|
|
624
|
+
def test_acomplete(self):
|
|
625
|
+
response = """
|
|
626
|
+
```python
|
|
627
|
+
TripPlan(
|
|
628
|
+
place='San Francisco',
|
|
629
|
+
itineraries=[
|
|
630
|
+
Itinerary(
|
|
631
|
+
day=1,
|
|
632
|
+
type='daytime',
|
|
633
|
+
activities=[
|
|
634
|
+
Activity(description='Arrive in San Francisco and check into your hotel.'),
|
|
635
|
+
Activity(description='Take a walk around Fisherman\\'s Wharf and have dinner at one of the many seafood restaurants.'),
|
|
636
|
+
Activity(description='Visit Pier 39 and see the sea lions.'),
|
|
637
|
+
],
|
|
638
|
+
),
|
|
639
|
+
]
|
|
640
|
+
)
|
|
641
|
+
```
|
|
642
|
+
"""
|
|
643
|
+
with lf.context(
|
|
644
|
+
lm=fake.StaticSequence(
|
|
645
|
+
[response],
|
|
646
|
+
),
|
|
647
|
+
override_attrs=True,
|
|
648
|
+
):
|
|
649
|
+
r = completion.acomplete(
|
|
650
|
+
TripPlan.partial(
|
|
651
|
+
place='San Francisco',
|
|
652
|
+
itineraries=[
|
|
653
|
+
Itinerary.partial(day=1),
|
|
654
|
+
],
|
|
655
|
+
)
|
|
656
|
+
)
|
|
657
|
+
self.assertIsInstance(asyncio.run(r), TripPlan)
|
|
658
|
+
|
|
623
659
|
|
|
624
660
|
if __name__ == '__main__':
|
|
625
661
|
unittest.main()
|