langfun 0.1.2.dev202410100804__py3-none-any.whl → 0.1.2.dev202410120803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. langfun/core/__init__.py +1 -0
  2. langfun/core/eval/base_test.py +1 -0
  3. langfun/core/langfunc_test.py +2 -2
  4. langfun/core/language_model.py +140 -24
  5. langfun/core/language_model_test.py +166 -36
  6. langfun/core/llms/__init__.py +8 -1
  7. langfun/core/llms/anthropic.py +72 -7
  8. langfun/core/llms/cache/in_memory_test.py +3 -2
  9. langfun/core/llms/fake_test.py +7 -0
  10. langfun/core/llms/groq.py +154 -6
  11. langfun/core/llms/openai.py +300 -42
  12. langfun/core/llms/openai_test.py +35 -8
  13. langfun/core/llms/vertexai.py +121 -16
  14. langfun/core/logging.py +150 -43
  15. langfun/core/logging_test.py +33 -0
  16. langfun/core/message.py +249 -70
  17. langfun/core/message_test.py +70 -45
  18. langfun/core/modalities/audio.py +1 -1
  19. langfun/core/modalities/audio_test.py +1 -1
  20. langfun/core/modalities/image.py +1 -1
  21. langfun/core/modalities/image_test.py +9 -3
  22. langfun/core/modalities/mime.py +39 -3
  23. langfun/core/modalities/mime_test.py +39 -0
  24. langfun/core/modalities/ms_office.py +2 -5
  25. langfun/core/modalities/ms_office_test.py +1 -1
  26. langfun/core/modalities/pdf_test.py +1 -1
  27. langfun/core/modalities/video.py +1 -1
  28. langfun/core/modalities/video_test.py +2 -2
  29. langfun/core/structured/completion_test.py +1 -0
  30. langfun/core/structured/mapping.py +38 -0
  31. langfun/core/structured/mapping_test.py +55 -0
  32. langfun/core/structured/parsing_test.py +2 -1
  33. langfun/core/structured/prompting_test.py +1 -0
  34. langfun/core/structured/schema.py +34 -0
  35. langfun/core/template.py +110 -1
  36. langfun/core/template_test.py +37 -0
  37. langfun/core/templates/selfplay_test.py +4 -2
  38. {langfun-0.1.2.dev202410100804.dist-info → langfun-0.1.2.dev202410120803.dist-info}/METADATA +1 -1
  39. {langfun-0.1.2.dev202410100804.dist-info → langfun-0.1.2.dev202410120803.dist-info}/RECORD +42 -42
  40. {langfun-0.1.2.dev202410100804.dist-info → langfun-0.1.2.dev202410120803.dist-info}/LICENSE +0 -0
  41. {langfun-0.1.2.dev202410100804.dist-info → langfun-0.1.2.dev202410120803.dist-info}/WHEEL +0 -0
  42. {langfun-0.1.2.dev202410100804.dist-info → langfun-0.1.2.dev202410120803.dist-info}/top_level.txt +0 -0
@@ -95,11 +95,18 @@ from langfun.core.llms.anthropic import Claude3Sonnet
95
95
  from langfun.core.llms.anthropic import Claude3Haiku
96
96
 
97
97
  from langfun.core.llms.groq import Groq
98
+ from langfun.core.llms.groq import GroqLlama3_2_3B
99
+ from langfun.core.llms.groq import GroqLlama3_2_1B
100
+ from langfun.core.llms.groq import GroqLlama3_1_70B
101
+ from langfun.core.llms.groq import GroqLlama3_1_8B
98
102
  from langfun.core.llms.groq import GroqLlama3_70B
99
103
  from langfun.core.llms.groq import GroqLlama3_8B
100
104
  from langfun.core.llms.groq import GroqLlama2_70B
101
105
  from langfun.core.llms.groq import GroqMistral_8x7B
102
- from langfun.core.llms.groq import GroqGemma7B_IT
106
+ from langfun.core.llms.groq import GroqGemma2_9B_IT
107
+ from langfun.core.llms.groq import GroqGemma_7B_IT
108
+ from langfun.core.llms.groq import GroqWhisper_Large_v3
109
+ from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo
103
110
 
104
111
  from langfun.core.llms.vertexai import VertexAI
105
112
  from langfun.core.llms.vertexai import VertexAIGemini1_5
@@ -28,15 +28,57 @@ SUPPORTED_MODELS_AND_SETTINGS = {
28
28
  # Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
29
29
  # RPM/TPM for Claude-2.1, Claude-2.0, and Claude-Instant-1.2 estimated
30
30
  # as RPM/TPM of the largest-available model (Claude-3-Opus).
31
+ # Price in US dollars at https://www.anthropic.com/pricing
32
+ # as of 2024-10-10.
31
33
  'claude-3-5-sonnet-20240620': pg.Dict(
32
- max_tokens=4096, rpm=4000, tpm=400000
34
+ max_tokens=4096,
35
+ rpm=4000,
36
+ tpm=400000,
37
+ cost_per_1k_input_tokens=0.003,
38
+ cost_per_1k_output_tokens=0.015,
39
+ ),
40
+ 'claude-3-opus-20240229': pg.Dict(
41
+ max_tokens=4096,
42
+ rpm=4000,
43
+ tpm=400000,
44
+ cost_per_1k_input_tokens=0.015,
45
+ cost_per_1k_output_tokens=0.075,
46
+ ),
47
+ 'claude-3-sonnet-20240229': pg.Dict(
48
+ max_tokens=4096,
49
+ rpm=4000,
50
+ tpm=400000,
51
+ cost_per_1k_input_tokens=0.003,
52
+ cost_per_1k_output_tokens=0.015,
53
+ ),
54
+ 'claude-3-haiku-20240307': pg.Dict(
55
+ max_tokens=4096,
56
+ rpm=4000,
57
+ tpm=400000,
58
+ cost_per_1k_input_tokens=0.00025,
59
+ cost_per_1k_output_tokens=0.00125,
60
+ ),
61
+ 'claude-2.1': pg.Dict(
62
+ max_tokens=4096,
63
+ rpm=4000,
64
+ tpm=400000,
65
+ cost_per_1k_input_tokens=0.008,
66
+ cost_per_1k_output_tokens=0.024,
67
+ ),
68
+ 'claude-2.0': pg.Dict(
69
+ max_tokens=4096,
70
+ rpm=4000,
71
+ tpm=400000,
72
+ cost_per_1k_input_tokens=0.008,
73
+ cost_per_1k_output_tokens=0.024,
74
+ ),
75
+ 'claude-instant-1.2': pg.Dict(
76
+ max_tokens=4096,
77
+ rpm=4000,
78
+ tpm=400000,
79
+ cost_per_1k_input_tokens=0.0008,
80
+ cost_per_1k_output_tokens=0.0024,
33
81
  ),
34
- 'claude-3-opus-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
35
- 'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
36
- 'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
37
- 'claude-2.1': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
38
- 'claude-2.0': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
39
- 'claude-instant-1.2': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
40
82
  }
41
83
 
42
84
 
@@ -107,6 +149,25 @@ class Anthropic(rest.REST):
107
149
  requests_per_min=rpm, tokens_per_min=tpm
108
150
  )
109
151
 
152
+ def estimate_cost(
153
+ self,
154
+ num_input_tokens: int,
155
+ num_output_tokens: int
156
+ ) -> float | None:
157
+ """Estimate the cost based on usage."""
158
+ cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
159
+ 'cost_per_1k_input_tokens', None
160
+ )
161
+ cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
162
+ 'cost_per_1k_output_tokens', None
163
+ )
164
+ if cost_per_1k_output_tokens is None or cost_per_1k_input_tokens is None:
165
+ return None
166
+ return (
167
+ cost_per_1k_input_tokens * num_input_tokens
168
+ + cost_per_1k_output_tokens * num_output_tokens
169
+ ) / 1000
170
+
110
171
  def request(
111
172
  self,
112
173
  prompt: lf.Message,
@@ -181,6 +242,10 @@ class Anthropic(rest.REST):
181
242
  prompt_tokens=input_tokens,
182
243
  completion_tokens=output_tokens,
183
244
  total_tokens=input_tokens + output_tokens,
245
+ estimated_cost=self.estimate_cost(
246
+ num_input_tokens=input_tokens,
247
+ num_output_tokens=output_tokens,
248
+ ),
184
249
  ),
185
250
  )
186
251
 
@@ -66,14 +66,15 @@ class InMemoryLMCacheTest(unittest.TestCase):
66
66
  [
67
67
  lf.LMSample(
68
68
  lf.AIMessage(response_text, cache_seed=cache_seed),
69
- score=1.0
69
+ score=1.0,
70
70
  )
71
71
  ],
72
72
  usage=lf.LMSamplingUsage(
73
73
  1,
74
74
  len(response_text),
75
75
  len(response_text) + 1,
76
- )
76
+ ),
77
+ is_cached=True,
77
78
  )
78
79
  )
79
80
 
@@ -34,6 +34,7 @@ class EchoTest(unittest.TestCase):
34
34
  'hi',
35
35
  score=1.0,
36
36
  logprobs=None,
37
+ is_cached=False,
37
38
  usage=lf.LMSamplingUsage(2, 2, 4),
38
39
  tags=[lf.Message.TAG_LM_RESPONSE],
39
40
  ),
@@ -85,6 +86,7 @@ class StaticResponseTest(unittest.TestCase):
85
86
  canned_response,
86
87
  score=1.0,
87
88
  logprobs=None,
89
+ is_cached=False,
88
90
  usage=lf.LMSamplingUsage(2, 38, 40),
89
91
  tags=[lf.Message.TAG_LM_RESPONSE],
90
92
  ),
@@ -106,6 +108,7 @@ class StaticResponseTest(unittest.TestCase):
106
108
  canned_response,
107
109
  score=1.0,
108
110
  logprobs=None,
111
+ is_cached=False,
109
112
  usage=lf.LMSamplingUsage(15, 38, 53),
110
113
  tags=[lf.Message.TAG_LM_RESPONSE],
111
114
  ),
@@ -150,6 +153,7 @@ class StaticMappingTest(unittest.TestCase):
150
153
  'Hello',
151
154
  score=1.0,
152
155
  logprobs=None,
156
+ is_cached=False,
153
157
  usage=lf.LMSamplingUsage(2, 5, 7),
154
158
  tags=[lf.Message.TAG_LM_RESPONSE],
155
159
  ),
@@ -166,6 +170,7 @@ class StaticMappingTest(unittest.TestCase):
166
170
  'I am fine, how about you?',
167
171
  score=1.0,
168
172
  logprobs=None,
173
+ is_cached=False,
169
174
  usage=lf.LMSamplingUsage(12, 25, 37),
170
175
  tags=[lf.Message.TAG_LM_RESPONSE],
171
176
  ),
@@ -199,6 +204,7 @@ class StaticSequenceTest(unittest.TestCase):
199
204
  'Hello',
200
205
  score=1.0,
201
206
  logprobs=None,
207
+ is_cached=False,
202
208
  usage=lf.LMSamplingUsage(2, 5, 7),
203
209
  tags=[lf.Message.TAG_LM_RESPONSE],
204
210
  ),
@@ -215,6 +221,7 @@ class StaticSequenceTest(unittest.TestCase):
215
221
  'I am fine, how about you?',
216
222
  score=1.0,
217
223
  logprobs=None,
224
+ is_cached=False,
218
225
  usage=lf.LMSamplingUsage(12, 25, 37),
219
226
  tags=[lf.Message.TAG_LM_RESPONSE],
220
227
  ),
langfun/core/llms/groq.py CHANGED
@@ -24,11 +24,73 @@ import pyglove as pg
24
24
 
25
25
  SUPPORTED_MODELS_AND_SETTINGS = {
26
26
  # Refer https://console.groq.com/docs/models
27
- 'llama3-8b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
28
- 'llama3-70b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
29
- 'llama2-70b-4096': pg.Dict(max_tokens=4096, max_concurrency=16),
30
- 'mixtral-8x7b-32768': pg.Dict(max_tokens=32768, max_concurrency=16),
31
- 'gemma-7b-it': pg.Dict(max_tokens=8192, max_concurrency=16),
27
+ # Price in US dollars at https://groq.com/pricing/ as of 2024-10-10.
28
+ 'llama-3.2-3b-preview': pg.Dict(
29
+ max_tokens=8192,
30
+ max_concurrency=64,
31
+ cost_per_1k_input_tokens=0.00006,
32
+ cost_per_1k_output_tokens=0.00006,
33
+ ),
34
+ 'llama-3.2-1b-preview': pg.Dict(
35
+ max_tokens=8192,
36
+ max_concurrency=64,
37
+ cost_per_1k_input_tokens=0.00004,
38
+ cost_per_1k_output_tokens=0.00004,
39
+ ),
40
+ 'llama-3.1-70b-versatile': pg.Dict(
41
+ max_tokens=8192,
42
+ max_concurrency=16,
43
+ cost_per_1k_input_tokens=0.00059,
44
+ cost_per_1k_output_tokens=0.00079,
45
+ ),
46
+ 'llama-3.1-8b-instant': pg.Dict(
47
+ max_tokens=8192,
48
+ max_concurrency=32,
49
+ cost_per_1k_input_tokens=0.00005,
50
+ cost_per_1k_output_tokens=0.00008,
51
+ ),
52
+ 'llama3-70b-8192': pg.Dict(
53
+ max_tokens=8192,
54
+ max_concurrency=16,
55
+ cost_per_1k_input_tokens=0.00059,
56
+ cost_per_1k_output_tokens=0.00079,
57
+ ),
58
+ 'llama3-8b-8192': pg.Dict(
59
+ max_tokens=8192,
60
+ max_concurrency=32,
61
+ cost_per_1k_input_tokens=0.00005,
62
+ cost_per_1k_output_tokens=0.00008,
63
+ ),
64
+ 'llama2-70b-4096': pg.Dict(
65
+ max_tokens=4096,
66
+ max_concurrency=16,
67
+ ),
68
+ 'mixtral-8x7b-32768': pg.Dict(
69
+ max_tokens=32768,
70
+ max_concurrency=16,
71
+ cost_per_1k_input_tokens=0.00024,
72
+ cost_per_1k_output_tokens=0.00024,
73
+ ),
74
+ 'gemma2-9b-it': pg.Dict(
75
+ max_tokens=8192,
76
+ max_concurrency=32,
77
+ cost_per_1k_input_tokens=0.0002,
78
+ cost_per_1k_output_tokens=0.0002,
79
+ ),
80
+ 'gemma-7b-it': pg.Dict(
81
+ max_tokens=8192,
82
+ max_concurrency=32,
83
+ cost_per_1k_input_tokens=0.00007,
84
+ cost_per_1k_output_tokens=0.00007,
85
+ ),
86
+ 'whisper-large-v3': pg.Dict(
87
+ max_tokens=8192,
88
+ max_concurrency=16,
89
+ ),
90
+ 'whisper-large-v3-turbo': pg.Dict(
91
+ max_tokens=8192,
92
+ max_concurrency=16,
93
+ )
32
94
  }
33
95
 
34
96
 
@@ -89,6 +151,25 @@ class Groq(rest.REST):
89
151
  def max_concurrency(self) -> int:
90
152
  return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
91
153
 
154
+ def estimate_cost(
155
+ self,
156
+ num_input_tokens: int,
157
+ num_output_tokens: int
158
+ ) -> float | None:
159
+ """Estimate the cost based on usage."""
160
+ cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
161
+ 'cost_per_1k_input_tokens', None
162
+ )
163
+ cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
164
+ 'cost_per_1k_output_tokens', None
165
+ )
166
+ if cost_per_1k_input_tokens is None or cost_per_1k_output_tokens is None:
167
+ return None
168
+ return (
169
+ cost_per_1k_input_tokens * num_input_tokens
170
+ + cost_per_1k_output_tokens * num_output_tokens
171
+ ) / 1000
172
+
92
173
  def request(
93
174
  self,
94
175
  prompt: lf.Message,
@@ -156,6 +237,10 @@ class Groq(rest.REST):
156
237
  prompt_tokens=usage['prompt_tokens'],
157
238
  completion_tokens=usage['completion_tokens'],
158
239
  total_tokens=usage['total_tokens'],
240
+ estimated_cost=self.estimate_cost(
241
+ num_input_tokens=usage['prompt_tokens'],
242
+ num_output_tokens=usage['completion_tokens'],
243
+ ),
159
244
  ),
160
245
  )
161
246
 
@@ -170,6 +255,24 @@ class Groq(rest.REST):
170
255
  )
171
256
 
172
257
 
258
+ class GroqLlama3_2_3B(Groq): # pylint: disable=invalid-name
259
+ """Llama3.2-3B with 8K context window.
260
+
261
+ See: https://huggingface.co/meta-llama/Llama-3.2-3B
262
+ """
263
+
264
+ model = 'llama-3.2-3b-preview'
265
+
266
+
267
+ class GroqLlama3_2_1B(Groq): # pylint: disable=invalid-name
268
+ """Llama3.2-1B with 8K context window.
269
+
270
+ See: https://huggingface.co/meta-llama/Llama-3.2-1B
271
+ """
272
+
273
+ model = 'llama-3.2-3b-preview'
274
+
275
+
173
276
  class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
174
277
  """Llama3-8B with 8K context window.
175
278
 
@@ -179,6 +282,24 @@ class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
179
282
  model = 'llama3-8b-8192'
180
283
 
181
284
 
285
+ class GroqLlama3_1_70B(Groq): # pylint: disable=invalid-name
286
+ """Llama3.1-70B with 8K context window.
287
+
288
+ See: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md # pylint: disable=line-too-long
289
+ """
290
+
291
+ model = 'llama-3.1-70b-versatile'
292
+
293
+
294
+ class GroqLlama3_1_8B(Groq): # pylint: disable=invalid-name
295
+ """Llama3.1-8B with 8K context window.
296
+
297
+ See: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md # pylint: disable=line-too-long
298
+ """
299
+
300
+ model = 'llama-3.1-8b-instant'
301
+
302
+
182
303
  class GroqLlama3_70B(Groq): # pylint: disable=invalid-name
183
304
  """Llama3-70B with 8K context window.
184
305
 
@@ -206,10 +327,37 @@ class GroqMistral_8x7B(Groq): # pylint: disable=invalid-name
206
327
  model = 'mixtral-8x7b-32768'
207
328
 
208
329
 
209
- class GroqGemma7B_IT(Groq): # pylint: disable=invalid-name
330
+ class GroqGemma2_9B_IT(Groq): # pylint: disable=invalid-name
331
+ """Gemma2 9B with 8K context window.
332
+
333
+ See: https://huggingface.co/google/gemma-2-9b-it
334
+ """
335
+
336
+ model = 'gemma2-9b-it'
337
+
338
+
339
+ class GroqGemma_7B_IT(Groq): # pylint: disable=invalid-name
210
340
  """Gemma 7B with 8K context window.
211
341
 
212
342
  See: https://huggingface.co/google/gemma-1.1-7b-it
213
343
  """
214
344
 
215
345
  model = 'gemma-7b-it'
346
+
347
+
348
+ class GroqWhisper_Large_v3(Groq): # pylint: disable=invalid-name
349
+ """Whisper Large V3 with 8K context window.
350
+
351
+ See: https://huggingface.co/openai/whisper-large-v3
352
+ """
353
+
354
+ model = 'whisper-large-v3'
355
+
356
+
357
+ class GroqWhisper_Large_v3Turbo(Groq): # pylint: disable=invalid-name
358
+ """Whisper Large V3 Turbo with 8K context window.
359
+
360
+ See: https://huggingface.co/openai/whisper-large-v3-turbo
361
+ """
362
+
363
+ model = 'whisper-large-v3-turbo'