model-library 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. model_library/base/base.py +237 -62
  2. model_library/base/delegate_only.py +86 -9
  3. model_library/base/input.py +10 -7
  4. model_library/base/output.py +48 -0
  5. model_library/base/utils.py +56 -7
  6. model_library/config/alibaba_models.yaml +44 -57
  7. model_library/config/all_models.json +253 -126
  8. model_library/config/kimi_models.yaml +30 -3
  9. model_library/config/openai_models.yaml +15 -23
  10. model_library/config/zai_models.yaml +24 -3
  11. model_library/exceptions.py +14 -77
  12. model_library/logging.py +6 -2
  13. model_library/providers/ai21labs.py +30 -14
  14. model_library/providers/alibaba.py +17 -8
  15. model_library/providers/amazon.py +119 -64
  16. model_library/providers/anthropic.py +184 -104
  17. model_library/providers/azure.py +22 -10
  18. model_library/providers/cohere.py +7 -7
  19. model_library/providers/deepseek.py +8 -8
  20. model_library/providers/fireworks.py +7 -8
  21. model_library/providers/google/batch.py +17 -13
  22. model_library/providers/google/google.py +130 -73
  23. model_library/providers/inception.py +7 -7
  24. model_library/providers/kimi.py +18 -8
  25. model_library/providers/minimax.py +30 -13
  26. model_library/providers/mistral.py +61 -35
  27. model_library/providers/openai.py +219 -93
  28. model_library/providers/openrouter.py +34 -0
  29. model_library/providers/perplexity.py +7 -7
  30. model_library/providers/together.py +7 -8
  31. model_library/providers/vals.py +16 -9
  32. model_library/providers/xai.py +157 -144
  33. model_library/providers/zai.py +38 -8
  34. model_library/register_models.py +4 -2
  35. model_library/registry_utils.py +39 -15
  36. model_library/retriers/__init__.py +0 -0
  37. model_library/retriers/backoff.py +73 -0
  38. model_library/retriers/base.py +225 -0
  39. model_library/retriers/token.py +427 -0
  40. model_library/retriers/utils.py +11 -0
  41. model_library/settings.py +1 -1
  42. model_library/utils.py +13 -35
  43. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/METADATA +4 -3
  44. model_library-0.1.8.dist-info/RECORD +70 -0
  45. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
  46. model_library-0.1.6.dist-info/RECORD +0 -64
  47. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
  48. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,21 @@
1
+ import datetime
1
2
  import io
2
3
  import logging
4
+ import time
3
5
  from typing import Any, Literal, Sequence, cast
4
6
 
5
- from anthropic import AsyncAnthropic
6
- from anthropic.types import TextBlock, ToolUseBlock
7
+ from anthropic import APIConnectionError, AsyncAnthropic
7
8
  from anthropic.types.beta.beta_tool_use_block import BetaToolUseBlock
8
- from anthropic.types.message import Message
9
+ from anthropic.types.beta.parsed_beta_message import ParsedBetaMessage
10
+ from pydantic import SecretStr
9
11
  from typing_extensions import override
10
12
 
11
13
  from model_library import model_library_settings
12
14
  from model_library.base import (
13
15
  LLM,
14
16
  BatchResult,
17
+ DelegateConfig,
18
+ FileBase,
15
19
  FileInput,
16
20
  FileWithBase64,
17
21
  FileWithId,
@@ -22,7 +26,9 @@ from model_library.base import (
22
26
  QueryResult,
23
27
  QueryResultCost,
24
28
  QueryResultMetadata,
25
- RawInputItem,
29
+ RateLimit,
30
+ RawInput,
31
+ RawResponse,
26
32
  TextInput,
27
33
  ToolBody,
28
34
  ToolCall,
@@ -30,16 +36,15 @@ from model_library.base import (
30
36
  ToolResult,
31
37
  )
32
38
  from model_library.exceptions import (
39
+ ImmediateRetryException,
33
40
  MaxOutputTokensExceededError,
41
+ NoMatchingToolCallError,
34
42
  )
35
43
  from model_library.model_utils import get_default_budget_tokens
36
44
  from model_library.providers.openai import OpenAIModel
37
45
  from model_library.register_models import register_provider
38
46
  from model_library.utils import (
39
- create_openai_client_with_defaults,
40
- default_httpx_client,
41
- filter_empty_text_blocks,
42
- normalize_tool_result,
47
+ create_anthropic_client_with_defaults,
43
48
  )
44
49
 
45
50
 
@@ -62,9 +67,9 @@ class AnthropicBatchMixin(LLMBatchMixin):
62
67
 
63
68
  Format: {"custom_id": str, "params": {...message params...}}
64
69
  """
65
- # Build the message body using the parent model's create_body method
70
+ # Build the message body using the parent model's build_body method
66
71
  tools = cast(list[ToolDefinition], kwargs.pop("tools", []))
67
- body = await self._root.create_body(input, tools=tools, **kwargs)
72
+ body = await self._root.build_body(input, tools=tools, **kwargs)
68
73
 
69
74
  return {
70
75
  "custom_id": custom_id,
@@ -246,21 +251,25 @@ class AnthropicBatchMixin(LLMBatchMixin):
246
251
 
247
252
  @register_provider("anthropic")
248
253
  class AnthropicModel(LLM):
249
- _client: AsyncAnthropic | None = None
254
+ def _get_default_api_key(self) -> str:
255
+ if self.delegate_config:
256
+ return self.delegate_config.api_key.get_secret_value()
257
+ return model_library_settings.ANTHROPIC_API_KEY
250
258
 
251
259
  @override
252
- def get_client(self) -> AsyncAnthropic:
253
- if self._delegate_client:
254
- return self._delegate_client
255
- if not AnthropicModel._client:
260
+ def get_client(self, api_key: str | None = None) -> AsyncAnthropic:
261
+ if not self.has_client():
262
+ assert api_key
256
263
  headers: dict[str, str] = {}
257
- AnthropicModel._client = AsyncAnthropic(
258
- api_key=model_library_settings.ANTHROPIC_API_KEY,
259
- http_client=default_httpx_client(),
260
- max_retries=1,
264
+ client = create_anthropic_client_with_defaults(
265
+ base_url=self.delegate_config.base_url
266
+ if self.delegate_config
267
+ else None,
268
+ api_key=api_key,
261
269
  default_headers=headers,
262
270
  )
263
- return AnthropicModel._client
271
+ self.assign_client(client)
272
+ return super().get_client()
264
273
 
265
274
  def __init__(
266
275
  self,
@@ -268,38 +277,51 @@ class AnthropicModel(LLM):
268
277
  provider: str = "anthropic",
269
278
  *,
270
279
  config: LLMConfig | None = None,
271
- custom_client: AsyncAnthropic | None = None,
280
+ delegate_config: DelegateConfig | None = None,
272
281
  ):
273
- super().__init__(model_name, provider, config=config)
282
+ self.delegate_config = delegate_config
274
283
 
275
- # allow custom client to act as delegate (native)
276
- self._delegate_client: AsyncAnthropic | None = custom_client
284
+ super().__init__(model_name, provider, config=config)
277
285
 
278
286
  # https://docs.anthropic.com/en/api/openai-sdk
279
287
  self.delegate = (
280
288
  None
281
- if self.native or custom_client
289
+ if self.native or self.delegate_config
282
290
  else OpenAIModel(
283
291
  model_name=self.model_name,
284
- provider=provider,
292
+ provider=self.provider,
285
293
  config=config,
286
- custom_client=create_openai_client_with_defaults(
287
- api_key=model_library_settings.ANTHROPIC_API_KEY,
294
+ use_completions=True,
295
+ delegate_config=DelegateConfig(
288
296
  base_url="https://api.anthropic.com/v1/",
297
+ api_key=SecretStr(model_library_settings.ANTHROPIC_API_KEY),
289
298
  ),
290
- use_completions=True,
291
299
  )
292
300
  )
293
301
 
294
302
  # Initialize batch support if enabled
295
303
  # Disable batch when using custom_client (similar to OpenAI)
296
304
  self.supports_batch: bool = (
297
- self.supports_batch and self.native and not custom_client
305
+ self.supports_batch and self.native and not self.delegate_config
298
306
  )
299
307
  self.batch: LLMBatchMixin | None = (
300
308
  AnthropicBatchMixin(self) if self.supports_batch else None
301
309
  )
302
310
 
311
+ async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
312
+ raw_responses = [x for x in input if isinstance(x, RawResponse)]
313
+ tool_call_ids: list[str] = []
314
+
315
+ calls = [
316
+ y
317
+ for x in raw_responses
318
+ if isinstance(x.response, ParsedBetaMessage)
319
+ for y in x.response.content # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
320
+ if isinstance(y, BetaToolUseBlock)
321
+ ]
322
+ tool_call_ids.extend([x.id for x in calls])
323
+ return tool_call_ids
324
+
303
325
  @override
304
326
  async def parse_input(
305
327
  self,
@@ -307,77 +329,61 @@ class AnthropicModel(LLM):
307
329
  **kwargs: Any,
308
330
  ) -> list[dict[str, Any] | Any]:
309
331
  new_input: list[dict[str, Any] | Any] = []
332
+
310
333
  content_user: list[dict[str, Any]] = []
311
334
 
312
- # First pass: collect all tool calls from Message objects for validation
313
- tool_calls_in_input: set[str] = set()
314
- for item in input:
315
- if hasattr(item, "content") and hasattr(item, "role"):
316
- content_list = getattr(item, "content", [])
317
- for content in content_list:
318
- # Check for both ToolUseBlock and BetaToolUseBlock
319
- if isinstance(content, (ToolUseBlock, BetaToolUseBlock)):
320
- tool_calls_in_input.add(content.id)
335
+ def flush_content_user():
336
+ if content_user:
337
+ # NOTE: must make new object as we clear()
338
+ new_input.append({"role": "user", "content": content_user.copy()})
339
+ content_user.clear()
340
+
341
+ tool_call_ids = await self.get_tool_call_ids(input)
321
342
 
322
343
  for item in input:
344
+ if isinstance(item, TextInput):
345
+ content_user.append({"type": "text", "text": item.text})
346
+ continue
347
+
348
+ if isinstance(item, FileBase):
349
+ match item.type:
350
+ case "image":
351
+ parsed = await self.parse_image(item)
352
+ case "file":
353
+ parsed = await self.parse_file(item)
354
+ content_user.append(parsed)
355
+ continue
356
+
357
+ # non content user item
358
+ flush_content_user()
359
+
323
360
  match item:
324
- case TextInput():
325
- if item.text.strip():
326
- content_user.append({"type": "text", "text": item.text})
327
- case FileWithBase64() | FileWithUrl() | FileWithId():
328
- match item.type:
329
- case "image":
330
- content_user.append(await self.parse_image(item))
331
- case "file":
332
- content_user.append(await self.parse_file(item))
333
- case _:
334
- if content_user:
335
- filtered = filter_empty_text_blocks(content_user)
336
- if filtered:
337
- new_input.append({"role": "user", "content": filtered})
338
- content_user = []
339
- match item:
340
- case ToolResult():
341
- if item.tool_call.id not in tool_calls_in_input:
342
- raise Exception(
343
- "Tool call result provided with no matching tool call"
344
- )
345
- result_str = normalize_tool_result(item.result)
346
- new_input.append(
361
+ case ToolResult():
362
+ if item.tool_call.id not in tool_call_ids:
363
+ raise NoMatchingToolCallError()
364
+
365
+ new_input.append(
366
+ {
367
+ "role": "user",
368
+ "content": [
347
369
  {
348
- "role": "user",
349
- "content": [
350
- {
351
- "type": "tool_result",
352
- "tool_use_id": item.tool_call.id,
353
- "content": [
354
- {"type": "text", "text": result_str}
355
- ],
356
- }
357
- ],
370
+ "type": "tool_result",
371
+ "tool_use_id": item.tool_call.id,
372
+ "content": [{"type": "text", "text": item.result}],
358
373
  }
359
- )
360
- case dict(): # RawInputItem
361
- item = cast(RawInputItem, item)
362
- new_input.append(item)
363
- case _: # RawResponse
364
- item = cast(Message, item)
365
- filtered_content = [
366
- block
367
- for block in item.content
368
- if not isinstance(block, TextBlock)
369
- or block.text.strip()
370
- ]
371
- if filtered_content:
372
- new_input.append(
373
- {"role": "assistant", "content": filtered_content}
374
- )
375
-
376
- if content_user:
377
- filtered = filter_empty_text_blocks(content_user)
378
- if filtered:
379
- new_input.append({"role": "user", "content": filtered})
374
+ ],
375
+ }
376
+ )
377
+ case RawResponse():
378
+ content = cast(ParsedBetaMessage, item.response).content
379
+ new_input.append({"role": "assistant", "content": content})
380
+ case RawInput():
381
+ new_input.append(item.input)
382
+
383
+ # in case content user item is the last item
384
+ flush_content_user()
380
385
 
386
+ # cache control
381
387
  if new_input:
382
388
  last_msg = new_input[-1]
383
389
  if not isinstance(last_msg, dict):
@@ -495,7 +501,7 @@ class AnthropicModel(LLM):
495
501
  bytes: io.BytesIO,
496
502
  type: Literal["image", "file"] = "file",
497
503
  ) -> FileWithId:
498
- file_mime = f"image/{mime}" if type == "image" else mime # TODO:
504
+ file_mime = f"image/{mime}" if type == "image" else mime
499
505
  response = await self.get_client().beta.files.upload(
500
506
  file=(
501
507
  name,
@@ -513,7 +519,8 @@ class AnthropicModel(LLM):
513
519
 
514
520
  cache_control = {"type": "ephemeral"} # 5 min cache
515
521
 
516
- async def create_body(
522
+ @override
523
+ async def build_body(
517
524
  self,
518
525
  input: Sequence[InputItem],
519
526
  *,
@@ -521,7 +528,6 @@ class AnthropicModel(LLM):
521
528
  **kwargs: object,
522
529
  ) -> dict[str, Any]:
523
530
  body: dict[str, Any] = {
524
- "max_tokens": self.max_tokens,
525
531
  "model": self.model_name,
526
532
  "messages": await self.parse_input(input),
527
533
  }
@@ -535,6 +541,11 @@ class AnthropicModel(LLM):
535
541
  }
536
542
  ]
537
543
 
544
+ if not self.max_tokens:
545
+ raise Exception("Anthropic models require a max_tokens parameter")
546
+
547
+ body["max_tokens"] = self.max_tokens
548
+
538
549
  if self.reasoning:
539
550
  budget_tokens = kwargs.pop(
540
551
  "budget_tokens", get_default_budget_tokens(self.max_tokens)
@@ -573,12 +584,12 @@ class AnthropicModel(LLM):
573
584
  input, tools=tools, query_logger=query_logger, **kwargs
574
585
  )
575
586
 
576
- body = await self.create_body(input, tools=tools, **kwargs)
587
+ body = await self.build_body(input, tools=tools, **kwargs)
577
588
 
578
589
  client = self.get_client()
579
590
 
580
591
  # only send betas for the official Anthropic endpoint
581
- is_anthropic_endpoint = self._delegate_client is None
592
+ is_anthropic_endpoint = self.delegate_config is None
582
593
  if not is_anthropic_endpoint:
583
594
  client_base_url = getattr(client, "_base_url", None) or getattr(
584
595
  client, "base_url", None
@@ -593,11 +604,14 @@ class AnthropicModel(LLM):
593
604
  betas.append("context-1m-2025-08-07")
594
605
  stream_kwargs["betas"] = betas
595
606
 
596
- async with client.beta.messages.stream(
597
- **stream_kwargs,
598
- ) as stream: # pyright: ignore[reportAny]
599
- message = await stream.get_final_message()
600
- self.logger.info(f"Anthropic Response finished: {message.id}")
607
+ try:
608
+ async with client.beta.messages.stream(
609
+ **stream_kwargs,
610
+ ) as stream: # pyright: ignore[reportAny]
611
+ message = await stream.get_final_message()
612
+ self.logger.info(f"Anthropic Response finished: {message.id}")
613
+ except APIConnectionError:
614
+ raise ImmediateRetryException("Failed to connect to Anthropic")
601
615
 
602
616
  text = ""
603
617
  reasoning = ""
@@ -630,9 +644,75 @@ class AnthropicModel(LLM):
630
644
  cache_write_tokens=message.usage.cache_creation_input_tokens,
631
645
  ),
632
646
  tool_calls=tool_calls,
633
- history=[*input, message],
647
+ history=[*input, RawResponse(response=message)],
634
648
  )
635
649
 
650
+ @override
651
+ async def get_rate_limit(self) -> RateLimit:
652
+ response = await self.get_client().messages.with_raw_response.create(
653
+ max_tokens=1,
654
+ messages=[
655
+ {
656
+ "role": "user",
657
+ "content": "Ping",
658
+ }
659
+ ],
660
+ model=self.model_name,
661
+ )
662
+ headers = response.headers
663
+
664
+ server_time_str = headers.get("date")
665
+ if server_time_str:
666
+ server_time = datetime.datetime.strptime(
667
+ server_time_str, "%a, %d %b %Y %H:%M:%S GMT"
668
+ ).replace(tzinfo=datetime.timezone.utc)
669
+ timestamp = server_time.timestamp()
670
+ else:
671
+ timestamp = time.time()
672
+
673
+ return RateLimit(
674
+ unix_timestamp=timestamp,
675
+ raw=headers,
676
+ request_limit=int(headers["anthropic-ratelimit-requests-limit"]),
677
+ request_remaining=int(headers["anthropic-ratelimit-requests-remaining"]),
678
+ token_limit=int(response.headers["anthropic-ratelimit-tokens-limit"]),
679
+ token_remaining=int(headers["anthropic-ratelimit-tokens-remaining"]),
680
+ )
681
+
682
+ @override
683
+ async def count_tokens(
684
+ self,
685
+ input: Sequence[InputItem],
686
+ *,
687
+ history: Sequence[InputItem] = [],
688
+ tools: list[ToolDefinition] = [],
689
+ **kwargs: object,
690
+ ) -> int:
691
+ """
692
+ Count the number of tokens using Anthropic's native token counting API.
693
+ https://docs.anthropic.com/en/docs/build-with-claude/token-counting
694
+ """
695
+ try:
696
+ input = [*history, *input]
697
+ if not input:
698
+ return 0
699
+
700
+ body = await self.build_body(input, tools=tools, **kwargs)
701
+
702
+ # Remove fields not supported by count_tokens endpoint
703
+ body.pop("max_tokens", None)
704
+ body.pop("temperature", None)
705
+
706
+ client = self.get_client()
707
+ response = await client.messages.count_tokens(**body)
708
+
709
+ return response.input_tokens
710
+ except Exception as e:
711
+ self.logger.error(f"Error counting tokens: {e}")
712
+ return await super().count_tokens(
713
+ input, history=history, tools=tools, **kwargs
714
+ )
715
+
636
716
  @override
637
717
  async def _calculate_cost(
638
718
  self,
@@ -1,3 +1,4 @@
1
+ import json
1
2
  from typing import Literal
2
3
 
3
4
  from openai.lib.azure import AsyncAzureOpenAI
@@ -14,21 +15,32 @@ from model_library.utils import default_httpx_client
14
15
 
15
16
  @register_provider("azure")
16
17
  class AzureOpenAIModel(OpenAIModel):
17
- _azure_client: AsyncAzureOpenAI | None = None
18
-
19
18
  @override
20
- def get_client(self) -> AsyncAzureOpenAI:
21
- if not AzureOpenAIModel._azure_client:
22
- AzureOpenAIModel._azure_client = AsyncAzureOpenAI(
23
- api_key=model_library_settings.AZURE_API_KEY,
24
- azure_endpoint=model_library_settings.AZURE_ENDPOINT,
25
- api_version=model_library_settings.get(
19
+ def _get_default_api_key(self) -> str:
20
+ return json.dumps(
21
+ {
22
+ "AZURE_API_KEY": model_library_settings.AZURE_API_KEY,
23
+ "AZURE_ENDPOINT": model_library_settings.AZURE_ENDPOINT,
24
+ "AZURE_API_VERSION": model_library_settings.get(
26
25
  "AZURE_API_VERSION", "2025-04-01-preview"
27
26
  ),
27
+ }
28
+ )
29
+
30
+ @override
31
+ def get_client(self, api_key: str | None = None) -> AsyncAzureOpenAI:
32
+ if not self.has_client():
33
+ assert api_key
34
+ creds = json.loads(api_key)
35
+ client = AsyncAzureOpenAI(
36
+ api_key=creds["AZURE_API_KEY"],
37
+ azure_endpoint=creds["AZURE_ENDPOINT"],
38
+ api_version=creds["AZURE_API_VERSION"],
28
39
  http_client=default_httpx_client(),
29
- max_retries=1,
40
+ max_retries=3,
30
41
  )
31
- return AzureOpenAIModel._azure_client
42
+ self.assign_client(client)
43
+ return super(OpenAIModel, self).get_client(api_key)
32
44
 
33
45
  def __init__(
34
46
  self,
@@ -1,13 +1,14 @@
1
1
  from typing import Literal
2
2
 
3
+ from pydantic import SecretStr
4
+
3
5
  from model_library import model_library_settings
4
6
  from model_library.base import (
7
+ DelegateConfig,
5
8
  DelegateOnly,
6
9
  LLMConfig,
7
10
  )
8
- from model_library.providers.openai import OpenAIModel
9
11
  from model_library.register_models import register_provider
10
- from model_library.utils import create_openai_client_with_defaults
11
12
 
12
13
 
13
14
  @register_provider("cohere")
@@ -22,13 +23,12 @@ class CohereModel(DelegateOnly):
22
23
  super().__init__(model_name, provider, config=config)
23
24
 
24
25
  # https://docs.cohere.com/docs/compatibility-api
25
- self.delegate = OpenAIModel(
26
- model_name=self.model_name,
27
- provider=self.provider,
26
+ self.init_delegate(
28
27
  config=config,
29
- custom_client=create_openai_client_with_defaults(
30
- api_key=model_library_settings.COHERE_API_KEY,
28
+ delegate_config=DelegateConfig(
31
29
  base_url="https://api.cohere.ai/compatibility/v1",
30
+ api_key=SecretStr(model_library_settings.COHERE_API_KEY),
32
31
  ),
33
32
  use_completions=True,
33
+ delegate_provider="openai",
34
34
  )
@@ -5,14 +5,15 @@ https://cdn.deepseek.com/policies/en-US/deepseek-privacy-policy.html
5
5
 
6
6
  from typing import Literal
7
7
 
8
+ from pydantic import SecretStr
9
+
8
10
  from model_library import model_library_settings
9
11
  from model_library.base import (
12
+ DelegateConfig,
10
13
  DelegateOnly,
11
14
  LLMConfig,
12
15
  )
13
- from model_library.providers.openai import OpenAIModel
14
16
  from model_library.register_models import register_provider
15
- from model_library.utils import create_openai_client_with_defaults
16
17
 
17
18
 
18
19
  @register_provider("deepseek")
@@ -27,13 +28,12 @@ class DeepSeekModel(DelegateOnly):
27
28
  super().__init__(model_name, provider, config=config)
28
29
 
29
30
  # https://api-docs.deepseek.com/
30
- self.delegate = OpenAIModel(
31
- model_name=self.model_name,
32
- provider=self.provider,
31
+ self.init_delegate(
33
32
  config=config,
34
- custom_client=create_openai_client_with_defaults(
35
- api_key=model_library_settings.DEEPSEEK_API_KEY,
36
- base_url="https://api.deepseek.com",
33
+ delegate_config=DelegateConfig(
34
+ base_url="https://api.deepseek.com/v1",
35
+ api_key=SecretStr(model_library_settings.DEEPSEEK_API_KEY),
37
36
  ),
38
37
  use_completions=True,
38
+ delegate_provider="openai",
39
39
  )
@@ -1,18 +1,18 @@
1
1
  from typing import Literal
2
2
 
3
+ from pydantic import SecretStr
3
4
  from typing_extensions import override
4
5
 
5
6
  from model_library import model_library_settings
6
7
  from model_library.base import (
8
+ DelegateConfig,
9
+ DelegateOnly,
7
10
  LLMConfig,
8
11
  ProviderConfig,
9
12
  QueryResultCost,
10
13
  QueryResultMetadata,
11
14
  )
12
- from model_library.base.delegate_only import DelegateOnly
13
- from model_library.providers.openai import OpenAIModel
14
15
  from model_library.register_models import register_provider
15
- from model_library.utils import create_openai_client_with_defaults
16
16
 
17
17
 
18
18
  class FireworksConfig(ProviderConfig):
@@ -38,15 +38,14 @@ class FireworksModel(DelegateOnly):
38
38
  self.model_name = "accounts/rayan-936e28/deployedModels/" + self.model_name
39
39
 
40
40
  # https://docs.fireworks.ai/tools-sdks/openai-compatibility
41
- self.delegate = OpenAIModel(
42
- model_name=self.model_name,
43
- provider=self.provider,
41
+ self.init_delegate(
44
42
  config=config,
45
- custom_client=create_openai_client_with_defaults(
46
- api_key=model_library_settings.FIREWORKS_API_KEY,
43
+ delegate_config=DelegateConfig(
47
44
  base_url="https://api.fireworks.ai/inference/v1",
45
+ api_key=SecretStr(model_library_settings.FIREWORKS_API_KEY),
48
46
  ),
49
47
  use_completions=True,
48
+ delegate_provider="openai",
50
49
  )
51
50
 
52
51
  @override