langfun 0.1.2.dev202410100804__py3-none-any.whl → 0.1.2.dev202410120803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. langfun/core/__init__.py +1 -0
  2. langfun/core/eval/base_test.py +1 -0
  3. langfun/core/langfunc_test.py +2 -2
  4. langfun/core/language_model.py +140 -24
  5. langfun/core/language_model_test.py +166 -36
  6. langfun/core/llms/__init__.py +8 -1
  7. langfun/core/llms/anthropic.py +72 -7
  8. langfun/core/llms/cache/in_memory_test.py +3 -2
  9. langfun/core/llms/fake_test.py +7 -0
  10. langfun/core/llms/groq.py +154 -6
  11. langfun/core/llms/openai.py +300 -42
  12. langfun/core/llms/openai_test.py +35 -8
  13. langfun/core/llms/vertexai.py +121 -16
  14. langfun/core/logging.py +150 -43
  15. langfun/core/logging_test.py +33 -0
  16. langfun/core/message.py +249 -70
  17. langfun/core/message_test.py +70 -45
  18. langfun/core/modalities/audio.py +1 -1
  19. langfun/core/modalities/audio_test.py +1 -1
  20. langfun/core/modalities/image.py +1 -1
  21. langfun/core/modalities/image_test.py +9 -3
  22. langfun/core/modalities/mime.py +39 -3
  23. langfun/core/modalities/mime_test.py +39 -0
  24. langfun/core/modalities/ms_office.py +2 -5
  25. langfun/core/modalities/ms_office_test.py +1 -1
  26. langfun/core/modalities/pdf_test.py +1 -1
  27. langfun/core/modalities/video.py +1 -1
  28. langfun/core/modalities/video_test.py +2 -2
  29. langfun/core/structured/completion_test.py +1 -0
  30. langfun/core/structured/mapping.py +38 -0
  31. langfun/core/structured/mapping_test.py +55 -0
  32. langfun/core/structured/parsing_test.py +2 -1
  33. langfun/core/structured/prompting_test.py +1 -0
  34. langfun/core/structured/schema.py +34 -0
  35. langfun/core/template.py +110 -1
  36. langfun/core/template_test.py +37 -0
  37. langfun/core/templates/selfplay_test.py +4 -2
  38. {langfun-0.1.2.dev202410100804.dist-info → langfun-0.1.2.dev202410120803.dist-info}/METADATA +1 -1
  39. {langfun-0.1.2.dev202410100804.dist-info → langfun-0.1.2.dev202410120803.dist-info}/RECORD +42 -42
  40. {langfun-0.1.2.dev202410100804.dist-info → langfun-0.1.2.dev202410120803.dist-info}/LICENSE +0 -0
  41. {langfun-0.1.2.dev202410100804.dist-info → langfun-0.1.2.dev202410120803.dist-info}/WHEEL +0 -0
  42. {langfun-0.1.2.dev202410100804.dist-info → langfun-0.1.2.dev202410120803.dist-info}/top_level.txt +0 -0
@@ -40,24 +40,106 @@ except ImportError:
40
40
  Credentials = Any
41
41
 
42
42
 
43
+ # https://cloud.google.com/vertex-ai/generative-ai/pricing
44
+ # describes that the average number of characters per token is about 4.
45
+ AVGERAGE_CHARS_PER_TOEKN = 4
46
+
47
+
48
+ # Price in US dollars,
49
+ # from https://cloud.google.com/vertex-ai/generative-ai/pricing
50
+ # as of 2024-10-10.
43
51
  SUPPORTED_MODELS_AND_SETTINGS = {
44
- 'gemini-1.5-pro-001': pg.Dict(api='gemini', rpm=500),
45
- 'gemini-1.5-pro-002': pg.Dict(api='gemini', rpm=500),
46
- 'gemini-1.5-flash-002': pg.Dict(api='gemini', rpm=500),
47
- 'gemini-1.5-flash-001': pg.Dict(api='gemini', rpm=500),
48
- 'gemini-1.5-pro': pg.Dict(api='gemini', rpm=500),
49
- 'gemini-1.5-flash': pg.Dict(api='gemini', rpm=500),
50
- 'gemini-1.5-pro-latest': pg.Dict(api='gemini', rpm=500),
51
- 'gemini-1.5-flash-latest': pg.Dict(api='gemini', rpm=500),
52
- 'gemini-1.5-pro-preview-0514': pg.Dict(api='gemini', rpm=50),
53
- 'gemini-1.5-pro-preview-0409': pg.Dict(api='gemini', rpm=50),
54
- 'gemini-1.5-flash-preview-0514': pg.Dict(api='gemini', rpm=200),
55
- 'gemini-1.0-pro': pg.Dict(api='gemini', rpm=300),
56
- 'gemini-1.0-pro-vision': pg.Dict(api='gemini', rpm=100),
52
+ 'gemini-1.5-pro-001': pg.Dict(
53
+ api='gemini',
54
+ rpm=500,
55
+ cost_per_1k_input_chars=0.0003125,
56
+ cost_per_1k_output_chars=0.00125,
57
+ ),
58
+ 'gemini-1.5-pro-002': pg.Dict(
59
+ api='gemini',
60
+ rpm=500,
61
+ cost_per_1k_input_chars=0.0003125,
62
+ cost_per_1k_output_chars=0.00125,
63
+ ),
64
+ 'gemini-1.5-flash-002': pg.Dict(
65
+ api='gemini',
66
+ rpm=500,
67
+ cost_per_1k_input_chars=0.00001875,
68
+ cost_per_1k_output_chars=0.000075,
69
+ ),
70
+ 'gemini-1.5-flash-001': pg.Dict(
71
+ api='gemini',
72
+ rpm=500,
73
+ cost_per_1k_input_chars=0.00001875,
74
+ cost_per_1k_output_chars=0.000075,
75
+ ),
76
+ 'gemini-1.5-pro': pg.Dict(
77
+ api='gemini',
78
+ rpm=500,
79
+ cost_per_1k_input_chars=0.0003125,
80
+ cost_per_1k_output_chars=0.00125,
81
+ ),
82
+ 'gemini-1.5-flash': pg.Dict(
83
+ api='gemini',
84
+ rpm=500,
85
+ cost_per_1k_input_chars=0.00001875,
86
+ cost_per_1k_output_chars=0.000075,
87
+ ),
88
+ 'gemini-1.5-pro-latest': pg.Dict(
89
+ api='gemini',
90
+ rpm=500,
91
+ cost_per_1k_input_chars=0.0003125,
92
+ cost_per_1k_output_chars=0.00125,
93
+ ),
94
+ 'gemini-1.5-flash-latest': pg.Dict(
95
+ api='gemini',
96
+ rpm=500,
97
+ cost_per_1k_input_chars=0.00001875,
98
+ cost_per_1k_output_chars=0.000075,
99
+ ),
100
+ 'gemini-1.5-pro-preview-0514': pg.Dict(
101
+ api='gemini',
102
+ rpm=50,
103
+ cost_per_1k_input_chars=0.0003125,
104
+ cost_per_1k_output_chars=0.00125,
105
+ ),
106
+ 'gemini-1.5-pro-preview-0409': pg.Dict(
107
+ api='gemini',
108
+ rpm=50,
109
+ cost_per_1k_input_chars=0.0003125,
110
+ cost_per_1k_output_chars=0.00125,
111
+ ),
112
+ 'gemini-1.5-flash-preview-0514': pg.Dict(
113
+ api='gemini',
114
+ rpm=200,
115
+ cost_per_1k_input_chars=0.00001875,
116
+ cost_per_1k_output_chars=0.000075,
117
+ ),
118
+ 'gemini-1.0-pro': pg.Dict(
119
+ api='gemini',
120
+ rpm=300,
121
+ cost_per_1k_input_chars=0.000125,
122
+ cost_per_1k_output_chars=0.000375,
123
+ ),
124
+ 'gemini-1.0-pro-vision': pg.Dict(
125
+ api='gemini',
126
+ rpm=100,
127
+ cost_per_1k_input_chars=0.000125,
128
+ cost_per_1k_output_chars=0.000375,
129
+ ),
57
130
  # PaLM APIs.
58
- 'text-bison': pg.Dict(api='palm', rpm=1600),
59
- 'text-bison-32k': pg.Dict(api='palm', rpm=300),
60
- 'text-unicorn': pg.Dict(api='palm', rpm=100),
131
+ 'text-bison': pg.Dict(
132
+ api='palm',
133
+ rpm=1600
134
+ ),
135
+ 'text-bison-32k': pg.Dict(
136
+ api='palm',
137
+ rpm=300
138
+ ),
139
+ 'text-unicorn': pg.Dict(
140
+ api='palm',
141
+ rpm=100
142
+ ),
61
143
  # Endpoint
62
144
  # TODO(chengrun): Set a more appropriate rpm for endpoint.
63
145
  'custom': pg.Dict(api='endpoint', rpm=20),
@@ -161,6 +243,25 @@ class VertexAI(lf.LanguageModel):
161
243
  tokens_per_min=0,
162
244
  )
163
245
 
246
+ def estimate_cost(
247
+ self,
248
+ num_input_tokens: int,
249
+ num_output_tokens: int
250
+ ) -> float | None:
251
+ """Estimate the cost based on usage."""
252
+ cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
253
+ 'cost_per_1k_input_chars', None
254
+ )
255
+ cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
256
+ 'cost_per_1k_output_chars', None
257
+ )
258
+ if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
259
+ return None
260
+ return (
261
+ cost_per_1k_input_chars * num_input_tokens
262
+ + cost_per_1k_output_chars * num_output_tokens
263
+ ) * AVGERAGE_CHARS_PER_TOEKN / 1000
264
+
164
265
  def _generation_config(
165
266
  self, prompt: lf.Message, options: lf.LMSamplingOptions
166
267
  ) -> Any: # generative_models.GenerationConfig
@@ -285,6 +386,10 @@ class VertexAI(lf.LanguageModel):
285
386
  prompt_tokens=usage_metadata.prompt_token_count,
286
387
  completion_tokens=usage_metadata.candidates_token_count,
287
388
  total_tokens=usage_metadata.total_token_count,
389
+ estimated_cost=self.estimate_cost(
390
+ num_input_tokens=usage_metadata.prompt_token_count,
391
+ num_output_tokens=usage_metadata.candidates_token_count,
392
+ ),
288
393
  )
289
394
  return lf.LMSamplingResult(
290
395
  [
langfun/core/logging.py CHANGED
@@ -13,16 +13,13 @@
13
13
  # limitations under the License.
14
14
  """Langfun event logging."""
15
15
 
16
- from collections.abc import Iterator
17
16
  import contextlib
18
17
  import datetime
19
- import io
20
18
  import typing
21
- from typing import Any, Literal
19
+ from typing import Any, Iterator, Literal, Sequence
22
20
 
23
21
  from langfun.core import component
24
22
  from langfun.core import console
25
- from langfun.core import repr_utils
26
23
  import pyglove as pg
27
24
 
28
25
 
@@ -56,49 +53,159 @@ class LogEntry(pg.Object):
56
53
  def should_output(self, min_log_level: LogLevel) -> bool:
57
54
  return _LOG_LEVELS.index(self.level) >= _LOG_LEVELS.index(min_log_level)
58
55
 
59
- def _repr_html_(self) -> str:
60
- s = io.StringIO()
61
- padding_left = 50 * self.indent
62
- s.write(f'<div style="padding-left: {padding_left}px;">')
63
- s.write(self._message_display)
64
- if self.metadata:
65
- s.write(repr_utils.html_repr(self.metadata))
66
- s.write('</div>')
67
- return s.getvalue()
68
-
69
- @property
70
- def _message_text_bgcolor(self) -> str:
71
- match self.level:
72
- case 'debug':
73
- return '#EEEEEE'
74
- case 'info':
75
- return '#A3E4D7'
76
- case 'warning':
77
- return '#F8C471'
78
- case 'error':
79
- return '#F5C6CB'
80
- case 'fatal':
81
- return '#F19CBB'
82
- case _:
83
- raise ValueError(f'Unknown log level: {self.level}')
84
-
85
- @property
86
- def _time_display(self) -> str:
87
- display_text = self.time.strftime('%H:%M:%S')
88
- alt_text = self.time.strftime('%Y-%m-%d %H:%M:%S.%f')
89
- return (
90
- '<span style="background-color: #BBBBBB; color: white; '
91
- 'border-radius:5px; padding:0px 5px 0px 5px;" '
92
- f'title="{alt_text}">{display_text}</span>'
56
+ def _html_tree_view_summary(
57
+ self,
58
+ view: pg.views.HtmlTreeView,
59
+ title: str | pg.Html | None = None,
60
+ max_str_len_for_summary: int = pg.View.PresetArgValue(80), # pytype: disable=annotation-type-mismatch
61
+ **kwargs
62
+ ) -> str:
63
+ if len(self.message) > max_str_len_for_summary:
64
+ message = self.message[:max_str_len_for_summary] + '...'
65
+ else:
66
+ message = self.message
67
+
68
+ s = pg.Html(
69
+ pg.Html.element(
70
+ 'span',
71
+ [self.time.strftime('%H:%M:%S')],
72
+ css_class=['log-time']
73
+ ),
74
+ pg.Html.element(
75
+ 'span',
76
+ [pg.Html.escape(message)],
77
+ css_class=['log-summary'],
78
+ ),
79
+ )
80
+ return view.summary(
81
+ self,
82
+ title=title or s,
83
+ max_str_len_for_summary=max_str_len_for_summary,
84
+ **kwargs,
93
85
  )
94
86
 
95
- @property
96
- def _message_display(self) -> str:
97
- return repr_utils.html_round_text(
98
- self._time_display + '&nbsp;' + self.message,
99
- background_color=self._message_text_bgcolor,
87
+ # pytype: disable=annotation-type-mismatch
88
+ def _html_tree_view_content(
89
+ self,
90
+ view: pg.views.HtmlTreeView,
91
+ root_path: pg.KeyPath,
92
+ collapse_log_metadata_level: int | None = pg.View.PresetArgValue(0),
93
+ max_str_len_for_summary: int = pg.View.PresetArgValue(80),
94
+ collapse_level: int | None = pg.View.PresetArgValue(1),
95
+ **kwargs
96
+ ) -> pg.Html:
97
+ # pytype: enable=annotation-type-mismatch
98
+ def render_message_text():
99
+ if len(self.message) < max_str_len_for_summary:
100
+ return None
101
+ return pg.Html.element(
102
+ 'span',
103
+ [pg.Html.escape(self.message)],
104
+ css_class=['log-text'],
105
+ )
106
+
107
+ def render_metadata():
108
+ if not self.metadata:
109
+ return None
110
+ child_path = root_path + 'metadata'
111
+ return pg.Html.element(
112
+ 'div',
113
+ [
114
+ view.render(
115
+ self.metadata,
116
+ name='metadata',
117
+ root_path=child_path,
118
+ parent=self,
119
+ collapse_level=(
120
+ view.max_collapse_level(
121
+ collapse_level,
122
+ collapse_log_metadata_level,
123
+ child_path
124
+ )
125
+ )
126
+ )
127
+ ],
128
+ css_class=['log-metadata'],
129
+ )
130
+
131
+ return pg.Html.element(
132
+ 'div',
133
+ [
134
+ render_message_text(),
135
+ render_metadata(),
136
+ ],
137
+ css_class=['complex_value'],
100
138
  )
101
139
 
140
+ def _html_style(self) -> list[str]:
141
+ return super()._html_style() + [
142
+ """
143
+ .log-time {
144
+ color: #222;
145
+ font-size: 12px;
146
+ padding-right: 10px;
147
+ }
148
+ .log-summary {
149
+ font-weight: normal;
150
+ font-style: italic;
151
+ padding: 4px;
152
+ }
153
+ .log-debug > summary > .summary_title::before {
154
+ content: '🛠️ '
155
+ }
156
+ .log-info > summary > .summary_title::before {
157
+ content: '💡 '
158
+ }
159
+ .log-warning > summary > .summary_title::before {
160
+ content: '❗ '
161
+ }
162
+ .log-error > summary > .summary_title::before {
163
+ content: '❌ '
164
+ }
165
+ .log-fatal > summary > .summary_title::before {
166
+ content: '💀 '
167
+ }
168
+ .log-text {
169
+ display: block;
170
+ color: black;
171
+ font-style: italic;
172
+ padding: 20px;
173
+ border-radius: 5px;
174
+ background: rgba(255, 255, 255, 0.5);
175
+ white-space: pre-wrap;
176
+ }
177
+ details.log-entry {
178
+ margin: 0px 0px 10px;
179
+ border: 0px;
180
+ }
181
+ div.log-metadata {
182
+ margin: 10px 0px 0px 0px;
183
+ }
184
+ .log-metadata > details {
185
+ background-color: rgba(255, 255, 255, 0.5);
186
+ border: 1px solid transparent;
187
+ }
188
+ .log-debug {
189
+ background-color: #EEEEEE
190
+ }
191
+ .log-warning {
192
+ background-color: #F8C471
193
+ }
194
+ .log-info {
195
+ background-color: #A3E4D7
196
+ }
197
+ .log-error {
198
+ background-color: #F5C6CB
199
+ }
200
+ .log-fatal {
201
+ background-color: #F19CBB
202
+ }
203
+ """
204
+ ]
205
+
206
+ def _html_element_class(self) -> Sequence[str] | None:
207
+ return super()._html_element_class() + [f'log-{self.level}']
208
+
102
209
 
103
210
  def log(level: LogLevel,
104
211
  message: str,
@@ -13,6 +13,8 @@
13
13
  # limitations under the License.
14
14
  """Tests for langfun.core.logging."""
15
15
 
16
+ import datetime
17
+ import inspect
16
18
  import unittest
17
19
 
18
20
  from langfun.core import logging
@@ -52,6 +54,37 @@ class LoggingTest(unittest.TestCase):
52
54
  assert_color(logging.error('hi', indent=2, x=1, y=2), '#F5C6CB')
53
55
  assert_color(logging.fatal('hi', indent=2, x=1, y=2), '#F19CBB')
54
56
 
57
+ def assert_html_content(self, html, expected):
58
+ expected = inspect.cleandoc(expected).strip()
59
+ actual = html.content.strip()
60
+ if actual != expected:
61
+ print(actual)
62
+ self.assertEqual(actual, expected)
63
+
64
+ def test_html(self):
65
+ time = datetime.datetime(2024, 10, 10, 12, 30, 45)
66
+ self.assert_html_content(
67
+ logging.LogEntry(
68
+ level='info', message='5 + 2 > 3',
69
+ time=time, metadata={}
70
+ ).to_html(enable_summary_tooltip=False),
71
+ """
72
+ <details open class="pyglove log-entry log-info"><summary><div class="summary_title"><span class="log-time">12:30:45</span><span class="log-summary">5 + 2 &gt; 3</span></div></summary><div class="complex_value"></div></details>
73
+ """
74
+ )
75
+ self.assert_html_content(
76
+ logging.LogEntry(
77
+ level='error', message='This is a longer message: 5 + 2 > 3',
78
+ time=time, metadata=dict(x=1, y=2)
79
+ ).to_html(
80
+ max_str_len_for_summary=10,
81
+ enable_summary_tooltip=False,
82
+ collapse_log_metadata_level=1
83
+ ),
84
+ """
85
+ <details open class="pyglove log-entry log-error"><summary><div class="summary_title"><span class="log-time">12:30:45</span><span class="log-summary">This is a ...</span></div></summary><div class="complex_value"><span class="log-text">This is a longer message: 5 + 2 &gt; 3</span><div class="log-metadata"><details open class="pyglove dict"><summary><div class="summary_name">metadata</div><div class="summary_title">Dict(...)</div></summary><div class="complex_value dict"><table><tr><td><span class="object_key str">x</span><span class="tooltip key-path">metadata.x</span></td><td><div><span class="simple_value int">1</span></div></td></tr><tr><td><span class="object_key str">y</span><span class="tooltip key-path">metadata.y</span></td><td><div><span class="simple_value int">2</span></div></td></tr></table></div></details></div></div></details>
86
+ """
87
+ )
55
88
 
56
89
  if __name__ == '__main__':
57
90
  unittest.main()