mirascope 1.18.2__py3-none-any.whl → 1.18.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mirascope/__init__.py +20 -1
- mirascope/beta/openai/__init__.py +1 -1
- mirascope/beta/openai/realtime/__init__.py +1 -1
- mirascope/beta/openai/realtime/tool.py +1 -1
- mirascope/beta/rag/__init__.py +2 -2
- mirascope/beta/rag/base/__init__.py +2 -2
- mirascope/beta/rag/weaviate/__init__.py +1 -1
- mirascope/core/__init__.py +29 -6
- mirascope/core/anthropic/__init__.py +3 -3
- mirascope/core/anthropic/_utils/_calculate_cost.py +114 -47
- mirascope/core/anthropic/call_response.py +9 -3
- mirascope/core/anthropic/call_response_chunk.py +7 -0
- mirascope/core/anthropic/stream.py +3 -1
- mirascope/core/azure/__init__.py +2 -2
- mirascope/core/azure/_utils/_calculate_cost.py +4 -1
- mirascope/core/azure/call_response.py +9 -3
- mirascope/core/azure/call_response_chunk.py +5 -0
- mirascope/core/azure/stream.py +3 -1
- mirascope/core/base/__init__.py +11 -9
- mirascope/core/base/_utils/__init__.py +10 -10
- mirascope/core/base/_utils/_get_common_usage.py +8 -4
- mirascope/core/base/_utils/_get_create_fn_or_async_create_fn.py +2 -2
- mirascope/core/base/_utils/_protocols.py +9 -8
- mirascope/core/base/call_response.py +22 -22
- mirascope/core/base/call_response_chunk.py +12 -1
- mirascope/core/base/stream.py +24 -21
- mirascope/core/base/tool.py +7 -5
- mirascope/core/base/types.py +22 -5
- mirascope/core/bedrock/__init__.py +3 -3
- mirascope/core/bedrock/_utils/_calculate_cost.py +4 -1
- mirascope/core/bedrock/call_response.py +8 -3
- mirascope/core/bedrock/call_response_chunk.py +5 -0
- mirascope/core/bedrock/stream.py +3 -1
- mirascope/core/cohere/__init__.py +2 -2
- mirascope/core/cohere/_utils/_calculate_cost.py +4 -3
- mirascope/core/cohere/call_response.py +9 -3
- mirascope/core/cohere/call_response_chunk.py +5 -0
- mirascope/core/cohere/stream.py +3 -1
- mirascope/core/gemini/__init__.py +2 -2
- mirascope/core/gemini/_utils/_calculate_cost.py +4 -1
- mirascope/core/gemini/_utils/_convert_message_params.py +1 -1
- mirascope/core/gemini/call_response.py +9 -3
- mirascope/core/gemini/call_response_chunk.py +5 -0
- mirascope/core/gemini/stream.py +3 -1
- mirascope/core/google/__init__.py +2 -2
- mirascope/core/google/_utils/_calculate_cost.py +141 -14
- mirascope/core/google/_utils/_convert_message_params.py +120 -115
- mirascope/core/google/_utils/_message_param_converter.py +34 -33
- mirascope/core/google/_utils/_validate_media_type.py +34 -0
- mirascope/core/google/call_response.py +38 -10
- mirascope/core/google/call_response_chunk.py +17 -9
- mirascope/core/google/stream.py +20 -2
- mirascope/core/groq/__init__.py +2 -2
- mirascope/core/groq/_utils/_calculate_cost.py +12 -11
- mirascope/core/groq/call_response.py +9 -3
- mirascope/core/groq/call_response_chunk.py +5 -0
- mirascope/core/groq/stream.py +3 -1
- mirascope/core/litellm/__init__.py +1 -1
- mirascope/core/litellm/_utils/_setup_call.py +7 -3
- mirascope/core/mistral/__init__.py +2 -2
- mirascope/core/mistral/_utils/_calculate_cost.py +10 -9
- mirascope/core/mistral/call_response.py +9 -3
- mirascope/core/mistral/call_response_chunk.py +5 -0
- mirascope/core/mistral/stream.py +3 -1
- mirascope/core/openai/__init__.py +2 -2
- mirascope/core/openai/_utils/_calculate_cost.py +78 -37
- mirascope/core/openai/call_params.py +13 -0
- mirascope/core/openai/call_response.py +14 -3
- mirascope/core/openai/call_response_chunk.py +12 -0
- mirascope/core/openai/stream.py +6 -4
- mirascope/core/vertex/__init__.py +1 -1
- mirascope/core/vertex/_utils/_calculate_cost.py +1 -0
- mirascope/core/vertex/_utils/_convert_message_params.py +1 -1
- mirascope/core/vertex/call_response.py +9 -3
- mirascope/core/vertex/call_response_chunk.py +5 -0
- mirascope/core/vertex/stream.py +3 -1
- mirascope/integrations/_middleware_factory.py +6 -6
- mirascope/integrations/logfire/_utils.py +1 -1
- mirascope/llm/__init__.py +3 -1
- mirascope/llm/_protocols.py +5 -5
- mirascope/llm/call_response.py +16 -9
- mirascope/llm/llm_call.py +53 -25
- mirascope/llm/stream.py +43 -31
- mirascope/retries/__init__.py +1 -1
- mirascope/tools/__init__.py +2 -2
- {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/METADATA +2 -2
- {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/RECORD +89 -88
- {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/WHEEL +0 -0
- {mirascope-1.18.2.dist-info → mirascope-1.18.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
def calculate_cost(
|
|
5
|
-
input_tokens: int | float | None,
|
|
5
|
+
input_tokens: int | float | None,
|
|
6
|
+
cached_tokens: int | float | None,
|
|
7
|
+
output_tokens: int | float | None,
|
|
8
|
+
model: str,
|
|
6
9
|
) -> float | None:
|
|
7
10
|
"""Calculate the cost of a Gemini API call.
|
|
8
11
|
|
|
@@ -96,7 +96,7 @@ def convert_message_params(
|
|
|
96
96
|
elif part.type == "audio_url":
|
|
97
97
|
if part.url.startswith(("https://", "http://")):
|
|
98
98
|
audio = _load_media(part.url)
|
|
99
|
-
audio_type = get_audio_type(audio)
|
|
99
|
+
audio_type = f"audio/{get_audio_type(audio)}"
|
|
100
100
|
if audio_type not in [
|
|
101
101
|
"audio/wav",
|
|
102
102
|
"audio/mp3",
|
|
@@ -122,6 +122,12 @@ class GeminiCallResponse(
|
|
|
122
122
|
"""Returns the number of input tokens."""
|
|
123
123
|
return None
|
|
124
124
|
|
|
125
|
+
@computed_field
|
|
126
|
+
@property
|
|
127
|
+
def cached_tokens(self) -> None:
|
|
128
|
+
"""Returns the number of cached tokens."""
|
|
129
|
+
return None
|
|
130
|
+
|
|
125
131
|
@computed_field
|
|
126
132
|
@property
|
|
127
133
|
def output_tokens(self) -> None:
|
|
@@ -132,7 +138,9 @@ class GeminiCallResponse(
|
|
|
132
138
|
@property
|
|
133
139
|
def cost(self) -> float | None:
|
|
134
140
|
"""Returns the cost of the call."""
|
|
135
|
-
return calculate_cost(
|
|
141
|
+
return calculate_cost(
|
|
142
|
+
self.input_tokens, self.cached_tokens, self.output_tokens, self.model
|
|
143
|
+
)
|
|
136
144
|
|
|
137
145
|
@computed_field
|
|
138
146
|
@cached_property
|
|
@@ -140,7 +148,6 @@ class GeminiCallResponse(
|
|
|
140
148
|
"""Returns the models's response as a message parameter."""
|
|
141
149
|
return {"role": "model", "parts": self.response.parts} # pyright: ignore [reportReturnType]
|
|
142
150
|
|
|
143
|
-
@computed_field
|
|
144
151
|
@cached_property
|
|
145
152
|
def tools(self) -> list[GeminiTool] | None:
|
|
146
153
|
"""Returns the list of tools for the 0th candidate's 0th content part."""
|
|
@@ -157,7 +164,6 @@ class GeminiCallResponse(
|
|
|
157
164
|
|
|
158
165
|
return extracted_tools
|
|
159
166
|
|
|
160
|
-
@computed_field
|
|
161
167
|
@cached_property
|
|
162
168
|
def tool(self) -> GeminiTool | None:
|
|
163
169
|
"""Returns the 0th tool for the 0th candidate's 0th content part.
|
|
@@ -78,6 +78,11 @@ class GeminiCallResponseChunk(
|
|
|
78
78
|
"""Returns the number of input tokens."""
|
|
79
79
|
return None
|
|
80
80
|
|
|
81
|
+
@property
|
|
82
|
+
def cached_tokens(self) -> None:
|
|
83
|
+
"""Returns the number of cached tokens."""
|
|
84
|
+
return None
|
|
85
|
+
|
|
81
86
|
@property
|
|
82
87
|
def output_tokens(self) -> None:
|
|
83
88
|
"""Returns the number of output tokens."""
|
mirascope/core/gemini/stream.py
CHANGED
|
@@ -69,7 +69,9 @@ class GeminiStream(
|
|
|
69
69
|
@property
|
|
70
70
|
def cost(self) -> float | None:
|
|
71
71
|
"""Returns the cost of the call."""
|
|
72
|
-
return calculate_cost(
|
|
72
|
+
return calculate_cost(
|
|
73
|
+
self.input_tokens, self.cached_tokens, self.output_tokens, self.model
|
|
74
|
+
)
|
|
73
75
|
|
|
74
76
|
def _construct_message_param(
|
|
75
77
|
self, tool_calls: list[FunctionCall] | None = None, content: str | None = None
|
|
@@ -17,13 +17,13 @@ from .tool import GoogleTool
|
|
|
17
17
|
GoogleMessageParam: TypeAlias = ContentDict | FunctionResponse | BaseMessageParam
|
|
18
18
|
|
|
19
19
|
__all__ = [
|
|
20
|
-
"call",
|
|
21
|
-
"GoogleDynamicConfig",
|
|
22
20
|
"GoogleCallParams",
|
|
23
21
|
"GoogleCallResponse",
|
|
24
22
|
"GoogleCallResponseChunk",
|
|
23
|
+
"GoogleDynamicConfig",
|
|
25
24
|
"GoogleMessageParam",
|
|
26
25
|
"GoogleStream",
|
|
27
26
|
"GoogleTool",
|
|
27
|
+
"call",
|
|
28
28
|
"google_call",
|
|
29
29
|
]
|
|
@@ -2,7 +2,10 @@
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
def calculate_cost(
|
|
5
|
-
input_tokens: int | float | None,
|
|
5
|
+
input_tokens: int | float | None,
|
|
6
|
+
cached_tokens: int | float | None,
|
|
7
|
+
output_tokens: int | float | None,
|
|
8
|
+
model: str,
|
|
6
9
|
) -> float | None:
|
|
7
10
|
"""Calculate the cost of a Google API call.
|
|
8
11
|
|
|
@@ -10,16 +13,31 @@ def calculate_cost(
|
|
|
10
13
|
|
|
11
14
|
Pricing (per 1M tokens):
|
|
12
15
|
|
|
13
|
-
Model
|
|
14
|
-
gemini-2.0-
|
|
15
|
-
gemini-2.0-
|
|
16
|
-
gemini-
|
|
17
|
-
gemini-
|
|
18
|
-
gemini-
|
|
19
|
-
gemini-
|
|
16
|
+
Model Input (<128K) Output (<128K) Input (>128K) Output (>128K) Cached
|
|
17
|
+
gemini-2.0-pro $1.25 $5.00 $2.50 $10.00 $0.625
|
|
18
|
+
gemini-2.0-pro-preview-1206 $1.25 $5.00 $2.50 $10.00 $0.625
|
|
19
|
+
gemini-2.0-flash $0.10 $0.40 $0.10 $0.40 $0.0375
|
|
20
|
+
gemini-2.0-flash-latest $0.10 $0.40 $0.10 $0.40 $0.0375
|
|
21
|
+
gemini-2.0-flash-001 $0.10 $0.40 $0.10 $0.40 $0.0375
|
|
22
|
+
gemini-2.0-flash-lite $0.075 $0.30 $0.075 $0.30 $0.0375
|
|
23
|
+
gemini-2.0-flash-lite-preview-02-05 $0.075 $0.30 $0.075 $0.30 $0.0375
|
|
24
|
+
gemini-1.5-pro $1.25 $5.00 $2.50 $10.00 $0.625
|
|
25
|
+
gemini-1.5-pro-latest $1.25 $5.00 $2.50 $10.00 $0.625
|
|
26
|
+
gemini-1.5-pro-001 $1.25 $5.00 $2.50 $10.00 $0.625
|
|
27
|
+
gemini-1.5-pro-002 $1.25 $5.00 $2.50 $10.00 $0.625
|
|
28
|
+
gemini-1.5-flash $0.075 $0.30 $0.15 $0.60 $0.0375
|
|
29
|
+
gemini-1.5-flash-latest $0.075 $0.30 $0.15 $0.60 $0.0375
|
|
30
|
+
gemini-1.5-flash-001 $0.075 $0.30 $0.15 $0.60 $0.0375
|
|
31
|
+
gemini-1.5-flash-002 $0.075 $0.30 $0.15 $0.60 $0.0375
|
|
32
|
+
gemini-1.5-flash-8b $0.0375 $0.15 $0.075 $0.30 $0.025
|
|
33
|
+
gemini-1.5-flash-8b-latest $0.0375 $0.15 $0.075 $0.30 $0.025
|
|
34
|
+
gemini-1.5-flash-8b-001 $0.0375 $0.15 $0.075 $0.30 $0.025
|
|
35
|
+
gemini-1.5-flash-8b-002 $0.0375 $0.15 $0.075 $0.30 $0.025
|
|
36
|
+
gemini-1.0-pro $0.50 $1.50 $0.50 $1.50 $0.00
|
|
20
37
|
|
|
21
38
|
Args:
|
|
22
39
|
input_tokens: Number of input tokens
|
|
40
|
+
cached_tokens: Number of cached tokens
|
|
23
41
|
output_tokens: Number of output tokens
|
|
24
42
|
model: Model name to use for pricing calculation
|
|
25
43
|
|
|
@@ -27,47 +45,154 @@ def calculate_cost(
|
|
|
27
45
|
Total cost in USD or None if invalid input
|
|
28
46
|
"""
|
|
29
47
|
pricing = {
|
|
48
|
+
"gemini-2.0-pro": {
|
|
49
|
+
"prompt_short": 0.000_001_25,
|
|
50
|
+
"completion_short": 0.000_005,
|
|
51
|
+
"prompt_long": 0.000_002_5,
|
|
52
|
+
"completion_long": 0.000_01,
|
|
53
|
+
"cached": 0.000_000_625,
|
|
54
|
+
},
|
|
55
|
+
"gemini-2.0-pro-preview-1206": {
|
|
56
|
+
"prompt_short": 0.000_001_25,
|
|
57
|
+
"completion_short": 0.000_005,
|
|
58
|
+
"prompt_long": 0.000_002_5,
|
|
59
|
+
"completion_long": 0.000_01,
|
|
60
|
+
"cached": 0.000_000_625,
|
|
61
|
+
},
|
|
30
62
|
"gemini-2.0-flash": {
|
|
31
63
|
"prompt_short": 0.000_000_10,
|
|
32
64
|
"completion_short": 0.000_000_40,
|
|
33
65
|
"prompt_long": 0.000_000_10,
|
|
34
66
|
"completion_long": 0.000_000_40,
|
|
67
|
+
"cached": 0.000_000_037_5,
|
|
68
|
+
},
|
|
69
|
+
"gemini-2.0-flash-latest": {
|
|
70
|
+
"prompt_short": 0.000_000_10,
|
|
71
|
+
"completion_short": 0.000_000_40,
|
|
72
|
+
"prompt_long": 0.000_000_10,
|
|
73
|
+
"completion_long": 0.000_000_40,
|
|
74
|
+
"cached": 0.000_000_037_5,
|
|
75
|
+
},
|
|
76
|
+
"gemini-2.0-flash-001": {
|
|
77
|
+
"prompt_short": 0.000_000_10,
|
|
78
|
+
"completion_short": 0.000_000_40,
|
|
79
|
+
"prompt_long": 0.000_000_10,
|
|
80
|
+
"completion_long": 0.000_000_40,
|
|
81
|
+
"cached": 0.000_000_037_5,
|
|
35
82
|
},
|
|
36
83
|
"gemini-2.0-flash-lite": {
|
|
37
84
|
"prompt_short": 0.000_000_075,
|
|
38
85
|
"completion_short": 0.000_000_30,
|
|
39
86
|
"prompt_long": 0.000_000_075,
|
|
40
87
|
"completion_long": 0.000_000_30,
|
|
88
|
+
"cached": 0.000_000_037_5,
|
|
89
|
+
},
|
|
90
|
+
"gemini-2.0-flash-lite-preview-02-05": {
|
|
91
|
+
"prompt_short": 0.000_000_075,
|
|
92
|
+
"completion_short": 0.000_000_30,
|
|
93
|
+
"prompt_long": 0.000_000_075,
|
|
94
|
+
"completion_long": 0.000_000_30,
|
|
95
|
+
"cached": 0.000_000_037_5,
|
|
96
|
+
},
|
|
97
|
+
"gemini-1.5-pro": {
|
|
98
|
+
"prompt_short": 0.000_001_25,
|
|
99
|
+
"completion_short": 0.000_005,
|
|
100
|
+
"prompt_long": 0.000_002_5,
|
|
101
|
+
"completion_long": 0.000_01,
|
|
102
|
+
"cached": 0.000_000_625,
|
|
103
|
+
},
|
|
104
|
+
"gemini-1.5-pro-latest": {
|
|
105
|
+
"prompt_short": 0.000_001_25,
|
|
106
|
+
"completion_short": 0.000_005,
|
|
107
|
+
"prompt_long": 0.000_002_5,
|
|
108
|
+
"completion_long": 0.000_01,
|
|
109
|
+
"cached": 0.000_000_625,
|
|
110
|
+
},
|
|
111
|
+
"gemini-1.5-pro-001": {
|
|
112
|
+
"prompt_short": 0.000_001_25,
|
|
113
|
+
"completion_short": 0.000_005,
|
|
114
|
+
"prompt_long": 0.000_002_5,
|
|
115
|
+
"completion_long": 0.000_01,
|
|
116
|
+
"cached": 0.000_000_625,
|
|
117
|
+
},
|
|
118
|
+
"gemini-1.5-pro-002": {
|
|
119
|
+
"prompt_short": 0.000_001_25,
|
|
120
|
+
"completion_short": 0.000_005,
|
|
121
|
+
"prompt_long": 0.000_002_5,
|
|
122
|
+
"completion_long": 0.000_01,
|
|
123
|
+
"cached": 0.000_000_625,
|
|
41
124
|
},
|
|
42
125
|
"gemini-1.5-flash": {
|
|
43
126
|
"prompt_short": 0.000_000_075,
|
|
44
127
|
"completion_short": 0.000_000_30,
|
|
45
128
|
"prompt_long": 0.000_000_15,
|
|
46
129
|
"completion_long": 0.000_000_60,
|
|
130
|
+
"cached": 0.000_000_037_5,
|
|
131
|
+
},
|
|
132
|
+
"gemini-1.5-flash-latest": {
|
|
133
|
+
"prompt_short": 0.000_000_075,
|
|
134
|
+
"completion_short": 0.000_000_30,
|
|
135
|
+
"prompt_long": 0.000_000_15,
|
|
136
|
+
"completion_long": 0.000_000_60,
|
|
137
|
+
"cached": 0.000_000_037_5,
|
|
138
|
+
},
|
|
139
|
+
"gemini-1.5-flash-001": {
|
|
140
|
+
"prompt_short": 0.000_000_075,
|
|
141
|
+
"completion_short": 0.000_000_30,
|
|
142
|
+
"prompt_long": 0.000_000_15,
|
|
143
|
+
"completion_long": 0.000_000_60,
|
|
144
|
+
"cached": 0.000_000_037_5,
|
|
145
|
+
},
|
|
146
|
+
"gemini-1.5-flash-002": {
|
|
147
|
+
"prompt_short": 0.000_000_075,
|
|
148
|
+
"completion_short": 0.000_000_30,
|
|
149
|
+
"prompt_long": 0.000_000_15,
|
|
150
|
+
"completion_long": 0.000_000_60,
|
|
151
|
+
"cached": 0.000_000_037_5,
|
|
47
152
|
},
|
|
48
153
|
"gemini-1.5-flash-8b": {
|
|
49
154
|
"prompt_short": 0.000_000_037_5,
|
|
50
155
|
"completion_short": 0.000_000_15,
|
|
51
156
|
"prompt_long": 0.000_000_075,
|
|
52
157
|
"completion_long": 0.000_000_30,
|
|
158
|
+
"cached": 0.000_000_025,
|
|
53
159
|
},
|
|
54
|
-
"gemini-1.5-
|
|
55
|
-
"prompt_short": 0.
|
|
56
|
-
"completion_short": 0.
|
|
57
|
-
"prompt_long": 0.
|
|
58
|
-
"completion_long": 0.
|
|
160
|
+
"gemini-1.5-flash-8b-latest": {
|
|
161
|
+
"prompt_short": 0.000_000_037_5,
|
|
162
|
+
"completion_short": 0.000_000_15,
|
|
163
|
+
"prompt_long": 0.000_000_075,
|
|
164
|
+
"completion_long": 0.000_000_30,
|
|
165
|
+
"cached": 0.000_000_025,
|
|
166
|
+
},
|
|
167
|
+
"gemini-1.5-flash-8b-001": {
|
|
168
|
+
"prompt_short": 0.000_000_037_5,
|
|
169
|
+
"completion_short": 0.000_000_15,
|
|
170
|
+
"prompt_long": 0.000_000_075,
|
|
171
|
+
"completion_long": 0.000_000_30,
|
|
172
|
+
"cached": 0.000_000_025,
|
|
173
|
+
},
|
|
174
|
+
"gemini-1.5-flash-8b-002": {
|
|
175
|
+
"prompt_short": 0.000_000_037_5,
|
|
176
|
+
"completion_short": 0.000_000_15,
|
|
177
|
+
"prompt_long": 0.000_000_075,
|
|
178
|
+
"completion_long": 0.000_000_30,
|
|
179
|
+
"cached": 0.000_000_025,
|
|
59
180
|
},
|
|
60
181
|
"gemini-1.0-pro": {
|
|
61
182
|
"prompt_short": 0.000_000_5,
|
|
62
183
|
"completion_short": 0.000_001_5,
|
|
63
184
|
"prompt_long": 0.000_000_5,
|
|
64
185
|
"completion_long": 0.000_001_5,
|
|
186
|
+
"cached": 0.000_000,
|
|
65
187
|
},
|
|
66
188
|
}
|
|
67
189
|
|
|
68
190
|
if input_tokens is None or output_tokens is None:
|
|
69
191
|
return None
|
|
70
192
|
|
|
193
|
+
if cached_tokens is None:
|
|
194
|
+
cached_tokens = 0
|
|
195
|
+
|
|
71
196
|
try:
|
|
72
197
|
model_pricing = pricing[model]
|
|
73
198
|
except KeyError:
|
|
@@ -77,12 +202,14 @@ def calculate_cost(
|
|
|
77
202
|
use_long_context = input_tokens > 128_000
|
|
78
203
|
|
|
79
204
|
prompt_price = model_pricing["prompt_long" if use_long_context else "prompt_short"]
|
|
205
|
+
cached_price = model_pricing["cached"]
|
|
80
206
|
completion_price = model_pricing[
|
|
81
207
|
"completion_long" if use_long_context else "completion_short"
|
|
82
208
|
]
|
|
83
209
|
|
|
84
210
|
prompt_cost = input_tokens * prompt_price
|
|
211
|
+
cached_cost = cached_tokens * cached_price
|
|
85
212
|
completion_cost = output_tokens * completion_price
|
|
86
|
-
total_cost = prompt_cost + completion_cost
|
|
213
|
+
total_cost = prompt_cost + cached_cost + completion_cost
|
|
87
214
|
|
|
88
215
|
return total_cost
|
|
@@ -1,21 +1,37 @@
|
|
|
1
1
|
"""Utility for converting `BaseMessageParam` to `ContentsType`"""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import base64
|
|
4
5
|
import io
|
|
6
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
5
7
|
|
|
6
|
-
import PIL.Image
|
|
7
8
|
from google.genai import Client
|
|
8
|
-
from google.genai.types import
|
|
9
|
+
from google.genai.types import (
|
|
10
|
+
BlobDict,
|
|
11
|
+
ContentDict,
|
|
12
|
+
FileDataDict,
|
|
13
|
+
PartDict,
|
|
14
|
+
)
|
|
9
15
|
|
|
10
16
|
from ...base import BaseMessageParam
|
|
11
|
-
from ...base._utils import get_audio_type
|
|
17
|
+
from ...base._utils import get_audio_type, get_image_type
|
|
12
18
|
from ...base._utils._parse_content_template import _load_media
|
|
19
|
+
from ._validate_media_type import _check_audio_media_type, _check_image_media_type
|
|
13
20
|
|
|
14
21
|
|
|
15
|
-
def
|
|
22
|
+
def _over_file_size_limit(size: int) -> bool:
|
|
23
|
+
"""Check if the total file size exceeds the limit (10mb).
|
|
24
|
+
|
|
25
|
+
Google limit is 20MB but base64 adds 33% to the size.
|
|
26
|
+
"""
|
|
27
|
+
return size > 10 * 1024 * 1024 # 10MB
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
async def _convert_message_params_async(
|
|
16
31
|
message_params: list[BaseMessageParam | ContentDict], client: Client
|
|
17
32
|
) -> list[ContentDict]:
|
|
18
33
|
converted_message_params = []
|
|
34
|
+
total_payload_size = 0
|
|
19
35
|
for message_param in message_params:
|
|
20
36
|
if not isinstance(message_param, BaseMessageParam):
|
|
21
37
|
converted_message_params.append(message_param)
|
|
@@ -39,139 +55,108 @@ def convert_message_params(
|
|
|
39
55
|
)
|
|
40
56
|
else:
|
|
41
57
|
converted_content = []
|
|
42
|
-
|
|
58
|
+
must_upload: dict[int, BlobDict] = {}
|
|
59
|
+
for index, part in enumerate(content):
|
|
43
60
|
if part.type == "text":
|
|
44
61
|
converted_content.append(PartDict(text=part.text))
|
|
45
62
|
elif part.type == "image":
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
f"Unsupported image media type: {part.media_type}. "
|
|
55
|
-
"Google currently only supports JPEG, PNG, WebP, HEIC, "
|
|
56
|
-
"and HEIF images."
|
|
57
|
-
)
|
|
58
|
-
converted_content.append(
|
|
59
|
-
PartDict(
|
|
60
|
-
inline_data=BlobDict(
|
|
61
|
-
data=part.image, mime_type=part.media_type
|
|
62
|
-
)
|
|
63
|
-
)
|
|
64
|
-
)
|
|
63
|
+
_check_image_media_type(part.media_type)
|
|
64
|
+
blob_dict = BlobDict(data=part.image, mime_type=part.media_type)
|
|
65
|
+
converted_content.append(PartDict(inline_data=blob_dict))
|
|
66
|
+
image_size = len(part.image)
|
|
67
|
+
total_payload_size += image_size
|
|
68
|
+
if _over_file_size_limit(total_payload_size):
|
|
69
|
+
must_upload[index] = blob_dict
|
|
70
|
+
total_payload_size -= image_size
|
|
65
71
|
elif part.type == "image_url":
|
|
66
72
|
if (
|
|
67
|
-
|
|
68
|
-
|
|
73
|
+
client.vertexai
|
|
74
|
+
or not part.url.startswith(("https://", "http://"))
|
|
75
|
+
or "generativelanguage.googleapis.com" in part.url
|
|
69
76
|
):
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
else "image/unknown"
|
|
76
|
-
)
|
|
77
|
-
if media_type not in [
|
|
78
|
-
"image/jpeg",
|
|
79
|
-
"image/png",
|
|
80
|
-
"image/webp",
|
|
81
|
-
"image/heic",
|
|
82
|
-
"image/heif",
|
|
83
|
-
]:
|
|
84
|
-
raise ValueError(
|
|
85
|
-
f"Unsupported image media type: {media_type}. "
|
|
86
|
-
"Google currently only supports JPEG, PNG, WebP, HEIC, "
|
|
87
|
-
"and HEIF images."
|
|
77
|
+
converted_content.append(
|
|
78
|
+
PartDict(
|
|
79
|
+
file_data=FileDataDict(
|
|
80
|
+
file_uri=part.url, mime_type=None
|
|
81
|
+
)
|
|
88
82
|
)
|
|
89
|
-
|
|
90
|
-
uri = part.url
|
|
91
|
-
else:
|
|
92
|
-
downloaded_image.seek(0)
|
|
93
|
-
file_ref = client.files.upload(
|
|
94
|
-
file=downloaded_image, config={"mime_type": media_type}
|
|
95
|
-
)
|
|
96
|
-
uri = file_ref.uri
|
|
97
|
-
media_type = file_ref.mime_type
|
|
83
|
+
)
|
|
98
84
|
else:
|
|
99
|
-
|
|
100
|
-
media_type =
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
file_data=FileDataDict(file_uri=uri, mime_type=media_type)
|
|
85
|
+
downloaded_image = _load_media(part.url)
|
|
86
|
+
media_type = f"image/{get_image_type(downloaded_image)}"
|
|
87
|
+
_check_image_media_type(media_type)
|
|
88
|
+
blob_dict = BlobDict(
|
|
89
|
+
data=downloaded_image, mime_type=media_type
|
|
105
90
|
)
|
|
106
|
-
|
|
91
|
+
converted_content.append(PartDict(inline_data=blob_dict))
|
|
92
|
+
image_size = len(downloaded_image)
|
|
93
|
+
total_payload_size += image_size
|
|
94
|
+
if _over_file_size_limit(total_payload_size):
|
|
95
|
+
must_upload[index] = blob_dict
|
|
96
|
+
total_payload_size -= image_size
|
|
107
97
|
elif part.type == "audio":
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
"audio/ogg",
|
|
114
|
-
"audio/flac",
|
|
115
|
-
]:
|
|
116
|
-
raise ValueError(
|
|
117
|
-
f"Unsupported audio media type: {part.media_type}. "
|
|
118
|
-
"Google currently only supports WAV, MP3, AIFF, AAC, OGG, "
|
|
119
|
-
"and FLAC audio file types."
|
|
120
|
-
)
|
|
121
|
-
converted_content.append(
|
|
122
|
-
PartDict(
|
|
123
|
-
inline_data=BlobDict(
|
|
124
|
-
data=part.audio
|
|
125
|
-
if isinstance(part.audio, bytes)
|
|
126
|
-
else base64.b64decode(part.audio),
|
|
127
|
-
mime_type=part.media_type,
|
|
128
|
-
)
|
|
129
|
-
)
|
|
98
|
+
_check_audio_media_type(part.media_type)
|
|
99
|
+
audio_data = (
|
|
100
|
+
part.audio
|
|
101
|
+
if isinstance(part.audio, bytes)
|
|
102
|
+
else base64.b64decode(part.audio)
|
|
130
103
|
)
|
|
104
|
+
blob_dict = BlobDict(data=audio_data, mime_type=part.media_type)
|
|
105
|
+
converted_content.append(PartDict(inline_data=blob_dict))
|
|
106
|
+
audio_size = len(audio_data)
|
|
107
|
+
total_payload_size += audio_size
|
|
108
|
+
if _over_file_size_limit(total_payload_size):
|
|
109
|
+
must_upload[index] = blob_dict
|
|
110
|
+
total_payload_size -= audio_size
|
|
131
111
|
elif part.type == "audio_url":
|
|
132
112
|
if (
|
|
133
|
-
|
|
134
|
-
|
|
113
|
+
client.vertexai
|
|
114
|
+
or not part.url.startswith(("https://", "http://"))
|
|
115
|
+
or "generativelanguage.googleapis.com" in part.url
|
|
135
116
|
):
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
"audio/aiff",
|
|
142
|
-
"audio/aac",
|
|
143
|
-
"audio/ogg",
|
|
144
|
-
"audio/flac",
|
|
145
|
-
]:
|
|
146
|
-
raise ValueError(
|
|
147
|
-
f"Unsupported audio media type: {audio_type}. "
|
|
148
|
-
"Google currently only supports WAV, MP3, AIFF, AAC, OGG, "
|
|
149
|
-
"and FLAC audio file types."
|
|
117
|
+
converted_content.append(
|
|
118
|
+
PartDict(
|
|
119
|
+
file_data=FileDataDict(
|
|
120
|
+
file_uri=part.url, mime_type=None
|
|
121
|
+
)
|
|
150
122
|
)
|
|
151
|
-
|
|
152
|
-
uri = part.url
|
|
153
|
-
else:
|
|
154
|
-
downloaded_audio = io.BytesIO(downloaded_audio)
|
|
155
|
-
downloaded_audio.seek(0)
|
|
156
|
-
file_ref = client.files.upload(
|
|
157
|
-
file=downloaded_audio, config={"mime_type": audio_type}
|
|
158
|
-
)
|
|
159
|
-
uri = file_ref.uri
|
|
160
|
-
media_type = file_ref.mime_type
|
|
123
|
+
)
|
|
161
124
|
else:
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
file_data=FileDataDict(file_uri=uri, mime_type=audio_type)
|
|
125
|
+
downloaded_audio = _load_media(part.url)
|
|
126
|
+
media_type = f"audio/{get_audio_type(downloaded_audio)}"
|
|
127
|
+
_check_audio_media_type(media_type)
|
|
128
|
+
blob_dict = BlobDict(
|
|
129
|
+
data=downloaded_audio, mime_type=media_type
|
|
168
130
|
)
|
|
169
|
-
|
|
131
|
+
converted_content.append(PartDict(inline_data=blob_dict))
|
|
132
|
+
audio_size = len(downloaded_audio)
|
|
133
|
+
total_payload_size += audio_size
|
|
134
|
+
if _over_file_size_limit(total_payload_size):
|
|
135
|
+
must_upload[index] = blob_dict
|
|
136
|
+
total_payload_size -= audio_size
|
|
170
137
|
else:
|
|
171
138
|
raise ValueError(
|
|
172
139
|
"Google currently only supports text, image, and audio parts. "
|
|
173
140
|
f"Part provided: {part.type}"
|
|
174
141
|
)
|
|
142
|
+
|
|
143
|
+
if must_upload:
|
|
144
|
+
indices, blob_dicts = zip(*must_upload.items(), strict=True)
|
|
145
|
+
upload_tasks = [
|
|
146
|
+
client.aio.files.upload(
|
|
147
|
+
file=io.BytesIO(blob_dict["data"]),
|
|
148
|
+
config={"mime_type": blob_dict.get("mime_type", None)},
|
|
149
|
+
)
|
|
150
|
+
for blob_dict in blob_dicts
|
|
151
|
+
]
|
|
152
|
+
file_refs = await asyncio.gather(*upload_tasks)
|
|
153
|
+
for index, file_ref in zip(indices, file_refs, strict=True):
|
|
154
|
+
converted_content[index] = PartDict(
|
|
155
|
+
file_data=FileDataDict(
|
|
156
|
+
file_uri=file_ref.uri, mime_type=file_ref.mime_type
|
|
157
|
+
)
|
|
158
|
+
)
|
|
159
|
+
|
|
175
160
|
converted_message_params.append(
|
|
176
161
|
{
|
|
177
162
|
"role": role if role == "user" else "model",
|
|
@@ -179,3 +164,23 @@ def convert_message_params(
|
|
|
179
164
|
}
|
|
180
165
|
)
|
|
181
166
|
return converted_message_params
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def convert_message_params(
|
|
170
|
+
message_params: list[BaseMessageParam | ContentDict], client: Client
|
|
171
|
+
) -> list[ContentDict]:
|
|
172
|
+
"""Convert message params to Google's ContentDict format.
|
|
173
|
+
|
|
174
|
+
If called from sync context, uses asyncio.run().
|
|
175
|
+
If called from async context, uses the current event loop.
|
|
176
|
+
"""
|
|
177
|
+
try:
|
|
178
|
+
asyncio.get_running_loop()
|
|
179
|
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
|
180
|
+
future = executor.submit(
|
|
181
|
+
asyncio.run, _convert_message_params_async(message_params, client)
|
|
182
|
+
)
|
|
183
|
+
return future.result()
|
|
184
|
+
except RuntimeError:
|
|
185
|
+
...
|
|
186
|
+
return asyncio.run(_convert_message_params_async(message_params, client))
|