langfun 0.0.2.dev20240429__py3-none-any.whl → 0.1.2.dev202501150804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. langfun/__init__.py +20 -2
  2. langfun/core/__init__.py +16 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -21
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +63 -2
  18. langfun/core/component_test.py +53 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +17 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +16 -1
  24. langfun/core/eval/base.py +622 -174
  25. langfun/core/eval/base_test.py +200 -54
  26. langfun/core/eval/matching.py +63 -76
  27. langfun/core/eval/matching_test.py +17 -8
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +26 -26
  31. langfun/core/eval/scoring_test.py +19 -2
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +4 -17
  55. langfun/core/langfunc_test.py +22 -6
  56. langfun/core/language_model.py +577 -39
  57. langfun/core/language_model_test.py +470 -56
  58. langfun/core/llms/__init__.py +87 -16
  59. langfun/core/llms/anthropic.py +312 -87
  60. langfun/core/llms/anthropic_test.py +71 -3
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +53 -2
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +11 -7
  69. langfun/core/llms/fake_test.py +14 -0
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -202
  74. langfun/core/llms/groq.py +160 -144
  75. langfun/core/llms/groq_test.py +31 -137
  76. langfun/core/llms/llama_cpp.py +15 -42
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +395 -203
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +30 -395
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -26
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +12 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +7 -6
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +60 -27
  112. langfun/core/structured/function_generation_test.py +72 -2
  113. langfun/core/structured/mapping.py +97 -47
  114. langfun/core/structured/mapping_test.py +90 -2
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +53 -9
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +469 -51
  119. langfun/core/structured/schema.py +204 -97
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_test.py +130 -29
  122. langfun/core/structured/scoring.py +125 -19
  123. langfun/core/structured/scoring_test.py +30 -0
  124. langfun/core/structured/tokenization.py +64 -0
  125. langfun/core/structured/tokenization_test.py +48 -0
  126. langfun/core/template.py +115 -1
  127. langfun/core/template_test.py +71 -1
  128. langfun/core/templates/conversation.py +9 -0
  129. langfun/core/templates/conversation_test.py +4 -3
  130. langfun/core/templates/selfplay_test.py +10 -2
  131. langfun-0.1.2.dev202501150804.dist-info/METADATA +225 -0
  132. langfun-0.1.2.dev202501150804.dist-info/RECORD +153 -0
  133. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/WHEEL +1 -1
  134. langfun/core/coding/python/errors.py +0 -108
  135. langfun/core/coding/python/errors_test.py +0 -99
  136. langfun/core/coding/python/permissions.py +0 -90
  137. langfun/core/coding/python/permissions_test.py +0 -86
  138. langfun/core/structured/prompting.py +0 -238
  139. langfun/core/text_formatting.py +0 -162
  140. langfun/core/text_formatting_test.py +0 -47
  141. langfun-0.0.2.dev20240429.dist-info/METADATA +0 -100
  142. langfun-0.0.2.dev20240429.dist-info/RECORD +0 -108
  143. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/LICENSE +0 -0
  144. {langfun-0.0.2.dev20240429.dist-info → langfun-0.1.2.dev202501150804.dist-info}/top_level.txt +0 -0
@@ -13,16 +13,11 @@
13
13
  # limitations under the License.
14
14
  """Language models from OpenAI."""
15
15
 
16
- import collections
17
- import functools
18
16
  import os
19
- from typing import Annotated, Any, cast
17
+ from typing import Annotated, Any
20
18
 
21
19
  import langfun.core as lf
22
- from langfun.core import modalities as lf_modalities
23
- import openai
24
- from openai import error as openai_error
25
- from openai import openai_object
20
+ from langfun.core.llms import openai_compatible
26
21
  import pyglove as pg
27
22
 
28
23
 
@@ -33,52 +28,277 @@ _DEFAULT_RPM = 3000
33
28
  SUPPORTED_MODELS_AND_SETTINGS = {
34
29
  # Models from https://platform.openai.com/docs/models
35
30
  # RPM is from https://platform.openai.com/docs/guides/rate-limits
31
+ # o1 (preview) models.
32
+ # Pricing in US dollars, from https://openai.com/api/pricing/
33
+ # as of 2024-10-10.
34
+ 'o1': pg.Dict(
35
+ in_service=True,
36
+ rpm=10000,
37
+ tpm=5000000,
38
+ cost_per_1k_input_tokens=0.015,
39
+ cost_per_1k_output_tokens=0.06,
40
+ ),
41
+ 'o1-preview': pg.Dict(
42
+ in_service=True,
43
+ rpm=10000,
44
+ tpm=5000000,
45
+ cost_per_1k_input_tokens=0.015,
46
+ cost_per_1k_output_tokens=0.06,
47
+ ),
48
+ 'o1-preview-2024-09-12': pg.Dict(
49
+ in_service=True,
50
+ rpm=10000,
51
+ tpm=5000000,
52
+ cost_per_1k_input_tokens=0.015,
53
+ cost_per_1k_output_tokens=0.06,
54
+ ),
55
+ 'o1-mini': pg.Dict(
56
+ in_service=True,
57
+ rpm=10000,
58
+ tpm=5000000,
59
+ cost_per_1k_input_tokens=0.003,
60
+ cost_per_1k_output_tokens=0.012,
61
+ ),
62
+ 'o1-mini-2024-09-12': pg.Dict(
63
+ in_service=True,
64
+ rpm=10000,
65
+ tpm=5000000,
66
+ cost_per_1k_input_tokens=0.003,
67
+ cost_per_1k_output_tokens=0.012,
68
+ ),
69
+ # GPT-4o models
70
+ 'gpt-4o-mini': pg.Dict(
71
+ in_service=True,
72
+ rpm=10000,
73
+ tpm=5000000,
74
+ cost_per_1k_input_tokens=0.00015,
75
+ cost_per_1k_output_tokens=0.0006,
76
+ ),
77
+ 'gpt-4o-mini-2024-07-18': pg.Dict(
78
+ in_service=True,
79
+ rpm=10000,
80
+ tpm=5000000,
81
+ cost_per_1k_input_tokens=0.00015,
82
+ cost_per_1k_output_tokens=0.0006,
83
+ ),
84
+ 'gpt-4o': pg.Dict(
85
+ in_service=True,
86
+ rpm=10000,
87
+ tpm=5000000,
88
+ cost_per_1k_input_tokens=0.0025,
89
+ cost_per_1k_output_tokens=0.01,
90
+ ),
91
+ 'gpt-4o-2024-11-20': pg.Dict(
92
+ in_service=True,
93
+ rpm=10000,
94
+ tpm=5000000,
95
+ cost_per_1k_input_tokens=0.0025,
96
+ cost_per_1k_output_tokens=0.01,
97
+ ),
98
+ 'gpt-4o-2024-08-06': pg.Dict(
99
+ in_service=True,
100
+ rpm=10000,
101
+ tpm=5000000,
102
+ cost_per_1k_input_tokens=0.0025,
103
+ cost_per_1k_output_tokens=0.01,
104
+ ),
105
+ 'gpt-4o-2024-05-13': pg.Dict(
106
+ in_service=True,
107
+ rpm=10000,
108
+ tpm=5000000,
109
+ cost_per_1k_input_tokens=0.005,
110
+ cost_per_1k_output_tokens=0.015,
111
+ ),
36
112
  # GPT-4-Turbo models
37
- 'gpt-4-turbo': pg.Dict(rpm=10000, tpm=1500000),
38
- 'gpt-4-turbo-2024-04-09': pg.Dict(rpm=10000, tpm=1500000),
39
- 'gpt-4-turbo-preview': pg.Dict(rpm=10000, tpm=1500000),
40
- 'gpt-4-0125-preview': pg.Dict(rpm=10000, tpm=1500000),
41
- 'gpt-4-1106-preview': pg.Dict(rpm=10000, tpm=1500000),
42
- 'gpt-4-vision-preview': pg.Dict(rpm=10000, tpm=1500000),
113
+ 'gpt-4-turbo': pg.Dict(
114
+ in_service=True,
115
+ rpm=10000,
116
+ tpm=2000000,
117
+ cost_per_1k_input_tokens=0.01,
118
+ cost_per_1k_output_tokens=0.03,
119
+ ),
120
+ 'gpt-4-turbo-2024-04-09': pg.Dict(
121
+ in_service=True,
122
+ rpm=10000,
123
+ tpm=2000000,
124
+ cost_per_1k_input_tokens=0.01,
125
+ cost_per_1k_output_tokens=0.03,
126
+ ),
127
+ 'gpt-4-turbo-preview': pg.Dict(
128
+ in_service=True,
129
+ rpm=10000,
130
+ tpm=2000000,
131
+ cost_per_1k_input_tokens=0.01,
132
+ cost_per_1k_output_tokens=0.03,
133
+ ),
134
+ 'gpt-4-0125-preview': pg.Dict(
135
+ in_service=True,
136
+ rpm=10000,
137
+ tpm=2000000,
138
+ cost_per_1k_input_tokens=0.01,
139
+ cost_per_1k_output_tokens=0.03,
140
+ ),
141
+ 'gpt-4-1106-preview': pg.Dict(
142
+ in_service=True,
143
+ rpm=10000,
144
+ tpm=2000000,
145
+ cost_per_1k_input_tokens=0.01,
146
+ cost_per_1k_output_tokens=0.03,
147
+ ),
148
+ 'gpt-4-vision-preview': pg.Dict(
149
+ in_service=True,
150
+ rpm=10000,
151
+ tpm=2000000,
152
+ cost_per_1k_input_tokens=0.01,
153
+ cost_per_1k_output_tokens=0.03,
154
+ ),
43
155
  'gpt-4-1106-vision-preview': pg.Dict(
44
- rpm=10000, tpm=1500000
156
+ in_service=True,
157
+ rpm=10000,
158
+ tpm=2000000,
159
+ cost_per_1k_input_tokens=0.01,
160
+ cost_per_1k_output_tokens=0.03,
45
161
  ),
46
162
  # GPT-4 models
47
- 'gpt-4': pg.Dict(rpm=10000, tpm=300000),
48
- 'gpt-4-0613': pg.Dict(rpm=10000, tpm=300000),
49
- 'gpt-4-0314': pg.Dict(rpm=10000, tpm=300000),
50
- 'gpt-4-32k': pg.Dict(rpm=10000, tpm=300000),
51
- 'gpt-4-32k-0613': pg.Dict(rpm=10000, tpm=300000),
52
- 'gpt-4-32k-0314': pg.Dict(rpm=10000, tpm=300000),
163
+ 'gpt-4': pg.Dict(
164
+ in_service=True,
165
+ rpm=10000,
166
+ tpm=300000,
167
+ cost_per_1k_input_tokens=0.03,
168
+ cost_per_1k_output_tokens=0.06,
169
+ ),
170
+ 'gpt-4-0613': pg.Dict(
171
+ in_service=False,
172
+ rpm=10000,
173
+ tpm=300000,
174
+ cost_per_1k_input_tokens=0.03,
175
+ cost_per_1k_output_tokens=0.06,
176
+ ),
177
+ 'gpt-4-0314': pg.Dict(
178
+ in_service=False,
179
+ rpm=10000,
180
+ tpm=300000,
181
+ cost_per_1k_input_tokens=0.03,
182
+ cost_per_1k_output_tokens=0.06,
183
+ ),
184
+ 'gpt-4-32k': pg.Dict(
185
+ in_service=True,
186
+ rpm=10000,
187
+ tpm=300000,
188
+ cost_per_1k_input_tokens=0.06,
189
+ cost_per_1k_output_tokens=0.12,
190
+ ),
191
+ 'gpt-4-32k-0613': pg.Dict(
192
+ in_service=False,
193
+ rpm=10000,
194
+ tpm=300000,
195
+ cost_per_1k_input_tokens=0.06,
196
+ cost_per_1k_output_tokens=0.12,
197
+ ),
198
+ 'gpt-4-32k-0314': pg.Dict(
199
+ in_service=False,
200
+ rpm=10000,
201
+ tpm=300000,
202
+ cost_per_1k_input_tokens=0.06,
203
+ cost_per_1k_output_tokens=0.12,
204
+ ),
53
205
  # GPT-3.5-Turbo models
54
- 'gpt-3.5-turbo': pg.Dict(rpm=10000, tpm=2000000),
55
- 'gpt-3.5-turbo-0125': pg.Dict(rpm=10000, tpm=2000000),
56
- 'gpt-3.5-turbo-1106': pg.Dict(rpm=10000, tpm=2000000),
57
- 'gpt-3.5-turbo-0613': pg.Dict(rpm=10000, tpm=2000000),
58
- 'gpt-3.5-turbo-0301': pg.Dict(rpm=10000, tpm=2000000),
59
- 'gpt-3.5-turbo-16k': pg.Dict(rpm=10000, tpm=2000000),
60
- 'gpt-3.5-turbo-16k-0613': pg.Dict(rpm=10000, tpm=2000000),
61
- 'gpt-3.5-turbo-16k-0301': pg.Dict(rpm=10000, tpm=2000000),
206
+ 'gpt-3.5-turbo': pg.Dict(
207
+ in_service=True,
208
+ rpm=10000,
209
+ tpm=2000000,
210
+ cost_per_1k_input_tokens=0.0005,
211
+ cost_per_1k_output_tokens=0.0015,
212
+ ),
213
+ 'gpt-3.5-turbo-0125': pg.Dict(
214
+ in_service=True,
215
+ rpm=10000,
216
+ tpm=2000000,
217
+ cost_per_1k_input_tokens=0.0005,
218
+ cost_per_1k_output_tokens=0.0015,
219
+ ),
220
+ 'gpt-3.5-turbo-1106': pg.Dict(
221
+ in_service=True,
222
+ rpm=10000,
223
+ tpm=2000000,
224
+ cost_per_1k_input_tokens=0.001,
225
+ cost_per_1k_output_tokens=0.002,
226
+ ),
227
+ 'gpt-3.5-turbo-0613': pg.Dict(
228
+ in_service=True,
229
+ rpm=10000,
230
+ tpm=2000000,
231
+ cost_per_1k_input_tokens=0.0015,
232
+ cost_per_1k_output_tokens=0.002,
233
+ ),
234
+ 'gpt-3.5-turbo-0301': pg.Dict(
235
+ in_service=True,
236
+ rpm=10000,
237
+ tpm=2000000,
238
+ cost_per_1k_input_tokens=0.0015,
239
+ cost_per_1k_output_tokens=0.002,
240
+ ),
241
+ 'gpt-3.5-turbo-16k': pg.Dict(
242
+ in_service=True,
243
+ rpm=10000,
244
+ tpm=2000000,
245
+ cost_per_1k_input_tokens=0.003,
246
+ cost_per_1k_output_tokens=0.004,
247
+ ),
248
+ 'gpt-3.5-turbo-16k-0613': pg.Dict(
249
+ in_service=True,
250
+ rpm=10000,
251
+ tpm=2000000,
252
+ cost_per_1k_input_tokens=0.003,
253
+ cost_per_1k_output_tokens=0.004,
254
+ ),
255
+ 'gpt-3.5-turbo-16k-0301': pg.Dict(
256
+ in_service=False,
257
+ rpm=10000,
258
+ tpm=2000000,
259
+ cost_per_1k_input_tokens=0.003,
260
+ cost_per_1k_output_tokens=0.004,
261
+ ),
62
262
  # GPT-3.5 models
63
- 'text-davinci-003': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
64
- 'text-davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
65
- 'code-davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
66
- # GPT-3 instruction-tuned models
67
- 'text-curie-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
68
- 'text-babbage-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
69
- 'text-ada-001': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
70
- 'davinci': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
71
- 'curie': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
72
- 'babbage': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
73
- 'ada': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
74
- # GPT-3 base models
75
- 'babbage-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
76
- 'davinci-002': pg.Dict(rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
263
+ 'text-davinci-003': pg.Dict(
264
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
265
+ ),
266
+ 'text-davinci-002': pg.Dict(
267
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
268
+ ),
269
+ 'code-davinci-002': pg.Dict(
270
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
271
+ ),
272
+ # GPT-3 instruction-tuned models (Deprecated)
273
+ 'text-curie-001': pg.Dict(
274
+ in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM
275
+ ),
276
+ 'text-babbage-001': pg.Dict(
277
+ in_service=False,
278
+ rpm=_DEFAULT_RPM,
279
+ tpm=_DEFAULT_TPM,
280
+ ),
281
+ 'text-ada-001': pg.Dict(
282
+ in_service=False,
283
+ rpm=_DEFAULT_RPM,
284
+ tpm=_DEFAULT_TPM,
285
+ ),
286
+ 'davinci': pg.Dict(
287
+ in_service=False,
288
+ rpm=_DEFAULT_RPM,
289
+ tpm=_DEFAULT_TPM,
290
+ ),
291
+ 'curie': pg.Dict(in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
292
+ 'babbage': pg.Dict(in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
293
+ 'ada': pg.Dict(in_service=False, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
294
+ # GPT-3 base models that are still in service.
295
+ 'babbage-002': pg.Dict(in_service=True, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
296
+ 'davinci-002': pg.Dict(in_service=True, rpm=_DEFAULT_RPM, tpm=_DEFAULT_TPM),
77
297
  }
78
298
 
79
299
 
80
300
  @lf.use_init_args(['model'])
81
- class OpenAI(lf.LanguageModel):
301
+ class OpenAI(openai_compatible.OpenAICompatible):
82
302
  """OpenAI model."""
83
303
 
84
304
  model: pg.typing.Annotated[
@@ -86,12 +306,9 @@ class OpenAI(lf.LanguageModel):
86
306
  pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
87
307
  ),
88
308
  'The name of the model to use.',
89
- ] = 'gpt-3.5-turbo'
309
+ ]
90
310
 
91
- multimodal: Annotated[
92
- bool,
93
- 'Whether this model has multimodal support.'
94
- ] = False
311
+ api_endpoint: str = 'https://api.openai.com/v1/chat/completions'
95
312
 
96
313
  api_key: Annotated[
97
314
  str | None,
@@ -110,23 +327,44 @@ class OpenAI(lf.LanguageModel):
110
327
  ),
111
328
  ] = None
112
329
 
330
+ project: Annotated[
331
+ str | None,
332
+ (
333
+ 'Project. If None, the key will be read from environment '
334
+ "variable 'OPENAI_PROJECT'. Based on the value, usages from "
335
+ "these API requests will count against the project's quota. "
336
+ ),
337
+ ] = None
338
+
113
339
  def _on_bound(self):
114
340
  super()._on_bound()
115
- self.__dict__.pop('_api_initialized', None)
341
+ self._api_key = None
342
+ self._organization = None
343
+ self._project = None
116
344
 
117
- @functools.cached_property
118
- def _api_initialized(self):
345
+ def _initialize(self):
119
346
  api_key = self.api_key or os.environ.get('OPENAI_API_KEY', None)
120
347
  if not api_key:
121
348
  raise ValueError(
122
349
  'Please specify `api_key` during `__init__` or set environment '
123
350
  'variable `OPENAI_API_KEY` with your OpenAI API key.'
124
351
  )
125
- openai.api_key = api_key
126
- org = self.organization or os.environ.get('OPENAI_ORGANIZATION', None)
127
- if org:
128
- openai.organization = org
129
- return True
352
+ self._api_key = api_key
353
+ self._organization = self.organization or os.environ.get(
354
+ 'OPENAI_ORGANIZATION', None
355
+ )
356
+ self._project = self.project or os.environ.get('OPENAI_PROJECT', None)
357
+
358
+ @property
359
+ def headers(self) -> dict[str, Any]:
360
+ assert self._api_initialized
361
+ headers = super().headers
362
+ headers['Authorization'] = f'Bearer {self._api_key}'
363
+ if self._organization:
364
+ headers['OpenAI-Organization'] = self._organization
365
+ if self._project:
366
+ headers['OpenAI-Project'] = self._project
367
+ return headers
130
368
 
131
369
  @property
132
370
  def model_id(self) -> str:
@@ -141,149 +379,67 @@ class OpenAI(lf.LanguageModel):
141
379
  requests_per_min=rpm, tokens_per_min=tpm
142
380
  )
143
381
 
382
+ def estimate_cost(
383
+ self,
384
+ num_input_tokens: int,
385
+ num_output_tokens: int
386
+ ) -> float | None:
387
+ """Estimate the cost based on usage."""
388
+ cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
389
+ 'cost_per_1k_input_tokens', None
390
+ )
391
+ cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
392
+ 'cost_per_1k_output_tokens', None
393
+ )
394
+ if cost_per_1k_output_tokens is None or cost_per_1k_input_tokens is None:
395
+ return None
396
+ return (
397
+ cost_per_1k_input_tokens * num_input_tokens
398
+ + cost_per_1k_output_tokens * num_output_tokens
399
+ ) / 1000
400
+
144
401
  @classmethod
145
402
  def dir(cls):
146
- return openai.Model.list()
403
+ return [k for k, v in SUPPORTED_MODELS_AND_SETTINGS.items() if v.in_service]
147
404
 
148
- @property
149
- def is_chat_model(self):
150
- """Returns True if the model is a chat model."""
151
- return self.model.startswith(('gpt-4', 'gpt-3.5-turbo'))
152
-
153
- def _get_request_args(
405
+ def _request_args(
154
406
  self, options: lf.LMSamplingOptions) -> dict[str, Any]:
155
- # Reference:
156
- # https://platform.openai.com/docs/api-reference/completions/create
157
- # NOTE(daiyip): options.top_k is not applicable.
158
- args = dict(
159
- n=options.n,
160
- stream=False,
161
- timeout=self.timeout,
162
- logprobs=options.logprobs,
163
- top_logprobs=options.top_logprobs,
164
- )
165
- # Completion and ChatCompletion uses different parameter name for model.
166
- args['model' if self.is_chat_model else 'engine'] = self.model
167
-
168
- if options.temperature is not None:
169
- args['temperature'] = options.temperature
170
- if options.max_tokens is not None:
171
- args['max_tokens'] = options.max_tokens
172
- if options.top_p is not None:
173
- args['top_p'] = options.top_p
174
- if options.stop:
175
- args['stop'] = options.stop
176
- return args
177
-
178
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
179
- assert self._api_initialized
180
- if self.is_chat_model:
181
- return self._chat_complete_batch(prompts)
182
- else:
183
- return self._complete_batch(prompts)
184
-
185
- def _complete_batch(
186
- self, prompts: list[lf.Message]
187
- ) -> list[lf.LMSamplingResult]:
188
-
189
- def _open_ai_completion(prompts):
190
- response = openai.Completion.create(
191
- prompt=[p.text for p in prompts],
192
- **self._get_request_args(self.sampling_options),
193
- )
194
- response = cast(openai_object.OpenAIObject, response)
195
- # Parse response.
196
- samples_by_index = collections.defaultdict(list)
197
- for choice in response.choices:
198
- samples_by_index[choice.index].append(
199
- lf.LMSample(choice.text.strip(), score=choice.logprobs or 0.0)
200
- )
201
-
202
- usage = lf.LMSamplingUsage(
203
- prompt_tokens=response.usage.prompt_tokens,
204
- completion_tokens=response.usage.completion_tokens,
205
- total_tokens=response.usage.total_tokens,
206
- )
207
- return [
208
- lf.LMSamplingResult(
209
- samples_by_index[index], usage=usage if index == 0 else None
210
- )
211
- for index in sorted(samples_by_index.keys())
212
- ]
213
-
214
- return self._parallel_execute_with_currency_control(
215
- _open_ai_completion,
216
- [prompts],
217
- retry_on_errors=(
218
- openai_error.ServiceUnavailableError,
219
- openai_error.RateLimitError,
220
- # Handling transient OpenAI server error (code 500). Check out
221
- # https://platform.openai.com/docs/guides/error-codes/error-codes
222
- (openai_error.APIError,
223
- '.*The server had an error processing your request'),
224
- ),
225
- )[0]
226
-
227
- def _chat_complete_batch(
228
- self, prompts: list[lf.Message]
229
- ) -> list[lf.LMSamplingResult]:
230
- def _open_ai_chat_completion(prompt: lf.Message):
231
- if self.multimodal:
232
- content = []
233
- for chunk in prompt.chunk():
234
- if isinstance(chunk, str):
235
- item = dict(type='text', text=chunk)
236
- elif isinstance(chunk, lf_modalities.Image) and chunk.uri:
237
- item = dict(type='image_url', image_url=chunk.uri)
238
- else:
239
- raise ValueError(f'Unsupported modality object: {chunk!r}.')
240
- content.append(item)
241
- else:
242
- content = prompt.text
243
-
244
- response = openai.ChatCompletion.create(
245
- # TODO(daiyip): support conversation history and system prompt.
246
- messages=[{'role': 'user', 'content': content}],
247
- **self._get_request_args(self.sampling_options),
248
- )
249
- response = cast(openai_object.OpenAIObject, response)
250
- samples = []
251
- for choice in response.choices:
252
- logprobs = None
253
- if choice.logprobs:
254
- logprobs = [
255
- (
256
- t.token,
257
- t.logprob,
258
- [(tt.token, tt.logprob) for tt in t.top_logprobs],
259
- )
260
- for t in choice.logprobs.content
261
- ]
262
- samples.append(
263
- lf.LMSample(
264
- choice.message.content,
265
- score=0.0,
266
- logprobs=logprobs,
267
- )
268
- )
269
-
270
- return lf.LMSamplingResult(
271
- samples=samples,
272
- usage=lf.LMSamplingUsage(
273
- prompt_tokens=response.usage.prompt_tokens,
274
- completion_tokens=response.usage.completion_tokens,
275
- total_tokens=response.usage.total_tokens,
276
- ),
277
- )
407
+ # Reasoning models (o1 series) does not support `logprobs` by 2024/09/12.
408
+ if options.logprobs and self.model.startswith(('o1-', 'o3-')):
409
+ raise RuntimeError('`logprobs` is not supported on {self.model!r}.')
410
+ return super()._request_args(options)
278
411
 
279
- return self._parallel_execute_with_currency_control(
280
- _open_ai_chat_completion,
281
- prompts,
282
- retry_on_errors=(
283
- openai_error.ServiceUnavailableError,
284
- openai_error.RateLimitError,
285
- ),
286
- )
412
+
413
+ class GptO1(OpenAI):
414
+ """GPT-O1."""
415
+
416
+ model = 'o1'
417
+ multimodal = True
418
+ timeout = None
419
+
420
+
421
+ class GptO1Preview(OpenAI):
422
+ """GPT-O1."""
423
+ model = 'o1-preview'
424
+ timeout = None
425
+
426
+
427
+ class GptO1Preview_20240912(OpenAI): # pylint: disable=invalid-name
428
+ """GPT O1."""
429
+ model = 'o1-preview-2024-09-12'
430
+ timeout = None
431
+
432
+
433
+ class GptO1Mini(OpenAI):
434
+ """GPT O1-mini."""
435
+ model = 'o1-mini'
436
+ timeout = None
437
+
438
+
439
+ class GptO1Mini_20240912(OpenAI): # pylint: disable=invalid-name
440
+ """GPT O1-mini."""
441
+ model = 'o1-mini-2024-09-12'
442
+ timeout = None
287
443
 
288
444
 
289
445
  class Gpt4(OpenAI):
@@ -308,12 +464,12 @@ class Gpt4TurboPreview(Gpt4):
308
464
  model = 'gpt-4-turbo-preview'
309
465
 
310
466
 
311
- class Gpt4TurboPreview_0125(Gpt4TurboPreview): # pylint: disable=invalid-name
467
+ class Gpt4TurboPreview_20240125(Gpt4TurboPreview): # pylint: disable=invalid-name
312
468
  """GPT-4 Turbo Preview with 128k context window. Knowledge up to Dec. 2023."""
313
469
  model = 'gpt-4-0125-preview'
314
470
 
315
471
 
316
- class Gpt4TurboPreview_1106(Gpt4TurboPreview): # pylint: disable=invalid-name
472
+ class Gpt4TurboPreview_20231106(Gpt4TurboPreview): # pylint: disable=invalid-name
317
473
  """GPT-4 Turbo Preview with 128k context window. Knowledge up to Apr. 2023."""
318
474
  model = 'gpt-4-1106-preview'
319
475
 
@@ -324,12 +480,12 @@ class Gpt4VisionPreview(Gpt4):
324
480
  multimodal = True
325
481
 
326
482
 
327
- class Gpt4VisionPreview_1106(Gpt4): # pylint: disable=invalid-name
483
+ class Gpt4VisionPreview_20231106(Gpt4): # pylint: disable=invalid-name
328
484
  """GPT-4 Turbo vision preview. 128k context window. Knowledge to Apr. 2023."""
329
485
  model = 'gpt-4-1106-vision-preview'
330
486
 
331
487
 
332
- class Gpt4_0613(Gpt4): # pylint:disable=invalid-name
488
+ class Gpt4_20230613(Gpt4): # pylint:disable=invalid-name
333
489
  """GPT-4 @20230613. 8K context window. Knowledge up to 9-2021."""
334
490
  model = 'gpt-4-0613'
335
491
 
@@ -339,11 +495,47 @@ class Gpt4_32K(Gpt4): # pylint:disable=invalid-name
339
495
  model = 'gpt-4-32k'
340
496
 
341
497
 
342
- class Gpt4_32K_0613(Gpt4_32K): # pylint:disable=invalid-name
498
+ class Gpt4_32K_20230613(Gpt4_32K): # pylint:disable=invalid-name
343
499
  """GPT-4 @20230613. 32K context window. Knowledge up to 9-2021."""
344
500
  model = 'gpt-4-32k-0613'
345
501
 
346
502
 
503
+ class Gpt4oMini(OpenAI):
504
+ """GPT-4o Mini."""
505
+ model = 'gpt-4o-mini'
506
+ multimodal = True
507
+
508
+
509
+ class Gpt4oMini_20240718(OpenAI): # pylint:disable=invalid-name
510
+ """GPT-4o Mini."""
511
+ model = 'gpt-4o-mini-2024-07-18'
512
+ multimodal = True
513
+
514
+
515
+ class Gpt4o(OpenAI):
516
+ """GPT-4o."""
517
+ model = 'gpt-4o'
518
+ multimodal = True
519
+
520
+
521
+ class Gpt4o_20241120(OpenAI): # pylint:disable=invalid-name
522
+ """GPT-4o version 2024-11-20."""
523
+ model = 'gpt-4o-2024-11-20'
524
+ multimodal = True
525
+
526
+
527
+ class Gpt4o_20240806(OpenAI): # pylint:disable=invalid-name
528
+ """GPT-4o version 2024-08-06."""
529
+ model = 'gpt-4o-2024-08-06'
530
+ multimodal = True
531
+
532
+
533
+ class Gpt4o_20240513(OpenAI): # pylint:disable=invalid-name
534
+ """GPT-4o version 2024-05-13."""
535
+ model = 'gpt-4o-2024-05-13'
536
+ multimodal = True
537
+
538
+
347
539
  class Gpt35(OpenAI):
348
540
  """GPT-3.5. 4K max tokens, trained up on data up to Sep, 2021."""
349
541
  model = 'text-davinci-003'
@@ -354,17 +546,17 @@ class Gpt35Turbo(Gpt35):
354
546
  model = 'gpt-3.5-turbo'
355
547
 
356
548
 
357
- class Gpt35Turbo_0125(Gpt35Turbo): # pylint:disable=invalid-name
549
+ class Gpt35Turbo_20240125(Gpt35Turbo): # pylint:disable=invalid-name
358
550
  """GPT-3.5 Turbo @20240125. 16K context window. Knowledge up to 09/2021."""
359
551
  model = 'gpt-3.5-turbo-0125'
360
552
 
361
553
 
362
- class Gpt35Turbo_1106(Gpt35Turbo): # pylint:disable=invalid-name
554
+ class Gpt35Turbo_20231106(Gpt35Turbo): # pylint:disable=invalid-name
363
555
  """Gpt3.5 Turbo @20231106. 16K context window. Knowledge up to 09/2021."""
364
556
  model = 'gpt-3.5-turbo-1106'
365
557
 
366
558
 
367
- class Gpt35Turbo_0613(Gpt35Turbo): # pylint:disable=invalid-name
559
+ class Gpt35Turbo_20230613(Gpt35Turbo): # pylint:disable=invalid-name
368
560
  """Gpt3.5 Turbo snapshot at 2023/06/13, with 4K context window size."""
369
561
  model = 'gpt-3.5-turbo-0613'
370
562
 
@@ -374,7 +566,7 @@ class Gpt35Turbo16K(Gpt35Turbo):
374
566
  model = 'gpt-3.5-turbo-16k'
375
567
 
376
568
 
377
- class Gpt35Turbo16K_0613(Gpt35Turbo): # pylint:disable=invalid-name
569
+ class Gpt35Turbo16K_20230613(Gpt35Turbo): # pylint:disable=invalid-name
378
570
  """Gtp 3.5 Turbo 16K 0613."""
379
571
  model = 'gpt-3.5-turbo-16k-0613'
380
572