langfun 0.1.2.dev202507140805__py3-none-any.whl → 0.1.2.dev202507150805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/concurrent_test.py +3 -0
- langfun/core/eval/v2/checkpointing_test.py +11 -11
- langfun/core/eval/v2/evaluation_test.py +1 -1
- langfun/core/eval/v2/experiment_test.py +7 -7
- langfun/core/eval/v2/progress_tracking_test.py +3 -3
- langfun/core/eval/v2/reporting_test.py +5 -5
- langfun/core/eval/v2/runners_test.py +11 -11
- langfun/core/language_model.py +5 -5
- langfun/core/language_model_test.py +12 -14
- langfun/core/llms/compositional.py +1 -1
- langfun/core/message.py +2 -4
- langfun/core/message_test.py +5 -5
- langfun/core/structured/mapping.py +11 -1
- langfun/core/structured/querying_test.py +15 -0
- langfun/core/structured/schema.py +10 -0
- {langfun-0.1.2.dev202507140805.dist-info → langfun-0.1.2.dev202507150805.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202507140805.dist-info → langfun-0.1.2.dev202507150805.dist-info}/RECORD +20 -20
- {langfun-0.1.2.dev202507140805.dist-info → langfun-0.1.2.dev202507150805.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202507140805.dist-info → langfun-0.1.2.dev202507150805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202507140805.dist-info → langfun-0.1.2.dev202507150805.dist-info}/top_level.txt +0 -0
langfun/core/concurrent_test.py
CHANGED
@@ -17,6 +17,7 @@ import collections
|
|
17
17
|
from concurrent import futures
|
18
18
|
import contextlib
|
19
19
|
import io
|
20
|
+
import sys
|
20
21
|
import time
|
21
22
|
import unittest
|
22
23
|
from langfun.core import component
|
@@ -330,6 +331,8 @@ class ProgressBarTest(unittest.TestCase):
|
|
330
331
|
with self.assertRaisesRegex(ValueError, 'Unsupported status'):
|
331
332
|
concurrent.ProgressBar.update(bar_id, 0, status=1)
|
332
333
|
concurrent.ProgressBar.uninstall(bar_id)
|
334
|
+
sys.stderr.flush()
|
335
|
+
time.sleep(1)
|
333
336
|
self.assertIn('1/4', string_io.getvalue())
|
334
337
|
self.assertIn('2/4', string_io.getvalue())
|
335
338
|
self.assertIn('hello', string_io.getvalue())
|
@@ -28,7 +28,7 @@ Example = example_lib.Example
|
|
28
28
|
class SequenceWriterTest(unittest.TestCase):
|
29
29
|
|
30
30
|
def test_basic(self):
|
31
|
-
file = os.path.join(tempfile.
|
31
|
+
file = os.path.join(tempfile.mkdtemp(), 'test.jsonl')
|
32
32
|
writer = checkpointing.SequenceWriter(file)
|
33
33
|
example = Example(id=1, input=pg.Dict(x=1), output=2)
|
34
34
|
writer.add(example)
|
@@ -36,7 +36,7 @@ class SequenceWriterTest(unittest.TestCase):
|
|
36
36
|
self.assertTrue(pg.io.path_exists(file))
|
37
37
|
|
38
38
|
def test_error_handling(self):
|
39
|
-
file = os.path.join(tempfile.
|
39
|
+
file = os.path.join(tempfile.mkdtemp(), 'test_error_handling.jsonl')
|
40
40
|
writer = checkpointing.SequenceWriter(file)
|
41
41
|
writer.add(Example(id=1, input=pg.Dict(x=1), output=2))
|
42
42
|
|
@@ -87,7 +87,7 @@ class CheckpointerTest(unittest.TestCase):
|
|
87
87
|
class PerExampleCheckpointerTest(CheckpointerTest):
|
88
88
|
|
89
89
|
def test_checkpointing(self):
|
90
|
-
root_dir = os.path.join(tempfile.
|
90
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'per_example_checkpointer')
|
91
91
|
experiment = eval_test_helper.test_experiment()
|
92
92
|
checkpoint_filename = 'checkpoint.jsonl'
|
93
93
|
checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
|
@@ -119,7 +119,7 @@ class PerExampleCheckpointerTest(CheckpointerTest):
|
|
119
119
|
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
|
120
120
|
|
121
121
|
# Test warm start without reprocess.
|
122
|
-
root_dir = os.path.join(tempfile.
|
122
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'per_example_checkpointer2')
|
123
123
|
experiment = eval_test_helper.test_experiment()
|
124
124
|
_ = experiment.run(
|
125
125
|
root_dir, 'new', runner='sequential', plugins=[checkpointer],
|
@@ -129,7 +129,7 @@ class PerExampleCheckpointerTest(CheckpointerTest):
|
|
129
129
|
self.assertEqual(leaf.progress.num_skipped, num_processed[leaf.id])
|
130
130
|
|
131
131
|
# Test warm start with reprocess.
|
132
|
-
root_dir = os.path.join(tempfile.
|
132
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'per_example_checkpointer3')
|
133
133
|
experiment = eval_test_helper.test_experiment()
|
134
134
|
_ = experiment.run(
|
135
135
|
root_dir, 'new', runner='sequential', plugins=[checkpointer],
|
@@ -139,7 +139,7 @@ class PerExampleCheckpointerTest(CheckpointerTest):
|
|
139
139
|
for leaf in experiment.leaf_nodes:
|
140
140
|
self.assertEqual(leaf.progress.num_skipped, 0)
|
141
141
|
|
142
|
-
root_dir = os.path.join(tempfile.
|
142
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'per_example_checkpointer4')
|
143
143
|
experiment = eval_test_helper.test_experiment()
|
144
144
|
_ = experiment.run(
|
145
145
|
root_dir, 'new', runner='sequential', plugins=[checkpointer],
|
@@ -151,7 +151,7 @@ class PerExampleCheckpointerTest(CheckpointerTest):
|
|
151
151
|
|
152
152
|
def test_loading_corrupted_checkpoint(self):
|
153
153
|
root_dir = os.path.join(
|
154
|
-
tempfile.
|
154
|
+
tempfile.mkdtemp(),
|
155
155
|
'per_example_checkpointer_with_corrupted_checkpoint'
|
156
156
|
)
|
157
157
|
experiment = eval_test_helper.TestEvaluation()
|
@@ -178,7 +178,7 @@ class PerExampleCheckpointerTest(CheckpointerTest):
|
|
178
178
|
num_processed[example.id] = i + 1
|
179
179
|
|
180
180
|
root_dir = os.path.join(
|
181
|
-
tempfile.
|
181
|
+
tempfile.mkdtemp(),
|
182
182
|
'per_example_checkpointer_with_corrupted_checkpoint_warm_start'
|
183
183
|
)
|
184
184
|
experiment = eval_test_helper.TestEvaluation()
|
@@ -192,7 +192,7 @@ class PerExampleCheckpointerTest(CheckpointerTest):
|
|
192
192
|
|
193
193
|
def test_checkpointing_error(self):
|
194
194
|
root_dir = os.path.join(
|
195
|
-
tempfile.
|
195
|
+
tempfile.mkdtemp(),
|
196
196
|
'per_example_checkpointer_with_checkpointing_error'
|
197
197
|
)
|
198
198
|
experiment = (eval_test_helper
|
@@ -207,7 +207,7 @@ class PerExampleCheckpointerTest(CheckpointerTest):
|
|
207
207
|
class BulkCheckpointerTest(CheckpointerTest):
|
208
208
|
|
209
209
|
def test_checkpointing(self):
|
210
|
-
root_dir = os.path.join(tempfile.
|
210
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_bulk_checkpointer')
|
211
211
|
experiment = eval_test_helper.test_experiment()
|
212
212
|
checkpoint_filename = 'checkpoint.jsonl'
|
213
213
|
checkpointer = checkpointing.BulkCheckpointer(checkpoint_filename)
|
@@ -238,7 +238,7 @@ class BulkCheckpointerTest(CheckpointerTest):
|
|
238
238
|
|
239
239
|
def test_checkpointing_error(self):
|
240
240
|
root_dir = os.path.join(
|
241
|
-
tempfile.
|
241
|
+
tempfile.mkdtemp(),
|
242
242
|
'bulk_checkpointer_with_checkpointing_error'
|
243
243
|
)
|
244
244
|
experiment = (eval_test_helper
|
@@ -116,7 +116,7 @@ class EvaluationTest(unittest.TestCase):
|
|
116
116
|
self.assertEqual(example.metric_metadata, dict(error='ValueError'))
|
117
117
|
|
118
118
|
def test_evaluate_withstate(self):
|
119
|
-
eval_dir = os.path.join(tempfile.
|
119
|
+
eval_dir = os.path.join(tempfile.mkdtemp(), 'test_eval')
|
120
120
|
pg.io.mkdirs(eval_dir, exist_ok=True)
|
121
121
|
state_file = os.path.join(eval_dir, 'state.jsonl')
|
122
122
|
with pg.io.open_sequence(state_file, 'w') as f:
|
@@ -145,7 +145,7 @@ class RunIdTest(unittest.TestCase):
|
|
145
145
|
)
|
146
146
|
|
147
147
|
def test_get_latest(self):
|
148
|
-
root_dir = os.path.join(tempfile.
|
148
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_eval')
|
149
149
|
pg.io.mkdirs(os.path.join(root_dir, 'run_20241102_0'))
|
150
150
|
pg.io.mkdirs(os.path.join(root_dir, 'run_20241101_0'))
|
151
151
|
self.assertEqual(
|
@@ -153,15 +153,15 @@ class RunIdTest(unittest.TestCase):
|
|
153
153
|
RunId.from_id('20241102_0')
|
154
154
|
)
|
155
155
|
self.assertIsNone(RunId.get_latest('/notexist'))
|
156
|
-
self.assertIsNone(RunId.get_latest(tempfile.
|
156
|
+
self.assertIsNone(RunId.get_latest(tempfile.mkdtemp()))
|
157
157
|
|
158
158
|
def test_new(self):
|
159
159
|
rid = RunId(date=datetime.date.today(), number=1)
|
160
160
|
self.assertEqual(
|
161
|
-
RunId.new(root_dir=os.path.join(tempfile.
|
161
|
+
RunId.new(root_dir=os.path.join(tempfile.mkdtemp(), 'test_new')),
|
162
162
|
rid
|
163
163
|
)
|
164
|
-
root_dir = os.path.join(tempfile.
|
164
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_eval2')
|
165
165
|
pg.io.mkdirs(rid.dirname(root_dir))
|
166
166
|
self.assertEqual(RunId.new(root_dir), rid.next())
|
167
167
|
|
@@ -185,13 +185,13 @@ class RunIdTest(unittest.TestCase):
|
|
185
185
|
with self.assertRaisesRegex(
|
186
186
|
ValueError, '.* no previous runs'
|
187
187
|
):
|
188
|
-
RunId.from_id('latest', root_dir=tempfile.
|
188
|
+
RunId.from_id('latest', root_dir=tempfile.mkdtemp())
|
189
189
|
|
190
190
|
self.assertEqual(
|
191
191
|
RunId.from_id('20241102_1'),
|
192
192
|
RunId(date=datetime.date(2024, 11, 2), number=1)
|
193
193
|
)
|
194
|
-
root_dir = os.path.join(tempfile.
|
194
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_eval3')
|
195
195
|
rid = RunId.from_id('20241102_1')
|
196
196
|
pg.io.mkdirs(rid.dirname(root_dir))
|
197
197
|
self.assertEqual(
|
@@ -413,7 +413,7 @@ class RunnerTest(unittest.TestCase):
|
|
413
413
|
),
|
414
414
|
TestRunner
|
415
415
|
)
|
416
|
-
root_dir = os.path.join(tempfile.
|
416
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'my_eval')
|
417
417
|
|
418
418
|
# Test standard run.
|
419
419
|
MyEvaluation(replica_id=0).run(
|
@@ -34,7 +34,7 @@ class HtmlProgressTrackerTest(unittest.TestCase):
|
|
34
34
|
lf_console._notebook = pg.Dict(
|
35
35
|
display=display
|
36
36
|
)
|
37
|
-
root_dir = os.path.join(tempfile.
|
37
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_html_progress_tracker')
|
38
38
|
experiment = eval_test_helper.test_experiment()
|
39
39
|
_ = experiment.run(root_dir, 'new', plugins=[])
|
40
40
|
self.assertIsInstance(result['view'], pg.Html)
|
@@ -44,7 +44,7 @@ class HtmlProgressTrackerTest(unittest.TestCase):
|
|
44
44
|
class TqdmProgressTrackerTest(unittest.TestCase):
|
45
45
|
|
46
46
|
def test_basic(self):
|
47
|
-
root_dir = os.path.join(tempfile.
|
47
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_tqdm_progress_tracker')
|
48
48
|
experiment = eval_test_helper.test_experiment()
|
49
49
|
string_io = io.StringIO()
|
50
50
|
with contextlib.redirect_stderr(string_io):
|
@@ -53,7 +53,7 @@ class TqdmProgressTrackerTest(unittest.TestCase):
|
|
53
53
|
|
54
54
|
def test_with_example_ids(self):
|
55
55
|
root_dir = os.path.join(
|
56
|
-
tempfile.
|
56
|
+
tempfile.mkdtemp(), 'test_tqdm_progress_tracker_with_example_ids'
|
57
57
|
)
|
58
58
|
experiment = eval_test_helper.test_experiment()
|
59
59
|
string_io = io.StringIO()
|
@@ -25,7 +25,7 @@ import pyglove as pg
|
|
25
25
|
class ReportingTest(unittest.TestCase):
|
26
26
|
|
27
27
|
def test_reporting(self):
|
28
|
-
root_dir = os.path.join(tempfile.
|
28
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_reporting')
|
29
29
|
experiment = eval_test_helper.test_experiment()
|
30
30
|
checkpointer = checkpointing.BulkCheckpointer('checkpoint.jsonl')
|
31
31
|
reporter = reporting.HtmlReporter()
|
@@ -49,7 +49,7 @@ class ReportingTest(unittest.TestCase):
|
|
49
49
|
self.assertTrue(found_generation_log)
|
50
50
|
|
51
51
|
# Test warm start.
|
52
|
-
root_dir = os.path.join(tempfile.
|
52
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_reporting2')
|
53
53
|
experiment = eval_test_helper.test_experiment()
|
54
54
|
run = experiment.run(
|
55
55
|
root_dir, 'new', plugins=[checkpointer, reporter],
|
@@ -75,7 +75,7 @@ class ReportingTest(unittest.TestCase):
|
|
75
75
|
|
76
76
|
def test_index_html_generation_error(self):
|
77
77
|
root_dir = os.path.join(
|
78
|
-
tempfile.
|
78
|
+
tempfile.mkdtemp(),
|
79
79
|
'test_reporting_with_index_html_generation_error'
|
80
80
|
)
|
81
81
|
experiment = (eval_test_helper
|
@@ -98,7 +98,7 @@ class ReportingTest(unittest.TestCase):
|
|
98
98
|
|
99
99
|
def test_example_html_generation_error(self):
|
100
100
|
root_dir = os.path.join(
|
101
|
-
tempfile.
|
101
|
+
tempfile.mkdtemp(),
|
102
102
|
'test_reporting_with_example_html_generation_error'
|
103
103
|
)
|
104
104
|
experiment = (eval_test_helper
|
@@ -126,7 +126,7 @@ class ReportingTest(unittest.TestCase):
|
|
126
126
|
|
127
127
|
# Test warm start.
|
128
128
|
root_dir = os.path.join(
|
129
|
-
tempfile.
|
129
|
+
tempfile.mkdtemp(),
|
130
130
|
'test_reporting_with_example_html_generation_error2'
|
131
131
|
)
|
132
132
|
experiment = (eval_test_helper
|
@@ -103,7 +103,7 @@ class RunnerTest(unittest.TestCase):
|
|
103
103
|
def test_basic(self):
|
104
104
|
plugin = TestPlugin()
|
105
105
|
exp = eval_test_helper.test_experiment()
|
106
|
-
root_dir = os.path.join(tempfile.
|
106
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_sequential_runner')
|
107
107
|
run = exp.run(root_dir, runner='sequential', plugins=[plugin])
|
108
108
|
|
109
109
|
self.assertIsNotNone(plugin.start_time)
|
@@ -143,7 +143,7 @@ class RunnerTest(unittest.TestCase):
|
|
143
143
|
self.assertEqual(node.progress.num_processed, node.progress.num_total)
|
144
144
|
|
145
145
|
def test_raise_if_has_error(self):
|
146
|
-
root_dir = os.path.join(tempfile.
|
146
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_raise_if_has_error')
|
147
147
|
exp = eval_test_helper.TestEvaluation()
|
148
148
|
with self.assertRaisesRegex(ValueError, 'x should not be 5'):
|
149
149
|
exp.run(
|
@@ -154,7 +154,7 @@ class RunnerTest(unittest.TestCase):
|
|
154
154
|
exp.run(root_dir, runner='parallel', plugins=[], raise_if_has_error=True)
|
155
155
|
|
156
156
|
def test_example_ids(self):
|
157
|
-
root_dir = os.path.join(tempfile.
|
157
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_example_ids')
|
158
158
|
exp = eval_test_helper.test_experiment()
|
159
159
|
plugin = TestPlugin()
|
160
160
|
_ = exp.run(
|
@@ -164,7 +164,7 @@ class RunnerTest(unittest.TestCase):
|
|
164
164
|
self.assertEqual(plugin.completed_example_ids, [5, 7, 9] * 6)
|
165
165
|
|
166
166
|
def test_shuffle_inputs(self):
|
167
|
-
root_dir = os.path.join(tempfile.
|
167
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_shuffle_inputs')
|
168
168
|
exp = eval_test_helper.test_experiment()
|
169
169
|
plugin = TestPlugin()
|
170
170
|
run = exp.run(
|
@@ -175,7 +175,7 @@ class RunnerTest(unittest.TestCase):
|
|
175
175
|
def test_filter(self):
|
176
176
|
plugin = TestPlugin()
|
177
177
|
exp = eval_test_helper.test_experiment()
|
178
|
-
root_dir = os.path.join(tempfile.
|
178
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_filter')
|
179
179
|
|
180
180
|
_ = exp.run(
|
181
181
|
root_dir, runner='sequential', plugins=[plugin],
|
@@ -207,7 +207,7 @@ class RunnerTest(unittest.TestCase):
|
|
207
207
|
inputs=test_inputs(num_examples=pg.oneof([2, 4]))
|
208
208
|
)
|
209
209
|
# Global cache.
|
210
|
-
root_dir = os.path.join(tempfile.
|
210
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'global_cache')
|
211
211
|
run = exp.run(
|
212
212
|
root_dir, 'new', runner='sequential', use_cache='global', plugins=[]
|
213
213
|
)
|
@@ -216,7 +216,7 @@ class RunnerTest(unittest.TestCase):
|
|
216
216
|
self.assertEqual(exp.usage_summary.uncached.total.num_requests, 2)
|
217
217
|
|
218
218
|
# Per-dataset cache.
|
219
|
-
root_dir = os.path.join(tempfile.
|
219
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'per_dataset')
|
220
220
|
run = exp.run(
|
221
221
|
root_dir, 'new', runner='sequential',
|
222
222
|
use_cache='per_dataset', plugins=[]
|
@@ -229,7 +229,7 @@ class RunnerTest(unittest.TestCase):
|
|
229
229
|
self.assertEqual(exp.usage_summary.uncached.total.num_requests, 3)
|
230
230
|
|
231
231
|
# No cache.
|
232
|
-
root_dir = os.path.join(tempfile.
|
232
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'no')
|
233
233
|
run = exp.run(root_dir, runner='sequential', use_cache='no', plugins=[])
|
234
234
|
self.assertFalse(pg.io.path_exists(run.output_path_for(exp, 'cache.json')))
|
235
235
|
for leaf in exp.leaf_nodes:
|
@@ -245,7 +245,7 @@ class ParallelRunnerTest(RunnerTest):
|
|
245
245
|
def test_parallel_runner(self):
|
246
246
|
plugin = TestPlugin()
|
247
247
|
exp = eval_test_helper.test_experiment()
|
248
|
-
root_dir = os.path.join(tempfile.
|
248
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_parallel_runner')
|
249
249
|
run = exp.run(root_dir, runner='parallel', plugins=[plugin])
|
250
250
|
|
251
251
|
self.assertIsNotNone(plugin.start_time)
|
@@ -286,7 +286,7 @@ class ParallelRunnerTest(RunnerTest):
|
|
286
286
|
plugin = TestPlugin()
|
287
287
|
exp = eval_test_helper.test_experiment()
|
288
288
|
root_dir = os.path.join(
|
289
|
-
tempfile.
|
289
|
+
tempfile.mkdtemp(), 'test_concurrent_startup_delay'
|
290
290
|
)
|
291
291
|
_ = exp.run(
|
292
292
|
root_dir,
|
@@ -301,7 +301,7 @@ class DebugRunnerTest(RunnerTest):
|
|
301
301
|
def test_debug_runner(self):
|
302
302
|
plugin = TestPlugin()
|
303
303
|
exp = eval_test_helper.test_experiment()
|
304
|
-
root_dir = os.path.join(tempfile.
|
304
|
+
root_dir = os.path.join(tempfile.mkdtemp(), 'test_debug_runner')
|
305
305
|
run = exp.run(root_dir, runner='debug', plugins=[plugin])
|
306
306
|
|
307
307
|
self.assertIsNotNone(plugin.start_time)
|
langfun/core/language_model.py
CHANGED
@@ -667,10 +667,7 @@ class LMDebugMode(enum.IntFlag):
|
|
667
667
|
PROMPT = enum.auto()
|
668
668
|
RESPONSE = enum.auto()
|
669
669
|
|
670
|
-
|
671
|
-
@property
|
672
|
-
def ALL(cls) -> 'LMDebugMode': # pylint: disable=invalid-name
|
673
|
-
return LMDebugMode.INFO | LMDebugMode.PROMPT | LMDebugMode.RESPONSE
|
670
|
+
ALL = INFO | PROMPT | RESPONSE
|
674
671
|
|
675
672
|
|
676
673
|
class LanguageModel(component.Component):
|
@@ -1101,7 +1098,10 @@ class LanguageModel(component.Component):
|
|
1101
1098
|
return [job.result for job in executed_jobs]
|
1102
1099
|
|
1103
1100
|
def __call__(
|
1104
|
-
self,
|
1101
|
+
self,
|
1102
|
+
prompt: str | message_lib.Message,
|
1103
|
+
*,
|
1104
|
+
cache_seed: int = 0, **kwargs
|
1105
1105
|
) -> message_lib.Message:
|
1106
1106
|
"""Returns the first candidate."""
|
1107
1107
|
prompt = message_lib.UserMessage.from_value(prompt)
|
@@ -683,16 +683,14 @@ class LanguageModelTest(unittest.TestCase):
|
|
683
683
|
|
684
684
|
debug_info = string_io.getvalue()
|
685
685
|
expected_included = [
|
686
|
-
debug_prints[f]
|
687
|
-
|
688
|
-
if f != lm_lib.LMDebugMode.NONE and f in debug_mode
|
686
|
+
debug_prints[f] for f in (info_flag, prompt_flag, response_flag)
|
687
|
+
if f in debug_mode
|
689
688
|
]
|
690
689
|
expected_excluded = [
|
691
690
|
debug_prints[f]
|
692
|
-
for f in
|
693
|
-
if f
|
691
|
+
for f in (info_flag, prompt_flag, response_flag)
|
692
|
+
if f not in debug_mode
|
694
693
|
]
|
695
|
-
|
696
694
|
for expected_include in expected_included:
|
697
695
|
self.assertIn('[0] ' + expected_include, debug_info)
|
698
696
|
for expected_exclude in expected_excluded:
|
@@ -750,13 +748,13 @@ class LanguageModelTest(unittest.TestCase):
|
|
750
748
|
debug_info = string_io.getvalue()
|
751
749
|
expected_included = [
|
752
750
|
debug_prints[f]
|
753
|
-
for f in
|
754
|
-
if f
|
751
|
+
for f in (info_flag, prompt_flag, response_flag)
|
752
|
+
if f in debug_mode
|
755
753
|
]
|
756
754
|
expected_excluded = [
|
757
755
|
debug_prints[f]
|
758
|
-
for f in
|
759
|
-
if f
|
756
|
+
for f in (info_flag, prompt_flag, response_flag)
|
757
|
+
if f not in debug_mode
|
760
758
|
]
|
761
759
|
|
762
760
|
for expected_include in expected_included:
|
@@ -813,13 +811,13 @@ class LanguageModelTest(unittest.TestCase):
|
|
813
811
|
debug_info = string_io.getvalue()
|
814
812
|
expected_included = [
|
815
813
|
debug_prints[f]
|
816
|
-
for f in
|
817
|
-
if f
|
814
|
+
for f in (info_flag, prompt_flag, response_flag)
|
815
|
+
if f in debug_mode
|
818
816
|
]
|
819
817
|
expected_excluded = [
|
820
818
|
debug_prints[f]
|
821
|
-
for f in
|
822
|
-
if f
|
819
|
+
for f in (info_flag, prompt_flag, response_flag)
|
820
|
+
if f not in debug_mode
|
823
821
|
]
|
824
822
|
|
825
823
|
for expected_include in expected_included:
|
@@ -78,7 +78,7 @@ class RandomChoice(lf.LanguageModel):
|
|
78
78
|
)
|
79
79
|
|
80
80
|
def __call__(
|
81
|
-
self, prompt: lf.Message, *, cache_seed: int = 0, **kwargs
|
81
|
+
self, prompt: str | lf.Message, *, cache_seed: int = 0, **kwargs
|
82
82
|
) -> lf.Message:
|
83
83
|
return self._select_lm()(prompt, cache_seed=cache_seed, **kwargs)
|
84
84
|
|
langfun/core/message.py
CHANGED
@@ -225,13 +225,11 @@ class Message(
|
|
225
225
|
return MessageConverter.get(format_or_type, **kwargs).to_value(self)
|
226
226
|
|
227
227
|
@classmethod
|
228
|
-
@property
|
229
228
|
def convertible_formats(cls) -> list[str]:
|
230
229
|
"""Returns supported format for message conversion."""
|
231
230
|
return MessageConverter.convertible_formats()
|
232
231
|
|
233
232
|
@classmethod
|
234
|
-
@property
|
235
233
|
def convertible_types(cls) -> list[str]:
|
236
234
|
"""Returns supported types for message conversion."""
|
237
235
|
return MessageConverter.convertible_types()
|
@@ -938,8 +936,8 @@ class MessageConverter(pg.Object):
|
|
938
936
|
"""Converts a Langfun message to other formats."""
|
939
937
|
|
940
938
|
@abc.abstractmethod
|
941
|
-
def from_value(self, value:
|
942
|
-
"""Returns a
|
939
|
+
def from_value(self, value: Any) -> Message:
|
940
|
+
"""Returns a Langfun message from other formats."""
|
943
941
|
|
944
942
|
@classmethod
|
945
943
|
def _safe_read(
|
langfun/core/message_test.py
CHANGED
@@ -521,12 +521,12 @@ class MessageConverterTest(unittest.TestCase):
|
|
521
521
|
def from_value(self, value: tuple[int, ...]) -> message.Message:
|
522
522
|
return message.UserMessage(','.join(str(x) for x in value))
|
523
523
|
|
524
|
-
self.assertIn('test_format1', message.Message.convertible_formats)
|
525
|
-
self.assertIn('test_format2', message.Message.convertible_formats)
|
526
|
-
self.assertIn('test_format3', message.Message.convertible_formats)
|
524
|
+
self.assertIn('test_format1', message.Message.convertible_formats())
|
525
|
+
self.assertIn('test_format2', message.Message.convertible_formats())
|
526
|
+
self.assertIn('test_format3', message.Message.convertible_formats())
|
527
527
|
|
528
|
-
self.assertIn(int, message.Message.convertible_types)
|
529
|
-
self.assertIn(tuple, message.Message.convertible_types)
|
528
|
+
self.assertIn(int, message.Message.convertible_types())
|
529
|
+
self.assertIn(tuple, message.Message.convertible_types())
|
530
530
|
self.assertEqual(
|
531
531
|
message.Message.from_value(1, format='test_format1'),
|
532
532
|
message.UserMessage('1')
|
@@ -433,8 +433,18 @@ class Mapping(lf.LangFunc):
|
|
433
433
|
schema = self.mapping_request.schema
|
434
434
|
if schema is None:
|
435
435
|
return None
|
436
|
+
response_text = lm_output.text
|
437
|
+
|
438
|
+
# For Gemini, we might have tool calls in the metadata, use tool call codes
|
439
|
+
# to construct the response text if it's present.
|
440
|
+
# NOTE(daiyip): This logic is subject to change.
|
441
|
+
if 'tool_calls' in lm_output.metadata:
|
442
|
+
assert lm_output.metadata['tool_calls'], lm_output.metadata
|
443
|
+
response_text = '\n'.join(
|
444
|
+
tc.text for tc in lm_output.metadata['tool_calls']
|
445
|
+
)
|
436
446
|
return schema.parse(
|
437
|
-
|
447
|
+
response_text,
|
438
448
|
protocol=self.protocol,
|
439
449
|
additional_context=self.globals(),
|
440
450
|
autofix=self.autofix,
|
@@ -114,6 +114,21 @@ class QueryTest(unittest.TestCase):
|
|
114
114
|
),
|
115
115
|
'The answer is one.',
|
116
116
|
)
|
117
|
+
# Testing tool calls in the response.
|
118
|
+
self.assertEqual(
|
119
|
+
querying.query(
|
120
|
+
'abc',
|
121
|
+
Activity,
|
122
|
+
lm=fake.StaticResponse(
|
123
|
+
lf.AIMessage(
|
124
|
+
'Here is the answer.',
|
125
|
+
tool_calls=[lf.AIMessage('Activity(description="hello")')],
|
126
|
+
),
|
127
|
+
),
|
128
|
+
),
|
129
|
+
Activity(description='hello'),
|
130
|
+
)
|
131
|
+
# Test completing a partial object.
|
117
132
|
self.assertEqual(
|
118
133
|
querying.query(
|
119
134
|
Activity.partial(),
|
@@ -17,6 +17,7 @@ import abc
|
|
17
17
|
import inspect
|
18
18
|
import io
|
19
19
|
import re
|
20
|
+
import sys
|
20
21
|
import textwrap
|
21
22
|
import typing
|
22
23
|
from typing import Any, Literal, Sequence, Type, Union
|
@@ -451,6 +452,15 @@ def class_definition(
|
|
451
452
|
out.write(f' """{cls.__doc__}"""\n')
|
452
453
|
else:
|
453
454
|
out.write(' """')
|
455
|
+
|
456
|
+
# Since Python 3.13, the indentation of docstring lines is removed.
|
457
|
+
# Therefore, we add two spaces to each non-empty line to keep the
|
458
|
+
# indentation consistent with the class definition.
|
459
|
+
if sys.version_info >= (3, 13):
|
460
|
+
for i in range(1, len(doc_lines)):
|
461
|
+
if doc_lines[i]:
|
462
|
+
doc_lines[i] = ' ' * 2 + doc_lines[i]
|
463
|
+
|
454
464
|
for line in doc_lines:
|
455
465
|
out.write(line)
|
456
466
|
out.write('\n')
|
@@ -3,18 +3,18 @@ langfun/core/__init__.py,sha256=pW4prpiyWNkRbtWBGYF1thn7_0F_TgDVfAIZPvGn6HA,4758
|
|
3
3
|
langfun/core/component.py,sha256=g1kQM0bryYYYWVDrSMnHfc74wIBbpfe5_B3s-UIP5GE,3028
|
4
4
|
langfun/core/component_test.py,sha256=0CxTgjAud3aj8wBauFhG2FHDqrxCTl4OI4gzQTad-40,9254
|
5
5
|
langfun/core/concurrent.py,sha256=zY-pXqlGqss_GI20tM1gXvyW8QepVPUuFNmutcIdhbI,32760
|
6
|
-
langfun/core/concurrent_test.py,sha256=
|
6
|
+
langfun/core/concurrent_test.py,sha256=fjVcxD_OSH9fBqBEpDpuIVfcfoKZWDtwmkoM2ZMHqy8,17628
|
7
7
|
langfun/core/console.py,sha256=cLQEf84aDxItA9fStJV22xJch0TqFLNf9hLqwJ0RHmU,2652
|
8
8
|
langfun/core/console_test.py,sha256=pBOcuNMJdVELywvroptfcRtJMsegMm3wSlHAL2TdxVk,1679
|
9
9
|
langfun/core/langfunc.py,sha256=G50YgoVZ0y1GFw2ev41MlOqr6qa8YakbvNC0h_E0PiA,11140
|
10
10
|
langfun/core/langfunc_test.py,sha256=CDn-gJCa5EnjN7cotAVCfSCbuzddq2o-HzEt7kV8HbY,8882
|
11
|
-
langfun/core/language_model.py,sha256=
|
12
|
-
langfun/core/language_model_test.py,sha256=
|
11
|
+
langfun/core/language_model.py,sha256=5i0Je5526JO2YY6qExi6Yf7VQVgSVeZIKOjt3I8kxqQ,49573
|
12
|
+
langfun/core/language_model_test.py,sha256=9EofP3_gTH28SNWiOKzTUYMHH0EYtbi9xGuT1KZT1XU,37330
|
13
13
|
langfun/core/logging.py,sha256=7IGAhp7mGokZxxqtL-XZvFLKaZ5k3F5_Xp2NUtR4GwE,9136
|
14
14
|
langfun/core/logging_test.py,sha256=vbVGOQxwMmVSiFfbt2897gUt-8nqDpV64jCAeUG_q5U,6924
|
15
15
|
langfun/core/memory.py,sha256=vyXVvfvSdLLJAzdIupnbn3k26OgclCx-OJ7gddS5e1Y,2070
|
16
|
-
langfun/core/message.py,sha256=
|
17
|
-
langfun/core/message_test.py,sha256=
|
16
|
+
langfun/core/message.py,sha256=Nx9SqEIkPMS5I1RyMQFlWUjZCsdlGamv_wTze2-3R4M,32784
|
17
|
+
langfun/core/message_test.py,sha256=dAA_ZzI5MGyFfXyejxPrB90SbR066mkIgmRtdZ5ZbL4,40803
|
18
18
|
langfun/core/modality.py,sha256=K8pUGuMpfWcOtVcXC_OqVjro1-RhHF6ddQni61DuYzM,4166
|
19
19
|
langfun/core/modality_test.py,sha256=0WL_yd3B4K-FviWdSpDnOwj0f9TQI0v9t6X0vWvvJbo,2415
|
20
20
|
langfun/core/natural_language.py,sha256=3ynSnaYQnjE60LIPK5fyMgdIjubnPYZwzGq4rWPeloE,1177
|
@@ -61,14 +61,14 @@ langfun/core/eval/scoring.py,sha256=_DvnlgI1SdRVaOojao_AkV3pnenfCPOqyhvlg-Sw-5M,
|
|
61
61
|
langfun/core/eval/scoring_test.py,sha256=UcBH0R6vAovZ0A4yM22s5cBHL1qVKASubrbu1t8dYBw,4529
|
62
62
|
langfun/core/eval/v2/__init__.py,sha256=9lNKJwbvl0lcFblAXYT_OHI8fOubJsTOdSkxEqsP1xU,1726
|
63
63
|
langfun/core/eval/v2/checkpointing.py,sha256=t47rBfzGZYgIqWW1N1Ak9yQnNtHd-IRbEO0cZjG2VRo,11755
|
64
|
-
langfun/core/eval/v2/checkpointing_test.py,sha256=
|
64
|
+
langfun/core/eval/v2/checkpointing_test.py,sha256=cuQ1zom5DMXIebxYW6L3N5XRyhfoEEDrs7XQcAxg8Nc,9164
|
65
65
|
langfun/core/eval/v2/eval_test_helper.py,sha256=sKFi_wPYCNmr96WyTduuXY0KnxjFxcJyEhXey-_nGX8,3962
|
66
66
|
langfun/core/eval/v2/evaluation.py,sha256=ihT5dljnUkHM97XS9OwE2wOnYC-oYnHYgG5KN1hmiaU,27037
|
67
|
-
langfun/core/eval/v2/evaluation_test.py,sha256=
|
67
|
+
langfun/core/eval/v2/evaluation_test.py,sha256=46bGjNZmd57NXcJSoaC17DO9B74rpVBOVTEln_4W61c,6916
|
68
68
|
langfun/core/eval/v2/example.py,sha256=v1dIz89pccIqujt7utrk0EbqMWM9kBn-2fYGRTKe358,10890
|
69
69
|
langfun/core/eval/v2/example_test.py,sha256=wsHQD6te7ghROmxe3Xg_NK4TU0xS2MkNfnpo-H0H8xM,3399
|
70
70
|
langfun/core/eval/v2/experiment.py,sha256=fb3RHNOSRftV7ZTBfYVV50iEevqdPwRHCt3mgtLzuFw,33408
|
71
|
-
langfun/core/eval/v2/experiment_test.py,sha256=
|
71
|
+
langfun/core/eval/v2/experiment_test.py,sha256=BYrPYfQfU2jDfAlZcHDT0KUaXOnCnyTWyUEKhDoqXfw,13645
|
72
72
|
langfun/core/eval/v2/metric_values.py,sha256=_B905bC-jxrYPLSEcP2M8MaHZOVMz_bVrUw8YC4arCE,4660
|
73
73
|
langfun/core/eval/v2/metric_values_test.py,sha256=ab2oF_HsIwrSy459108ggyjgefHSPn8UVILR4dRwx14,2634
|
74
74
|
langfun/core/eval/v2/metrics.py,sha256=bl8i6u-ZHRBz4hAc3LzsZ2Dc7ZRQcuTYeUhhH-GxfF0,10628
|
@@ -76,17 +76,17 @@ langfun/core/eval/v2/metrics_test.py,sha256=LibZXvWEJDVRY-Mza_bQT-SbmbXCHUnFhL7Z
|
|
76
76
|
langfun/core/eval/v2/progress.py,sha256=azZgssQgNdv3IgjKEaQBuGI5ucFDNbdi02P4z_nQ8GE,10292
|
77
77
|
langfun/core/eval/v2/progress_test.py,sha256=YU7VHzmy5knPZwj9vpBN3rQQH2tukj9eKHkuBCI62h8,2540
|
78
78
|
langfun/core/eval/v2/progress_tracking.py,sha256=zNhNPGlnJnHELEfFpbTMCSXFn8d1IJ57OOYkfFaBFfM,6097
|
79
|
-
langfun/core/eval/v2/progress_tracking_test.py,sha256=
|
79
|
+
langfun/core/eval/v2/progress_tracking_test.py,sha256=sJhlVfinGsg3Kf2wQ_hT7VMcpQfaI4ZkqyW9ujElkwA,2282
|
80
80
|
langfun/core/eval/v2/reporting.py,sha256=yUIPCAMnp7InIzpv1DDWrcLO-75iiOUTpscj7smkfrA,8335
|
81
|
-
langfun/core/eval/v2/reporting_test.py,sha256=
|
81
|
+
langfun/core/eval/v2/reporting_test.py,sha256=CMK-vwho8cNRJwlbkCqm_v5fykE7Y3V6SaIOCY0CDyA,5671
|
82
82
|
langfun/core/eval/v2/runners.py,sha256=iqbH4jMtnNMhfuv1eHaxJmk1Vvsrz-sAJJFP8U44-tA,16758
|
83
|
-
langfun/core/eval/v2/runners_test.py,sha256=
|
83
|
+
langfun/core/eval/v2/runners_test.py,sha256=spjkmqlls_vyERdZMdjv6dhIN9ZfxsDDvIQAWTj2kMk,11954
|
84
84
|
langfun/core/llms/__init__.py,sha256=CtxUdXohQ8AQk1DqBT6MBy2zdAoPSggNo00SYrj9-AY,9521
|
85
85
|
langfun/core/llms/anthropic.py,sha256=YcQ2VG8iOfXtry_tTpAukmiwXa2hK_9LkpkmXk41Nm0,26226
|
86
86
|
langfun/core/llms/anthropic_test.py,sha256=qA9vByp_cwwXNlXzcwHpPWFnO9lfFo8NKfDi5nBNqgI,9052
|
87
87
|
langfun/core/llms/azure_openai.py,sha256=-KkSLaR54MlsIqz_XIwv0TnsBnvNTAxnjA2Q2O2u5KM,2733
|
88
88
|
langfun/core/llms/azure_openai_test.py,sha256=lkMZkQdJBV97fTM4C4z8qNfvr6spgiN5G4hvVUIVr0M,1735
|
89
|
-
langfun/core/llms/compositional.py,sha256=
|
89
|
+
langfun/core/llms/compositional.py,sha256=W_Fe2BdbkjwTzWW-paCWcEeG9oOR3-IcBG8oc73taSM,2878
|
90
90
|
langfun/core/llms/compositional_test.py,sha256=4eTnOer-DncRKGaIJW2ZQQMLnt5r2R0UIx_DYOvGAQo,2027
|
91
91
|
langfun/core/llms/deepseek.py,sha256=jvTxdXPr-vH6HNakn_Ootx1heDg8Fen2FUkUW36bpCs,5247
|
92
92
|
langfun/core/llms/deepseek_test.py,sha256=DvROWPlDuow5E1lfoSkhyGt_ELA19JoQoDsTnRgDtTg,1847
|
@@ -133,13 +133,13 @@ langfun/core/structured/description.py,sha256=6BztYOiucPkF4CrTQtPLPJo1gN2dwnKmaJ
|
|
133
133
|
langfun/core/structured/description_test.py,sha256=UxaXnKKP7TnyPDPUyf3U-zPE0TvLlIP6DGr8thjcePw,7365
|
134
134
|
langfun/core/structured/function_generation.py,sha256=g7AOR_e8HxFU6n6Df750aGkgMgV1KExLZMAz0yd5Agg,8555
|
135
135
|
langfun/core/structured/function_generation_test.py,sha256=LaXYDXf9GlqUrR6v_gtmK_H4kxzonmU7SYbn7XXMgjU,12128
|
136
|
-
langfun/core/structured/mapping.py,sha256=
|
136
|
+
langfun/core/structured/mapping.py,sha256=1YBW8PKpJKXS7DKukfzKNioL84PrKUcB4KOUudrQ20w,14374
|
137
137
|
langfun/core/structured/mapping_test.py,sha256=OntYvfDitAf0tAnzQty3YS90vyEn6FY1Mi93r_ViEk8,9594
|
138
138
|
langfun/core/structured/parsing.py,sha256=MGvI7ypXlwfzr5XB8_TFU9Ei0_5reYqkWkv64eAy0EA,12015
|
139
139
|
langfun/core/structured/parsing_test.py,sha256=V8Cj1tJK4Lxv_b0YQj6-2hzXZgnYNBa2JR7rOLRBKoQ,22346
|
140
140
|
langfun/core/structured/querying.py,sha256=vE_NOLNlIe4A0DueQfyiBEUh3AsSD8Hhx2dSDHNYpYk,37976
|
141
|
-
langfun/core/structured/querying_test.py,sha256=
|
142
|
-
langfun/core/structured/schema.py,sha256=
|
141
|
+
langfun/core/structured/querying_test.py,sha256=Q0HwmbUI9BqMaeN8vgn_EvX29CzfcomGIKVqKJ6dZyY,50212
|
142
|
+
langfun/core/structured/schema.py,sha256=xtgrr3t5tcYQ2gi_fkTKz2IgDMf84gpiykmBdfnV6Io,29486
|
143
143
|
langfun/core/structured/schema_generation.py,sha256=pEWeTd8tQWYnEHukas6GVl4uGerLsQ2aNybtnm4Qgxc,5352
|
144
144
|
langfun/core/structured/schema_generation_test.py,sha256=RM9s71kMNg2jTePwInkiW9fK1ACN37eyPeF8OII-0zw,2950
|
145
145
|
langfun/core/structured/schema_test.py,sha256=H42ZZdPi8CIv7WzrnXwMwQQaPQxlmDSY31pfqQs-Xqw,26567
|
@@ -156,8 +156,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
156
156
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
157
157
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
158
158
|
langfun/core/templates/selfplay_test.py,sha256=Ot__1P1M8oJfoTp-M9-PQ6HUXqZKyMwvZ5f7yQ3yfyM,2326
|
159
|
-
langfun-0.1.2.
|
160
|
-
langfun-0.1.2.
|
161
|
-
langfun-0.1.2.
|
162
|
-
langfun-0.1.2.
|
163
|
-
langfun-0.1.2.
|
159
|
+
langfun-0.1.2.dev202507150805.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
160
|
+
langfun-0.1.2.dev202507150805.dist-info/METADATA,sha256=sO0bfLTiiYrI_s4AbIEF_5v0li1JRZ1IjH492KCQrGU,8178
|
161
|
+
langfun-0.1.2.dev202507150805.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
162
|
+
langfun-0.1.2.dev202507150805.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
163
|
+
langfun-0.1.2.dev202507150805.dist-info/RECORD,,
|
File without changes
|
{langfun-0.1.2.dev202507140805.dist-info → langfun-0.1.2.dev202507150805.dist-info}/licenses/LICENSE
RENAMED
File without changes
|
{langfun-0.1.2.dev202507140805.dist-info → langfun-0.1.2.dev202507150805.dist-info}/top_level.txt
RENAMED
File without changes
|