mirascope 1.18.2__py3-none-any.whl → 1.18.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. mirascope/__init__.py +20 -1
  2. mirascope/beta/openai/__init__.py +1 -1
  3. mirascope/beta/openai/realtime/__init__.py +1 -1
  4. mirascope/beta/openai/realtime/tool.py +1 -1
  5. mirascope/beta/rag/__init__.py +2 -2
  6. mirascope/beta/rag/base/__init__.py +2 -2
  7. mirascope/beta/rag/weaviate/__init__.py +1 -1
  8. mirascope/core/__init__.py +29 -6
  9. mirascope/core/anthropic/__init__.py +3 -3
  10. mirascope/core/anthropic/_utils/_calculate_cost.py +114 -47
  11. mirascope/core/anthropic/call_response.py +9 -3
  12. mirascope/core/anthropic/call_response_chunk.py +7 -0
  13. mirascope/core/anthropic/stream.py +3 -1
  14. mirascope/core/azure/__init__.py +2 -2
  15. mirascope/core/azure/_utils/_calculate_cost.py +4 -1
  16. mirascope/core/azure/call_response.py +9 -3
  17. mirascope/core/azure/call_response_chunk.py +5 -0
  18. mirascope/core/azure/stream.py +3 -1
  19. mirascope/core/base/__init__.py +11 -9
  20. mirascope/core/base/_utils/__init__.py +10 -10
  21. mirascope/core/base/_utils/_get_common_usage.py +8 -4
  22. mirascope/core/base/_utils/_get_create_fn_or_async_create_fn.py +2 -2
  23. mirascope/core/base/_utils/_protocols.py +9 -8
  24. mirascope/core/base/call_response.py +22 -22
  25. mirascope/core/base/call_response_chunk.py +12 -1
  26. mirascope/core/base/stream.py +24 -21
  27. mirascope/core/base/tool.py +7 -5
  28. mirascope/core/base/types.py +22 -5
  29. mirascope/core/bedrock/__init__.py +3 -3
  30. mirascope/core/bedrock/_utils/_calculate_cost.py +4 -1
  31. mirascope/core/bedrock/call_response.py +8 -3
  32. mirascope/core/bedrock/call_response_chunk.py +5 -0
  33. mirascope/core/bedrock/stream.py +3 -1
  34. mirascope/core/cohere/__init__.py +2 -2
  35. mirascope/core/cohere/_utils/_calculate_cost.py +4 -3
  36. mirascope/core/cohere/call_response.py +9 -3
  37. mirascope/core/cohere/call_response_chunk.py +5 -0
  38. mirascope/core/cohere/stream.py +3 -1
  39. mirascope/core/gemini/__init__.py +2 -2
  40. mirascope/core/gemini/_utils/_calculate_cost.py +4 -1
  41. mirascope/core/gemini/_utils/_convert_message_params.py +1 -1
  42. mirascope/core/gemini/call_response.py +9 -3
  43. mirascope/core/gemini/call_response_chunk.py +5 -0
  44. mirascope/core/gemini/stream.py +3 -1
  45. mirascope/core/google/__init__.py +2 -2
  46. mirascope/core/google/_utils/_calculate_cost.py +141 -14
  47. mirascope/core/google/_utils/_convert_message_params.py +120 -115
  48. mirascope/core/google/_utils/_message_param_converter.py +34 -33
  49. mirascope/core/google/_utils/_validate_media_type.py +34 -0
  50. mirascope/core/google/call_response.py +38 -10
  51. mirascope/core/google/call_response_chunk.py +17 -9
  52. mirascope/core/google/stream.py +20 -2
  53. mirascope/core/groq/__init__.py +2 -2
  54. mirascope/core/groq/_utils/_calculate_cost.py +12 -11
  55. mirascope/core/groq/call_response.py +9 -3
  56. mirascope/core/groq/call_response_chunk.py +5 -0
  57. mirascope/core/groq/stream.py +3 -1
  58. mirascope/core/litellm/__init__.py +1 -1
  59. mirascope/core/litellm/_utils/_setup_call.py +7 -3
  60. mirascope/core/mistral/__init__.py +2 -2
  61. mirascope/core/mistral/_utils/_calculate_cost.py +10 -9
  62. mirascope/core/mistral/call_response.py +9 -3
  63. mirascope/core/mistral/call_response_chunk.py +5 -0
  64. mirascope/core/mistral/stream.py +3 -1
  65. mirascope/core/openai/__init__.py +2 -2
  66. mirascope/core/openai/_utils/_calculate_cost.py +78 -37
  67. mirascope/core/openai/call_params.py +13 -0
  68. mirascope/core/openai/call_response.py +14 -3
  69. mirascope/core/openai/call_response_chunk.py +12 -0
  70. mirascope/core/openai/stream.py +6 -4
  71. mirascope/core/vertex/__init__.py +1 -1
  72. mirascope/core/vertex/_utils/_calculate_cost.py +1 -0
  73. mirascope/core/vertex/_utils/_convert_message_params.py +1 -1
  74. mirascope/core/vertex/call_response.py +9 -3
  75. mirascope/core/vertex/call_response_chunk.py +5 -0
  76. mirascope/core/vertex/stream.py +3 -1
  77. mirascope/integrations/_middleware_factory.py +6 -6
  78. mirascope/integrations/logfire/_utils.py +1 -1
  79. mirascope/llm/__init__.py +3 -1
  80. mirascope/llm/_protocols.py +5 -5
  81. mirascope/llm/call_response.py +16 -9
  82. mirascope/llm/llm_call.py +53 -25
  83. mirascope/llm/stream.py +43 -31
  84. mirascope/retries/__init__.py +1 -1
  85. mirascope/tools/__init__.py +2 -2
  86. {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/METADATA +2 -2
  87. {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/RECORD +89 -88
  88. {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/WHEEL +0 -0
  89. {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/licenses/LICENSE +0 -0
@@ -2,7 +2,10 @@
2
2
 
3
3
 
4
4
  def calculate_cost(
5
- input_tokens: int | float | None, output_tokens: int | float | None, model: str
5
+ input_tokens: int | float | None,
6
+ cached_tokens: int | float | None,
7
+ output_tokens: int | float | None,
8
+ model: str,
6
9
  ) -> float | None:
7
10
  """Calculate the cost of a Gemini API call.
8
11
 
@@ -96,7 +96,7 @@ def convert_message_params(
96
96
  elif part.type == "audio_url":
97
97
  if part.url.startswith(("https://", "http://")):
98
98
  audio = _load_media(part.url)
99
- audio_type = get_audio_type(audio)
99
+ audio_type = f"audio/{get_audio_type(audio)}"
100
100
  if audio_type not in [
101
101
  "audio/wav",
102
102
  "audio/mp3",
@@ -122,6 +122,12 @@ class GeminiCallResponse(
122
122
  """Returns the number of input tokens."""
123
123
  return None
124
124
 
125
+ @computed_field
126
+ @property
127
+ def cached_tokens(self) -> None:
128
+ """Returns the number of cached tokens."""
129
+ return None
130
+
125
131
  @computed_field
126
132
  @property
127
133
  def output_tokens(self) -> None:
@@ -132,7 +138,9 @@ class GeminiCallResponse(
132
138
  @property
133
139
  def cost(self) -> float | None:
134
140
  """Returns the cost of the call."""
135
- return calculate_cost(self.input_tokens, self.output_tokens, self.model)
141
+ return calculate_cost(
142
+ self.input_tokens, self.cached_tokens, self.output_tokens, self.model
143
+ )
136
144
 
137
145
  @computed_field
138
146
  @cached_property
@@ -140,7 +148,6 @@ class GeminiCallResponse(
140
148
  """Returns the models's response as a message parameter."""
141
149
  return {"role": "model", "parts": self.response.parts} # pyright: ignore [reportReturnType]
142
150
 
143
- @computed_field
144
151
  @cached_property
145
152
  def tools(self) -> list[GeminiTool] | None:
146
153
  """Returns the list of tools for the 0th candidate's 0th content part."""
@@ -157,7 +164,6 @@ class GeminiCallResponse(
157
164
 
158
165
  return extracted_tools
159
166
 
160
- @computed_field
161
167
  @cached_property
162
168
  def tool(self) -> GeminiTool | None:
163
169
  """Returns the 0th tool for the 0th candidate's 0th content part.
@@ -78,6 +78,11 @@ class GeminiCallResponseChunk(
78
78
  """Returns the number of input tokens."""
79
79
  return None
80
80
 
81
+ @property
82
+ def cached_tokens(self) -> None:
83
+ """Returns the number of cached tokens."""
84
+ return None
85
+
81
86
  @property
82
87
  def output_tokens(self) -> None:
83
88
  """Returns the number of output tokens."""
@@ -69,7 +69,9 @@ class GeminiStream(
69
69
  @property
70
70
  def cost(self) -> float | None:
71
71
  """Returns the cost of the call."""
72
- return calculate_cost(self.input_tokens, self.output_tokens, self.model)
72
+ return calculate_cost(
73
+ self.input_tokens, self.cached_tokens, self.output_tokens, self.model
74
+ )
73
75
 
74
76
  def _construct_message_param(
75
77
  self, tool_calls: list[FunctionCall] | None = None, content: str | None = None
@@ -17,13 +17,13 @@ from .tool import GoogleTool
17
17
  GoogleMessageParam: TypeAlias = ContentDict | FunctionResponse | BaseMessageParam
18
18
 
19
19
  __all__ = [
20
- "call",
21
- "GoogleDynamicConfig",
22
20
  "GoogleCallParams",
23
21
  "GoogleCallResponse",
24
22
  "GoogleCallResponseChunk",
23
+ "GoogleDynamicConfig",
25
24
  "GoogleMessageParam",
26
25
  "GoogleStream",
27
26
  "GoogleTool",
27
+ "call",
28
28
  "google_call",
29
29
  ]
@@ -2,7 +2,10 @@
2
2
 
3
3
 
4
4
  def calculate_cost(
5
- input_tokens: int | float | None, output_tokens: int | float | None, model: str
5
+ input_tokens: int | float | None,
6
+ cached_tokens: int | float | None,
7
+ output_tokens: int | float | None,
8
+ model: str,
6
9
  ) -> float | None:
7
10
  """Calculate the cost of a Google API call.
8
11
 
@@ -10,16 +13,31 @@ def calculate_cost(
10
13
 
11
14
  Pricing (per 1M tokens):
12
15
 
13
- Model Input (<128K) Output (<128K) Input (>128K) Output (>128K)
14
- gemini-2.0-flash $0.10 $0.40 $0.10 $0.40
15
- gemini-2.0-flash-lite $0.075 $0.30 $0.075 $0.30
16
- gemini-1.5-flash $0.075 $0.30 $0.15 $0.60
17
- gemini-1.5-flash-8b $0.0375 $0.15 $0.075 $0.30
18
- gemini-1.5-pro $1.25 $5.00 $2.50 $10.00
19
- gemini-1.0-pro $0.50 $1.50 $0.50 $1.50
16
+ Model Input (<128K) Output (<128K) Input (>128K) Output (>128K) Cached
17
+ gemini-2.0-pro $1.25 $5.00 $2.50 $10.00 $0.625
18
+ gemini-2.0-pro-preview-1206 $1.25 $5.00 $2.50 $10.00 $0.625
19
+ gemini-2.0-flash $0.10 $0.40 $0.10 $0.40 $0.0375
20
+ gemini-2.0-flash-latest $0.10 $0.40 $0.10 $0.40 $0.0375
21
+ gemini-2.0-flash-001 $0.10 $0.40 $0.10 $0.40 $0.0375
22
+ gemini-2.0-flash-lite $0.075 $0.30 $0.075 $0.30 $0.0375
23
+ gemini-2.0-flash-lite-preview-02-05 $0.075 $0.30 $0.075 $0.30 $0.0375
24
+ gemini-1.5-pro $1.25 $5.00 $2.50 $10.00 $0.625
25
+ gemini-1.5-pro-latest $1.25 $5.00 $2.50 $10.00 $0.625
26
+ gemini-1.5-pro-001 $1.25 $5.00 $2.50 $10.00 $0.625
27
+ gemini-1.5-pro-002 $1.25 $5.00 $2.50 $10.00 $0.625
28
+ gemini-1.5-flash $0.075 $0.30 $0.15 $0.60 $0.0375
29
+ gemini-1.5-flash-latest $0.075 $0.30 $0.15 $0.60 $0.0375
30
+ gemini-1.5-flash-001 $0.075 $0.30 $0.15 $0.60 $0.0375
31
+ gemini-1.5-flash-002 $0.075 $0.30 $0.15 $0.60 $0.0375
32
+ gemini-1.5-flash-8b $0.0375 $0.15 $0.075 $0.30 $0.025
33
+ gemini-1.5-flash-8b-latest $0.0375 $0.15 $0.075 $0.30 $0.025
34
+ gemini-1.5-flash-8b-001 $0.0375 $0.15 $0.075 $0.30 $0.025
35
+ gemini-1.5-flash-8b-002 $0.0375 $0.15 $0.075 $0.30 $0.025
36
+ gemini-1.0-pro $0.50 $1.50 $0.50 $1.50 $0.00
20
37
 
21
38
  Args:
22
39
  input_tokens: Number of input tokens
40
+ cached_tokens: Number of cached tokens
23
41
  output_tokens: Number of output tokens
24
42
  model: Model name to use for pricing calculation
25
43
 
@@ -27,47 +45,154 @@ def calculate_cost(
27
45
  Total cost in USD or None if invalid input
28
46
  """
29
47
  pricing = {
48
+ "gemini-2.0-pro": {
49
+ "prompt_short": 0.000_001_25,
50
+ "completion_short": 0.000_005,
51
+ "prompt_long": 0.000_002_5,
52
+ "completion_long": 0.000_01,
53
+ "cached": 0.000_000_625,
54
+ },
55
+ "gemini-2.0-pro-preview-1206": {
56
+ "prompt_short": 0.000_001_25,
57
+ "completion_short": 0.000_005,
58
+ "prompt_long": 0.000_002_5,
59
+ "completion_long": 0.000_01,
60
+ "cached": 0.000_000_625,
61
+ },
30
62
  "gemini-2.0-flash": {
31
63
  "prompt_short": 0.000_000_10,
32
64
  "completion_short": 0.000_000_40,
33
65
  "prompt_long": 0.000_000_10,
34
66
  "completion_long": 0.000_000_40,
67
+ "cached": 0.000_000_037_5,
68
+ },
69
+ "gemini-2.0-flash-latest": {
70
+ "prompt_short": 0.000_000_10,
71
+ "completion_short": 0.000_000_40,
72
+ "prompt_long": 0.000_000_10,
73
+ "completion_long": 0.000_000_40,
74
+ "cached": 0.000_000_037_5,
75
+ },
76
+ "gemini-2.0-flash-001": {
77
+ "prompt_short": 0.000_000_10,
78
+ "completion_short": 0.000_000_40,
79
+ "prompt_long": 0.000_000_10,
80
+ "completion_long": 0.000_000_40,
81
+ "cached": 0.000_000_037_5,
35
82
  },
36
83
  "gemini-2.0-flash-lite": {
37
84
  "prompt_short": 0.000_000_075,
38
85
  "completion_short": 0.000_000_30,
39
86
  "prompt_long": 0.000_000_075,
40
87
  "completion_long": 0.000_000_30,
88
+ "cached": 0.000_000_037_5,
89
+ },
90
+ "gemini-2.0-flash-lite-preview-02-05": {
91
+ "prompt_short": 0.000_000_075,
92
+ "completion_short": 0.000_000_30,
93
+ "prompt_long": 0.000_000_075,
94
+ "completion_long": 0.000_000_30,
95
+ "cached": 0.000_000_037_5,
96
+ },
97
+ "gemini-1.5-pro": {
98
+ "prompt_short": 0.000_001_25,
99
+ "completion_short": 0.000_005,
100
+ "prompt_long": 0.000_002_5,
101
+ "completion_long": 0.000_01,
102
+ "cached": 0.000_000_625,
103
+ },
104
+ "gemini-1.5-pro-latest": {
105
+ "prompt_short": 0.000_001_25,
106
+ "completion_short": 0.000_005,
107
+ "prompt_long": 0.000_002_5,
108
+ "completion_long": 0.000_01,
109
+ "cached": 0.000_000_625,
110
+ },
111
+ "gemini-1.5-pro-001": {
112
+ "prompt_short": 0.000_001_25,
113
+ "completion_short": 0.000_005,
114
+ "prompt_long": 0.000_002_5,
115
+ "completion_long": 0.000_01,
116
+ "cached": 0.000_000_625,
117
+ },
118
+ "gemini-1.5-pro-002": {
119
+ "prompt_short": 0.000_001_25,
120
+ "completion_short": 0.000_005,
121
+ "prompt_long": 0.000_002_5,
122
+ "completion_long": 0.000_01,
123
+ "cached": 0.000_000_625,
41
124
  },
42
125
  "gemini-1.5-flash": {
43
126
  "prompt_short": 0.000_000_075,
44
127
  "completion_short": 0.000_000_30,
45
128
  "prompt_long": 0.000_000_15,
46
129
  "completion_long": 0.000_000_60,
130
+ "cached": 0.000_000_037_5,
131
+ },
132
+ "gemini-1.5-flash-latest": {
133
+ "prompt_short": 0.000_000_075,
134
+ "completion_short": 0.000_000_30,
135
+ "prompt_long": 0.000_000_15,
136
+ "completion_long": 0.000_000_60,
137
+ "cached": 0.000_000_037_5,
138
+ },
139
+ "gemini-1.5-flash-001": {
140
+ "prompt_short": 0.000_000_075,
141
+ "completion_short": 0.000_000_30,
142
+ "prompt_long": 0.000_000_15,
143
+ "completion_long": 0.000_000_60,
144
+ "cached": 0.000_000_037_5,
145
+ },
146
+ "gemini-1.5-flash-002": {
147
+ "prompt_short": 0.000_000_075,
148
+ "completion_short": 0.000_000_30,
149
+ "prompt_long": 0.000_000_15,
150
+ "completion_long": 0.000_000_60,
151
+ "cached": 0.000_000_037_5,
47
152
  },
48
153
  "gemini-1.5-flash-8b": {
49
154
  "prompt_short": 0.000_000_037_5,
50
155
  "completion_short": 0.000_000_15,
51
156
  "prompt_long": 0.000_000_075,
52
157
  "completion_long": 0.000_000_30,
158
+ "cached": 0.000_000_025,
53
159
  },
54
- "gemini-1.5-pro": {
55
- "prompt_short": 0.000_001_25,
56
- "completion_short": 0.000_005,
57
- "prompt_long": 0.000_002_5,
58
- "completion_long": 0.000_01,
160
+ "gemini-1.5-flash-8b-latest": {
161
+ "prompt_short": 0.000_000_037_5,
162
+ "completion_short": 0.000_000_15,
163
+ "prompt_long": 0.000_000_075,
164
+ "completion_long": 0.000_000_30,
165
+ "cached": 0.000_000_025,
166
+ },
167
+ "gemini-1.5-flash-8b-001": {
168
+ "prompt_short": 0.000_000_037_5,
169
+ "completion_short": 0.000_000_15,
170
+ "prompt_long": 0.000_000_075,
171
+ "completion_long": 0.000_000_30,
172
+ "cached": 0.000_000_025,
173
+ },
174
+ "gemini-1.5-flash-8b-002": {
175
+ "prompt_short": 0.000_000_037_5,
176
+ "completion_short": 0.000_000_15,
177
+ "prompt_long": 0.000_000_075,
178
+ "completion_long": 0.000_000_30,
179
+ "cached": 0.000_000_025,
59
180
  },
60
181
  "gemini-1.0-pro": {
61
182
  "prompt_short": 0.000_000_5,
62
183
  "completion_short": 0.000_001_5,
63
184
  "prompt_long": 0.000_000_5,
64
185
  "completion_long": 0.000_001_5,
186
+ "cached": 0.000_000,
65
187
  },
66
188
  }
67
189
 
68
190
  if input_tokens is None or output_tokens is None:
69
191
  return None
70
192
 
193
+ if cached_tokens is None:
194
+ cached_tokens = 0
195
+
71
196
  try:
72
197
  model_pricing = pricing[model]
73
198
  except KeyError:
@@ -77,12 +202,14 @@ def calculate_cost(
77
202
  use_long_context = input_tokens > 128_000
78
203
 
79
204
  prompt_price = model_pricing["prompt_long" if use_long_context else "prompt_short"]
205
+ cached_price = model_pricing["cached"]
80
206
  completion_price = model_pricing[
81
207
  "completion_long" if use_long_context else "completion_short"
82
208
  ]
83
209
 
84
210
  prompt_cost = input_tokens * prompt_price
211
+ cached_cost = cached_tokens * cached_price
85
212
  completion_cost = output_tokens * completion_price
86
- total_cost = prompt_cost + completion_cost
213
+ total_cost = prompt_cost + cached_cost + completion_cost
87
214
 
88
215
  return total_cost
@@ -1,21 +1,37 @@
1
1
  """Utility for converting `BaseMessageParam` to `ContentsType`"""
2
2
 
3
+ import asyncio
3
4
  import base64
4
5
  import io
6
+ from concurrent.futures import ThreadPoolExecutor
5
7
 
6
- import PIL.Image
7
8
  from google.genai import Client
8
- from google.genai.types import BlobDict, ContentDict, FileDataDict, PartDict
9
+ from google.genai.types import (
10
+ BlobDict,
11
+ ContentDict,
12
+ FileDataDict,
13
+ PartDict,
14
+ )
9
15
 
10
16
  from ...base import BaseMessageParam
11
- from ...base._utils import get_audio_type
17
+ from ...base._utils import get_audio_type, get_image_type
12
18
  from ...base._utils._parse_content_template import _load_media
19
+ from ._validate_media_type import _check_audio_media_type, _check_image_media_type
13
20
 
14
21
 
15
- def convert_message_params(
22
+ def _over_file_size_limit(size: int) -> bool:
23
+ """Check if the total file size exceeds the limit (10mb).
24
+
25
+ Google limit is 20MB but base64 adds 33% to the size.
26
+ """
27
+ return size > 10 * 1024 * 1024 # 10MB
28
+
29
+
30
+ async def _convert_message_params_async(
16
31
  message_params: list[BaseMessageParam | ContentDict], client: Client
17
32
  ) -> list[ContentDict]:
18
33
  converted_message_params = []
34
+ total_payload_size = 0
19
35
  for message_param in message_params:
20
36
  if not isinstance(message_param, BaseMessageParam):
21
37
  converted_message_params.append(message_param)
@@ -39,139 +55,108 @@ def convert_message_params(
39
55
  )
40
56
  else:
41
57
  converted_content = []
42
- for part in content:
58
+ must_upload: dict[int, BlobDict] = {}
59
+ for index, part in enumerate(content):
43
60
  if part.type == "text":
44
61
  converted_content.append(PartDict(text=part.text))
45
62
  elif part.type == "image":
46
- if part.media_type not in [
47
- "image/jpeg",
48
- "image/png",
49
- "image/webp",
50
- "image/heic",
51
- "image/heif",
52
- ]:
53
- raise ValueError(
54
- f"Unsupported image media type: {part.media_type}. "
55
- "Google currently only supports JPEG, PNG, WebP, HEIC, "
56
- "and HEIF images."
57
- )
58
- converted_content.append(
59
- PartDict(
60
- inline_data=BlobDict(
61
- data=part.image, mime_type=part.media_type
62
- )
63
- )
64
- )
63
+ _check_image_media_type(part.media_type)
64
+ blob_dict = BlobDict(data=part.image, mime_type=part.media_type)
65
+ converted_content.append(PartDict(inline_data=blob_dict))
66
+ image_size = len(part.image)
67
+ total_payload_size += image_size
68
+ if _over_file_size_limit(total_payload_size):
69
+ must_upload[index] = blob_dict
70
+ total_payload_size -= image_size
65
71
  elif part.type == "image_url":
66
72
  if (
67
- part.url.startswith(("https://", "http://"))
68
- and "generativelanguage.googleapis.com" not in part.url
73
+ client.vertexai
74
+ or not part.url.startswith(("https://", "http://"))
75
+ or "generativelanguage.googleapis.com" in part.url
69
76
  ):
70
- downloaded_image = io.BytesIO(_load_media(part.url))
71
- image = PIL.Image.open(downloaded_image)
72
- media_type = (
73
- PIL.Image.MIME[image.format]
74
- if image.format
75
- else "image/unknown"
76
- )
77
- if media_type not in [
78
- "image/jpeg",
79
- "image/png",
80
- "image/webp",
81
- "image/heic",
82
- "image/heif",
83
- ]:
84
- raise ValueError(
85
- f"Unsupported image media type: {media_type}. "
86
- "Google currently only supports JPEG, PNG, WebP, HEIC, "
87
- "and HEIF images."
77
+ converted_content.append(
78
+ PartDict(
79
+ file_data=FileDataDict(
80
+ file_uri=part.url, mime_type=None
81
+ )
88
82
  )
89
- if client.vertexai:
90
- uri = part.url
91
- else:
92
- downloaded_image.seek(0)
93
- file_ref = client.files.upload(
94
- file=downloaded_image, config={"mime_type": media_type}
95
- )
96
- uri = file_ref.uri
97
- media_type = file_ref.mime_type
83
+ )
98
84
  else:
99
- uri = part.url
100
- media_type = None
101
-
102
- converted_content.append(
103
- PartDict(
104
- file_data=FileDataDict(file_uri=uri, mime_type=media_type)
85
+ downloaded_image = _load_media(part.url)
86
+ media_type = f"image/{get_image_type(downloaded_image)}"
87
+ _check_image_media_type(media_type)
88
+ blob_dict = BlobDict(
89
+ data=downloaded_image, mime_type=media_type
105
90
  )
106
- )
91
+ converted_content.append(PartDict(inline_data=blob_dict))
92
+ image_size = len(downloaded_image)
93
+ total_payload_size += image_size
94
+ if _over_file_size_limit(total_payload_size):
95
+ must_upload[index] = blob_dict
96
+ total_payload_size -= image_size
107
97
  elif part.type == "audio":
108
- if part.media_type not in [
109
- "audio/wav",
110
- "audio/mp3",
111
- "audio/aiff",
112
- "audio/aac",
113
- "audio/ogg",
114
- "audio/flac",
115
- ]:
116
- raise ValueError(
117
- f"Unsupported audio media type: {part.media_type}. "
118
- "Google currently only supports WAV, MP3, AIFF, AAC, OGG, "
119
- "and FLAC audio file types."
120
- )
121
- converted_content.append(
122
- PartDict(
123
- inline_data=BlobDict(
124
- data=part.audio
125
- if isinstance(part.audio, bytes)
126
- else base64.b64decode(part.audio),
127
- mime_type=part.media_type,
128
- )
129
- )
98
+ _check_audio_media_type(part.media_type)
99
+ audio_data = (
100
+ part.audio
101
+ if isinstance(part.audio, bytes)
102
+ else base64.b64decode(part.audio)
130
103
  )
104
+ blob_dict = BlobDict(data=audio_data, mime_type=part.media_type)
105
+ converted_content.append(PartDict(inline_data=blob_dict))
106
+ audio_size = len(audio_data)
107
+ total_payload_size += audio_size
108
+ if _over_file_size_limit(total_payload_size):
109
+ must_upload[index] = blob_dict
110
+ total_payload_size -= audio_size
131
111
  elif part.type == "audio_url":
132
112
  if (
133
- part.url.startswith(("https://", "http://"))
134
- and "generativelanguage.googleapis.com" not in part.url
113
+ client.vertexai
114
+ or not part.url.startswith(("https://", "http://"))
115
+ or "generativelanguage.googleapis.com" in part.url
135
116
  ):
136
- downloaded_audio = _load_media(part.url)
137
- audio_type = get_audio_type(downloaded_audio)
138
- if audio_type not in [
139
- "audio/wav",
140
- "audio/mp3",
141
- "audio/aiff",
142
- "audio/aac",
143
- "audio/ogg",
144
- "audio/flac",
145
- ]:
146
- raise ValueError(
147
- f"Unsupported audio media type: {audio_type}. "
148
- "Google currently only supports WAV, MP3, AIFF, AAC, OGG, "
149
- "and FLAC audio file types."
117
+ converted_content.append(
118
+ PartDict(
119
+ file_data=FileDataDict(
120
+ file_uri=part.url, mime_type=None
121
+ )
150
122
  )
151
- if client.vertexai:
152
- uri = part.url
153
- else:
154
- downloaded_audio = io.BytesIO(downloaded_audio)
155
- downloaded_audio.seek(0)
156
- file_ref = client.files.upload(
157
- file=downloaded_audio, config={"mime_type": audio_type}
158
- )
159
- uri = file_ref.uri
160
- media_type = file_ref.mime_type
123
+ )
161
124
  else:
162
- uri = part.url
163
- audio_type = None
164
-
165
- converted_content.append(
166
- PartDict(
167
- file_data=FileDataDict(file_uri=uri, mime_type=audio_type)
125
+ downloaded_audio = _load_media(part.url)
126
+ media_type = f"audio/{get_audio_type(downloaded_audio)}"
127
+ _check_audio_media_type(media_type)
128
+ blob_dict = BlobDict(
129
+ data=downloaded_audio, mime_type=media_type
168
130
  )
169
- )
131
+ converted_content.append(PartDict(inline_data=blob_dict))
132
+ audio_size = len(downloaded_audio)
133
+ total_payload_size += audio_size
134
+ if _over_file_size_limit(total_payload_size):
135
+ must_upload[index] = blob_dict
136
+ total_payload_size -= audio_size
170
137
  else:
171
138
  raise ValueError(
172
139
  "Google currently only supports text, image, and audio parts. "
173
140
  f"Part provided: {part.type}"
174
141
  )
142
+
143
+ if must_upload:
144
+ indices, blob_dicts = zip(*must_upload.items(), strict=True)
145
+ upload_tasks = [
146
+ client.aio.files.upload(
147
+ file=io.BytesIO(blob_dict["data"]),
148
+ config={"mime_type": blob_dict.get("mime_type", None)},
149
+ )
150
+ for blob_dict in blob_dicts
151
+ ]
152
+ file_refs = await asyncio.gather(*upload_tasks)
153
+ for index, file_ref in zip(indices, file_refs, strict=True):
154
+ converted_content[index] = PartDict(
155
+ file_data=FileDataDict(
156
+ file_uri=file_ref.uri, mime_type=file_ref.mime_type
157
+ )
158
+ )
159
+
175
160
  converted_message_params.append(
176
161
  {
177
162
  "role": role if role == "user" else "model",
@@ -179,3 +164,23 @@ def convert_message_params(
179
164
  }
180
165
  )
181
166
  return converted_message_params
167
+
168
+
169
+ def convert_message_params(
170
+ message_params: list[BaseMessageParam | ContentDict], client: Client
171
+ ) -> list[ContentDict]:
172
+ """Convert message params to Google's ContentDict format.
173
+
174
+ If called from sync context, uses asyncio.run().
175
+ If called from async context, uses the current event loop.
176
+ """
177
+ try:
178
+ asyncio.get_running_loop()
179
+ with ThreadPoolExecutor(max_workers=1) as executor:
180
+ future = executor.submit(
181
+ asyncio.run, _convert_message_params_async(message_params, client)
182
+ )
183
+ return future.result()
184
+ except RuntimeError:
185
+ ...
186
+ return asyncio.run(_convert_message_params_async(message_params, client))