langfun 0.0.2.dev20240420__tar.gz → 0.0.2.dev20240423__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/PKG-INFO +1 -1
  2. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/component.py +6 -0
  3. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/component_test.py +1 -0
  4. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/language_model.py +14 -0
  5. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/language_model_test.py +32 -0
  6. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/__init__.py +7 -0
  7. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/anthropic.py +36 -22
  8. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/anthropic_test.py +7 -7
  9. langfun-0.0.2.dev20240423/langfun/core/llms/groq.py +260 -0
  10. langfun-0.0.2.dev20240423/langfun/core/llms/groq_test.py +170 -0
  11. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/openai.py +55 -50
  12. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/openai_test.py +3 -3
  13. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/template.py +26 -8
  14. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/template_test.py +9 -0
  15. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun.egg-info/PKG-INFO +1 -1
  16. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun.egg-info/SOURCES.txt +2 -0
  17. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/LICENSE +0 -0
  18. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/README.md +0 -0
  19. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/__init__.py +0 -0
  20. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/__init__.py +0 -0
  21. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/__init__.py +0 -0
  22. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/__init__.py +0 -0
  23. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/correction.py +0 -0
  24. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/correction_test.py +0 -0
  25. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/errors.py +0 -0
  26. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/errors_test.py +0 -0
  27. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/execution.py +0 -0
  28. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/execution_test.py +0 -0
  29. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/generation.py +0 -0
  30. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/generation_test.py +0 -0
  31. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/parsing.py +0 -0
  32. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/parsing_test.py +0 -0
  33. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/permissions.py +0 -0
  34. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/coding/python/permissions_test.py +0 -0
  35. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/concurrent.py +0 -0
  36. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/concurrent_test.py +0 -0
  37. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/console.py +0 -0
  38. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/console_test.py +0 -0
  39. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/eval/__init__.py +0 -0
  40. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/eval/base.py +0 -0
  41. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/eval/base_test.py +0 -0
  42. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/eval/matching.py +0 -0
  43. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/eval/matching_test.py +0 -0
  44. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/eval/scoring.py +0 -0
  45. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/eval/scoring_test.py +0 -0
  46. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/langfunc.py +0 -0
  47. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/langfunc_test.py +0 -0
  48. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/cache/__init__.py +0 -0
  49. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/cache/base.py +0 -0
  50. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/cache/in_memory.py +0 -0
  51. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/cache/in_memory_test.py +0 -0
  52. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/fake.py +0 -0
  53. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/fake_test.py +0 -0
  54. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/google_genai.py +0 -0
  55. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/google_genai_test.py +0 -0
  56. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/llama_cpp.py +0 -0
  57. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/llms/llama_cpp_test.py +0 -0
  58. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/memories/__init__.py +0 -0
  59. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/memories/conversation_history.py +0 -0
  60. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/memories/conversation_history_test.py +0 -0
  61. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/memory.py +0 -0
  62. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/message.py +0 -0
  63. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/message_test.py +0 -0
  64. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modalities/__init__.py +0 -0
  65. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modalities/image.py +0 -0
  66. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modalities/image_test.py +0 -0
  67. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modalities/mime.py +0 -0
  68. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modalities/mime_test.py +0 -0
  69. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modalities/video.py +0 -0
  70. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modalities/video_test.py +0 -0
  71. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modality.py +0 -0
  72. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/modality_test.py +0 -0
  73. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/natural_language.py +0 -0
  74. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/natural_language_test.py +0 -0
  75. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/sampling.py +0 -0
  76. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/sampling_test.py +0 -0
  77. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/__init__.py +0 -0
  78. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/completion.py +0 -0
  79. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/completion_test.py +0 -0
  80. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/description.py +0 -0
  81. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/description_test.py +0 -0
  82. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/function_generation.py +0 -0
  83. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/function_generation_test.py +0 -0
  84. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/mapping.py +0 -0
  85. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/mapping_test.py +0 -0
  86. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/parsing.py +0 -0
  87. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/parsing_test.py +0 -0
  88. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/prompting.py +0 -0
  89. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/prompting_test.py +0 -0
  90. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/schema.py +0 -0
  91. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/schema_generation.py +0 -0
  92. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/schema_generation_test.py +0 -0
  93. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/schema_test.py +0 -0
  94. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/scoring.py +0 -0
  95. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/structured/scoring_test.py +0 -0
  96. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/subscription.py +0 -0
  97. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/subscription_test.py +0 -0
  98. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/__init__.py +0 -0
  99. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/completion.py +0 -0
  100. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/completion_test.py +0 -0
  101. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/conversation.py +0 -0
  102. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/conversation_test.py +0 -0
  103. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/demonstration.py +0 -0
  104. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/demonstration_test.py +0 -0
  105. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/selfplay.py +0 -0
  106. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/templates/selfplay_test.py +0 -0
  107. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/text_formatting.py +0 -0
  108. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun/core/text_formatting_test.py +0 -0
  109. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun.egg-info/dependency_links.txt +0 -0
  110. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun.egg-info/requires.txt +0 -0
  111. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/langfun.egg-info/top_level.txt +0 -0
  112. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/setup.cfg +0 -0
  113. {langfun-0.0.2.dev20240420 → langfun-0.0.2.dev20240423}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240420
3
+ Version: 0.0.2.dev20240423
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -210,6 +210,12 @@ def get_contextual_override(var_name: str) -> ContextualOverride | None:
210
210
  return _get_scoped_value(_global_tls, _CONTEXT_OVERRIDES, var_name)
211
211
 
212
212
 
213
+ def all_contextual_values() -> dict[str, Any]:
214
+ """Returns all contextual values provided from `lf.context` in scope."""
215
+ overrides = getattr(_global_tls, _CONTEXT_OVERRIDES, {})
216
+ return {k: v.value for k, v in overrides.items()}
217
+
218
+
213
219
  @contextlib.contextmanager
214
220
  def _contextual_scope(
215
221
  tls: threading.local, tls_key, **variables
@@ -84,6 +84,7 @@ class ComponentContextTest(unittest.TestCase):
84
84
  lf.get_contextual_override('y'),
85
85
  lf.ContextualOverride(3, cascade=False, override_attrs=False),
86
86
  )
87
+ self.assertEqual(lf.all_contextual_values(), dict(x=3, y=3, z=3))
87
88
 
88
89
  # Member attributes take precedence over `lf.context`.
89
90
  self.assertEqual(a1.x, 1)
@@ -24,6 +24,9 @@ from langfun.core import console
24
24
  from langfun.core import message as message_lib
25
25
  import pyglove as pg
26
26
 
27
+ TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
28
+ DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
29
+
27
30
 
28
31
  class LMSample(pg.Object):
29
32
  """Response candidate."""
@@ -604,3 +607,14 @@ class LanguageModel(component.Component):
604
607
  f'score: {r.score}',
605
608
  color='blue',
606
609
  )
610
+
611
+ def rate_to_max_concurrency(
612
+ self, requests_per_min: float = 0, tokens_per_min: float = 0
613
+ ) -> int:
614
+ """Converts a rate to a max concurrency."""
615
+ if tokens_per_min > 0:
616
+ return max(int(tokens_per_min / TOKENS_PER_REQUEST / 60), 1)
617
+ elif requests_per_min > 0:
618
+ return max(int(requests_per_min / 60), 1) # Max concurrency can't be zero
619
+ else:
620
+ return DEFAULT_MAX_CONCURRENCY # Default of 1
@@ -394,6 +394,38 @@ class LanguageModelTest(unittest.TestCase):
394
394
  with self.assertRaises(NotImplementedError):
395
395
  MockModel().score('hi', ['1', '2'])
396
396
 
397
+ def test_rate_to_max_concurrency_no_rpm_no_tpm(self) -> None:
398
+ lm = MockModel()
399
+ self.assertEqual(
400
+ lm_lib.DEFAULT_MAX_CONCURRENCY,
401
+ lm.rate_to_max_concurrency(requests_per_min=0, tokens_per_min=0),
402
+ )
403
+ self.assertEqual(
404
+ lm_lib.DEFAULT_MAX_CONCURRENCY,
405
+ lm.rate_to_max_concurrency(requests_per_min=-1, tokens_per_min=-1),
406
+ )
407
+
408
+ def test_rate_to_max_concurrency_only_rpm_specified_uses_rpm(self) -> None:
409
+ lm = MockModel()
410
+ test_rpm = 1e4
411
+ self.assertEqual(
412
+ lm.rate_to_max_concurrency(requests_per_min=test_rpm),
413
+ int(test_rpm / 60)
414
+ )
415
+
416
+ def test_rate_to_max_concurrency_tpm_specified_uses_tpm(self) -> None:
417
+ lm = MockModel()
418
+ test_tpm = 1e7
419
+ self.assertEqual(
420
+ lm.rate_to_max_concurrency(requests_per_min=1, tokens_per_min=test_tpm),
421
+ int(test_tpm / lm_lib.TOKENS_PER_REQUEST / 60),
422
+ )
423
+
424
+ def test_rate_to_max_concurrency_small_rate_returns_one(self) -> None:
425
+ lm = MockModel()
426
+ self.assertEqual(lm.rate_to_max_concurrency(requests_per_min=1), 1)
427
+ self.assertEqual(lm.rate_to_max_concurrency(tokens_per_min=1), 1)
428
+
397
429
 
398
430
  if __name__ == '__main__':
399
431
  unittest.main()
@@ -66,6 +66,13 @@ from langfun.core.llms.anthropic import Claude3Opus
66
66
  from langfun.core.llms.anthropic import Claude3Sonnet
67
67
  from langfun.core.llms.anthropic import Claude3Haiku
68
68
 
69
+ from langfun.core.llms.groq import Groq
70
+ from langfun.core.llms.groq import GroqLlama3_70B
71
+ from langfun.core.llms.groq import GroqLlama3_8B
72
+ from langfun.core.llms.groq import GroqLlama2_70B
73
+ from langfun.core.llms.groq import GroqMistral_8x7B
74
+ from langfun.core.llms.groq import GroqGemma7B_IT
75
+
69
76
 
70
77
  # LLaMA C++ models.
71
78
  from langfun.core.llms.llama_cpp import LlamaCppRemote
@@ -26,12 +26,15 @@ import requests
26
26
 
27
27
  SUPPORTED_MODELS_AND_SETTINGS = {
28
28
  # See https://docs.anthropic.com/claude/docs/models-overview
29
- 'claude-3-opus-20240229': pg.Dict(max_tokens=4096, max_concurrency=16),
30
- 'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, max_concurrency=16),
31
- 'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, max_concurrency=16),
32
- 'claude-2.1': pg.Dict(max_tokens=4096, max_concurrency=16),
33
- 'claude-2.0': pg.Dict(max_tokens=4096, max_concurrency=16),
34
- 'claude-instant-1.2': pg.Dict(max_tokens=4096, max_concurrency=16),
29
+ # Rate limits from https://docs.anthropic.com/claude/reference/rate-limits
30
+ # RPM/TPM for Claude-2.1, Claude-2.0, and Claude-Instant-1.2 estimated
31
+ # as RPM/TPM of the largest-available model (Claude-3-Opus).
32
+ 'claude-3-opus-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
33
+ 'claude-3-sonnet-20240229': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
34
+ 'claude-3-haiku-20240307': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
35
+ 'claude-2.1': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
36
+ 'claude-2.0': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
37
+ 'claude-instant-1.2': pg.Dict(max_tokens=4096, rpm=4000, tpm=400000),
35
38
  }
36
39
 
37
40
 
@@ -81,6 +84,7 @@ class Anthropic(lf.LanguageModel):
81
84
  super()._on_bound()
82
85
  self._api_key = None
83
86
  self.__dict__.pop('_api_initialized', None)
87
+ self.__dict__.pop('_session', None)
84
88
 
85
89
  @functools.cached_property
86
90
  def _api_initialized(self):
@@ -93,6 +97,17 @@ class Anthropic(lf.LanguageModel):
93
97
  self._api_key = api_key
94
98
  return True
95
99
 
100
+ @functools.cached_property
101
+ def _session(self) -> requests.Session:
102
+ assert self._api_initialized
103
+ s = requests.Session()
104
+ s.headers.update({
105
+ 'x-api-key': self._api_key,
106
+ 'anthropic-version': _ANTHROPIC_API_VERSION,
107
+ 'content-type': 'application/json',
108
+ })
109
+ return s
110
+
96
111
  @property
97
112
  def model_id(self) -> str:
98
113
  """Returns a string to identify the model."""
@@ -100,7 +115,11 @@ class Anthropic(lf.LanguageModel):
100
115
 
101
116
  @property
102
117
  def max_concurrency(self) -> int:
103
- return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
118
+ rpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('rpm', 0)
119
+ tpm = SUPPORTED_MODELS_AND_SETTINGS[self.model].get('tpm', 0)
120
+ return self.rate_to_max_concurrency(
121
+ requests_per_min=rpm, tokens_per_min=tpm
122
+ )
104
123
 
105
124
  def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
106
125
  assert self._api_initialized
@@ -165,8 +184,8 @@ class Anthropic(lf.LanguageModel):
165
184
  def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
166
185
  """Parses Anthropic's response."""
167
186
  # NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
168
- output = response.json()
169
187
  if response.status_code == 200:
188
+ output = response.json()
170
189
  message = self._message_from_content(output['content'])
171
190
  input_tokens = output['usage']['input_tokens']
172
191
  output_tokens = output['usage']['output_tokens']
@@ -181,12 +200,11 @@ class Anthropic(lf.LanguageModel):
181
200
  else:
182
201
  if response.status_code == 429:
183
202
  error_cls = RateLimitError
184
- elif response.status_code == 529:
203
+ elif response.status_code in (502, 529):
185
204
  error_cls = OverloadedError
186
205
  else:
187
206
  error_cls = AnthropicError
188
- error = output['error']
189
- raise error_cls(f'{error["type"]}: {error["message"]}')
207
+ raise error_cls(f'{response.status_code}: {response.content}')
190
208
 
191
209
  def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
192
210
  request = dict()
@@ -198,17 +216,13 @@ class Anthropic(lf.LanguageModel):
198
216
  ]
199
217
  )
200
218
  )
201
- response = requests.post(
202
- _ANTHROPIC_MESSAGE_API_ENDPOINT,
203
- json=request,
204
- headers={
205
- 'x-api-key': self._api_key,
206
- 'anthropic-version': _ANTHROPIC_API_VERSION,
207
- 'content-type': 'application/json',
208
- },
209
- timeout=self.timeout,
210
- )
211
- return self._parse_response(response)
219
+ try:
220
+ response = self._session.post(
221
+ _ANTHROPIC_MESSAGE_API_ENDPOINT, json=request, timeout=self.timeout,
222
+ )
223
+ return self._parse_response(response)
224
+ except ConnectionError as e:
225
+ raise OverloadedError(str(e)) from e
212
226
 
213
227
 
214
228
  class Claude3(Anthropic):
@@ -98,20 +98,20 @@ def mock_requests_post_error(status_code, error_type, error_message):
98
98
  return _mock_requests
99
99
 
100
100
 
101
- class AuthropicTest(unittest.TestCase):
101
+ class AnthropicTest(unittest.TestCase):
102
102
 
103
103
  def test_basics(self):
104
104
  self.assertEqual(
105
105
  anthropic.Claude3Haiku().model_id, 'claude-3-haiku-20240307'
106
106
  )
107
- self.assertEqual(anthropic.Claude3Haiku().max_concurrency, 16)
107
+ self.assertGreater(anthropic.Claude3Haiku().max_concurrency, 0)
108
108
 
109
109
  def test_api_key(self):
110
110
  lm = anthropic.Claude3Haiku()
111
111
  with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
112
112
  lm('hi')
113
113
 
114
- with mock.patch('requests.post') as mock_request:
114
+ with mock.patch('requests.Session.post') as mock_request:
115
115
  mock_request.side_effect = mock_requests_post
116
116
 
117
117
  lm = anthropic.Claude3Haiku(api_key='fake key')
@@ -123,7 +123,7 @@ class AuthropicTest(unittest.TestCase):
123
123
  del os.environ['ANTHROPIC_API_KEY']
124
124
 
125
125
  def test_call(self):
126
- with mock.patch('requests.post') as mock_request:
126
+ with mock.patch('requests.Session.post') as mock_request:
127
127
  mock_request.side_effect = mock_requests_post
128
128
  lm = anthropic.Claude3Haiku(api_key='fake_key')
129
129
  response = lm('hello', temperature=0.0, top_k=0.1, top_p=0.2, stop=['\n'])
@@ -140,7 +140,7 @@ class AuthropicTest(unittest.TestCase):
140
140
  self.assertIsNotNone(response.usage.total_tokens, 3)
141
141
 
142
142
  def test_mm_call(self):
143
- with mock.patch('requests.post') as mock_mm_request:
143
+ with mock.patch('requests.Session.post') as mock_mm_request:
144
144
  mock_mm_request.side_effect = mock_mm_requests_post
145
145
  lm = anthropic.Claude3Haiku(api_key='fake_key')
146
146
  response = lm(lf_modalities.Image.from_bytes(image_content), lm=lm)
@@ -152,13 +152,13 @@ class AuthropicTest(unittest.TestCase):
152
152
  (529, 'service_unavailable', 'Service unavailable.'),
153
153
  (500, 'bad_request', 'Bad request.'),
154
154
  ]:
155
- with mock.patch('requests.post') as mock_mm_request:
155
+ with mock.patch('requests.Session.post') as mock_mm_request:
156
156
  mock_mm_request.side_effect = mock_requests_post_error(
157
157
  status_code, error_type, error_message
158
158
  )
159
159
  lm = anthropic.Claude3Haiku(api_key='fake_key')
160
160
  with self.assertRaisesRegex(
161
- Exception, f'{error_type}: {error_message}'
161
+ Exception, f'.*{status_code}: .*{error_message}'
162
162
  ):
163
163
  lm('hello', lm=lm, max_attempts=1)
164
164
 
@@ -0,0 +1,260 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Language models from Groq."""
15
+
16
+ import functools
17
+ import os
18
+ from typing import Annotated, Any
19
+
20
+ import langfun.core as lf
21
+ from langfun.core import modalities as lf_modalities
22
+ import pyglove as pg
23
+ import requests
24
+
25
+
26
+ SUPPORTED_MODELS_AND_SETTINGS = {
27
+ # Refer https://console.groq.com/docs/models
28
+ 'llama3-8b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
29
+ 'llama3-70b-8192': pg.Dict(max_tokens=8192, max_concurrency=16),
30
+ 'llama2-70b-4096': pg.Dict(max_tokens=4096, max_concurrency=16),
31
+ 'mixtral-8x7b-32768': pg.Dict(max_tokens=32768, max_concurrency=16),
32
+ 'gemma-7b-it': pg.Dict(max_tokens=8192, max_concurrency=16),
33
+ }
34
+
35
+
36
+ class GroqError(Exception): # pylint: disable=g-bad-exception-name
37
+ """Base class for Groq errors."""
38
+
39
+
40
+ class RateLimitError(GroqError):
41
+ """Error for rate limit reached."""
42
+
43
+
44
+ class OverloadedError(GroqError):
45
+ """Groq's server is temporarily overloaded."""
46
+
47
+
48
+ _CHAT_COMPLETE_API_ENDPOINT = 'https://api.groq.com/openai/v1/chat/completions'
49
+
50
+
51
+ @lf.use_init_args(['model'])
52
+ class Groq(lf.LanguageModel):
53
+ """Groq LLMs through REST APIs (OpenAI compatible).
54
+
55
+ See https://platform.openai.com/docs/api-reference/chat
56
+ """
57
+
58
+ model: pg.typing.Annotated[
59
+ pg.typing.Enum(
60
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
61
+ ),
62
+ 'The name of the model to use.',
63
+ ]
64
+
65
+ multimodal: Annotated[bool, 'Whether this model has multimodal support.'] = (
66
+ False
67
+ )
68
+
69
+ api_key: Annotated[
70
+ str | None,
71
+ (
72
+ 'API key. If None, the key will be read from environment variable '
73
+ "'GROQ_API_KEY'."
74
+ ),
75
+ ] = None
76
+
77
+ def _on_bound(self):
78
+ super()._on_bound()
79
+ self._api_key = None
80
+ self.__dict__.pop('_api_initialized', None)
81
+ self.__dict__.pop('_session', None)
82
+
83
+ @functools.cached_property
84
+ def _api_initialized(self):
85
+ api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
86
+ if not api_key:
87
+ raise ValueError(
88
+ 'Please specify `api_key` during `__init__` or set environment '
89
+ 'variable `GROQ_API_KEY` with your Groq API key.'
90
+ )
91
+ self._api_key = api_key
92
+ return True
93
+
94
+ @functools.cached_property
95
+ def _session(self) -> requests.Session:
96
+ assert self._api_initialized
97
+ s = requests.Session()
98
+ s.headers.update({
99
+ 'Authorization': f'Bearer {self._api_key}',
100
+ 'Content-Type': 'application/json',
101
+ })
102
+ return s
103
+
104
+ @property
105
+ def model_id(self) -> str:
106
+ """Returns a string to identify the model."""
107
+ return self.model
108
+
109
+ @property
110
+ def max_concurrency(self) -> int:
111
+ return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
112
+
113
+ def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
114
+ """Returns a dict as request arguments."""
115
+ # `logprobs` and `top_logprobs` flags are not supported on Groq yet.
116
+ args = dict(
117
+ model=self.model,
118
+ n=options.n,
119
+ stream=False,
120
+ )
121
+
122
+ if options.temperature is not None:
123
+ args['temperature'] = options.temperature
124
+ if options.max_tokens is not None:
125
+ args['max_tokens'] = options.max_tokens
126
+ if options.top_p is not None:
127
+ args['top_p'] = options.top_p
128
+ if options.stop:
129
+ args['stop'] = options.stop
130
+ return args
131
+
132
+ def _content_from_message(self, prompt: lf.Message) -> list[dict[str, Any]]:
133
+ """Converts an message to Groq's content protocol (list of dicts)."""
134
+ # Refer: https://platform.openai.com/docs/api-reference/chat/create
135
+ content = []
136
+ for chunk in prompt.chunk():
137
+ if isinstance(chunk, str):
138
+ item = dict(type='text', text=chunk)
139
+ elif (
140
+ self.multimodal
141
+ and isinstance(chunk, lf_modalities.Image)
142
+ and chunk.uri
143
+ ):
144
+ # NOTE(daiyip): Groq only support image URL.
145
+ item = dict(type='image_url', image_url=chunk.uri)
146
+ else:
147
+ raise ValueError(f'Unsupported modality object: {chunk!r}.')
148
+ content.append(item)
149
+ return content
150
+
151
+ def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
152
+ """Converts Groq's content protocol to message."""
153
+ # Refer: https://platform.openai.com/docs/api-reference/chat/create
154
+ content = choice['message']['content']
155
+ if isinstance(content, str):
156
+ return lf.AIMessage(content)
157
+ return lf.AIMessage.from_chunks(
158
+ [x['text'] for x in content if x['type'] == 'text']
159
+ )
160
+
161
+ def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
162
+ """Parses Groq's response."""
163
+ # Refer: https://platform.openai.com/docs/api-reference/chat/object
164
+ if response.status_code == 200:
165
+ output = response.json()
166
+ samples = [
167
+ lf.LMSample(self._message_from_choice(choice), score=0.0)
168
+ for choice in output['choices']
169
+ ]
170
+ usage = output['usage']
171
+ return lf.LMSamplingResult(
172
+ samples,
173
+ usage=lf.LMSamplingUsage(
174
+ prompt_tokens=usage['prompt_tokens'],
175
+ completion_tokens=usage['completion_tokens'],
176
+ total_tokens=usage['total_tokens'],
177
+ ),
178
+ )
179
+ else:
180
+ # https://platform.openai.com/docs/guides/error-codes/api-errors
181
+ if response.status_code == 429:
182
+ error_cls = RateLimitError
183
+ elif response.status_code in (500, 502, 503):
184
+ error_cls = OverloadedError
185
+ else:
186
+ error_cls = GroqError
187
+ raise error_cls(f'{response.status_code}: {response.content}')
188
+
189
+ def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
190
+ assert self._api_initialized
191
+ return self._parallel_execute_with_currency_control(
192
+ self._sample_single,
193
+ prompts,
194
+ retry_on_errors=(RateLimitError, OverloadedError),
195
+ )
196
+
197
+ def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
198
+ request = dict()
199
+ request.update(self._get_request_args(self.sampling_options))
200
+ request.update(
201
+ dict(
202
+ messages=[
203
+ dict(role='user', content=self._content_from_message(prompt))
204
+ ]
205
+ )
206
+ )
207
+ try:
208
+ response = self._session.post(
209
+ _CHAT_COMPLETE_API_ENDPOINT,
210
+ json=request,
211
+ timeout=self.timeout,
212
+ )
213
+ return self._parse_response(response)
214
+ except ConnectionError as e:
215
+ raise OverloadedError(str(e)) from e
216
+
217
+
218
+ class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
219
+ """Llama3-8B with 8K context window.
220
+
221
+ See: https://huggingface.co/meta-llama/Meta-Llama-3-8B
222
+ """
223
+
224
+ model = 'llama3-8b-8192'
225
+
226
+
227
+ class GroqLlama3_70B(Groq): # pylint: disable=invalid-name
228
+ """Llama3-70B with 8K context window.
229
+
230
+ See: https://huggingface.co/meta-llama/Meta-Llama-3-70B
231
+ """
232
+
233
+ model = 'llama3-70b-8192'
234
+
235
+
236
+ class GroqLlama2_70B(Groq): # pylint: disable=invalid-name
237
+ """Llama2-70B with 4K context window.
238
+
239
+ See: https://huggingface.co/meta-llama/Llama-2-70b
240
+ """
241
+
242
+ model = 'llama2-70b-4096'
243
+
244
+
245
+ class GroqMistral_8x7B(Groq): # pylint: disable=invalid-name
246
+ """Mixtral 8x7B with 32K context window.
247
+
248
+ See: https://huggingface.co/meta-llama/Llama-2-70b
249
+ """
250
+
251
+ model = 'mixtral-8x7b-32768'
252
+
253
+
254
+ class GroqGemma7B_IT(Groq): # pylint: disable=invalid-name
255
+ """Gemma 7B with 8K context window.
256
+
257
+ See: https://huggingface.co/google/gemma-1.1-7b-it
258
+ """
259
+
260
+ model = 'gemma-7b-it'