langfun 0.0.2.dev20240601__tar.gz → 0.0.2.dev20240604__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/PKG-INFO +1 -1
  2. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/__init__.py +1 -1
  3. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/__init__.py +5 -0
  4. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/language_model.py +27 -1
  5. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/__init__.py +3 -0
  6. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/anthropic.py +44 -80
  7. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/anthropic_test.py +1 -1
  8. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/groq.py +42 -87
  9. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/groq_test.py +1 -1
  10. langfun-0.0.2.dev20240604/langfun/core/llms/llama_cpp.py +84 -0
  11. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/llama_cpp_test.py +14 -8
  12. langfun-0.0.2.dev20240604/langfun/core/llms/rest.py +112 -0
  13. langfun-0.0.2.dev20240604/langfun/core/llms/rest_test.py +111 -0
  14. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun.egg-info/PKG-INFO +1 -1
  15. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun.egg-info/SOURCES.txt +2 -0
  16. langfun-0.0.2.dev20240601/langfun/core/llms/llama_cpp.py +0 -74
  17. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/LICENSE +0 -0
  18. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/README.md +0 -0
  19. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/__init__.py +0 -0
  20. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/__init__.py +0 -0
  21. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/correction.py +0 -0
  22. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/correction_test.py +0 -0
  23. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/errors.py +0 -0
  24. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/errors_test.py +0 -0
  25. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/execution.py +0 -0
  26. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/execution_test.py +0 -0
  27. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/generation.py +0 -0
  28. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/generation_test.py +0 -0
  29. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/parsing.py +0 -0
  30. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/parsing_test.py +0 -0
  31. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/permissions.py +0 -0
  32. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/coding/python/permissions_test.py +0 -0
  33. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/component.py +0 -0
  34. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/component_test.py +0 -0
  35. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/concurrent.py +0 -0
  36. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/concurrent_test.py +0 -0
  37. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/console.py +0 -0
  38. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/console_test.py +0 -0
  39. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/__init__.py +0 -0
  40. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/base.py +0 -0
  41. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/base_test.py +0 -0
  42. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/matching.py +0 -0
  43. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/matching_test.py +0 -0
  44. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/patching.py +0 -0
  45. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/patching_test.py +0 -0
  46. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/scoring.py +0 -0
  47. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/eval/scoring_test.py +0 -0
  48. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/langfunc.py +0 -0
  49. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/langfunc_test.py +0 -0
  50. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/language_model_test.py +0 -0
  51. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/cache/__init__.py +0 -0
  52. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/cache/base.py +0 -0
  53. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/cache/in_memory.py +0 -0
  54. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/cache/in_memory_test.py +0 -0
  55. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/fake.py +0 -0
  56. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/fake_test.py +0 -0
  57. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/google_genai.py +0 -0
  58. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/google_genai_test.py +0 -0
  59. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/openai.py +0 -0
  60. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/openai_test.py +0 -0
  61. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/vertexai.py +0 -0
  62. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/llms/vertexai_test.py +0 -0
  63. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/memories/__init__.py +0 -0
  64. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/memories/conversation_history.py +0 -0
  65. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/memories/conversation_history_test.py +0 -0
  66. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/memory.py +0 -0
  67. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/message.py +0 -0
  68. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/message_test.py +0 -0
  69. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/__init__.py +0 -0
  70. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/audio.py +0 -0
  71. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/audio_test.py +0 -0
  72. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/image.py +0 -0
  73. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/image_test.py +0 -0
  74. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/mime.py +0 -0
  75. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/mime_test.py +0 -0
  76. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/ms_office.py +0 -0
  77. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/ms_office_test.py +0 -0
  78. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/pdf.py +0 -0
  79. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/pdf_test.py +0 -0
  80. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/video.py +0 -0
  81. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modalities/video_test.py +0 -0
  82. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modality.py +0 -0
  83. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/modality_test.py +0 -0
  84. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/natural_language.py +0 -0
  85. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/natural_language_test.py +0 -0
  86. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/sampling.py +0 -0
  87. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/sampling_test.py +0 -0
  88. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/__init__.py +0 -0
  89. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/completion.py +0 -0
  90. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/completion_test.py +0 -0
  91. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/description.py +0 -0
  92. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/description_test.py +0 -0
  93. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/function_generation.py +0 -0
  94. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/function_generation_test.py +0 -0
  95. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/mapping.py +0 -0
  96. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/mapping_test.py +0 -0
  97. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/parsing.py +0 -0
  98. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/parsing_test.py +0 -0
  99. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/prompting.py +0 -0
  100. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/prompting_test.py +0 -0
  101. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/schema.py +0 -0
  102. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/schema_generation.py +0 -0
  103. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/schema_generation_test.py +0 -0
  104. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/schema_test.py +0 -0
  105. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/scoring.py +0 -0
  106. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/structured/scoring_test.py +0 -0
  107. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/subscription.py +0 -0
  108. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/subscription_test.py +0 -0
  109. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/template.py +0 -0
  110. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/template_test.py +0 -0
  111. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/__init__.py +0 -0
  112. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/completion.py +0 -0
  113. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/completion_test.py +0 -0
  114. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/conversation.py +0 -0
  115. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/conversation_test.py +0 -0
  116. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/demonstration.py +0 -0
  117. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/demonstration_test.py +0 -0
  118. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/selfplay.py +0 -0
  119. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/templates/selfplay_test.py +0 -0
  120. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/text_formatting.py +0 -0
  121. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun/core/text_formatting_test.py +0 -0
  122. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun.egg-info/dependency_links.txt +0 -0
  123. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun.egg-info/requires.txt +0 -0
  124. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/langfun.egg-info/top_level.txt +0 -0
  125. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/setup.cfg +0 -0
  126. {langfun-0.0.2.dev20240601 → langfun-0.0.2.dev20240604}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langfun
3
- Version: 0.0.2.dev20240601
3
+ Version: 0.0.2.dev20240604
4
4
  Summary: Langfun: Language as Functions.
5
5
  Home-page: https://github.com/google/langfun
6
6
  Author: Langfun Authors
@@ -63,7 +63,7 @@ Image = modalities.Image
63
63
  Video = modalities.Video
64
64
  PDF = modalities.PDF
65
65
 
66
- # Error types.
66
+ # Additional error types.
67
67
  MappingError = structured.MappingError
68
68
  SchemaError = structured.SchemaError
69
69
  JsonError = structured.JsonError
@@ -106,6 +106,11 @@ from langfun.core.language_model import LMScoringResult
106
106
  from langfun.core.language_model import LMCache
107
107
  from langfun.core.language_model import LMDebugMode
108
108
 
109
+ from langfun.core.language_model import LMError
110
+ from langfun.core.language_model import RetryableLMError
111
+ from langfun.core.language_model import RateLimitError
112
+ from langfun.core.language_model import TemporaryLMError
113
+
109
114
  # Components for building agents.
110
115
  from langfun.core.memory import Memory
111
116
 
@@ -29,6 +29,32 @@ TOKENS_PER_REQUEST = 250 # Estimated num tokens for a single request
29
29
  DEFAULT_MAX_CONCURRENCY = 1 # Use this as max concurrency if no RPM or TPM data
30
30
 
31
31
 
32
+ #
33
+ # Common errors during calling language models.
34
+ #
35
+
36
+
37
+ class LMError(RuntimeError):
38
+ """Base class for language model errors."""
39
+
40
+
41
+ class RetryableLMError(LMError):
42
+ """Base class for LLM errors that can be solved by retrying."""
43
+
44
+
45
+ class RateLimitError(RetryableLMError):
46
+ """Error for rate limit reached."""
47
+
48
+
49
+ class TemporaryLMError(RetryableLMError):
50
+ """Error for temporary service issues that can be retried."""
51
+
52
+
53
+ #
54
+ # Language model input/output interfaces.
55
+ #
56
+
57
+
32
58
  class LMSample(pg.Object):
33
59
  """Response candidate."""
34
60
 
@@ -445,7 +471,7 @@ class LanguageModel(component.Component):
445
471
  None,
446
472
  Union[Type[Exception], Tuple[Type[Exception], str]],
447
473
  Sequence[Union[Type[Exception], Tuple[Type[Exception], str]]],
448
- ] = None,
474
+ ] = RetryableLMError,
449
475
  ) -> Any:
450
476
  """Helper method for subclasses for implementing _sample."""
451
477
  return concurrent.concurrent_execute(
@@ -24,6 +24,9 @@ from langfun.core.llms.fake import StaticMapping
24
24
  from langfun.core.llms.fake import StaticResponse
25
25
  from langfun.core.llms.fake import StaticSequence
26
26
 
27
+ # REST-based models.
28
+ from langfun.core.llms.rest import REST
29
+
27
30
  # Gemini models.
28
31
  from langfun.core.llms.google_genai import GenAI
29
32
  from langfun.core.llms.google_genai import GeminiPro
@@ -14,14 +14,13 @@
14
14
  """Language models from Anthropic."""
15
15
 
16
16
  import base64
17
- import functools
18
17
  import os
19
18
  from typing import Annotated, Any
20
19
 
21
20
  import langfun.core as lf
22
21
  from langfun.core import modalities as lf_modalities
22
+ from langfun.core.llms import rest
23
23
  import pyglove as pg
24
- import requests
25
24
 
26
25
 
27
26
  SUPPORTED_MODELS_AND_SETTINGS = {
@@ -38,24 +37,8 @@ SUPPORTED_MODELS_AND_SETTINGS = {
38
37
  }
39
38
 
40
39
 
41
- class AnthropicError(Exception): # pylint: disable=g-bad-exception-name
42
- """Base class for Anthropic errors."""
43
-
44
-
45
- class RateLimitError(AnthropicError):
46
- """Error for rate limit reached."""
47
-
48
-
49
- class OverloadedError(AnthropicError):
50
- """Anthropic's server is temporarily overloaded."""
51
-
52
-
53
- _ANTHROPIC_MESSAGE_API_ENDPOINT = 'https://api.anthropic.com/v1/messages'
54
- _ANTHROPIC_API_VERSION = '2023-06-01'
55
-
56
-
57
40
  @lf.use_init_args(['model'])
58
- class Anthropic(lf.LanguageModel):
41
+ class Anthropic(rest.REST):
59
42
  """Anthropic LLMs (Claude) through REST APIs.
60
43
 
61
44
  See https://docs.anthropic.com/claude/reference/messages_post
@@ -80,14 +63,18 @@ class Anthropic(lf.LanguageModel):
80
63
  ),
81
64
  ] = None
82
65
 
66
+ api_endpoint: str = 'https://api.anthropic.com/v1/messages'
67
+
68
+ api_version: Annotated[
69
+ str,
70
+ 'Anthropic API version.'
71
+ ] = '2023-06-01'
72
+
83
73
  def _on_bound(self):
84
74
  super()._on_bound()
85
75
  self._api_key = None
86
- self.__dict__.pop('_api_initialized', None)
87
- self.__dict__.pop('_session', None)
88
76
 
89
- @functools.cached_property
90
- def _api_initialized(self):
77
+ def _initialize(self):
91
78
  api_key = self.api_key or os.environ.get('ANTHROPIC_API_KEY', None)
92
79
  if not api_key:
93
80
  raise ValueError(
@@ -95,18 +82,14 @@ class Anthropic(lf.LanguageModel):
95
82
  'variable `ANTHROPIC_API_KEY` with your Anthropic API key.'
96
83
  )
97
84
  self._api_key = api_key
98
- return True
99
85
 
100
- @functools.cached_property
101
- def _session(self) -> requests.Session:
102
- assert self._api_initialized
103
- s = requests.Session()
104
- s.headers.update({
86
+ @property
87
+ def headers(self) -> dict[str, Any]:
88
+ return {
105
89
  'x-api-key': self._api_key,
106
- 'anthropic-version': _ANTHROPIC_API_VERSION,
90
+ 'anthropic-version': self.api_version,
107
91
  'content-type': 'application/json',
108
- })
109
- return s
92
+ }
110
93
 
111
94
  @property
112
95
  def model_id(self) -> str:
@@ -121,13 +104,24 @@ class Anthropic(lf.LanguageModel):
121
104
  requests_per_min=rpm, tokens_per_min=tpm
122
105
  )
123
106
 
124
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
125
- assert self._api_initialized
126
- return self._parallel_execute_with_currency_control(
127
- self._sample_single, prompts, retry_on_errors=(RateLimitError)
107
+ def request(
108
+ self,
109
+ prompt: lf.Message,
110
+ sampling_options: lf.LMSamplingOptions
111
+ ) -> dict[str, Any]:
112
+ """Returns the JSON input for a message."""
113
+ request = dict()
114
+ request.update(self._request_args(sampling_options))
115
+ request.update(
116
+ dict(
117
+ messages=[
118
+ dict(role='user', content=self._content_from_message(prompt))
119
+ ]
120
+ )
128
121
  )
122
+ return request
129
123
 
130
- def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
124
+ def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
131
125
  """Returns a dict as request arguments."""
132
126
  # Authropic requires `max_tokens` to be specified.
133
127
  max_tokens = (
@@ -174,6 +168,19 @@ class Anthropic(lf.LanguageModel):
174
168
  else:
175
169
  return [dict(type='text', text=prompt.text)]
176
170
 
171
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
172
+ message = self._message_from_content(json['content'])
173
+ input_tokens = json['usage']['input_tokens']
174
+ output_tokens = json['usage']['output_tokens']
175
+ return lf.LMSamplingResult(
176
+ [lf.LMSample(message)],
177
+ usage=lf.LMSamplingUsage(
178
+ prompt_tokens=input_tokens,
179
+ completion_tokens=output_tokens,
180
+ total_tokens=input_tokens + output_tokens,
181
+ ),
182
+ )
183
+
177
184
  def _message_from_content(self, content: list[dict[str, Any]]) -> lf.Message:
178
185
  """Converts Anthropic's content protocol to message."""
179
186
  # Refer: https://docs.anthropic.com/claude/reference/messages-examples
@@ -181,49 +188,6 @@ class Anthropic(lf.LanguageModel):
181
188
  [x['text'] for x in content if x['type'] == 'text']
182
189
  )
183
190
 
184
- def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
185
- """Parses Anthropic's response."""
186
- # NOTE(daiyip): Refer https://docs.anthropic.com/claude/reference/errors
187
- if response.status_code == 200:
188
- output = response.json()
189
- message = self._message_from_content(output['content'])
190
- input_tokens = output['usage']['input_tokens']
191
- output_tokens = output['usage']['output_tokens']
192
- return lf.LMSamplingResult(
193
- [lf.LMSample(message)],
194
- usage=lf.LMSamplingUsage(
195
- prompt_tokens=input_tokens,
196
- completion_tokens=output_tokens,
197
- total_tokens=input_tokens + output_tokens,
198
- ),
199
- )
200
- else:
201
- if response.status_code == 429:
202
- error_cls = RateLimitError
203
- elif response.status_code in (502, 529):
204
- error_cls = OverloadedError
205
- else:
206
- error_cls = AnthropicError
207
- raise error_cls(f'{response.status_code}: {response.content}')
208
-
209
- def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
210
- request = dict()
211
- request.update(self._get_request_args(self.sampling_options))
212
- request.update(
213
- dict(
214
- messages=[
215
- dict(role='user', content=self._content_from_message(prompt))
216
- ]
217
- )
218
- )
219
- try:
220
- response = self._session.post(
221
- _ANTHROPIC_MESSAGE_API_ENDPOINT, json=request, timeout=self.timeout,
222
- )
223
- return self._parse_response(response)
224
- except ConnectionError as e:
225
- raise OverloadedError(str(e)) from e
226
-
227
191
 
228
192
  class Claude3(Anthropic):
229
193
  """Base class for Claude 3 models. 200K input tokens and 4K output tokens."""
@@ -160,7 +160,7 @@ class AnthropicTest(unittest.TestCase):
160
160
  with self.assertRaisesRegex(
161
161
  Exception, f'.*{status_code}: .*{error_message}'
162
162
  ):
163
- lm('hello', lm=lm, max_attempts=1)
163
+ lm('hello', max_attempts=1)
164
164
 
165
165
 
166
166
  if __name__ == '__main__':
@@ -13,14 +13,13 @@
13
13
  # limitations under the License.
14
14
  """Language models from Groq."""
15
15
 
16
- import functools
17
16
  import os
18
17
  from typing import Annotated, Any
19
18
 
20
19
  import langfun.core as lf
21
20
  from langfun.core import modalities as lf_modalities
21
+ from langfun.core.llms import rest
22
22
  import pyglove as pg
23
- import requests
24
23
 
25
24
 
26
25
  SUPPORTED_MODELS_AND_SETTINGS = {
@@ -33,23 +32,8 @@ SUPPORTED_MODELS_AND_SETTINGS = {
33
32
  }
34
33
 
35
34
 
36
- class GroqError(Exception): # pylint: disable=g-bad-exception-name
37
- """Base class for Groq errors."""
38
-
39
-
40
- class RateLimitError(GroqError):
41
- """Error for rate limit reached."""
42
-
43
-
44
- class OverloadedError(GroqError):
45
- """Groq's server is temporarily overloaded."""
46
-
47
-
48
- _CHAT_COMPLETE_API_ENDPOINT = 'https://api.groq.com/openai/v1/chat/completions'
49
-
50
-
51
35
  @lf.use_init_args(['model'])
52
- class Groq(lf.LanguageModel):
36
+ class Groq(rest.REST):
53
37
  """Groq LLMs through REST APIs (OpenAI compatible).
54
38
 
55
39
  See https://platform.openai.com/docs/api-reference/chat
@@ -74,14 +58,13 @@ class Groq(lf.LanguageModel):
74
58
  ),
75
59
  ] = None
76
60
 
61
+ api_endpoint: str = 'https://api.groq.com/openai/v1/chat/completions'
62
+
77
63
  def _on_bound(self):
78
64
  super()._on_bound()
79
65
  self._api_key = None
80
- self.__dict__.pop('_api_initialized', None)
81
- self.__dict__.pop('_session', None)
82
66
 
83
- @functools.cached_property
84
- def _api_initialized(self):
67
+ def _initialize(self):
85
68
  api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
86
69
  if not api_key:
87
70
  raise ValueError(
@@ -89,17 +72,13 @@ class Groq(lf.LanguageModel):
89
72
  'variable `GROQ_API_KEY` with your Groq API key.'
90
73
  )
91
74
  self._api_key = api_key
92
- return True
93
75
 
94
- @functools.cached_property
95
- def _session(self) -> requests.Session:
96
- assert self._api_initialized
97
- s = requests.Session()
98
- s.headers.update({
76
+ @property
77
+ def headers(self) -> dict[str, Any]:
78
+ return {
99
79
  'Authorization': f'Bearer {self._api_key}',
100
80
  'Content-Type': 'application/json',
101
- })
102
- return s
81
+ }
103
82
 
104
83
  @property
105
84
  def model_id(self) -> str:
@@ -110,7 +89,24 @@ class Groq(lf.LanguageModel):
110
89
  def max_concurrency(self) -> int:
111
90
  return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
112
91
 
113
- def _get_request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
92
+ def request(
93
+ self,
94
+ prompt: lf.Message,
95
+ sampling_options: lf.LMSamplingOptions
96
+ ) -> dict[str, Any]:
97
+ """Returns the JSON input for a message."""
98
+ request = dict()
99
+ request.update(self._request_args(sampling_options))
100
+ request.update(
101
+ dict(
102
+ messages=[
103
+ dict(role='user', content=self._content_from_message(prompt))
104
+ ]
105
+ )
106
+ )
107
+ return request
108
+
109
+ def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
114
110
  """Returns a dict as request arguments."""
115
111
  # `logprobs` and `top_logprobs` flags are not supported on Groq yet.
116
112
  args = dict(
@@ -148,6 +144,21 @@ class Groq(lf.LanguageModel):
148
144
  content.append(item)
149
145
  return content
150
146
 
147
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
148
+ samples = [
149
+ lf.LMSample(self._message_from_choice(choice), score=0.0)
150
+ for choice in json['choices']
151
+ ]
152
+ usage = json['usage']
153
+ return lf.LMSamplingResult(
154
+ samples,
155
+ usage=lf.LMSamplingUsage(
156
+ prompt_tokens=usage['prompt_tokens'],
157
+ completion_tokens=usage['completion_tokens'],
158
+ total_tokens=usage['total_tokens'],
159
+ ),
160
+ )
161
+
151
162
  def _message_from_choice(self, choice: dict[str, Any]) -> lf.Message:
152
163
  """Converts Groq's content protocol to message."""
153
164
  # Refer: https://platform.openai.com/docs/api-reference/chat/create
@@ -158,62 +169,6 @@ class Groq(lf.LanguageModel):
158
169
  [x['text'] for x in content if x['type'] == 'text']
159
170
  )
160
171
 
161
- def _parse_response(self, response: requests.Response) -> lf.LMSamplingResult:
162
- """Parses Groq's response."""
163
- # Refer: https://platform.openai.com/docs/api-reference/chat/object
164
- if response.status_code == 200:
165
- output = response.json()
166
- samples = [
167
- lf.LMSample(self._message_from_choice(choice), score=0.0)
168
- for choice in output['choices']
169
- ]
170
- usage = output['usage']
171
- return lf.LMSamplingResult(
172
- samples,
173
- usage=lf.LMSamplingUsage(
174
- prompt_tokens=usage['prompt_tokens'],
175
- completion_tokens=usage['completion_tokens'],
176
- total_tokens=usage['total_tokens'],
177
- ),
178
- )
179
- else:
180
- # https://platform.openai.com/docs/guides/error-codes/api-errors
181
- if response.status_code == 429:
182
- error_cls = RateLimitError
183
- elif response.status_code in (500, 502, 503):
184
- error_cls = OverloadedError
185
- else:
186
- error_cls = GroqError
187
- raise error_cls(f'{response.status_code}: {response.content}')
188
-
189
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
190
- assert self._api_initialized
191
- return self._parallel_execute_with_currency_control(
192
- self._sample_single,
193
- prompts,
194
- retry_on_errors=(RateLimitError, OverloadedError),
195
- )
196
-
197
- def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
198
- request = dict()
199
- request.update(self._get_request_args(self.sampling_options))
200
- request.update(
201
- dict(
202
- messages=[
203
- dict(role='user', content=self._content_from_message(prompt))
204
- ]
205
- )
206
- )
207
- try:
208
- response = self._session.post(
209
- _CHAT_COMPLETE_API_ENDPOINT,
210
- json=request,
211
- timeout=self.timeout,
212
- )
213
- return self._parse_response(response)
214
- except ConnectionError as e:
215
- raise OverloadedError(str(e)) from e
216
-
217
172
 
218
173
  class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
219
174
  """Llama3-8B with 8K context window.
@@ -163,7 +163,7 @@ class AuthropicTest(unittest.TestCase):
163
163
  with self.assertRaisesRegex(
164
164
  Exception, f'{status_code}:.*{error_type}'
165
165
  ):
166
- lm('hello', lm=lm, max_attempts=1)
166
+ lm('hello', max_attempts=1)
167
167
 
168
168
 
169
169
  if __name__ == '__main__':
@@ -0,0 +1,84 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Language models from llama.cpp."""
15
+
16
+ from typing import Any
17
+
18
+ import langfun.core as lf
19
+ from langfun.core.llms import rest
20
+ import pyglove as pg
21
+
22
+
23
+ class LlamaCppRemote(rest.REST):
24
+ """The remote LLaMA C++ model.
25
+
26
+ The Remote LLaMA C++ models can be launched via
27
+ https://github.com/ggerganov/llama.cpp/tree/master/examples/server
28
+ """
29
+
30
+ @pg.explicit_method_override
31
+ def __init__(self, url: str, model: str | None = None, **kwargs):
32
+ super().__init__(api_endpoint=f'{url}/completion', model=model, **kwargs)
33
+
34
+ @property
35
+ def model_id(self) -> str:
36
+ """Returns a string to identify the model."""
37
+ return f'LLaMAC++({self.model or ""})'
38
+
39
+ def request(
40
+ self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
41
+ ) -> dict[str, Any]:
42
+ """Returns the JSON input for a message."""
43
+ request = dict()
44
+ request.update(self._request_args(sampling_options))
45
+ # NOTE(daiyip): multi-modal is current not supported.
46
+ request['prompt'] = prompt.text
47
+ return request
48
+
49
+ def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
50
+ """Returns a dict as request arguments."""
51
+ args = dict(
52
+ n_predict=options.max_tokens or 1024,
53
+ top_k=options.top_k or 50,
54
+ top_p=options.top_p or 0.95,
55
+ )
56
+ if options.temperature is not None:
57
+ args['temperature'] = options.temperature
58
+ return args
59
+
60
+ def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
61
+ return lf.LMSamplingResult(
62
+ [lf.LMSample(item['content'], score=0.0) for item in json['items']]
63
+ )
64
+
65
+ def _sample_single(self, prompt: lf.Message) -> lf.LMSamplingResult:
66
+ request = self.request(prompt, self.sampling_options)
67
+
68
+ def _sample_one_example(request):
69
+ response = self._session.post(
70
+ self.api_endpoint,
71
+ json=request,
72
+ timeout=self.timeout,
73
+ )
74
+ if response.status_code == 200:
75
+ return response.json()
76
+ else:
77
+ error_cls = self._error_cls_from_status(response.status_code)
78
+ raise error_cls(f'{response.status_code}: {response.content}')
79
+
80
+ items = self._parallel_execute_with_currency_control(
81
+ _sample_one_example,
82
+ [request] * (self.sampling_options.n or 1),
83
+ )
84
+ return self.result(dict(items=items))
@@ -17,7 +17,6 @@ import typing
17
17
  import unittest
18
18
  from unittest import mock
19
19
 
20
- import langfun.core as lf
21
20
  from langfun.core.llms import llama_cpp
22
21
 
23
22
 
@@ -25,6 +24,9 @@ def mock_requests_post(url: str, json: typing.Dict[str, typing.Any], **kwargs):
25
24
  del kwargs
26
25
 
27
26
  class TEMP:
27
+ @property
28
+ def status_code(self):
29
+ return 200
28
30
 
29
31
  def json(self):
30
32
  return {"content": json["prompt"] + "\n" + url}
@@ -36,19 +38,23 @@ class LlamaCppRemoteTest(unittest.TestCase):
36
38
  """Tests for the LlamaCppRemote model."""
37
39
 
38
40
  def test_call_completion(self):
39
- with mock.patch("requests.post") as mock_request:
41
+ with mock.patch("requests.Session.post") as mock_request:
40
42
  mock_request.side_effect = mock_requests_post
41
- lm = llama_cpp.LlamaCppRemote(url="http://127.0.0.1:8080")
42
- response = lm("hello", sampling_options=lf.LMSamplingOptions(n=1))
43
+ lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080")
44
+ [result] = lm.sample(["hello"], n=2)
43
45
  self.assertEqual(
44
- response.text,
46
+ len(result.samples),
47
+ 2
48
+ )
49
+ self.assertEqual(
50
+ str(result.samples[0].response),
45
51
  "hello\nhttp://127.0.0.1:8080/completion",
46
52
  )
47
53
 
48
- def test_name(self):
49
- lm = llama_cpp.LlamaCppRemote()
54
+ def test_model_id(self):
55
+ lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080")
50
56
  self.assertEqual(lm.model_id, "LLaMAC++()")
51
- lm = llama_cpp.LlamaCppRemote(url="xxx", name="x")
57
+ lm = llama_cpp.LlamaCppRemote("xxx", model="x")
52
58
  self.assertEqual(lm.model_id, "LLaMAC++(x)")
53
59
 
54
60