langfun 0.1.2.dev202505130804__py3-none-any.whl → 0.1.2.dev202505150805__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langfun might be problematic. Click here for more details.
- langfun/core/agentic/action.py +237 -108
- langfun/core/agentic/action_eval.py +4 -6
- langfun/core/agentic/action_test.py +15 -9
- langfun/core/coding/python/correction.py +4 -0
- langfun/core/console.py +6 -3
- langfun/core/language_model.py +4 -2
- langfun/core/llms/anthropic.py +4 -8
- langfun/core/llms/anthropic_test.py +38 -13
- langfun/core/llms/gemini.py +2 -2
- langfun/core/logging.py +3 -4
- langfun/core/structured/mapping.py +6 -0
- langfun/core/structured/querying.py +324 -91
- langfun/core/structured/querying_test.py +242 -2
- langfun/core/structured/schema.py +8 -0
- langfun/core/structured/schema_generation.py +1 -0
- langfun/core/structured/schema_test.py +6 -3
- {langfun-0.1.2.dev202505130804.dist-info → langfun-0.1.2.dev202505150805.dist-info}/METADATA +1 -1
- {langfun-0.1.2.dev202505130804.dist-info → langfun-0.1.2.dev202505150805.dist-info}/RECORD +21 -21
- {langfun-0.1.2.dev202505130804.dist-info → langfun-0.1.2.dev202505150805.dist-info}/WHEEL +1 -1
- {langfun-0.1.2.dev202505130804.dist-info → langfun-0.1.2.dev202505150805.dist-info}/licenses/LICENSE +0 -0
- {langfun-0.1.2.dev202505130804.dist-info → langfun-0.1.2.dev202505150805.dist-info}/top_level.txt +0 -0
langfun/core/llms/gemini.py
CHANGED
@@ -155,7 +155,7 @@ SUPPORTED_MODELS = [
|
|
155
155
|
GeminiModelInfo(
|
156
156
|
model_id='gemini-2.5-pro-preview-05-06',
|
157
157
|
in_service=True,
|
158
|
-
provider=pg.oneof(['Google GenAI', '
|
158
|
+
provider=pg.oneof(['Google GenAI', 'VertexAI']),
|
159
159
|
model_type='instruction-tuned',
|
160
160
|
description='Gemini 2.5 Pro.',
|
161
161
|
release_date=datetime.datetime(2025, 5, 6),
|
@@ -178,7 +178,7 @@ SUPPORTED_MODELS = [
|
|
178
178
|
GeminiModelInfo(
|
179
179
|
model_id='gemini-2.5-flash-preview-04-17',
|
180
180
|
in_service=True,
|
181
|
-
provider=pg.oneof(['Google GenAI', '
|
181
|
+
provider=pg.oneof(['Google GenAI', 'VertexAI']),
|
182
182
|
model_type='instruction-tuned',
|
183
183
|
description='Gemini 2.5 Flash.',
|
184
184
|
release_date=datetime.datetime(2025, 4, 17),
|
langfun/core/logging.py
CHANGED
@@ -268,14 +268,13 @@ def log(level: LogLevel,
|
|
268
268
|
metadata=kwargs,
|
269
269
|
)
|
270
270
|
|
271
|
-
if entry.should_output(get_log_level()):
|
271
|
+
if console and entry.should_output(get_log_level()):
|
272
272
|
if console_lib.under_notebook():
|
273
273
|
console_lib.display(entry)
|
274
|
-
|
274
|
+
else:
|
275
275
|
# TODO(daiyip): Improve the console output formatting.
|
276
276
|
console_lib.write(entry)
|
277
|
-
|
278
|
-
if not console:
|
277
|
+
elif not console:
|
279
278
|
if kwargs:
|
280
279
|
message = f'{message} (metadata: {pg.format(kwargs)})'
|
281
280
|
system_log_func(level)(message)
|
@@ -278,6 +278,11 @@ class Mapping(lf.LangFunc):
|
|
278
278
|
'A `lf.structured.Schema` object that constrains mapping output ',
|
279
279
|
] = None
|
280
280
|
|
281
|
+
permission: Annotated[
|
282
|
+
pg.coding.CodePermission,
|
283
|
+
'The permission to run the LLM generated code.'
|
284
|
+
] = pg.coding.CodePermission.ASSIGN | pg.coding.CodePermission.CALL
|
285
|
+
|
281
286
|
@property
|
282
287
|
def mapping_request(self) -> MappingExample:
|
283
288
|
"""Returns a MappingExample as the mapping request."""
|
@@ -434,6 +439,7 @@ class Mapping(lf.LangFunc):
|
|
434
439
|
additional_context=self.globals(),
|
435
440
|
autofix=self.autofix,
|
436
441
|
autofix_lm=self.autofix_lm or self.lm,
|
442
|
+
permission=self.permission,
|
437
443
|
)
|
438
444
|
|
439
445
|
def postprocess_response(self, response: lf.Message) -> lf.Message:
|
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Query LLM for structured output."""
|
15
15
|
|
16
16
|
import contextlib
|
17
|
+
import dataclasses
|
17
18
|
import functools
|
18
19
|
import inspect
|
19
20
|
import time
|
@@ -549,76 +550,120 @@ def query(
|
|
549
550
|
else:
|
550
551
|
query_input = schema_lib.mark_missing(prompt)
|
551
552
|
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
input=(
|
570
|
-
query_input.render(lm=lm)
|
571
|
-
if isinstance(query_input, lf.Template)
|
572
|
-
else query_input
|
573
|
-
),
|
574
|
-
schema=schema,
|
575
|
-
default=default,
|
576
|
-
examples=examples,
|
577
|
-
response_postprocess=response_postprocess,
|
578
|
-
autofix=autofix if protocol.startswith('python:') else 0,
|
579
|
-
**kwargs,
|
580
|
-
)(
|
581
|
-
lm=lm,
|
582
|
-
autofix_lm=autofix_lm or lm,
|
583
|
-
cache_seed=cache_seed,
|
584
|
-
skip_lm=skip_lm,
|
585
|
-
)
|
586
|
-
end_time = time.time()
|
553
|
+
# Determine query class.
|
554
|
+
if schema in (None, str):
|
555
|
+
# Non-structured query.
|
556
|
+
query_cls = None
|
557
|
+
else:
|
558
|
+
# Query with structured output.
|
559
|
+
query_cls = LfQuery.from_protocol(protocol)
|
560
|
+
if ':' not in protocol:
|
561
|
+
protocol = f'{protocol}:{query_cls.version}'
|
562
|
+
|
563
|
+
# `skip_lm`` is True when `lf.query_prompt` is called.
|
564
|
+
# and `prompt` is `pg.MISSING_VALUE` when `lf.query_output` is called.
|
565
|
+
# In these cases, we do not track the query invocation.
|
566
|
+
if skip_lm or pg.MISSING_VALUE == prompt:
|
567
|
+
trackers = []
|
568
|
+
else:
|
569
|
+
trackers = lf.context_value('__query_trackers__', [])
|
587
570
|
|
588
|
-
|
589
|
-
|
571
|
+
# Mark query start with trackers.
|
572
|
+
# NOTE: prompt is MISSING_VALUE when `lf.query_output` is called.
|
573
|
+
# We do not track the query invocation in this case.
|
574
|
+
if trackers:
|
575
|
+
invocation = QueryInvocation(
|
576
|
+
id=_invocation_id(),
|
577
|
+
input=pg.Ref(query_input),
|
578
|
+
schema=(
|
579
|
+
schema_lib.Schema.from_value(schema)
|
580
|
+
if schema not in (None, str) else None
|
581
|
+
),
|
582
|
+
default=default,
|
583
|
+
lm=pg.Ref(lm),
|
584
|
+
examples=pg.Ref(examples) if examples else [],
|
585
|
+
protocol=protocol,
|
586
|
+
kwargs={k: pg.Ref(v) for k, v in kwargs.items()},
|
587
|
+
start_time=time.time(),
|
588
|
+
)
|
589
|
+
for i, tracker in enumerate(trackers):
|
590
|
+
if i == 0 or tracker.include_child_scopes:
|
591
|
+
tracker.track(invocation)
|
592
|
+
else:
|
593
|
+
invocation = None
|
590
594
|
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
+
def _mark_query_completed(output_message, error, usage_summary):
|
596
|
+
# Mark query completion with trackers.
|
597
|
+
if not trackers:
|
598
|
+
return
|
599
|
+
|
600
|
+
if output_message is not None:
|
595
601
|
# To minimize payload for serialization, we remove the result and usage
|
596
602
|
# fields from the metadata. They will be computed on the fly when the
|
597
603
|
# invocation is rendered.
|
598
604
|
metadata = dict(output_message.metadata)
|
599
605
|
metadata.pop('result', None)
|
600
606
|
metadata.pop('usage', None)
|
607
|
+
lm_response = lf.AIMessage(output_message.text, metadata=metadata)
|
608
|
+
else:
|
609
|
+
lm_response = None
|
601
610
|
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
611
|
+
assert invocation is not None
|
612
|
+
invocation.mark_completed(
|
613
|
+
lm_response=lm_response, error=error, usage_summary=usage_summary,
|
614
|
+
)
|
615
|
+
for i, tracker in enumerate(trackers):
|
616
|
+
if i == 0 or tracker.include_child_scopes:
|
617
|
+
tracker.mark_completed(invocation)
|
618
|
+
|
619
|
+
with lf.track_usages() as usage_summary:
|
620
|
+
try:
|
621
|
+
if query_cls is None:
|
622
|
+
# Query with natural language output.
|
623
|
+
output_message = lf.LangFunc.from_value(query_input, **kwargs)(
|
624
|
+
lm=lm, cache_seed=cache_seed, skip_lm=skip_lm
|
625
|
+
)
|
626
|
+
if response_postprocess:
|
627
|
+
processed_text = response_postprocess(output_message.text)
|
628
|
+
if processed_text != output_message.text:
|
629
|
+
output_message = lf.AIMessage(processed_text, source=output_message)
|
630
|
+
else:
|
631
|
+
# Query with structured output.
|
632
|
+
output_message = query_cls(
|
633
|
+
input=(
|
634
|
+
query_input.render(lm=lm)
|
635
|
+
if isinstance(query_input, lf.Template)
|
636
|
+
else query_input
|
637
|
+
),
|
638
|
+
schema=schema,
|
639
|
+
examples=examples,
|
640
|
+
response_postprocess=response_postprocess,
|
641
|
+
autofix=autofix if protocol.startswith('python:') else 0,
|
642
|
+
**kwargs,
|
643
|
+
)(
|
644
|
+
lm=lm,
|
645
|
+
autofix_lm=autofix_lm or lm,
|
646
|
+
cache_seed=cache_seed,
|
647
|
+
skip_lm=skip_lm,
|
648
|
+
)
|
649
|
+
_mark_query_completed(output_message, None, usage_summary)
|
650
|
+
except mapping.MappingError as e:
|
651
|
+
_mark_query_completed(
|
652
|
+
e.lm_response, pg.utils.ErrorInfo.from_exception(e), usage_summary
|
617
653
|
)
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
654
|
+
if lf.RAISE_IF_HAS_ERROR == default:
|
655
|
+
raise e
|
656
|
+
output_message = e.lm_response
|
657
|
+
output_message.result = default
|
658
|
+
except BaseException as e:
|
659
|
+
_mark_query_completed(
|
660
|
+
None, pg.utils.ErrorInfo.from_exception(e), usage_summary
|
661
|
+
)
|
662
|
+
raise e
|
663
|
+
|
664
|
+
if returns_message:
|
665
|
+
return output_message
|
666
|
+
return output_message.text if schema in (None, str) else output_message.result
|
622
667
|
|
623
668
|
|
624
669
|
@contextlib.contextmanager
|
@@ -768,6 +813,10 @@ def _reward_fn(cls) -> Callable[
|
|
768
813
|
class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
|
769
814
|
"""A class to represent the invocation of `lf.query`."""
|
770
815
|
|
816
|
+
#
|
817
|
+
# Query input.
|
818
|
+
#
|
819
|
+
|
771
820
|
id: Annotated[
|
772
821
|
str,
|
773
822
|
'The ID of the query invocation.'
|
@@ -777,42 +826,69 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
777
826
|
Union[lf.Template, pg.Symbolic],
|
778
827
|
'Mapping input of `lf.query`.'
|
779
828
|
]
|
829
|
+
|
780
830
|
schema: pg.typing.Annotated[
|
781
831
|
schema_lib.schema_spec(noneable=True),
|
782
832
|
'Schema of `lf.query`.'
|
783
833
|
]
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
834
|
+
|
835
|
+
default: Annotated[
|
836
|
+
Any,
|
837
|
+
'Default value of `lf.query`.'
|
838
|
+
] = lf.RAISE_IF_HAS_ERROR
|
839
|
+
|
788
840
|
lm: Annotated[
|
789
841
|
lf.LanguageModel,
|
790
842
|
'Language model used for `lf.query`.'
|
791
843
|
]
|
844
|
+
|
792
845
|
examples: Annotated[
|
793
846
|
list[mapping.MappingExample],
|
794
847
|
'Fewshot exemplars for `lf.query`.'
|
795
848
|
]
|
849
|
+
|
796
850
|
protocol: Annotated[
|
797
851
|
str,
|
798
852
|
'Protocol of `lf.query`.'
|
799
853
|
] = 'python'
|
854
|
+
|
800
855
|
kwargs: Annotated[
|
801
856
|
dict[str, Any],
|
802
857
|
'Kwargs of `lf.query`.'
|
803
858
|
] = {}
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
859
|
+
|
860
|
+
#
|
861
|
+
# Query output.
|
862
|
+
#
|
863
|
+
|
864
|
+
lm_response: Annotated[
|
865
|
+
lf.Message | None,
|
866
|
+
'Raw LM response. If None, query is not completed yet or failed.'
|
867
|
+
] = None
|
868
|
+
|
869
|
+
error: Annotated[
|
870
|
+
pg.utils.ErrorInfo | None,
|
871
|
+
'Error info if the query failed.'
|
872
|
+
] = None
|
873
|
+
|
874
|
+
#
|
875
|
+
# Execution details.
|
876
|
+
#
|
877
|
+
|
808
878
|
start_time: Annotated[
|
809
879
|
float,
|
810
880
|
'Start time of query.'
|
811
881
|
]
|
882
|
+
|
812
883
|
end_time: Annotated[
|
813
|
-
float,
|
814
|
-
'End time of query.'
|
815
|
-
]
|
884
|
+
float | None,
|
885
|
+
'End time of query. If None, query is not completed yet.'
|
886
|
+
] = None
|
887
|
+
|
888
|
+
usage_summary: Annotated[
|
889
|
+
lf.UsageSummary,
|
890
|
+
'Usage summary of the query.'
|
891
|
+
] = lf.UsageSummary()
|
816
892
|
|
817
893
|
@functools.cached_property
|
818
894
|
def lm_request(self) -> lf.Message:
|
@@ -822,22 +898,26 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
822
898
|
**self.kwargs
|
823
899
|
)
|
824
900
|
|
825
|
-
@
|
901
|
+
@property
|
826
902
|
def output(self) -> Any:
|
827
|
-
"""The output of `lf.query`. If it failed, returns
|
828
|
-
|
829
|
-
return query_output(self.lm_response, self.schema, protocol=self.protocol)
|
830
|
-
except mapping.MappingError as e:
|
831
|
-
return e
|
903
|
+
"""The output of `lf.query`. If it failed, returns None."""
|
904
|
+
return self._output
|
832
905
|
|
833
906
|
@property
|
834
907
|
def has_error(self) -> bool:
|
835
908
|
"""Returns True if the query failed to generate a valid output."""
|
836
|
-
return
|
909
|
+
return self.error is not None
|
910
|
+
|
911
|
+
@property
|
912
|
+
def has_oop_error(self) -> bool:
|
913
|
+
"""Returns True if the query failed due to out of memory error."""
|
914
|
+
return self.error is not None and self.error.tag.startswith('MappingError')
|
837
915
|
|
838
916
|
@property
|
839
917
|
def elapse(self) -> float:
|
840
918
|
"""Returns query elapse in seconds."""
|
919
|
+
if self.end_time is None:
|
920
|
+
return time.time() - self.start_time
|
841
921
|
return self.end_time - self.start_time
|
842
922
|
|
843
923
|
def as_mapping_example(
|
@@ -848,14 +928,94 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
848
928
|
return mapping.MappingExample(
|
849
929
|
input=self.input,
|
850
930
|
schema=self.schema,
|
851
|
-
output=self.output,
|
931
|
+
output=self.lm_response.text if self.has_oop_error else self.output,
|
852
932
|
metadata=metadata or {},
|
853
933
|
)
|
854
934
|
|
855
935
|
def _on_bound(self):
|
856
936
|
super()._on_bound()
|
937
|
+
self._tab_control = None
|
938
|
+
self._output = None
|
857
939
|
self.__dict__.pop('lm_request', None)
|
858
|
-
|
940
|
+
|
941
|
+
@property
|
942
|
+
def is_completed(self) -> bool:
|
943
|
+
"""Returns True if the query is completed."""
|
944
|
+
return self.end_time is not None
|
945
|
+
|
946
|
+
def mark_completed(
|
947
|
+
self,
|
948
|
+
lm_response: lf.Message | None,
|
949
|
+
error: pg.utils.ErrorInfo | None = None,
|
950
|
+
usage_summary: lf.UsageSummary | None = None) -> None:
|
951
|
+
"""Marks the query as completed."""
|
952
|
+
assert self.end_time is None, 'Query is already completed.'
|
953
|
+
|
954
|
+
if error is None:
|
955
|
+
# Autofix could lead to a successful `lf.query`, however, the initial
|
956
|
+
# lm_response may not be valid. When Error is None, we always try to parse
|
957
|
+
# the lm_response into the output. If the output is not valid, the error
|
958
|
+
# will be updated accordingly. This logic could be optimized in future by
|
959
|
+
# returning attempt information when autofix is enabled.
|
960
|
+
if self.schema is not None:
|
961
|
+
try:
|
962
|
+
output = query_output(
|
963
|
+
lm_response, self.schema,
|
964
|
+
default=self.default, protocol=self.protocol
|
965
|
+
)
|
966
|
+
except mapping.MappingError as e:
|
967
|
+
output = None
|
968
|
+
error = pg.utils.ErrorInfo.from_exception(e)
|
969
|
+
self._output = output
|
970
|
+
else:
|
971
|
+
assert lm_response is not None
|
972
|
+
self._output = lm_response.text
|
973
|
+
elif (error.tag.startswith('MappingError')
|
974
|
+
and self.default != lf.RAISE_IF_HAS_ERROR):
|
975
|
+
self._output = self.default
|
976
|
+
|
977
|
+
self.rebind(
|
978
|
+
lm_response=lm_response,
|
979
|
+
error=error,
|
980
|
+
end_time=time.time(),
|
981
|
+
skip_notification=True,
|
982
|
+
)
|
983
|
+
if usage_summary is not None:
|
984
|
+
self.usage_summary.merge(usage_summary)
|
985
|
+
|
986
|
+
# Refresh the tab control.
|
987
|
+
if self._tab_control is None:
|
988
|
+
return
|
989
|
+
|
990
|
+
self._tab_control.insert(
|
991
|
+
'schema',
|
992
|
+
pg.views.html.controls.Tab( # pylint: disable=g-long-ternary
|
993
|
+
'output',
|
994
|
+
pg.view(self.output, collapse_level=None),
|
995
|
+
name='output',
|
996
|
+
),
|
997
|
+
)
|
998
|
+
if self.error is not None:
|
999
|
+
self._tab_control.insert(
|
1000
|
+
'schema',
|
1001
|
+
pg.views.html.controls.Tab(
|
1002
|
+
'error',
|
1003
|
+
pg.view(self.error, collapse_level=None),
|
1004
|
+
name='error',
|
1005
|
+
)
|
1006
|
+
)
|
1007
|
+
if self.lm_response is not None:
|
1008
|
+
self._tab_control.append(
|
1009
|
+
pg.views.html.controls.Tab(
|
1010
|
+
'lm_response',
|
1011
|
+
pg.view(
|
1012
|
+
self.lm_response,
|
1013
|
+
extra_flags=dict(include_message_metadata=True)
|
1014
|
+
),
|
1015
|
+
name='lm_response',
|
1016
|
+
)
|
1017
|
+
)
|
1018
|
+
self._tab_control.select(['error', 'output', 'lm_response'])
|
859
1019
|
|
860
1020
|
def _html_tree_view_summary(
|
861
1021
|
self,
|
@@ -872,6 +1032,7 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
872
1032
|
[
|
873
1033
|
pg.views.html.controls.Label(
|
874
1034
|
'lf.query',
|
1035
|
+
tooltip=f'[{self.id}] Query invocation',
|
875
1036
|
css_classes=['query-invocation-type-name']
|
876
1037
|
),
|
877
1038
|
pg.views.html.controls.Badge(
|
@@ -888,7 +1049,9 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
888
1049
|
f'{int(self.elapse)} seconds',
|
889
1050
|
css_classes=['query-invocation-time']
|
890
1051
|
),
|
891
|
-
self.usage_summary.to_html(
|
1052
|
+
self.usage_summary.to_html(
|
1053
|
+
extra_flags=dict(as_badge=True)
|
1054
|
+
),
|
892
1055
|
],
|
893
1056
|
css_classes=['query-invocation-title']
|
894
1057
|
),
|
@@ -900,20 +1063,31 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
900
1063
|
self,
|
901
1064
|
*,
|
902
1065
|
view: pg.views.HtmlTreeView,
|
1066
|
+
extra_flags: dict[str, Any] | None = None,
|
903
1067
|
**kwargs: Any
|
904
1068
|
) -> pg.Html:
|
905
|
-
|
1069
|
+
extra_flags = extra_flags or {}
|
1070
|
+
interactive = extra_flags.get('interactive', True)
|
1071
|
+
tab_control = pg.views.html.controls.TabControl([
|
906
1072
|
pg.views.html.controls.Tab(
|
907
1073
|
'input',
|
908
1074
|
pg.view(self.input, collapse_level=None),
|
1075
|
+
name='input',
|
909
1076
|
),
|
910
|
-
pg.views.html.controls.Tab(
|
1077
|
+
pg.views.html.controls.Tab( # pylint: disable=g-long-ternary
|
911
1078
|
'output',
|
912
1079
|
pg.view(self.output, collapse_level=None),
|
913
|
-
|
1080
|
+
name='output',
|
1081
|
+
) if self.is_completed else None,
|
1082
|
+
pg.views.html.controls.Tab( # pylint: disable=g-long-ternary
|
1083
|
+
'error',
|
1084
|
+
pg.view(self.error, collapse_level=None),
|
1085
|
+
name='error',
|
1086
|
+
) if self.has_error else None,
|
914
1087
|
pg.views.html.controls.Tab(
|
915
1088
|
'schema',
|
916
1089
|
pg.view(self.schema),
|
1090
|
+
name='schema',
|
917
1091
|
),
|
918
1092
|
pg.views.html.controls.Tab(
|
919
1093
|
'lm_request',
|
@@ -921,15 +1095,20 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
921
1095
|
self.lm_request,
|
922
1096
|
extra_flags=dict(include_message_metadata=False),
|
923
1097
|
),
|
1098
|
+
name='lm_request',
|
924
1099
|
),
|
925
|
-
pg.views.html.controls.Tab(
|
1100
|
+
pg.views.html.controls.Tab( # pylint: disable=g-long-ternary
|
926
1101
|
'lm_response',
|
927
1102
|
pg.view(
|
928
1103
|
self.lm_response,
|
929
1104
|
extra_flags=dict(include_message_metadata=True)
|
930
1105
|
),
|
931
|
-
|
932
|
-
|
1106
|
+
name='lm_response',
|
1107
|
+
) if self.is_completed else None,
|
1108
|
+
], tab_position='top', selected=1)
|
1109
|
+
if interactive:
|
1110
|
+
self._tab_control = tab_control
|
1111
|
+
return tab_control.to_html(extra_flags=extra_flags)
|
933
1112
|
|
934
1113
|
@classmethod
|
935
1114
|
def _html_tree_view_css_styles(cls) -> list[str]:
|
@@ -962,9 +1141,57 @@ class QueryInvocation(pg.Object, pg.views.HtmlTreeView.Extension):
|
|
962
1141
|
]
|
963
1142
|
|
964
1143
|
|
1144
|
+
@dataclasses.dataclass
|
1145
|
+
class _QueryTracker:
|
1146
|
+
"""Query tracker for `track_queries`."""
|
1147
|
+
|
1148
|
+
include_child_scopes: Annotated[
|
1149
|
+
bool,
|
1150
|
+
(
|
1151
|
+
'If True, the queries made in nested `track_queries` contexts will '
|
1152
|
+
'be tracked by this tracker. Otherwise, only the queries made in the '
|
1153
|
+
'current scope will be included.'
|
1154
|
+
)
|
1155
|
+
] = True
|
1156
|
+
|
1157
|
+
start_callabck: Annotated[
|
1158
|
+
Callable[[QueryInvocation], None] | None,
|
1159
|
+
(
|
1160
|
+
'A callback function to be called when a query is started.'
|
1161
|
+
)
|
1162
|
+
] = None
|
1163
|
+
|
1164
|
+
end_callabck: Annotated[
|
1165
|
+
Callable[[QueryInvocation], None] | None,
|
1166
|
+
(
|
1167
|
+
'A callback function to be called when a query is completed.'
|
1168
|
+
)
|
1169
|
+
] = None
|
1170
|
+
|
1171
|
+
tracked_queries: Annotated[
|
1172
|
+
list[QueryInvocation],
|
1173
|
+
(
|
1174
|
+
'The list of queries tracked by this tracker.'
|
1175
|
+
)
|
1176
|
+
] = dataclasses.field(default_factory=list)
|
1177
|
+
|
1178
|
+
def track(self, invocation: QueryInvocation) -> None:
|
1179
|
+
self.tracked_queries.append(invocation)
|
1180
|
+
if self.start_callabck is not None:
|
1181
|
+
self.start_callabck(invocation)
|
1182
|
+
|
1183
|
+
def mark_completed(self, invocation: QueryInvocation) -> None:
|
1184
|
+
assert invocation in self.tracked_queries, invocation
|
1185
|
+
if self.end_callabck is not None:
|
1186
|
+
self.end_callabck(invocation)
|
1187
|
+
|
1188
|
+
|
965
1189
|
@contextlib.contextmanager
|
966
1190
|
def track_queries(
|
967
|
-
include_child_scopes: bool = True
|
1191
|
+
include_child_scopes: bool = True,
|
1192
|
+
*,
|
1193
|
+
start_callabck: Callable[[QueryInvocation], None] | None = None,
|
1194
|
+
end_callabck: Callable[[QueryInvocation], None] | None = None,
|
968
1195
|
) -> Iterator[list[QueryInvocation]]:
|
969
1196
|
"""Track all queries made during the context.
|
970
1197
|
|
@@ -982,18 +1209,24 @@ def track_queries(
|
|
982
1209
|
include_child_scopes: If True, the queries made in child scopes will be
|
983
1210
|
included in the returned list. Otherwise, only the queries made in the
|
984
1211
|
current scope will be included.
|
1212
|
+
start_callabck: A callback function to be called when a query is started.
|
1213
|
+
end_callabck: A callback function to be called when a query is completed.
|
985
1214
|
|
986
1215
|
Yields:
|
987
1216
|
A list of `QueryInvocation` objects representing the queries made during
|
988
1217
|
the context.
|
989
1218
|
"""
|
990
1219
|
trackers = lf.context_value('__query_trackers__', [])
|
991
|
-
tracker =
|
1220
|
+
tracker = _QueryTracker(
|
1221
|
+
include_child_scopes=include_child_scopes,
|
1222
|
+
start_callabck=start_callabck,
|
1223
|
+
end_callabck=end_callabck
|
1224
|
+
)
|
992
1225
|
|
993
1226
|
with lf.context(
|
994
|
-
__query_trackers__=[
|
1227
|
+
__query_trackers__=[tracker] + trackers
|
995
1228
|
):
|
996
1229
|
try:
|
997
|
-
yield tracker
|
1230
|
+
yield tracker.tracked_queries
|
998
1231
|
finally:
|
999
1232
|
pass
|