model-library 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,5 @@
1
1
  import io
2
2
  import logging
3
- import time
4
3
  from collections.abc import Sequence
5
4
  from typing import Any, Literal
6
5
 
@@ -13,14 +12,16 @@ from typing_extensions import override
13
12
  from model_library import model_library_settings
14
13
  from model_library.base import (
15
14
  LLM,
15
+ FileBase,
16
16
  FileInput,
17
17
  FileWithBase64,
18
18
  FileWithId,
19
- FileWithUrl,
20
19
  InputItem,
21
20
  LLMConfig,
22
21
  QueryResult,
23
22
  QueryResultMetadata,
23
+ RawInput,
24
+ RawResponse,
24
25
  TextInput,
25
26
  ToolBody,
26
27
  ToolCall,
@@ -69,27 +70,30 @@ class MistralModel(LLM):
69
70
  content_user: list[dict[str, Any]] = []
70
71
 
71
72
  def flush_content_user():
72
- nonlocal content_user
73
-
74
73
  if content_user:
75
- new_input.append({"role": "user", "content": content_user})
76
- content_user = []
74
+ # NOTE: must make new object as we clear()
75
+ new_input.append({"role": "user", "content": content_user.copy()})
76
+ content_user.clear()
77
77
 
78
78
  for item in input:
79
+ if isinstance(item, TextInput):
80
+ content_user.append({"type": "text", "text": item.text})
81
+ continue
82
+
83
+ if isinstance(item, FileBase):
84
+ match item.type:
85
+ case "image":
86
+ parsed = await self.parse_image(item)
87
+ case "file":
88
+ parsed = await self.parse_file(item)
89
+ content_user.append(parsed)
90
+ continue
91
+
92
+ # non content user item
93
+ flush_content_user()
94
+
79
95
  match item:
80
- case TextInput():
81
- content_user.append({"type": "text", "text": item.text})
82
- case FileWithBase64() | FileWithUrl() | FileWithId():
83
- match item.type:
84
- case "image":
85
- content_user.append(await self.parse_image(item))
86
- case "file":
87
- content_user.append(await self.parse_file(item))
88
- case AssistantMessage():
89
- flush_content_user()
90
- new_input.append(item)
91
96
  case ToolResult():
92
- flush_content_user()
93
97
  new_input.append(
94
98
  {
95
99
  "role": "tool",
@@ -98,9 +102,12 @@ class MistralModel(LLM):
98
102
  "tool_call_id": item.tool_call.id,
99
103
  }
100
104
  )
101
- case _:
102
- raise BadInputError("Unsupported input type")
105
+ case RawResponse():
106
+ new_input.append(item.response)
107
+ case RawInput():
108
+ new_input.append(item.input)
103
109
 
110
+ # in case content user item is the last item
104
111
  flush_content_user()
105
112
 
106
113
  return new_input
@@ -167,14 +174,13 @@ class MistralModel(LLM):
167
174
  raise NotImplementedError()
168
175
 
169
176
  @override
170
- async def _query_impl(
177
+ async def build_body(
171
178
  self,
172
179
  input: Sequence[InputItem],
173
180
  *,
174
181
  tools: list[ToolDefinition],
175
- query_logger: logging.Logger,
176
182
  **kwargs: object,
177
- ) -> QueryResult:
183
+ ) -> dict[str, Any]:
178
184
  # mistral supports max 8 images, merge extra images into the 8th image
179
185
  input = trim_images(input, max_images=8)
180
186
 
@@ -205,8 +211,18 @@ class MistralModel(LLM):
205
211
  body["top_p"] = self.top_p
206
212
 
207
213
  body.update(kwargs)
214
+ return body
208
215
 
209
- start = time.time()
216
+ @override
217
+ async def _query_impl(
218
+ self,
219
+ input: Sequence[InputItem],
220
+ *,
221
+ tools: list[ToolDefinition],
222
+ query_logger: logging.Logger,
223
+ **kwargs: object,
224
+ ) -> QueryResult:
225
+ body = await self.build_body(input, tools=tools, **kwargs)
210
226
 
211
227
  response: EventStreamAsync[
212
228
  CompletionEvent
@@ -247,8 +263,6 @@ class MistralModel(LLM):
247
263
  in_tokens += data.usage.prompt_tokens or 0
248
264
  out_tokens += data.usage.completion_tokens or 0
249
265
 
250
- self.logger.info(f"Finished in: {time.time() - start}")
251
-
252
266
  except Exception as e:
253
267
  self.logger.error(f"Error: {e}", exc_info=True)
254
268
  raise e
@@ -302,7 +316,7 @@ class MistralModel(LLM):
302
316
  return QueryResult(
303
317
  output_text=text,
304
318
  reasoning=reasoning or None,
305
- history=[*input, message],
319
+ history=[*input, RawResponse(response=message)],
306
320
  tool_calls=tool_calls,
307
321
  metadata=QueryResultMetadata(
308
322
  in_tokens=in_tokens,
@@ -16,6 +16,7 @@ from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
16
16
  from openai.types.create_embedding_response import CreateEmbeddingResponse
17
17
  from openai.types.moderation_create_response import ModerationCreateResponse
18
18
  from openai.types.responses import (
19
+ ResponseFunctionToolCall,
19
20
  ResponseOutputItem,
20
21
  ResponseOutputText,
21
22
  ResponseStreamEvent,
@@ -29,6 +30,7 @@ from model_library.base import (
29
30
  LLM,
30
31
  BatchResult,
31
32
  Citation,
33
+ FileBase,
32
34
  FileInput,
33
35
  FileWithBase64,
34
36
  FileWithId,
@@ -42,7 +44,8 @@ from model_library.base import (
42
44
  QueryResultCost,
43
45
  QueryResultExtras,
44
46
  QueryResultMetadata,
45
- RawInputItem,
47
+ RawInput,
48
+ RawResponse,
46
49
  TextInput,
47
50
  ToolBody,
48
51
  ToolCall,
@@ -53,6 +56,7 @@ from model_library.exceptions import (
53
56
  ImmediateRetryException,
54
57
  MaxOutputTokensExceededError,
55
58
  ModelNoOutputError,
59
+ NoMatchingToolCallError,
56
60
  )
57
61
  from model_library.model_utils import get_reasoning_in_tag
58
62
  from model_library.register_models import register_provider
@@ -258,7 +262,9 @@ class OpenAIModel(LLM):
258
262
  use_completions: bool = False,
259
263
  ):
260
264
  super().__init__(model_name, provider, config=config)
261
- self.use_completions: bool = use_completions
265
+ self.use_completions: bool = (
266
+ use_completions # TODO: do completions in a separate file
267
+ )
262
268
  self.deep_research = self.provider_config.deep_research
263
269
 
264
270
  # allow custom client to act as delegate (native)
@@ -270,6 +276,29 @@ class OpenAIModel(LLM):
270
276
  OpenAIBatchMixin(self) if self.supports_batch else None
271
277
  )
272
278
 
279
+ async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
280
+ raw_responses = [x for x in input if isinstance(x, RawResponse)]
281
+ tool_call_ids: list[str] = []
282
+
283
+ if self.use_completions:
284
+ calls = [
285
+ y
286
+ for x in raw_responses
287
+ if isinstance(x.response, ChatCompletionMessage)
288
+ and x.response.tool_calls
289
+ for y in x.response.tool_calls
290
+ ]
291
+ tool_call_ids.extend([x.id for x in calls if x.id])
292
+ else:
293
+ calls = [
294
+ y
295
+ for x in raw_responses
296
+ for y in x.response
297
+ if isinstance(y, ResponseFunctionToolCall)
298
+ ]
299
+ tool_call_ids.extend([x.id for x in calls if x.id])
300
+ return tool_call_ids
301
+
273
302
  @override
274
303
  async def parse_input(
275
304
  self,
@@ -277,63 +306,70 @@ class OpenAIModel(LLM):
277
306
  **kwargs: Any,
278
307
  ) -> list[dict[str, Any] | Any]:
279
308
  new_input: list[dict[str, Any] | Any] = []
309
+
280
310
  content_user: list[dict[str, Any]] = []
311
+
312
+ def flush_content_user():
313
+ if content_user:
314
+ # NOTE: must make new object as we clear()
315
+ new_input.append({"role": "user", "content": content_user.copy()})
316
+ content_user.clear()
317
+
318
+ tool_call_ids = await self.get_tool_call_ids(input)
319
+
281
320
  for item in input:
321
+ if isinstance(item, TextInput):
322
+ if self.use_completions:
323
+ text_key = "text"
324
+ else:
325
+ text_key = "input_text"
326
+ content_user.append({"type": text_key, "text": item.text})
327
+ continue
328
+
329
+ if isinstance(item, FileBase):
330
+ match item.type:
331
+ case "image":
332
+ parsed = await self.parse_image(item)
333
+ case "file":
334
+ parsed = await self.parse_file(item)
335
+ content_user.append(parsed)
336
+ continue
337
+
338
+ # non content user item
339
+ flush_content_user()
340
+
282
341
  match item:
283
- case TextInput():
342
+ case ToolResult():
343
+ if item.tool_call.id not in tool_call_ids:
344
+ raise NoMatchingToolCallError()
345
+
284
346
  if self.use_completions:
285
- content_user.append({"type": "text", "text": item.text})
347
+ new_input.append(
348
+ {
349
+ "role": "tool",
350
+ "tool_call_id": item.tool_call.id,
351
+ "content": item.result,
352
+ }
353
+ )
286
354
  else:
287
- content_user.append({"type": "input_text", "text": item.text})
288
- case FileWithBase64() | FileWithUrl() | FileWithId():
289
- match item.type:
290
- case "image":
291
- content_user.append(await self.parse_image(item))
292
- case "file":
293
- content_user.append(await self.parse_file(item))
294
- case _:
295
- if content_user:
296
- new_input.append({"role": "user", "content": content_user})
297
- content_user = []
298
- match item:
299
- case ToolResult():
300
- if not (
301
- not isinstance(x, dict)
302
- and x.type == "function_call"
303
- and x.call_id == item.tool_call.call_id
304
- for x in new_input
305
- ):
306
- raise Exception(
307
- "Tool call result provided with no matching tool call"
308
- )
309
- if self.use_completions:
310
- new_input.append(
311
- {
312
- "role": "tool",
313
- "tool_call_id": item.tool_call.id,
314
- "content": item.result,
315
- }
316
- )
317
- else:
318
- new_input.append(
319
- {
320
- "type": "function_call_output",
321
- "call_id": item.tool_call.call_id,
322
- "output": item.result,
323
- }
324
- )
325
- case dict(): # RawInputItem
326
- item = cast(RawInputItem, item)
327
- new_input.append(item)
328
- case _: # RawResponse
329
- if self.use_completions:
330
- item = cast(ChatCompletionMessageToolCall, item)
331
- else:
332
- item = cast(ResponseOutputItem, item)
333
- new_input.append(item)
334
-
335
- if content_user:
336
- new_input.append({"role": "user", "content": content_user})
355
+ new_input.append(
356
+ {
357
+ "type": "function_call_output",
358
+ "call_id": item.tool_call.call_id,
359
+ "output": item.result,
360
+ }
361
+ )
362
+ case RawResponse():
363
+ if self.use_completions:
364
+ pass
365
+ new_input.append(item.response)
366
+ else:
367
+ new_input.extend(item.response)
368
+ case RawInput():
369
+ new_input.append(item.input)
370
+
371
+ # in case content user item is the last item
372
+ flush_content_user()
337
373
 
338
374
  return new_input
339
375
 
@@ -469,19 +505,13 @@ class OpenAIModel(LLM):
469
505
  file_id=response.id,
470
506
  )
471
507
 
472
- async def _query_completions(
508
+ async def _build_body_completions(
473
509
  self,
474
510
  input: Sequence[InputItem],
475
511
  *,
476
512
  tools: list[ToolDefinition],
477
513
  **kwargs: object,
478
- ) -> QueryResult:
479
- """
480
- Completions endpoint
481
- Generally not used for openai models
482
- Used by some providers using openai as a delegate
483
- """
484
-
514
+ ) -> dict[str, Any]:
485
515
  parsed_input: list[dict[str, Any] | ChatCompletionMessage] = []
486
516
  if "system_prompt" in kwargs:
487
517
  parsed_input.append(
@@ -520,6 +550,23 @@ class OpenAIModel(LLM):
520
550
 
521
551
  body.update(kwargs)
522
552
 
553
+ return body
554
+
555
+ async def _query_completions(
556
+ self,
557
+ input: Sequence[InputItem],
558
+ *,
559
+ tools: list[ToolDefinition],
560
+ **kwargs: object,
561
+ ) -> QueryResult:
562
+ """
563
+ Completions endpoint
564
+ Generally not used for openai models
565
+ Used by providers using openai as a delegate
566
+ """
567
+
568
+ body = await self.build_body(input, tools=tools, **kwargs)
569
+
523
570
  output_text: str = ""
524
571
  reasoning_text: str = ""
525
572
  metadata: QueryResultMetadata = QueryResultMetadata()
@@ -632,7 +679,7 @@ class OpenAIModel(LLM):
632
679
  output_text=output_text,
633
680
  reasoning=reasoning_text,
634
681
  tool_calls=tool_calls,
635
- history=[*input, final_message],
682
+ history=[*input, RawResponse(response=final_message)],
636
683
  metadata=metadata,
637
684
  )
638
685
 
@@ -667,13 +714,17 @@ class OpenAIModel(LLM):
667
714
  if not valid:
668
715
  raise Exception("Deep research models require web search tools")
669
716
 
717
+ @override
670
718
  async def build_body(
671
719
  self,
672
720
  input: Sequence[InputItem],
673
721
  *,
674
- tools: Sequence[ToolDefinition],
722
+ tools: list[ToolDefinition],
675
723
  **kwargs: object,
676
724
  ) -> dict[str, Any]:
725
+ if self.use_completions:
726
+ return await self._build_body_completions(input, tools=tools, **kwargs)
727
+
677
728
  if self.deep_research:
678
729
  await self._check_deep_research_args(tools, **kwargs)
679
730
 
@@ -717,7 +768,6 @@ class OpenAIModel(LLM):
717
768
  _ = kwargs.pop("stream", None)
718
769
 
719
770
  body.update(kwargs)
720
-
721
771
  return body
722
772
 
723
773
  @override
@@ -785,13 +835,12 @@ class OpenAIModel(LLM):
785
835
  citations: list[Citation] = []
786
836
  reasoning = None
787
837
  for output in response.output:
788
- if self.deep_research:
789
- if output.type == "message":
790
- for content in output.content:
791
- if not isinstance(content, ResponseOutputText):
792
- continue
793
- for citation in content.annotations:
794
- citations.append(Citation(**citation.model_dump()))
838
+ if output.type == "message":
839
+ for content in output.content:
840
+ if not isinstance(content, ResponseOutputText):
841
+ continue
842
+ for citation in content.annotations:
843
+ citations.append(Citation(**citation.model_dump()))
795
844
 
796
845
  if output.type == "reasoning":
797
846
  reasoning = " ".join([i.text for i in output.summary])
@@ -814,7 +863,7 @@ class OpenAIModel(LLM):
814
863
  output_text=response.output_text,
815
864
  reasoning=reasoning,
816
865
  tool_calls=tool_calls,
817
- history=[*input, *response.output],
866
+ history=[*input, RawResponse(response=response.output)],
818
867
  extras=QueryResultExtras(citations=citations),
819
868
  )
820
869
  if response.usage:
@@ -51,7 +51,7 @@ class DummyAIBatchMixin(LLMBatchMixin):
51
51
  "custom_id": custom_id,
52
52
  "method": "",
53
53
  "url": "",
54
- "body": await self._root.create_body(input, tools=[], **kwargs),
54
+ "body": await self._root.build_body(input, tools=[], **kwargs),
55
55
  }
56
56
 
57
57
  @override
@@ -227,7 +227,8 @@ class DummyAIModel(LLM):
227
227
  ) -> FileWithId:
228
228
  raise NotImplementedError()
229
229
 
230
- async def create_body(
230
+ @override
231
+ async def build_body(
231
232
  self,
232
233
  input: Sequence[InputItem],
233
234
  *,
@@ -275,7 +276,7 @@ class DummyAIModel(LLM):
275
276
  query_logger: logging.Logger,
276
277
  **kwargs: object,
277
278
  ) -> QueryResult:
278
- body = await self.create_body(input, tools=tools, **kwargs)
279
+ body = await self.build_body(input, tools=tools, **kwargs)
279
280
 
280
281
  fail_rate = FAIL_RATE
281
282