langfun 0.1.2.dev202501060804__py3-none-any.whl → 0.1.2.dev202501100804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +0 -5
- langfun/core/coding/python/correction.py +4 -3
- langfun/core/coding/python/errors.py +10 -9
- langfun/core/coding/python/execution.py +23 -12
- langfun/core/coding/python/execution_test.py +21 -2
- langfun/core/coding/python/generation.py +18 -9
- langfun/core/concurrent.py +2 -3
- langfun/core/console.py +8 -3
- langfun/core/eval/base.py +2 -3
- langfun/core/eval/v2/reporting.py +8 -4
- langfun/core/language_model.py +7 -4
- langfun/core/language_model_test.py +15 -0
- langfun/core/llms/__init__.py +7 -0
- langfun/core/llms/deepseek.py +117 -0
- langfun/core/llms/deepseek_test.py +61 -0
- langfun/core/llms/google_genai.py +1 -0
- langfun/core/llms/groq.py +12 -99
- langfun/core/llms/groq_test.py +31 -137
- langfun/core/llms/llama_cpp.py +17 -54
- langfun/core/llms/llama_cpp_test.py +2 -34
- langfun/core/llms/openai.py +14 -147
- langfun/core/llms/openai_compatible.py +179 -0
- langfun/core/llms/openai_compatible_test.py +480 -0
- langfun/core/llms/openai_test.py +13 -423
- langfun/core/llms/vertexai.py +6 -2
- langfun/core/llms/vertexai_test.py +1 -1
- langfun/core/modalities/mime.py +8 -0
- langfun/core/modalities/mime_test.py +19 -4
- langfun/core/modality_test.py +0 -1
- langfun/core/structured/mapping.py +13 -13
- langfun/core/structured/mapping_test.py +2 -2
- langfun/core/structured/schema.py +16 -8
- {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/METADATA +13 -2
- {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/RECORD +37 -35
- {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/WHEEL +1 -1
- langfun/core/text_formatting.py +0 -168
- langfun/core/text_formatting_test.py +0 -65
- {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202501060804.dist-info → langfun-0.1.2.dev202501100804.dist-info}/top_level.txt +0 -0
langfun/core/llms/openai_test.py
CHANGED
@@ -13,102 +13,9 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Tests for OpenAI models."""
|
15
15
|
|
16
|
-
from typing import Any
|
17
16
|
import unittest
|
18
|
-
from unittest import mock
|
19
|
-
|
20
17
|
import langfun.core as lf
|
21
|
-
from langfun.core import modalities as lf_modalities
|
22
18
|
from langfun.core.llms import openai
|
23
|
-
import pyglove as pg
|
24
|
-
import requests
|
25
|
-
|
26
|
-
|
27
|
-
def mock_chat_completion_request(url: str, json: dict[str, Any], **kwargs):
|
28
|
-
del url, kwargs
|
29
|
-
messages = json['messages']
|
30
|
-
if len(messages) > 1:
|
31
|
-
system_message = f' system={messages[0]["content"]}'
|
32
|
-
else:
|
33
|
-
system_message = ''
|
34
|
-
|
35
|
-
if 'response_format' in json:
|
36
|
-
response_format = f' format={json["response_format"]["type"]}'
|
37
|
-
else:
|
38
|
-
response_format = ''
|
39
|
-
|
40
|
-
choices = []
|
41
|
-
for k in range(json['n']):
|
42
|
-
if json.get('logprobs'):
|
43
|
-
logprobs = dict(
|
44
|
-
content=[
|
45
|
-
dict(
|
46
|
-
token='chosen_token',
|
47
|
-
logprob=0.5,
|
48
|
-
top_logprobs=[
|
49
|
-
dict(
|
50
|
-
token=f'alternative_token_{i + 1}',
|
51
|
-
logprob=0.1
|
52
|
-
) for i in range(3)
|
53
|
-
]
|
54
|
-
)
|
55
|
-
]
|
56
|
-
)
|
57
|
-
else:
|
58
|
-
logprobs = None
|
59
|
-
|
60
|
-
choices.append(dict(
|
61
|
-
message=dict(
|
62
|
-
content=(
|
63
|
-
f'Sample {k} for message.{system_message}{response_format}'
|
64
|
-
)
|
65
|
-
),
|
66
|
-
logprobs=logprobs,
|
67
|
-
))
|
68
|
-
response = requests.Response()
|
69
|
-
response.status_code = 200
|
70
|
-
response._content = pg.to_json_str(
|
71
|
-
dict(
|
72
|
-
choices=choices,
|
73
|
-
usage=lf.LMSamplingUsage(
|
74
|
-
prompt_tokens=100,
|
75
|
-
completion_tokens=100,
|
76
|
-
total_tokens=200,
|
77
|
-
),
|
78
|
-
)
|
79
|
-
).encode()
|
80
|
-
return response
|
81
|
-
|
82
|
-
|
83
|
-
def mock_chat_completion_request_vision(
|
84
|
-
url: str, json: dict[str, Any], **kwargs
|
85
|
-
):
|
86
|
-
del url, kwargs
|
87
|
-
choices = []
|
88
|
-
urls = [
|
89
|
-
c['image_url']['url']
|
90
|
-
for c in json['messages'][0]['content'] if c['type'] == 'image_url'
|
91
|
-
]
|
92
|
-
for k in range(json['n']):
|
93
|
-
choices.append(pg.Dict(
|
94
|
-
message=pg.Dict(
|
95
|
-
content=f'Sample {k} for message: {"".join(urls)}'
|
96
|
-
),
|
97
|
-
logprobs=None,
|
98
|
-
))
|
99
|
-
response = requests.Response()
|
100
|
-
response.status_code = 200
|
101
|
-
response._content = pg.to_json_str(
|
102
|
-
dict(
|
103
|
-
choices=choices,
|
104
|
-
usage=lf.LMSamplingUsage(
|
105
|
-
prompt_tokens=100,
|
106
|
-
completion_tokens=100,
|
107
|
-
total_tokens=200,
|
108
|
-
),
|
109
|
-
)
|
110
|
-
).encode()
|
111
|
-
return response
|
112
19
|
|
113
20
|
|
114
21
|
class OpenAITest(unittest.TestCase):
|
@@ -130,6 +37,15 @@ class OpenAITest(unittest.TestCase):
|
|
130
37
|
openai.Gpt35(api_key='test_key').resource_id, 'OpenAI(text-davinci-003)'
|
131
38
|
)
|
132
39
|
|
40
|
+
def test_headers(self):
|
41
|
+
self.assertEqual(
|
42
|
+
openai.Gpt35(api_key='test_key').headers,
|
43
|
+
{
|
44
|
+
'Content-Type': 'application/json',
|
45
|
+
'Authorization': 'Bearer test_key',
|
46
|
+
},
|
47
|
+
)
|
48
|
+
|
133
49
|
def test_max_concurrency(self):
|
134
50
|
self.assertGreater(openai.Gpt35(api_key='test_key').max_concurrency, 0)
|
135
51
|
|
@@ -156,340 +72,14 @@ class OpenAITest(unittest.TestCase):
|
|
156
72
|
)
|
157
73
|
)
|
158
74
|
|
159
|
-
def
|
160
|
-
with mock.patch('requests.Session.post') as mock_request:
|
161
|
-
mock_request.side_effect = mock_chat_completion_request
|
162
|
-
lm = openai.OpenAI(
|
163
|
-
model='gpt-4',
|
164
|
-
api_key='test_key',
|
165
|
-
organization='my_org',
|
166
|
-
project='my_project'
|
167
|
-
)
|
168
|
-
self.assertEqual(
|
169
|
-
lm('hello', sampling_options=lf.LMSamplingOptions(n=2)),
|
170
|
-
'Sample 0 for message.',
|
171
|
-
)
|
172
|
-
|
173
|
-
def test_call_chat_completion_with_logprobs(self):
|
174
|
-
with mock.patch('requests.Session.post') as mock_request:
|
175
|
-
mock_request.side_effect = mock_chat_completion_request
|
176
|
-
lm = openai.OpenAI(
|
177
|
-
model='gpt-4',
|
178
|
-
api_key='test_key',
|
179
|
-
organization='my_org',
|
180
|
-
project='my_project'
|
181
|
-
)
|
182
|
-
results = lm.sample(['hello'], logprobs=True)
|
183
|
-
self.assertEqual(len(results), 1)
|
184
|
-
self.assertEqual(
|
185
|
-
results[0],
|
186
|
-
lf.LMSamplingResult(
|
187
|
-
[
|
188
|
-
lf.LMSample(
|
189
|
-
response=lf.AIMessage(
|
190
|
-
text='Sample 0 for message.',
|
191
|
-
metadata={
|
192
|
-
'score': 0.0,
|
193
|
-
'logprobs': [(
|
194
|
-
'chosen_token',
|
195
|
-
0.5,
|
196
|
-
[
|
197
|
-
('alternative_token_1', 0.1),
|
198
|
-
('alternative_token_2', 0.1),
|
199
|
-
('alternative_token_3', 0.1),
|
200
|
-
],
|
201
|
-
)],
|
202
|
-
'is_cached': False,
|
203
|
-
'usage': lf.LMSamplingUsage(
|
204
|
-
prompt_tokens=100,
|
205
|
-
completion_tokens=100,
|
206
|
-
total_tokens=200,
|
207
|
-
estimated_cost=0.009,
|
208
|
-
),
|
209
|
-
},
|
210
|
-
tags=['lm-response'],
|
211
|
-
),
|
212
|
-
logprobs=[(
|
213
|
-
'chosen_token',
|
214
|
-
0.5,
|
215
|
-
[
|
216
|
-
('alternative_token_1', 0.1),
|
217
|
-
('alternative_token_2', 0.1),
|
218
|
-
('alternative_token_3', 0.1),
|
219
|
-
],
|
220
|
-
)],
|
221
|
-
)
|
222
|
-
],
|
223
|
-
usage=lf.LMSamplingUsage(
|
224
|
-
prompt_tokens=100,
|
225
|
-
completion_tokens=100,
|
226
|
-
total_tokens=200,
|
227
|
-
estimated_cost=0.009,
|
228
|
-
),
|
229
|
-
),
|
230
|
-
)
|
231
|
-
|
232
|
-
def test_call_chat_completion_vision(self):
|
233
|
-
with mock.patch('requests.Session.post') as mock_request:
|
234
|
-
mock_request.side_effect = mock_chat_completion_request_vision
|
235
|
-
lm_1 = openai.Gpt4Turbo(api_key='test_key')
|
236
|
-
lm_2 = openai.Gpt4VisionPreview(api_key='test_key')
|
237
|
-
for lm in (lm_1, lm_2):
|
238
|
-
self.assertEqual(
|
239
|
-
lm(
|
240
|
-
lf.UserMessage(
|
241
|
-
'hello <<[[image]]>>',
|
242
|
-
image=lf_modalities.Image.from_uri('https://fake/image')
|
243
|
-
),
|
244
|
-
sampling_options=lf.LMSamplingOptions(n=2)
|
245
|
-
),
|
246
|
-
'Sample 0 for message: https://fake/image',
|
247
|
-
)
|
248
|
-
lm_3 = openai.Gpt35Turbo(api_key='test_key')
|
249
|
-
with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
|
250
|
-
lm_3(
|
251
|
-
lf.UserMessage(
|
252
|
-
'hello <<[[image]]>>',
|
253
|
-
image=lf_modalities.Image.from_uri('https://fake/image')
|
254
|
-
),
|
255
|
-
)
|
256
|
-
|
257
|
-
def test_sample_chat_completion(self):
|
258
|
-
with mock.patch('requests.Session.post') as mock_request:
|
259
|
-
mock_request.side_effect = mock_chat_completion_request
|
260
|
-
openai.SUPPORTED_MODELS_AND_SETTINGS['gpt-4'].update({
|
261
|
-
'cost_per_1k_input_tokens': 1.0,
|
262
|
-
'cost_per_1k_output_tokens': 1.0,
|
263
|
-
})
|
264
|
-
lm = openai.OpenAI(api_key='test_key', model='gpt-4')
|
265
|
-
results = lm.sample(
|
266
|
-
['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3)
|
267
|
-
)
|
268
|
-
|
269
|
-
self.assertEqual(len(results), 2)
|
270
|
-
print(results[0])
|
271
|
-
self.assertEqual(
|
272
|
-
results[0],
|
273
|
-
lf.LMSamplingResult(
|
274
|
-
[
|
275
|
-
lf.LMSample(
|
276
|
-
lf.AIMessage(
|
277
|
-
'Sample 0 for message.',
|
278
|
-
score=0.0,
|
279
|
-
logprobs=None,
|
280
|
-
is_cached=False,
|
281
|
-
usage=lf.LMSamplingUsage(
|
282
|
-
prompt_tokens=33,
|
283
|
-
completion_tokens=33,
|
284
|
-
total_tokens=66,
|
285
|
-
estimated_cost=0.2 / 3,
|
286
|
-
),
|
287
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
288
|
-
),
|
289
|
-
score=0.0,
|
290
|
-
logprobs=None,
|
291
|
-
),
|
292
|
-
lf.LMSample(
|
293
|
-
lf.AIMessage(
|
294
|
-
'Sample 1 for message.',
|
295
|
-
score=0.0,
|
296
|
-
logprobs=None,
|
297
|
-
is_cached=False,
|
298
|
-
usage=lf.LMSamplingUsage(
|
299
|
-
prompt_tokens=33,
|
300
|
-
completion_tokens=33,
|
301
|
-
total_tokens=66,
|
302
|
-
estimated_cost=0.2 / 3,
|
303
|
-
),
|
304
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
305
|
-
),
|
306
|
-
score=0.0,
|
307
|
-
logprobs=None,
|
308
|
-
),
|
309
|
-
lf.LMSample(
|
310
|
-
lf.AIMessage(
|
311
|
-
'Sample 2 for message.',
|
312
|
-
score=0.0,
|
313
|
-
logprobs=None,
|
314
|
-
is_cached=False,
|
315
|
-
usage=lf.LMSamplingUsage(
|
316
|
-
prompt_tokens=33,
|
317
|
-
completion_tokens=33,
|
318
|
-
total_tokens=66,
|
319
|
-
estimated_cost=0.2 / 3,
|
320
|
-
),
|
321
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
322
|
-
),
|
323
|
-
score=0.0,
|
324
|
-
logprobs=None,
|
325
|
-
),
|
326
|
-
],
|
327
|
-
usage=lf.LMSamplingUsage(
|
328
|
-
prompt_tokens=100, completion_tokens=100, total_tokens=200,
|
329
|
-
estimated_cost=0.2,
|
330
|
-
),
|
331
|
-
),
|
332
|
-
)
|
75
|
+
def test_estimate_cost(self):
|
333
76
|
self.assertEqual(
|
334
|
-
|
335
|
-
|
336
|
-
[
|
337
|
-
lf.LMSample(
|
338
|
-
lf.AIMessage(
|
339
|
-
'Sample 0 for message.',
|
340
|
-
score=0.0,
|
341
|
-
logprobs=None,
|
342
|
-
is_cached=False,
|
343
|
-
usage=lf.LMSamplingUsage(
|
344
|
-
prompt_tokens=33,
|
345
|
-
completion_tokens=33,
|
346
|
-
total_tokens=66,
|
347
|
-
estimated_cost=0.2 / 3,
|
348
|
-
),
|
349
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
350
|
-
),
|
351
|
-
score=0.0,
|
352
|
-
logprobs=None,
|
353
|
-
),
|
354
|
-
lf.LMSample(
|
355
|
-
lf.AIMessage(
|
356
|
-
'Sample 1 for message.',
|
357
|
-
score=0.0,
|
358
|
-
logprobs=None,
|
359
|
-
is_cached=False,
|
360
|
-
usage=lf.LMSamplingUsage(
|
361
|
-
prompt_tokens=33,
|
362
|
-
completion_tokens=33,
|
363
|
-
total_tokens=66,
|
364
|
-
estimated_cost=0.2 / 3,
|
365
|
-
),
|
366
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
367
|
-
),
|
368
|
-
score=0.0,
|
369
|
-
logprobs=None,
|
370
|
-
),
|
371
|
-
lf.LMSample(
|
372
|
-
lf.AIMessage(
|
373
|
-
'Sample 2 for message.',
|
374
|
-
score=0.0,
|
375
|
-
logprobs=None,
|
376
|
-
is_cached=False,
|
377
|
-
usage=lf.LMSamplingUsage(
|
378
|
-
prompt_tokens=33,
|
379
|
-
completion_tokens=33,
|
380
|
-
total_tokens=66,
|
381
|
-
estimated_cost=0.2 / 3,
|
382
|
-
),
|
383
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
384
|
-
),
|
385
|
-
score=0.0,
|
386
|
-
logprobs=None,
|
387
|
-
),
|
388
|
-
],
|
389
|
-
usage=lf.LMSamplingUsage(
|
390
|
-
prompt_tokens=100, completion_tokens=100, total_tokens=200,
|
391
|
-
estimated_cost=0.2,
|
392
|
-
),
|
77
|
+
openai.Gpt4(api_key='test_key').estimate_cost(
|
78
|
+
num_input_tokens=100, num_output_tokens=100
|
393
79
|
),
|
80
|
+
0.009
|
394
81
|
)
|
395
82
|
|
396
|
-
def test_sample_with_contextual_options(self):
|
397
|
-
with mock.patch('requests.Session.post') as mock_request:
|
398
|
-
mock_request.side_effect = mock_chat_completion_request
|
399
|
-
lm = openai.OpenAI(api_key='test_key', model='text-davinci-003')
|
400
|
-
with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
|
401
|
-
results = lm.sample(['hello'])
|
402
|
-
|
403
|
-
self.assertEqual(len(results), 1)
|
404
|
-
self.assertEqual(
|
405
|
-
results[0],
|
406
|
-
lf.LMSamplingResult(
|
407
|
-
[
|
408
|
-
lf.LMSample(
|
409
|
-
lf.AIMessage(
|
410
|
-
'Sample 0 for message.',
|
411
|
-
score=0.0,
|
412
|
-
logprobs=None,
|
413
|
-
is_cached=False,
|
414
|
-
usage=lf.LMSamplingUsage(
|
415
|
-
prompt_tokens=50,
|
416
|
-
completion_tokens=50,
|
417
|
-
total_tokens=100,
|
418
|
-
),
|
419
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
420
|
-
),
|
421
|
-
score=0.0,
|
422
|
-
logprobs=None,
|
423
|
-
),
|
424
|
-
lf.LMSample(
|
425
|
-
lf.AIMessage(
|
426
|
-
'Sample 1 for message.',
|
427
|
-
score=0.0,
|
428
|
-
logprobs=None,
|
429
|
-
is_cached=False,
|
430
|
-
usage=lf.LMSamplingUsage(
|
431
|
-
prompt_tokens=50,
|
432
|
-
completion_tokens=50,
|
433
|
-
total_tokens=100,
|
434
|
-
),
|
435
|
-
tags=[lf.Message.TAG_LM_RESPONSE],
|
436
|
-
),
|
437
|
-
score=0.0,
|
438
|
-
logprobs=None,
|
439
|
-
),
|
440
|
-
],
|
441
|
-
usage=lf.LMSamplingUsage(
|
442
|
-
prompt_tokens=100, completion_tokens=100, total_tokens=200
|
443
|
-
),
|
444
|
-
)
|
445
|
-
)
|
446
|
-
|
447
|
-
def test_call_with_system_message(self):
|
448
|
-
with mock.patch('requests.Session.post') as mock_request:
|
449
|
-
mock_request.side_effect = mock_chat_completion_request
|
450
|
-
lm = openai.OpenAI(api_key='test_key', model='gpt-4')
|
451
|
-
self.assertEqual(
|
452
|
-
lm(
|
453
|
-
lf.UserMessage(
|
454
|
-
'hello',
|
455
|
-
system_message='hi',
|
456
|
-
),
|
457
|
-
sampling_options=lf.LMSamplingOptions(n=2)
|
458
|
-
),
|
459
|
-
'''Sample 0 for message. system=[{'type': 'text', 'text': 'hi'}]''',
|
460
|
-
)
|
461
|
-
|
462
|
-
def test_call_with_json_schema(self):
|
463
|
-
with mock.patch('requests.Session.post') as mock_request:
|
464
|
-
mock_request.side_effect = mock_chat_completion_request
|
465
|
-
lm = openai.OpenAI(api_key='test_key', model='gpt-4')
|
466
|
-
self.assertEqual(
|
467
|
-
lm(
|
468
|
-
lf.UserMessage(
|
469
|
-
'hello',
|
470
|
-
json_schema={
|
471
|
-
'type': 'object',
|
472
|
-
'properties': {
|
473
|
-
'name': {'type': 'string'},
|
474
|
-
},
|
475
|
-
'required': ['name'],
|
476
|
-
'title': 'Person',
|
477
|
-
}
|
478
|
-
),
|
479
|
-
sampling_options=lf.LMSamplingOptions(n=2)
|
480
|
-
),
|
481
|
-
'Sample 0 for message. format=json_schema',
|
482
|
-
)
|
483
|
-
|
484
|
-
# Test bad json schema.
|
485
|
-
with self.assertRaisesRegex(ValueError, '`json_schema` must be a dict'):
|
486
|
-
lm(lf.UserMessage('hello', json_schema='foo'))
|
487
|
-
|
488
|
-
with self.assertRaisesRegex(
|
489
|
-
ValueError, 'The root of `json_schema` must have a `title` field'
|
490
|
-
):
|
491
|
-
lm(lf.UserMessage('hello', json_schema={}))
|
492
|
-
|
493
83
|
|
494
84
|
if __name__ == '__main__':
|
495
85
|
unittest.main()
|
langfun/core/llms/vertexai.py
CHANGED
@@ -90,6 +90,8 @@ class VertexAI(gemini.Gemini):
|
|
90
90
|
)
|
91
91
|
|
92
92
|
self._project = project
|
93
|
+
self._location = location
|
94
|
+
|
93
95
|
credentials = self.credentials
|
94
96
|
if credentials is None:
|
95
97
|
# Use default credentials.
|
@@ -114,9 +116,10 @@ class VertexAI(gemini.Gemini):
|
|
114
116
|
|
115
117
|
@property
|
116
118
|
def api_endpoint(self) -> str:
|
119
|
+
assert self._api_initialized
|
117
120
|
return (
|
118
|
-
f'https://{self.
|
119
|
-
f'{self.
|
121
|
+
f'https://{self._location}-aiplatform.googleapis.com/v1/projects/'
|
122
|
+
f'{self._project}/locations/{self._location}/publishers/google/'
|
120
123
|
f'models/{self.model}:generateContent'
|
121
124
|
)
|
122
125
|
|
@@ -126,6 +129,7 @@ class VertexAIGeminiFlash2_0ThinkingExp_20241219(VertexAI): # pylint: disable=i
|
|
126
129
|
|
127
130
|
api_version = 'v1alpha'
|
128
131
|
model = 'gemini-2.0-flash-thinking-exp-1219'
|
132
|
+
timeout = None
|
129
133
|
|
130
134
|
|
131
135
|
class VertexAIGeminiFlash2_0Exp(VertexAI): # pylint: disable=invalid-name
|
@@ -41,7 +41,7 @@ class VertexAITest(unittest.TestCase):
|
|
41
41
|
os.environ['VERTEXAI_LOCATION'] = 'us-central1'
|
42
42
|
model = vertexai.VertexAIGeminiPro1()
|
43
43
|
self.assertTrue(model.model_id.startswith('VertexAI('))
|
44
|
-
self.
|
44
|
+
self.assertIn('us-central1', model.api_endpoint)
|
45
45
|
self.assertTrue(model._api_initialized)
|
46
46
|
self.assertIsNotNone(model._session)
|
47
47
|
del os.environ['VERTEXAI_PROJECT']
|
langfun/core/modalities/mime.py
CHANGED
@@ -138,9 +138,17 @@ class Mime(lf.Modality):
|
|
138
138
|
|
139
139
|
@property
|
140
140
|
def content_uri(self) -> str:
|
141
|
+
"""Returns the URI with encoded content."""
|
141
142
|
base64_content = base64.b64encode(self.to_bytes()).decode()
|
142
143
|
return f'data:{self.mime_type};base64,{base64_content}'
|
143
144
|
|
145
|
+
@property
|
146
|
+
def embeddable_uri(self) -> str:
|
147
|
+
"""Returns the URI that can be embedded in HTML."""
|
148
|
+
if self.uri and self.uri.lower().startswith(('http:', 'https:', 'ftp:')):
|
149
|
+
return self.uri
|
150
|
+
return self.content_uri
|
151
|
+
|
144
152
|
@classmethod
|
145
153
|
def from_uri(cls, uri: str, **kwargs) -> 'Mime':
|
146
154
|
if cls is Mime:
|
@@ -23,12 +23,12 @@ import pyglove as pg
|
|
23
23
|
|
24
24
|
def mock_request(*args, **kwargs):
|
25
25
|
del args, kwargs
|
26
|
-
return pg.Dict(content='foo')
|
26
|
+
return pg.Dict(content=b'foo')
|
27
27
|
|
28
28
|
|
29
29
|
def mock_readfile(*args, **kwargs):
|
30
30
|
del args, kwargs
|
31
|
-
return 'bar'
|
31
|
+
return b'bar'
|
32
32
|
|
33
33
|
|
34
34
|
class CustomMimeTest(unittest.TestCase):
|
@@ -65,17 +65,32 @@ class CustomMimeTest(unittest.TestCase):
|
|
65
65
|
):
|
66
66
|
mime.Custom('text/plain')
|
67
67
|
|
68
|
+
def test_uri(self):
|
69
|
+
content = mime.Custom.from_uri('http://mock/web/a.txt', mime='text/plain')
|
70
|
+
with mock.patch('requests.get') as mock_requests_stub:
|
71
|
+
mock_requests_stub.side_effect = mock_request
|
72
|
+
self.assertEqual(content.uri, 'http://mock/web/a.txt')
|
73
|
+
self.assertEqual(content.content_uri, 'data:text/plain;base64,Zm9v')
|
74
|
+
self.assertEqual(content.embeddable_uri, 'http://mock/web/a.txt')
|
75
|
+
|
76
|
+
content = mime.Custom.from_uri('a.txt', mime='text/plain')
|
77
|
+
with mock.patch('pyglove.io.readfile') as mock_readfile_stub:
|
78
|
+
mock_readfile_stub.side_effect = mock_readfile
|
79
|
+
self.assertEqual(content.uri, 'a.txt')
|
80
|
+
self.assertEqual(content.content_uri, 'data:text/plain;base64,YmFy')
|
81
|
+
self.assertEqual(content.embeddable_uri, 'data:text/plain;base64,YmFy')
|
82
|
+
|
68
83
|
def test_from_uri(self):
|
69
84
|
content = mime.Custom.from_uri('http://mock/web/a.txt', mime='text/plain')
|
70
85
|
with mock.patch('requests.get') as mock_requests_stub:
|
71
86
|
mock_requests_stub.side_effect = mock_request
|
72
|
-
self.assertEqual(content.to_bytes(), 'foo')
|
87
|
+
self.assertEqual(content.to_bytes(), b'foo')
|
73
88
|
self.assertEqual(content.mime_type, 'text/plain')
|
74
89
|
|
75
90
|
content = mime.Custom.from_uri('a.txt', mime='text/plain')
|
76
91
|
with mock.patch('pyglove.io.readfile') as mock_readfile_stub:
|
77
92
|
mock_readfile_stub.side_effect = mock_readfile
|
78
|
-
self.assertEqual(content.to_bytes(), 'bar')
|
93
|
+
self.assertEqual(content.to_bytes(), b'bar')
|
79
94
|
self.assertEqual(content.mime_type, 'text/plain')
|
80
95
|
|
81
96
|
def assert_html_content(self, html, expected):
|
langfun/core/modality_test.py
CHANGED
@@ -11,7 +11,6 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
"""Tests for modality."""
|
15
14
|
from typing import Any
|
16
15
|
import unittest
|
17
16
|
|
@@ -46,15 +46,15 @@ class MappingError(Exception): # pylint: disable=g-bad-exception-name
|
|
46
46
|
r = io.StringIO()
|
47
47
|
error_message = str(self.cause).rstrip()
|
48
48
|
r.write(
|
49
|
-
|
49
|
+
pg.colored(
|
50
50
|
f'{self.cause.__class__.__name__}: {error_message}', 'magenta'
|
51
51
|
)
|
52
52
|
)
|
53
53
|
if include_lm_response:
|
54
54
|
r.write('\n\n')
|
55
|
-
r.write(
|
55
|
+
r.write(pg.colored('[LM Response]', 'blue', styles=['bold']))
|
56
56
|
r.write('\n')
|
57
|
-
r.write(
|
57
|
+
r.write(pg.colored(self.lm_response.text, 'blue'))
|
58
58
|
return r.getvalue()
|
59
59
|
|
60
60
|
|
@@ -163,27 +163,27 @@ class MappingExample(lf.NaturalLanguageFormattable,
|
|
163
163
|
def natural_language_format(self) -> str:
|
164
164
|
result = io.StringIO()
|
165
165
|
if self.context:
|
166
|
-
result.write(
|
167
|
-
result.write(
|
166
|
+
result.write(pg.colored('[CONTEXT]\n', styles=['bold']))
|
167
|
+
result.write(pg.colored(self.context, color='magenta'))
|
168
168
|
result.write('\n\n')
|
169
169
|
|
170
|
-
result.write(
|
171
|
-
result.write(
|
170
|
+
result.write(pg.colored('[INPUT]\n', styles=['bold']))
|
171
|
+
result.write(pg.colored(self.input_repr(), color='green'))
|
172
172
|
|
173
173
|
if self.schema is not None:
|
174
174
|
result.write('\n\n')
|
175
|
-
result.write(
|
176
|
-
result.write(
|
175
|
+
result.write(pg.colored('[SCHEMA]\n', styles=['bold']))
|
176
|
+
result.write(pg.colored(self.schema_repr(), color='red'))
|
177
177
|
|
178
178
|
if schema_lib.MISSING != self.output:
|
179
179
|
result.write('\n\n')
|
180
|
-
result.write(
|
181
|
-
result.write(
|
180
|
+
result.write(pg.colored('[OUTPUT]\n', styles=['bold']))
|
181
|
+
result.write(pg.colored(self.output_repr(), color='blue'))
|
182
182
|
|
183
183
|
if self.metadata:
|
184
184
|
result.write('\n\n')
|
185
|
-
result.write(
|
186
|
-
result.write(
|
185
|
+
result.write(pg.colored('[METADATA]\n', styles=['bold']))
|
186
|
+
result.write(pg.colored(str(self.metadata), color='cyan'))
|
187
187
|
return result.getvalue().strip()
|
188
188
|
|
189
189
|
@classmethod
|
@@ -29,11 +29,11 @@ class MappingErrorTest(unittest.TestCase):
|
|
29
29
|
lf.AIMessage('hi'), ValueError('Cannot parse message.')
|
30
30
|
)
|
31
31
|
self.assertEqual(
|
32
|
-
|
32
|
+
pg.decolor(str(error)),
|
33
33
|
'ValueError: Cannot parse message.\n\n[LM Response]\nhi',
|
34
34
|
)
|
35
35
|
self.assertEqual(
|
36
|
-
|
36
|
+
pg.decolor(error.format(include_lm_response=False)),
|
37
37
|
'ValueError: Cannot parse message.',
|
38
38
|
)
|
39
39
|
|