langfun 0.1.2.dev202410100804__tar.gz → 0.1.2.dev202410120803__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/PKG-INFO +1 -1
  2. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/__init__.py +1 -0
  3. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/eval/base_test.py +1 -0
  4. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/langfunc_test.py +2 -2
  5. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/language_model.py +140 -24
  6. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/language_model_test.py +166 -36
  7. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/__init__.py +8 -1
  8. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/anthropic.py +72 -7
  9. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/cache/in_memory_test.py +3 -2
  10. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/fake_test.py +7 -0
  11. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/groq.py +154 -6
  12. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/openai.py +300 -42
  13. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/openai_test.py +35 -8
  14. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/vertexai.py +121 -16
  15. langfun-0.1.2.dev202410120803/langfun/core/logging.py +254 -0
  16. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/logging_test.py +33 -0
  17. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/message.py +249 -70
  18. langfun-0.1.2.dev202410120803/langfun/core/message_test.py +408 -0
  19. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/modalities/audio.py +1 -1
  20. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/modalities/audio_test.py +1 -1
  21. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/modalities/image.py +1 -1
  22. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/modalities/image_test.py +9 -3
  23. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/modalities/mime.py +39 -3
  24. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/modalities/mime_test.py +39 -0
  25. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/modalities/ms_office.py +2 -5
  26. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/modalities/ms_office_test.py +1 -1
  27. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/modalities/pdf_test.py +1 -1
  28. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/modalities/video.py +1 -1
  29. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/modalities/video_test.py +2 -2
  30. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/completion_test.py +1 -0
  31. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/mapping.py +38 -0
  32. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/mapping_test.py +55 -0
  33. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/parsing_test.py +2 -1
  34. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/prompting_test.py +1 -0
  35. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/schema.py +34 -0
  36. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/template.py +110 -1
  37. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/template_test.py +37 -0
  38. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/templates/selfplay_test.py +4 -2
  39. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun.egg-info/PKG-INFO +1 -1
  40. langfun-0.1.2.dev202410100804/langfun/core/logging.py +0 -147
  41. langfun-0.1.2.dev202410100804/langfun/core/message_test.py +0 -383
  42. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/LICENSE +0 -0
  43. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/README.md +0 -0
  44. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/__init__.py +0 -0
  45. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/coding/__init__.py +0 -0
  46. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/coding/python/__init__.py +0 -0
  47. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/coding/python/correction.py +0 -0
  48. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/coding/python/correction_test.py +0 -0
  49. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/coding/python/errors.py +0 -0
  50. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/coding/python/errors_test.py +0 -0
  51. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/coding/python/execution.py +0 -0
  52. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/coding/python/execution_test.py +0 -0
  53. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/coding/python/generation.py +0 -0
  54. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/coding/python/generation_test.py +0 -0
  55. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/coding/python/parsing.py +0 -0
  56. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/coding/python/parsing_test.py +0 -0
  57. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/coding/python/permissions.py +0 -0
  58. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/coding/python/permissions_test.py +0 -0
  59. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/component.py +0 -0
  60. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/component_test.py +0 -0
  61. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/concurrent.py +0 -0
  62. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/concurrent_test.py +0 -0
  63. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/console.py +0 -0
  64. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/console_test.py +0 -0
  65. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/eval/__init__.py +0 -0
  66. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/eval/base.py +0 -0
  67. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/eval/matching.py +0 -0
  68. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/eval/matching_test.py +0 -0
  69. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/eval/patching.py +0 -0
  70. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/eval/patching_test.py +0 -0
  71. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/eval/scoring.py +0 -0
  72. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/eval/scoring_test.py +0 -0
  73. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/langfunc.py +0 -0
  74. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/anthropic_test.py +0 -0
  75. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/cache/__init__.py +0 -0
  76. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/cache/base.py +0 -0
  77. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/cache/in_memory.py +0 -0
  78. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/fake.py +0 -0
  79. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/google_genai.py +0 -0
  80. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/google_genai_test.py +0 -0
  81. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/groq_test.py +0 -0
  82. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/llama_cpp.py +0 -0
  83. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/llama_cpp_test.py +0 -0
  84. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/rest.py +0 -0
  85. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/rest_test.py +0 -0
  86. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/llms/vertexai_test.py +0 -0
  87. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/memories/__init__.py +0 -0
  88. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/memories/conversation_history.py +0 -0
  89. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/memories/conversation_history_test.py +0 -0
  90. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/memory.py +0 -0
  91. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/modalities/__init__.py +0 -0
  92. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/modalities/pdf.py +0 -0
  93. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/modality.py +0 -0
  94. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/modality_test.py +0 -0
  95. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/natural_language.py +0 -0
  96. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/natural_language_test.py +0 -0
  97. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/repr_utils.py +0 -0
  98. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/repr_utils_test.py +0 -0
  99. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/sampling.py +0 -0
  100. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/sampling_test.py +0 -0
  101. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/__init__.py +0 -0
  102. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/completion.py +0 -0
  103. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/description.py +0 -0
  104. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/description_test.py +0 -0
  105. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/function_generation.py +0 -0
  106. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/function_generation_test.py +0 -0
  107. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/parsing.py +0 -0
  108. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/prompting.py +0 -0
  109. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/schema_generation.py +0 -0
  110. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/schema_generation_test.py +0 -0
  111. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/schema_test.py +0 -0
  112. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/scoring.py +0 -0
  113. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/scoring_test.py +0 -0
  114. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/tokenization.py +0 -0
  115. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/structured/tokenization_test.py +0 -0
  116. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/subscription.py +0 -0
  117. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/subscription_test.py +0 -0
  118. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/templates/__init__.py +0 -0
  119. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/templates/completion.py +0 -0
  120. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/templates/completion_test.py +0 -0
  121. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/templates/conversation.py +0 -0
  122. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/templates/conversation_test.py +0 -0
  123. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/templates/demonstration.py +0 -0
  124. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/templates/demonstration_test.py +0 -0
  125. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/templates/selfplay.py +0 -0
  126. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/text_formatting.py +0 -0
  127. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun/core/text_formatting_test.py +0 -0
  128. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun.egg-info/SOURCES.txt +0 -0
  129. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun.egg-info/dependency_links.txt +0 -0
  130. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun.egg-info/requires.txt +0 -0
  131. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/langfun.egg-info/top_level.txt +0 -0
  132. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/setup.cfg +0 -0
  133. {langfun-0.1.2.dev202410100804 → langfun-0.1.2.dev202410120803}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.1.2.dev202410100804
3
+ Version: 0.1.2.dev202410120803
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -103,6 +103,7 @@ from langfun.core.language_model import LMSample
103
103
  from langfun.core.language_model import LMSamplingOptions
104
104
  from langfun.core.language_model import LMSamplingUsage
105
105
  from langfun.core.language_model import UsageNotAvailable
106
+ from langfun.core.language_model import UsageSummary
106
107
  from langfun.core.language_model import LMSamplingResult
107
108
  from langfun.core.language_model import LMScoringResult
108
109
  from langfun.core.language_model import LMCache
@@ -194,6 +194,7 @@ class EvaluationTest(unittest.TestCase):
194
194
  cache_seed=0,
195
195
  score=1.0,
196
196
  logprobs=None,
197
+ is_cached=False,
197
198
  usage=lf.LMSamplingUsage(387, 24, 411),
198
199
  tags=['lm-response', 'lm-output', 'transformed'],
199
200
  ),
@@ -89,7 +89,7 @@ class LangFuncCallTest(unittest.TestCase):
89
89
  self.assertEqual(
90
90
  r,
91
91
  message.AIMessage(
92
- 'Hello!!!', score=0.0, logprobs=None,
92
+ 'Hello!!!', score=0.0, logprobs=None, is_cached=False,
93
93
  usage=language_model.UsageNotAvailable()
94
94
  )
95
95
  )
@@ -120,7 +120,7 @@ class LangFuncCallTest(unittest.TestCase):
120
120
  self.assertEqual(
121
121
  r,
122
122
  message.AIMessage(
123
- 'Hello!!!', score=0.0, logprobs=None,
123
+ 'Hello!!!', score=0.0, logprobs=None, is_cached=False,
124
124
  usage=language_model.UsageNotAvailable()
125
125
  )
126
126
  )
@@ -19,7 +19,7 @@ import dataclasses
19
19
  import enum
20
20
  import threading
21
21
  import time
22
- from typing import Annotated, Any, Callable, Iterator, Sequence, Tuple, Type, Union
22
+ from typing import Annotated, Any, Callable, Iterator, Optional, Sequence, Tuple, Type, Union
23
23
  from langfun.core import component
24
24
  from langfun.core import concurrent
25
25
  from langfun.core import console
@@ -86,25 +86,75 @@ class LMSamplingUsage(pg.Object):
86
86
  completion_tokens: int
87
87
  total_tokens: int
88
88
  num_requests: int = 1
89
+ estimated_cost: Annotated[
90
+ float | None,
91
+ (
92
+ 'Estimated cost in US dollars. If None, cost estimating is not '
93
+ 'suppported on the model being queried.'
94
+ )
95
+ ] = None
96
+
97
+ def __bool__(self) -> bool:
98
+ return self.num_requests > 0
99
+
100
+ @property
101
+ def average_prompt_tokens(self) -> int:
102
+ """Returns the average prompt tokens per request."""
103
+ return self.prompt_tokens // self.num_requests
104
+
105
+ @property
106
+ def average_completion_tokens(self) -> int:
107
+ """Returns the average completion tokens per request."""
108
+ return self.completion_tokens // self.num_requests
109
+
110
+ @property
111
+ def average_total_tokens(self) -> int:
112
+ """Returns the average total tokens per request."""
113
+ return self.total_tokens // self.num_requests
89
114
 
90
- def __add__(self, other: 'LMSamplingUsage') -> 'LMSamplingUsage':
115
+ @property
116
+ def average_estimated_cost(self) -> float | None:
117
+ """Returns the average estimated cost per request."""
118
+ if self.estimated_cost is None:
119
+ return None
120
+ return self.estimated_cost / self.num_requests
121
+
122
+ def __add__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
123
+ if other is None:
124
+ return self
91
125
  return LMSamplingUsage(
92
126
  prompt_tokens=self.prompt_tokens + other.prompt_tokens,
93
127
  completion_tokens=self.completion_tokens + other.completion_tokens,
94
128
  total_tokens=self.total_tokens + other.total_tokens,
95
129
  num_requests=self.num_requests + other.num_requests,
130
+ estimated_cost=(
131
+ self.estimated_cost + other.estimated_cost # pylint: disable=g-long-ternary
132
+ if (self.estimated_cost is not None
133
+ and other.estimated_cost is not None)
134
+ else None
135
+ )
96
136
  )
97
137
 
138
+ def __radd__(self, other: Optional['LMSamplingUsage']) -> 'LMSamplingUsage':
139
+ return self + other
140
+
98
141
 
99
142
  class UsageNotAvailable(LMSamplingUsage):
100
143
  """Usage information not available."""
101
144
  prompt_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
102
145
  completion_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
103
146
  total_tokens: pg.typing.Int(0).freeze() # pytype: disable=invalid-annotation
104
- num_requests: pg.typing.Int(1).freeze() # pytype: disable=invalid-annotation
147
+ estimated_cost: pg.typing.Float(default=None, is_noneable=True).freeze() # pytype: disable=invalid-annotation
105
148
 
106
- def __bool__(self) -> bool:
107
- return False
149
+ def __add__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
150
+ if other is None:
151
+ return self
152
+ return UsageNotAvailable(
153
+ num_requests=self.num_requests + other.num_requests
154
+ )
155
+
156
+ def __radd__(self, other: Optional['LMSamplingUsage']) -> 'UsageNotAvailable':
157
+ return self + other
108
158
 
109
159
 
110
160
  class LMSamplingResult(pg.Object):
@@ -123,6 +173,11 @@ class LMSamplingResult(pg.Object):
123
173
  'Usage information. Currently only OpenAI models are supported.',
124
174
  ] = UsageNotAvailable()
125
175
 
176
+ is_cached: Annotated[
177
+ bool,
178
+ 'Whether the result is from cache or not.'
179
+ ] = False
180
+
126
181
 
127
182
  class LMSamplingOptions(component.Component):
128
183
  """Language model sampling options."""
@@ -425,12 +480,13 @@ class LanguageModel(component.Component):
425
480
  response = sample.response
426
481
  response.metadata.score = sample.score
427
482
  response.metadata.logprobs = sample.logprobs
483
+ response.metadata.is_cached = result.is_cached
428
484
 
429
485
  # NOTE(daiyip): Current usage is computed at per-result level,
430
486
  # which is accurate when n=1. For n > 1, we average the usage across
431
487
  # multiple samples.
432
488
  usage = result.usage
433
- if len(result.samples) == 1 or not usage:
489
+ if len(result.samples) == 1 or isinstance(usage, UsageNotAvailable):
434
490
  response.metadata.usage = usage
435
491
  else:
436
492
  n = len(result.samples)
@@ -438,6 +494,9 @@ class LanguageModel(component.Component):
438
494
  prompt_tokens=usage.prompt_tokens // n,
439
495
  completion_tokens=usage.completion_tokens // n,
440
496
  total_tokens=usage.total_tokens // n,
497
+ estimated_cost=(
498
+ usage.estimated_cost / n if usage.estimated_cost else None
499
+ )
441
500
  )
442
501
 
443
502
  # Track usage.
@@ -445,7 +504,7 @@ class LanguageModel(component.Component):
445
504
  if trackers:
446
505
  model_id = self.model_id
447
506
  for tracker in trackers:
448
- tracker.track(model_id, usage)
507
+ tracker.track(model_id, usage, result.is_cached)
449
508
 
450
509
  # Track the prompt for corresponding response.
451
510
  response.source = prompt
@@ -474,7 +533,9 @@ class LanguageModel(component.Component):
474
533
  request_to_result_index[len(requests)] = i
475
534
  requests.append(prompt)
476
535
  else:
477
- results[i] = r.clone()
536
+ result = r.clone()
537
+ assert result.is_cached, result
538
+ results[i] = result
478
539
 
479
540
  # Sample non-cache-hit prompts.
480
541
  if requests:
@@ -491,8 +552,12 @@ class LanguageModel(component.Component):
491
552
  sample.response.set('cache_seed', cache_seed)
492
553
 
493
554
  if cache_seed is not None:
494
- self.cache.put(self, prompt, result.clone(), seed=cache_seed)
495
-
555
+ self.cache.put(
556
+ self,
557
+ prompt,
558
+ result.clone(override=dict(is_cached=True)),
559
+ seed=cache_seed
560
+ )
496
561
  return results # pytype: disable=bad-return-type
497
562
 
498
563
  @abc.abstractmethod
@@ -800,30 +865,81 @@ class LanguageModel(component.Component):
800
865
  return DEFAULT_MAX_CONCURRENCY # Default of 1
801
866
 
802
867
 
868
+ class UsageSummary(pg.Object):
869
+ """Usage sumary."""
870
+
871
+ class AggregatedUsage(pg.Object):
872
+ """Aggregated usage."""
873
+
874
+ total: LMSamplingUsage = LMSamplingUsage(0, 0, 0, 0, 0.0)
875
+ breakdown: dict[str, LMSamplingUsage] = {}
876
+
877
+ def __bool__(self) -> bool:
878
+ """Returns True if the usage is non-empty."""
879
+ return bool(self.breakdown)
880
+
881
+ def add(
882
+ self,
883
+ model_id: str,
884
+ usage: LMSamplingUsage,
885
+ ) -> None:
886
+ """Adds an entry to the breakdown."""
887
+ aggregated = self.breakdown.get(model_id, None)
888
+ with pg.notify_on_change(False):
889
+ self.breakdown[model_id] = usage + aggregated
890
+ self.rebind(total=self.total + usage, skip_notification=True)
891
+
892
+ @property
893
+ def total(self) -> LMSamplingUsage:
894
+ return self.cached.total + self.uncached.total
895
+
896
+ def update(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
897
+ """Updates the usage summary."""
898
+ if is_cached:
899
+ usage.rebind(estimated_cost=0.0, skip_notification=True)
900
+ self.cached.add(model_id, usage)
901
+ else:
902
+ self.uncached.add(model_id, usage)
903
+
904
+
905
+ pg.members(
906
+ dict(
907
+ cached=(
908
+ pg.typing.Object(
909
+ UsageSummary.AggregatedUsage,
910
+ default=UsageSummary.AggregatedUsage()
911
+ ),
912
+ 'Aggregated usages for cached LLM calls.'
913
+ ),
914
+ uncached=(
915
+ pg.typing.Object(
916
+ UsageSummary.AggregatedUsage,
917
+ default=UsageSummary.AggregatedUsage()
918
+ ),
919
+ 'Aggregated usages for uncached LLM calls.'
920
+ ),
921
+ )
922
+ )(UsageSummary)
923
+
924
+
803
925
  class _UsageTracker:
804
926
  """Usage tracker."""
805
927
 
806
928
  def __init__(self, model_ids: set[str] | None):
807
929
  self.model_ids = model_ids
930
+ self.usage_summary = UsageSummary()
808
931
  self._lock = threading.Lock()
809
- self.usages = {
810
- m: LMSamplingUsage(0, 0, 0, 0) for m in model_ids
811
- } if model_ids else {}
812
-
813
- def track(self, model_id: str, usage: LMSamplingUsage):
814
- if self.model_ids is not None and model_id not in self.model_ids:
815
- return
816
- with self._lock:
817
- if not isinstance(usage, UsageNotAvailable) and model_id in self.usages:
818
- self.usages[model_id] += usage
819
- else:
820
- self.usages[model_id] = usage
932
+
933
+ def track(self, model_id: str, usage: LMSamplingUsage, is_cached: bool):
934
+ if self.model_ids is None or model_id in self.model_ids:
935
+ with self._lock:
936
+ self.usage_summary.update(model_id, usage, is_cached)
821
937
 
822
938
 
823
939
  @contextlib.contextmanager
824
940
  def track_usages(
825
941
  *lm: Union[str, LanguageModel]
826
- ) -> Iterator[dict[str, LMSamplingUsage]]:
942
+ ) -> Iterator[UsageSummary]:
827
943
  """Context manager to track the usages of all language models in scope.
828
944
 
829
945
  `lf.track_usages` works with threads spawned by `lf.concurrent_map` and
@@ -854,6 +970,6 @@ def track_usages(
854
970
  tracker = _UsageTracker(set(model_ids) if model_ids else None)
855
971
  with component.context(__usage_trackers__=trackers + [tracker]):
856
972
  try:
857
- yield tracker.usages
973
+ yield tracker.usage_summary
858
974
  finally:
859
975
  pass