langfun 0.0.2.dev20240330__py3-none-any.whl → 0.1.2.dev202501140804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (145) hide show
  1. langfun/__init__.py +22 -2
  2. langfun/core/__init__.py +17 -5
  3. langfun/core/agentic/__init__.py +30 -0
  4. langfun/core/agentic/action.py +854 -0
  5. langfun/core/agentic/action_eval.py +150 -0
  6. langfun/core/agentic/action_eval_test.py +109 -0
  7. langfun/core/agentic/action_test.py +136 -0
  8. langfun/core/coding/python/__init__.py +5 -11
  9. langfun/core/coding/python/correction.py +37 -28
  10. langfun/core/coding/python/correction_test.py +29 -3
  11. langfun/core/coding/python/execution.py +40 -216
  12. langfun/core/coding/python/execution_test.py +29 -89
  13. langfun/core/coding/python/generation.py +21 -11
  14. langfun/core/coding/python/generation_test.py +2 -2
  15. langfun/core/coding/python/parsing.py +108 -193
  16. langfun/core/coding/python/parsing_test.py +2 -105
  17. langfun/core/component.py +69 -2
  18. langfun/core/component_test.py +54 -0
  19. langfun/core/concurrent.py +414 -117
  20. langfun/core/concurrent_test.py +111 -24
  21. langfun/core/console.py +18 -5
  22. langfun/core/console_test.py +17 -0
  23. langfun/core/eval/__init__.py +17 -0
  24. langfun/core/eval/base.py +767 -140
  25. langfun/core/eval/base_test.py +238 -53
  26. langfun/core/eval/matching.py +80 -76
  27. langfun/core/eval/matching_test.py +19 -9
  28. langfun/core/eval/patching.py +130 -0
  29. langfun/core/eval/patching_test.py +170 -0
  30. langfun/core/eval/scoring.py +37 -28
  31. langfun/core/eval/scoring_test.py +21 -3
  32. langfun/core/eval/v2/__init__.py +42 -0
  33. langfun/core/eval/v2/checkpointing.py +380 -0
  34. langfun/core/eval/v2/checkpointing_test.py +228 -0
  35. langfun/core/eval/v2/eval_test_helper.py +136 -0
  36. langfun/core/eval/v2/evaluation.py +725 -0
  37. langfun/core/eval/v2/evaluation_test.py +180 -0
  38. langfun/core/eval/v2/example.py +305 -0
  39. langfun/core/eval/v2/example_test.py +128 -0
  40. langfun/core/eval/v2/experiment.py +1048 -0
  41. langfun/core/eval/v2/experiment_test.py +433 -0
  42. langfun/core/eval/v2/metric_values.py +156 -0
  43. langfun/core/eval/v2/metric_values_test.py +80 -0
  44. langfun/core/eval/v2/metrics.py +357 -0
  45. langfun/core/eval/v2/metrics_test.py +203 -0
  46. langfun/core/eval/v2/progress.py +348 -0
  47. langfun/core/eval/v2/progress_test.py +82 -0
  48. langfun/core/eval/v2/progress_tracking.py +210 -0
  49. langfun/core/eval/v2/progress_tracking_test.py +66 -0
  50. langfun/core/eval/v2/reporting.py +270 -0
  51. langfun/core/eval/v2/reporting_test.py +158 -0
  52. langfun/core/eval/v2/runners.py +488 -0
  53. langfun/core/eval/v2/runners_test.py +334 -0
  54. langfun/core/langfunc.py +3 -21
  55. langfun/core/langfunc_test.py +26 -8
  56. langfun/core/language_model.py +686 -48
  57. langfun/core/language_model_test.py +681 -44
  58. langfun/core/llms/__init__.py +100 -12
  59. langfun/core/llms/anthropic.py +488 -0
  60. langfun/core/llms/anthropic_test.py +235 -0
  61. langfun/core/llms/cache/base.py +21 -2
  62. langfun/core/llms/cache/in_memory.py +13 -0
  63. langfun/core/llms/cache/in_memory_test.py +88 -28
  64. langfun/core/llms/compositional.py +101 -0
  65. langfun/core/llms/compositional_test.py +73 -0
  66. langfun/core/llms/deepseek.py +117 -0
  67. langfun/core/llms/deepseek_test.py +61 -0
  68. langfun/core/llms/fake.py +39 -26
  69. langfun/core/llms/fake_test.py +136 -11
  70. langfun/core/llms/gemini.py +507 -0
  71. langfun/core/llms/gemini_test.py +195 -0
  72. langfun/core/llms/google_genai.py +62 -218
  73. langfun/core/llms/google_genai_test.py +9 -197
  74. langfun/core/llms/groq.py +276 -0
  75. langfun/core/llms/groq_test.py +64 -0
  76. langfun/core/llms/llama_cpp.py +15 -40
  77. langfun/core/llms/llama_cpp_test.py +4 -30
  78. langfun/core/llms/openai.py +436 -226
  79. langfun/core/llms/openai_compatible.py +179 -0
  80. langfun/core/llms/openai_compatible_test.py +495 -0
  81. langfun/core/llms/openai_test.py +35 -174
  82. langfun/core/llms/rest.py +113 -0
  83. langfun/core/llms/rest_test.py +111 -0
  84. langfun/core/llms/vertexai.py +192 -0
  85. langfun/core/llms/vertexai_test.py +52 -0
  86. langfun/core/logging.py +284 -0
  87. langfun/core/logging_test.py +125 -0
  88. langfun/core/message.py +319 -9
  89. langfun/core/message_test.py +190 -13
  90. langfun/core/modalities/__init__.py +6 -2
  91. langfun/core/modalities/audio.py +30 -0
  92. langfun/core/modalities/audio_test.py +63 -0
  93. langfun/core/modalities/image.py +39 -20
  94. langfun/core/modalities/image_test.py +52 -9
  95. langfun/core/modalities/mime.py +206 -29
  96. langfun/core/modalities/mime_test.py +90 -9
  97. langfun/core/modalities/ms_office.py +117 -0
  98. langfun/core/modalities/ms_office_test.py +389 -0
  99. langfun/core/modalities/pdf.py +22 -0
  100. langfun/core/modalities/pdf_test.py +57 -0
  101. langfun/core/modalities/video.py +9 -23
  102. langfun/core/modalities/video_test.py +3 -3
  103. langfun/core/modality.py +26 -3
  104. langfun/core/modality_test.py +2 -2
  105. langfun/core/sampling.py +11 -11
  106. langfun/core/structured/__init__.py +15 -16
  107. langfun/core/structured/completion.py +32 -5
  108. langfun/core/structured/completion_test.py +9 -8
  109. langfun/core/structured/description.py +2 -2
  110. langfun/core/structured/description_test.py +3 -3
  111. langfun/core/structured/function_generation.py +278 -0
  112. langfun/core/structured/function_generation_test.py +399 -0
  113. langfun/core/structured/mapping.py +150 -46
  114. langfun/core/structured/mapping_test.py +105 -0
  115. langfun/core/structured/parsing.py +33 -21
  116. langfun/core/structured/parsing_test.py +71 -22
  117. langfun/core/structured/querying.py +746 -0
  118. langfun/core/structured/{prompting_test.py → querying_test.py} +545 -60
  119. langfun/core/structured/schema.py +208 -99
  120. langfun/core/structured/schema_generation.py +1 -1
  121. langfun/core/structured/schema_generation_test.py +2 -2
  122. langfun/core/structured/schema_test.py +133 -34
  123. langfun/core/structured/scoring.py +125 -19
  124. langfun/core/structured/scoring_test.py +30 -0
  125. langfun/core/structured/tokenization.py +64 -0
  126. langfun/core/structured/tokenization_test.py +48 -0
  127. langfun/core/template.py +240 -11
  128. langfun/core/template_test.py +146 -1
  129. langfun/core/templates/conversation.py +9 -0
  130. langfun/core/templates/conversation_test.py +4 -3
  131. langfun/core/templates/selfplay_test.py +14 -2
  132. langfun-0.1.2.dev202501140804.dist-info/METADATA +225 -0
  133. langfun-0.1.2.dev202501140804.dist-info/RECORD +153 -0
  134. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/WHEEL +1 -1
  135. langfun/core/coding/python/errors.py +0 -108
  136. langfun/core/coding/python/errors_test.py +0 -99
  137. langfun/core/coding/python/permissions.py +0 -90
  138. langfun/core/coding/python/permissions_test.py +0 -86
  139. langfun/core/structured/prompting.py +0 -217
  140. langfun/core/text_formatting.py +0 -162
  141. langfun/core/text_formatting_test.py +0 -47
  142. langfun-0.0.2.dev20240330.dist-info/METADATA +0 -99
  143. langfun-0.0.2.dev20240330.dist-info/RECORD +0 -102
  144. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/LICENSE +0 -0
  145. {langfun-0.0.2.dev20240330.dist-info → langfun-0.1.2.dev202501140804.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,276 @@
1
+ # Copyright 2024 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Language models from Groq."""
15
+
16
+ import os
17
+ from typing import Annotated, Any
18
+
19
+ import langfun.core as lf
20
+ from langfun.core.llms import openai_compatible
21
+ import pyglove as pg
22
+
23
+
24
+ SUPPORTED_MODELS_AND_SETTINGS = {
25
+ # Refer https://console.groq.com/docs/models
26
+ # Price in US dollars at https://groq.com/pricing/ as of 2024-10-10.
27
+ 'llama-3.2-3b-preview': pg.Dict(
28
+ max_tokens=8192,
29
+ max_concurrency=64,
30
+ cost_per_1k_input_tokens=0.00006,
31
+ cost_per_1k_output_tokens=0.00006,
32
+ ),
33
+ 'llama-3.2-1b-preview': pg.Dict(
34
+ max_tokens=8192,
35
+ max_concurrency=64,
36
+ cost_per_1k_input_tokens=0.00004,
37
+ cost_per_1k_output_tokens=0.00004,
38
+ ),
39
+ 'llama-3.1-70b-versatile': pg.Dict(
40
+ max_tokens=8192,
41
+ max_concurrency=16,
42
+ cost_per_1k_input_tokens=0.00059,
43
+ cost_per_1k_output_tokens=0.00079,
44
+ ),
45
+ 'llama-3.1-8b-instant': pg.Dict(
46
+ max_tokens=8192,
47
+ max_concurrency=32,
48
+ cost_per_1k_input_tokens=0.00005,
49
+ cost_per_1k_output_tokens=0.00008,
50
+ ),
51
+ 'llama3-70b-8192': pg.Dict(
52
+ max_tokens=8192,
53
+ max_concurrency=16,
54
+ cost_per_1k_input_tokens=0.00059,
55
+ cost_per_1k_output_tokens=0.00079,
56
+ ),
57
+ 'llama3-8b-8192': pg.Dict(
58
+ max_tokens=8192,
59
+ max_concurrency=32,
60
+ cost_per_1k_input_tokens=0.00005,
61
+ cost_per_1k_output_tokens=0.00008,
62
+ ),
63
+ 'llama2-70b-4096': pg.Dict(
64
+ max_tokens=4096,
65
+ max_concurrency=16,
66
+ ),
67
+ 'mixtral-8x7b-32768': pg.Dict(
68
+ max_tokens=32768,
69
+ max_concurrency=16,
70
+ cost_per_1k_input_tokens=0.00024,
71
+ cost_per_1k_output_tokens=0.00024,
72
+ ),
73
+ 'gemma2-9b-it': pg.Dict(
74
+ max_tokens=8192,
75
+ max_concurrency=32,
76
+ cost_per_1k_input_tokens=0.0002,
77
+ cost_per_1k_output_tokens=0.0002,
78
+ ),
79
+ 'gemma-7b-it': pg.Dict(
80
+ max_tokens=8192,
81
+ max_concurrency=32,
82
+ cost_per_1k_input_tokens=0.00007,
83
+ cost_per_1k_output_tokens=0.00007,
84
+ ),
85
+ 'whisper-large-v3': pg.Dict(
86
+ max_tokens=8192,
87
+ max_concurrency=16,
88
+ ),
89
+ 'whisper-large-v3-turbo': pg.Dict(
90
+ max_tokens=8192,
91
+ max_concurrency=16,
92
+ )
93
+ }
94
+
95
+
96
+ @lf.use_init_args(['model'])
97
+ class Groq(openai_compatible.OpenAICompatible):
98
+ """Groq LLMs through REST APIs (OpenAI compatible).
99
+
100
+ See https://platform.openai.com/docs/api-reference/chat
101
+ """
102
+
103
+ model: pg.typing.Annotated[
104
+ pg.typing.Enum(
105
+ pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
106
+ ),
107
+ 'The name of the model to use.',
108
+ ]
109
+
110
+ api_key: Annotated[
111
+ str | None,
112
+ (
113
+ 'API key. If None, the key will be read from environment variable '
114
+ "'GROQ_API_KEY'."
115
+ ),
116
+ ] = None
117
+
118
+ api_endpoint: str = 'https://api.groq.com/openai/v1/chat/completions'
119
+
120
+ @property
121
+ def headers(self) -> dict[str, Any]:
122
+ api_key = self.api_key or os.environ.get('GROQ_API_KEY', None)
123
+ if not api_key:
124
+ raise ValueError(
125
+ 'Please specify `api_key` during `__init__` or set environment '
126
+ 'variable `GROQ_API_KEY` with your Groq API key.'
127
+ )
128
+ headers = super().headers
129
+ headers.update({
130
+ 'Authorization': f'Bearer {api_key}',
131
+ })
132
+ return headers
133
+
134
+ @property
135
+ def model_id(self) -> str:
136
+ """Returns a string to identify the model."""
137
+ return self.model
138
+
139
+ @property
140
+ def max_concurrency(self) -> int:
141
+ return SUPPORTED_MODELS_AND_SETTINGS[self.model].max_concurrency
142
+
143
+ def estimate_cost(
144
+ self,
145
+ num_input_tokens: int,
146
+ num_output_tokens: int
147
+ ) -> float | None:
148
+ """Estimate the cost based on usage."""
149
+ cost_per_1k_input_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
150
+ 'cost_per_1k_input_tokens', None
151
+ )
152
+ cost_per_1k_output_tokens = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
153
+ 'cost_per_1k_output_tokens', None
154
+ )
155
+ if cost_per_1k_input_tokens is None or cost_per_1k_output_tokens is None:
156
+ return None
157
+ return (
158
+ cost_per_1k_input_tokens * num_input_tokens
159
+ + cost_per_1k_output_tokens * num_output_tokens
160
+ ) / 1000
161
+
162
+ def _request_args(self, options: lf.LMSamplingOptions) -> dict[str, Any]:
163
+ """Returns a dict as request arguments."""
164
+ # `logprobs` and `top_logprobs` flags are not supported on Groq yet.
165
+ args = super()._request_args(options)
166
+ args.pop('logprobs', None)
167
+ args.pop('top_logprobs', None)
168
+ return args
169
+
170
+
171
+ class GroqLlama3_2_3B(Groq): # pylint: disable=invalid-name
172
+ """Llama3.2-3B with 8K context window.
173
+
174
+ See: https://huggingface.co/meta-llama/Llama-3.2-3B
175
+ """
176
+
177
+ model = 'llama-3.2-3b-preview'
178
+
179
+
180
+ class GroqLlama3_2_1B(Groq): # pylint: disable=invalid-name
181
+ """Llama3.2-1B with 8K context window.
182
+
183
+ See: https://huggingface.co/meta-llama/Llama-3.2-1B
184
+ """
185
+
186
+ model = 'llama-3.2-3b-preview'
187
+
188
+
189
+ class GroqLlama3_8B(Groq): # pylint: disable=invalid-name
190
+ """Llama3-8B with 8K context window.
191
+
192
+ See: https://huggingface.co/meta-llama/Meta-Llama-3-8B
193
+ """
194
+
195
+ model = 'llama3-8b-8192'
196
+
197
+
198
+ class GroqLlama3_1_70B(Groq): # pylint: disable=invalid-name
199
+ """Llama3.1-70B with 8K context window.
200
+
201
+ See: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md # pylint: disable=line-too-long
202
+ """
203
+
204
+ model = 'llama-3.1-70b-versatile'
205
+
206
+
207
+ class GroqLlama3_1_8B(Groq): # pylint: disable=invalid-name
208
+ """Llama3.1-8B with 8K context window.
209
+
210
+ See: https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/MODEL_CARD.md # pylint: disable=line-too-long
211
+ """
212
+
213
+ model = 'llama-3.1-8b-instant'
214
+
215
+
216
+ class GroqLlama3_70B(Groq): # pylint: disable=invalid-name
217
+ """Llama3-70B with 8K context window.
218
+
219
+ See: https://huggingface.co/meta-llama/Meta-Llama-3-70B
220
+ """
221
+
222
+ model = 'llama3-70b-8192'
223
+
224
+
225
+ class GroqLlama2_70B(Groq): # pylint: disable=invalid-name
226
+ """Llama2-70B with 4K context window.
227
+
228
+ See: https://huggingface.co/meta-llama/Llama-2-70b
229
+ """
230
+
231
+ model = 'llama2-70b-4096'
232
+
233
+
234
+ class GroqMistral_8x7B(Groq): # pylint: disable=invalid-name
235
+ """Mixtral 8x7B with 32K context window.
236
+
237
+ See: https://huggingface.co/meta-llama/Llama-2-70b
238
+ """
239
+
240
+ model = 'mixtral-8x7b-32768'
241
+
242
+
243
+ class GroqGemma2_9B_IT(Groq): # pylint: disable=invalid-name
244
+ """Gemma2 9B with 8K context window.
245
+
246
+ See: https://huggingface.co/google/gemma-2-9b-it
247
+ """
248
+
249
+ model = 'gemma2-9b-it'
250
+
251
+
252
+ class GroqGemma_7B_IT(Groq): # pylint: disable=invalid-name
253
+ """Gemma 7B with 8K context window.
254
+
255
+ See: https://huggingface.co/google/gemma-1.1-7b-it
256
+ """
257
+
258
+ model = 'gemma-7b-it'
259
+
260
+
261
+ class GroqWhisper_Large_v3(Groq): # pylint: disable=invalid-name
262
+ """Whisper Large V3 with 8K context window.
263
+
264
+ See: https://huggingface.co/openai/whisper-large-v3
265
+ """
266
+
267
+ model = 'whisper-large-v3'
268
+
269
+
270
+ class GroqWhisper_Large_v3Turbo(Groq): # pylint: disable=invalid-name
271
+ """Whisper Large V3 Turbo with 8K context window.
272
+
273
+ See: https://huggingface.co/openai/whisper-large-v3-turbo
274
+ """
275
+
276
+ model = 'whisper-large-v3-turbo'
@@ -0,0 +1,64 @@
1
+ # Copyright 2023 The Langfun Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import unittest
16
+ import langfun.core as lf
17
+ from langfun.core.llms import groq
18
+
19
+
20
+ class AuthropicTest(unittest.TestCase):
21
+
22
+ def test_basics(self):
23
+ self.assertEqual(groq.GroqMistral_8x7B().model_id, 'mixtral-8x7b-32768')
24
+ self.assertEqual(groq.GroqMistral_8x7B().max_concurrency, 16)
25
+ self.assertEqual(groq.GroqMistral_8x7B().estimate_cost(100, 100), 4.8e-5)
26
+
27
+ def test_request_args(self):
28
+ args = groq.GroqMistral_8x7B()._request_args(
29
+ lf.LMSamplingOptions(
30
+ temperature=1.0, stop=['\n'], n=1, random_seed=123,
31
+ logprobs=True, top_logprobs=True
32
+ )
33
+ )
34
+ self.assertNotIn('logprobs', args)
35
+ self.assertNotIn('top_logprobs', args)
36
+
37
+ def test_api_key(self):
38
+ lm = groq.GroqMistral_8x7B()
39
+ with self.assertRaisesRegex(ValueError, 'Please specify `api_key`'):
40
+ _ = lm.headers
41
+
42
+ lm = groq.GroqMistral_8x7B(api_key='fake key')
43
+ self.assertEqual(
44
+ lm.headers,
45
+ {
46
+ 'Content-Type': 'application/json',
47
+ 'Authorization': 'Bearer fake key',
48
+ }
49
+ )
50
+
51
+ os.environ['GROQ_API_KEY'] = 'abc'
52
+ lm = groq.GroqMistral_8x7B()
53
+ self.assertEqual(
54
+ lm.headers,
55
+ {
56
+ 'Content-Type': 'application/json',
57
+ 'Authorization': 'Bearer abc',
58
+ }
59
+ )
60
+ del os.environ['GROQ_API_KEY']
61
+
62
+
63
+ if __name__ == '__main__':
64
+ unittest.main()
@@ -14,59 +14,34 @@
14
14
  """Language models from llama.cpp."""
15
15
 
16
16
  from typing import Annotated
17
+ from langfun.core.llms import openai_compatible
18
+ import pyglove as pg
17
19
 
18
- import langfun.core as lf
19
- import requests
20
20
 
21
-
22
- @lf.use_init_args(["url"])
23
- class LlamaCppRemote(lf.LanguageModel):
21
+ @pg.use_init_args(['url', 'model'])
22
+ @pg.members([('api_endpoint', pg.typing.Str().freeze(''))])
23
+ class LlamaCppRemote(openai_compatible.OpenAICompatible):
24
24
  """The remote LLaMA C++ model.
25
25
 
26
26
  The Remote LLaMA C++ models can be launched via
27
27
  https://github.com/ggerganov/llama.cpp/tree/master/examples/server
28
28
  """
29
-
30
29
  url: Annotated[
31
30
  str,
32
- "The name of the model to use.",
33
- ] = ""
31
+ 'The URL of the LLaMA C++ server.',
32
+ ]
34
33
 
35
- name: Annotated[
34
+ model: Annotated[
36
35
  str,
37
- "The abbreviation for the LLaMA CPP-based model name.",
38
- ] = ""
36
+ 'The name of the model to use.',
37
+ ] = ''
38
+
39
+ @property
40
+ def api_endpoint(self) -> str:
41
+ return self.url + '/completion'
39
42
 
40
43
  @property
41
44
  def model_id(self) -> str:
42
45
  """Returns a string to identify the model."""
43
- return f"LLaMAC++({self.name})"
44
-
45
- def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
46
- def _complete_fn(cur_prompts):
47
- results = []
48
- for prompt in cur_prompts:
49
- result = lf.LMSamplingResult()
50
- for _ in range(self.sampling_options.n or 1):
51
- data = {
52
- "prompt": prompt.text,
53
- "n_predict": self.sampling_options.max_tokens,
54
- "temperature": self.sampling_options.temperature,
55
- "top_k": self.sampling_options.top_k or 50,
56
- "top_p": self.sampling_options.top_p or 0.95,
57
- }
58
- response = requests.post(
59
- f"{self.url}/completion",
60
- json=data,
61
- headers={"Content-Type": "application/json"},
62
- timeout=self.timeout,
63
- )
64
- decoded_response = response.json()
65
- response = decoded_response["content"]
66
- result.samples.append(lf.LMSample(response, score=0.0))
67
- results.append(result)
68
- return results
46
+ return f'LLaMAC++({self.model or ""})'
69
47
 
70
- return self._parallel_execute_with_currency_control(
71
- _complete_fn, [prompts]
72
- )[0]
@@ -11,44 +11,18 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- """Tests for llama cpp models."""
15
-
16
- import typing
17
14
  import unittest
18
- from unittest import mock
19
-
20
- import langfun.core as lf
21
15
  from langfun.core.llms import llama_cpp
22
16
 
23
17
 
24
- def mock_requests_post(url: str, json: typing.Dict[str, typing.Any], **kwargs):
25
- del kwargs
26
-
27
- class TEMP:
28
-
29
- def json(self):
30
- return {"content": json["prompt"] + "\n" + url}
31
-
32
- return TEMP()
33
-
34
-
35
18
  class LlamaCppRemoteTest(unittest.TestCase):
36
19
  """Tests for the LlamaCppRemote model."""
37
20
 
38
- def test_call_completion(self):
39
- with mock.patch("requests.post") as mock_request:
40
- mock_request.side_effect = mock_requests_post
41
- lm = llama_cpp.LlamaCppRemote(url="http://127.0.0.1:8080")
42
- response = lm("hello", sampling_options=lf.LMSamplingOptions(n=1))
43
- self.assertEqual(
44
- response.text,
45
- "hello\nhttp://127.0.0.1:8080/completion",
46
- )
47
-
48
- def test_name(self):
49
- lm = llama_cpp.LlamaCppRemote()
21
+ def test_basics(self):
22
+ lm = llama_cpp.LlamaCppRemote("http://127.0.0.1:8080")
23
+ self.assertEqual(lm.api_endpoint, "http://127.0.0.1:8080/completion")
50
24
  self.assertEqual(lm.model_id, "LLaMAC++()")
51
- lm = llama_cpp.LlamaCppRemote(url="xxx", name="x")
25
+ lm = llama_cpp.LlamaCppRemote("xxx", model="x")
52
26
  self.assertEqual(lm.model_id, "LLaMAC++(x)")
53
27
 
54
28