langfun 0.1.2.dev202504280818__py3-none-any.whl → 0.1.2.dev202504300804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of langfun might be problematic. Click here for more details.

@@ -15,8 +15,9 @@
15
15
 
16
16
  import contextlib
17
17
  import functools
18
+ import inspect
18
19
  import time
19
- from typing import Annotated, Any, Callable, Iterator, Type, Union
20
+ from typing import Annotated, Any, Callable, ClassVar, Iterator, Type, Union
20
21
  import uuid
21
22
 
22
23
  import langfun.core as lf
@@ -26,8 +27,35 @@ import pyglove as pg
26
27
 
27
28
 
28
29
  @lf.use_init_args(['schema', 'default', 'examples'])
29
- class _QueryStructure(mapping.Mapping):
30
- """Query an object out from a natural language text."""
30
+ class LfQuery(mapping.Mapping):
31
+ """Base class for different implementations of `lf.query`.
32
+
33
+ By subclassing this class, users could create variations of prompts for
34
+ `lf.query` and associated them with specific protocols and versions.
35
+
36
+ For example:
37
+
38
+ ```
39
+ class _MyLfQuery(LFQuery):
40
+ protocol = 'my_format'
41
+ version = '1.0'
42
+
43
+ template_str = inspect.cleandoc(
44
+ '''
45
+ ...
46
+ '''
47
+ )
48
+ mapping_template = lf.Template(
49
+ '''
50
+ ...
51
+ '''
52
+ )
53
+
54
+ lf.query(..., protocol='my_format:1.0')
55
+ ```
56
+
57
+ (THIS IS NOT A TEMPLATE)
58
+ """
31
59
 
32
60
  context_title = 'CONTEXT'
33
61
  input_title = 'INPUT_OBJECT'
@@ -37,8 +65,81 @@ class _QueryStructure(mapping.Mapping):
37
65
  schema_lib.schema_spec(), 'Required schema for parsing.'
38
66
  ]
39
67
 
68
+ # A map from (protocol, version) to the query structure class.
69
+ # This is used to map different protocols/versions to different templates.
70
+ # So users can use `lf.query(..., protocol='<protocol>:<version>')` to use
71
+ # a specific version of the prompt. We use this feature to support variations
72
+ # of prompts and maintain backward compatibility.
73
+ _OOP_PROMPT_MAP: ClassVar[
74
+ dict[
75
+ str, # protocol.
76
+ dict[
77
+ str, # version.
78
+ Type['LfQuery']
79
+ ]
80
+ ]
81
+ ] = {}
82
+
83
+ # This the flag to update default protocol version.
84
+ _DEFAULT_PROTOCOL_VERSIONS: ClassVar[dict[str, str]] = {
85
+ 'python': '2.0',
86
+ 'json': '1.0',
87
+ }
88
+
89
+ def __init_subclass__(cls) -> Any:
90
+ super().__init_subclass__()
91
+ if not inspect.isabstract(cls):
92
+ protocol = cls.__schema__['protocol'].default_value
93
+ version_dict = cls._OOP_PROMPT_MAP.get(protocol)
94
+ if version_dict is None:
95
+ version_dict = {}
96
+ cls._OOP_PROMPT_MAP[protocol] = version_dict
97
+ dest_cls = version_dict.get(cls.version)
98
+ if dest_cls is not None and dest_cls.__type_name__ != cls.__type_name__:
99
+ raise ValueError(
100
+ f'Version {cls.version} is already registered for {dest_cls!r} '
101
+ f'under protocol {protocol!r}. Please use a different version.'
102
+ )
103
+ version_dict[cls.version] = cls
104
+
105
+ @classmethod
106
+ def from_protocol(cls, protocol: str) -> Type['LfQuery']:
107
+ """Returns a query structure from the given protocol and version."""
108
+ if ':' in protocol:
109
+ protocol, version = protocol.split(':')
110
+ else:
111
+ version = cls._DEFAULT_PROTOCOL_VERSIONS.get(protocol)
112
+ if version is None:
113
+ version_dict = cls._OOP_PROMPT_MAP.get(protocol)
114
+ if version_dict is None:
115
+ raise ValueError(
116
+ f'Protocol {protocol!r} is not supported. Available protocols: '
117
+ f'{sorted(cls._OOP_PROMPT_MAP.keys())}.'
118
+ )
119
+ elif len(version_dict) == 1:
120
+ version = list(version_dict.keys())[0]
121
+ else:
122
+ raise ValueError(
123
+ f'Multiple versions found for protocol {protocol!r}, please '
124
+ f'specify a version with "{protocol}:<version>".'
125
+ )
126
+
127
+ version_dict = cls._OOP_PROMPT_MAP.get(protocol)
128
+ if version_dict is None:
129
+ raise ValueError(
130
+ f'Protocol {protocol!r} is not supported. Available protocols: '
131
+ f'{sorted(cls._OOP_PROMPT_MAP.keys())}.'
132
+ )
133
+ dest_cls = version_dict.get(version)
134
+ if dest_cls is None:
135
+ raise ValueError(
136
+ f'Version {version!r} is not supported for protocol {protocol!r}. '
137
+ f'Available versions: {sorted(version_dict.keys())}.'
138
+ )
139
+ return dest_cls
40
140
 
41
- class _QueryStructureJson(_QueryStructure):
141
+
142
+ class _LfQueryJsonV1(LfQuery):
42
143
  """Query a structured value using JSON as the protocol."""
43
144
 
44
145
  preamble = """
@@ -58,12 +159,13 @@ class _QueryStructureJson(_QueryStructure):
58
159
  {"result": {"_type": "langfun.core.structured.query.Answer", "final_answer": 2}}
59
160
  """
60
161
 
162
+ version = '1.0'
61
163
  protocol = 'json'
62
164
  schema_title = 'SCHEMA'
63
165
  output_title = 'JSON'
64
166
 
65
167
 
66
- class _QueryStructurePython(_QueryStructure):
168
+ class _LfQueryPythonV1(LfQuery):
67
169
  """Query a structured value using Python as the protocol."""
68
170
 
69
171
  preamble = """
@@ -87,20 +189,87 @@ class _QueryStructurePython(_QueryStructure):
87
189
  )
88
190
  ```
89
191
  """
192
+ version = '1.0'
90
193
  protocol = 'python'
91
194
  schema_title = 'OUTPUT_TYPE'
92
195
  output_title = 'OUTPUT_OBJECT'
196
+ mapping_template = lf.Template(
197
+ """
198
+ {%- if example.context -%}
199
+ {{ context_title}}:
200
+ {{ example.context | indent(2, True)}}
93
201
 
202
+ {% endif -%}
94
203
 
95
- def _query_structure_cls(
96
- protocol: schema_lib.SchemaProtocol,
97
- ) -> Type[_QueryStructure]:
98
- if protocol == 'json':
99
- return _QueryStructureJson
100
- elif protocol == 'python':
101
- return _QueryStructurePython
102
- else:
103
- raise ValueError(f'Unknown protocol: {protocol!r}.')
204
+ {{ input_title }}:
205
+ {{ example.input_repr(protocol, compact=False) | indent(2, True) }}
206
+
207
+ {% if example.schema -%}
208
+ {{ schema_title }}:
209
+ {{ example.schema_repr(protocol) | indent(2, True) }}
210
+
211
+ {% endif -%}
212
+
213
+ {{ output_title }}:
214
+ {%- if example.has_output %}
215
+ {{ example.output_repr(protocol, compact=False) | indent(2, True) }}
216
+ {% endif -%}
217
+ """
218
+ )
219
+
220
+
221
+ class _LfQueryPythonV2(LfQuery):
222
+ """Query a structured value using Python as the protocol."""
223
+
224
+ preamble = """
225
+ Please respond to the last {{ input_title }} with {{ output_title }} only according to {{ schema_title }}.
226
+
227
+ {{ input_title }}:
228
+ 1 + 1 =
229
+
230
+ {{ schema_title }}:
231
+ Answer
232
+
233
+ ```python
234
+ class Answer:
235
+ final_answer: int
236
+ ```
237
+
238
+ {{ output_title }}:
239
+ ```python
240
+ output = Answer(
241
+ final_answer=2
242
+ )
243
+ ```
244
+ """
245
+ version = '2.0'
246
+ protocol = 'python'
247
+ input_title = 'REQUEST'
248
+ schema_title = 'OUTPUT PYTHON TYPE'
249
+ output_title = 'OUTPUT PYTHON OBJECT'
250
+ mapping_template = lf.Template(
251
+ """
252
+ {%- if example.context -%}
253
+ {{ context_title}}:
254
+ {{ example.context | indent(2, True)}}
255
+
256
+ {% endif -%}
257
+
258
+ {{ input_title }}:
259
+ {{ example.input_repr(protocol, compact=False) | indent(2, True) }}
260
+
261
+ {% if example.schema -%}
262
+ {{ schema_title }}:
263
+ {{ example.schema_repr(protocol) | indent(2, True) }}
264
+
265
+ {% endif -%}
266
+
267
+ {{ output_title }}:
268
+ {%- if example.has_output %}
269
+ {{ example.output_repr(protocol, compact=False, assign_to_var='output') | indent(2, True) }}
270
+ {% endif -%}
271
+ """
272
+ )
104
273
 
105
274
 
106
275
  def query(
@@ -116,7 +285,7 @@ def query(
116
285
  response_postprocess: Callable[[str], str] | None = None,
117
286
  autofix: int = 0,
118
287
  autofix_lm: lf.LanguageModel | None = None,
119
- protocol: schema_lib.SchemaProtocol = 'python',
288
+ protocol: str | None = None,
120
289
  returns_message: bool = False,
121
290
  skip_lm: bool = False,
122
291
  invocation_id: str | None = None,
@@ -259,8 +428,14 @@ def query(
259
428
  disable auto-fixing. Not supported with the `'json'` protocol.
260
429
  autofix_lm: The LM to use for auto-fixing. Defaults to the `autofix_lm`
261
430
  from `lf.context` or the main `lm`.
262
- protocol: Format for schema representation. Choices are `'json'` or
263
- `'python'`. Default is `'python'`.
431
+ protocol: Format for schema representation. Builtin choices are `'json'` or
432
+ `'python'`, users could extend with their own protocols by subclassing
433
+ `lf.structured.LfQuery'. Also protocol could be specified with a version
434
+ in the format of 'protocol:version', e.g., 'python:1.0', so users could
435
+ use a specific version of the prompt based on the protocol. Please see the
436
+ documentation of `LfQuery` for more details. If None, the protocol from
437
+ context manager `lf.query_protocol` will be used, or 'python' if not
438
+ specified.
264
439
  returns_message: If `True`, returns an `lf.Message` object instead of
265
440
  the final parsed result.
266
441
  skip_lm: If `True`, skips the LLM call and returns the rendered
@@ -280,6 +455,9 @@ def query(
280
455
  """
281
456
  # Internal usage logging.
282
457
 
458
+ if protocol is None:
459
+ protocol = lf.context_value('__query_protocol__', 'python')
460
+
283
461
  invocation_id = invocation_id or f'query@{uuid.uuid4().hex[-7:]}'
284
462
  # Multiple quries will be issued when `lm` is a list or `num_samples` is
285
463
  # greater than 1.
@@ -382,7 +560,7 @@ def query(
382
560
  output_message = lf.AIMessage(processed_text, source=output_message)
383
561
  else:
384
562
  # Query with structured output.
385
- output_message = _query_structure_cls(protocol)(
563
+ output_message = LfQuery.from_protocol(protocol)(
386
564
  input=(
387
565
  query_input.render(lm=lm)
388
566
  if isinstance(query_input, lf.Template)
@@ -436,6 +614,15 @@ def query(
436
614
  return output_message if returns_message else _result(output_message)
437
615
 
438
616
 
617
+ @contextlib.contextmanager
618
+ def query_protocol(protocol: str) -> Iterator[None]:
619
+ """Context manager for setting the query protocol for the scope."""
620
+ with lf.context(__query_protocol__=protocol):
621
+ try:
622
+ yield
623
+ finally:
624
+ pass
625
+
439
626
  #
440
627
  # Helper function for map-reduce style querying.
441
628
  #