langfun 0.0.2.dev20240314__py3-none-any.whl → 0.0.2.dev20240316__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langfun/__init__.py CHANGED
@@ -31,6 +31,9 @@ query = structured.query
31
31
  describe = structured.describe
32
32
  complete = structured.complete
33
33
  score = structured.score
34
+ generate_class = structured.generate_class
35
+
36
+ source_form = structured.source_form
34
37
 
35
38
  from langfun.core import eval # pylint: disable=redefined-builtin
36
39
  from langfun.core import templates
@@ -95,8 +95,8 @@ class LangFuncCallTest(unittest.TestCase):
95
95
  ' lm=ExcitedEchoer(sampling_options=LMSamplingOptions(temperature=0.0,'
96
96
  ' max_tokens=1024, n=1, top_k=40, top_p=None, stop=None,'
97
97
  ' random_seed=None, logprobs=False, top_logprobs=None), cache=None,'
98
- ' timeout=120.0, max_attempts=5, retry_interval=(5, 60),'
99
- ' exponential_backoff=True, debug=False))',
98
+ ' max_concurrency=None, timeout=120.0, max_attempts=5,'
99
+ ' retry_interval=(5, 60), exponential_backoff=True, debug=False))',
100
100
  )
101
101
 
102
102
  l = LangFunc('Hello')
@@ -17,8 +17,9 @@ import abc
17
17
  import dataclasses
18
18
  import enum
19
19
  import time
20
- from typing import Annotated, Any
20
+ from typing import Annotated, Any, Callable, Sequence, Tuple, Type, Union
21
21
  from langfun.core import component
22
+ from langfun.core import concurrent
22
23
  from langfun.core import console
23
24
  from langfun.core import message as message_lib
24
25
  import pyglove as pg
@@ -209,6 +210,22 @@ class LanguageModel(component.Component):
209
210
  )
210
211
  ] = component.contextual(default=None)
211
212
 
213
+ max_concurrency: Annotated[
214
+ int | None,
215
+ (
216
+ 'Max concurrent requests being sent to the server. '
217
+ 'If None, there is no limit. '
218
+ 'Please note that the concurrency control is based on the '
219
+ '`resource_id` property, meaning that model instances shared '
220
+ 'the same resource ID will be accounted under the same concurrency '
221
+ 'control key. This allows a process-level concurrency control '
222
+ 'for specific models regardless the number of LM (client) instances '
223
+ 'created by the program. Subclasses could override this number or '
224
+ 'replace it with a `max_concurrency` property to allow dynamic '
225
+ 'concurrency control.'
226
+ ),
227
+ ] = None
228
+
212
229
  timeout: Annotated[
213
230
  float | None, 'Timeout in seconds. If None, there is no timeout.'
214
231
  ] = 120.0
@@ -284,11 +301,6 @@ class LanguageModel(component.Component):
284
301
  """Resource ID for performing request parallism control."""
285
302
  return self.model_id
286
303
 
287
- @property
288
- def max_concurrency(self) -> int:
289
- """Max concurrent requests."""
290
- return 32
291
-
292
304
  def sample(
293
305
  self,
294
306
  prompts: list[str | message_lib.Message],
@@ -355,6 +367,28 @@ class LanguageModel(component.Component):
355
367
  ) -> list[LMSamplingResult]:
356
368
  """Subclass should override."""
357
369
 
370
+ def _parallel_execute_with_currency_control(
371
+ self,
372
+ action: Callable[..., Any],
373
+ inputs: Sequence[Any],
374
+ retry_on_errors: Union[
375
+ None,
376
+ Union[Type[Exception], Tuple[Type[Exception], str]],
377
+ Sequence[Union[Type[Exception], Tuple[Type[Exception], str]]],
378
+ ] = None,
379
+ ) -> Any:
380
+ """Helper method for subclasses for implementing _sample."""
381
+ return concurrent.concurrent_execute(
382
+ action,
383
+ inputs,
384
+ executor=self.resource_id if self.max_concurrency else None,
385
+ max_workers=self.max_concurrency or len(inputs),
386
+ retry_on_errors=retry_on_errors,
387
+ max_attempts=self.max_attempts,
388
+ retry_interval=self.retry_interval,
389
+ exponential_backoff=self.exponential_backoff,
390
+ )
391
+
358
392
  def __call__(
359
393
  self, prompt: message_lib.Message, *, cache_seed: int = 0, **kwargs
360
394
  ) -> message_lib.Message:
@@ -89,7 +89,7 @@ class LanguageModelTest(unittest.TestCase):
89
89
  lm = MockModel(1, temperature=0.5, top_k=2, max_attempts=2)
90
90
  self.assertEqual(lm.model_id, 'MockModel')
91
91
  self.assertEqual(lm.resource_id, 'MockModel')
92
- self.assertEqual(lm.max_concurrency, 32)
92
+ self.assertIsNone(lm.max_concurrency)
93
93
  self.assertEqual(lm.failures_before_attempt, 1)
94
94
  self.assertEqual(lm.sampling_options.temperature, 0.5)
95
95
  self.assertEqual(lm.sampling_options.top_k, 2)
@@ -133,14 +133,9 @@ class Gemini(lf.LanguageModel):
133
133
 
134
134
  def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
135
135
  assert self._api_initialized, 'Vertex AI API is not initialized.'
136
- return lf.concurrent_execute(
136
+ return self._parallel_execute_with_currency_control(
137
137
  self._sample_single,
138
138
  prompts,
139
- executor=self.resource_id,
140
- max_workers=self.max_concurrency,
141
- # NOTE(daiyip): Vertex has its own policy on handling
142
- # with rate limit, so we do not retry on errors.
143
- retry_on_errors=None,
144
139
  )
145
140
 
146
141
  def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
@@ -67,13 +67,6 @@ class LlamaCppRemote(lf.LanguageModel):
67
67
  results.append(result)
68
68
  return results
69
69
 
70
- return lf.concurrent_execute(
71
- _complete_fn,
72
- [prompts],
73
- executor=self.resource_id,
74
- max_workers=self.max_concurrency,
75
- retry_on_errors=(),
76
- max_attempts=self.max_attempts,
77
- retry_interval=self.retry_interval,
78
- exponential_backoff=self.exponential_backoff,
70
+ return self._parallel_execute_with_currency_control(
71
+ _complete_fn, [prompts]
79
72
  )[0]
@@ -214,18 +214,13 @@ class OpenAI(lf.LanguageModel):
214
214
  for index in sorted(samples_by_index.keys())
215
215
  ]
216
216
 
217
- return lf.concurrent_execute(
217
+ return self._parallel_execute_with_currency_control(
218
218
  _open_ai_completion,
219
219
  [prompts],
220
- executor=self.resource_id,
221
- max_workers=self.max_concurrency,
222
220
  retry_on_errors=(
223
221
  openai_error.ServiceUnavailableError,
224
222
  openai_error.RateLimitError,
225
223
  ),
226
- max_attempts=self.max_attempts,
227
- retry_interval=self.retry_interval,
228
- exponential_backoff=self.exponential_backoff,
229
224
  )[0]
230
225
 
231
226
  def _chat_complete_batch(
@@ -280,18 +275,13 @@ class OpenAI(lf.LanguageModel):
280
275
  ),
281
276
  )
282
277
 
283
- return lf.concurrent_execute(
278
+ return self._parallel_execute_with_currency_control(
284
279
  _open_ai_chat_completion,
285
280
  prompts,
286
- executor=self.resource_id,
287
- max_workers=self.max_concurrency,
288
281
  retry_on_errors=(
289
282
  openai_error.ServiceUnavailableError,
290
283
  openai_error.RateLimitError,
291
284
  ),
292
- max_attempts=self.max_attempts,
293
- retry_interval=self.retry_interval,
294
- exponential_backoff=self.exponential_backoff,
295
285
  )
296
286
 
297
287
 
@@ -41,8 +41,12 @@ from langfun.core.structured.schema import ValueRepr
41
41
  from langfun.core.structured.schema import ValueJsonRepr
42
42
  from langfun.core.structured.schema import ValuePythonRepr
43
43
  from langfun.core.structured.schema import schema_repr
44
+ from langfun.core.structured.schema import source_form
44
45
  from langfun.core.structured.schema import value_repr
45
46
 
47
+ from langfun.core.structured.schema_generation import generate_class
48
+ from langfun.core.structured.schema_generation import classgen_example
49
+ from langfun.core.structured.schema_generation import default_classgen_examples
46
50
 
47
51
  from langfun.core.structured.mapping import Mapping
48
52
  from langfun.core.structured.mapping import MappingExample
@@ -68,8 +72,8 @@ from langfun.core.structured.scoring import score
68
72
 
69
73
  # Expose default examples for structured operations so users could refer to
70
74
  # them.
71
- from langfun.core.structured.parsing import DEFAULT_PARSE_EXAMPLES
72
- from langfun.core.structured.description import DEFAULT_DESCRIBE_EXAMPLES
75
+ from langfun.core.structured.parsing import default_parse_examples
76
+ from langfun.core.structured.description import default_describe_examples
73
77
 
74
78
  # Default examples.
75
79
 
@@ -106,58 +106,61 @@ def describe(
106
106
  Returns:
107
107
  The parsed result based on the schema.
108
108
  """
109
- if examples is None:
110
- examples = DEFAULT_DESCRIBE_EXAMPLES
111
109
  return DescribeStructure(
112
- input=value, context=context, examples=examples, **kwargs
110
+ input=value,
111
+ context=context,
112
+ examples=examples or default_describe_examples(),
113
+ **kwargs,
113
114
  )(lm=lm, cache_seed=cache_seed).text
114
115
 
115
116
 
116
- class _Country(pg.Object):
117
- """A example dataclass for structured mapping."""
118
-
119
- name: str
120
- continents: list[
121
- Literal[
122
- 'Africa',
123
- 'Asia',
124
- 'Europe',
125
- 'Oceania',
126
- 'North America',
127
- 'South America',
128
- ]
117
+ def default_describe_examples() -> list[mapping.MappingExample]:
118
+ """Default describe examples."""
119
+
120
+ class Country(pg.Object):
121
+ """A example dataclass for structured mapping."""
122
+
123
+ name: str
124
+ continents: list[
125
+ Literal[
126
+ 'Africa',
127
+ 'Asia',
128
+ 'Europe',
129
+ 'Oceania',
130
+ 'North America',
131
+ 'South America',
132
+ ]
133
+ ]
134
+ num_states: int
135
+ neighbor_countries: list[str]
136
+ population: int
137
+ capital: str | None
138
+ president: str | None
139
+
140
+ return [
141
+ mapping.MappingExample(
142
+ context='Brief intro to United States',
143
+ input=Country(
144
+ name='The United States of America',
145
+ continents=['North America'],
146
+ num_states=50,
147
+ neighbor_countries=[
148
+ 'Canada',
149
+ 'Mexico',
150
+ 'Bahamas',
151
+ 'Cuba',
152
+ 'Russia',
153
+ ],
154
+ population=333000000,
155
+ capital='Washington, D.C',
156
+ president=None,
157
+ ),
158
+ output=inspect.cleandoc("""
159
+ The United States of America is a country primarily located in North America
160
+ consisting of fifty states. It shares land borders with Canada to its north
161
+ and with Mexico to its south and has maritime borders with the Bahamas, Cuba,
162
+ Russia, and other nations. With a population of over 333 million. The national
163
+ capital of the United States is Washington, D.C.
164
+ """),
165
+ ),
129
166
  ]
130
- num_states: int
131
- neighbor_countries: list[str]
132
- population: int
133
- capital: str | None
134
- president: str | None
135
-
136
-
137
- DEFAULT_DESCRIBE_EXAMPLES: list[mapping.MappingExample] = [
138
- mapping.MappingExample(
139
- context='Brief intro to United States',
140
- input=_Country(
141
- name='The United States of America',
142
- continents=['North America'],
143
- num_states=50,
144
- neighbor_countries=[
145
- 'Canada',
146
- 'Mexico',
147
- 'Bahamas',
148
- 'Cuba',
149
- 'Russia',
150
- ],
151
- population=333000000,
152
- capital='Washington, D.C',
153
- president=None,
154
- ),
155
- output=inspect.cleandoc("""
156
- The United States of America is a country primarily located in North America
157
- consisting of fifty states. It shares land borders with Canada to its north
158
- and with Mexico to its south and has maritime borders with the Bahamas, Cuba,
159
- Russia, and other nations. With a population of over 333 million. The national
160
- capital of the United States is Washington, D.C.
161
- """),
162
- ),
163
- ]
@@ -293,25 +293,27 @@ class Mapping(lf.LangFunc):
293
293
 
294
294
  def transform_output(self, lm_output: lf.Message) -> lf.Message:
295
295
  """Transforms LM response into structure if schema is present."""
296
- schema = self.mapping_request.schema
297
- if schema is None:
298
- return lm_output
299
-
300
296
  try:
301
- result = schema.parse(
302
- lm_output.text,
303
- protocol=self.protocol,
304
- additional_context=self.globals(),
305
- autofix=self.autofix,
306
- autofix_lm=self.autofix_lm or self.lm,
307
- )
308
- lm_output.result = self.postprocess_result(result)
297
+ lm_output.result = self.postprocess_result(self.parse_result(lm_output))
309
298
  except Exception as e: # pylint: disable=broad-exception-caught
310
299
  if self.default == lf.RAISE_IF_HAS_ERROR:
311
300
  raise e
312
301
  lm_output.result = self.default
313
302
  return lm_output
314
303
 
304
+ def parse_result(self, lm_output: lf.Message) -> Any:
305
+ """Parse result from LLM response."""
306
+ schema = self.mapping_request.schema
307
+ if schema is None:
308
+ return None
309
+ return schema.parse(
310
+ lm_output.text,
311
+ protocol=self.protocol,
312
+ additional_context=self.globals(),
313
+ autofix=self.autofix,
314
+ autofix_lm=self.autofix_lm or self.lm,
315
+ )
316
+
315
317
  def postprocess_result(self, result: Any) -> Any:
316
318
  """Post process structured output."""
317
319
  return result
@@ -162,11 +162,11 @@ def parse(
162
162
  message.source = lf.UserMessage(user_prompt, tags=['lm-input'])
163
163
  context = getattr(message.lm_input, 'text', None) if include_context else None
164
164
 
165
- if examples is None:
166
- examples = DEFAULT_PARSE_EXAMPLES
167
-
168
165
  t = _parse_structure_cls(protocol)(
169
- schema=schema, context=context, default=default, examples=examples
166
+ schema=schema,
167
+ context=context,
168
+ default=default,
169
+ examples=examples or default_parse_examples(),
170
170
  )
171
171
 
172
172
  # Setting up context.
@@ -296,17 +296,19 @@ def _parse_structure_cls(
296
296
  raise ValueError(f'Unknown protocol: {protocol!r}.')
297
297
 
298
298
 
299
- class _AdditionResults(pg.Object):
300
- one_plus_one_equals: int | None
301
- two_plus_two_equals: int | None
299
+ def default_parse_examples() -> list[mapping.MappingExample]:
300
+ """Default parsing examples."""
302
301
 
302
+ class AdditionResults(pg.Object):
303
+ one_plus_one_equals: int | None
304
+ two_plus_two_equals: int | None
303
305
 
304
- DEFAULT_PARSE_EXAMPLES: list[mapping.MappingExample] = [
305
- mapping.MappingExample(
306
- input='Two plus two equals four. Three plus three equals six.',
307
- schema=_AdditionResults,
308
- output=_AdditionResults(
309
- one_plus_one_equals=None, two_plus_two_equals=4
310
- ),
311
- ),
312
- ]
306
+ return [
307
+ mapping.MappingExample(
308
+ input='Two plus two equals four. Three plus three equals six.',
309
+ schema=AdditionResults,
310
+ output=AdditionResults(
311
+ one_plus_one_equals=None, two_plus_two_equals=4
312
+ ),
313
+ ),
314
+ ]
@@ -301,23 +301,43 @@ class SchemaRepr(metaclass=abc.ABCMeta):
301
301
  class SchemaPythonRepr(SchemaRepr):
302
302
  """Python-representation for a schema."""
303
303
 
304
- def repr(self, schema: Schema) -> str:
305
- ret = self.result_definition(schema)
306
- class_definition_str = self.class_definitions(schema)
304
+ def repr(
305
+ self,
306
+ schema: Schema,
307
+ *,
308
+ include_result_definition: bool = True,
309
+ markdown: bool = True,
310
+ **kwargs,
311
+ ) -> str:
312
+ ret = ''
313
+ if include_result_definition:
314
+ ret += self.result_definition(schema)
315
+ class_definition_str = self.class_definitions(
316
+ schema, markdown=markdown, **kwargs
317
+ )
307
318
  if class_definition_str:
308
- ret += f'\n\n```python\n{class_definition_str}```'
309
- return ret
319
+ ret += f'\n\n{class_definition_str}'
320
+ return ret.strip()
310
321
 
311
- def class_definitions(self, schema: Schema) -> str | None:
322
+ def class_definitions(self, schema: Schema, **kwargs) -> str | None:
312
323
  deps = schema.class_dependencies(include_subclasses=True)
313
- return class_definitions(deps)
324
+ return class_definitions(deps, **kwargs)
314
325
 
315
326
  def result_definition(self, schema: Schema) -> str:
316
327
  return annotation(schema.spec)
317
328
 
318
329
 
330
+ def source_form(value, markdown: bool = False) -> str:
331
+ """Returns the source code form of an object."""
332
+ return ValuePythonRepr().repr(value, markdown=markdown)
333
+
334
+
319
335
  def class_definitions(
320
- classes: Sequence[Type[Any]], strict: bool = False, markdown: bool = False
336
+ classes: Sequence[Type[Any]],
337
+ *,
338
+ include_pg_object_as_base: bool = False,
339
+ strict: bool = False,
340
+ markdown: bool = False,
321
341
  ) -> str | None:
322
342
  """Returns a str for class definitions."""
323
343
  if not classes:
@@ -326,14 +346,22 @@ def class_definitions(
326
346
  for i, cls in enumerate(classes):
327
347
  if i > 0:
328
348
  def_str.write('\n')
329
- def_str.write(class_definition(cls, strict))
349
+ def_str.write(
350
+ class_definition(
351
+ cls,
352
+ strict=strict,
353
+ include_pg_object_as_base=include_pg_object_as_base,
354
+ )
355
+ )
330
356
  ret = def_str.getvalue()
331
357
  if markdown and ret:
332
358
  ret = f'```python\n{ret}```'
333
359
  return ret
334
360
 
335
361
 
336
- def class_definition(cls, strict: bool = False) -> str:
362
+ def class_definition(
363
+ cls, strict: bool = False, include_pg_object_as_base: bool = False
364
+ ) -> str:
337
365
  """Returns the Python class definition."""
338
366
  out = io.StringIO()
339
367
  if not issubclass(cls, pg.Object):
@@ -344,10 +372,9 @@ def class_definition(cls, strict: bool = False) -> str:
344
372
  schema = cls.__schema__
345
373
  eligible_bases = []
346
374
  for base_cls in cls.__bases__:
347
- if issubclass(base_cls, pg.Symbolic) and not base_cls.__module__.startswith(
348
- 'pyglove'
349
- ):
350
- eligible_bases.append(base_cls.__name__)
375
+ if issubclass(base_cls, pg.Object):
376
+ if include_pg_object_as_base or base_cls is not pg.Object:
377
+ eligible_bases.append(base_cls.__name__)
351
378
  if eligible_bases:
352
379
  base_cls_str = ', '.join(eligible_bases)
353
380
  out.write(f'class {cls.__name__}({base_cls_str}):\n')
@@ -547,8 +574,20 @@ class ValuePythonRepr(ValueRepr):
547
574
  markdown: bool = True,
548
575
  **kwargs) -> str:
549
576
  del schema
550
- object_code = pg.format(
551
- value, compact=compact, verbose=verbose, python_format=True)
577
+ if inspect.isclass(value):
578
+ cls_schema = Schema.from_value(value)
579
+ if isinstance(cls_schema.spec, pg.typing.Object):
580
+ object_code = SchemaPythonRepr().class_definitions(
581
+ cls_schema, markdown=markdown, include_pg_object_as_base=True
582
+ )
583
+ assert object_code is not None
584
+ return object_code
585
+ else:
586
+ object_code = SchemaPythonRepr().result_definition(cls_schema)
587
+ else:
588
+ object_code = pg.format(
589
+ value, compact=compact, verbose=verbose, python_format=True
590
+ )
552
591
  if markdown:
553
592
  return f'```python\n{ object_code }\n```'
554
593
  return object_code
@@ -588,6 +627,7 @@ def structure_from_python(
588
627
  global_vars = global_vars or {}
589
628
  global_vars.update({
590
629
  'pg': pg,
630
+ 'Object': pg.Object,
591
631
  'Any': typing.Any,
592
632
  'List': typing.List,
593
633
  'Tuple': typing.Tuple,
@@ -0,0 +1,175 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """LLM-based class generation."""
15
+
16
+ import typing
17
+ from typing import Any, Type
18
+ import langfun.core as lf
19
+ from langfun.core.coding.python import correction
20
+ from langfun.core.structured import mapping
21
+ import pyglove as pg
22
+
23
+
24
+ class GenerateClass(mapping.Mapping):
25
+ """Python class generation."""
26
+
27
+ input_title = 'GENERATION_CONTEXT'
28
+ context_title = 'CLASS_NAME'
29
+ output_title = 'OUTPUT_CLASS'
30
+
31
+ preamble: lf.Template = lf.Template("""
32
+ Help generate a class based on the last {{ context_title }} and {{ input_title }}.
33
+
34
+ Instructions:
35
+ - Use `Object` as the base class for all generated classes
36
+ - Create auxillary classes for composition if needed.
37
+ - Use Python type annotation for declaraing fields:
38
+ (e.g. bool, str, int, float, Optional[str], List[int], Union[str, int])
39
+ - Do not use types that need import.
40
+ - Avoid self-referential types. e.g:
41
+ ```
42
+ class Node(Object):
43
+ children: list[Node]
44
+ ```
45
+ - Do not generate methods.
46
+ """)
47
+
48
+ def parse_result(self, lm_output: lf.Message) -> Type[Any]:
49
+ output_vars, final_code = correction.run_with_correction(
50
+ lm_output.text,
51
+ global_vars=self.allowed_annotation_types,
52
+ sandbox=False,
53
+ max_attempts=self.autofix,
54
+ lm=self.autofix_lm,
55
+ returns_code=True,
56
+ outputs_intermediate=True,
57
+ )
58
+ class_name = self.context
59
+ cls = output_vars.get(class_name, None)
60
+ if cls is None:
61
+ raise correction.errors.CodeError(
62
+ final_code,
63
+ TypeError(f'Class {class_name} is absent from LLM output.'),
64
+ )
65
+ return cls
66
+
67
+ @property
68
+ def allowed_annotation_types(self):
69
+ return dict(
70
+ pg=pg,
71
+ Any=typing.Any,
72
+ Object=pg.Object,
73
+ List=typing.List,
74
+ Dict=typing.Tuple,
75
+ Tuple=typing.Tuple,
76
+ Sequence=typing.Sequence,
77
+ Optional=typing.Optional,
78
+ Union=typing.Union,
79
+ )
80
+
81
+
82
+ def generate_class(
83
+ name: str,
84
+ prompt: str | pg.Symbolic,
85
+ *,
86
+ lm: lf.LanguageModel | None = None,
87
+ examples: list[mapping.MappingExample] | None = None,
88
+ returns_message: bool = False,
89
+ skip_lm: bool = False,
90
+ **kwargs,
91
+ ) -> Type[Any] | lf.Message:
92
+ """Generate a class with specified name based on the prompt.
93
+
94
+ Example:
95
+ ```
96
+ trip_cls = lf.classgen(
97
+ 'Trip',
98
+ 'A trip plan to visit {{ city }}, city='San Francisco',
99
+ lm=lf.llms.GeminiPro()
100
+ )
101
+ ```
102
+
103
+ Args:
104
+ name: Class name to be generated.
105
+ prompt: A str (may contain {{}} as template) as natural language input, or a
106
+ `pg.Symbolic` object as structured input as prompt to LLM.
107
+ lm: The language model to use. If not specified, the language model from
108
+ `lf.context` context manager will be used.
109
+ examples: An optional list of fewshot examples for helping class generation.
110
+ If None, a default single shot example will be used. Use
111
+ `lf.structured.classgen_example` to generate example.
112
+ returns_message: If True, returns `lf.Message` as the output, instead of
113
+ returning the structured `message.result`.
114
+ skip_lm: If True, returns the rendered prompt as a UserMessage object.
115
+ otherwise return the LLM response based on the rendered prompt.
116
+ **kwargs: Template variables passed to `prompt` and keyword arguments passed
117
+ to `lf.structured.GenerateClass`.
118
+
119
+ Returns:
120
+ Generated class.
121
+
122
+ Raises:
123
+ CodeError: if generation failed.
124
+ """
125
+ if isinstance(prompt, str):
126
+ prompt = lf.Template(prompt, **kwargs)
127
+ elif isinstance(prompt, lf.Template):
128
+ prompt = prompt.rebind(**kwargs, raise_on_no_change=False)
129
+
130
+ if isinstance(prompt, lf.Template):
131
+ prompt = prompt.render(lm=lm)
132
+
133
+ call_kwargs = dict(skip_lm=skip_lm)
134
+ if lm is not None:
135
+ call_kwargs['lm'] = lm
136
+ message = GenerateClass(
137
+ input=prompt,
138
+ context=name,
139
+ examples=examples or default_classgen_examples(),
140
+ **kwargs,
141
+ )(**call_kwargs)
142
+ return message if returns_message else message.result
143
+
144
+
145
+ def classgen_example(
146
+ class_name: str, prompt: str | pg.Symbolic, cls: Type[Any]
147
+ ) -> mapping.MappingExample:
148
+ """Creates a class generation example."""
149
+ if isinstance(prompt, lf.Template):
150
+ prompt = prompt.render()
151
+ return mapping.MappingExample(
152
+ input=prompt,
153
+ context=class_name,
154
+ output=cls,
155
+ )
156
+
157
+
158
+ def default_classgen_examples() -> list[mapping.MappingExample]:
159
+ """Default examples for class generation."""
160
+
161
+ class Step(pg.Object):
162
+ description: str
163
+ output: float
164
+
165
+ class Solution(pg.Object):
166
+ steps: list[Step] # pytype: disable=invalid-annotation
167
+ result: float
168
+
169
+ return [
170
+ classgen_example(
171
+ 'Solution',
172
+ 'How to evaluate an arithmetic expression?',
173
+ Solution,
174
+ )
175
+ ]
@@ -0,0 +1,104 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import inspect
15
+ import unittest
16
+
17
+ import langfun.core.coding as lf_coding
18
+ from langfun.core.llms import fake
19
+ from langfun.core.structured import schema_generation
20
+
21
+
22
+ class GenerateClassTest(unittest.TestCase):
23
+
24
+ def test_generate_class_prompt(self):
25
+ input_message = schema_generation.generate_class(
26
+ 'Trip',
27
+ 'Generate a trip class',
28
+ skip_lm=True,
29
+ returns_message=True,
30
+ )
31
+ self.maxDiff = None
32
+ self.assertEqual(
33
+ input_message.text,
34
+ inspect.cleandoc("""
35
+ Help generate a class based on the last CLASS_NAME and GENERATION_CONTEXT.
36
+
37
+ Instructions:
38
+ - Use `Object` as the base class for all generated classes
39
+ - Create auxillary classes for composition if needed.
40
+ - Use Python type annotation for declaraing fields:
41
+ (e.g. bool, str, int, float, Optional[str], List[int], Union[str, int])
42
+ - Do not use types that need import.
43
+ - Avoid self-referential types. e.g:
44
+ ```
45
+ class Node(Object):
46
+ children: list[Node]
47
+ ```
48
+ - Do not generate methods.
49
+
50
+ CLASS_NAME:
51
+ Solution
52
+
53
+ GENERATION_CONTEXT:
54
+ How to evaluate an arithmetic expression?
55
+
56
+ OUTPUT_CLASS:
57
+ ```python
58
+ class Step(Object):
59
+ description: str
60
+ output: float
61
+
62
+ class Solution(Object):
63
+ steps: list[Step]
64
+ result: float
65
+ ```
66
+
67
+
68
+ CLASS_NAME:
69
+ Trip
70
+
71
+ GENERATION_CONTEXT:
72
+ Generate a trip class
73
+
74
+ OUTPUT_CLASS:
75
+ """),
76
+ )
77
+
78
+ def test_generate_class(self):
79
+ lm = fake.StaticResponse("""
80
+ ```python
81
+ class A(Object):
82
+ x: int
83
+
84
+ class B(Object):
85
+ a: A
86
+ ```
87
+ """)
88
+ cls = schema_generation.generate_class(
89
+ 'B',
90
+ 'Generate a B class with a field pointing to another class A',
91
+ lm=lm,
92
+ )
93
+ self.assertIs(cls.__name__, 'B')
94
+
95
+ with self.assertRaises(lf_coding.CodeError):
96
+ schema_generation.generate_class(
97
+ 'Foo',
98
+ 'Generate a Foo class with a field pointing to another class A',
99
+ lm=lm,
100
+ )
101
+
102
+
103
+ if __name__ == '__main__':
104
+ unittest.main()
@@ -435,6 +435,10 @@ class SchemaPythonReprTest(unittest.TestCase):
435
435
  schema_lib.class_definition(A),
436
436
  'class A:\n pass\n',
437
437
  )
438
+ self.assertEqual(
439
+ schema_lib.class_definition(A, include_pg_object_as_base=True),
440
+ 'class A(Object):\n pass\n',
441
+ )
438
442
 
439
443
  class B:
440
444
  pass
@@ -520,6 +524,32 @@ class SchemaPythonReprTest(unittest.TestCase):
520
524
  ```
521
525
  """),
522
526
  )
527
+ self.assertEqual(
528
+ schema_lib.SchemaPythonRepr().repr(
529
+ schema,
530
+ include_result_definition=False,
531
+ include_pg_object_as_base=True,
532
+ markdown=False,
533
+ ),
534
+ inspect.cleandoc("""
535
+ class Foo(Object):
536
+ x: int
537
+
538
+ class A(Object):
539
+ foo: Foo
540
+
541
+ class Bar(Object):
542
+ y: str
543
+
544
+ class Baz(Bar):
545
+ y: str
546
+
547
+ class B(A):
548
+ foo: Foo
549
+ bar: Bar
550
+ foo2: Foo
551
+ """),
552
+ )
523
553
 
524
554
 
525
555
  class SchemaJsonReprTest(unittest.TestCase):
@@ -559,6 +589,20 @@ class ValuePythonReprTest(unittest.TestCase):
559
589
  ),
560
590
  "A(foo=[Foo(x=1), Foo(x=2)], y='bar')",
561
591
  )
592
+ self.assertEqual(
593
+ schema_lib.ValuePythonRepr().repr(A),
594
+ inspect.cleandoc("""
595
+ ```python
596
+ class Foo(Object):
597
+ x: int
598
+
599
+ class A(Object):
600
+ foo: list[Foo]
601
+ y: str | None
602
+ ```
603
+ """),
604
+ )
605
+ self.assertEqual(schema_lib.source_form(int), 'int')
562
606
 
563
607
  def test_parse(self):
564
608
  class Foo(pg.Object):
langfun/core/template.py CHANGED
@@ -124,9 +124,12 @@ class Template(
124
124
 
125
125
  @classmethod
126
126
  def resolve_vars(cls, template_str: str) -> Set[str]:
127
- env = jinja2.Environment()
128
- ast = env.parse(template_str)
129
- return jinja2_meta.find_undeclared_variables(ast)
127
+ try:
128
+ env = jinja2.Environment()
129
+ ast = env.parse(template_str)
130
+ return jinja2_meta.find_undeclared_variables(ast)
131
+ except jinja2.TemplateSyntaxError as e:
132
+ raise ValueError(f'Bad template string:\n\n{template_str}') from e
130
133
 
131
134
  def _on_bound(self) -> None:
132
135
  super()._on_bound()
@@ -154,6 +154,10 @@ class DefinitionTest(unittest.TestCase):
154
154
  with self.assertRaisesRegex(TypeError, '.* missing 1 required argument'):
155
155
  MyPrompt(y=2)()
156
156
 
157
+ def test_bad_template(self):
158
+ with self.assertRaisesRegex(ValueError, 'Bad template string.*'):
159
+ Template('{{x=1')
160
+
157
161
 
158
162
  class VarsTest(unittest.TestCase):
159
163
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240314
3
+ Version: 0.0.2.dev20240316
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -24,7 +24,7 @@ License-File: LICENSE
24
24
  Requires-Dist: google-generativeai >=0.3.2
25
25
  Requires-Dist: jinja2 >=3.1.2
26
26
  Requires-Dist: openai ==0.27.2
27
- Requires-Dist: pyglove >=0.4.5.dev20240201
27
+ Requires-Dist: pyglove >=0.4.5.dev20240314
28
28
  Requires-Dist: python-magic >=0.4.27
29
29
  Requires-Dist: requests >=2.31.0
30
30
  Requires-Dist: termcolor ==1.1.0
@@ -1,4 +1,4 @@
1
- langfun/__init__.py,sha256=8H9dYWG6gM3SlQfTH9BNiwVtIbe7Wz8XUIK5uolY1Z0,1760
1
+ langfun/__init__.py,sha256=PqX3u18BC0szYIMu00j-RKxvwkNPwXtAFZ-96oxrQ0M,1841
2
2
  langfun/core/__init__.py,sha256=sVcPl89lWYHQ1cUoaLaM8dErCovugJo5e2F3A_94Q3Y,4192
3
3
  langfun/core/component.py,sha256=VRPfDB_2jEnxcB3-HoiVjG4ID-SMenNPIsytb0uXMPg,9674
4
4
  langfun/core/component_test.py,sha256=VAPd6V_-odAe8rBvesW3ogYDd6OSqRq4FaPhfgOM4Zg,7949
@@ -7,9 +7,9 @@ langfun/core/concurrent_test.py,sha256=mwFMZhDUdppnDr7vDSTwcbMHwrdsIoKJwRYNtl4ZW
7
7
  langfun/core/console.py,sha256=bk5rNPNm9rMGW5YT2HixxU04p2umnoabn5SDz6Dqe88,2317
8
8
  langfun/core/console_test.py,sha256=5SYJdxpJGLgdSSQqqMPoA1X6jpsLD8rgcyk-EgI65oE,1077
9
9
  langfun/core/langfunc.py,sha256=WXdTc3QsmGD_n80KD9dFRr5MHpGZ9E_y_Rhtk4t9-3w,11852
10
- langfun/core/langfunc_test.py,sha256=8WeiyNauZPkbAA3HiLjVw5-pRSmiLlz-77lB_fjHGdA,8317
11
- langfun/core/language_model.py,sha256=Qbm7wxgxW26bCVwtgpp-4aV3BKYAsb4IJrJuzhf3Q6o,15507
12
- langfun/core/language_model_test.py,sha256=h5MWooOb9HubvOzxaBnH6WuDYBdxTetu7JZSWDzva3M,11368
10
+ langfun/core/langfunc_test.py,sha256=dFNJoEXExIkrAJ9_PSWh_iRQoR4Gmp2VOZ_ve61DSHM,8339
11
+ langfun/core/language_model.py,sha256=jPuFfjnRCnbT8po-CBPgmXoa09Yfk5_21snCXURqaKU,17011
12
+ langfun/core/language_model_test.py,sha256=q7pNdirVWfkQXPA3taCGnyLB2NNs1KqX4JjjnoJvFOQ,11365
13
13
  langfun/core/memory.py,sha256=f-asN1F7Vehgdn_fK84v73GrEUOxRtaW934keutTKjk,2416
14
14
  langfun/core/message.py,sha256=QhvV9t5qaryPcruyxxcXi3gm9QDInkSldwTtK6sVJ3c,15734
15
15
  langfun/core/message_test.py,sha256=Z23pUM5vPnDrYkIIibe2KL73D5HKur_awI0ut_EQFQA,9501
@@ -21,8 +21,8 @@ langfun/core/sampling.py,sha256=vygWvgC8MFw0_AKNSmz-ywMXJYWf8cl0tI8QycvAmyI,5795
21
21
  langfun/core/sampling_test.py,sha256=U7PANpMsl9E_pa4_Y4FzesSjcwg-u-LKHGCWSgv-8FY,3663
22
22
  langfun/core/subscription.py,sha256=euawEuSZP-BHydaT-AQpfYFL0m5pWPGcW0upFhrojqc,10930
23
23
  langfun/core/subscription_test.py,sha256=Y4ZdbZEwm83YNZBxHff0QR4QUa4rdaNXA3_jfIcArBo,8717
24
- langfun/core/template.py,sha256=1UnXgVuqdGRes0vSMamIQ8KpG0NgTXV1iCSWVbCybN4,17530
25
- langfun/core/template_test.py,sha256=WNf7O45V5BZz7IaAAcZhGANyaKMwpWjBpcpEIASfh-Q,13446
24
+ langfun/core/template.py,sha256=zVD8dAsXFfgF25aKh2WqSuCEHVqriCC-4tLbQqTMa2w,17662
25
+ langfun/core/template_test.py,sha256=1hDdYfvXJVoslTUudh3WhxU7VnDSiIz6MkxPfmuHKAY,13572
26
26
  langfun/core/text_formatting.py,sha256=ytjj7opnRJ6w-pkglL2CZUyfYDXLpNf65E42LBb31gc,5158
27
27
  langfun/core/text_formatting_test.py,sha256=nyKC6tn2L4hPJiqQHgxcbQsJJi4A4Nbj8FiO8iT6B80,1514
28
28
  langfun/core/coding/__init__.py,sha256=5utju_fwEsImaiftx4oXKl9FAM8p281k8-Esdh_-m1w,835
@@ -49,11 +49,11 @@ langfun/core/eval/scoring_test.py,sha256=_L_B40VZkyI2_PJce-jVKYC4llrO4jGUR5j86Gu
49
49
  langfun/core/llms/__init__.py,sha256=T4mgT091BLA4mHrOjAvEGhZPHf0tiYgqD88l_JTp1dQ,2386
50
50
  langfun/core/llms/fake.py,sha256=dVzOrW27RZ1p3DdQoRCRZs_vfoQcTcNrlWxia7oqmvw,2499
51
51
  langfun/core/llms/fake_test.py,sha256=Qk_Yoi4Z7P9o6f8Q_BZkaSlvxH89ZVsDxnVIbSBRBXk,3555
52
- langfun/core/llms/gemini.py,sha256=9HxrTvac_dMbDytNCEo6YcYqs8vsePtScfI_EygpI5Y,5677
52
+ langfun/core/llms/gemini.py,sha256=E7JGewkgjMzuDQxAn8CBbwWsDDZH4jcmNbzmO3OvdsY,5474
53
53
  langfun/core/llms/gemini_test.py,sha256=ybNNCn3JW3hYpMe0wT5ILGDrMPaYYU8PN2kSookM0jk,5433
54
- langfun/core/llms/llama_cpp.py,sha256=EIjJa1-Tg4_VaIxVR88oDWSWc_axc1r2KwSPpl4PSp0,2549
54
+ langfun/core/llms/llama_cpp.py,sha256=sJ9TOismqwGJ7QhgdYknWTEkqrbeZpWYc_nClOh36NU,2320
55
55
  langfun/core/llms/llama_cpp_test.py,sha256=ZxC6defGd_HX9SFRU9U4cJiQnBKundbOrchbXuC1Z2M,1683
56
- langfun/core/llms/openai.py,sha256=-PnJ8OICuPTzr-unIys4HftNVZ6seAhV5nXny4OfVYc,11715
56
+ langfun/core/llms/openai.py,sha256=BV8NWjB1b6A1X4Kff8Pub5AECodsngZnXqeBvRIHFM0,11331
57
57
  langfun/core/llms/openai_test.py,sha256=yfw7A-4Zo9u1cIkAMk39evE-tO7z6isNYTXiSnJXDQw,7599
58
58
  langfun/core/llms/cache/__init__.py,sha256=QAo3InUMDM_YpteNnVCSejI4zOsnjSMWKJKzkb3VY64,993
59
59
  langfun/core/llms/cache/base.py,sha256=cFfYvOIUae842pncqCAsRvqXCk2AnAsRYVx0mcIoAeY,3338
@@ -69,19 +69,21 @@ langfun/core/modalities/mime.py,sha256=wVfaYflhGz1W4v3m972rAplW3OGOFtjFpHDYIaUD5
69
69
  langfun/core/modalities/mime_test.py,sha256=cVHxRvJ1QXC1SVhBmWkJdWGpL9Xl0UNfTQq6j0OGGL4,1881
70
70
  langfun/core/modalities/video.py,sha256=5-sIlzXb_ZY84RMFcpVD9ysP9GbcwbdKaZOEm3jECtc,1469
71
71
  langfun/core/modalities/video_test.py,sha256=jYuI2m8S8zDCAVBPEUbbpP205dXAht90A2_PHWo4-r8,2039
72
- langfun/core/structured/__init__.py,sha256=LZ5BFLX6VXy1zH17yChWCdg8bvIDrhtL2lqtSCwtZ-M,3187
72
+ langfun/core/structured/__init__.py,sha256=SpObW-HKpyKvkLlX8FV5ixz7CRm098j2aGfOguM3AUI,3462
73
73
  langfun/core/structured/completion.py,sha256=skBxt6V_fv2TBUKnzFgnPMbVY8HSYn8sY04MLok2yvs,7299
74
74
  langfun/core/structured/completion_test.py,sha256=98UCgA4gzfp6H6HgP2s2kcKs25YH3k4Nxj1rgAvmVBw,19249
75
- langfun/core/structured/description.py,sha256=3MLTbpTpeiBqMRe3WfDNIxtrE6WQsKJsJdkbfcyPlsg,5088
75
+ langfun/core/structured/description.py,sha256=SXW4MJvshFjbR-0gw6rE21o6WXq12UlRXawvDBXMZFA,5211
76
76
  langfun/core/structured/description_test.py,sha256=UtZGjSFUaQ6130t1E5tcL7ODu0xIefkapb53TbnqsK8,7362
77
- langfun/core/structured/mapping.py,sha256=lGkjhmvVdhBGgJmc5KbfT2xQjC1MuU4OCcCfsAYJjaQ,10192
77
+ langfun/core/structured/mapping.py,sha256=tahkaAB-L6yKbYb7qjVI301-FfIARdw4w8nP3wqS2-k,10291
78
78
  langfun/core/structured/mapping_test.py,sha256=07DDCGbwytQHSMm7fCi5-Ly-JNgdV4ubHZq0wthX4A4,3338
79
- langfun/core/structured/parsing.py,sha256=YKWl9ZQ2uFkt78SXiRISWHg8_cDMGMwAN3SeK-OqWt4,11382
79
+ langfun/core/structured/parsing.py,sha256=yTKuezai5i-X9W-jU0DeEZzqHHbCFom0plj-D0bhp98,11436
80
80
  langfun/core/structured/parsing_test.py,sha256=2_Uf3LYNRON1-5ysEr75xiG_cAxR3ZiixSfvUQu6mOQ,20846
81
81
  langfun/core/structured/prompting.py,sha256=0xRPC0K_RaFRv-j52x8_-1n1eRFSomJEpdZApVXsCV0,6902
82
82
  langfun/core/structured/prompting_test.py,sha256=SwoYbPyKhUT1H2QbqHvl93biCiE9Ttn1aWixoHH-v9Y,19129
83
- langfun/core/structured/schema.py,sha256=5DKba0LrvXCJFRY-NVfER3p54BLOB7M3Yi2-u5IAJTw,24115
84
- langfun/core/structured/schema_test.py,sha256=LEtCST5Bfwoke59I6Q1mnOJLf2cFXQwKwTeAkI2hgqM,20912
83
+ langfun/core/structured/schema.py,sha256=60griJ-yC1SExX6g-aOcAOo8yFh53CdwMV4EVK3ivug,25207
84
+ langfun/core/structured/schema_generation.py,sha256=Yv9flJ4GTtLw-bDB8S7A93G-z4gXsFMkMASkbiduT3E,5353
85
+ langfun/core/structured/schema_generation_test.py,sha256=cfZyP0gHno2fXy_c9vsVdvHmqKQSfuyUsCtfO3JFmYQ,2945
86
+ langfun/core/structured/schema_test.py,sha256=kMIgnAzm3f2O5ofn0pPKjT6H8hny4cWVaUVDOZuyjOQ,21987
85
87
  langfun/core/structured/scoring.py,sha256=a3vfGnqf-DOWjD07MF54GCZTO_R1RTxTDVPzerXnU0s,2325
86
88
  langfun/core/structured/scoring_test.py,sha256=TznLMl0x9QxzmhHz_3Vr44VOXuvFnUSeRQVhu33W5cA,1437
87
89
  langfun/core/templates/__init__.py,sha256=bO0eMsVJbi7sxEB2YlInKRQ2EVP-RyyKUwcD-8msuN4,927
@@ -93,8 +95,8 @@ langfun/core/templates/demonstration.py,sha256=vCrgYubdZM5Umqcgp8NUVGXgr4P_c-fik
93
95
  langfun/core/templates/demonstration_test.py,sha256=SafcDQ0WgI7pw05EmPI2S4v1t3ABKzup8jReCljHeK4,2162
94
96
  langfun/core/templates/selfplay.py,sha256=yhgrJbiYwq47TgzThmHrDQTF4nDrTI09CWGhuQPNv-s,2273
95
97
  langfun/core/templates/selfplay_test.py,sha256=IB5rWbjK_9CTkqEo1BclQPzFAKcIiusJckH8J19HFgI,2096
96
- langfun-0.0.2.dev20240314.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
97
- langfun-0.0.2.dev20240314.dist-info/METADATA,sha256=zA8mV-vbd9FftlkgYSeTTvvUqkuTmjewdbc4NSDhuas,3405
98
- langfun-0.0.2.dev20240314.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
99
- langfun-0.0.2.dev20240314.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
100
- langfun-0.0.2.dev20240314.dist-info/RECORD,,
98
+ langfun-0.0.2.dev20240316.dist-info/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
99
+ langfun-0.0.2.dev20240316.dist-info/METADATA,sha256=rvpQMtNiFs55Okrd1TNlJOS8szUWshlCt5NFB_2vPfs,3405
100
+ langfun-0.0.2.dev20240316.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
101
+ langfun-0.0.2.dev20240316.dist-info/top_level.txt,sha256=RhlEkHxs1qtzmmtWSwYoLVJAc1YrbPtxQ52uh8Z9VvY,8
102
+ langfun-0.0.2.dev20240316.dist-info/RECORD,,