langfun 0.0.2.dev20240423__py3-none-any.whl → 0.0.2.dev20240428__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/__init__.py +1 -0
- langfun/core/eval/__init__.py +2 -0
- langfun/core/eval/base.py +176 -18
- langfun/core/eval/base_test.py +34 -6
- langfun/core/eval/matching.py +18 -1
- langfun/core/eval/matching_test.py +2 -1
- langfun/core/eval/scoring.py +11 -1
- langfun/core/eval/scoring_test.py +2 -1
- langfun/core/langfunc.py +0 -5
- langfun/core/language_model.py +39 -9
- langfun/core/language_model_test.py +156 -18
- langfun/core/llms/fake_test.py +91 -7
- langfun/core/llms/openai_test.py +202 -17
- langfun/core/structured/__init__.py +1 -0
- langfun/core/structured/completion_test.py +1 -2
- langfun/core/structured/mapping.py +38 -1
- langfun/core/structured/mapping_test.py +17 -0
- langfun/core/structured/parsing_test.py +2 -4
- langfun/core/structured/prompting.py +14 -4
- langfun/core/structured/prompting_test.py +35 -4
- langfun/core/structured/schema_generation_test.py +2 -2
- langfun/core/template.py +99 -2
- langfun/core/template_test.py +66 -0
- {langfun-0.0.2.dev20240423.dist-info → langfun-0.0.2.dev20240428.dist-info}/METADATA +3 -2
- {langfun-0.0.2.dev20240423.dist-info → langfun-0.0.2.dev20240428.dist-info}/RECORD +28 -28
- {langfun-0.0.2.dev20240423.dist-info → langfun-0.0.2.dev20240428.dist-info}/LICENSE +0 -0
- {langfun-0.0.2.dev20240423.dist-info → langfun-0.0.2.dev20240428.dist-info}/WHEEL +0 -0
- {langfun-0.0.2.dev20240423.dist-info → langfun-0.0.2.dev20240428.dist-info}/top_level.txt +0 -0
@@ -111,11 +111,35 @@ class LanguageModelTest(unittest.TestCase):
|
|
111
111
|
lm.sample(prompts=['foo', 'bar']),
|
112
112
|
[
|
113
113
|
lm_lib.LMSamplingResult(
|
114
|
-
[
|
114
|
+
[
|
115
|
+
lm_lib.LMSample(
|
116
|
+
message_lib.AIMessage(
|
117
|
+
'foo',
|
118
|
+
score=-1.0,
|
119
|
+
logprobs=None,
|
120
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
121
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
122
|
+
),
|
123
|
+
score=-1.0,
|
124
|
+
logprobs=None,
|
125
|
+
)
|
126
|
+
],
|
115
127
|
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
116
128
|
),
|
117
129
|
lm_lib.LMSamplingResult(
|
118
|
-
[
|
130
|
+
[
|
131
|
+
lm_lib.LMSample(
|
132
|
+
message_lib.AIMessage(
|
133
|
+
'bar',
|
134
|
+
score=-1.0,
|
135
|
+
logprobs=None,
|
136
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
137
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
138
|
+
),
|
139
|
+
score=-1.0,
|
140
|
+
logprobs=None,
|
141
|
+
)
|
142
|
+
],
|
119
143
|
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
120
144
|
),
|
121
145
|
],
|
@@ -128,41 +152,119 @@ class LanguageModelTest(unittest.TestCase):
|
|
128
152
|
),
|
129
153
|
[
|
130
154
|
lm_lib.LMSamplingResult(
|
131
|
-
[
|
155
|
+
[
|
156
|
+
lm_lib.LMSample(
|
157
|
+
message_lib.AIMessage(
|
158
|
+
'foo' * 2,
|
159
|
+
score=0.5,
|
160
|
+
logprobs=None,
|
161
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
162
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
163
|
+
),
|
164
|
+
score=0.5,
|
165
|
+
logprobs=None,
|
166
|
+
),
|
167
|
+
],
|
132
168
|
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
133
169
|
),
|
134
170
|
lm_lib.LMSamplingResult(
|
135
|
-
[
|
136
|
-
|
171
|
+
[
|
172
|
+
lm_lib.LMSample(
|
173
|
+
message_lib.AIMessage(
|
174
|
+
'bar' * 2,
|
175
|
+
score=0.5,
|
176
|
+
logprobs=None,
|
177
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
178
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
179
|
+
),
|
180
|
+
score=0.5,
|
181
|
+
logprobs=None,
|
182
|
+
),
|
183
|
+
],
|
184
|
+
usage=lm_lib.LMSamplingUsage(
|
185
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
186
|
+
),
|
137
187
|
),
|
138
|
-
]
|
188
|
+
]
|
139
189
|
)
|
140
190
|
# Test override individual flags within sampling_options.
|
141
191
|
self.assertEqual(
|
142
192
|
lm.sample(prompts=['foo', 'bar'], temperature=1.0),
|
143
193
|
[
|
144
194
|
lm_lib.LMSamplingResult(
|
145
|
-
[
|
195
|
+
[
|
196
|
+
lm_lib.LMSample(
|
197
|
+
message_lib.AIMessage(
|
198
|
+
'foo',
|
199
|
+
score=1.0,
|
200
|
+
logprobs=None,
|
201
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
202
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
203
|
+
),
|
204
|
+
score=1.0,
|
205
|
+
logprobs=None,
|
206
|
+
),
|
207
|
+
],
|
146
208
|
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
147
209
|
),
|
148
210
|
lm_lib.LMSamplingResult(
|
149
|
-
[
|
150
|
-
|
211
|
+
[
|
212
|
+
lm_lib.LMSample(
|
213
|
+
message_lib.AIMessage(
|
214
|
+
'bar',
|
215
|
+
score=1.0,
|
216
|
+
logprobs=None,
|
217
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
218
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
219
|
+
),
|
220
|
+
score=1.0,
|
221
|
+
logprobs=None,
|
222
|
+
),
|
223
|
+
],
|
224
|
+
usage=lm_lib.LMSamplingUsage(
|
225
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
226
|
+
),
|
151
227
|
),
|
152
|
-
]
|
228
|
+
]
|
153
229
|
)
|
154
230
|
self.assertEqual(
|
155
231
|
lm.sample(prompts=['foo', 'bar'], top_k=2, temperature=0.7),
|
156
232
|
[
|
157
233
|
lm_lib.LMSamplingResult(
|
158
|
-
[
|
234
|
+
[
|
235
|
+
lm_lib.LMSample(
|
236
|
+
message_lib.AIMessage(
|
237
|
+
'foo' * 2,
|
238
|
+
score=0.7,
|
239
|
+
logprobs=None,
|
240
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
241
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
242
|
+
),
|
243
|
+
score=0.7,
|
244
|
+
logprobs=None,
|
245
|
+
),
|
246
|
+
],
|
159
247
|
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
160
248
|
),
|
161
249
|
lm_lib.LMSamplingResult(
|
162
|
-
[
|
163
|
-
|
250
|
+
[
|
251
|
+
lm_lib.LMSample(
|
252
|
+
message_lib.AIMessage(
|
253
|
+
'bar' * 2,
|
254
|
+
score=0.7,
|
255
|
+
logprobs=None,
|
256
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
257
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
258
|
+
),
|
259
|
+
score=0.7,
|
260
|
+
logprobs=None,
|
261
|
+
),
|
262
|
+
],
|
263
|
+
usage=lm_lib.LMSamplingUsage(
|
264
|
+
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
265
|
+
),
|
164
266
|
),
|
165
|
-
]
|
267
|
+
]
|
166
268
|
)
|
167
269
|
|
168
270
|
def test_call(self):
|
@@ -189,7 +291,16 @@ class LanguageModelTest(unittest.TestCase):
|
|
189
291
|
lm_lib.LMSamplingResult(
|
190
292
|
[
|
191
293
|
lm_lib.LMSample(
|
192
|
-
message_lib.AIMessage(
|
294
|
+
message_lib.AIMessage(
|
295
|
+
'foo',
|
296
|
+
cache_seed=0,
|
297
|
+
score=-1.0,
|
298
|
+
logprobs=None,
|
299
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
300
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
301
|
+
),
|
302
|
+
score=-1.0,
|
303
|
+
logprobs=None,
|
193
304
|
)
|
194
305
|
],
|
195
306
|
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
@@ -197,7 +308,16 @@ class LanguageModelTest(unittest.TestCase):
|
|
197
308
|
lm_lib.LMSamplingResult(
|
198
309
|
[
|
199
310
|
lm_lib.LMSample(
|
200
|
-
message_lib.AIMessage(
|
311
|
+
message_lib.AIMessage(
|
312
|
+
'bar',
|
313
|
+
cache_seed=0,
|
314
|
+
score=-1.0,
|
315
|
+
logprobs=None,
|
316
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
317
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
318
|
+
),
|
319
|
+
score=-1.0,
|
320
|
+
logprobs=None,
|
201
321
|
)
|
202
322
|
],
|
203
323
|
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
@@ -225,7 +345,16 @@ class LanguageModelTest(unittest.TestCase):
|
|
225
345
|
lm_lib.LMSamplingResult(
|
226
346
|
[
|
227
347
|
lm_lib.LMSample(
|
228
|
-
message_lib.AIMessage(
|
348
|
+
message_lib.AIMessage(
|
349
|
+
'foo',
|
350
|
+
cache_seed=0,
|
351
|
+
score=1.0,
|
352
|
+
logprobs=None,
|
353
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
354
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
355
|
+
),
|
356
|
+
score=1.0,
|
357
|
+
logprobs=None,
|
229
358
|
)
|
230
359
|
],
|
231
360
|
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
@@ -233,7 +362,16 @@ class LanguageModelTest(unittest.TestCase):
|
|
233
362
|
lm_lib.LMSamplingResult(
|
234
363
|
[
|
235
364
|
lm_lib.LMSample(
|
236
|
-
message_lib.AIMessage(
|
365
|
+
message_lib.AIMessage(
|
366
|
+
'baz',
|
367
|
+
cache_seed=0,
|
368
|
+
score=1.0,
|
369
|
+
logprobs=None,
|
370
|
+
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
371
|
+
tags=[message_lib.Message.TAG_LM_RESPONSE],
|
372
|
+
),
|
373
|
+
score=1.0,
|
374
|
+
logprobs=None,
|
237
375
|
)
|
238
376
|
],
|
239
377
|
usage=lm_lib.LMSamplingUsage(100, 100, 200),
|
langfun/core/llms/fake_test.py
CHANGED
@@ -28,7 +28,19 @@ class EchoTest(unittest.TestCase):
|
|
28
28
|
lm.sample(['hi']),
|
29
29
|
[
|
30
30
|
lf.LMSamplingResult(
|
31
|
-
[
|
31
|
+
[
|
32
|
+
lf.LMSample(
|
33
|
+
lf.AIMessage(
|
34
|
+
'hi',
|
35
|
+
score=1.0,
|
36
|
+
logprobs=None,
|
37
|
+
usage=lf.LMSamplingUsage(2, 2, 4),
|
38
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
39
|
+
),
|
40
|
+
score=1.0,
|
41
|
+
logprobs=None,
|
42
|
+
)
|
43
|
+
],
|
32
44
|
lf.LMSamplingUsage(2, 2, 4))
|
33
45
|
]
|
34
46
|
)
|
@@ -60,7 +72,19 @@ class StaticResponseTest(unittest.TestCase):
|
|
60
72
|
lm.sample(['hi']),
|
61
73
|
[
|
62
74
|
lf.LMSamplingResult(
|
63
|
-
[
|
75
|
+
[
|
76
|
+
lf.LMSample(
|
77
|
+
lf.AIMessage(
|
78
|
+
canned_response,
|
79
|
+
score=1.0,
|
80
|
+
logprobs=None,
|
81
|
+
usage=lf.LMSamplingUsage(2, 38, 40),
|
82
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
83
|
+
),
|
84
|
+
score=1.0,
|
85
|
+
logprobs=None,
|
86
|
+
)
|
87
|
+
],
|
64
88
|
usage=lf.LMSamplingUsage(2, 38, 40)
|
65
89
|
)
|
66
90
|
],
|
@@ -69,7 +93,19 @@ class StaticResponseTest(unittest.TestCase):
|
|
69
93
|
lm.sample(['Tell me a joke.']),
|
70
94
|
[
|
71
95
|
lf.LMSamplingResult(
|
72
|
-
[
|
96
|
+
[
|
97
|
+
lf.LMSample(
|
98
|
+
lf.AIMessage(
|
99
|
+
canned_response,
|
100
|
+
score=1.0,
|
101
|
+
logprobs=None,
|
102
|
+
usage=lf.LMSamplingUsage(15, 38, 53),
|
103
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
104
|
+
),
|
105
|
+
score=1.0,
|
106
|
+
logprobs=None,
|
107
|
+
)
|
108
|
+
],
|
73
109
|
usage=lf.LMSamplingUsage(15, 38, 53)
|
74
110
|
)
|
75
111
|
],
|
@@ -101,11 +137,35 @@ class StaticMappingTest(unittest.TestCase):
|
|
101
137
|
lm.sample(['Hi', 'How are you?']),
|
102
138
|
[
|
103
139
|
lf.LMSamplingResult(
|
104
|
-
[
|
140
|
+
[
|
141
|
+
lf.LMSample(
|
142
|
+
lf.AIMessage(
|
143
|
+
'Hello',
|
144
|
+
score=1.0,
|
145
|
+
logprobs=None,
|
146
|
+
usage=lf.LMSamplingUsage(2, 5, 7),
|
147
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
148
|
+
),
|
149
|
+
score=1.0,
|
150
|
+
logprobs=None,
|
151
|
+
)
|
152
|
+
],
|
105
153
|
usage=lf.LMSamplingUsage(2, 5, 7)
|
106
154
|
),
|
107
155
|
lf.LMSamplingResult(
|
108
|
-
[
|
156
|
+
[
|
157
|
+
lf.LMSample(
|
158
|
+
lf.AIMessage(
|
159
|
+
'I am fine, how about you?',
|
160
|
+
score=1.0,
|
161
|
+
logprobs=None,
|
162
|
+
usage=lf.LMSamplingUsage(12, 25, 37),
|
163
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
164
|
+
),
|
165
|
+
score=1.0,
|
166
|
+
logprobs=None,
|
167
|
+
)
|
168
|
+
],
|
109
169
|
usage=lf.LMSamplingUsage(12, 25, 37)
|
110
170
|
)
|
111
171
|
]
|
@@ -126,11 +186,35 @@ class StaticSequenceTest(unittest.TestCase):
|
|
126
186
|
lm.sample(['Hi', 'How are you?']),
|
127
187
|
[
|
128
188
|
lf.LMSamplingResult(
|
129
|
-
[
|
189
|
+
[
|
190
|
+
lf.LMSample(
|
191
|
+
lf.AIMessage(
|
192
|
+
'Hello',
|
193
|
+
score=1.0,
|
194
|
+
logprobs=None,
|
195
|
+
usage=lf.LMSamplingUsage(2, 5, 7),
|
196
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
197
|
+
),
|
198
|
+
score=1.0,
|
199
|
+
logprobs=None,
|
200
|
+
)
|
201
|
+
],
|
130
202
|
usage=lf.LMSamplingUsage(2, 5, 7)
|
131
203
|
),
|
132
204
|
lf.LMSamplingResult(
|
133
|
-
[
|
205
|
+
[
|
206
|
+
lf.LMSample(
|
207
|
+
lf.AIMessage(
|
208
|
+
'I am fine, how about you?',
|
209
|
+
score=1.0,
|
210
|
+
logprobs=None,
|
211
|
+
usage=lf.LMSamplingUsage(12, 25, 37),
|
212
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
213
|
+
),
|
214
|
+
score=1.0,
|
215
|
+
logprobs=None,
|
216
|
+
)
|
217
|
+
],
|
134
218
|
usage=lf.LMSamplingUsage(12, 25, 37)
|
135
219
|
)
|
136
220
|
]
|
langfun/core/llms/openai_test.py
CHANGED
@@ -184,23 +184,96 @@ class OpenAITest(unittest.TestCase):
|
|
184
184
|
results[0],
|
185
185
|
lf.LMSamplingResult(
|
186
186
|
[
|
187
|
-
lf.LMSample(
|
188
|
-
|
189
|
-
|
187
|
+
lf.LMSample(
|
188
|
+
lf.AIMessage(
|
189
|
+
'Sample 0 for prompt 0.',
|
190
|
+
score=0.0,
|
191
|
+
logprobs=None,
|
192
|
+
usage=lf.LMSamplingUsage(
|
193
|
+
prompt_tokens=33,
|
194
|
+
completion_tokens=33,
|
195
|
+
total_tokens=66
|
196
|
+
),
|
197
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
198
|
+
),
|
199
|
+
score=0.0,
|
200
|
+
logprobs=None,
|
201
|
+
),
|
202
|
+
lf.LMSample(
|
203
|
+
lf.AIMessage(
|
204
|
+
'Sample 1 for prompt 0.',
|
205
|
+
score=0.1,
|
206
|
+
logprobs=None,
|
207
|
+
usage=lf.LMSamplingUsage(
|
208
|
+
prompt_tokens=33,
|
209
|
+
completion_tokens=33,
|
210
|
+
total_tokens=66
|
211
|
+
),
|
212
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
213
|
+
),
|
214
|
+
score=0.1,
|
215
|
+
logprobs=None,
|
216
|
+
),
|
217
|
+
lf.LMSample(
|
218
|
+
lf.AIMessage(
|
219
|
+
'Sample 2 for prompt 0.',
|
220
|
+
score=0.2,
|
221
|
+
logprobs=None,
|
222
|
+
usage=lf.LMSamplingUsage(
|
223
|
+
prompt_tokens=33,
|
224
|
+
completion_tokens=33,
|
225
|
+
total_tokens=66
|
226
|
+
),
|
227
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
228
|
+
),
|
229
|
+
score=0.2,
|
230
|
+
logprobs=None,
|
231
|
+
),
|
190
232
|
],
|
191
233
|
usage=lf.LMSamplingUsage(
|
192
234
|
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
193
235
|
),
|
194
236
|
),
|
195
237
|
)
|
196
|
-
|
197
238
|
self.assertEqual(
|
198
239
|
results[1],
|
199
|
-
lf.LMSamplingResult(
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
240
|
+
lf.LMSamplingResult(
|
241
|
+
[
|
242
|
+
lf.LMSample(
|
243
|
+
lf.AIMessage(
|
244
|
+
'Sample 0 for prompt 1.',
|
245
|
+
score=0.0,
|
246
|
+
logprobs=None,
|
247
|
+
usage=None,
|
248
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
249
|
+
),
|
250
|
+
score=0.0,
|
251
|
+
logprobs=None,
|
252
|
+
),
|
253
|
+
lf.LMSample(
|
254
|
+
lf.AIMessage(
|
255
|
+
'Sample 1 for prompt 1.',
|
256
|
+
score=0.1,
|
257
|
+
logprobs=None,
|
258
|
+
usage=None,
|
259
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
260
|
+
),
|
261
|
+
score=0.1,
|
262
|
+
logprobs=None,
|
263
|
+
),
|
264
|
+
lf.LMSample(
|
265
|
+
lf.AIMessage(
|
266
|
+
'Sample 2 for prompt 1.',
|
267
|
+
score=0.2,
|
268
|
+
logprobs=None,
|
269
|
+
usage=None,
|
270
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
271
|
+
),
|
272
|
+
score=0.2,
|
273
|
+
logprobs=None,
|
274
|
+
),
|
275
|
+
],
|
276
|
+
),
|
204
277
|
)
|
205
278
|
|
206
279
|
def test_sample_chat_completion(self):
|
@@ -216,9 +289,51 @@ class OpenAITest(unittest.TestCase):
|
|
216
289
|
results[0],
|
217
290
|
lf.LMSamplingResult(
|
218
291
|
[
|
219
|
-
lf.LMSample(
|
220
|
-
|
221
|
-
|
292
|
+
lf.LMSample(
|
293
|
+
lf.AIMessage(
|
294
|
+
'Sample 0 for message.',
|
295
|
+
score=0.0,
|
296
|
+
logprobs=None,
|
297
|
+
usage=lf.LMSamplingUsage(
|
298
|
+
prompt_tokens=33,
|
299
|
+
completion_tokens=33,
|
300
|
+
total_tokens=66
|
301
|
+
),
|
302
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
303
|
+
),
|
304
|
+
score=0.0,
|
305
|
+
logprobs=None,
|
306
|
+
),
|
307
|
+
lf.LMSample(
|
308
|
+
lf.AIMessage(
|
309
|
+
'Sample 1 for message.',
|
310
|
+
score=0.0,
|
311
|
+
logprobs=None,
|
312
|
+
usage=lf.LMSamplingUsage(
|
313
|
+
prompt_tokens=33,
|
314
|
+
completion_tokens=33,
|
315
|
+
total_tokens=66
|
316
|
+
),
|
317
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
318
|
+
),
|
319
|
+
score=0.0,
|
320
|
+
logprobs=None,
|
321
|
+
),
|
322
|
+
lf.LMSample(
|
323
|
+
lf.AIMessage(
|
324
|
+
'Sample 2 for message.',
|
325
|
+
score=0.0,
|
326
|
+
logprobs=None,
|
327
|
+
usage=lf.LMSamplingUsage(
|
328
|
+
prompt_tokens=33,
|
329
|
+
completion_tokens=33,
|
330
|
+
total_tokens=66
|
331
|
+
),
|
332
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
333
|
+
),
|
334
|
+
score=0.0,
|
335
|
+
logprobs=None,
|
336
|
+
),
|
222
337
|
],
|
223
338
|
usage=lf.LMSamplingUsage(
|
224
339
|
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
@@ -229,9 +344,51 @@ class OpenAITest(unittest.TestCase):
|
|
229
344
|
results[1],
|
230
345
|
lf.LMSamplingResult(
|
231
346
|
[
|
232
|
-
lf.LMSample(
|
233
|
-
|
234
|
-
|
347
|
+
lf.LMSample(
|
348
|
+
lf.AIMessage(
|
349
|
+
'Sample 0 for message.',
|
350
|
+
score=0.0,
|
351
|
+
logprobs=None,
|
352
|
+
usage=lf.LMSamplingUsage(
|
353
|
+
prompt_tokens=33,
|
354
|
+
completion_tokens=33,
|
355
|
+
total_tokens=66
|
356
|
+
),
|
357
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
358
|
+
),
|
359
|
+
score=0.0,
|
360
|
+
logprobs=None,
|
361
|
+
),
|
362
|
+
lf.LMSample(
|
363
|
+
lf.AIMessage(
|
364
|
+
'Sample 1 for message.',
|
365
|
+
score=0.0,
|
366
|
+
logprobs=None,
|
367
|
+
usage=lf.LMSamplingUsage(
|
368
|
+
prompt_tokens=33,
|
369
|
+
completion_tokens=33,
|
370
|
+
total_tokens=66
|
371
|
+
),
|
372
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
373
|
+
),
|
374
|
+
score=0.0,
|
375
|
+
logprobs=None,
|
376
|
+
),
|
377
|
+
lf.LMSample(
|
378
|
+
lf.AIMessage(
|
379
|
+
'Sample 2 for message.',
|
380
|
+
score=0.0,
|
381
|
+
logprobs=None,
|
382
|
+
usage=lf.LMSamplingUsage(
|
383
|
+
prompt_tokens=33,
|
384
|
+
completion_tokens=33,
|
385
|
+
total_tokens=66
|
386
|
+
),
|
387
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
388
|
+
),
|
389
|
+
score=0.0,
|
390
|
+
logprobs=None,
|
391
|
+
),
|
235
392
|
],
|
236
393
|
usage=lf.LMSamplingUsage(
|
237
394
|
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
@@ -251,8 +408,36 @@ class OpenAITest(unittest.TestCase):
|
|
251
408
|
results[0],
|
252
409
|
lf.LMSamplingResult(
|
253
410
|
[
|
254
|
-
lf.LMSample(
|
255
|
-
|
411
|
+
lf.LMSample(
|
412
|
+
lf.AIMessage(
|
413
|
+
'Sample 0 for prompt 0.',
|
414
|
+
score=0.0,
|
415
|
+
logprobs=None,
|
416
|
+
usage=lf.LMSamplingUsage(
|
417
|
+
prompt_tokens=50,
|
418
|
+
completion_tokens=50,
|
419
|
+
total_tokens=100,
|
420
|
+
),
|
421
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
422
|
+
),
|
423
|
+
score=0.0,
|
424
|
+
logprobs=None,
|
425
|
+
),
|
426
|
+
lf.LMSample(
|
427
|
+
lf.AIMessage(
|
428
|
+
'Sample 1 for prompt 0.',
|
429
|
+
score=0.1,
|
430
|
+
logprobs=None,
|
431
|
+
usage=lf.LMSamplingUsage(
|
432
|
+
prompt_tokens=50,
|
433
|
+
completion_tokens=50,
|
434
|
+
total_tokens=100,
|
435
|
+
),
|
436
|
+
tags=[lf.Message.TAG_LM_RESPONSE],
|
437
|
+
),
|
438
|
+
score=0.1,
|
439
|
+
logprobs=None,
|
440
|
+
),
|
256
441
|
],
|
257
442
|
usage=lf.LMSamplingUsage(
|
258
443
|
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
@@ -51,6 +51,7 @@ from langfun.core.structured.schema_generation import default_classgen_examples
|
|
51
51
|
from langfun.core.structured.function_generation import function_gen
|
52
52
|
|
53
53
|
from langfun.core.structured.mapping import Mapping
|
54
|
+
from langfun.core.structured.mapping import MappingError
|
54
55
|
from langfun.core.structured.mapping import MappingExample
|
55
56
|
|
56
57
|
from langfun.core.structured.parsing import ParseStructure
|
@@ -17,7 +17,6 @@ import inspect
|
|
17
17
|
import unittest
|
18
18
|
|
19
19
|
import langfun.core as lf
|
20
|
-
from langfun.core import coding
|
21
20
|
from langfun.core import modalities
|
22
21
|
from langfun.core.llms import fake
|
23
22
|
from langfun.core.structured import completion
|
@@ -608,7 +607,7 @@ class CompleteStructureTest(unittest.TestCase):
|
|
608
607
|
override_attrs=True,
|
609
608
|
):
|
610
609
|
with self.assertRaisesRegex(
|
611
|
-
|
610
|
+
mapping.MappingError,
|
612
611
|
'Expect .* but encountered .*',
|
613
612
|
):
|
614
613
|
completion.complete(Activity.partial(), autofix=0)
|