langfun 0.1.2.dev202409100804__py3-none-any.whl → 0.1.2.dev202409130804__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langfun/core/component.py +1 -1
- langfun/core/concurrent.py +16 -1
- langfun/core/concurrent_test.py +1 -1
- langfun/core/eval/base.py +1 -1
- langfun/core/llms/__init__.py +5 -0
- langfun/core/llms/google_genai.py +4 -4
- langfun/core/llms/google_genai_test.py +27 -25
- langfun/core/llms/openai.py +35 -4
- langfun/core/llms/openai_test.py +8 -2
- langfun/core/structured/__init__.py +1 -0
- langfun/core/structured/prompting.py +55 -0
- langfun/core/structured/prompting_test.py +124 -0
- langfun/core/structured/schema.py +2 -12
- {langfun-0.1.2.dev202409100804.dist-info → langfun-0.1.2.dev202409130804.dist-info}/METADATA +3 -3
- {langfun-0.1.2.dev202409100804.dist-info → langfun-0.1.2.dev202409130804.dist-info}/RECORD +18 -18
- {langfun-0.1.2.dev202409100804.dist-info → langfun-0.1.2.dev202409130804.dist-info}/LICENSE +0 -0
- {langfun-0.1.2.dev202409100804.dist-info → langfun-0.1.2.dev202409130804.dist-info}/WHEEL +0 -0
- {langfun-0.1.2.dev202409100804.dist-info → langfun-0.1.2.dev202409130804.dist-info}/top_level.txt +0 -0
langfun/core/component.py
CHANGED
@@ -73,7 +73,7 @@ class Component(pg.Object):
|
|
73
73
|
field.value.set_default(attr_value)
|
74
74
|
additional_fields.append(field)
|
75
75
|
if additional_fields:
|
76
|
-
|
76
|
+
cls.update_schema(additional_fields)
|
77
77
|
|
78
78
|
def _on_bound(self):
|
79
79
|
super()._on_bound()
|
langfun/core/concurrent.py
CHANGED
@@ -469,7 +469,9 @@ class ProgressBar:
|
|
469
469
|
# Process uninstall requests.
|
470
470
|
if cls._uninstall_requests:
|
471
471
|
for bar_id in cls._uninstall_requests:
|
472
|
-
cls._progress_bars.pop(bar_id, None)
|
472
|
+
bar = cls._progress_bars.pop(bar_id, None)
|
473
|
+
if bar is not None:
|
474
|
+
bar.close()
|
473
475
|
cls._uninstall_requests.clear()
|
474
476
|
|
475
477
|
|
@@ -765,6 +767,10 @@ class _ProgressControl(pg.Object):
|
|
765
767
|
def refresh(self) -> None:
|
766
768
|
"""Refresh progress bar."""
|
767
769
|
|
770
|
+
@abc.abstractmethod
|
771
|
+
def close(self) -> None:
|
772
|
+
"""Close progress bar."""
|
773
|
+
|
768
774
|
|
769
775
|
class _TqdmProgressControl(_ProgressControl):
|
770
776
|
"""Tqdm-based progress control."""
|
@@ -791,6 +797,9 @@ class _TqdmProgressControl(_ProgressControl):
|
|
791
797
|
self._tqdm.colour = self.color
|
792
798
|
self._tqdm.refresh()
|
793
799
|
|
800
|
+
def close(self):
|
801
|
+
self._tqdm.close()
|
802
|
+
|
794
803
|
|
795
804
|
class _ConsoleProgressControl(_ProgressControl):
|
796
805
|
"""Simple progress control by printing the status to the console."""
|
@@ -824,6 +833,9 @@ class _ConsoleProgressControl(_ProgressControl):
|
|
824
833
|
s.write(f' : {status}')
|
825
834
|
sys.stderr.write(s.getvalue() + '\n')
|
826
835
|
|
836
|
+
def close(self):
|
837
|
+
sys.stderr.flush()
|
838
|
+
|
827
839
|
|
828
840
|
class _NoopProgressControl(_ProgressControl):
|
829
841
|
"""No-op progress control."""
|
@@ -834,6 +846,9 @@ class _NoopProgressControl(_ProgressControl):
|
|
834
846
|
def refresh(self) -> None:
|
835
847
|
pass
|
836
848
|
|
849
|
+
def close(self) -> None:
|
850
|
+
pass
|
851
|
+
|
837
852
|
|
838
853
|
def _progress_control(
|
839
854
|
total: int,
|
langfun/core/concurrent_test.py
CHANGED
@@ -576,8 +576,8 @@ class ConcurrentMapTest(unittest.TestCase):
|
|
576
576
|
(3, pg.MISSING_VALUE),
|
577
577
|
],
|
578
578
|
)
|
579
|
-
self.assertIn('100%', string_io.getvalue())
|
580
579
|
concurrent.ProgressBar.uninstall(bar_id)
|
580
|
+
self.assertIn('100%', string_io.getvalue())
|
581
581
|
|
582
582
|
|
583
583
|
class ExecutorPoolTest(unittest.TestCase):
|
langfun/core/eval/base.py
CHANGED
@@ -941,7 +941,7 @@ class Evaluation(Evaluable):
|
|
941
941
|
|
942
942
|
fields = list(cls.__schema__.values())
|
943
943
|
fields.insert(0, (self.completion_prompt_field, pg.typing.Str()))
|
944
|
-
|
944
|
+
cls.update_schema(fields, extend=False)
|
945
945
|
|
946
946
|
def _maybe_adjust_examples_for_completion(
|
947
947
|
self,
|
langfun/core/llms/__init__.py
CHANGED
@@ -39,6 +39,11 @@ from langfun.core.llms.google_genai import Palm2_IT
|
|
39
39
|
# OpenAI models.
|
40
40
|
from langfun.core.llms.openai import OpenAI
|
41
41
|
|
42
|
+
from langfun.core.llms.openai import GptO1Preview
|
43
|
+
from langfun.core.llms.openai import GptO1Preview_20240912
|
44
|
+
from langfun.core.llms.openai import GptO1Mini
|
45
|
+
from langfun.core.llms.openai import GptO1Mini_20240912
|
46
|
+
|
42
47
|
from langfun.core.llms.openai import Gpt4oMini
|
43
48
|
from langfun.core.llms.openai import Gpt4oMini_20240718
|
44
49
|
from langfun.core.llms.openai import Gpt4o
|
@@ -27,18 +27,18 @@ try:
|
|
27
27
|
import google.generativeai as genai # pylint: disable=g-import-not-at-top
|
28
28
|
BlobDict = genai.types.BlobDict
|
29
29
|
GenerativeModel = genai.GenerativeModel
|
30
|
-
Completion = genai.types
|
30
|
+
Completion = getattr(genai.types, 'Completion', Any)
|
31
|
+
ChatResponse = getattr(genai.types, 'ChatResponse', Any)
|
32
|
+
GenerateContentResponse = getattr(genai.types, 'GenerateContentResponse', Any)
|
31
33
|
GenerationConfig = genai.GenerationConfig
|
32
|
-
GenerateContentResponse = genai.types.GenerateContentResponse
|
33
|
-
ChatResponse = genai.types.ChatResponse
|
34
34
|
except ImportError:
|
35
35
|
genai = None
|
36
36
|
BlobDict = Any
|
37
37
|
GenerativeModel = Any
|
38
38
|
Completion = Any
|
39
|
+
ChatResponse = Any
|
39
40
|
GenerationConfig = Any
|
40
41
|
GenerateContentResponse = Any
|
41
|
-
ChatResponse = Any
|
42
42
|
|
43
43
|
|
44
44
|
@lf.use_init_args(['model'])
|
@@ -192,39 +192,41 @@ class GenAITest(unittest.TestCase):
|
|
192
192
|
def test_call_with_legacy_completion_model(self):
|
193
193
|
orig_get_model = genai.get_model
|
194
194
|
genai.get_model = mock_get_model
|
195
|
-
orig_generate_text = genai
|
196
|
-
|
195
|
+
orig_generate_text = getattr(genai, 'generate_text', None)
|
196
|
+
if orig_generate_text is not None:
|
197
|
+
genai.generate_text = mock_generate_text
|
197
198
|
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
199
|
+
lm = google_genai.Palm2(api_key='test_key')
|
200
|
+
self.maxDiff = None
|
201
|
+
self.assertEqual(
|
202
|
+
lm('hello', temperature=2.0, top_k=20).text,
|
203
|
+
(
|
204
|
+
"hello to models/text-bison-001 with {'temperature': 2.0, "
|
205
|
+
"'top_k': 20, 'top_p': None, 'candidate_count': 1, "
|
206
|
+
"'max_output_tokens': None, 'stop_sequences': None}"
|
207
|
+
),
|
208
|
+
)
|
209
|
+
genai.generate_text = orig_generate_text
|
208
210
|
genai.get_model = orig_get_model
|
209
|
-
genai.generate_text = orig_generate_text
|
210
211
|
|
211
212
|
def test_call_with_legacy_chat_model(self):
|
212
213
|
orig_get_model = genai.get_model
|
213
214
|
genai.get_model = mock_get_model
|
214
|
-
orig_chat = genai
|
215
|
-
|
215
|
+
orig_chat = getattr(genai, 'chat', None)
|
216
|
+
if orig_chat is not None:
|
217
|
+
genai.chat = mock_chat
|
216
218
|
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
219
|
+
lm = google_genai.Palm2_IT(api_key='test_key')
|
220
|
+
self.maxDiff = None
|
221
|
+
self.assertEqual(
|
222
|
+
lm('hello', temperature=2.0, top_k=20).text,
|
223
|
+
(
|
224
|
+
"hello to models/chat-bison-001 with {'temperature': 2.0, "
|
225
|
+
"'top_k': 20, 'top_p': None, 'candidate_count': 1}"
|
226
|
+
),
|
227
|
+
)
|
228
|
+
genai.chat = orig_chat
|
226
229
|
genai.get_model = orig_get_model
|
227
|
-
genai.chat = orig_chat
|
228
230
|
|
229
231
|
|
230
232
|
if __name__ == '__main__':
|
langfun/core/llms/openai.py
CHANGED
@@ -49,6 +49,11 @@ _DEFAULT_RPM = 3000
|
|
49
49
|
SUPPORTED_MODELS_AND_SETTINGS = {
|
50
50
|
# Models from https://platform.openai.com/docs/models
|
51
51
|
# RPM is from https://platform.openai.com/docs/guides/rate-limits
|
52
|
+
# o1 (preview) models.
|
53
|
+
'o1-preview': pg.Dict(rpm=10000, tpm=5000000),
|
54
|
+
'o1-preview-2024-09-12': pg.Dict(rpm=10000, tpm=5000000),
|
55
|
+
'o1-mini': pg.Dict(rpm=10000, tpm=5000000),
|
56
|
+
'o1-mini-2024-09-12': pg.Dict(rpm=10000, tpm=5000000),
|
52
57
|
# GPT-4o models
|
53
58
|
'gpt-4o-mini': pg.Dict(rpm=10000, tpm=5000000),
|
54
59
|
'gpt-4o-mini-2024-07-18': pg.Dict(rpm=10000, tpm=5000000),
|
@@ -175,7 +180,7 @@ class OpenAI(lf.LanguageModel):
|
|
175
180
|
@property
|
176
181
|
def is_chat_model(self):
|
177
182
|
"""Returns True if the model is a chat model."""
|
178
|
-
return self.model.startswith(('gpt-4', 'gpt-3.5-turbo'))
|
183
|
+
return self.model.startswith(('o1', 'gpt-4', 'gpt-3.5-turbo'))
|
179
184
|
|
180
185
|
def _get_request_args(
|
181
186
|
self, options: lf.LMSamplingOptions) -> dict[str, Any]:
|
@@ -186,9 +191,14 @@ class OpenAI(lf.LanguageModel):
|
|
186
191
|
n=options.n,
|
187
192
|
stream=False,
|
188
193
|
timeout=self.timeout,
|
189
|
-
logprobs=options.logprobs,
|
190
194
|
top_logprobs=options.top_logprobs,
|
191
195
|
)
|
196
|
+
if options.logprobs:
|
197
|
+
# Reasoning models (o1 series) does not support `logprobs` by 2024/09/12.
|
198
|
+
if self.model.startswith('o1-'):
|
199
|
+
raise RuntimeError('`logprobs` is not supported on {self.model!r}.')
|
200
|
+
args['logprobs'] = options.logprobs
|
201
|
+
|
192
202
|
# Completion and ChatCompletion uses different parameter name for model.
|
193
203
|
args['model' if self.is_chat_model else 'engine'] = self.model
|
194
204
|
|
@@ -316,14 +326,15 @@ class OpenAI(lf.LanguageModel):
|
|
316
326
|
samples = []
|
317
327
|
for choice in response.choices:
|
318
328
|
logprobs = None
|
319
|
-
|
329
|
+
choice_logprobs = getattr(choice, 'logprobs', None)
|
330
|
+
if choice_logprobs:
|
320
331
|
logprobs = [
|
321
332
|
(
|
322
333
|
t.token,
|
323
334
|
t.logprob,
|
324
335
|
[(tt.token, tt.logprob) for tt in t.top_logprobs],
|
325
336
|
)
|
326
|
-
for t in
|
337
|
+
for t in choice_logprobs.content
|
327
338
|
]
|
328
339
|
samples.append(
|
329
340
|
lf.LMSample(
|
@@ -353,6 +364,26 @@ class OpenAI(lf.LanguageModel):
|
|
353
364
|
)
|
354
365
|
|
355
366
|
|
367
|
+
class GptO1Preview(OpenAI):
|
368
|
+
"""GPT-O1."""
|
369
|
+
model = 'o1-preview'
|
370
|
+
|
371
|
+
|
372
|
+
class GptO1Preview_20240912(OpenAI): # pylint: disable=invalid-name
|
373
|
+
"""GPT O1."""
|
374
|
+
model = 'o1-preview-2024-09-12'
|
375
|
+
|
376
|
+
|
377
|
+
class GptO1Mini(OpenAI):
|
378
|
+
"""GPT O1-mini."""
|
379
|
+
model = 'o1-mini'
|
380
|
+
|
381
|
+
|
382
|
+
class GptO1Mini_20240912(OpenAI): # pylint: disable=invalid-name
|
383
|
+
"""GPT O1-mini."""
|
384
|
+
model = 'o1-mini-2024-09-12'
|
385
|
+
|
386
|
+
|
356
387
|
class Gpt4(OpenAI):
|
357
388
|
"""GPT-4."""
|
358
389
|
model = 'gpt-4'
|
langfun/core/llms/openai_test.py
CHANGED
@@ -117,12 +117,13 @@ class OpenAITest(unittest.TestCase):
|
|
117
117
|
openai.Gpt35(api_key='test_key', timeout=90.0)._get_request_args(
|
118
118
|
lf.LMSamplingOptions(
|
119
119
|
temperature=2.0,
|
120
|
+
logprobs=True,
|
120
121
|
n=2,
|
121
122
|
max_tokens=4096,
|
122
123
|
top_p=1.0)),
|
123
124
|
dict(
|
124
125
|
engine='text-davinci-003',
|
125
|
-
logprobs=
|
126
|
+
logprobs=True,
|
126
127
|
top_logprobs=None,
|
127
128
|
n=2,
|
128
129
|
temperature=2.0,
|
@@ -140,7 +141,6 @@ class OpenAITest(unittest.TestCase):
|
|
140
141
|
),
|
141
142
|
dict(
|
142
143
|
model='gpt-4',
|
143
|
-
logprobs=False,
|
144
144
|
top_logprobs=None,
|
145
145
|
n=1,
|
146
146
|
temperature=1.0,
|
@@ -150,6 +150,12 @@ class OpenAITest(unittest.TestCase):
|
|
150
150
|
seed=123,
|
151
151
|
),
|
152
152
|
)
|
153
|
+
with self.assertRaisesRegex(RuntimeError, '`logprobs` is not supported.*'):
|
154
|
+
openai.GptO1Preview(api_key='test_key')._get_request_args(
|
155
|
+
lf.LMSamplingOptions(
|
156
|
+
temperature=1.0, logprobs=True
|
157
|
+
)
|
158
|
+
)
|
153
159
|
|
154
160
|
def test_call_completion(self):
|
155
161
|
with mock.patch('openai.Completion.create') as mock_completion:
|
@@ -68,6 +68,7 @@ from langfun.core.structured.prompting import QueryStructurePython
|
|
68
68
|
from langfun.core.structured.prompting import query
|
69
69
|
from langfun.core.structured.prompting import query_prompt
|
70
70
|
from langfun.core.structured.prompting import query_output
|
71
|
+
from langfun.core.structured.prompting import query_reward
|
71
72
|
|
72
73
|
from langfun.core.structured.description import DescribeStructure
|
73
74
|
from langfun.core.structured.description import describe
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
"""Symbolic query."""
|
15
15
|
|
16
|
+
import functools
|
16
17
|
from typing import Any, Callable, Type, Union
|
17
18
|
|
18
19
|
import langfun.core as lf
|
@@ -265,3 +266,57 @@ def query_output(
|
|
265
266
|
return query(
|
266
267
|
'Unused prompt', schema, lm=fake.StaticResponse(response), **kwargs
|
267
268
|
)
|
269
|
+
|
270
|
+
|
271
|
+
def query_reward(
|
272
|
+
mapping_example: Union[str, mapping.MappingExample],
|
273
|
+
response: Union[str, lf.Message],
|
274
|
+
) -> float | None:
|
275
|
+
"""Returns the reward of an LLM response based on an mapping example."""
|
276
|
+
if isinstance(mapping_example, str):
|
277
|
+
mapping_example = pg.from_json_str(mapping_example)
|
278
|
+
assert isinstance(mapping_example, mapping.MappingExample), mapping_example
|
279
|
+
schema = mapping_example.schema
|
280
|
+
|
281
|
+
if schema and isinstance(schema.spec, pg.typing.Object):
|
282
|
+
output_cls = schema.spec.cls
|
283
|
+
elif schema is None and isinstance(mapping_example.output, pg.Object):
|
284
|
+
output_cls = mapping_example.output.__class__
|
285
|
+
else:
|
286
|
+
output_cls = None
|
287
|
+
|
288
|
+
reward_fn = _reward_fn(output_cls)
|
289
|
+
if reward_fn is None:
|
290
|
+
return None
|
291
|
+
|
292
|
+
return reward_fn(
|
293
|
+
query_output(response, output_cls),
|
294
|
+
mapping_example.input,
|
295
|
+
mapping_example.output,
|
296
|
+
mapping_example.metadata,
|
297
|
+
)
|
298
|
+
|
299
|
+
|
300
|
+
@functools.cache
|
301
|
+
def _reward_fn(cls) -> Callable[
|
302
|
+
[
|
303
|
+
pg.Object, # Actual output object.
|
304
|
+
Any, # Input object.
|
305
|
+
pg.Object, # Expected output object.
|
306
|
+
pg.Dict # User metadata.
|
307
|
+
], float] | None:
|
308
|
+
"""Returns the reward function for a class that is being queried."""
|
309
|
+
if not callable(getattr(cls, '__reward__', None)):
|
310
|
+
return None
|
311
|
+
|
312
|
+
signature = pg.typing.signature(cls.__reward__)
|
313
|
+
num_args = len(signature.args)
|
314
|
+
if num_args < 2 or num_args > 4:
|
315
|
+
raise TypeError(
|
316
|
+
f'`{cls.__type_name__}.__reward__` should have signature: '
|
317
|
+
'`__reward__(self, input, [expected_output], [expected_metadata])`.'
|
318
|
+
)
|
319
|
+
def _reward(self, input, expected_output, metadata): # pylint: disable=redefined-builtin
|
320
|
+
args = [self, input, expected_output, metadata]
|
321
|
+
return cls.__reward__(*args[:num_args])
|
322
|
+
return _reward
|
@@ -14,6 +14,8 @@
|
|
14
14
|
"""Tests for structured prompting."""
|
15
15
|
|
16
16
|
import inspect
|
17
|
+
import math
|
18
|
+
from typing import Any
|
17
19
|
import unittest
|
18
20
|
|
19
21
|
import langfun.core as lf
|
@@ -382,6 +384,128 @@ class QueryTest(unittest.TestCase):
|
|
382
384
|
1,
|
383
385
|
)
|
384
386
|
|
387
|
+
def test_query_reward(self):
|
388
|
+
|
389
|
+
class Answer(pg.Object):
|
390
|
+
final_answer: int
|
391
|
+
|
392
|
+
def __reward__(self, inputs: lf.Template) -> None:
|
393
|
+
diff = abs(self.final_answer - (inputs.x + inputs.y))
|
394
|
+
# Center screwed sigmoid scaled to [-1.0 and 1.0].
|
395
|
+
return 4 / (1 + math.exp(diff)) - 1.0
|
396
|
+
|
397
|
+
# Case 1: Reward function based on input and output.
|
398
|
+
self.assertEqual(
|
399
|
+
prompting.query_reward(
|
400
|
+
mapping.MappingExample(
|
401
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
402
|
+
schema=Answer,
|
403
|
+
output=Answer(final_answer=2),
|
404
|
+
),
|
405
|
+
'Answer(2)'
|
406
|
+
),
|
407
|
+
1.0
|
408
|
+
)
|
409
|
+
self.assertEqual(
|
410
|
+
prompting.query_reward(
|
411
|
+
mapping.MappingExample(
|
412
|
+
input=lf.Template('{{x}} + {{y}}', x=2, y=3),
|
413
|
+
output=Answer(final_answer=2),
|
414
|
+
).to_json_str(),
|
415
|
+
'Answer(5)'
|
416
|
+
),
|
417
|
+
1.0
|
418
|
+
)
|
419
|
+
|
420
|
+
# Case 2: Reward function based on input, result and expected output.
|
421
|
+
class Answer2(pg.Object):
|
422
|
+
final_answer: int
|
423
|
+
|
424
|
+
def __reward__(self, inputs: lf.Template, expected_output: 'Answer2'):
|
425
|
+
return (
|
426
|
+
1.0 if self.final_answer == expected_output.final_answer else -1.0
|
427
|
+
)
|
428
|
+
|
429
|
+
self.assertEqual(
|
430
|
+
prompting.query_reward(
|
431
|
+
mapping.MappingExample(
|
432
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
433
|
+
output=Answer2(final_answer=2),
|
434
|
+
),
|
435
|
+
'Answer2(3)'
|
436
|
+
),
|
437
|
+
-1.0
|
438
|
+
)
|
439
|
+
|
440
|
+
# Case 3: Reward function based on input, result, expected output
|
441
|
+
# and metadata.
|
442
|
+
class Answer3(pg.Object):
|
443
|
+
final_answer: int
|
444
|
+
|
445
|
+
def __reward__(self,
|
446
|
+
inputs: lf.Template,
|
447
|
+
expected_output: 'Answer3',
|
448
|
+
metadata: dict[str, Any]):
|
449
|
+
del inputs
|
450
|
+
return (
|
451
|
+
1.0 if self.final_answer == expected_output.final_answer else -1.0
|
452
|
+
) * metadata['weight']
|
453
|
+
|
454
|
+
self.assertEqual(
|
455
|
+
prompting.query_reward(
|
456
|
+
mapping.MappingExample(
|
457
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
458
|
+
output=Answer3(final_answer=2),
|
459
|
+
metadata=dict(weight=0.5)
|
460
|
+
),
|
461
|
+
'Answer3(3)'
|
462
|
+
),
|
463
|
+
-0.5
|
464
|
+
)
|
465
|
+
|
466
|
+
# Case 4: No reward function is provided.
|
467
|
+
class Answer4(pg.Object):
|
468
|
+
final_answer: int
|
469
|
+
|
470
|
+
self.assertIsNone(
|
471
|
+
prompting.query_reward(
|
472
|
+
mapping.MappingExample(
|
473
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
474
|
+
output=Answer4(final_answer=2),
|
475
|
+
),
|
476
|
+
'Answer2(2)'
|
477
|
+
)
|
478
|
+
)
|
479
|
+
|
480
|
+
# Case 5: Not a structured output.
|
481
|
+
self.assertIsNone(
|
482
|
+
prompting.query_reward(
|
483
|
+
mapping.MappingExample(
|
484
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
485
|
+
output='2',
|
486
|
+
),
|
487
|
+
'2'
|
488
|
+
)
|
489
|
+
)
|
490
|
+
|
491
|
+
# Case 6: Bad reward function.
|
492
|
+
class Answer5(pg.Object):
|
493
|
+
final_answer: int
|
494
|
+
|
495
|
+
def __reward__(self):
|
496
|
+
return 0.0
|
497
|
+
|
498
|
+
with self.assertRaisesRegex(
|
499
|
+
TypeError, '.*Answer5.__reward__` should have signature'
|
500
|
+
):
|
501
|
+
prompting.query_reward(
|
502
|
+
mapping.MappingExample(
|
503
|
+
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
|
504
|
+
output=Answer5(final_answer=2),
|
505
|
+
),
|
506
|
+
'Answer5(2)'
|
507
|
+
)
|
508
|
+
|
385
509
|
|
386
510
|
class QueryStructurePythonTest(unittest.TestCase):
|
387
511
|
|
@@ -262,7 +262,7 @@ def class_dependencies(
|
|
262
262
|
)
|
263
263
|
|
264
264
|
# Add members as dependencies.
|
265
|
-
for field in
|
265
|
+
for field in pg.schema(vs.cls).values():
|
266
266
|
_fill_dependencies(field.value, include_subclasses)
|
267
267
|
_add_dependency(vs.cls)
|
268
268
|
|
@@ -390,7 +390,7 @@ def class_definition(
|
|
390
390
|
) -> str:
|
391
391
|
"""Returns the Python class definition."""
|
392
392
|
out = io.StringIO()
|
393
|
-
schema =
|
393
|
+
schema = pg.schema(cls)
|
394
394
|
eligible_bases = []
|
395
395
|
for base_cls in cls.__bases__:
|
396
396
|
if base_cls is not object:
|
@@ -913,13 +913,3 @@ class Unknown(pg.Object, pg.typing.CustomTyping):
|
|
913
913
|
|
914
914
|
|
915
915
|
UNKNOWN = Unknown()
|
916
|
-
|
917
|
-
|
918
|
-
def _pg_schema(cls: Type[Any]) -> pg.Schema:
|
919
|
-
"""Returns PyGlove schema for the constructor of a class."""
|
920
|
-
schema = getattr(cls, '__schema__', None)
|
921
|
-
if schema is None:
|
922
|
-
schema = pg.symbolic.callable_schema(
|
923
|
-
cls.__init__, auto_typing=True, auto_doc=True, remove_self=True
|
924
|
-
)
|
925
|
-
return schema
|
{langfun-0.1.2.dev202409100804.dist-info → langfun-0.1.2.dev202409130804.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: langfun
|
3
|
-
Version: 0.1.2.
|
3
|
+
Version: 0.1.2.dev202409130804
|
4
4
|
Summary: Langfun: Language as Functions.
|
5
5
|
Home-page: https://github.com/google/langfun
|
6
6
|
Author: Langfun Authors
|
@@ -21,11 +21,11 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
21
21
|
Classifier: Topic :: Software Development :: Libraries
|
22
22
|
Description-Content-Type: text/markdown
|
23
23
|
License-File: LICENSE
|
24
|
-
Requires-Dist: pyglove>=0.4.5.
|
24
|
+
Requires-Dist: pyglove>=0.4.5.dev202409110000
|
25
25
|
Requires-Dist: jinja2>=3.1.2
|
26
26
|
Requires-Dist: requests>=2.31.0
|
27
27
|
Provides-Extra: all
|
28
|
-
Requires-Dist: pyglove>=0.4.5.
|
28
|
+
Requires-Dist: pyglove>=0.4.5.dev202409110000; extra == "all"
|
29
29
|
Requires-Dist: jinja2>=3.1.2; extra == "all"
|
30
30
|
Requires-Dist: requests>=2.31.0; extra == "all"
|
31
31
|
Requires-Dist: termcolor==1.1.0; extra == "all"
|
@@ -1,9 +1,9 @@
|
|
1
1
|
langfun/__init__.py,sha256=mCES7t3R7Z-ZQYvG38-yrVqZubrXNfGCa8tI5HGB7mE,2274
|
2
2
|
langfun/core/__init__.py,sha256=r86kuy-BiJIveqnXx5OklUUXtMG3q79nWRBum6zFOCQ,4835
|
3
|
-
langfun/core/component.py,sha256=
|
3
|
+
langfun/core/component.py,sha256=nxF4UD7NLduPjrf4IAtysMAIg3IChfoQHjY77vC8f_E,10263
|
4
4
|
langfun/core/component_test.py,sha256=q15Xn51cVTu2RKxZ9U5VQgT3bm6RQ4638bKhWBtvW5o,8220
|
5
|
-
langfun/core/concurrent.py,sha256=
|
6
|
-
langfun/core/concurrent_test.py,sha256
|
5
|
+
langfun/core/concurrent.py,sha256=D_zoLGFREGfm0G93V27wzOlFh3MjaDxUZKb9g6Z69d4,28019
|
6
|
+
langfun/core/concurrent_test.py,sha256=-FO_VfzNtMvdCv-qKU5StVSStDHldVLzN3xBqkAnw3U,16943
|
7
7
|
langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
|
8
8
|
langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
|
9
9
|
langfun/core/langfunc.py,sha256=G50YgoVZ0y1GFw2ev41MlOqr6qa8YakbvNC0h_E0PiA,11140
|
@@ -44,7 +44,7 @@ langfun/core/coding/python/parsing_test.py,sha256=9vAWF484kWIm6JZq8NFiMgKUDhXV-d
|
|
44
44
|
langfun/core/coding/python/permissions.py,sha256=1QWGHvzL8MM0Ok_auQ9tURqZHtdOfJaDpBzZ29GUE-c,2544
|
45
45
|
langfun/core/coding/python/permissions_test.py,sha256=w5EDb8QxpxgJyZkojyzVWQvDfg366zn99-g__6TbPQ0,2699
|
46
46
|
langfun/core/eval/__init__.py,sha256=Ogdr9OtTywhhLPHi3AZzOD2mXX2oyaHWflrSTMm96uA,1899
|
47
|
-
langfun/core/eval/base.py,sha256=
|
47
|
+
langfun/core/eval/base.py,sha256=xDZQ3lu5oJaPDZCE5-QbBEajq--HRU64GVKb3xB64aI,74738
|
48
48
|
langfun/core/eval/base_test.py,sha256=VEraWaRybSxOCOcZrZouNkiroDEPR6uyFBJoAz-1pQg,26930
|
49
49
|
langfun/core/eval/matching.py,sha256=9GX8HfO9jKxgNLAivgy5K88Xhoh6Z7Pptq65pe7vht8,9762
|
50
50
|
langfun/core/eval/matching_test.py,sha256=f7iVyXH5KGJBWt4Wp14Bt9J3X59A6Ayfog9MbuFvPew,5532
|
@@ -52,19 +52,19 @@ langfun/core/eval/patching.py,sha256=R0s2eAd1m97exQt06dmUL0V_MBG0W2Hxg7fhNB7cXW0
|
|
52
52
|
langfun/core/eval/patching_test.py,sha256=8kCd54Egjju22FMgtJuxEsrXkW8ifs-UUBHtrCG1L6w,4775
|
53
53
|
langfun/core/eval/scoring.py,sha256=AlCwEVrU6nvURDB1aPxA2XBUmOjWxuNJDXJoS4-6VbU,6386
|
54
54
|
langfun/core/eval/scoring_test.py,sha256=O8olHbrUEg60gMxwOkWzKBJZpZoUlmVnBANX5Se2SXM,4546
|
55
|
-
langfun/core/llms/__init__.py,sha256=
|
55
|
+
langfun/core/llms/__init__.py,sha256=ZgnzZSjI37LhUVNvgPGNSLS-HMCqPjaCGdwPd2Ij5Rs,5031
|
56
56
|
langfun/core/llms/anthropic.py,sha256=Gon3fOi31RhZFgNd0ijyTnKnUdp9hrWrCoSXyO4UaLw,7316
|
57
57
|
langfun/core/llms/anthropic_test.py,sha256=T-swuMkfnlgs8Fpif4rtXs579exGk0TsbLMirXDZCkg,5533
|
58
58
|
langfun/core/llms/fake.py,sha256=gCHBYBLvBCsC78HI1hpoqXCS-p1FMTgY1P1qh_sGBPk,3070
|
59
59
|
langfun/core/llms/fake_test.py,sha256=sIl_Mg7nFVjaN7AJhYCpA_qzDJpSnJzkazepGXpfQQg,7338
|
60
|
-
langfun/core/llms/google_genai.py,sha256=
|
61
|
-
langfun/core/llms/google_genai_test.py,sha256=
|
60
|
+
langfun/core/llms/google_genai.py,sha256=btUIfWteBoj8Jl0j8d3e8hyI6p3Biq4rldlQYctVQfg,10936
|
61
|
+
langfun/core/llms/google_genai_test.py,sha256=zw14sgWmk0P_irHyb7vpPy1WAuLEE0PmyfiFElu03sA,7686
|
62
62
|
langfun/core/llms/groq.py,sha256=pqtyOZ_1_OJMOg8xATWT_B_SVbuT9nMRf4VkH9GzW8g,6308
|
63
63
|
langfun/core/llms/groq_test.py,sha256=GYF_Qtq5S1H1TrKH38t6_lkdroqT7v-joYLDKnmS9e0,5274
|
64
64
|
langfun/core/llms/llama_cpp.py,sha256=9tXQntSCDtjTF3bnyJrAPCr4N6wycy5nXYvp9uduygE,2843
|
65
65
|
langfun/core/llms/llama_cpp_test.py,sha256=MWO_qaOeKjRniGjcaWPDScd7HPaIJemqUZoslrt4FPs,1806
|
66
|
-
langfun/core/llms/openai.py,sha256=
|
67
|
-
langfun/core/llms/openai_test.py,sha256=
|
66
|
+
langfun/core/llms/openai.py,sha256=vnDrKuD-pli0AtDIDq_TmlltOk7z7_PQ-xpU4K1ARdU,17083
|
67
|
+
langfun/core/llms/openai_test.py,sha256=UcBFW_7RkkMEo47Tn5RuVRK_DryTN7bb9ITphlzthE8,17762
|
68
68
|
langfun/core/llms/rest.py,sha256=laopuq-zD8V-3Y6eFDngftHEbE66VlUkCD2-rvvRaLU,3388
|
69
69
|
langfun/core/llms/rest_test.py,sha256=NZ3Nf0XQVpT9kLP5cBVo_yBHLI7vWTYhWQxYEJVMGs4,3472
|
70
70
|
langfun/core/llms/vertexai.py,sha256=mEQVwO3Kf3rGRmsI-qKrV6vg0hYy6OH1lEVOM81cb3U,15134
|
@@ -89,7 +89,7 @@ langfun/core/modalities/pdf.py,sha256=mfaeCbUA4JslFVTARiJh8hW7imvL4tLVw9gUhO5bAZ
|
|
89
89
|
langfun/core/modalities/pdf_test.py,sha256=KE40zJD3Whe6ty2OULkp1J8jwLmB4ZjGXlGekluTP48,1952
|
90
90
|
langfun/core/modalities/video.py,sha256=sKcXxbx9S1ERjH8yEzkbtySpcRJD40QiPIQiIBy-U5I,955
|
91
91
|
langfun/core/modalities/video_test.py,sha256=GbsoefSeO7y8kCYhTtp4s9E3ah_eYrb6Z-MXpS01RFc,2046
|
92
|
-
langfun/core/structured/__init__.py,sha256=
|
92
|
+
langfun/core/structured/__init__.py,sha256=7EgI6pQIWnSNxhcIawBcSRDR8GPq5ytcmGxgPEjbxeA,3894
|
93
93
|
langfun/core/structured/completion.py,sha256=cS2PjG7sqzDu5x0xoTk8RmNcoeX55iVwH38NTefkMHg,8108
|
94
94
|
langfun/core/structured/completion_test.py,sha256=2mUzDMKGF_WGfTtsnfmfMDx97dkJ-98y8leen__qWLA,19281
|
95
95
|
langfun/core/structured/description.py,sha256=SXW4MJvshFjbR-0gw6rE21o6WXq12UlRXawvDBXMZFA,5211
|
@@ -100,9 +100,9 @@ langfun/core/structured/mapping.py,sha256=dKOCvIA_kCQ88KoCnP5k0iOe9xRt8WLC2kbT6q
|
|
100
100
|
langfun/core/structured/mapping_test.py,sha256=zQoVx3kAD5oSm_OJAQA6q41NXLLyn8qs6CIVJgAoP_w,4489
|
101
101
|
langfun/core/structured/parsing.py,sha256=keoVqEfzAbdULh6GawWFsTQzU91MzJXYFZjXGXLaD8g,11492
|
102
102
|
langfun/core/structured/parsing_test.py,sha256=34wDrXaQ-EYhJLfDL8mX9K53oQMSzh5pVYdKjnESmK8,20895
|
103
|
-
langfun/core/structured/prompting.py,sha256=
|
104
|
-
langfun/core/structured/prompting_test.py,sha256=
|
105
|
-
langfun/core/structured/schema.py,sha256=
|
103
|
+
langfun/core/structured/prompting.py,sha256=huwwh01AQQCwPBQESOMI_V1V5PZkVQ8C89Yjk67_4Uw,10677
|
104
|
+
langfun/core/structured/prompting_test.py,sha256=P4SsnWKC1MuslxUVPd8rJQQrBbMD_WrItNWqqSc_EV8,26689
|
105
|
+
langfun/core/structured/schema.py,sha256=zpK3UqHPgxaHWptMr4I4zd_f22Yn4IFxF655d4O_jmQ,27416
|
106
106
|
langfun/core/structured/schema_generation.py,sha256=U3nRQsqmMZg_qIVDh2fiY3K4JLfsAL1LcKzIFP1iXFg,5316
|
107
107
|
langfun/core/structured/schema_generation_test.py,sha256=RM9s71kMNg2jTePwInkiW9fK1ACN37eyPeF8OII-0zw,2950
|
108
108
|
langfun/core/structured/schema_test.py,sha256=RjYhwTgktQgyqAjzLvo967nTiIK9KWgP-aNGg4e7ihE,25258
|
@@ -119,8 +119,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
|
|
119
119
|
langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
|
120
120
|
langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
|
121
121
|
langfun/core/templates/selfplay_test.py,sha256=rBW2Qr8yi-aWYwoTwRR-n1peKyMX9QXPZXURjLgoiRs,2264
|
122
|
-
langfun-0.1.2.
|
123
|
-
langfun-0.1.2.
|
124
|
-
langfun-0.1.2.
|
125
|
-
langfun-0.1.2.
|
126
|
-
langfun-0.1.2.
|
122
|
+
langfun-0.1.2.dev202409130804.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
123
|
+
langfun-0.1.2.dev202409130804.dist-info/METADATA,sha256=ILLD9AJGZDwWiYQXBUKIVMPWWGXdIdqaQn6hOCumo9I,8890
|
124
|
+
langfun-0.1.2.dev202409130804.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
125
|
+
langfun-0.1.2.dev202409130804.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
|
126
|
+
langfun-0.1.2.dev202409130804.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{langfun-0.1.2.dev202409100804.dist-info → langfun-0.1.2.dev202409130804.dist-info}/top_level.txt
RENAMED
File without changes
|