langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -14,80 +14,20 @@
14
14
  """Tests for OpenAI models."""
15
15
 
16
16
  import unittest
17
- from unittest import mock
18
-
19
17
  import langfun.core as lf
20
- from langfun.core import modalities as lf_modalities
21
18
  from langfun.core.llms import openai
22
- import pyglove as pg
23
-
24
-
25
- def mock_completion_query(prompt, *, n=1, **kwargs):
26
- del kwargs
27
- choices = []
28
- for i, _ in enumerate(prompt):
29
- for k in range(n):
30
- choices.append(pg.Dict(
31
- index=i,
32
- text=f'Sample {k} for prompt {i}.',
33
- logprobs=k / 10,
34
- ))
35
- return pg.Dict(
36
- choices=choices,
37
- usage=lf.LMSamplingUsage(
38
- prompt_tokens=100,
39
- completion_tokens=100,
40
- total_tokens=200,
41
- ),
42
- )
43
-
44
-
45
- def mock_chat_completion_query(messages, *, n=1, **kwargs):
46
- del messages, kwargs
47
- choices = []
48
- for k in range(n):
49
- choices.append(pg.Dict(
50
- message=pg.Dict(
51
- content=f'Sample {k} for message.'
52
- ),
53
- logprobs=None,
54
- ))
55
- return pg.Dict(
56
- choices=choices,
57
- usage=lf.LMSamplingUsage(
58
- prompt_tokens=100,
59
- completion_tokens=100,
60
- total_tokens=200,
61
- ),
62
- )
63
-
64
-
65
- def mock_chat_completion_query_vision(messages, *, n=1, **kwargs):
66
- del kwargs
67
- choices = []
68
- urls = [
69
- c['image_url'] for c in messages[0]['content'] if c['type'] == 'image_url'
70
- ]
71
- for k in range(n):
72
- choices.append(pg.Dict(
73
- message=pg.Dict(
74
- content=f'Sample {k} for message: {"".join(urls)}'
75
- ),
76
- logprobs=None,
77
- ))
78
- return pg.Dict(
79
- choices=choices,
80
- usage=lf.LMSamplingUsage(
81
- prompt_tokens=100,
82
- completion_tokens=100,
83
- total_tokens=200,
84
- ),
85
- )
86
19
 
87
20
 
88
21
  class OpenAITest(unittest.TestCase):
89
22
  """Tests for OpenAI language model."""
90
23
 
24
+ def test_dir(self):
25
+ self.assertIn('gpt-4-turbo', openai.OpenAI.dir())
26
+
27
+ def test_key(self):
28
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
29
+ openai.Gpt4()('hi')
30
+
91
31
  def test_model_id(self):
92
32
  self.assertEqual(
93
33
  openai.Gpt35(api_key='test_key').model_id, 'OpenAI(text-davinci-003)')
@@ -97,352 +37,47 @@ class OpenAITest(unittest.TestCase):
97
37
  openai.Gpt35(api_key='test_key').resource_id, 'OpenAI(text-davinci-003)'
98
38
  )
99
39
 
40
+ def test_headers(self):
41
+ self.assertEqual(
42
+ openai.Gpt35(api_key='test_key').headers,
43
+ {
44
+ 'Content-Type': 'application/json',
45
+ 'Authorization': 'Bearer test_key',
46
+ },
47
+ )
48
+
100
49
  def test_max_concurrency(self):
101
50
  self.assertGreater(openai.Gpt35(api_key='test_key').max_concurrency, 0)
102
51
 
103
- def test_get_request_args(self):
52
+ def test_request_args(self):
104
53
  self.assertEqual(
105
- openai.Gpt35(api_key='test_key', timeout=90.0)._get_request_args(
54
+ openai.Gpt4(api_key='test_key')._request_args(
106
55
  lf.LMSamplingOptions(
107
- temperature=2.0,
108
- n=2,
109
- max_tokens=4096,
110
- top_p=1.0)),
111
- dict(
112
- engine='text-davinci-003',
113
- logprobs=False,
114
- top_logprobs=None,
115
- n=2,
116
- temperature=2.0,
117
- max_tokens=4096,
118
- stream=False,
119
- timeout=90.0,
120
- top_p=1.0,
121
- )
122
- )
123
- self.assertEqual(
124
- openai.Gpt4(api_key='test_key')._get_request_args(
125
- lf.LMSamplingOptions(temperature=1.0, stop=['\n'], n=1)
56
+ temperature=1.0, stop=['\n'], n=1, random_seed=123
57
+ )
126
58
  ),
127
59
  dict(
128
60
  model='gpt-4',
129
- logprobs=False,
130
61
  top_logprobs=None,
131
62
  n=1,
132
63
  temperature=1.0,
133
- stream=False,
134
- timeout=120.0,
135
64
  stop=['\n'],
65
+ seed=123,
136
66
  ),
137
67
  )
138
-
139
- def test_call_completion(self):
140
- with mock.patch('openai.Completion.create') as mock_completion:
141
- mock_completion.side_effect = mock_completion_query
142
- lm = openai.OpenAI(api_key='test_key', model='text-davinci-003')
143
- self.assertEqual(
144
- lm('hello', sampling_options=lf.LMSamplingOptions(n=2)),
145
- 'Sample 0 for prompt 0.',
146
- )
147
-
148
- def test_call_chat_completion(self):
149
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
150
- mock_chat_completion.side_effect = mock_chat_completion_query
151
- lm = openai.OpenAI(api_key='test_key', model='gpt-4')
152
- self.assertEqual(
153
- lm('hello', sampling_options=lf.LMSamplingOptions(n=2)),
154
- 'Sample 0 for message.',
68
+ with self.assertRaisesRegex(RuntimeError, '`logprobs` is not supported.*'):
69
+ openai.GptO1Preview(api_key='test_key')._request_args(
70
+ lf.LMSamplingOptions(
71
+ temperature=1.0, logprobs=True
72
+ )
155
73
  )
156
74
 
157
- def test_call_chat_completion_vision(self):
158
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
159
- mock_chat_completion.side_effect = mock_chat_completion_query_vision
160
- lm_1 = openai.Gpt4Turbo(api_key='test_key')
161
- lm_2 = openai.Gpt4VisionPreview(api_key='test_key')
162
- for lm in (lm_1, lm_2):
163
- self.assertEqual(
164
- lm(
165
- lf.UserMessage(
166
- 'hello {{image}}',
167
- image=lf_modalities.Image.from_uri('https://fake/image')
168
- ),
169
- sampling_options=lf.LMSamplingOptions(n=2)
170
- ),
171
- 'Sample 0 for message: https://fake/image',
172
- )
173
-
174
- def test_sample_completion(self):
175
- with mock.patch('openai.Completion.create') as mock_completion:
176
- mock_completion.side_effect = mock_completion_query
177
- lm = openai.OpenAI(api_key='test_key', model='text-davinci-003')
178
- results = lm.sample(
179
- ['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3)
180
- )
181
-
182
- self.assertEqual(len(results), 2)
183
- self.assertEqual(
184
- results[0],
185
- lf.LMSamplingResult(
186
- [
187
- lf.LMSample(
188
- lf.AIMessage(
189
- 'Sample 0 for prompt 0.',
190
- score=0.0,
191
- logprobs=None,
192
- usage=lf.LMSamplingUsage(
193
- prompt_tokens=33,
194
- completion_tokens=33,
195
- total_tokens=66
196
- ),
197
- tags=[lf.Message.TAG_LM_RESPONSE],
198
- ),
199
- score=0.0,
200
- logprobs=None,
201
- ),
202
- lf.LMSample(
203
- lf.AIMessage(
204
- 'Sample 1 for prompt 0.',
205
- score=0.1,
206
- logprobs=None,
207
- usage=lf.LMSamplingUsage(
208
- prompt_tokens=33,
209
- completion_tokens=33,
210
- total_tokens=66
211
- ),
212
- tags=[lf.Message.TAG_LM_RESPONSE],
213
- ),
214
- score=0.1,
215
- logprobs=None,
216
- ),
217
- lf.LMSample(
218
- lf.AIMessage(
219
- 'Sample 2 for prompt 0.',
220
- score=0.2,
221
- logprobs=None,
222
- usage=lf.LMSamplingUsage(
223
- prompt_tokens=33,
224
- completion_tokens=33,
225
- total_tokens=66
226
- ),
227
- tags=[lf.Message.TAG_LM_RESPONSE],
228
- ),
229
- score=0.2,
230
- logprobs=None,
231
- ),
232
- ],
233
- usage=lf.LMSamplingUsage(
234
- prompt_tokens=100, completion_tokens=100, total_tokens=200
235
- ),
236
- ),
237
- )
238
- self.assertEqual(
239
- results[1],
240
- lf.LMSamplingResult(
241
- [
242
- lf.LMSample(
243
- lf.AIMessage(
244
- 'Sample 0 for prompt 1.',
245
- score=0.0,
246
- logprobs=None,
247
- usage=None,
248
- tags=[lf.Message.TAG_LM_RESPONSE],
249
- ),
250
- score=0.0,
251
- logprobs=None,
252
- ),
253
- lf.LMSample(
254
- lf.AIMessage(
255
- 'Sample 1 for prompt 1.',
256
- score=0.1,
257
- logprobs=None,
258
- usage=None,
259
- tags=[lf.Message.TAG_LM_RESPONSE],
260
- ),
261
- score=0.1,
262
- logprobs=None,
263
- ),
264
- lf.LMSample(
265
- lf.AIMessage(
266
- 'Sample 2 for prompt 1.',
267
- score=0.2,
268
- logprobs=None,
269
- usage=None,
270
- tags=[lf.Message.TAG_LM_RESPONSE],
271
- ),
272
- score=0.2,
273
- logprobs=None,
274
- ),
275
- ],
276
- ),
277
- )
278
-
279
- def test_sample_chat_completion(self):
280
- with mock.patch('openai.ChatCompletion.create') as mock_chat_completion:
281
- mock_chat_completion.side_effect = mock_chat_completion_query
282
- lm = openai.OpenAI(api_key='test_key', model='gpt-4')
283
- results = lm.sample(
284
- ['hello', 'bye'], sampling_options=lf.LMSamplingOptions(n=3)
285
- )
286
-
287
- self.assertEqual(len(results), 2)
288
- self.assertEqual(
289
- results[0],
290
- lf.LMSamplingResult(
291
- [
292
- lf.LMSample(
293
- lf.AIMessage(
294
- 'Sample 0 for message.',
295
- score=0.0,
296
- logprobs=None,
297
- usage=lf.LMSamplingUsage(
298
- prompt_tokens=33,
299
- completion_tokens=33,
300
- total_tokens=66
301
- ),
302
- tags=[lf.Message.TAG_LM_RESPONSE],
303
- ),
304
- score=0.0,
305
- logprobs=None,
306
- ),
307
- lf.LMSample(
308
- lf.AIMessage(
309
- 'Sample 1 for message.',
310
- score=0.0,
311
- logprobs=None,
312
- usage=lf.LMSamplingUsage(
313
- prompt_tokens=33,
314
- completion_tokens=33,
315
- total_tokens=66
316
- ),
317
- tags=[lf.Message.TAG_LM_RESPONSE],
318
- ),
319
- score=0.0,
320
- logprobs=None,
321
- ),
322
- lf.LMSample(
323
- lf.AIMessage(
324
- 'Sample 2 for message.',
325
- score=0.0,
326
- logprobs=None,
327
- usage=lf.LMSamplingUsage(
328
- prompt_tokens=33,
329
- completion_tokens=33,
330
- total_tokens=66
331
- ),
332
- tags=[lf.Message.TAG_LM_RESPONSE],
333
- ),
334
- score=0.0,
335
- logprobs=None,
336
- ),
337
- ],
338
- usage=lf.LMSamplingUsage(
339
- prompt_tokens=100, completion_tokens=100, total_tokens=200
340
- ),
341
- ),
342
- )
343
- self.assertEqual(
344
- results[1],
345
- lf.LMSamplingResult(
346
- [
347
- lf.LMSample(
348
- lf.AIMessage(
349
- 'Sample 0 for message.',
350
- score=0.0,
351
- logprobs=None,
352
- usage=lf.LMSamplingUsage(
353
- prompt_tokens=33,
354
- completion_tokens=33,
355
- total_tokens=66
356
- ),
357
- tags=[lf.Message.TAG_LM_RESPONSE],
358
- ),
359
- score=0.0,
360
- logprobs=None,
361
- ),
362
- lf.LMSample(
363
- lf.AIMessage(
364
- 'Sample 1 for message.',
365
- score=0.0,
366
- logprobs=None,
367
- usage=lf.LMSamplingUsage(
368
- prompt_tokens=33,
369
- completion_tokens=33,
370
- total_tokens=66
371
- ),
372
- tags=[lf.Message.TAG_LM_RESPONSE],
373
- ),
374
- score=0.0,
375
- logprobs=None,
376
- ),
377
- lf.LMSample(
378
- lf.AIMessage(
379
- 'Sample 2 for message.',
380
- score=0.0,
381
- logprobs=None,
382
- usage=lf.LMSamplingUsage(
383
- prompt_tokens=33,
384
- completion_tokens=33,
385
- total_tokens=66
386
- ),
387
- tags=[lf.Message.TAG_LM_RESPONSE],
388
- ),
389
- score=0.0,
390
- logprobs=None,
391
- ),
392
- ],
393
- usage=lf.LMSamplingUsage(
394
- prompt_tokens=100, completion_tokens=100, total_tokens=200
395
- ),
396
- ),
397
- )
398
-
399
- def test_sample_with_contextual_options(self):
400
- with mock.patch('openai.Completion.create') as mock_completion:
401
- mock_completion.side_effect = mock_completion_query
402
- lm = openai.OpenAI(api_key='test_key', model='text-davinci-003')
403
- with lf.use_settings(sampling_options=lf.LMSamplingOptions(n=2)):
404
- results = lm.sample(['hello'])
405
-
406
- self.assertEqual(len(results), 1)
75
+ def test_estimate_cost(self):
407
76
  self.assertEqual(
408
- results[0],
409
- lf.LMSamplingResult(
410
- [
411
- lf.LMSample(
412
- lf.AIMessage(
413
- 'Sample 0 for prompt 0.',
414
- score=0.0,
415
- logprobs=None,
416
- usage=lf.LMSamplingUsage(
417
- prompt_tokens=50,
418
- completion_tokens=50,
419
- total_tokens=100,
420
- ),
421
- tags=[lf.Message.TAG_LM_RESPONSE],
422
- ),
423
- score=0.0,
424
- logprobs=None,
425
- ),
426
- lf.LMSample(
427
- lf.AIMessage(
428
- 'Sample 1 for prompt 0.',
429
- score=0.1,
430
- logprobs=None,
431
- usage=lf.LMSamplingUsage(
432
- prompt_tokens=50,
433
- completion_tokens=50,
434
- total_tokens=100,
435
- ),
436
- tags=[lf.Message.TAG_LM_RESPONSE],
437
- ),
438
- score=0.1,
439
- logprobs=None,
440
- ),
441
- ],
442
- usage=lf.LMSamplingUsage(
443
- prompt_tokens=100, completion_tokens=100, total_tokens=200
444
- ),
77
+ openai.Gpt4(api_key='test_key').estimate_cost(
78
+ num_input_tokens=100, num_output_tokens=100
445
79
  ),
80
+ 0.009
446
81
  )
447
82
 
448
83
 
@@ -0,0 +1,113 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Base class for language models through REST APIs."""
15
+
16
+ import functools
17
+ from typing import Annotated, Any, Callable
18
+
19
+ import langfun.core as lf
20
+ import requests
21
+
22
+
23
+ class REST(lf.LanguageModel):
24
+ """REST-based language model."""
25
+
26
+ api_endpoint: Annotated[
27
+ str,
28
+ 'The endpoint of the REST API.'
29
+ ]
30
+
31
+ request: Annotated[
32
+ Callable[[lf.Message, lf.LMSamplingOptions], dict[str, Any]],
33
+ 'A function to convert a Langfun message to a JSON request.'
34
+ ]
35
+
36
+ result: Annotated[
37
+ Callable[[dict[str, Any]], lf.LMSamplingResult],
38
+ 'A function to convert a JSON response to an LMSamplingResult.'
39
+ ]
40
+
41
+ model: Annotated[
42
+ str | None,
43
+ 'Model ID.'
44
+ ] = None
45
+
46
+ headers: Annotated[
47
+ dict[str, Any] | None,
48
+ 'The headers for the REST API.'
49
+ ] = None
50
+
51
+ @property
52
+ def model_id(self) -> str:
53
+ """Returns a string to identify the model."""
54
+ return self.model or 'unknown'
55
+
56
+ @functools.cached_property
57
+ def _api_initialized(self) -> bool:
58
+ """Returns whether the API is initialized."""
59
+ self._initialize()
60
+ return True
61
+
62
+ def _initialize(self) -> None:
63
+ """Initializes the API. Subclasses can override."""
64
+
65
+ @functools.cached_property
66
+ def _session(self) -> requests.Session:
67
+ assert self._api_initialized
68
+ s = requests.Session()
69
+ s.headers.update(self.headers or {})
70
+ return s
71
+
72
+ def _on_bound(self):
73
+ super()._on_bound()
74
+ self.__dict__.pop('_session', None)
75
+ self.__dict__.pop('_api_initialized', None)
76
+
77
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
78
+ assert self._api_initialized
79
+ return self._parallel_execute_with_currency_control(
80
+ self._sample_single, prompts
81
+ )
82
+
83
+ def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
84
+ try:
85
+ response = self._session.post(
86
+ self.api_endpoint,
87
+ json=self.request(prompt, self.sampling_options),
88
+ timeout=self.timeout,
89
+ )
90
+ return self._parse_response(response)
91
+ except ConnectionError as e:
92
+ raise lf.LMError(str(e)) from e
93
+
94
+ def _error(self, status_code: int, content: str) -> lf.LMError:
95
+ if status_code == 429:
96
+ error_cls = lf.RateLimitError
97
+ elif status_code in (
98
+ 500, # Server side issue (might be bug).
99
+ 502, # Bad gateway (upstream issue, might retry).
100
+ 503, # Servers currently under load, retry after a brief wait.
101
+ 529, # Overloaded, retry after a brief wait.
102
+ ):
103
+ error_cls = lf.TemporaryLMError
104
+ else:
105
+ error_cls = lf.LMError
106
+ return error_cls(f'{status_code}: {content}')
107
+
108
+ def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
109
+ """Parses Anthropic's response."""
110
+ if response.status_code == 200:
111
+ return self.result(response.json())
112
+ else:
113
+ raise self._error(response.status_code, response.content)
@@ -0,0 +1,111 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tests for REST models."""
15
+
16
+ from typing import Any
17
+ import unittest
18
+ from unittest import mock
19
+ import langfun.core as lf
20
+ from langfun.core.llms import rest
21
+ import pyglove as pg
22
+ import requests
23
+
24
+
25
+ def mock_requests_post(url: str, json: dict[str, Any], **kwargs):
26
+ del url, kwargs
27
+ response = requests.Response()
28
+ response.status_code = 200
29
+ response._content = pg.to_json_str({
30
+ 'content': [(
31
+ f'hello with temperature={json.get("temperature")}, '
32
+ f'top_k={json.get("top_k")}, '
33
+ f'top_p={json.get("top_p")}, '
34
+ f'max_tokens={json.get("max_tokens")}, '
35
+ f'stop={json.get("stop_sequences")}.'
36
+ )],
37
+ }).encode()
38
+ return response
39
+
40
+
41
+ def mock_requests_post_error(status_code, error_type, error_message):
42
+ def _mock_requests(url: str, json: dict[str, Any], **kwargs):
43
+ del url, json, kwargs
44
+ response = requests.Response()
45
+ response.status_code = status_code
46
+ response._content = pg.to_json_str(
47
+ {
48
+ 'error': {
49
+ 'type': error_type,
50
+ 'message': error_message,
51
+ }
52
+ }
53
+ ).encode()
54
+ return response
55
+
56
+ return _mock_requests
57
+
58
+
59
+ class RestTest(unittest.TestCase):
60
+
61
+ def setUp(self):
62
+ super().setUp()
63
+ self._lm = rest.REST(
64
+ api_endpoint='https://fake-api.com',
65
+ request=lambda x, o: dict(
66
+ model='test-model',
67
+ prompt=x.text,
68
+ temperature=0.0,
69
+ top_k=0.1,
70
+ top_p=0.2,
71
+ stop_sequences=['\n'],
72
+ max_tokens=4096,
73
+ ),
74
+ result=lambda x: lf.LMSamplingResult(
75
+ [lf.LMSample(c) for c in x['content']]),
76
+ headers=dict(api_key='fake_key'),
77
+ )
78
+
79
+ def test_call(self):
80
+ with mock.patch('requests.Session.post') as mock_request:
81
+ mock_request.side_effect = mock_requests_post
82
+ self.assertEqual(self._lm.model_id, 'unknown')
83
+ response = self._lm(
84
+ 'hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
85
+ self.assertEqual(
86
+ response.text,
87
+ (
88
+ 'hello with temperature=0.0, top_k=0.1, top_p=0.2, '
89
+ "max_tokens=4096, stop=['\\n']."
90
+ ),
91
+ )
92
+ self.assertIsInstance(response.usage, lf.UsageNotAvailable)
93
+
94
+ def test_call_errors(self):
95
+ for status_code, error_type, error_message in [
96
+ (429, 'rate_limit', 'Rate limit exceeded.'),
97
+ (529, 'service_unavailable', 'Service unavailable.'),
98
+ (500, 'bad_request', 'Bad request.'),
99
+ ]:
100
+ with mock.patch('requests.Session.post') as mock_mm_request:
101
+ mock_mm_request.side_effect = mock_requests_post_error(
102
+ status_code, error_type, error_message
103
+ )
104
+ with self.assertRaisesRegex(
105
+ Exception, f'.*{status_code}: .*{error_message}'
106
+ ):
107
+ self._lm('hello', max_attempts=1)
108
+
109
+
110
+ if __name__ == '__main__':
111
+ unittest.main()