langfun 0.1.2.dev202505130804__py3-none-any.whl → 0.1.2.dev202505140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,6 +14,7 @@
14
14
  """Query LLM for structured output."""
15
15
 
16
16
  import contextlib
17
+ import dataclasses
17
18
  import functools
18
19
  import inspect
19
20
  import time
@@ -549,76 +550,120 @@ def query(
549
550
  else:
550
551
  query_input = schema_lib.mark_missing(prompt)
551
552
 
552
- with lf.track_usages() as usage_summary:
553
- start_time = time.time()
554
- if schema in (None, str):
555
- # Query with natural language output.
556
- output_message = lf.LangFunc.from_value(query_input, **kwargs)(
557
- lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
558
- )
559
- if response_postprocess:
560
- processed_text = response_postprocess(output_message.text)
561
- if processed_text != output_message.text:
562
- output_message = lf.AIMessage(processed_text, source=output_message)
563
- else:
564
- # Query with structured output.
565
- query_cls = LfQuery.from_protocol(protocol)
566
- if ':' not in protocol:
567
- protocol = f'{protocol}:{query_cls.version}'
568
- output_message = query_cls(
569
- input=(
570
- query_input.render(lm=lm)
571
- if isinstance(query_input, lf.Template)
572
- else query_input
573
- ),
574
- schema=schema,
575
- default=default,
576
- examples=examples,
577
- response_postprocess=response_postprocess,
578
- autofix=autofix if protocol.startswith('python:') else 0,
579
- **kwargs,
580
- )(
581
- lm=lm,
582
- autofix_lm=autofix_lm or lm,
583
- cache_seed=cache_seed,
584
- skip_lm=skip_lm,
585
- )
586
- end_time = time.time()
553
+ # Determine query class.
554
+ if schema in (None, str):
555
+ # Non-structured query.
556
+ query_cls = None
557
+ else:
558
+ # Query with structured output.
559
+ query_cls = LfQuery.from_protocol(protocol)
560
+ if ':' not in protocol:
561
+ protocol = f'{protocol}:{query_cls.version}'
562
+
563
+ # `skip_lm`` is True when `lf.query_prompt` is called.
564
+ # and `prompt` is `pg.MISSING_VALUE` when `lf.query_output` is called.
565
+ # In these cases, we do not track the query invocation.
566
+ if skip_lm or pg.MISSING_VALUE == prompt:
567
+ trackers = []
568
+ else:
569
+ trackers = lf.context_value('__query_trackers__', [])
587
570
 
588
- def _result(message: lf.Message):
589
- return message.text if schema in (None, str) else message.result
571
+ # Mark query start with trackers.
572
+ # NOTE: prompt is MISSING_VALUE when `lf.query_output` is called.
573
+ # We do not track the query invocation in this case.
574
+ if trackers:
575
+ invocation = QueryInvocation(
576
+ id=_invocation_id(),
577
+ input=pg.Ref(query_input),
578
+ schema=(
579
+ schema_lib.Schema.from_value(schema)
580
+ if schema not in (None, str) else None
581
+ ),
582
+ default=default,
583
+ lm=pg.Ref(lm),
584
+ examples=pg.Ref(examples) if examples else [],
585
+ protocol=protocol,
586
+ kwargs={k: pg.Ref(v) for k, v in kwargs.items()},
587
+ start_time=time.time(),
588
+ )
589
+ for i, tracker in enumerate(trackers):
590
+ if i == 0 or tracker.include_child_scopes:
591
+ tracker.track(invocation)
592
+ else:
593
+ invocation = None
590
594
 
591
- # Track the query invocations.
592
- if pg.MISSING_VALUE != prompt and not skip_lm:
593
- trackers = lf.context_value('__query_trackers__', [])
594
- if trackers:
595
+ def _mark_query_completed(output_message, error, usage_summary):
596
+ # Mark query completion with trackers.
597
+ if not trackers:
598
+ return
599
+
600
+ if output_message is not None:
595
601
  # To minimize payload for serialization, we remove the result and usage
596
602
  # fields from the metadata. They will be computed on the fly when the
597
603
  # invocation is rendered.
598
604
  metadata = dict(output_message.metadata)
599
605
  metadata.pop('result', None)
600
606
  metadata.pop('usage', None)
607
+ lm_response = lf.AIMessage(output_message.text, metadata=metadata)
608
+ else:
609
+ lm_response = None
601
610
 
602
- invocation = QueryInvocation(
603
- id=_invocation_id(),
604
- input=pg.Ref(query_input),
605
- schema=(
606
- schema_lib.Schema.from_value(schema)
607
- if schema not in (None, str) else None
608
- ),
609
- lm=pg.Ref(lm),
610
- examples=pg.Ref(examples) if examples else [],
611
- protocol=protocol,
612
- kwargs={k: pg.Ref(v) for k, v in kwargs.items()},
613
- lm_response=lf.AIMessage(output_message.text, metadata=metadata),
614
- usage_summary=usage_summary,
615
- start_time=start_time,
616
- end_time=end_time,
611
+ assert invocation is not None
612
+ invocation.mark_completed(
613
+ lm_response=lm_response, error=error, usage_summary=usage_summary,
614
+ )
615
+ for i, tracker in enumerate(trackers):
616
+ if i == 0 or tracker.include_child_scopes:
617
+ tracker.mark_completed(invocation)
618
+
619
+ with lf.track_usages() as usage_summary:
620
+ try:
621
+ if query_cls is None:
622
+ # Query with natural language output.
623
+ output_message = lf.LangFunc.from_value(query_input, **kwargs)(
624
+ lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
625
+ )
626
+ if response_postprocess:
627
+ processed_text = response_postprocess(output_message.text)
628
+ if processed_text != output_message.text:
629
+ output_message = lf.AIMessage(processed_text, source=output_message)
630
+ else:
631
+ # Query with structured output.
632
+ output_message = query_cls(
633
+ input=(
634
+ query_input.render(lm=lm)
635
+ if isinstance(query_input, lf.Template)
636
+ else query_input
637
+ ),
638
+ schema=schema,
639
+ examples=examples,
640
+ response_postprocess=response_postprocess,
641
+ autofix=autofix if protocol.startswith('python:') else 0,
642
+ **kwargs,
643
+ )(
644
+ lm=lm,
645
+ autofix_lm=autofix_lm or lm,
646
+ cache_seed=cache_seed,
647
+ skip_lm=skip_lm,
648
+ )
649
+ _mark_query_completed(output_message, None, usage_summary)
650
+ except mapping.MappingError as e:
651
+ _mark_query_completed(
652
+ e.lm_response, pg.utils.ErrorInfo.from_exception(e), usage_summary
617
653
  )
618
- for i, (tracker, include_child_scopes) in enumerate(trackers):
619
- if i == 0 or include_child_scopes:
620
- tracker.append(invocation)
621
- return output_message if returns_message else _result(output_message)
654
+ if lf.RAISE_IF_HAS_ERROR == default:
655
+ raise e
656
+ output_message = e.lm_response
657
+ output_message.result = default
658
+ except BaseException as e:
659
+ _mark_query_completed(
660
+ None, pg.utils.ErrorInfo.from_exception(e), usage_summary
661
+ )
662
+ raise e
663
+
664
+ if returns_message:
665
+ return output_message
666
+ return output_message.text if schema in (None, str) else output_message.result
622
667
 
623
668
 
624
669
  @contextlib.contextmanager
@@ -768,6 +813,10 @@ def _reward_fn(cls) -> Callable[
768
813
  class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
769
814
  """A class to represent the invocation of `lf.query`."""
770
815
 
816
+ #
817
+ # Query input.
818
+ #
819
+
771
820
  id: Annotated[
772
821
  str,
773
822
  'The ID of the query invocation.'
@@ -777,42 +826,69 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
777
826
  Union[lf.Template, pg.Symbolic],
778
827
  'Mapping input of `lf.query`.'
779
828
  ]
829
+
780
830
  schema: pg.typing.Annotated[
781
831
  schema_lib.schema_spec(noneable=True),
782
832
  'Schema of `lf.query`.'
783
833
  ]
784
- lm_response: Annotated[
785
- lf.Message,
786
- 'Raw LM response.'
787
- ]
834
+
835
+ default: Annotated[
836
+ Any,
837
+ 'Default value of `lf.query`.'
838
+ ] = lf.RAISE_IF_HAS_ERROR
839
+
788
840
  lm: Annotated[
789
841
  lf.LanguageModel,
790
842
  'Language model used for `lf.query`.'
791
843
  ]
844
+
792
845
  examples: Annotated[
793
846
  list[mapping.MappingExample],
794
847
  'Fewshot exemplars for `lf.query`.'
795
848
  ]
849
+
796
850
  protocol: Annotated[
797
851
  str,
798
852
  'Protocol of `lf.query`.'
799
853
  ] = 'python'
854
+
800
855
  kwargs: Annotated[
801
856
  dict[str, Any],
802
857
  'Kwargs of `lf.query`.'
803
858
  ] = {}
804
- usage_summary: Annotated[
805
- lf.UsageSummary,
806
- 'Usage summary for `lf.query`.'
807
- ]
859
+
860
+ #
861
+ # Query output.
862
+ #
863
+
864
+ lm_response: Annotated[
865
+ lf.Message | None,
866
+ 'Raw LM response. If None, query is not completed yet or failed.'
867
+ ] = None
868
+
869
+ error: Annotated[
870
+ pg.utils.ErrorInfo | None,
871
+ 'Error info if the query failed.'
872
+ ] = None
873
+
874
+ #
875
+ # Execution details.
876
+ #
877
+
808
878
  start_time: Annotated[
809
879
  float,
810
880
  'Start time of query.'
811
881
  ]
882
+
812
883
  end_time: Annotated[
813
- float,
814
- 'End time of query.'
815
- ]
884
+ float | None,
885
+ 'End time of query. If None, query is not completed yet.'
886
+ ] = None
887
+
888
+ usage_summary: Annotated[
889
+ lf.UsageSummary,
890
+ 'Usage summary of the query.'
891
+ ] = lf.UsageSummary()
816
892
 
817
893
  @functools.cached_property
818
894
  def lm_request(self) -> lf.Message:
@@ -822,22 +898,26 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
822
898
  **self.kwargs
823
899
  )
824
900
 
825
- @functools.cached_property
901
+ @property
826
902
  def output(self) -> Any:
827
- """The output of `lf.query`. If it failed, returns the `MappingError`."""
828
- try:
829
- return query_output(self.lm_response, self.schema, protocol=self.protocol)
830
- except mapping.MappingError as e:
831
- return e
903
+ """The output of `lf.query`. If it failed, returns None."""
904
+ return self._output
832
905
 
833
906
  @property
834
907
  def has_error(self) -> bool:
835
908
  """Returns True if the query failed to generate a valid output."""
836
- return isinstance(self.output, BaseException)
909
+ return self.error is not None
910
+
911
+ @property
912
+ def has_oop_error(self) -> bool:
913
+ """Returns True if the query failed due to out of memory error."""
914
+ return self.error is not None and self.error.tag.startswith('MappingError')
837
915
 
838
916
  @property
839
917
  def elapse(self) -> float:
840
918
  """Returns query elapse in seconds."""
919
+ if self.end_time is None:
920
+ return time.time() - self.start_time
841
921
  return self.end_time - self.start_time
842
922
 
843
923
  def as_mapping_example(
@@ -848,14 +928,94 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
848
928
  return mapping.MappingExample(
849
929
  input=self.input,
850
930
  schema=self.schema,
851
- output=self.output,
931
+ output=self.lm_response.text if self.has_oop_error else self.output,
852
932
  metadata=metadata or {},
853
933
  )
854
934
 
855
935
  def _on_bound(self):
856
936
  super()._on_bound()
937
+ self._tab_control = None
938
+ self._output = None
857
939
  self.__dict__.pop('lm_request', None)
858
- self.__dict__.pop('output', None)
940
+
941
+ @property
942
+ def is_completed(self) -> bool:
943
+ """Returns True if the query is completed."""
944
+ return self.end_time is not None
945
+
946
+ def mark_completed(
947
+ self,
948
+ lm_response: lf.Message | None,
949
+ error: pg.utils.ErrorInfo | None = None,
950
+ usage_summary: lf.UsageSummary | None = None) -> None:
951
+ """Marks the query as completed."""
952
+ assert self.end_time is None, 'Query is already completed.'
953
+
954
+ if error is None:
955
+ # Autofix could lead to a successful `lf.query`, however, the initial
956
+ # lm_response may not be valid. When Error is None, we always try to parse
957
+ # the lm_response into the output. If the output is not valid, the error
958
+ # will be updated accordingly. This logic could be optimized in future by
959
+ # returning attempt information when autofix is enabled.
960
+ if self.schema is not None:
961
+ try:
962
+ output = query_output(
963
+ lm_response, self.schema,
964
+ default=self.default, protocol=self.protocol
965
+ )
966
+ except mapping.MappingError as e:
967
+ output = None
968
+ error = pg.utils.ErrorInfo.from_exception(e)
969
+ self._output = output
970
+ else:
971
+ assert lm_response is not None
972
+ self._output = lm_response.text
973
+ elif (error.tag.startswith('MappingError')
974
+ and self.default != lf.RAISE_IF_HAS_ERROR):
975
+ self._output = self.default
976
+
977
+ self.rebind(
978
+ lm_response=lm_response,
979
+ error=error,
980
+ end_time=time.time(),
981
+ skip_notification=True,
982
+ )
983
+ if usage_summary is not None:
984
+ self.usage_summary.merge(usage_summary)
985
+
986
+ # Refresh the tab control.
987
+ if self._tab_control is None:
988
+ return
989
+
990
+ self._tab_control.insert(
991
+ 'schema',
992
+ pg.views.html.controls.Tab( # pylint: disable=g-long-ternary
993
+ 'output',
994
+ pg.view(self.output, collapse_level=None),
995
+ name='output',
996
+ ),
997
+ )
998
+ if self.error is not None:
999
+ self._tab_control.insert(
1000
+ 'schema',
1001
+ pg.views.html.controls.Tab(
1002
+ 'error',
1003
+ pg.view(self.error, collapse_level=None),
1004
+ name='error',
1005
+ )
1006
+ )
1007
+ if self.lm_response is not None:
1008
+ self._tab_control.append(
1009
+ pg.views.html.controls.Tab(
1010
+ 'lm_response',
1011
+ pg.view(
1012
+ self.lm_response,
1013
+ extra_flags=dict(include_message_metadata=True)
1014
+ ),
1015
+ name='lm_response',
1016
+ )
1017
+ )
1018
+ self._tab_control.select(['error', 'output', 'lm_response'])
859
1019
 
860
1020
  def _html_tree_view_summary(
861
1021
  self,
@@ -872,6 +1032,7 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
872
1032
  [
873
1033
  pg.views.html.controls.Label(
874
1034
  'lf.query',
1035
+ tooltip=f'[{self.id}] Query invocation',
875
1036
  css_classes=['query-invocation-type-name']
876
1037
  ),
877
1038
  pg.views.html.controls.Badge(
@@ -888,7 +1049,9 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
888
1049
  f'{int(self.elapse)} seconds',
889
1050
  css_classes=['query-invocation-time']
890
1051
  ),
891
- self.usage_summary.to_html(extra_flags=dict(as_badge=True))
1052
+ self.usage_summary.to_html(
1053
+ extra_flags=dict(as_badge=True)
1054
+ ),
892
1055
  ],
893
1056
  css_classes=['query-invocation-title']
894
1057
  ),
@@ -900,20 +1063,31 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
900
1063
  self,
901
1064
  *,
902
1065
  view: pg.views.HtmlTreeView,
1066
+ extra_flags: dict[str, Any] | None = None,
903
1067
  **kwargs: Any
904
1068
  ) -> pg.Html:
905
- return pg.views.html.controls.TabControl([
1069
+ extra_flags = extra_flags or {}
1070
+ interactive = extra_flags.get('interactive', True)
1071
+ tab_control = pg.views.html.controls.TabControl([
906
1072
  pg.views.html.controls.Tab(
907
1073
  'input',
908
1074
  pg.view(self.input, collapse_level=None),
1075
+ name='input',
909
1076
  ),
910
- pg.views.html.controls.Tab(
1077
+ pg.views.html.controls.Tab( # pylint: disable=g-long-ternary
911
1078
  'output',
912
1079
  pg.view(self.output, collapse_level=None),
913
- ),
1080
+ name='output',
1081
+ ) if self.is_completed else None,
1082
+ pg.views.html.controls.Tab( # pylint: disable=g-long-ternary
1083
+ 'error',
1084
+ pg.view(self.error, collapse_level=None),
1085
+ name='error',
1086
+ ) if self.has_error else None,
914
1087
  pg.views.html.controls.Tab(
915
1088
  'schema',
916
1089
  pg.view(self.schema),
1090
+ name='schema',
917
1091
  ),
918
1092
  pg.views.html.controls.Tab(
919
1093
  'lm_request',
@@ -921,15 +1095,20 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
921
1095
  self.lm_request,
922
1096
  extra_flags=dict(include_message_metadata=False),
923
1097
  ),
1098
+ name='lm_request',
924
1099
  ),
925
- pg.views.html.controls.Tab(
1100
+ pg.views.html.controls.Tab( # pylint: disable=g-long-ternary
926
1101
  'lm_response',
927
1102
  pg.view(
928
1103
  self.lm_response,
929
1104
  extra_flags=dict(include_message_metadata=True)
930
1105
  ),
931
- ),
932
- ], tab_position='top', selected=1).to_html()
1106
+ name='lm_response',
1107
+ ) if self.is_completed else None,
1108
+ ], tab_position='top', selected=1)
1109
+ if interactive:
1110
+ self._tab_control = tab_control
1111
+ return tab_control.to_html(extra_flags=extra_flags)
933
1112
 
934
1113
  @classmethod
935
1114
  def _html_tree_view_css_styles(cls) -> list[str]:
@@ -962,9 +1141,57 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
962
1141
  ]
963
1142
 
964
1143
 
1144
+ @dataclasses.dataclass
1145
+ class _QueryTracker:
1146
+ """Query tracker for `track_queries`."""
1147
+
1148
+ include_child_scopes: Annotated[
1149
+ bool,
1150
+ (
1151
+ 'If True, the queries made in nested `track_queries` contexts will '
1152
+ 'be tracked by this tracker. Otherwise, only the queries made in the '
1153
+ 'current scope will be included.'
1154
+ )
1155
+ ] = True
1156
+
1157
+ start_callabck: Annotated[
1158
+ Callable[[QueryInvocation], None] | None,
1159
+ (
1160
+ 'A callback function to be called when a query is started.'
1161
+ )
1162
+ ] = None
1163
+
1164
+ end_callabck: Annotated[
1165
+ Callable[[QueryInvocation], None] | None,
1166
+ (
1167
+ 'A callback function to be called when a query is completed.'
1168
+ )
1169
+ ] = None
1170
+
1171
+ tracked_queries: Annotated[
1172
+ list[QueryInvocation],
1173
+ (
1174
+ 'The list of queries tracked by this tracker.'
1175
+ )
1176
+ ] = dataclasses.field(default_factory=list)
1177
+
1178
+ def track(self, invocation: QueryInvocation) -> None:
1179
+ self.tracked_queries.append(invocation)
1180
+ if self.start_callabck is not None:
1181
+ self.start_callabck(invocation)
1182
+
1183
+ def mark_completed(self, invocation: QueryInvocation) -> None:
1184
+ assert invocation in self.tracked_queries, invocation
1185
+ if self.end_callabck is not None:
1186
+ self.end_callabck(invocation)
1187
+
1188
+
965
1189
  @contextlib.contextmanager
966
1190
  def track_queries(
967
- include_child_scopes: bool = True
1191
+ include_child_scopes: bool = True,
1192
+ *,
1193
+ start_callabck: Callable[[QueryInvocation], None] | None = None,
1194
+ end_callabck: Callable[[QueryInvocation], None] | None = None,
968
1195
  ) -> Iterator[list[QueryInvocation]]:
969
1196
  """Track all queries made during the context.
970
1197
 
@@ -982,18 +1209,24 @@ def track_queries(
982
1209
  include_child_scopes: If True, the queries made in child scopes will be
983
1210
  included in the returned list. Otherwise, only the queries made in the
984
1211
  current scope will be included.
1212
+ start_callabck: A callback function to be called when a query is started.
1213
+ end_callabck: A callback function to be called when a query is completed.
985
1214
 
986
1215
  Yields:
987
1216
  A list of `QueryInvocation` objects representing the queries made during
988
1217
  the context.
989
1218
  """
990
1219
  trackers = lf.context_value('__query_trackers__', [])
991
- tracker = []
1220
+ tracker = _QueryTracker(
1221
+ include_child_scopes=include_child_scopes,
1222
+ start_callabck=start_callabck,
1223
+ end_callabck=end_callabck
1224
+ )
992
1225
 
993
1226
  with lf.context(
994
- __query_trackers__=[(tracker, include_child_scopes)] + trackers
1227
+ __query_trackers__=[tracker] + trackers
995
1228
  ):
996
1229
  try:
997
- yield tracker
1230
+ yield tracker.tracked_queries
998
1231
  finally:
999
1232
  pass