langfun 0.1.2.dev202502110804__py3-none-any.whl → 0.1.2.dev202502120804__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,10 +29,6 @@ class OpenAICompatible(rest.REST):
29
29
  str, 'The name of the model to use.',
30
30
  ] = ''
31
31
 
32
- multimodal: Annotated[
33
- bool, 'Whether this model has multimodal support.'
34
- ] = False
35
-
36
32
  @property
37
33
  def headers(self) -> dict[str, Any]:
38
34
  return {
@@ -71,7 +67,8 @@ class OpenAICompatible(rest.REST):
71
67
  for chunk in message.chunk():
72
68
  if isinstance(chunk, str):
73
69
  item = dict(type='text', text=chunk)
74
- elif isinstance(chunk, lf_modalities.Image) and self.multimodal:
70
+ elif (isinstance(chunk, lf_modalities.Image)
71
+ and self.supports_input(chunk.mime_type)):
75
72
  item = dict(type='image_url', image_url=dict(url=chunk.embeddable_uri))
76
73
  else:
77
74
  raise ValueError(f'Unsupported modality: {chunk!r}.')
@@ -162,18 +159,6 @@ class OpenAICompatible(rest.REST):
162
159
  prompt_tokens=usage['prompt_tokens'],
163
160
  completion_tokens=usage['completion_tokens'],
164
161
  total_tokens=usage['total_tokens'],
165
- estimated_cost=self.estimate_cost(
166
- num_input_tokens=usage['prompt_tokens'],
167
- num_output_tokens=usage['completion_tokens'],
168
- )
162
+
169
163
  ),
170
164
  )
171
-
172
- def estimate_cost(
173
- self,
174
- num_input_tokens: int,
175
- num_output_tokens: int
176
- ) -> float | None:
177
- """Estimate the cost based on usage."""
178
- del num_input_tokens, num_output_tokens
179
- return None
@@ -207,37 +207,52 @@ class OpenAIComptibleTest(unittest.TestCase):
207
207
  )
208
208
 
209
209
  def test_call_chat_completion_vision(self):
210
+
211
+ class FakeImage(lf_modalities.Image):
212
+
213
+ @property
214
+ def mime_type(self) -> str:
215
+ return 'image/png'
216
+
210
217
  with mock.patch('requests.Session.post') as mock_request:
211
218
  mock_request.side_effect = mock_chat_completion_request_vision
212
219
  lm_1 = openai_compatible.OpenAICompatible(
213
220
  api_endpoint='https://test-server',
214
221
  model='test-model1',
215
- multimodal=True
216
222
  )
217
223
  lm_2 = openai_compatible.OpenAICompatible(
218
224
  api_endpoint='https://test-server',
219
225
  model='test-model2',
220
- multimodal=True
221
226
  )
222
227
  for lm in (lm_1, lm_2):
223
228
  self.assertEqual(
224
229
  lm(
225
230
  lf.UserMessage(
226
231
  'hello <<[[image]]>>',
227
- image=lf_modalities.Image.from_uri('https://fake/image')
232
+ image=FakeImage.from_uri('https://fake/image')
228
233
  ),
229
234
  sampling_options=lf.LMSamplingOptions(n=2)
230
235
  ),
231
236
  'Sample 0 for message: https://fake/image',
232
237
  )
233
- lm_3 = openai_compatible.OpenAICompatible(
238
+
239
+ class TextOnlyModel(openai_compatible.OpenAICompatible):
240
+
241
+ class ModelInfo(lf.ModelInfo):
242
+ input_modalities: list[str] = lf.ModelInfo.TEXT_INPUT_ONLY
243
+
244
+ @property
245
+ def model_info(self) -> lf.ModelInfo:
246
+ return TextOnlyModel.ModelInfo('text-only-model')
247
+
248
+ lm_3 = TextOnlyModel(
234
249
  api_endpoint='https://test-server', model='test-model3'
235
250
  )
236
251
  with self.assertRaisesRegex(ValueError, 'Unsupported modality'):
237
252
  lm_3(
238
253
  lf.UserMessage(
239
254
  'hello <<[[image]]>>',
240
- image=lf_modalities.Image.from_uri('https://fake/image')
255
+ image=FakeImage.from_uri('https://fake/image')
241
256
  ),
242
257
  )
243
258
 
@@ -30,11 +30,12 @@ class OpenAITest(unittest.TestCase):
30
30
 
31
31
  def test_model_id(self):
32
32
  self.assertEqual(
33
- openai.Gpt35(api_key='test_key').model_id, 'OpenAI(text-davinci-003)')
33
+ openai.Gpt35(api_key='test_key').model_id, 'text-davinci-003')
34
34
 
35
35
  def test_resource_id(self):
36
36
  self.assertEqual(
37
- openai.Gpt35(api_key='test_key').resource_id, 'OpenAI(text-davinci-003)'
37
+ openai.Gpt35(api_key='test_key').resource_id,
38
+ 'openai://text-davinci-003'
38
39
  )
39
40
 
40
41
  def test_headers(self):
@@ -47,7 +48,9 @@ class OpenAITest(unittest.TestCase):
47
48
  )
48
49
 
49
50
  def test_max_concurrency(self):
50
- self.assertGreater(openai.Gpt35(api_key='test_key').max_concurrency, 0)
51
+ self.assertGreater(
52
+ openai.Gpt4o(api_key='test_key').max_concurrency, 0
53
+ )
51
54
 
52
55
  def test_request_args(self):
53
56
  self.assertEqual(
@@ -75,11 +78,18 @@ class OpenAITest(unittest.TestCase):
75
78
  def test_estimate_cost(self):
76
79
  self.assertEqual(
77
80
  openai.Gpt4(api_key='test_key').estimate_cost(
78
- num_input_tokens=100, num_output_tokens=100
81
+ lf.LMSamplingUsage(
82
+ total_tokens=200,
83
+ prompt_tokens=100,
84
+ completion_tokens=100,
85
+ )
79
86
  ),
80
87
  0.009
81
88
  )
82
89
 
90
+ def test_lm_get(self):
91
+ self.assertIsInstance(lf.LanguageModel.get('gpt-4o'), openai.OpenAI)
92
+
83
93
 
84
94
  if __name__ == '__main__':
85
95
  unittest.main()
langfun/core/llms/rest.py CHANGED
@@ -49,11 +49,6 @@ class REST(lf.LanguageModel):
49
49
  'The headers for the REST API.'
50
50
  ] = None
51
51
 
52
- @property
53
- def model_id(self) -> str:
54
- """Returns a string to identify the model."""
55
- return self.model or 'unknown'
56
-
57
52
  @functools.cached_property
58
53
  def _api_initialized(self) -> bool:
59
54
  """Returns whether the API is initialized."""
@@ -76,7 +71,6 @@ class REST(lf.LanguageModel):
76
71
 
77
72
  def _on_bound(self):
78
73
  super()._on_bound()
79
- self.__dict__.pop('_session', None)
80
74
  self.__dict__.pop('_api_initialized', None)
81
75
 
82
76
  def _sample(self, prompts: list[lf.Message]) -> list[lf.LMSamplingResult]:
@@ -121,3 +115,14 @@ class REST(lf.LanguageModel):
121
115
  return self.result(response.json())
122
116
  else:
123
117
  raise self._error(response.status_code, response.content)
118
+
119
+ @property
120
+ def max_concurrency(self) -> int | None:
121
+ """Returns the max concurrency for this model."""
122
+ rate_limits = self.model_info.rate_limits
123
+ if rate_limits is not None:
124
+ return self.estimate_max_concurrency(
125
+ max_requests_per_minute=rate_limits.max_requests_per_minute,
126
+ max_tokens_per_minute=rate_limits.max_tokens_per_minute
127
+ )
128
+ return None