langfun 0.1.2.dev202410100804__py3-none-any.whl → 0.1.2.dev202410120803__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/__init__.py +1 -0
- langfun/core/eval/base_test.py +1 -0
- langfun/core/langfunc_test.py +2 -2
- langfun/core/language_model.py +140 -24
- langfun/core/language_model_test.py +166 -36
- langfun/core/llms/__init__.py +8 -1
- langfun/core/llms/anthropic.py +72 -7
- langfun/core/llms/cache/in_memory_test.py +3 -2
- langfun/core/llms/fake_test.py +7 -0
- langfun/core/llms/groq.py +154 -6
- langfun/core/llms/openai.py +300 -42
- langfun/core/llms/openai_test.py +35 -8
- langfun/core/llms/vertexai.py +121 -16
- langfun/core/logging.py +150 -43
- langfun/core/logging_test.py +33 -0
- langfun/core/message.py +249 -70
- langfun/core/message_test.py +70 -45
- langfun/core/modalities/audio.py +1 -1
- langfun/core/modalities/audio_test.py +1 -1
- langfun/core/modalities/image.py +1 -1
- langfun/core/modalities/image_test.py +9 -3
- langfun/core/modalities/mime.py +39 -3
- langfun/core/modalities/mime_test.py +39 -0
- langfun/core/modalities/ms_office.py +2 -5
- langfun/core/modalities/ms_office_test.py +1 -1
- langfun/core/modalities/pdf_test.py +1 -1
- langfun/core/modalities/video.py +1 -1
- langfun/core/modalities/video_test.py +2 -2
- langfun/core/structured/completion_test.py +1 -0
- langfun/core/structured/mapping.py +38 -0
- langfun/core/structured/mapping_test.py +55 -0
- langfun/core/structured/parsing_test.py +2 -1
- langfun/core/structured/prompting_test.py +1 -0
- langfun/core/structured/schema.py +34 -0
- langfun/core/template.py +110 -1
- langfun/core/template_test.py +37 -0
- langfun/core/templates/selfplay_test.py +4 -2
- {langfun-0.1.2.dev202410100804.dist-info → langfun-0.1.2.dev202410120803.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202410100804.dist-info → langfun-0.1.2.dev202410120803.dist-info}/RECORD +42 -42
- {langfun-0.1.2.dev202410100804.dist-info → langfun-0.1.2.dev202410120803.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202410100804.dist-info → langfun-0.1.2.dev202410120803.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202410100804.dist-info → langfun-0.1.2.dev202410120803.dist-info}/top_level.txt +0 -0
langfun/core/llms/vertexai.py
CHANGED
@@ -40,24 +40,106 @@ except ImportError:
|
|
40
40
|
Credentials = Any
|
41
41
|
|
42
42
|
|
43
|
+
# https://cloud.google.com/vertex-ai/generative-ai/pricing
|
44
|
+
# describes that the average number of characters per token is about 4.
|
45
|
+
AVGERAGE_CHARS_PER_TOEKN = 4
|
46
|
+
|
47
|
+
|
48
|
+
# Price in US dollars,
|
49
|
+
# from https://cloud.google.com/vertex-ai/generative-ai/pricing
|
50
|
+
# as of 2024-10-10.
|
43
51
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
44
|
-
'gemini-1.5-pro-001': pg.Dict(
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
'gemini-1.5-pro-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
'gemini-1.
|
52
|
+
'gemini-1.5-pro-001': pg.Dict(
|
53
|
+
api='gemini',
|
54
|
+
rpm=500,
|
55
|
+
cost_per_1k_input_chars=0.0003125,
|
56
|
+
cost_per_1k_output_chars=0.00125,
|
57
|
+
),
|
58
|
+
'gemini-1.5-pro-002': pg.Dict(
|
59
|
+
api='gemini',
|
60
|
+
rpm=500,
|
61
|
+
cost_per_1k_input_chars=0.0003125,
|
62
|
+
cost_per_1k_output_chars=0.00125,
|
63
|
+
),
|
64
|
+
'gemini-1.5-flash-002': pg.Dict(
|
65
|
+
api='gemini',
|
66
|
+
rpm=500,
|
67
|
+
cost_per_1k_input_chars=0.00001875,
|
68
|
+
cost_per_1k_output_chars=0.000075,
|
69
|
+
),
|
70
|
+
'gemini-1.5-flash-001': pg.Dict(
|
71
|
+
api='gemini',
|
72
|
+
rpm=500,
|
73
|
+
cost_per_1k_input_chars=0.00001875,
|
74
|
+
cost_per_1k_output_chars=0.000075,
|
75
|
+
),
|
76
|
+
'gemini-1.5-pro': pg.Dict(
|
77
|
+
api='gemini',
|
78
|
+
rpm=500,
|
79
|
+
cost_per_1k_input_chars=0.0003125,
|
80
|
+
cost_per_1k_output_chars=0.00125,
|
81
|
+
),
|
82
|
+
'gemini-1.5-flash': pg.Dict(
|
83
|
+
api='gemini',
|
84
|
+
rpm=500,
|
85
|
+
cost_per_1k_input_chars=0.00001875,
|
86
|
+
cost_per_1k_output_chars=0.000075,
|
87
|
+
),
|
88
|
+
'gemini-1.5-pro-latest': pg.Dict(
|
89
|
+
api='gemini',
|
90
|
+
rpm=500,
|
91
|
+
cost_per_1k_input_chars=0.0003125,
|
92
|
+
cost_per_1k_output_chars=0.00125,
|
93
|
+
),
|
94
|
+
'gemini-1.5-flash-latest': pg.Dict(
|
95
|
+
api='gemini',
|
96
|
+
rpm=500,
|
97
|
+
cost_per_1k_input_chars=0.00001875,
|
98
|
+
cost_per_1k_output_chars=0.000075,
|
99
|
+
),
|
100
|
+
'gemini-1.5-pro-preview-0514': pg.Dict(
|
101
|
+
api='gemini',
|
102
|
+
rpm=50,
|
103
|
+
cost_per_1k_input_chars=0.0003125,
|
104
|
+
cost_per_1k_output_chars=0.00125,
|
105
|
+
),
|
106
|
+
'gemini-1.5-pro-preview-0409': pg.Dict(
|
107
|
+
api='gemini',
|
108
|
+
rpm=50,
|
109
|
+
cost_per_1k_input_chars=0.0003125,
|
110
|
+
cost_per_1k_output_chars=0.00125,
|
111
|
+
),
|
112
|
+
'gemini-1.5-flash-preview-0514': pg.Dict(
|
113
|
+
api='gemini',
|
114
|
+
rpm=200,
|
115
|
+
cost_per_1k_input_chars=0.00001875,
|
116
|
+
cost_per_1k_output_chars=0.000075,
|
117
|
+
),
|
118
|
+
'gemini-1.0-pro': pg.Dict(
|
119
|
+
api='gemini',
|
120
|
+
rpm=300,
|
121
|
+
cost_per_1k_input_chars=0.000125,
|
122
|
+
cost_per_1k_output_chars=0.000375,
|
123
|
+
),
|
124
|
+
'gemini-1.0-pro-vision': pg.Dict(
|
125
|
+
api='gemini',
|
126
|
+
rpm=100,
|
127
|
+
cost_per_1k_input_chars=0.000125,
|
128
|
+
cost_per_1k_output_chars=0.000375,
|
129
|
+
),
|
57
130
|
# PaLM APIs.
|
58
|
-
'text-bison': pg.Dict(
|
59
|
-
|
60
|
-
|
131
|
+
'text-bison': pg.Dict(
|
132
|
+
api='palm',
|
133
|
+
rpm=1600
|
134
|
+
),
|
135
|
+
'text-bison-32k': pg.Dict(
|
136
|
+
api='palm',
|
137
|
+
rpm=300
|
138
|
+
),
|
139
|
+
'text-unicorn': pg.Dict(
|
140
|
+
api='palm',
|
141
|
+
rpm=100
|
142
|
+
),
|
61
143
|
# Endpoint
|
62
144
|
# TODO(chengrun): Set a more appropriate rpm for endpoint.
|
63
145
|
'custom': pg.Dict(api='endpoint', rpm=20),
|
@@ -161,6 +243,25 @@ class VertexAI(lf.LanguageModel):
|
|
161
243
|
tokens_per_min=0,
|
162
244
|
)
|
163
245
|
|
246
|
+
def estimate_cost(
|
247
|
+
self,
|
248
|
+
num_input_tokens: int,
|
249
|
+
num_output_tokens: int
|
250
|
+
) -> float | None:
|
251
|
+
"""Estimate the cost based on usage."""
|
252
|
+
cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
253
|
+
'cost_per_1k_input_chars', None
|
254
|
+
)
|
255
|
+
cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
|
256
|
+
'cost_per_1k_output_chars', None
|
257
|
+
)
|
258
|
+
if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
|
259
|
+
return None
|
260
|
+
return (
|
261
|
+
cost_per_1k_input_chars * num_input_tokens
|
262
|
+
+ cost_per_1k_output_chars * num_output_tokens
|
263
|
+
) * AVGERAGE_CHARS_PER_TOEKN / 1000
|
264
|
+
|
164
265
|
def _generation_config(
|
165
266
|
self, prompt: lf.Message, options: lf.LMSamplingOptions
|
166
267
|
) -> Any: # generative_models.GenerationConfig
|
@@ -285,6 +386,10 @@ class VertexAI(lf.LanguageModel):
|
|
285
386
|
prompt_tokens=usage_metadata.prompt_token_count,
|
286
387
|
completion_tokens=usage_metadata.candidates_token_count,
|
287
388
|
total_tokens=usage_metadata.total_token_count,
|
389
|
+
estimated_cost=self.estimate_cost(
|
390
|
+
num_input_tokens=usage_metadata.prompt_token_count,
|
391
|
+
num_output_tokens=usage_metadata.candidates_token_count,
|
392
|
+
),
|
288
393
|
)
|
289
394
|
return lf.LMSamplingResult(
|
290
395
|
[
|
langfun/core/logging.py
CHANGED
@@ -13,16 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Langfun event logging."""
|
15
15
|
|
16
|
-
from collections.abc import Iterator
|
17
16
|
import contextlib
|
18
17
|
import datetime
|
19
|
-
import io
|
20
18
|
import typing
|
21
|
-
from typing import Any, Literal
|
19
|
+
from typing import Any, Iterator, Literal, Sequence
|
22
20
|
|
23
21
|
from langfun.core import component
|
24
22
|
from langfun.core import console
|
25
|
-
from langfun.core import repr_utils
|
26
23
|
import pyglove as pg
|
27
24
|
|
28
25
|
|
@@ -56,49 +53,159 @@ class LogEntry(pg.Object):
|
|
56
53
|
def should_output(self, min_log_level: LogLevel) -> bool:
|
57
54
|
return _LOG_LEVELS.index(self.level) >= _LOG_LEVELS.index(min_log_level)
|
58
55
|
|
59
|
-
def
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
alt_text = self.time.strftime('%Y-%m-%d %H:%M:%S.%f')
|
89
|
-
return (
|
90
|
-
'<span style="background-color: #BBBBBB; color: white; '
|
91
|
-
'border-radius:5px; padding:0px 5px 0px 5px;" '
|
92
|
-
f'title="{alt_text}">{display_text}</span>'
|
56
|
+
def _html_tree_view_summary(
|
57
|
+
self,
|
58
|
+
view: pg.views.HtmlTreeView,
|
59
|
+
title: str | pg.Html | None = None,
|
60
|
+
max_str_len_for_summary: int = pg.View.PresetArgValue(80), # pytype: disable=annotation-type-mismatch
|
61
|
+
**kwargs
|
62
|
+
) -> str:
|
63
|
+
if len(self.message) > max_str_len_for_summary:
|
64
|
+
message = self.message[:max_str_len_for_summary] + '...'
|
65
|
+
else:
|
66
|
+
message = self.message
|
67
|
+
|
68
|
+
s = pg.Html(
|
69
|
+
pg.Html.element(
|
70
|
+
'span',
|
71
|
+
[self.time.strftime('%H:%M:%S')],
|
72
|
+
css_class=['log-time']
|
73
|
+
),
|
74
|
+
pg.Html.element(
|
75
|
+
'span',
|
76
|
+
[pg.Html.escape(message)],
|
77
|
+
css_class=['log-summary'],
|
78
|
+
),
|
79
|
+
)
|
80
|
+
return view.summary(
|
81
|
+
self,
|
82
|
+
title=title or s,
|
83
|
+
max_str_len_for_summary=max_str_len_for_summary,
|
84
|
+
**kwargs,
|
93
85
|
)
|
94
86
|
|
95
|
-
|
96
|
-
def
|
97
|
-
|
98
|
-
|
99
|
-
|
87
|
+
# pytype: disable=annotation-type-mismatch
|
88
|
+
def _html_tree_view_content(
|
89
|
+
self,
|
90
|
+
view: pg.views.HtmlTreeView,
|
91
|
+
root_path: pg.KeyPath,
|
92
|
+
collapse_log_metadata_level: int | None = pg.View.PresetArgValue(0),
|
93
|
+
max_str_len_for_summary: int = pg.View.PresetArgValue(80),
|
94
|
+
collapse_level: int | None = pg.View.PresetArgValue(1),
|
95
|
+
**kwargs
|
96
|
+
) -> pg.Html:
|
97
|
+
# pytype: enable=annotation-type-mismatch
|
98
|
+
def render_message_text():
|
99
|
+
if len(self.message) < max_str_len_for_summary:
|
100
|
+
return None
|
101
|
+
return pg.Html.element(
|
102
|
+
'span',
|
103
|
+
[pg.Html.escape(self.message)],
|
104
|
+
css_class=['log-text'],
|
105
|
+
)
|
106
|
+
|
107
|
+
def render_metadata():
|
108
|
+
if not self.metadata:
|
109
|
+
return None
|
110
|
+
child_path = root_path + 'metadata'
|
111
|
+
return pg.Html.element(
|
112
|
+
'div',
|
113
|
+
[
|
114
|
+
view.render(
|
115
|
+
self.metadata,
|
116
|
+
name='metadata',
|
117
|
+
root_path=child_path,
|
118
|
+
parent=self,
|
119
|
+
collapse_level=(
|
120
|
+
view.max_collapse_level(
|
121
|
+
collapse_level,
|
122
|
+
collapse_log_metadata_level,
|
123
|
+
child_path
|
124
|
+
)
|
125
|
+
)
|
126
|
+
)
|
127
|
+
],
|
128
|
+
css_class=['log-metadata'],
|
129
|
+
)
|
130
|
+
|
131
|
+
return pg.Html.element(
|
132
|
+
'div',
|
133
|
+
[
|
134
|
+
render_message_text(),
|
135
|
+
render_metadata(),
|
136
|
+
],
|
137
|
+
css_class=['complex_value'],
|
100
138
|
)
|
101
139
|
|
140
|
+
def _html_style(self) -> list[str]:
|
141
|
+
return super()._html_style() + [
|
142
|
+
"""
|
143
|
+
.log-time {
|
144
|
+
color: #222;
|
145
|
+
font-size: 12px;
|
146
|
+
padding-right: 10px;
|
147
|
+
}
|
148
|
+
.log-summary {
|
149
|
+
font-weight: normal;
|
150
|
+
font-style: italic;
|
151
|
+
padding: 4px;
|
152
|
+
}
|
153
|
+
.log-debug > summary > .summary_title::before {
|
154
|
+
content: '🛠️ '
|
155
|
+
}
|
156
|
+
.log-info > summary > .summary_title::before {
|
157
|
+
content: '💡 '
|
158
|
+
}
|
159
|
+
.log-warning > summary > .summary_title::before {
|
160
|
+
content: '❗ '
|
161
|
+
}
|
162
|
+
.log-error > summary > .summary_title::before {
|
163
|
+
content: '❌ '
|
164
|
+
}
|
165
|
+
.log-fatal > summary > .summary_title::before {
|
166
|
+
content: '💀 '
|
167
|
+
}
|
168
|
+
.log-text {
|
169
|
+
display: block;
|
170
|
+
color: black;
|
171
|
+
font-style: italic;
|
172
|
+
padding: 20px;
|
173
|
+
border-radius: 5px;
|
174
|
+
background: rgba(255, 255, 255, 0.5);
|
175
|
+
white-space: pre-wrap;
|
176
|
+
}
|
177
|
+
details.log-entry {
|
178
|
+
margin: 0px 0px 10px;
|
179
|
+
border: 0px;
|
180
|
+
}
|
181
|
+
div.log-metadata {
|
182
|
+
margin: 10px 0px 0px 0px;
|
183
|
+
}
|
184
|
+
.log-metadata > details {
|
185
|
+
background-color: rgba(255, 255, 255, 0.5);
|
186
|
+
border: 1px solid transparent;
|
187
|
+
}
|
188
|
+
.log-debug {
|
189
|
+
background-color: #EEEEEE
|
190
|
+
}
|
191
|
+
.log-warning {
|
192
|
+
background-color: #F8C471
|
193
|
+
}
|
194
|
+
.log-info {
|
195
|
+
background-color: #A3E4D7
|
196
|
+
}
|
197
|
+
.log-error {
|
198
|
+
background-color: #F5C6CB
|
199
|
+
}
|
200
|
+
.log-fatal {
|
201
|
+
background-color: #F19CBB
|
202
|
+
}
|
203
|
+
"""
|
204
|
+
]
|
205
|
+
|
206
|
+
def _html_element_class(self) -> Sequence[str] | None:
|
207
|
+
return super()._html_element_class() + [f'log-{self.level}']
|
208
|
+
|
102
209
|
|
103
210
|
def log(level: LogLevel,
|
104
211
|
message: str,
|
langfun/core/logging_test.py
CHANGED
@@ -13,6 +13,8 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Tests for langfun.core.logging."""
|
15
15
|
|
16
|
+
import datetime
|
17
|
+
import inspect
|
16
18
|
import unittest
|
17
19
|
|
18
20
|
from langfun.core import logging
|
@@ -52,6 +54,37 @@ class LoggingTest(unittest.TestCase):
|
|
52
54
|
assert_color(logging.error('hi', indent=2, x=1, y=2), '#F5C6CB')
|
53
55
|
assert_color(logging.fatal('hi', indent=2, x=1, y=2), '#F19CBB')
|
54
56
|
|
57
|
+
def assert_html_content(self, html, expected):
|
58
|
+
expected = inspect.cleandoc(expected).strip()
|
59
|
+
actual = html.content.strip()
|
60
|
+
if actual != expected:
|
61
|
+
print(actual)
|
62
|
+
self.assertEqual(actual, expected)
|
63
|
+
|
64
|
+
def test_html(self):
|
65
|
+
time = datetime.datetime(2024, 10, 10, 12, 30, 45)
|
66
|
+
self.assert_html_content(
|
67
|
+
logging.LogEntry(
|
68
|
+
level='info', message='5 + 2 > 3',
|
69
|
+
time=time, metadata={}
|
70
|
+
).to_html(enable_summary_tooltip=False),
|
71
|
+
"""
|
72
|
+
<details open class="pyglove log-entry log-info"><summary><div class="summary_title"><span class="log-time">12:30:45</span><span class="log-summary">5 + 2 > 3</span></div></summary><div class="complex_value"></div></details>
|
73
|
+
"""
|
74
|
+
)
|
75
|
+
self.assert_html_content(
|
76
|
+
logging.LogEntry(
|
77
|
+
level='error', message='This is a longer message: 5 + 2 > 3',
|
78
|
+
time=time, metadata=dict(x=1, y=2)
|
79
|
+
).to_html(
|
80
|
+
max_str_len_for_summary=10,
|
81
|
+
enable_summary_tooltip=False,
|
82
|
+
collapse_log_metadata_level=1
|
83
|
+
),
|
84
|
+
"""
|
85
|
+
<details open class="pyglove log-entry log-error"><summary><div class="summary_title"><span class="log-time">12:30:45</span><span class="log-summary">This is a ...</span></div></summary><div class="complex_value"><span class="log-text">This is a longer message: 5 + 2 > 3</span><div class="log-metadata"><details open class="pyglove dict"><summary><div class="summary_name">metadata</div><div class="summary_title">Dict(...)</div></summary><div class="complex_value dict"><table><tr><td><span class="object_key str">x</span><span class="tooltip key-path">metadata.x</span></td><td><div><span class="simple_value int">1</span></div></td></tr><tr><td><span class="object_key str">y</span><span class="tooltip key-path">metadata.y</span></td><td><div><span class="simple_value int">2</span></div></td></tr></table></div></details></div></div></details>
|
86
|
+
"""
|
87
|
+
)
|
55
88
|
|
56
89
|
if __name__ == '__main__':
|
57
90
|
unittest.main()
|