model-library 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import io
4
4
  import json
5
+ import logging
5
6
  from typing import Any, Literal, Sequence, cast
6
7
 
7
8
  from openai import APIConnectionError, AsyncOpenAI
@@ -15,6 +16,7 @@ from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
15
16
  from openai.types.create_embedding_response import CreateEmbeddingResponse
16
17
  from openai.types.moderation_create_response import ModerationCreateResponse
17
18
  from openai.types.responses import (
19
+ ResponseFunctionToolCall,
18
20
  ResponseOutputItem,
19
21
  ResponseOutputText,
20
22
  ResponseStreamEvent,
@@ -28,6 +30,7 @@ from model_library.base import (
28
30
  LLM,
29
31
  BatchResult,
30
32
  Citation,
33
+ FileBase,
31
34
  FileInput,
32
35
  FileWithBase64,
33
36
  FileWithId,
@@ -41,7 +44,8 @@ from model_library.base import (
41
44
  QueryResultCost,
42
45
  QueryResultExtras,
43
46
  QueryResultMetadata,
44
- RawInputItem,
47
+ RawInput,
48
+ RawResponse,
45
49
  TextInput,
46
50
  ToolBody,
47
51
  ToolCall,
@@ -52,6 +56,7 @@ from model_library.exceptions import (
52
56
  ImmediateRetryException,
53
57
  MaxOutputTokensExceededError,
54
58
  ModelNoOutputError,
59
+ NoMatchingToolCallError,
55
60
  )
56
61
  from model_library.model_utils import get_reasoning_in_tag
57
62
  from model_library.register_models import register_provider
@@ -257,7 +262,9 @@ class OpenAIModel(LLM):
257
262
  use_completions: bool = False,
258
263
  ):
259
264
  super().__init__(model_name, provider, config=config)
260
- self.use_completions: bool = use_completions
265
+ self.use_completions: bool = (
266
+ use_completions # TODO: do completions in a separate file
267
+ )
261
268
  self.deep_research = self.provider_config.deep_research
262
269
 
263
270
  # allow custom client to act as delegate (native)
@@ -269,6 +276,29 @@ class OpenAIModel(LLM):
269
276
  OpenAIBatchMixin(self) if self.supports_batch else None
270
277
  )
271
278
 
279
+ async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
280
+ raw_responses = [x for x in input if isinstance(x, RawResponse)]
281
+ tool_call_ids: list[str] = []
282
+
283
+ if self.use_completions:
284
+ calls = [
285
+ y
286
+ for x in raw_responses
287
+ if isinstance(x.response, ChatCompletionMessage)
288
+ and x.response.tool_calls
289
+ for y in x.response.tool_calls
290
+ ]
291
+ tool_call_ids.extend([x.id for x in calls if x.id])
292
+ else:
293
+ calls = [
294
+ y
295
+ for x in raw_responses
296
+ for y in x.response
297
+ if isinstance(y, ResponseFunctionToolCall)
298
+ ]
299
+ tool_call_ids.extend([x.id for x in calls if x.id])
300
+ return tool_call_ids
301
+
272
302
  @override
273
303
  async def parse_input(
274
304
  self,
@@ -276,63 +306,70 @@ class OpenAIModel(LLM):
276
306
  **kwargs: Any,
277
307
  ) -> list[dict[str, Any] | Any]:
278
308
  new_input: list[dict[str, Any] | Any] = []
309
+
279
310
  content_user: list[dict[str, Any]] = []
311
+
312
+ def flush_content_user():
313
+ if content_user:
314
+ # NOTE: must make new object as we clear()
315
+ new_input.append({"role": "user", "content": content_user.copy()})
316
+ content_user.clear()
317
+
318
+ tool_call_ids = await self.get_tool_call_ids(input)
319
+
280
320
  for item in input:
321
+ if isinstance(item, TextInput):
322
+ if self.use_completions:
323
+ text_key = "text"
324
+ else:
325
+ text_key = "input_text"
326
+ content_user.append({"type": text_key, "text": item.text})
327
+ continue
328
+
329
+ if isinstance(item, FileBase):
330
+ match item.type:
331
+ case "image":
332
+ parsed = await self.parse_image(item)
333
+ case "file":
334
+ parsed = await self.parse_file(item)
335
+ content_user.append(parsed)
336
+ continue
337
+
338
+ # non content user item
339
+ flush_content_user()
340
+
281
341
  match item:
282
- case TextInput():
342
+ case ToolResult():
343
+ if item.tool_call.id not in tool_call_ids:
344
+ raise NoMatchingToolCallError()
345
+
283
346
  if self.use_completions:
284
- content_user.append({"type": "text", "text": item.text})
347
+ new_input.append(
348
+ {
349
+ "role": "tool",
350
+ "tool_call_id": item.tool_call.id,
351
+ "content": item.result,
352
+ }
353
+ )
285
354
  else:
286
- content_user.append({"type": "input_text", "text": item.text})
287
- case FileWithBase64() | FileWithUrl() | FileWithId():
288
- match item.type:
289
- case "image":
290
- content_user.append(await self.parse_image(item))
291
- case "file":
292
- content_user.append(await self.parse_file(item))
293
- case _:
294
- if content_user:
295
- new_input.append({"role": "user", "content": content_user})
296
- content_user = []
297
- match item:
298
- case ToolResult():
299
- if not (
300
- not isinstance(x, dict)
301
- and x.type == "function_call"
302
- and x.call_id == item.tool_call.call_id
303
- for x in new_input
304
- ):
305
- raise Exception(
306
- "Tool call result provided with no matching tool call"
307
- )
308
- if self.use_completions:
309
- new_input.append(
310
- {
311
- "role": "tool",
312
- "tool_call_id": item.tool_call.id,
313
- "content": item.result,
314
- }
315
- )
316
- else:
317
- new_input.append(
318
- {
319
- "type": "function_call_output",
320
- "call_id": item.tool_call.call_id,
321
- "output": item.result,
322
- }
323
- )
324
- case dict(): # RawInputItem
325
- item = cast(RawInputItem, item)
326
- new_input.append(item)
327
- case _: # RawResponse
328
- if self.use_completions:
329
- item = cast(ChatCompletionMessageToolCall, item)
330
- else:
331
- item = cast(ResponseOutputItem, item)
332
- new_input.append(item)
333
-
334
- if content_user:
335
- new_input.append({"role": "user", "content": content_user})
355
+ new_input.append(
356
+ {
357
+ "type": "function_call_output",
358
+ "call_id": item.tool_call.call_id,
359
+ "output": item.result,
360
+ }
361
+ )
362
+ case RawResponse():
363
+ if self.use_completions:
364
+ pass
365
+ new_input.append(item.response)
366
+ else:
367
+ new_input.extend(item.response)
368
+ case RawInput():
369
+ new_input.append(item.input)
370
+
371
+ # in case content user item is the last item
372
+ flush_content_user()
336
373
 
337
374
  return new_input
338
375
 
@@ -468,19 +505,13 @@ class OpenAIModel(LLM):
468
505
  file_id=response.id,
469
506
  )
470
507
 
471
- async def _query_completions(
508
+ async def _build_body_completions(
472
509
  self,
473
510
  input: Sequence[InputItem],
474
511
  *,
475
512
  tools: list[ToolDefinition],
476
513
  **kwargs: object,
477
- ) -> QueryResult:
478
- """
479
- Completions endpoint
480
- Generally not used for openai models
481
- Used by some providers using openai as a delegate
482
- """
483
-
514
+ ) -> dict[str, Any]:
484
515
  parsed_input: list[dict[str, Any] | ChatCompletionMessage] = []
485
516
  if "system_prompt" in kwargs:
486
517
  parsed_input.append(
@@ -505,8 +536,11 @@ class OpenAIModel(LLM):
505
536
  if self.reasoning:
506
537
  del body["max_tokens"]
507
538
  body["max_completion_tokens"] = self.max_tokens
508
- if self.reasoning_effort:
509
- body["reasoning_effort"] = self.reasoning_effort
539
+
540
+ # some model endpoints (like `fireworks/deepseek-v3p2`)
541
+ # require explicitly setting reasoning effort to disable thinking
542
+ if self.reasoning_effort is not None:
543
+ body["reasoning_effort"] = self.reasoning_effort
510
544
 
511
545
  if self.supports_temperature:
512
546
  if self.temperature is not None:
@@ -516,6 +550,23 @@ class OpenAIModel(LLM):
516
550
 
517
551
  body.update(kwargs)
518
552
 
553
+ return body
554
+
555
+ async def _query_completions(
556
+ self,
557
+ input: Sequence[InputItem],
558
+ *,
559
+ tools: list[ToolDefinition],
560
+ **kwargs: object,
561
+ ) -> QueryResult:
562
+ """
563
+ Completions endpoint
564
+ Generally not used for openai models
565
+ Used by providers using openai as a delegate
566
+ """
567
+
568
+ body = await self.build_body(input, tools=tools, **kwargs)
569
+
519
570
  output_text: str = ""
520
571
  reasoning_text: str = ""
521
572
  metadata: QueryResultMetadata = QueryResultMetadata()
@@ -628,7 +679,7 @@ class OpenAIModel(LLM):
628
679
  output_text=output_text,
629
680
  reasoning=reasoning_text,
630
681
  tool_calls=tool_calls,
631
- history=[*input, final_message],
682
+ history=[*input, RawResponse(response=final_message)],
632
683
  metadata=metadata,
633
684
  )
634
685
 
@@ -663,13 +714,17 @@ class OpenAIModel(LLM):
663
714
  if not valid:
664
715
  raise Exception("Deep research models require web search tools")
665
716
 
717
+ @override
666
718
  async def build_body(
667
719
  self,
668
720
  input: Sequence[InputItem],
669
721
  *,
670
- tools: Sequence[ToolDefinition],
722
+ tools: list[ToolDefinition],
671
723
  **kwargs: object,
672
724
  ) -> dict[str, Any]:
725
+ if self.use_completions:
726
+ return await self._build_body_completions(input, tools=tools, **kwargs)
727
+
673
728
  if self.deep_research:
674
729
  await self._check_deep_research_args(tools, **kwargs)
675
730
 
@@ -701,8 +756,8 @@ class OpenAIModel(LLM):
701
756
 
702
757
  if self.reasoning:
703
758
  body["reasoning"] = {"summary": "auto"}
704
- if self.reasoning_effort:
705
- body["reasoning"]["effort"] = self.reasoning_effort
759
+ if self.reasoning_effort is not None:
760
+ body["reasoning"]["effort"] = self.reasoning_effort # type: ignore[reportArgumentType]
706
761
 
707
762
  if self.supports_temperature:
708
763
  if self.temperature is not None:
@@ -713,7 +768,6 @@ class OpenAIModel(LLM):
713
768
  _ = kwargs.pop("stream", None)
714
769
 
715
770
  body.update(kwargs)
716
-
717
771
  return body
718
772
 
719
773
  @override
@@ -722,6 +776,7 @@ class OpenAIModel(LLM):
722
776
  input: Sequence[InputItem],
723
777
  *,
724
778
  tools: list[ToolDefinition],
779
+ query_logger: logging.Logger,
725
780
  **kwargs: object,
726
781
  ) -> QueryResult:
727
782
  if self.use_completions:
@@ -780,13 +835,12 @@ class OpenAIModel(LLM):
780
835
  citations: list[Citation] = []
781
836
  reasoning = None
782
837
  for output in response.output:
783
- if self.deep_research:
784
- if output.type == "message":
785
- for content in output.content:
786
- if not isinstance(content, ResponseOutputText):
787
- continue
788
- for citation in content.annotations:
789
- citations.append(Citation(**citation.model_dump()))
838
+ if output.type == "message":
839
+ for content in output.content:
840
+ if not isinstance(content, ResponseOutputText):
841
+ continue
842
+ for citation in content.annotations:
843
+ citations.append(Citation(**citation.model_dump()))
790
844
 
791
845
  if output.type == "reasoning":
792
846
  reasoning = " ".join([i.text for i in output.summary])
@@ -809,7 +863,7 @@ class OpenAIModel(LLM):
809
863
  output_text=response.output_text,
810
864
  reasoning=reasoning,
811
865
  tool_calls=tool_calls,
812
- history=[*input, *response.output],
866
+ history=[*input, RawResponse(response=response.output)],
813
867
  extras=QueryResultExtras(citations=citations),
814
868
  )
815
869
  if response.usage:
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import io
4
4
  import json
5
+ import logging
5
6
  import random
6
7
  import re
7
8
  import time
@@ -50,7 +51,7 @@ class DummyAIBatchMixin(LLMBatchMixin):
50
51
  "custom_id": custom_id,
51
52
  "method": "",
52
53
  "url": "",
53
- "body": await self._root.create_body(input, tools=[], **kwargs),
54
+ "body": await self._root.build_body(input, tools=[], **kwargs),
54
55
  }
55
56
 
56
57
  @override
@@ -226,7 +227,8 @@ class DummyAIModel(LLM):
226
227
  ) -> FileWithId:
227
228
  raise NotImplementedError()
228
229
 
229
- async def create_body(
230
+ @override
231
+ async def build_body(
230
232
  self,
231
233
  input: Sequence[InputItem],
232
234
  *,
@@ -271,9 +273,10 @@ class DummyAIModel(LLM):
271
273
  input: Sequence[InputItem],
272
274
  *,
273
275
  tools: list[ToolDefinition],
276
+ query_logger: logging.Logger,
274
277
  **kwargs: object,
275
278
  ) -> QueryResult:
276
- body = await self.create_body(input, tools=tools, **kwargs)
279
+ body = await self.build_body(input, tools=tools, **kwargs)
277
280
 
278
281
  fail_rate = FAIL_RATE
279
282