model-library 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. model_library/base/__init__.py +7 -0
  2. model_library/{base.py → base/base.py} +58 -429
  3. model_library/base/batch.py +121 -0
  4. model_library/base/delegate_only.py +94 -0
  5. model_library/base/input.py +100 -0
  6. model_library/base/output.py +229 -0
  7. model_library/base/utils.py +43 -0
  8. model_library/config/ai21labs_models.yaml +1 -0
  9. model_library/config/all_models.json +461 -36
  10. model_library/config/anthropic_models.yaml +30 -3
  11. model_library/config/deepseek_models.yaml +3 -1
  12. model_library/config/google_models.yaml +49 -0
  13. model_library/config/openai_models.yaml +43 -4
  14. model_library/config/together_models.yaml +1 -0
  15. model_library/config/xai_models.yaml +63 -3
  16. model_library/exceptions.py +8 -2
  17. model_library/file_utils.py +1 -1
  18. model_library/providers/__init__.py +0 -0
  19. model_library/providers/ai21labs.py +2 -0
  20. model_library/providers/alibaba.py +16 -78
  21. model_library/providers/amazon.py +3 -0
  22. model_library/providers/anthropic.py +215 -8
  23. model_library/providers/azure.py +2 -0
  24. model_library/providers/cohere.py +14 -80
  25. model_library/providers/deepseek.py +14 -90
  26. model_library/providers/fireworks.py +17 -81
  27. model_library/providers/google/google.py +55 -47
  28. model_library/providers/inception.py +15 -83
  29. model_library/providers/kimi.py +15 -83
  30. model_library/providers/mistral.py +2 -0
  31. model_library/providers/openai.py +10 -2
  32. model_library/providers/perplexity.py +12 -79
  33. model_library/providers/together.py +19 -210
  34. model_library/providers/vals.py +2 -0
  35. model_library/providers/xai.py +2 -0
  36. model_library/providers/zai.py +15 -83
  37. model_library/register_models.py +75 -57
  38. model_library/registry_utils.py +5 -5
  39. model_library/utils.py +3 -28
  40. {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/METADATA +2 -3
  41. model_library-0.1.3.dist-info/RECORD +61 -0
  42. model_library-0.1.1.dist-info/RECORD +0 -54
  43. {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/WHEEL +0 -0
  44. {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/licenses/LICENSE +0 -0
  45. {model_library-0.1.1.dist-info → model_library-0.1.3.dist-info}/top_level.txt +0 -0
@@ -5,14 +5,18 @@ from anthropic import AsyncAnthropic
5
5
  from anthropic.types import TextBlock, ToolUseBlock
6
6
  from anthropic.types.beta.beta_tool_use_block import BetaToolUseBlock
7
7
  from anthropic.types.message import Message
8
+ from typing_extensions import override
9
+
8
10
  from model_library import model_library_settings
9
11
  from model_library.base import (
10
12
  LLM,
13
+ BatchResult,
11
14
  FileInput,
12
15
  FileWithBase64,
13
16
  FileWithId,
14
17
  FileWithUrl,
15
18
  InputItem,
19
+ LLMBatchMixin,
16
20
  LLMConfig,
17
21
  QueryResult,
18
22
  QueryResultCost,
@@ -29,15 +33,217 @@ from model_library.exceptions import (
29
33
  )
30
34
  from model_library.model_utils import get_default_budget_tokens
31
35
  from model_library.providers.openai import OpenAIModel
36
+ from model_library.register_models import register_provider
32
37
  from model_library.utils import (
33
38
  create_openai_client_with_defaults,
34
39
  default_httpx_client,
35
40
  filter_empty_text_blocks,
36
41
  normalize_tool_result,
37
42
  )
38
- from typing_extensions import override
39
43
 
40
44
 
45
+ class AnthropicBatchMixin(LLMBatchMixin):
46
+ """Batch processing support for Anthropic's Message Batches API."""
47
+
48
+ COMPLETED_RESULT_TYPES = ["succeeded", "errored", "canceled", "expired"]
49
+
50
+ def __init__(self, model: "AnthropicModel"):
51
+ self._root = model
52
+
53
+ @override
54
+ async def create_batch_query_request(
55
+ self,
56
+ custom_id: str,
57
+ input: Sequence[InputItem],
58
+ **kwargs: object,
59
+ ) -> dict[str, Any]:
60
+ """Create a single batch request in Anthropic's format.
61
+
62
+ Format: {"custom_id": str, "params": {...message params...}}
63
+ """
64
+ # Build the message body using the parent model's create_body method
65
+ tools = cast(list[ToolDefinition], kwargs.pop("tools", []))
66
+ body = await self._root.create_body(input, tools=tools, **kwargs)
67
+
68
+ return {
69
+ "custom_id": custom_id,
70
+ "params": body,
71
+ }
72
+
73
+ @override
74
+ async def batch_query(
75
+ self,
76
+ batch_name: str,
77
+ requests: list[dict[str, Any]],
78
+ ) -> str:
79
+ """Submit a batch of requests to Anthropic's Message Batches API.
80
+
81
+ Returns the batch ID for status tracking.
82
+ """
83
+ client = self._root.get_client()
84
+
85
+ # Create the batch using Anthropic's batches API
86
+ batch = await client.messages.batches.create(
87
+ requests=cast(Any, requests), # Type mismatch in SDK, cast to Any
88
+ )
89
+
90
+ self._root.logger.info(
91
+ f"Created Anthropic batch {batch.id} with {len(requests)} requests"
92
+ )
93
+
94
+ return batch.id
95
+
96
+ @override
97
+ async def get_batch_results(self, batch_id: str) -> list[BatchResult]:
98
+ """Retrieve results from a completed batch.
99
+
100
+ Streams results using the SDK's batches.results() method.
101
+ """
102
+ client = self._root.get_client()
103
+
104
+ # Get batch status to verify it's completed
105
+ batch = await client.messages.batches.retrieve(batch_id)
106
+
107
+ if batch.processing_status != "ended":
108
+ raise ValueError(
109
+ f"Batch {batch_id} is not completed yet. Status: {batch.processing_status}"
110
+ )
111
+
112
+ # Stream results using the SDK's results method
113
+ batch_results: list[BatchResult] = []
114
+ async for result_item in await client.messages.batches.results(batch_id):
115
+ # result_item is a MessageBatchIndividualResponse - convert to dict
116
+ result_dict = result_item.model_dump()
117
+ custom_id = cast(str, result_dict["custom_id"])
118
+ result_type = cast(str, result_dict["result"]["type"])
119
+
120
+ if result_type not in self.COMPLETED_RESULT_TYPES:
121
+ self._root.logger.warning(
122
+ f"Unknown result type '{result_type}' for request {custom_id}"
123
+ )
124
+ continue
125
+
126
+ if result_type == "succeeded":
127
+ # Extract the message from the successful result
128
+ message_data = cast(dict[str, Any], result_dict["result"]["message"])
129
+
130
+ # Parse the message content to extract text, reasoning, and tool calls
131
+ text = ""
132
+ reasoning = ""
133
+ tool_calls: list[ToolCall] = []
134
+
135
+ for content in message_data.get("content", []):
136
+ if content.get("type") == "text":
137
+ text += content.get("text", "")
138
+ elif content.get("type") == "thinking":
139
+ reasoning += content.get("thinking", "")
140
+ elif content.get("type") == "tool_use":
141
+ tool_calls.append(
142
+ ToolCall(
143
+ id=content["id"],
144
+ name=content["name"],
145
+ args=content.get("input", {}),
146
+ )
147
+ )
148
+
149
+ # Extract usage information
150
+ usage = message_data.get("usage", {})
151
+ metadata = QueryResultMetadata(
152
+ in_tokens=usage.get("input_tokens", 0),
153
+ out_tokens=usage.get("output_tokens", 0),
154
+ cache_read_tokens=usage.get("cache_read_input_tokens", 0),
155
+ cache_write_tokens=usage.get("cache_creation_input_tokens", 0),
156
+ )
157
+
158
+ query_result = QueryResult(
159
+ output_text=text,
160
+ reasoning=reasoning,
161
+ metadata=metadata,
162
+ tool_calls=tool_calls,
163
+ history=[], # History not available in batch results
164
+ )
165
+
166
+ batch_results.append(
167
+ BatchResult(
168
+ custom_id=custom_id,
169
+ output=query_result,
170
+ )
171
+ )
172
+
173
+ elif result_type == "errored":
174
+ # Handle errored results
175
+ error = cast(dict[str, Any], result_dict["result"]["error"])
176
+ error_message = f"{error.get('type', 'unknown_error')}: {error.get('message', 'Unknown error')}"
177
+ output = QueryResult(output_text=error_message)
178
+ batch_results.append(
179
+ BatchResult(
180
+ custom_id=custom_id,
181
+ output=output,
182
+ error_message=error_message,
183
+ )
184
+ )
185
+
186
+ elif result_type in ["canceled", "expired"]:
187
+ # Handle canceled/expired results
188
+ error_message = f"Request {result_type}"
189
+ batch_results.append(
190
+ BatchResult(
191
+ custom_id=custom_id,
192
+ output=QueryResult(output_text=""),
193
+ error_message=error_message,
194
+ )
195
+ )
196
+
197
+ return batch_results
198
+
199
+ @override
200
+ async def get_batch_progress(self, batch_id: str) -> int:
201
+ """Get the number of completed requests in a batch."""
202
+ client = self._root.get_client()
203
+ batch = await client.messages.batches.retrieve(batch_id)
204
+
205
+ # Return the number of processed requests
206
+ request_counts = batch.request_counts
207
+ return (
208
+ request_counts.succeeded
209
+ + request_counts.errored
210
+ + request_counts.canceled
211
+ + request_counts.expired
212
+ )
213
+
214
+ @override
215
+ async def cancel_batch_request(self, batch_id: str) -> None:
216
+ """Cancel a running batch request."""
217
+ client = self._root.get_client()
218
+ await client.messages.batches.cancel(batch_id)
219
+ self._root.logger.info(f"Canceled Anthropic batch {batch_id}")
220
+
221
+ @override
222
+ async def get_batch_status(self, batch_id: str) -> str:
223
+ """Get the current status of a batch."""
224
+ client = self._root.get_client()
225
+ batch = await client.messages.batches.retrieve(batch_id)
226
+ return batch.processing_status
227
+
228
+ @classmethod
229
+ def is_batch_status_completed(cls, batch_status: str) -> bool:
230
+ """Check if a batch status indicates completion."""
231
+ return batch_status == "ended"
232
+
233
+ @classmethod
234
+ def is_batch_status_failed(cls, batch_status: str) -> bool:
235
+ """Check if a batch status indicates failure."""
236
+ # Anthropic batches can have individual request failures but the batch
237
+ # itself doesn't have a "failed" status - it just ends
238
+ return False
239
+
240
+ @classmethod
241
+ def is_batch_status_cancelled(cls, batch_status: str) -> bool:
242
+ """Check if a batch status indicates cancellation."""
243
+ return batch_status == "canceling" or batch_status == "canceled"
244
+
245
+
246
+ @register_provider("anthropic")
41
247
  class AnthropicModel(LLM):
42
248
  _client: AsyncAnthropic | None = None
43
249
 
@@ -78,6 +284,12 @@ class AnthropicModel(LLM):
78
284
  )
79
285
  )
80
286
 
287
+ # Initialize batch support if enabled
288
+ self.supports_batch: bool = self.supports_batch and self.native
289
+ self.batch: LLMBatchMixin | None = (
290
+ AnthropicBatchMixin(self) if self.supports_batch else None
291
+ )
292
+
81
293
  @override
82
294
  async def parse_input(
83
295
  self,
@@ -88,7 +300,6 @@ class AnthropicModel(LLM):
88
300
  content_user: list[dict[str, Any]] = []
89
301
 
90
302
  # First pass: collect all tool calls from Message objects for validation
91
- # This handles both Message and BetaMessage types
92
303
  tool_calls_in_input: set[str] = set()
93
304
  for item in input:
94
305
  if hasattr(item, "content") and hasattr(item, "role"):
@@ -351,12 +562,8 @@ class AnthropicModel(LLM):
351
562
 
352
563
  body = await self.create_body(input, tools=tools, **kwargs)
353
564
 
354
- betas = [
355
- "files-api-2025-04-14",
356
- "interleaved-thinking-2025-05-14",
357
- ]
358
-
359
- if "claude-sonnet-4-5" in self.model_name:
565
+ betas = ["files-api-2025-04-14", "interleaved-thinking-2025-05-14"]
566
+ if "sonnet-4-5" in self.model_name:
360
567
  betas.append("context-1m-2025-08-07")
361
568
 
362
569
  async with self.get_client().beta.messages.stream(
@@ -8,9 +8,11 @@ from model_library.base import (
8
8
  LLMConfig,
9
9
  )
10
10
  from model_library.providers.openai import OpenAIModel
11
+ from model_library.register_models import register_provider
11
12
  from model_library.utils import default_httpx_client
12
13
 
13
14
 
15
+ @register_provider("azure")
14
16
  class AzureOpenAIModel(OpenAIModel):
15
17
  _azure_client: AsyncAzureOpenAI | None = None
16
18
 
@@ -1,27 +1,17 @@
1
- import io
2
- from typing import Any, Literal, Sequence
3
-
4
- from typing_extensions import override
1
+ from typing import Literal
5
2
 
6
3
  from model_library import model_library_settings
7
4
  from model_library.base import (
8
- LLM,
9
- FileInput,
10
- FileWithId,
11
- InputItem,
5
+ DelegateOnly,
12
6
  LLMConfig,
13
- QueryResult,
14
- ToolDefinition,
15
7
  )
16
8
  from model_library.providers.openai import OpenAIModel
9
+ from model_library.register_models import register_provider
17
10
  from model_library.utils import create_openai_client_with_defaults
18
11
 
19
12
 
20
- class CohereModel(LLM):
21
- @override
22
- def get_client(self) -> None:
23
- raise NotImplementedError("Not implemented")
24
-
13
+ @register_provider("cohere")
14
+ class CohereModel(DelegateOnly):
25
15
  def __init__(
26
16
  self,
27
17
  model_name: str,
@@ -30,71 +20,15 @@ class CohereModel(LLM):
30
20
  config: LLMConfig | None = None,
31
21
  ):
32
22
  super().__init__(model_name, provider, config=config)
33
- self.native: bool = False
34
23
 
35
24
  # https://docs.cohere.com/docs/compatibility-api
36
- self.delegate: OpenAIModel | None = (
37
- None
38
- if self.native
39
- else OpenAIModel(
40
- model_name=model_name,
41
- provider=provider,
42
- config=config,
43
- custom_client=create_openai_client_with_defaults(
44
- api_key=model_library_settings.COHERE_API_KEY,
45
- base_url="https://api.cohere.ai/compatibility/v1",
46
- ),
47
- use_completions=True,
48
- )
25
+ self.delegate = OpenAIModel(
26
+ model_name=self.model_name,
27
+ provider=self.provider,
28
+ config=config,
29
+ custom_client=create_openai_client_with_defaults(
30
+ api_key=model_library_settings.COHERE_API_KEY,
31
+ base_url="https://api.cohere.ai/compatibility/v1",
32
+ ),
33
+ use_completions=True,
49
34
  )
50
-
51
- @override
52
- async def parse_input(
53
- self,
54
- input: Sequence[InputItem],
55
- **kwargs: Any,
56
- ) -> Any:
57
- raise NotImplementedError()
58
-
59
- @override
60
- async def parse_image(
61
- self,
62
- image: FileInput,
63
- ) -> Any:
64
- raise NotImplementedError()
65
-
66
- @override
67
- async def parse_file(
68
- self,
69
- file: FileInput,
70
- ) -> Any:
71
- raise NotImplementedError()
72
-
73
- @override
74
- async def parse_tools(
75
- self,
76
- tools: list[ToolDefinition],
77
- ) -> Any:
78
- raise NotImplementedError()
79
-
80
- @override
81
- async def upload_file(
82
- self,
83
- name: str,
84
- mime: str,
85
- bytes: io.BytesIO,
86
- type: Literal["image", "file"] = "file",
87
- ) -> FileWithId:
88
- raise NotImplementedError()
89
-
90
- @override
91
- async def _query_impl(
92
- self,
93
- input: Sequence[InputItem],
94
- *,
95
- tools: list[ToolDefinition],
96
- **kwargs: object,
97
- ) -> QueryResult:
98
- if self.delegate:
99
- return await self.delegate_query(input, tools=tools, **kwargs)
100
- raise NotImplementedError()
@@ -3,31 +3,20 @@ See deepseek data retention policy
3
3
  https://cdn.deepseek.com/policies/en-US/deepseek-privacy-policy.html
4
4
  """
5
5
 
6
- import io
7
- from typing import Any, Literal, Sequence
8
-
9
- from typing_extensions import override
6
+ from typing import Literal
10
7
 
11
8
  from model_library import model_library_settings
12
9
  from model_library.base import (
13
- LLM,
14
- FileInput,
15
- FileWithId,
16
- InputItem,
10
+ DelegateOnly,
17
11
  LLMConfig,
18
- QueryResult,
19
- ToolDefinition,
20
12
  )
21
- from model_library.exceptions import ToolCallingNotSupportedError
22
13
  from model_library.providers.openai import OpenAIModel
14
+ from model_library.register_models import register_provider
23
15
  from model_library.utils import create_openai_client_with_defaults
24
16
 
25
17
 
26
- class DeepSeekModel(LLM):
27
- @override
28
- def get_client(self) -> None:
29
- raise NotImplementedError("Not implemented")
30
-
18
+ @register_provider("deepseek")
19
+ class DeepSeekModel(DelegateOnly):
31
20
  def __init__(
32
21
  self,
33
22
  model_name: str,
@@ -36,80 +25,15 @@ class DeepSeekModel(LLM):
36
25
  config: LLMConfig | None = None,
37
26
  ):
38
27
  super().__init__(model_name, provider, config=config)
39
- self.model_name: str = model_name
40
- self.native: bool = False
41
28
 
42
29
  # https://api-docs.deepseek.com/
43
- self.delegate: OpenAIModel | None = (
44
- None
45
- if self.native
46
- else OpenAIModel(
47
- model_name=self.model_name,
48
- provider=provider,
49
- config=config,
50
- custom_client=create_openai_client_with_defaults(
51
- api_key=model_library_settings.DEEPSEEK_API_KEY,
52
- base_url="https://api.deepseek.com",
53
- ),
54
- use_completions=True,
55
- )
30
+ self.delegate = OpenAIModel(
31
+ model_name=self.model_name,
32
+ provider=self.provider,
33
+ config=config,
34
+ custom_client=create_openai_client_with_defaults(
35
+ api_key=model_library_settings.DEEPSEEK_API_KEY,
36
+ base_url="https://api.deepseek.com",
37
+ ),
38
+ use_completions=True,
56
39
  )
57
-
58
- @override
59
- async def parse_input(
60
- self,
61
- input: Sequence[InputItem],
62
- **kwargs: Any,
63
- ) -> Any:
64
- raise NotImplementedError()
65
-
66
- @override
67
- async def parse_image(
68
- self,
69
- image: FileInput,
70
- ) -> Any:
71
- raise NotImplementedError()
72
-
73
- @override
74
- async def parse_file(
75
- self,
76
- file: FileInput,
77
- ) -> Any:
78
- raise NotImplementedError()
79
-
80
- @override
81
- async def parse_tools(
82
- self,
83
- tools: list[ToolDefinition],
84
- ) -> Any:
85
- raise NotImplementedError()
86
-
87
- @override
88
- async def upload_file(
89
- self,
90
- name: str,
91
- mime: str,
92
- bytes: io.BytesIO,
93
- type: Literal["image", "file"] = "file",
94
- ) -> FileWithId:
95
- raise NotImplementedError()
96
-
97
- @override
98
- async def _query_impl(
99
- self,
100
- input: Sequence[InputItem],
101
- *,
102
- tools: list[ToolDefinition],
103
- **kwargs: object,
104
- ) -> QueryResult:
105
- # DeepSeek reasoning models don't support tools - they auto-route to chat model
106
- # which loses reasoning capability
107
- if tools and not self.supports_tools:
108
- raise ToolCallingNotSupportedError(
109
- f"DeepSeek model ({self.model_name}) does not support tools. "
110
- f"Use deepseek/deepseek-chat for tool calls."
111
- )
112
-
113
- if self.delegate:
114
- return await self.delegate_query(input, tools=tools, **kwargs)
115
- raise NotImplementedError()
@@ -1,22 +1,17 @@
1
- import io
2
- from typing import Any, Literal, Sequence
1
+ from typing import Literal
3
2
 
4
3
  from typing_extensions import override
5
4
 
6
5
  from model_library import model_library_settings
7
6
  from model_library.base import (
8
- LLM,
9
- FileInput,
10
- FileWithId,
11
- InputItem,
12
7
  LLMConfig,
13
8
  ProviderConfig,
14
- QueryResult,
15
9
  QueryResultCost,
16
10
  QueryResultMetadata,
17
- ToolDefinition,
18
11
  )
12
+ from model_library.base.delegate_only import DelegateOnly
19
13
  from model_library.providers.openai import OpenAIModel
14
+ from model_library.register_models import register_provider
20
15
  from model_library.utils import create_openai_client_with_defaults
21
16
 
22
17
 
@@ -24,13 +19,10 @@ class FireworksConfig(ProviderConfig):
24
19
  serverless: bool = True
25
20
 
26
21
 
27
- class FireworksModel(LLM):
22
+ @register_provider("fireworks")
23
+ class FireworksModel(DelegateOnly):
28
24
  provider_config = FireworksConfig()
29
25
 
30
- @override
31
- def get_client(self) -> None:
32
- raise NotImplementedError("Not implemented")
33
-
34
26
  def __init__(
35
27
  self,
36
28
  model_name: str,
@@ -45,76 +37,18 @@ class FireworksModel(LLM):
45
37
  else:
46
38
  self.model_name = "accounts/rayan-936e28/deployedModels/" + self.model_name
47
39
 
48
- # not using Fireworks SDK
49
- self.native: bool = False
50
-
51
40
  # https://docs.fireworks.ai/tools-sdks/openai-compatibility
52
- self.delegate: OpenAIModel | None = (
53
- None
54
- if self.native
55
- else OpenAIModel(
56
- model_name=self.model_name,
57
- provider=provider,
58
- config=config,
59
- custom_client=create_openai_client_with_defaults(
60
- api_key=model_library_settings.FIREWORKS_API_KEY,
61
- base_url="https://api.fireworks.ai/inference/v1",
62
- ),
63
- use_completions=True,
64
- )
41
+ self.delegate = OpenAIModel(
42
+ model_name=self.model_name,
43
+ provider=self.provider,
44
+ config=config,
45
+ custom_client=create_openai_client_with_defaults(
46
+ api_key=model_library_settings.FIREWORKS_API_KEY,
47
+ base_url="https://api.fireworks.ai/inference/v1",
48
+ ),
49
+ use_completions=True,
65
50
  )
66
51
 
67
- @override
68
- async def parse_input(
69
- self,
70
- input: Sequence[InputItem],
71
- **kwargs: Any,
72
- ) -> Any:
73
- raise NotImplementedError()
74
-
75
- @override
76
- async def parse_image(
77
- self,
78
- image: FileInput,
79
- ) -> Any:
80
- raise NotImplementedError()
81
-
82
- @override
83
- async def parse_file(
84
- self,
85
- file: FileInput,
86
- ) -> Any:
87
- raise NotImplementedError()
88
-
89
- @override
90
- async def parse_tools(
91
- self,
92
- tools: list[ToolDefinition],
93
- ) -> Any:
94
- raise NotImplementedError()
95
-
96
- @override
97
- async def upload_file(
98
- self,
99
- name: str,
100
- mime: str,
101
- bytes: io.BytesIO,
102
- type: Literal["image", "file"] = "file",
103
- ) -> FileWithId:
104
- raise NotImplementedError()
105
-
106
- @override
107
- async def _query_impl(
108
- self,
109
- input: Sequence[InputItem],
110
- *,
111
- tools: list[ToolDefinition],
112
- **kwargs: object,
113
- ) -> QueryResult:
114
- if self.delegate:
115
- return await self.delegate_query(input, tools=tools, **kwargs)
116
- raise NotImplementedError()
117
-
118
52
  @override
119
53
  async def _calculate_cost(
120
54
  self,
@@ -130,4 +64,6 @@ class FireworksModel(LLM):
130
64
  # https://docs.fireworks.ai/faq-new/billing-pricing/is-prompt-caching-billed-differently
131
65
  # prompt caching does not affect billing for serverless models
132
66
 
133
- return await super()._calculate_cost(metadata, batch, bill_reasoning=True)
67
+ return await super()._calculate_cost(
68
+ metadata, batch, bill_reasoning=bill_reasoning
69
+ )