langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -13,95 +13,302 @@
13
13
  # limitations under the License.
14
14
  """Language models from OpenAI."""
15
15
 
16
- import collections
17
- import functools
18
16
  import os
19
- from typing import Annotated, Any, cast
17
+ from typing import Annotated, Any
20
18
 
21
19
  import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
- import openai
24
- from openai import error as openai_error
25
- from openai import openai_object
20
+ from langfun.core.llms import openai_compatible
26
21
  import pyglove as pg
27
22
 
28
23
 
29
- class Usage(pg.Object):
30
- """Usage information per completion."""
31
-
32
- prompt_tokens: int
33
- completion_tokens: int
34
- total_tokens: int
35
-
36
-
37
- class LMSamplingResult(lf.LMSamplingResult):
38
- """LMSamplingResult with usage information."""
39
-
40
- usage: Usage | None = None
41
-
42
-
43
- SUPPORTED_MODELS_AND_SETTINGS = [
44
- # Model name, max concurrent requests.
45
- # The concurrent requests is estimated by TPM/RPM from
46
- # https://platform.openai.com/account/limits
47
- # GPT-4 Turbo models.
48
- ('gpt-4-turbo-preview', 1), # GPT-4 Turbo.
49
- ('gpt-4-0125-preview', 1), # GPT-4 Turbo
50
- ('gpt-4-1106-preview', 1), # GPT-4 Turbo
51
- ('gpt-4-vision-preview', 1), # GPT-4 Turbo with Vision.
52
- # GPT-4 models.
53
- ('gpt-4', 4),
54
- ('gpt-4-0613', 4),
55
- ('gpt-4-0314', 4),
56
- ('gpt-4-32k', 4),
57
- ('gpt-4-32k-0613', 4),
58
- ('gpt-4-32k-0314', 4),
59
- # GPT-3.5 Turbo models.
60
- ('gpt-3.5-turbo', 16),
61
- ('gpt-3.5-turbo-0125', 16),
62
- ('gpt-3.5-turbo-1106', 16),
63
- ('gpt-3.5-turbo-0613', 16),
64
- ('gpt-3.5-turbo-0301', 16),
65
- ('gpt-3.5-turbo-16k', 16),
66
- ('gpt-3.5-turbo-16k-0613', 16),
67
- ('gpt-3.5-turbo-16k-0301', 16),
68
- # GPT-3.5 models.
69
- ('text-davinci-003', 8), # GPT-3.5, trained with RHLF.
70
- ('text-davinci-002', 4), # Trained with SFT but no RHLF.
71
- ('code-davinci-002', 4),
72
- # GPT-3 instruction-tuned models.
73
- ('text-curie-001', 4),
74
- ('text-babbage-001', 4),
75
- ('text-ada-001', 4),
76
- ('davinci', 4),
77
- ('curie', 4),
78
- ('babbage', 4),
79
- ('ada', 4),
80
- # GPT-3 base models without instruction tuning.
81
- ('babbage-002', 4),
82
- ('davinci-002', 4),
83
- ]
84
-
85
-
86
- # Model concurreny setting.
87
- _MODEL_CONCURRENCY = {m[0]: m[1] for m in SUPPORTED_MODELS_AND_SETTINGS}
24
+ # From https://platform.openai.com/settings/organization/limits
25
+ _DEFAULT_TPM = 250000
26
+ _DEFAULT_RPM = 3000
27
+
28
+ SUPPORTED_MODELS_AND_SETTINGS = {
29
+ # Models from https://platform.openai.com/docs/models
30
+ # RPM is from https://platform.openai.com/docs/guides/rate-limits
31
+ # o1 (preview) models.
32
+ # Pricing in US dollars, from https://openai.com/api/pricing/
33
+ # as of 2024-10-10.
34
+ 'o1': pg.Dict(
35
+ in_service=True,
36
+ rpm=10000,
37
+ tpm=5000000,
38
+ cost_per_1k_input_tokens=0.015,
39
+ cost_per_1k_output_tokens=0.06,
40
+ ),
41
+ 'o1-preview': pg.Dict(
42
+ in_service=True,
43
+ rpm=10000,
44
+ tpm=5000000,
45
+ cost_per_1k_input_tokens=0.015,
46
+ cost_per_1k_output_tokens=0.06,
47
+ ),
48
+ 'o1-preview-2024-09-12': pg.Dict(
49
+ in_service=True,
50
+ rpm=10000,
51
+ tpm=5000000,
52
+ cost_per_1k_input_tokens=0.015,
53
+ cost_per_1k_output_tokens=0.06,
54
+ ),
55
+ 'o1-mini': pg.Dict(
56
+ in_service=True,
57
+ rpm=10000,
58
+ tpm=5000000,
59
+ cost_per_1k_input_tokens=0.003,
60
+ cost_per_1k_output_tokens=0.012,
61
+ ),
62
+ 'o1-mini-2024-09-12': pg.Dict(
63
+ in_service=True,
64
+ rpm=10000,
65
+ tpm=5000000,
66
+ cost_per_1k_input_tokens=0.003,
67
+ cost_per_1k_output_tokens=0.012,
68
+ ),
69
+ # GPT-4o models
70
+ 'gpt-4o-mini': pg.Dict(
71
+ in_service=True,
72
+ rpm=10000,
73
+ tpm=5000000,
74
+ cost_per_1k_input_tokens=0.00015,
75
+ cost_per_1k_output_tokens=0.0006,
76
+ ),
77
+ 'gpt-4o-mini-2024-07-18': pg.Dict(
78
+ in_service=True,
79
+ rpm=10000,
80
+ tpm=5000000,
81
+ cost_per_1k_input_tokens=0.00015,
82
+ cost_per_1k_output_tokens=0.0006,
83
+ ),
84
+ 'gpt-4o': pg.Dict(
85
+ in_service=True,
86
+ rpm=10000,
87
+ tpm=5000000,
88
+ cost_per_1k_input_tokens=0.0025,
89
+ cost_per_1k_output_tokens=0.01,
90
+ ),
91
+ 'gpt-4o-2024-11-20': pg.Dict(
92
+ in_service=True,
93
+ rpm=10000,
94
+ tpm=5000000,
95
+ cost_per_1k_input_tokens=0.0025,
96
+ cost_per_1k_output_tokens=0.01,
97
+ ),
98
+ 'gpt-4o-2024-08-06': pg.Dict(
99
+ in_service=True,
100
+ rpm=10000,
101
+ tpm=5000000,
102
+ cost_per_1k_input_tokens=0.0025,
103
+ cost_per_1k_output_tokens=0.01,
104
+ ),
105
+ 'gpt-4o-2024-05-13': pg.Dict(
106
+ in_service=True,
107
+ rpm=10000,
108
+ tpm=5000000,
109
+ cost_per_1k_input_tokens=0.005,
110
+ cost_per_1k_output_tokens=0.015,
111
+ ),
112
+ # GPT-4-Turbo models
113
+ 'gpt-4-turbo': pg.Dict(
114
+ in_service=True,
115
+ rpm=10000,
116
+ tpm=2000000,
117
+ cost_per_1k_input_tokens=0.01,
118
+ cost_per_1k_output_tokens=0.03,
119
+ ),
120
+ 'gpt-4-turbo-2024-04-09': pg.Dict(
121
+ in_service=True,
122
+ rpm=10000,
123
+ tpm=2000000,
124
+ cost_per_1k_input_tokens=0.01,
125
+ cost_per_1k_output_tokens=0.03,
126
+ ),
127
+ 'gpt-4-turbo-preview': pg.Dict(
128
+ in_service=True,
129
+ rpm=10000,
130
+ tpm=2000000,
131
+ cost_per_1k_input_tokens=0.01,
132
+ cost_per_1k_output_tokens=0.03,
133
+ ),
134
+ 'gpt-4-0125-preview': pg.Dict(
135
+ in_service=True,
136
+ rpm=10000,
137
+ tpm=2000000,
138
+ cost_per_1k_input_tokens=0.01,
139
+ cost_per_1k_output_tokens=0.03,
140
+ ),
141
+ 'gpt-4-1106-preview': pg.Dict(
142
+ in_service=True,
143
+ rpm=10000,
144
+ tpm=2000000,
145
+ cost_per_1k_input_tokens=0.01,
146
+ cost_per_1k_output_tokens=0.03,
147
+ ),
148
+ 'gpt-4-vision-preview': pg.Dict(
149
+ in_service=True,
150
+ rpm=10000,
151
+ tpm=2000000,
152
+ cost_per_1k_input_tokens=0.01,
153
+ cost_per_1k_output_tokens=0.03,
154
+ ),
155
+ 'gpt-4-1106-vision-preview': pg.Dict(
156
+ in_service=True,
157
+ rpm=10000,
158
+ tpm=2000000,
159
+ cost_per_1k_input_tokens=0.01,
160
+ cost_per_1k_output_tokens=0.03,
161
+ ),
162
+ # GPT-4 models
163
+ 'gpt-4': pg.Dict(
164
+ in_service=True,
165
+ rpm=10000,
166
+ tpm=300000,
167
+ cost_per_1k_input_tokens=0.03,
168
+ cost_per_1k_output_tokens=0.06,
169
+ ),
170
+ 'gpt-4-0613': pg.Dict(
171
+ in_service=False,
172
+ rpm=10000,
173
+ tpm=300000,
174
+ cost_per_1k_input_tokens=0.03,
175
+ cost_per_1k_output_tokens=0.06,
176
+ ),
177
+ 'gpt-4-0314': pg.Dict(
178
+ in_service=False,
179
+ rpm=10000,
180
+ tpm=300000,
181
+ cost_per_1k_input_tokens=0.03,
182
+ cost_per_1k_output_tokens=0.06,
183
+ ),
184
+ 'gpt-4-32k': pg.Dict(
185
+ in_service=True,
186
+ rpm=10000,
187
+ tpm=300000,
188
+ cost_per_1k_input_tokens=0.06,
189
+ cost_per_1k_output_tokens=0.12,
190
+ ),
191
+ 'gpt-4-32k-0613': pg.Dict(
192
+ in_service=False,
193
+ rpm=10000,
194
+ tpm=300000,
195
+ cost_per_1k_input_tokens=0.06,
196
+ cost_per_1k_output_tokens=0.12,
197
+ ),
198
+ 'gpt-4-32k-0314': pg.Dict(
199
+ in_service=False,
200
+ rpm=10000,
201
+ tpm=300000,
202
+ cost_per_1k_input_tokens=0.06,
203
+ cost_per_1k_output_tokens=0.12,
204
+ ),
205
+ # GPT-3.5-Turbo models
206
+ 'gpt-3.5-turbo': pg.Dict(
207
+ in_service=True,
208
+ rpm=10000,
209
+ tpm=2000000,
210
+ cost_per_1k_input_tokens=0.0005,
211
+ cost_per_1k_output_tokens=0.0015,
212
+ ),
213
+ 'gpt-3.5-turbo-0125': pg.Dict(
214
+ in_service=True,
215
+ rpm=10000,
216
+ tpm=2000000,
217
+ cost_per_1k_input_tokens=0.0005,
218
+ cost_per_1k_output_tokens=0.0015,
219
+ ),
220
+ 'gpt-3.5-turbo-1106': pg.Dict(
221
+ in_service=True,
222
+ rpm=10000,
223
+ tpm=2000000,
224
+ cost_per_1k_input_tokens=0.001,
225
+ cost_per_1k_output_tokens=0.002,
226
+ ),
227
+ 'gpt-3.5-turbo-0613': pg.Dict(
228
+ in_service=True,
229
+ rpm=10000,
230
+ tpm=2000000,
231
+ cost_per_1k_input_tokens=0.0015,
232
+ cost_per_1k_output_tokens=0.002,
233
+ ),
234
+ 'gpt-3.5-turbo-0301': pg.Dict(
235
+ in_service=True,
236
+ rpm=10000,
237
+ tpm=2000000,
238
+ cost_per_1k_input_tokens=0.0015,
239
+ cost_per_1k_output_tokens=0.002,
240
+ ),
241
+ 'gpt-3.5-turbo-16k': pg.Dict(
242
+ in_service=True,
243
+ rpm=10000,
244
+ tpm=2000000,
245
+ cost_per_1k_input_tokens=0.003,
246
+ cost_per_1k_output_tokens=0.004,
247
+ ),
248
+ 'gpt-3.5-turbo-16k-0613': pg.Dict(
249
+ in_service=True,
250
+ rpm=10000,
251
+ tpm=2000000,
252
+ cost_per_1k_input_tokens=0.003,
253
+ cost_per_1k_output_tokens=0.004,
254
+ ),
255
+ 'gpt-3.5-turbo-16k-0301': pg.Dict(
256
+ in_service=False,
257
+ rpm=10000,
258
+ tpm=2000000,
259
+ cost_per_1k_input_tokens=0.003,
260
+ cost_per_1k_output_tokens=0.004,
261
+ ),
262
+ # GPT-3.5 models
263
+ 'text-davinci-003': pg.Dict(
264
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
265
+ ),
266
+ 'text-davinci-002': pg.Dict(
267
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
268
+ ),
269
+ 'code-davinci-002': pg.Dict(
270
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
271
+ ),
272
+ # GPT-3 instruction-tuned models (Deprecated)
273
+ 'text-curie-001': pg.Dict(
274
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
275
+ ),
276
+ 'text-babbage-001': pg.Dict(
277
+ in_service=False,
278
+ rpm=_DEFAULT_RPM,
279
+ tpm=_DEFAULT_TPM,
280
+ ),
281
+ 'text-ada-001': pg.Dict(
282
+ in_service=False,
283
+ rpm=_DEFAULT_RPM,
284
+ tpm=_DEFAULT_TPM,
285
+ ),
286
+ 'davinci': pg.Dict(
287
+ in_service=False,
288
+ rpm=_DEFAULT_RPM,
289
+ tpm=_DEFAULT_TPM,
290
+ ),
291
+ 'curie': pg.Dict(in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
292
+ 'babbage': pg.Dict(in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
293
+ 'ada': pg.Dict(in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
294
+ # GPT-3 base models that are still in service.
295
+ 'babbage-002': pg.Dict(in_service=True, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
296
+ 'davinci-002': pg.Dict(in_service=True, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
297
+ }
88
298
 
89
299
 
90
300
  @lf.use_init_args(['model'])
91
- class OpenAI(lf.LanguageModel):
301
+ class OpenAI(openai_compatible.OpenAICompatible):
92
302
  """OpenAI model."""
93
303
 
94
304
  model: pg.typing.Annotated[
95
305
  pg.typing.Enum(
96
- pg.MISSING_VALUE, [m[0] for m in SUPPORTED_MODELS_AND_SETTINGS]
306
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
97
307
  ),
98
308
  'The name of the model to use.',
99
- ] = 'gpt-3.5-turbo'
309
+ ]
100
310
 
101
- multimodal: Annotated[
102
- bool,
103
- 'Whether this model has multimodal support.'
104
- ] = False
311
+ api_endpoint: str = 'https://api.openai.com/v1/chat/completions'
105
312
 
106
313
  api_key: Annotated[
107
314
  str | None,
@@ -120,23 +327,44 @@ class OpenAI(lf.LanguageModel):
120
327
  ),
121
328
  ] = None
122
329
 
330
+ project: Annotated[
331
+ str | None,
332
+ (
333
+ 'Project. If None, the key will be read from environment '
334
+ "variable 'OPENAI_PROJECT'. Based on the value, usages from "
335
+ "these API requests will count against the project's quota. "
336
+ ),
337
+ ] = None
338
+
123
339
  def _on_bound(self):
124
340
  super()._on_bound()
125
- self.__dict__.pop('_api_initialized', None)
341
+ self._api_key = None
342
+ self._organization = None
343
+ self._project = None
126
344
 
127
- @functools.cached_property
128
- def _api_initialized(self):
345
+ def _initialize(self):
129
346
  api_key = self.api_key or os.environ.get('OPENAI_API_KEY', None)
130
347
  if not api_key:
131
348
  raise ValueError(
132
349
  'Please specify `api_key` during `__init__` or set environment '
133
350
  'variable `OPENAI_API_KEY` with your OpenAI API key.'
134
351
  )
135
- openai.api_key = api_key
136
- org = self.organization or os.environ.get('OPENAI_ORGANIZATION', None)
137
- if org:
138
- openai.organization = org
139
- return True
352
+ self._api_key = api_key
353
+ self._organization = self.organization or os.environ.get(
354
+ 'OPENAI_ORGANIZATION', None
355
+ )
356
+ self._project = self.project or os.environ.get('OPENAI_PROJECT', None)
357
+
358
+ @property
359
+ def headers(self) -> dict[str, Any]:
360
+ assert self._api_initialized
361
+ headers = super().headers
362
+ headers['Authorization'] = f'Bearer {self._api_key}'
363
+ if self._organization:
364
+ headers['OpenAI-Organization'] = self._organization
365
+ if self._project:
366
+ headers['OpenAI-Project'] = self._project
367
+ return headers
140
368
 
141
369
  @property
142
370
  def model_id(self) -> str:
@@ -145,144 +373,73 @@ class OpenAI(lf.LanguageModel):
145
373
 
146
374
  @property
147
375
  def max_concurrency(self) -> int:
148
- return _MODEL_CONCURRENCY[self.model]
376
+ rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
377
+ tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
378
+ return self.rate_to_max_concurrency(
379
+ requests_per_min=rpm, tokens_per_min=tpm
380
+ )
381
+
382
+ def estimate_cost(
383
+ self,
384
+ num_input_tokens: int,
385
+ num_output_tokens: int
386
+ ) -> float | None:
387
+ """Estimate the cost based on usage."""
388
+ cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
389
+ 'cost_per_1k_input_tokens', None
390
+ )
391
+ cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
392
+ 'cost_per_1k_output_tokens', None
393
+ )
394
+ if cost_per_1k_output_tokens is None or cost_per_1k_input_tokens is None:
395
+ return None
396
+ return (
397
+ cost_per_1k_input_tokens * num_input_tokens
398
+ + cost_per_1k_output_tokens * num_output_tokens
399
+ ) / 1000
149
400
 
150
401
  @classmethod
151
402
  def dir(cls):
152
- return openai.Model.list()
153
-
154
- @property
155
- def is_chat_model(self):
156
- """Returns True if the model is a chat model."""
157
- return self.model.startswith(('gpt-4', 'gpt-3.5-turbo'))
403
+ return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
158
404
 
159
- def _get_request_args(
405
+ def _request_args(
160
406
  self, options: lf.LMSamplingOptions) -> dict[str, Any]:
161
- # Reference:
162
- # https://platform.openai.com/docs/api-reference/completions/create
163
- # NOTE(daiyip): options.top_k is not applicable.
164
- args = dict(
165
- n=options.n,
166
- temperature=options.temperature,
167
- max_tokens=options.max_tokens,
168
- stream=False,
169
- timeout=self.timeout,
170
- logprobs=options.logprobs,
171
- top_logprobs=options.top_logprobs,
172
- )
173
- # Completion and ChatCompletion uses different parameter name for model.
174
- args['model' if self.is_chat_model else 'engine'] = self.model
407
+ # Reasoning models (o1 series) does not support `logprobs` by 2024/09/12.
408
+ if options.logprobs and self.model.startswith(('o1-', 'o3-')):
409
+ raise RuntimeError('`logprobs` is not supported on {self.model!r}.')
410
+ return super()._request_args(options)
175
411
 
176
- if options.top_p is not None:
177
- args['top_p'] = options.top_p
178
- if options.stop:
179
- args['stop'] = options.stop
180
- return args
181
412
 
182
- def _sample(self, prompts: list[lf.Message]) -> list[LMSamplingResult]:
183
- assert self._api_initialized
184
- if self.is_chat_model:
185
- return self._chat_complete_batch(prompts)
186
- else:
187
- return self._complete_batch(prompts)
188
-
189
- def _complete_batch(
190
- self, prompts: list[lf.Message]) -> list[LMSamplingResult]:
191
-
192
- def _open_ai_completion(prompts):
193
- response = openai.Completion.create(
194
- prompt=[p.text for p in prompts],
195
- **self._get_request_args(self.sampling_options),
196
- )
197
- response = cast(openai_object.OpenAIObject, response)
198
- # Parse response.
199
- samples_by_index = collections.defaultdict(list)
200
- for choice in response.choices:
201
- samples_by_index[choice.index].append(
202
- lf.LMSample(choice.text.strip(), score=choice.logprobs or 0.0)
203
- )
204
-
205
- usage = Usage(
206
- prompt_tokens=response.usage.prompt_tokens,
207
- completion_tokens=response.usage.completion_tokens,
208
- total_tokens=response.usage.total_tokens,
209
- )
210
- return [
211
- LMSamplingResult(
212
- samples_by_index[index], usage=usage if index == 0 else None
213
- )
214
- for index in sorted(samples_by_index.keys())
215
- ]
216
-
217
- return self._parallel_execute_with_currency_control(
218
- _open_ai_completion,
219
- [prompts],
220
- retry_on_errors=(
221
- openai_error.ServiceUnavailableError,
222
- openai_error.RateLimitError,
223
- ),
224
- )[0]
225
-
226
- def _chat_complete_batch(
227
- self, prompts: list[lf.Message]
228
- ) -> list[LMSamplingResult]:
229
- def _open_ai_chat_completion(prompt: lf.Message):
230
- if self.multimodal:
231
- content = []
232
- for chunk in prompt.chunk():
233
- if isinstance(chunk, str):
234
- item = dict(type='text', text=chunk)
235
- elif isinstance(chunk, lf_modalities.Image) and chunk.uri:
236
- item = dict(type='image_url', image_url=chunk.uri)
237
- else:
238
- raise ValueError(f'Unsupported modality object: {chunk!r}.')
239
- content.append(item)
240
- else:
241
- content = prompt.text
242
-
243
- response = openai.ChatCompletion.create(
244
- # TODO(daiyip): support conversation history and system prompt.
245
- messages=[{'role': 'user', 'content': content}],
246
- **self._get_request_args(self.sampling_options),
247
- )
248
- response = cast(openai_object.OpenAIObject, response)
249
- samples = []
250
- for choice in response.choices:
251
- logprobs = None
252
- if choice.logprobs:
253
- logprobs = [
254
- (
255
- t.token,
256
- t.logprob,
257
- [(tt.token, tt.logprob) for tt in t.top_logprobs],
258
- )
259
- for t in choice.logprobs.content
260
- ]
261
- samples.append(
262
- lf.LMSample(
263
- choice.message.content,
264
- score=0.0,
265
- logprobs=logprobs,
266
- )
267
- )
268
-
269
- return LMSamplingResult(
270
- samples=samples,
271
- usage=Usage(
272
- prompt_tokens=response.usage.prompt_tokens,
273
- completion_tokens=response.usage.completion_tokens,
274
- total_tokens=response.usage.total_tokens,
275
- ),
276
- )
413
+ class GptO1(OpenAI):
414
+ """GPT-O1."""
277
415
 
278
- return self._parallel_execute_with_currency_control(
279
- _open_ai_chat_completion,
280
- prompts,
281
- retry_on_errors=(
282
- openai_error.ServiceUnavailableError,
283
- openai_error.RateLimitError,
284
- ),
285
- )
416
+ model = 'o1'
417
+ multimodal = True
418
+ timeout = None
419
+
420
+
421
+ class GptO1Preview(OpenAI):
422
+ """GPT-O1."""
423
+ model = 'o1-preview'
424
+ timeout = None
425
+
426
+
427
+ class GptO1Preview_20240912(OpenAI): # pylint: disable=invalid-name
428
+ """GPT O1."""
429
+ model = 'o1-preview-2024-09-12'
430
+ timeout = None
431
+
432
+
433
+ class GptO1Mini(OpenAI):
434
+ """GPT O1-mini."""
435
+ model = 'o1-mini'
436
+ timeout = None
437
+
438
+
439
+ class GptO1Mini_20240912(OpenAI): # pylint: disable=invalid-name
440
+ """GPT O1-mini."""
441
+ model = 'o1-mini-2024-09-12'
442
+ timeout = None
286
443
 
287
444
 
288
445
  class Gpt4(OpenAI):
@@ -291,27 +448,44 @@ class Gpt4(OpenAI):
291
448
 
292
449
 
293
450
  class Gpt4Turbo(Gpt4):
294
- """GPT-4 Turbo with 128K context window size. Knowledge up to 4-2023."""
295
- model = 'gpt-4-turbo-preview'
451
+ """GPT-4 Turbo with 128K context window. Knowledge up to Dec. 2023."""
452
+ model = 'gpt-4-turbo'
453
+ multimodal = True
296
454
 
297
455
 
298
- class Gpt4TurboVision(Gpt4Turbo):
299
- """GPT-4 Turbo with vision."""
300
- model = 'gpt-4-vision-preview'
456
+ class Gpt4Turbo_20240409(Gpt4Turbo): # pylint:disable=invalid-name
457
+ """GPT-4 Turbo with 128K context window. Knowledge up to Dec. 2023."""
458
+ model = 'gpt-4-turbo-2024-04-09'
301
459
  multimodal = True
302
460
 
303
461
 
304
- class Gpt4Turbo_0125(Gpt4Turbo): # pylint:disable=invalid-name
305
- """GPT-4 Turbo with 128K context window size. Knowledge up to 4-2023."""
462
+ class Gpt4TurboPreview(Gpt4):
463
+ """GPT-4 Turbo Preview with 128k context window. Knowledge up to Dec. 2023."""
464
+ model = 'gpt-4-turbo-preview'
465
+
466
+
467
+ class Gpt4TurboPreview_20240125(Gpt4TurboPreview): # pylint: disable=invalid-name
468
+ """GPT-4 Turbo Preview with 128k context window. Knowledge up to Dec. 2023."""
306
469
  model = 'gpt-4-0125-preview'
307
470
 
308
471
 
309
- class Gpt4Turbo_1106(Gpt4Turbo): # pylint:disable=invalid-name
310
- """GPT-4 Turbo @20231106. 128K context window. Knowledge up to 4-2023."""
472
+ class Gpt4TurboPreview_20231106(Gpt4TurboPreview): # pylint: disable=invalid-name
473
+ """GPT-4 Turbo Preview with 128k context window. Knowledge up to Apr. 2023."""
311
474
  model = 'gpt-4-1106-preview'
312
475
 
313
476
 
314
- class Gpt4_0613(Gpt4): # pylint:disable=invalid-name
477
+ class Gpt4VisionPreview(Gpt4):
478
+ """GPT-4 Turbo vision preview. 128k context window. Knowledge to Apr. 2023."""
479
+ model = 'gpt-4-vision-preview'
480
+ multimodal = True
481
+
482
+
483
+ class Gpt4VisionPreview_20231106(Gpt4): # pylint: disable=invalid-name
484
+ """GPT-4 Turbo vision preview. 128k context window. Knowledge to Apr. 2023."""
485
+ model = 'gpt-4-1106-vision-preview'
486
+
487
+
488
+ class Gpt4_20230613(Gpt4): # pylint:disable=invalid-name
315
489
  """GPT-4 @20230613. 8K context window. Knowledge up to 9-2021."""
316
490
  model = 'gpt-4-0613'
317
491
 
@@ -321,11 +495,47 @@ class Gpt4_32K(Gpt4): # pylint:disable=invalid-name
321
495
  model = 'gpt-4-32k'
322
496
 
323
497
 
324
- class Gpt4_32K_0613(Gpt4_32K): # pylint:disable=invalid-name
498
+ class Gpt4_32K_20230613(Gpt4_32K): # pylint:disable=invalid-name
325
499
  """GPT-4 @20230613. 32K context window. Knowledge up to 9-2021."""
326
500
  model = 'gpt-4-32k-0613'
327
501
 
328
502
 
503
+ class Gpt4oMini(OpenAI):
504
+ """GPT-4o Mini."""
505
+ model = 'gpt-4o-mini'
506
+ multimodal = True
507
+
508
+
509
+ class Gpt4oMini_20240718(OpenAI): # pylint:disable=invalid-name
510
+ """GPT-4o Mini."""
511
+ model = 'gpt-4o-mini-2024-07-18'
512
+ multimodal = True
513
+
514
+
515
+ class Gpt4o(OpenAI):
516
+ """GPT-4o."""
517
+ model = 'gpt-4o'
518
+ multimodal = True
519
+
520
+
521
+ class Gpt4o_20241120(OpenAI): # pylint:disable=invalid-name
522
+ """GPT-4o version 2024-11-20."""
523
+ model = 'gpt-4o-2024-11-20'
524
+ multimodal = True
525
+
526
+
527
+ class Gpt4o_20240806(OpenAI): # pylint:disable=invalid-name
528
+ """GPT-4o version 2024-08-06."""
529
+ model = 'gpt-4o-2024-08-06'
530
+ multimodal = True
531
+
532
+
533
+ class Gpt4o_20240513(OpenAI): # pylint:disable=invalid-name
534
+ """GPT-4o version 2024-05-13."""
535
+ model = 'gpt-4o-2024-05-13'
536
+ multimodal = True
537
+
538
+
329
539
  class Gpt35(OpenAI):
330
540
  """GPT-3.5. 4K max tokens, trained up on data up to Sep, 2021."""
331
541
  model = 'text-davinci-003'
@@ -336,17 +546,17 @@ class Gpt35Turbo(Gpt35):
336
546
  model = 'gpt-3.5-turbo'
337
547
 
338
548
 
339
- class Gpt35Turbo_0125(Gpt35Turbo): # pylint:disable=invalid-name
549
+ class Gpt35Turbo_20240125(Gpt35Turbo): # pylint:disable=invalid-name
340
550
  """GPT-3.5 Turbo @20240125. 16K context window. Knowledge up to 09/2021."""
341
551
  model = 'gpt-3.5-turbo-0125'
342
552
 
343
553
 
344
- class Gpt35Turbo_1106(Gpt35Turbo): # pylint:disable=invalid-name
554
+ class Gpt35Turbo_20231106(Gpt35Turbo): # pylint:disable=invalid-name
345
555
  """Gpt3.5 Turbo @20231106. 16K context window. Knowledge up to 09/2021."""
346
556
  model = 'gpt-3.5-turbo-1106'
347
557
 
348
558
 
349
- class Gpt35Turbo_0613(Gpt35Turbo): # pylint:disable=invalid-name
559
+ class Gpt35Turbo_20230613(Gpt35Turbo): # pylint:disable=invalid-name
350
560
  """Gpt3.5 Turbo snapshot at 2023/06/13, with 4K context window size."""
351
561
  model = 'gpt-3.5-turbo-0613'
352
562
 
@@ -356,7 +566,7 @@ class Gpt35Turbo16K(Gpt35Turbo):
356
566
  model = 'gpt-3.5-turbo-16k'
357
567
 
358
568
 
359
- class Gpt35Turbo16K_0613(Gpt35Turbo): # pylint:disable=invalid-name
569
+ class Gpt35Turbo16K_20230613(Gpt35Turbo): # pylint:disable=invalid-name
360
570
  """Gtp 3.5 Turbo 16K 0613."""
361
571
  model = 'gpt-3.5-turbo-16k-0613'
362
572