model-library 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. model_library/base/base.py +237 -62
  2. model_library/base/delegate_only.py +86 -9
  3. model_library/base/input.py +10 -7
  4. model_library/base/output.py +48 -0
  5. model_library/base/utils.py +56 -7
  6. model_library/config/alibaba_models.yaml +44 -57
  7. model_library/config/all_models.json +253 -126
  8. model_library/config/kimi_models.yaml +30 -3
  9. model_library/config/openai_models.yaml +15 -23
  10. model_library/config/zai_models.yaml +24 -3
  11. model_library/exceptions.py +14 -77
  12. model_library/logging.py +6 -2
  13. model_library/providers/ai21labs.py +30 -14
  14. model_library/providers/alibaba.py +17 -8
  15. model_library/providers/amazon.py +119 -64
  16. model_library/providers/anthropic.py +184 -104
  17. model_library/providers/azure.py +22 -10
  18. model_library/providers/cohere.py +7 -7
  19. model_library/providers/deepseek.py +8 -8
  20. model_library/providers/fireworks.py +7 -8
  21. model_library/providers/google/batch.py +17 -13
  22. model_library/providers/google/google.py +130 -73
  23. model_library/providers/inception.py +7 -7
  24. model_library/providers/kimi.py +18 -8
  25. model_library/providers/minimax.py +30 -13
  26. model_library/providers/mistral.py +61 -35
  27. model_library/providers/openai.py +219 -93
  28. model_library/providers/openrouter.py +34 -0
  29. model_library/providers/perplexity.py +7 -7
  30. model_library/providers/together.py +7 -8
  31. model_library/providers/vals.py +16 -9
  32. model_library/providers/xai.py +157 -144
  33. model_library/providers/zai.py +38 -8
  34. model_library/register_models.py +4 -2
  35. model_library/registry_utils.py +39 -15
  36. model_library/retriers/__init__.py +0 -0
  37. model_library/retriers/backoff.py +73 -0
  38. model_library/retriers/base.py +225 -0
  39. model_library/retriers/token.py +427 -0
  40. model_library/retriers/utils.py +11 -0
  41. model_library/settings.py +1 -1
  42. model_library/utils.py +13 -35
  43. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/METADATA +4 -3
  44. model_library-0.1.8.dist-info/RECORD +70 -0
  45. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
  46. model_library-0.1.6.dist-info/RECORD +0 -64
  47. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
  48. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import datetime
3
4
  import io
4
5
  import json
5
6
  import logging
7
+ import time
6
8
  from typing import Any, Literal, Sequence, cast
7
9
 
8
10
  from openai import APIConnectionError, AsyncOpenAI
@@ -16,6 +18,7 @@ from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
16
18
  from openai.types.create_embedding_response import CreateEmbeddingResponse
17
19
  from openai.types.moderation_create_response import ModerationCreateResponse
18
20
  from openai.types.responses import (
21
+ ResponseFunctionToolCall,
19
22
  ResponseOutputItem,
20
23
  ResponseOutputText,
21
24
  ResponseStreamEvent,
@@ -29,6 +32,8 @@ from model_library.base import (
29
32
  LLM,
30
33
  BatchResult,
31
34
  Citation,
35
+ DelegateConfig,
36
+ FileBase,
32
37
  FileInput,
33
38
  FileWithBase64,
34
39
  FileWithId,
@@ -42,7 +47,9 @@ from model_library.base import (
42
47
  QueryResultCost,
43
48
  QueryResultExtras,
44
49
  QueryResultMetadata,
45
- RawInputItem,
50
+ RateLimit,
51
+ RawInput,
52
+ RawResponse,
46
53
  TextInput,
47
54
  ToolBody,
48
55
  ToolCall,
@@ -53,9 +60,11 @@ from model_library.exceptions import (
53
60
  ImmediateRetryException,
54
61
  MaxOutputTokensExceededError,
55
62
  ModelNoOutputError,
63
+ NoMatchingToolCallError,
56
64
  )
57
65
  from model_library.model_utils import get_reasoning_in_tag
58
66
  from model_library.register_models import register_provider
67
+ from model_library.retriers.base import BaseRetrier
59
68
  from model_library.utils import create_openai_client_with_defaults
60
69
 
61
70
 
@@ -230,23 +239,31 @@ class OpenAIBatchMixin(LLMBatchMixin):
230
239
 
231
240
  class OpenAIConfig(ProviderConfig):
232
241
  deep_research: bool = False
242
+ verbosity: Literal["low", "medium", "high"] | None = None
233
243
 
234
244
 
235
245
  @register_provider("openai")
236
246
  class OpenAIModel(LLM):
237
247
  provider_config = OpenAIConfig()
238
248
 
239
- _client: AsyncOpenAI | None = None
249
+ @override
250
+ def _get_default_api_key(self) -> str:
251
+ if self.delegate_config:
252
+ return self.delegate_config.api_key.get_secret_value()
253
+ return model_library_settings.OPENAI_API_KEY
240
254
 
241
255
  @override
242
- def get_client(self) -> AsyncOpenAI:
243
- if self._delegate_client:
244
- return self._delegate_client
245
- if not OpenAIModel._client:
246
- OpenAIModel._client = create_openai_client_with_defaults(
247
- api_key=model_library_settings.OPENAI_API_KEY
256
+ def get_client(self, api_key: str | None = None) -> AsyncOpenAI:
257
+ if not self.has_client():
258
+ assert api_key
259
+ client = create_openai_client_with_defaults(
260
+ base_url=self.delegate_config.base_url
261
+ if self.delegate_config
262
+ else None,
263
+ api_key=api_key,
248
264
  )
249
- return OpenAIModel._client
265
+ self.assign_client(client)
266
+ return super().get_client()
250
267
 
251
268
  def __init__(
252
269
  self,
@@ -254,22 +271,48 @@ class OpenAIModel(LLM):
254
271
  provider: str = "openai",
255
272
  *,
256
273
  config: LLMConfig | None = None,
257
- custom_client: AsyncOpenAI | None = None,
258
274
  use_completions: bool = False,
275
+ delegate_config: DelegateConfig | None = None,
259
276
  ):
277
+ self.use_completions: bool = (
278
+ use_completions # TODO: do completions in a separate file
279
+ )
280
+ self.delegate_config = delegate_config
281
+
260
282
  super().__init__(model_name, provider, config=config)
261
- self.use_completions: bool = use_completions
262
- self.deep_research = self.provider_config.deep_research
263
283
 
264
- # allow custom client to act as delegate (native)
265
- self._delegate_client: AsyncOpenAI | None = custom_client
284
+ self.deep_research = self.provider_config.deep_research
285
+ self.verbosity = self.provider_config.verbosity
266
286
 
267
287
  # batch client
268
- self.supports_batch: bool = self.supports_batch and not custom_client
288
+ self.supports_batch: bool = self.supports_batch and not self.delegate_config
269
289
  self.batch: LLMBatchMixin | None = (
270
290
  OpenAIBatchMixin(self) if self.supports_batch else None
271
291
  )
272
292
 
293
+ async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
294
+ raw_responses = [x for x in input if isinstance(x, RawResponse)]
295
+ tool_call_ids: list[str] = []
296
+
297
+ if self.use_completions:
298
+ calls = [
299
+ y
300
+ for x in raw_responses
301
+ if isinstance(x.response, ChatCompletionMessage)
302
+ and x.response.tool_calls
303
+ for y in x.response.tool_calls
304
+ ]
305
+ tool_call_ids.extend([x.id for x in calls if x.id])
306
+ else:
307
+ calls = [
308
+ y
309
+ for x in raw_responses
310
+ for y in x.response
311
+ if isinstance(y, ResponseFunctionToolCall)
312
+ ]
313
+ tool_call_ids.extend([x.id for x in calls if x.id])
314
+ return tool_call_ids
315
+
273
316
  @override
274
317
  async def parse_input(
275
318
  self,
@@ -277,63 +320,69 @@ class OpenAIModel(LLM):
277
320
  **kwargs: Any,
278
321
  ) -> list[dict[str, Any] | Any]:
279
322
  new_input: list[dict[str, Any] | Any] = []
323
+
280
324
  content_user: list[dict[str, Any]] = []
325
+
326
+ def flush_content_user():
327
+ if content_user:
328
+ # NOTE: must make new object as we clear()
329
+ new_input.append({"role": "user", "content": content_user.copy()})
330
+ content_user.clear()
331
+
332
+ tool_call_ids = await self.get_tool_call_ids(input)
333
+
281
334
  for item in input:
335
+ if isinstance(item, TextInput):
336
+ if self.use_completions:
337
+ text_key = "text"
338
+ else:
339
+ text_key = "input_text"
340
+ content_user.append({"type": text_key, "text": item.text})
341
+ continue
342
+
343
+ if isinstance(item, FileBase):
344
+ match item.type:
345
+ case "image":
346
+ parsed = await self.parse_image(item)
347
+ case "file":
348
+ parsed = await self.parse_file(item)
349
+ content_user.append(parsed)
350
+ continue
351
+
352
+ # non content user item
353
+ flush_content_user()
354
+
282
355
  match item:
283
- case TextInput():
356
+ case ToolResult():
357
+ if item.tool_call.id not in tool_call_ids:
358
+ raise NoMatchingToolCallError()
359
+
284
360
  if self.use_completions:
285
- content_user.append({"type": "text", "text": item.text})
361
+ new_input.append(
362
+ {
363
+ "role": "tool",
364
+ "tool_call_id": item.tool_call.id,
365
+ "content": item.result,
366
+ }
367
+ )
286
368
  else:
287
- content_user.append({"type": "input_text", "text": item.text})
288
- case FileWithBase64() | FileWithUrl() | FileWithId():
289
- match item.type:
290
- case "image":
291
- content_user.append(await self.parse_image(item))
292
- case "file":
293
- content_user.append(await self.parse_file(item))
294
- case _:
295
- if content_user:
296
- new_input.append({"role": "user", "content": content_user})
297
- content_user = []
298
- match item:
299
- case ToolResult():
300
- if not (
301
- not isinstance(x, dict)
302
- and x.type == "function_call"
303
- and x.call_id == item.tool_call.call_id
304
- for x in new_input
305
- ):
306
- raise Exception(
307
- "Tool call result provided with no matching tool call"
308
- )
309
- if self.use_completions:
310
- new_input.append(
311
- {
312
- "role": "tool",
313
- "tool_call_id": item.tool_call.id,
314
- "content": item.result,
315
- }
316
- )
317
- else:
318
- new_input.append(
319
- {
320
- "type": "function_call_output",
321
- "call_id": item.tool_call.call_id,
322
- "output": item.result,
323
- }
324
- )
325
- case dict(): # RawInputItem
326
- item = cast(RawInputItem, item)
327
- new_input.append(item)
328
- case _: # RawResponse
329
- if self.use_completions:
330
- item = cast(ChatCompletionMessageToolCall, item)
331
- else:
332
- item = cast(ResponseOutputItem, item)
333
- new_input.append(item)
334
-
335
- if content_user:
336
- new_input.append({"role": "user", "content": content_user})
369
+ new_input.append(
370
+ {
371
+ "type": "function_call_output",
372
+ "call_id": item.tool_call.call_id,
373
+ "output": item.result,
374
+ }
375
+ )
376
+ case RawResponse():
377
+ if self.use_completions:
378
+ new_input.append(item.response)
379
+ else:
380
+ new_input.extend(item.response)
381
+ case RawInput():
382
+ new_input.append(item.input)
383
+
384
+ # in case content user item is the last item
385
+ flush_content_user()
337
386
 
338
387
  return new_input
339
388
 
@@ -469,19 +518,13 @@ class OpenAIModel(LLM):
469
518
  file_id=response.id,
470
519
  )
471
520
 
472
- async def _query_completions(
521
+ async def _build_body_completions(
473
522
  self,
474
523
  input: Sequence[InputItem],
475
524
  *,
476
525
  tools: list[ToolDefinition],
477
526
  **kwargs: object,
478
- ) -> QueryResult:
479
- """
480
- Completions endpoint
481
- Generally not used for openai models
482
- Used by some providers using openai as a delegate
483
- """
484
-
527
+ ) -> dict[str, Any]:
485
528
  parsed_input: list[dict[str, Any] | ChatCompletionMessage] = []
486
529
  if "system_prompt" in kwargs:
487
530
  parsed_input.append(
@@ -492,18 +535,20 @@ class OpenAIModel(LLM):
492
535
 
493
536
  body: dict[str, Any] = {
494
537
  "model": self.model_name,
495
- "max_tokens": self.max_tokens,
496
538
  "messages": parsed_input,
497
539
  # enable usage data in streaming responses
498
540
  "stream_options": {"include_usage": True},
499
541
  }
500
542
 
543
+ if self.max_tokens:
544
+ body["max_tokens"] = self.max_tokens
545
+
501
546
  if self.supports_tools:
502
547
  parsed_tools = await self.parse_tools(tools)
503
548
  if parsed_tools:
504
549
  body["tools"] = parsed_tools
505
550
 
506
- if self.reasoning:
551
+ if self.reasoning and self.max_tokens:
507
552
  del body["max_tokens"]
508
553
  body["max_completion_tokens"] = self.max_tokens
509
554
 
@@ -520,6 +565,23 @@ class OpenAIModel(LLM):
520
565
 
521
566
  body.update(kwargs)
522
567
 
568
+ return body
569
+
570
+ async def _query_completions(
571
+ self,
572
+ input: Sequence[InputItem],
573
+ *,
574
+ tools: list[ToolDefinition],
575
+ **kwargs: object,
576
+ ) -> QueryResult:
577
+ """
578
+ Completions endpoint
579
+ Generally not used for openai models
580
+ Used by providers using openai as a delegate
581
+ """
582
+
583
+ body = await self.build_body(input, tools=tools, **kwargs)
584
+
523
585
  output_text: str = ""
524
586
  reasoning_text: str = ""
525
587
  metadata: QueryResultMetadata = QueryResultMetadata()
@@ -632,7 +694,7 @@ class OpenAIModel(LLM):
632
694
  output_text=output_text,
633
695
  reasoning=reasoning_text,
634
696
  tool_calls=tool_calls,
635
- history=[*input, final_message],
697
+ history=[*input, RawResponse(response=final_message)],
636
698
  metadata=metadata,
637
699
  )
638
700
 
@@ -640,7 +702,7 @@ class OpenAIModel(LLM):
640
702
  self, tools: Sequence[ToolDefinition], **kwargs: object
641
703
  ) -> None:
642
704
  min_tokens = 30_000
643
- if self.max_tokens < min_tokens:
705
+ if not self.max_tokens or self.max_tokens < min_tokens:
644
706
  self.logger.warning(
645
707
  f"Recommended to set max_tokens >= {min_tokens} for deep research models"
646
708
  )
@@ -667,13 +729,17 @@ class OpenAIModel(LLM):
667
729
  if not valid:
668
730
  raise Exception("Deep research models require web search tools")
669
731
 
732
+ @override
670
733
  async def build_body(
671
734
  self,
672
735
  input: Sequence[InputItem],
673
736
  *,
674
- tools: Sequence[ToolDefinition],
737
+ tools: list[ToolDefinition],
675
738
  **kwargs: object,
676
739
  ) -> dict[str, Any]:
740
+ if self.use_completions:
741
+ return await self._build_body_completions(input, tools=tools, **kwargs)
742
+
677
743
  if self.deep_research:
678
744
  await self._check_deep_research_args(tools, **kwargs)
679
745
 
@@ -694,10 +760,12 @@ class OpenAIModel(LLM):
694
760
 
695
761
  body: dict[str, Any] = {
696
762
  "model": self.model_name,
697
- "max_output_tokens": self.max_tokens,
698
763
  "input": parsed_input,
699
764
  }
700
765
 
766
+ if self.max_tokens:
767
+ body["max_output_tokens"] = self.max_tokens
768
+
701
769
  if parsed_tools:
702
770
  body["tools"] = parsed_tools
703
771
  else:
@@ -708,6 +776,9 @@ class OpenAIModel(LLM):
708
776
  if self.reasoning_effort is not None:
709
777
  body["reasoning"]["effort"] = self.reasoning_effort # type: ignore[reportArgumentType]
710
778
 
779
+ if self.verbosity is not None:
780
+ body["text"] = {"format": {"type": "text"}, "verbosity": self.verbosity}
781
+
711
782
  if self.supports_temperature:
712
783
  if self.temperature is not None:
713
784
  body["temperature"] = self.temperature
@@ -717,7 +788,6 @@ class OpenAIModel(LLM):
717
788
  _ = kwargs.pop("stream", None)
718
789
 
719
790
  body.update(kwargs)
720
-
721
791
  return body
722
792
 
723
793
  @override
@@ -785,13 +855,12 @@ class OpenAIModel(LLM):
785
855
  citations: list[Citation] = []
786
856
  reasoning = None
787
857
  for output in response.output:
788
- if self.deep_research:
789
- if output.type == "message":
790
- for content in output.content:
791
- if not isinstance(content, ResponseOutputText):
792
- continue
793
- for citation in content.annotations:
794
- citations.append(Citation(**citation.model_dump()))
858
+ if output.type == "message":
859
+ for content in output.content:
860
+ if not isinstance(content, ResponseOutputText):
861
+ continue
862
+ for citation in content.annotations:
863
+ citations.append(Citation(**citation.model_dump()))
795
864
 
796
865
  if output.type == "reasoning":
797
866
  reasoning = " ".join([i.text for i in output.summary])
@@ -814,7 +883,7 @@ class OpenAIModel(LLM):
814
883
  output_text=response.output_text,
815
884
  reasoning=reasoning,
816
885
  tool_calls=tool_calls,
817
- history=[*input, *response.output],
886
+ history=[*input, RawResponse(response=response.output)],
818
887
  extras=QueryResultExtras(citations=citations),
819
888
  )
820
889
  if response.usage:
@@ -834,6 +903,61 @@ class OpenAIModel(LLM):
834
903
 
835
904
  return result
836
905
 
906
+ @override
907
+ async def get_rate_limit(self) -> RateLimit | None:
908
+ headers = {}
909
+
910
+ try:
911
+ # NOTE: with_streaming_response doesn't seem to always work
912
+ if self.use_completions:
913
+ response = (
914
+ await self.get_client().chat.completions.with_raw_response.create(
915
+ max_completion_tokens=16,
916
+ model=self.model_name,
917
+ messages=[
918
+ {
919
+ "role": "user",
920
+ "content": "Ping",
921
+ }
922
+ ],
923
+ stream=True,
924
+ )
925
+ )
926
+ else:
927
+ response = await self.get_client().responses.with_raw_response.create(
928
+ max_output_tokens=16,
929
+ input="Ping",
930
+ model=self.model_name,
931
+ )
932
+ headers = response.headers
933
+
934
+ server_time_str = headers.get("date")
935
+ if server_time_str:
936
+ server_time = datetime.datetime.strptime(
937
+ server_time_str, "%a, %d %b %Y %H:%M:%S GMT"
938
+ ).replace(tzinfo=datetime.timezone.utc)
939
+ timestamp = server_time.timestamp()
940
+ else:
941
+ timestamp = time.time()
942
+
943
+ # NOTE: for openai, max_tokens is used to reject requests if the amount of tokens left is less than the max_tokens
944
+
945
+ # we calculate estimated_tokens as (character_count / 4) + max_tokens. Note that OpenAI's rate limiter doesn't tokenize the request using the model's specific tokenizer but relies on a character count-based heuristic.
946
+
947
+ return RateLimit(
948
+ raw=headers,
949
+ unix_timestamp=timestamp,
950
+ request_limit=headers.get("x-ratelimit-limit-requests", None)
951
+ or headers.get("x-ratelimit-limit", None),
952
+ request_remaining=headers.get("x-ratelimit-remaining-requests", None)
953
+ or headers.get("x-ratelimit-remaining"),
954
+ token_limit=int(headers["x-ratelimit-limit-tokens"]),
955
+ token_remaining=int(headers["x-ratelimit-remaining-tokens"]),
956
+ )
957
+ except Exception as e:
958
+ self.logger.warning(f"Failed to get rate limit: {e}")
959
+ return None
960
+
837
961
  @override
838
962
  async def query_json(
839
963
  self,
@@ -857,7 +981,9 @@ class OpenAIModel(LLM):
857
981
  except APIConnectionError:
858
982
  raise ImmediateRetryException("Failed to connect to OpenAI")
859
983
 
860
- response = await LLM.immediate_retry_wrapper(func=_query, logger=self.logger)
984
+ response = await BaseRetrier.immediate_retry_wrapper(
985
+ func=_query, logger=self.logger
986
+ )
861
987
 
862
988
  parsed: PydanticT | None = response.output_parsed
863
989
  if parsed is None:
@@ -888,7 +1014,7 @@ class OpenAIModel(LLM):
888
1014
 
889
1015
  return response.data[0].embedding
890
1016
 
891
- return await LLM.immediate_retry_wrapper(
1017
+ return await BaseRetrier.immediate_retry_wrapper(
892
1018
  func=_get_embedding, logger=self.logger
893
1019
  )
894
1020
 
@@ -903,7 +1029,7 @@ class OpenAIModel(LLM):
903
1029
  except Exception as e:
904
1030
  raise Exception("Failed to query OpenAI's Moderation endpoint") from e
905
1031
 
906
- return await LLM.immediate_retry_wrapper(
1032
+ return await BaseRetrier.immediate_retry_wrapper(
907
1033
  func=_moderate_content, logger=self.logger
908
1034
  )
909
1035
 
@@ -0,0 +1,34 @@
1
+ from typing import Literal
2
+
3
+ from pydantic import SecretStr
4
+
5
+ from model_library import model_library_settings
6
+ from model_library.base import (
7
+ DelegateConfig,
8
+ DelegateOnly,
9
+ LLMConfig,
10
+ )
11
+ from model_library.register_models import register_provider
12
+
13
+
14
+ @register_provider("openrouter")
15
+ class OpenRouterModel(DelegateOnly):
16
+ def __init__(
17
+ self,
18
+ model_name: str,
19
+ provider: Literal["openrouter"] = "openrouter",
20
+ *,
21
+ config: LLMConfig | None = None,
22
+ ):
23
+ super().__init__(model_name, provider, config=config)
24
+
25
+ # https://openrouter.ai/docs/guides/community/openai-sdk
26
+ self.init_delegate(
27
+ config=config,
28
+ delegate_config=DelegateConfig(
29
+ base_url="https://openrouter.ai/api/v1",
30
+ api_key=SecretStr(model_library_settings.OPENROUTER_API_KEY),
31
+ ),
32
+ use_completions=True,
33
+ delegate_provider="openai",
34
+ )
@@ -1,13 +1,14 @@
1
1
  from typing import Literal
2
2
 
3
+ from pydantic import SecretStr
4
+
3
5
  from model_library import model_library_settings
4
6
  from model_library.base import (
7
+ DelegateConfig,
5
8
  DelegateOnly,
6
9
  LLMConfig,
7
10
  )
8
- from model_library.providers.openai import OpenAIModel
9
11
  from model_library.register_models import register_provider
10
- from model_library.utils import create_openai_client_with_defaults
11
12
 
12
13
 
13
14
  @register_provider("perplexity")
@@ -22,13 +23,12 @@ class PerplexityModel(DelegateOnly):
22
23
  super().__init__(model_name, provider, config=config)
23
24
 
24
25
  # https://docs.perplexity.ai/guides/chat-completions-guide
25
- self.delegate = OpenAIModel(
26
- model_name=self.model_name,
27
- provider=self.provider,
26
+ self.init_delegate(
28
27
  config=config,
29
- custom_client=create_openai_client_with_defaults(
30
- api_key=model_library_settings.PERPLEXITY_API_KEY,
28
+ delegate_config=DelegateConfig(
31
29
  base_url="https://api.perplexity.ai",
30
+ api_key=SecretStr(model_library_settings.PERPLEXITY_API_KEY),
32
31
  ),
33
32
  use_completions=True,
33
+ delegate_provider="openai",
34
34
  )
@@ -1,18 +1,18 @@
1
1
  from typing import Literal
2
2
 
3
+ from pydantic import SecretStr
3
4
  from typing_extensions import override
4
5
 
5
6
  from model_library import model_library_settings
6
7
  from model_library.base import (
8
+ DelegateConfig,
7
9
  DelegateOnly,
8
10
  LLMConfig,
9
11
  ProviderConfig,
10
12
  QueryResultCost,
11
13
  QueryResultMetadata,
12
14
  )
13
- from model_library.providers.openai import OpenAIModel
14
15
  from model_library.register_models import register_provider
15
- from model_library.utils import create_openai_client_with_defaults
16
16
 
17
17
 
18
18
  class TogetherConfig(ProviderConfig):
@@ -32,15 +32,14 @@ class TogetherModel(DelegateOnly):
32
32
  ):
33
33
  super().__init__(model_name, provider, config=config)
34
34
  # https://docs.together.ai/docs/openai-api-compatibility
35
- self.delegate = OpenAIModel(
36
- model_name=self.model_name,
37
- provider=self.provider,
35
+ self.init_delegate(
38
36
  config=config,
39
- custom_client=create_openai_client_with_defaults(
40
- api_key=model_library_settings.TOGETHER_API_KEY,
41
- base_url="https://api.together.xyz/v1",
37
+ delegate_config=DelegateConfig(
38
+ base_url="https://api.together.xyz/v1/",
39
+ api_key=SecretStr(model_library_settings.TOGETHER_API_KEY),
42
40
  ),
43
41
  use_completions=True,
42
+ delegate_provider="openai",
44
43
  )
45
44
 
46
45
  @override