crewplus 0.2.89__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- crewplus/__init__.py +10 -0
- crewplus/callbacks/__init__.py +1 -0
- crewplus/callbacks/async_langfuse_handler.py +166 -0
- crewplus/services/__init__.py +21 -0
- crewplus/services/azure_chat_model.py +145 -0
- crewplus/services/feedback.md +55 -0
- crewplus/services/feedback_manager.py +267 -0
- crewplus/services/gemini_chat_model.py +884 -0
- crewplus/services/init_services.py +57 -0
- crewplus/services/model_load_balancer.py +264 -0
- crewplus/services/schemas/feedback.py +61 -0
- crewplus/services/tracing_manager.py +182 -0
- crewplus/utils/__init__.py +4 -0
- crewplus/utils/schema_action.py +7 -0
- crewplus/utils/schema_document_updater.py +173 -0
- crewplus/utils/tracing_util.py +55 -0
- crewplus/vectorstores/milvus/__init__.py +5 -0
- crewplus/vectorstores/milvus/milvus_schema_manager.py +270 -0
- crewplus/vectorstores/milvus/schema_milvus.py +586 -0
- crewplus/vectorstores/milvus/vdb_service.py +917 -0
- crewplus-0.2.89.dist-info/METADATA +144 -0
- crewplus-0.2.89.dist-info/RECORD +29 -0
- crewplus-0.2.89.dist-info/WHEEL +4 -0
- crewplus-0.2.89.dist-info/entry_points.txt +4 -0
- crewplus-0.2.89.dist-info/licenses/LICENSE +21 -0
- docs/GeminiChatModel.md +247 -0
- docs/ModelLoadBalancer.md +134 -0
- docs/VDBService.md +238 -0
- docs/index.md +23 -0
|
@@ -0,0 +1,884 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import asyncio
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Dict, Iterator, List, Optional, AsyncIterator, Union, Tuple
|
|
5
|
+
from google import genai
|
|
6
|
+
from google.genai import types
|
|
7
|
+
from google.oauth2 import service_account
|
|
8
|
+
import base64
|
|
9
|
+
import requests
|
|
10
|
+
from langchain_core.language_models import BaseChatModel
|
|
11
|
+
from langchain_core.messages import (
|
|
12
|
+
AIMessage,
|
|
13
|
+
AIMessageChunk,
|
|
14
|
+
BaseMessage,
|
|
15
|
+
HumanMessage,
|
|
16
|
+
SystemMessage,
|
|
17
|
+
)
|
|
18
|
+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
19
|
+
from langchain_core.callbacks import (
|
|
20
|
+
CallbackManagerForLLMRun,
|
|
21
|
+
AsyncCallbackManagerForLLMRun
|
|
22
|
+
)
|
|
23
|
+
from pydantic import Field, SecretStr
|
|
24
|
+
from langchain_core.utils import convert_to_secret_str
|
|
25
|
+
from .tracing_manager import TracingManager, TracingContext
|
|
26
|
+
|
|
27
|
+
class GeminiChatModel(BaseChatModel):
|
|
28
|
+
"""Custom chat model for Google Gemini, supporting text, image, and video.
|
|
29
|
+
|
|
30
|
+
This model provides a robust interface to Google's Gemini Pro and Flash models,
|
|
31
|
+
handling various data formats for multimodal inputs while maintaining compatibility
|
|
32
|
+
with the LangChain ecosystem.
|
|
33
|
+
|
|
34
|
+
It supports standard invocation, streaming, and asynchronous operations.
|
|
35
|
+
By default, it uses Google AI with an API key. It can also be configured to use
|
|
36
|
+
Google Cloud Vertex AI.
|
|
37
|
+
|
|
38
|
+
**Authentication:**
|
|
39
|
+
- **Google AI (Default):** The `google_api_key` parameter or the `GOOGLE_API_KEY`
|
|
40
|
+
environment variable is used.
|
|
41
|
+
- **Vertex AI:** To use Vertex AI, set `use_vertex_ai=True` and provide
|
|
42
|
+
GCP configuration (`project_id`, `location`). Authentication is handled
|
|
43
|
+
via `service_account_file`, `credentials`, or Application Default Credentials (ADC).
|
|
44
|
+
|
|
45
|
+
**Tracing Integration:**
|
|
46
|
+
Tracing (e.g., with Langfuse) is automatically enabled when the respective
|
|
47
|
+
environment variables are set. For Langfuse:
|
|
48
|
+
- LANGFUSE_PUBLIC_KEY: Your Langfuse public key
|
|
49
|
+
- LANGFUSE_SECRET_KEY: Your Langfuse secret key
|
|
50
|
+
- LANGFUSE_HOST: Langfuse host URL (optional, defaults to https://cloud.langfuse.com)
|
|
51
|
+
|
|
52
|
+
You can also configure it explicitly or disable it. Session and user tracking
|
|
53
|
+
can be set per call via metadata.
|
|
54
|
+
|
|
55
|
+
Attributes:
|
|
56
|
+
model_name (str): The Google model name to use (e.g., "gemini-1.5-flash").
|
|
57
|
+
google_api_key (Optional[SecretStr]): Your Google API key.
|
|
58
|
+
temperature (Optional[float]): The sampling temperature for generation.
|
|
59
|
+
max_tokens (Optional[int]): The maximum number of tokens to generate.
|
|
60
|
+
top_p (Optional[float]): The top-p (nucleus) sampling parameter.
|
|
61
|
+
top_k (Optional[int]): The top-k sampling parameter.
|
|
62
|
+
logger (Optional[logging.Logger]): An optional logger instance.
|
|
63
|
+
enable_tracing (Optional[bool]): Enable/disable all tracing (auto-detect if None).
|
|
64
|
+
use_vertex_ai (bool): If True, uses Vertex AI instead of Google AI Platform. Defaults to False.
|
|
65
|
+
project_id (Optional[str]): GCP Project ID, required for Vertex AI.
|
|
66
|
+
location (Optional[str]): GCP Location for Vertex AI (e.g., "us-central1").
|
|
67
|
+
service_account_file (Optional[str]): Path to GCP service account JSON for Vertex AI.
|
|
68
|
+
credentials (Optional[Any]): GCP credentials object for Vertex AI (alternative to file).
|
|
69
|
+
|
|
70
|
+
Example:
|
|
71
|
+
.. code-block:: python
|
|
72
|
+
|
|
73
|
+
# Set Langfuse environment variables (optional)
|
|
74
|
+
import os
|
|
75
|
+
os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-..."
|
|
76
|
+
os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-..."
|
|
77
|
+
os.environ["LANGFUSE_HOST"] = "https://cloud.langfuse.com" # EU region or self-hosted
|
|
78
|
+
# os.environ["LANGFUSE_HOST"] = "https://us.cloud.langfuse.com" # US region
|
|
79
|
+
|
|
80
|
+
from crewplus.services import GeminiChatModel
|
|
81
|
+
from langchain_core.messages import HumanMessage
|
|
82
|
+
import base64
|
|
83
|
+
import logging
|
|
84
|
+
|
|
85
|
+
# Initialize the model with optional logger
|
|
86
|
+
logger = logging.getLogger("my_app.gemini")
|
|
87
|
+
model = GeminiChatModel(model_name="gemini-2.0-flash", logger=logger)
|
|
88
|
+
|
|
89
|
+
# --- Text-only usage (automatically traced if env vars set) ---
|
|
90
|
+
response = model.invoke("Hello, how are you?")
|
|
91
|
+
print("Text response:", response.content)
|
|
92
|
+
|
|
93
|
+
# --- Tracing with session/user tracking (for Langfuse) ---
|
|
94
|
+
response = model.invoke(
|
|
95
|
+
"What is AI?",
|
|
96
|
+
config={
|
|
97
|
+
"metadata": {
|
|
98
|
+
"langfuse_session_id": "chat-session-123",
|
|
99
|
+
"langfuse_user_id": "user-456"
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# --- Image processing with base64 data URI ---
|
|
105
|
+
# Replace with a path to your image
|
|
106
|
+
image_path = "path/to/your/image.jpg"
|
|
107
|
+
try:
|
|
108
|
+
with open(image_path, "rb") as image_file:
|
|
109
|
+
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
|
110
|
+
|
|
111
|
+
image_message = HumanMessage(
|
|
112
|
+
content=[
|
|
113
|
+
{"type": "text", "text": "What is in this image?"},
|
|
114
|
+
{
|
|
115
|
+
"type": "image_url",
|
|
116
|
+
"image_url": {
|
|
117
|
+
"url": f"data:image/jpeg;base64,{encoded_string}"
|
|
118
|
+
}
|
|
119
|
+
},
|
|
120
|
+
]
|
|
121
|
+
)
|
|
122
|
+
image_response = model.invoke([image_message])
|
|
123
|
+
print("Image response (base64):", image_response.content)
|
|
124
|
+
except FileNotFoundError:
|
|
125
|
+
print(f"Image file not found at {image_path}, skipping base64 example.")
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
# --- Image processing with URL ---
|
|
129
|
+
url_message = HumanMessage(
|
|
130
|
+
content=[
|
|
131
|
+
{"type": "text", "text": "Describe this image:"},
|
|
132
|
+
{
|
|
133
|
+
"type": "image_url",
|
|
134
|
+
"image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
|
135
|
+
},
|
|
136
|
+
]
|
|
137
|
+
)
|
|
138
|
+
url_response = model.invoke([url_message])
|
|
139
|
+
print("Image response (URL):", url_response.content)
|
|
140
|
+
|
|
141
|
+
# --- Video processing with file path (>=20MB) ---
|
|
142
|
+
video_path = "path/to/your/video.mp4"
|
|
143
|
+
video_file = client.files.upload(file=video_path)
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
video_message = HumanMessage(
|
|
147
|
+
content=[
|
|
148
|
+
{"type": "text", "text": "Summarize this video."},
|
|
149
|
+
{"type": "video_file", "file": video_file},
|
|
150
|
+
]
|
|
151
|
+
)
|
|
152
|
+
video_response = model.invoke([video_message])
|
|
153
|
+
print("Video response (file path):", video_response.content)
|
|
154
|
+
except Exception as e:
|
|
155
|
+
print(f"Video processing with file path failed: {e}")
|
|
156
|
+
|
|
157
|
+
# --- Video processing with raw bytes (<20MB) ---
|
|
158
|
+
video_path = "path/to/your/video.mp4"
|
|
159
|
+
try:
|
|
160
|
+
with open(video_path, "rb") as video_file:
|
|
161
|
+
video_bytes = video_file.read()
|
|
162
|
+
|
|
163
|
+
video_message = HumanMessage(
|
|
164
|
+
content=[
|
|
165
|
+
{"type": "text", "text": "What is happening in this video?"},
|
|
166
|
+
{
|
|
167
|
+
"type": "video_file",
|
|
168
|
+
"data": video_bytes,
|
|
169
|
+
"mime_type": "video/mp4"
|
|
170
|
+
},
|
|
171
|
+
]
|
|
172
|
+
)
|
|
173
|
+
video_response = model.invoke([video_message])
|
|
174
|
+
print("Video response (bytes):", video_response.content)
|
|
175
|
+
except FileNotFoundError:
|
|
176
|
+
print(f"Video file not found at {video_path}, skipping bytes example.")
|
|
177
|
+
except Exception as e:
|
|
178
|
+
print(f"Video processing with bytes failed: {e}")
|
|
179
|
+
|
|
180
|
+
# --- Streaming usage (works with text, images, and video) ---
|
|
181
|
+
print("Streaming response:")
|
|
182
|
+
for chunk in model.stream([url_message]):
|
|
183
|
+
print(chunk.content, end="", flush=True)
|
|
184
|
+
|
|
185
|
+
# --- Traditional Langfuse callback approach still works ---
|
|
186
|
+
from langfuse.langchain import CallbackHandler
|
|
187
|
+
langfuse_handler = CallbackHandler(
|
|
188
|
+
session_id="session-123",
|
|
189
|
+
user_id="user-456"
|
|
190
|
+
)
|
|
191
|
+
response = model.invoke(
|
|
192
|
+
"Hello with manual callback",
|
|
193
|
+
config={"callbacks": [langfuse_handler]}
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# --- Disable Langfuse for specific calls ---
|
|
197
|
+
response = model.invoke(
|
|
198
|
+
"Hello without tracing",
|
|
199
|
+
config={"metadata": {"tracing_disabled": True}}
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
Example (Vertex AI):
|
|
203
|
+
.. code-block:: python
|
|
204
|
+
|
|
205
|
+
# Assumes GCP environment is configured (e.g., gcloud auth application-default login)
|
|
206
|
+
# or environment variables are set:
|
|
207
|
+
# os.environ["GCP_PROJECT_ID"] = "your-gcp-project-id"
|
|
208
|
+
# os.environ["GCP_LOCATION"] = "us-central1"
|
|
209
|
+
# os.environ["GCP_SERVICE_ACCOUNT_FILE"] = "path/to/your/service-account-key.json"
|
|
210
|
+
|
|
211
|
+
vertex_model = GeminiChatModel(
|
|
212
|
+
model_name="gemini-1.5-flash-001",
|
|
213
|
+
use_vertex_ai=True,
|
|
214
|
+
)
|
|
215
|
+
response = vertex_model.invoke("Hello from Vertex AI!")
|
|
216
|
+
print(response.content)
|
|
217
|
+
"""
|
|
218
|
+
|
|
219
|
+
# Model configuration
|
|
220
|
+
model_name: str = Field(default="gemini-2.5-flash", description="The Google model name to use")
|
|
221
|
+
google_api_key: Optional[SecretStr] = Field(default=None, description="Google API key")
|
|
222
|
+
temperature: Optional[float] = Field(default=0.7, description="Sampling temperature")
|
|
223
|
+
max_tokens: Optional[int] = Field(default=None, description="Maximum tokens to generate")
|
|
224
|
+
top_p: Optional[float] = Field(default=None, description="Top-p sampling parameter")
|
|
225
|
+
top_k: Optional[int] = Field(default=None, description="Top-k sampling parameter")
|
|
226
|
+
|
|
227
|
+
# Vertex AI specific configuration
|
|
228
|
+
use_vertex_ai: bool = Field(default=False, description="Use Vertex AI instead of Google AI Platform")
|
|
229
|
+
project_id: Optional[str] = Field(default=None, description="Google Cloud Project ID for Vertex AI")
|
|
230
|
+
location: Optional[str] = Field(default=None, description="Google Cloud Location for Vertex AI (e.g., 'us-central1')")
|
|
231
|
+
service_account_file: Optional[str] = Field(default=None, description="Path to Google Cloud service account key file")
|
|
232
|
+
credentials: Optional[Any] = Field(default=None, description="Google Cloud credentials object", exclude=True)
|
|
233
|
+
|
|
234
|
+
# Configuration for tracing and logging
|
|
235
|
+
logger: Optional[logging.Logger] = Field(default=None, description="Optional logger instance", exclude=True)
|
|
236
|
+
enable_tracing: Optional[bool] = Field(default=None, description="Enable tracing (auto-detect if None)")
|
|
237
|
+
|
|
238
|
+
# Internal clients and managers
|
|
239
|
+
_client: Optional[genai.Client] = None
|
|
240
|
+
_tracing_manager: Optional[TracingManager] = None
|
|
241
|
+
|
|
242
|
+
def __init__(self, **kwargs):
|
|
243
|
+
super().__init__(**kwargs)
|
|
244
|
+
|
|
245
|
+
# Initialize logger
|
|
246
|
+
if self.logger is None:
|
|
247
|
+
self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
|
|
248
|
+
if not self.logger.handlers:
|
|
249
|
+
self.logger.addHandler(logging.StreamHandler())
|
|
250
|
+
self.logger.setLevel(logging.INFO)
|
|
251
|
+
|
|
252
|
+
self._initialize_client()
|
|
253
|
+
|
|
254
|
+
self._tracing_manager = TracingManager(self)
|
|
255
|
+
|
|
256
|
+
def _initialize_client(self):
|
|
257
|
+
"""Initializes the Google GenAI client for either Google AI or Vertex AI."""
|
|
258
|
+
if self.use_vertex_ai:
|
|
259
|
+
self._init_vertex_ai_client()
|
|
260
|
+
else:
|
|
261
|
+
self._init_google_ai_client()
|
|
262
|
+
|
|
263
|
+
def _init_google_ai_client(self):
|
|
264
|
+
"""Initializes the client for Google AI Platform."""
|
|
265
|
+
# Get API key from environment if not provided
|
|
266
|
+
if self.google_api_key is None:
|
|
267
|
+
api_key = os.getenv("GOOGLE_API_KEY")
|
|
268
|
+
if api_key:
|
|
269
|
+
self.google_api_key = convert_to_secret_str(api_key)
|
|
270
|
+
|
|
271
|
+
# Initialize the Google GenAI client
|
|
272
|
+
if self.google_api_key:
|
|
273
|
+
self._client = genai.Client(api_key=self.google_api_key.get_secret_value())
|
|
274
|
+
self.logger.info(f"Initialized GeminiChatModel with model: {self.model_name} for Google AI")
|
|
275
|
+
else:
|
|
276
|
+
error_msg = "Google API key is required. Set GOOGLE_API_KEY environment variable or pass google_api_key parameter."
|
|
277
|
+
self.logger.error(error_msg)
|
|
278
|
+
raise ValueError(error_msg)
|
|
279
|
+
|
|
280
|
+
def _init_vertex_ai_client(self):
|
|
281
|
+
"""Initializes the client for Vertex AI."""
|
|
282
|
+
# Get config from environment if not provided
|
|
283
|
+
if self.project_id is None:
|
|
284
|
+
self.project_id = os.getenv("GCP_PROJECT_ID")
|
|
285
|
+
if self.location is None:
|
|
286
|
+
self.location = os.getenv("GCP_LOCATION")
|
|
287
|
+
|
|
288
|
+
if not self.project_id or not self.location:
|
|
289
|
+
error_msg = "For Vertex AI, 'project_id' and 'location' are required."
|
|
290
|
+
self.logger.error(error_msg)
|
|
291
|
+
raise ValueError(error_msg)
|
|
292
|
+
|
|
293
|
+
creds = self.credentials
|
|
294
|
+
if creds is None:
|
|
295
|
+
# Get service account file from env if not provided
|
|
296
|
+
sa_file = self.service_account_file or os.getenv("GCP_SERVICE_ACCOUNT_FILE")
|
|
297
|
+
self.logger.debug(f"Service account file: {sa_file}")
|
|
298
|
+
if sa_file:
|
|
299
|
+
try:
|
|
300
|
+
creds = service_account.Credentials.from_service_account_file(
|
|
301
|
+
sa_file,
|
|
302
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform']
|
|
303
|
+
)
|
|
304
|
+
except Exception as e:
|
|
305
|
+
error_msg = f"Failed to load credentials from service account file '{sa_file}': {e}"
|
|
306
|
+
self.logger.error(error_msg)
|
|
307
|
+
raise ValueError(error_msg)
|
|
308
|
+
|
|
309
|
+
# If creds is still None, the client will use Application Default Credentials (ADC).
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
self._client = genai.Client(
|
|
313
|
+
vertexai=True,
|
|
314
|
+
project=self.project_id,
|
|
315
|
+
location=self.location,
|
|
316
|
+
credentials=creds,
|
|
317
|
+
)
|
|
318
|
+
self.logger.info(
|
|
319
|
+
f"Initialized GeminiChatModel with model: {self.model_name} for Vertex AI "
|
|
320
|
+
f"(Project: {self.project_id}, Location: {self.location})"
|
|
321
|
+
)
|
|
322
|
+
except Exception as e:
|
|
323
|
+
error_msg = f"Failed to initialize GenAI Client for Vertex AI: {e}"
|
|
324
|
+
self.logger.error(error_msg, exc_info=True)
|
|
325
|
+
raise ValueError(error_msg)
|
|
326
|
+
|
|
327
|
+
def get_model_identifier(self) -> str:
|
|
328
|
+
"""Return a string identifying this model for tracing and logging."""
|
|
329
|
+
return f"{self.__class__.__name__} (model='{self.model_name}')"
|
|
330
|
+
|
|
331
|
+
def invoke(self, input, config=None, **kwargs):
|
|
332
|
+
"""Override invoke to add tracing callbacks automatically."""
|
|
333
|
+
config = self._tracing_manager.add_sync_callbacks_to_config(config)
|
|
334
|
+
return super().invoke(input, config=config, **kwargs)
|
|
335
|
+
|
|
336
|
+
async def ainvoke(self, input, config=None, **kwargs):
|
|
337
|
+
"""Override ainvoke to add tracing callbacks automatically."""
|
|
338
|
+
config = self._tracing_manager.add_async_callbacks_to_config(config)
|
|
339
|
+
return await super().ainvoke(input, config=config, **kwargs)
|
|
340
|
+
|
|
341
|
+
def stream(self, input, config=None, **kwargs):
|
|
342
|
+
"""Override stream to add tracing callbacks automatically."""
|
|
343
|
+
config = self._tracing_manager.add_sync_callbacks_to_config(config)
|
|
344
|
+
return super().stream(input, config=config, **kwargs)
|
|
345
|
+
|
|
346
|
+
async def astream(self, input, config=None, **kwargs):
|
|
347
|
+
"""Override astream to add tracing callbacks automatically."""
|
|
348
|
+
config = self._tracing_manager.add_async_callbacks_to_config(config)
|
|
349
|
+
# We must call an async generator,
|
|
350
|
+
async for chunk in super().astream(input, config=config, **kwargs):
|
|
351
|
+
yield chunk
|
|
352
|
+
|
|
353
|
+
@property
|
|
354
|
+
def _llm_type(self) -> str:
|
|
355
|
+
"""Return identifier for the model type."""
|
|
356
|
+
return "custom_google_genai"
|
|
357
|
+
|
|
358
|
+
@property
|
|
359
|
+
def _identifying_params(self) -> Dict[str, Any]:
|
|
360
|
+
"""Return a dictionary of identifying parameters for tracing."""
|
|
361
|
+
return {
|
|
362
|
+
"model_name": self.model_name,
|
|
363
|
+
"temperature": self.temperature,
|
|
364
|
+
"max_tokens": self.max_tokens,
|
|
365
|
+
"top_p": self.top_p,
|
|
366
|
+
"top_k": self.top_k,
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
def _convert_messages(self, messages: List[BaseMessage]) -> Union[types.ContentListUnion, types.ContentListUnionDict]:
|
|
370
|
+
"""
|
|
371
|
+
Converts LangChain messages to a format suitable for the GenAI API.
|
|
372
|
+
- For single, multi-part HumanMessage, returns a direct list of parts (e.g., [File, "text"]).
|
|
373
|
+
- For multi-turn chats, returns a list of Content objects.
|
|
374
|
+
- For simple text, returns a string.
|
|
375
|
+
"""
|
|
376
|
+
self.logger.debug(f"Converting {len(messages)} messages.")
|
|
377
|
+
|
|
378
|
+
# Filter out system messages (handled in generation_config)
|
|
379
|
+
chat_messages = [msg for msg in messages if not isinstance(msg, SystemMessage)]
|
|
380
|
+
|
|
381
|
+
# Case 1: A single HumanMessage. This is the most common path for single prompts.
|
|
382
|
+
if len(chat_messages) == 1 and isinstance(chat_messages[0], HumanMessage):
|
|
383
|
+
content = chat_messages[0].content
|
|
384
|
+
# For a simple string, return it directly.
|
|
385
|
+
if isinstance(content, str):
|
|
386
|
+
return content
|
|
387
|
+
# For a list of parts, parse them into a direct list for the API.
|
|
388
|
+
return list(self._parse_message_content(content, is_simple=True))
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
# Case 2: Multi-turn chat history. This requires a list of Content objects.
|
|
392
|
+
self.logger.debug("Handling as a multi-turn chat conversation.")
|
|
393
|
+
genai_contents: List[types.Content] = []
|
|
394
|
+
for msg in chat_messages:
|
|
395
|
+
role = "model" if isinstance(msg, AIMessage) else "user"
|
|
396
|
+
parts = []
|
|
397
|
+
|
|
398
|
+
# Process each part and ensure proper typing
|
|
399
|
+
for part in self._parse_message_content(msg.content, is_simple=False):
|
|
400
|
+
if isinstance(part, types.File):
|
|
401
|
+
# put File directly into types.Content
|
|
402
|
+
parts.append(part)
|
|
403
|
+
elif isinstance(part, types.Part):
|
|
404
|
+
parts.append(part)
|
|
405
|
+
else:
|
|
406
|
+
self.logger.warning(f"Unexpected part type: {type(part)}")
|
|
407
|
+
|
|
408
|
+
if parts:
|
|
409
|
+
genai_contents.append(types.Content(parts=parts, role=role))
|
|
410
|
+
|
|
411
|
+
# If there's only one Content object, return it directly instead of a list
|
|
412
|
+
if len(genai_contents) == 1:
|
|
413
|
+
return genai_contents[0]
|
|
414
|
+
|
|
415
|
+
return genai_contents
|
|
416
|
+
|
|
417
|
+
def _create_image_part(self, image_info: Dict[str, Any]) -> Union[types.Part, types.File]:
|
|
418
|
+
"""Creates a GenAI Part or File from various image source formats."""
|
|
419
|
+
self.logger.debug(f"Creating image part from info: {list(image_info.keys())}")
|
|
420
|
+
|
|
421
|
+
if "path" in image_info:
|
|
422
|
+
return self._client.files.upload(file=image_info["path"])
|
|
423
|
+
|
|
424
|
+
if "data" in image_info:
|
|
425
|
+
data = image_info["data"]
|
|
426
|
+
if image_info.get("source_type") == "base64":
|
|
427
|
+
data = base64.b64decode(data)
|
|
428
|
+
return types.Part.from_bytes(data=data, mime_type=image_info["mime_type"])
|
|
429
|
+
|
|
430
|
+
url = image_info.get("image_url", image_info.get("url"))
|
|
431
|
+
if isinstance(url, dict):
|
|
432
|
+
url = url.get("url")
|
|
433
|
+
|
|
434
|
+
if not url:
|
|
435
|
+
raise ValueError(f"Invalid image info, requires 'path', 'data', or 'url'. Received: {image_info}")
|
|
436
|
+
|
|
437
|
+
if url.startswith("data:"):
|
|
438
|
+
header, encoded = url.split(",", 1)
|
|
439
|
+
mime_type = header.split(":", 1)[-1].split(";", 1)[0]
|
|
440
|
+
image_data = base64.b64decode(encoded)
|
|
441
|
+
return types.Part.from_bytes(data=image_data, mime_type=mime_type)
|
|
442
|
+
else:
|
|
443
|
+
response = requests.get(url)
|
|
444
|
+
response.raise_for_status()
|
|
445
|
+
mime_type = response.headers.get("Content-Type", "image/jpeg")
|
|
446
|
+
return types.Part.from_bytes(data=response.content, mime_type=mime_type)
|
|
447
|
+
|
|
448
|
+
def _create_video_part(self, video_info: Dict[str, Any]) -> Union[types.Part, types.File]:
|
|
449
|
+
"""Creates a Google GenAI Part or File from video information.
|
|
450
|
+
|
|
451
|
+
Supports multiple video input formats:
|
|
452
|
+
- File object: {"type": "video_file", "file": file_object}
|
|
453
|
+
- File path: {"type": "video_file", "path": "/path/to/video.mp4"}
|
|
454
|
+
- Raw bytes: {"type": "video_file", "data": video_bytes, "mime_type": "video/mp4"}
|
|
455
|
+
- URL/URI: {"type": "video_file", "url": "https://example.com/video.mp4"}
|
|
456
|
+
- YouTube URL: {"type": "video_file", "url": "https://www.youtube.com/watch?v=..."}
|
|
457
|
+
- URL with offset: {"type": "video_file", "url": "...", "start_offset": "12s", "end_offset": "50s"}
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
video_info: Dictionary containing video information
|
|
461
|
+
|
|
462
|
+
Returns:
|
|
463
|
+
Either a types.Part or File object for Google GenAI
|
|
464
|
+
|
|
465
|
+
Raises:
|
|
466
|
+
FileNotFoundError: If video file path doesn't exist
|
|
467
|
+
ValueError: If video_info is invalid or missing required fields
|
|
468
|
+
"""
|
|
469
|
+
self.logger.debug(f"Creating video part from info: {list(video_info.keys())}")
|
|
470
|
+
|
|
471
|
+
# Handle pre-uploaded file object
|
|
472
|
+
if "file" in video_info:
|
|
473
|
+
if isinstance(video_info["file"], types.File):
|
|
474
|
+
return video_info["file"]
|
|
475
|
+
else:
|
|
476
|
+
raise ValueError(f"The 'file' key must contain a google.genai.File object, but got {type(video_info['file'])}")
|
|
477
|
+
|
|
478
|
+
if "path" in video_info:
|
|
479
|
+
self.logger.debug(f"Uploading video file from path: {video_info['path']}")
|
|
480
|
+
|
|
481
|
+
uploaded_file =self._client.files.upload(file=video_info["path"])
|
|
482
|
+
|
|
483
|
+
self.logger.debug(f"Uploaded video file: {uploaded_file}")
|
|
484
|
+
|
|
485
|
+
return uploaded_file
|
|
486
|
+
|
|
487
|
+
mime_type = video_info.get("mime_type")
|
|
488
|
+
|
|
489
|
+
if "data" in video_info:
|
|
490
|
+
data = video_info["data"]
|
|
491
|
+
if not mime_type:
|
|
492
|
+
raise ValueError("'mime_type' is required when providing video data.")
|
|
493
|
+
max_size = 20 * 1024 * 1024 # 20MB
|
|
494
|
+
if len(data) > max_size:
|
|
495
|
+
raise ValueError(f"Video data size ({len(data)} bytes) exceeds 20MB limit for inline data.")
|
|
496
|
+
return types.Part(inline_data=types.Blob(data=data, mime_type=mime_type))
|
|
497
|
+
|
|
498
|
+
url = video_info.get("url")
|
|
499
|
+
if not url:
|
|
500
|
+
raise ValueError(f"Invalid video info, requires 'path', 'data', 'url', or 'file'. Received: {video_info}")
|
|
501
|
+
|
|
502
|
+
mime_type = video_info.get("mime_type", "video/mp4")
|
|
503
|
+
|
|
504
|
+
# Handle video offsets
|
|
505
|
+
start_offset = video_info.get("start_offset")
|
|
506
|
+
end_offset = video_info.get("end_offset")
|
|
507
|
+
|
|
508
|
+
self.logger.debug(f"Video offsets: {start_offset} to {end_offset}.")
|
|
509
|
+
|
|
510
|
+
if start_offset or end_offset:
|
|
511
|
+
video_metadata = types.VideoMetadata(start_offset=start_offset, end_offset=end_offset)
|
|
512
|
+
return types.Part(
|
|
513
|
+
file_data=types.FileData(file_uri=url, mime_type=mime_type),
|
|
514
|
+
video_metadata=video_metadata
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
return types.Part(file_data=types.FileData(file_uri=url, mime_type=mime_type))
|
|
518
|
+
|
|
519
|
+
def _parse_message_content(
|
|
520
|
+
self, content: Union[str, List[Union[str, Dict]]], *, is_simple: bool = True
|
|
521
|
+
) -> Iterator[Union[str, types.Part, types.File]]:
|
|
522
|
+
"""
|
|
523
|
+
Parses LangChain message content and yields parts for Google GenAI.
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
content: The message content to parse.
|
|
527
|
+
is_simple: If True, yields raw objects where possible (e.g., str, File)
|
|
528
|
+
for single-turn efficiency. If False, ensures all yielded
|
|
529
|
+
parts are `types.Part` by converting raw strings and
|
|
530
|
+
Files as needed, which is required for multi-turn chat.
|
|
531
|
+
|
|
532
|
+
Supports both standard LangChain formats and enhanced video formats:
|
|
533
|
+
- Text: "string" or {"type": "text", "text": "content"}
|
|
534
|
+
- Image: {"type": "image_url", "image_url": "url"} or {"type": "image_url", "image_url": {"url": "url"}}
|
|
535
|
+
- Video: {"type": "video_file", ...} or {"type": "video", ...}
|
|
536
|
+
"""
|
|
537
|
+
if isinstance(content, str):
|
|
538
|
+
yield content if is_simple else types.Part(text=content)
|
|
539
|
+
return
|
|
540
|
+
|
|
541
|
+
if not isinstance(content, list):
|
|
542
|
+
self.logger.warning(f"Unsupported content format: {type(content)}")
|
|
543
|
+
return
|
|
544
|
+
|
|
545
|
+
for i, part_spec in enumerate(content):
|
|
546
|
+
try:
|
|
547
|
+
if isinstance(part_spec, str):
|
|
548
|
+
yield part_spec if is_simple else types.Part(text=part_spec)
|
|
549
|
+
continue
|
|
550
|
+
|
|
551
|
+
if isinstance(part_spec, types.File):
|
|
552
|
+
if is_simple:
|
|
553
|
+
yield part_spec
|
|
554
|
+
else:
|
|
555
|
+
yield types.Part(file_data=types.FileData(
|
|
556
|
+
mime_type=part_spec.mime_type,
|
|
557
|
+
file_uri=part_spec.uri
|
|
558
|
+
))
|
|
559
|
+
continue
|
|
560
|
+
|
|
561
|
+
if not isinstance(part_spec, dict):
|
|
562
|
+
self.logger.warning(f"Skipping non-dict part in content list: {type(part_spec)}")
|
|
563
|
+
continue
|
|
564
|
+
|
|
565
|
+
part_type = part_spec.get("type", "").lower()
|
|
566
|
+
|
|
567
|
+
if part_type == "text":
|
|
568
|
+
if text_content := part_spec.get("text"):
|
|
569
|
+
yield text_content if is_simple else types.Part(text=text_content)
|
|
570
|
+
elif part_type in ("image", "image_url"):
|
|
571
|
+
yield self._create_image_part(part_spec)
|
|
572
|
+
elif part_type in ("video", "video_file"):
|
|
573
|
+
yield self._create_video_part(part_spec)
|
|
574
|
+
else:
|
|
575
|
+
self.logger.debug(f"Part with unknown type '{part_type}' was ignored at index {i}.")
|
|
576
|
+
except Exception as e:
|
|
577
|
+
self.logger.error(f"Failed to process message part at index {i}: {part_spec}. Error: {e}", exc_info=True)
|
|
578
|
+
|
|
579
|
+
def _prepare_generation_config(
|
|
580
|
+
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
|
581
|
+
) -> Dict[str, Any]:
|
|
582
|
+
"""Prepares the generation configuration, including system instructions."""
|
|
583
|
+
# Base config from model parameters
|
|
584
|
+
config = {
|
|
585
|
+
"temperature": self.temperature,
|
|
586
|
+
"max_output_tokens": self.max_tokens,
|
|
587
|
+
"top_p": self.top_p,
|
|
588
|
+
"top_k": self.top_k,
|
|
589
|
+
}
|
|
590
|
+
if stop:
|
|
591
|
+
config["stop_sequences"] = stop
|
|
592
|
+
|
|
593
|
+
# Handle system instructions
|
|
594
|
+
system_prompts = [msg.content for msg in messages if isinstance(msg, SystemMessage) and msg.content]
|
|
595
|
+
if system_prompts:
|
|
596
|
+
system_prompt_str = "\n\n".join(system_prompts)
|
|
597
|
+
config["system_instruction"] = system_prompt_str
|
|
598
|
+
|
|
599
|
+
# Filter out None values before returning
|
|
600
|
+
return {k: v for k, v in config.items() if v is not None}
|
|
601
|
+
|
|
602
|
+
def _trim_for_logging(self, contents: Any) -> Any:
|
|
603
|
+
"""Helper to trim large binary data from logging payloads."""
|
|
604
|
+
if isinstance(contents, str):
|
|
605
|
+
return contents
|
|
606
|
+
|
|
607
|
+
if isinstance(contents, types.Content):
|
|
608
|
+
return {
|
|
609
|
+
"role": contents.role,
|
|
610
|
+
"parts": [self._trim_part(part) for part in contents.parts]
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
if isinstance(contents, list):
|
|
614
|
+
return [self._trim_for_logging(item) for item in contents]
|
|
615
|
+
|
|
616
|
+
return contents
|
|
617
|
+
|
|
618
|
+
def _trim_part(self, part: types.Part) -> dict:
|
|
619
|
+
"""Trims individual part data for safe logging."""
|
|
620
|
+
part_dict = {}
|
|
621
|
+
if part.text:
|
|
622
|
+
part_dict["text"] = part.text
|
|
623
|
+
if part.inline_data:
|
|
624
|
+
part_dict["inline_data"] = {
|
|
625
|
+
"mime_type": part.inline_data.mime_type,
|
|
626
|
+
"data_size": f"{len(part.inline_data.data)} bytes"
|
|
627
|
+
}
|
|
628
|
+
if part.file_data:
|
|
629
|
+
part_dict["file_data"] = {
|
|
630
|
+
"mime_type": part.file_data.mime_type,
|
|
631
|
+
"file_uri": part.file_data.file_uri
|
|
632
|
+
}
|
|
633
|
+
return part_dict
|
|
634
|
+
|
|
635
|
+
def _map_usage_metadata(self, usage_metadata: Any) -> Optional[dict]:
|
|
636
|
+
"""
|
|
637
|
+
Maps Google's rich usage metadata to LangChain's expected format,
|
|
638
|
+
including detailed breakdowns by modality.
|
|
639
|
+
"""
|
|
640
|
+
if not usage_metadata:
|
|
641
|
+
return None
|
|
642
|
+
|
|
643
|
+
# --- Basic Token Counts ---
|
|
644
|
+
input_tokens = getattr(usage_metadata, "prompt_token_count", 0)
|
|
645
|
+
output_tokens = getattr(usage_metadata, "candidates_token_count", 0)
|
|
646
|
+
thoughts_tokens = getattr(usage_metadata, "thoughts_token_count", 0)
|
|
647
|
+
total_tokens = getattr(usage_metadata, "total_token_count", 0)
|
|
648
|
+
|
|
649
|
+
# In some cases, total_tokens is not provided, so we calculate it
|
|
650
|
+
if total_tokens == 0 and (input_tokens > 0 or output_tokens > 0):
|
|
651
|
+
total_tokens = input_tokens + output_tokens
|
|
652
|
+
|
|
653
|
+
# --- Detailed Token Counts (The Fix) ---
|
|
654
|
+
input_details = {}
|
|
655
|
+
# The `prompt_tokens_details` is a list of ModalityTokenCount objects.
|
|
656
|
+
# We convert it to a dictionary.
|
|
657
|
+
if prompt_details_list := getattr(usage_metadata, "prompt_tokens_details", None):
|
|
658
|
+
for detail in prompt_details_list:
|
|
659
|
+
# Convert enum e.g., <MediaModality.TEXT: 'TEXT'> to "text"
|
|
660
|
+
modality_key = detail.modality.name.lower()
|
|
661
|
+
input_details[modality_key] = detail.token_count
|
|
662
|
+
|
|
663
|
+
# Add cached tokens to input details if present
|
|
664
|
+
#if cached_tokens := getattr(usage_metadata, "cached_content_token_count", 0):
|
|
665
|
+
# input_details["cached_content"] = cached_tokens
|
|
666
|
+
|
|
667
|
+
output_details = {}
|
|
668
|
+
# The `candidates_tokens_details` is also a list, so we convert it.
|
|
669
|
+
if candidate_details_list := getattr(usage_metadata, "candidates_tokens_details", None):
|
|
670
|
+
for detail in candidate_details_list:
|
|
671
|
+
modality_key = detail.modality.name.lower()
|
|
672
|
+
output_details[modality_key] = detail.token_count
|
|
673
|
+
|
|
674
|
+
# --- Construct the final dictionary ---
|
|
675
|
+
final_metadata = {
|
|
676
|
+
"input_tokens": input_tokens,
|
|
677
|
+
"output_tokens": output_tokens,
|
|
678
|
+
"thoughts_tokens": thoughts_tokens,
|
|
679
|
+
"total_tokens": total_tokens,
|
|
680
|
+
}
|
|
681
|
+
|
|
682
|
+
## COMMENTED BEGIN: This is not working as expected.
|
|
683
|
+
# if input_details:
|
|
684
|
+
# final_metadata["input_token_details"] = input_details
|
|
685
|
+
# if output_details:
|
|
686
|
+
# final_metadata["output_token_details"] = output_details
|
|
687
|
+
## COMMENTED END
|
|
688
|
+
|
|
689
|
+
return final_metadata
|
|
690
|
+
|
|
691
|
+
def _extract_usage_metadata(self, response) -> Optional[Any]:
|
|
692
|
+
"""Extracts the raw usage_metadata object from a Google GenAI response."""
|
|
693
|
+
if hasattr(response, 'usage_metadata') and response.usage_metadata:
|
|
694
|
+
self.logger.debug(f"[_extract_usage_metadata] Found usage_metadata: {response.usage_metadata}")
|
|
695
|
+
return response.usage_metadata
|
|
696
|
+
return None
|
|
697
|
+
|
|
698
|
+
def _create_chat_generation_chunk(self, chunk_response) -> ChatGenerationChunk:
|
|
699
|
+
"""Creates a ChatGenerationChunk for streaming."""
|
|
700
|
+
# For streaming, we do not include usage metadata in individual chunks
|
|
701
|
+
# to prevent merge conflicts. The final, aggregated response will contain
|
|
702
|
+
# the full usage details for callbacks like Langfuse.
|
|
703
|
+
return ChatGenerationChunk(
|
|
704
|
+
message=AIMessageChunk(
|
|
705
|
+
content=chunk_response.text,
|
|
706
|
+
response_metadata={"model_name": self.model_name},
|
|
707
|
+
),
|
|
708
|
+
generation_info=None,
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
def _create_chat_result_with_usage(self, response) -> ChatResult:
|
|
712
|
+
"""Creates a ChatResult with usage metadata for Langfuse tracking."""
|
|
713
|
+
generated_text = response.text
|
|
714
|
+
finish_reason = response.candidates[0].finish_reason.name if response.candidates else None
|
|
715
|
+
|
|
716
|
+
# Use the new mapping function here for invoke calls
|
|
717
|
+
usage_metadata = self._extract_usage_metadata(response)
|
|
718
|
+
usage_dict = self._map_usage_metadata(usage_metadata) or {}
|
|
719
|
+
|
|
720
|
+
message = AIMessage(
|
|
721
|
+
content=generated_text,
|
|
722
|
+
response_metadata={
|
|
723
|
+
"model_name": self.model_name,
|
|
724
|
+
"finish_reason": finish_reason,
|
|
725
|
+
**usage_dict
|
|
726
|
+
}
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
generation = ChatGeneration(
|
|
730
|
+
message=message,
|
|
731
|
+
generation_info={"token_usage": usage_dict} if usage_dict else None
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
# We also construct the llm_output dictionary in the format expected
|
|
735
|
+
# by LangChain callback handlers, with a specific "token_usage" key.
|
|
736
|
+
chat_result = ChatResult(
|
|
737
|
+
generations=[generation],
|
|
738
|
+
llm_output={
|
|
739
|
+
"token_usage": usage_dict,
|
|
740
|
+
"model_name": self.model_name
|
|
741
|
+
} if usage_dict else {
|
|
742
|
+
"model_name": self.model_name
|
|
743
|
+
}
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
return chat_result
|
|
747
|
+
|
|
748
|
+
def _generate(
|
|
749
|
+
self,
|
|
750
|
+
messages: List[BaseMessage],
|
|
751
|
+
stop: Optional[List[str]] = None,
|
|
752
|
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
753
|
+
**kwargs: Any,
|
|
754
|
+
) -> ChatResult:
|
|
755
|
+
"""Generates a chat response from a list of messages."""
|
|
756
|
+
self.logger.info(f"Generating response for {len(messages)} messages.")
|
|
757
|
+
|
|
758
|
+
# Remove the problematic add_handler call - callbacks are now handled in invoke methods
|
|
759
|
+
|
|
760
|
+
contents = self._convert_messages(messages)
|
|
761
|
+
config = self._prepare_generation_config(messages, stop)
|
|
762
|
+
|
|
763
|
+
try:
|
|
764
|
+
response = self._client.models.generate_content(
|
|
765
|
+
model=self.model_name,
|
|
766
|
+
contents=contents,
|
|
767
|
+
config=config,
|
|
768
|
+
**kwargs,
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
return self._create_chat_result_with_usage(response)
|
|
772
|
+
|
|
773
|
+
except Exception as e:
|
|
774
|
+
self.logger.error(f"Error generating content with Google GenAI: {e}", exc_info=True)
|
|
775
|
+
raise ValueError(f"Error during generation: {e}")
|
|
776
|
+
|
|
777
|
+
async def _agenerate(
|
|
778
|
+
self,
|
|
779
|
+
messages: List[BaseMessage],
|
|
780
|
+
stop: Optional[List[str]] = None,
|
|
781
|
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
782
|
+
**kwargs: Any,
|
|
783
|
+
) -> ChatResult:
|
|
784
|
+
"""Asynchronously generates a chat response."""
|
|
785
|
+
self.logger.info(f"Async generating response for {len(messages)} messages.")
|
|
786
|
+
|
|
787
|
+
contents = self._convert_messages(messages)
|
|
788
|
+
config = self._prepare_generation_config(messages, stop)
|
|
789
|
+
|
|
790
|
+
try:
|
|
791
|
+
response = await self._client.aio.models.generate_content(
|
|
792
|
+
model=self.model_name,
|
|
793
|
+
contents=contents,
|
|
794
|
+
config=config,
|
|
795
|
+
**kwargs,
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
return self._create_chat_result_with_usage(response)
|
|
799
|
+
|
|
800
|
+
except Exception as e:
|
|
801
|
+
self.logger.error(f"Error during async generation: {e}", exc_info=True)
|
|
802
|
+
raise ValueError(f"Error during async generation: {e}")
|
|
803
|
+
|
|
804
|
+
def _stream(
|
|
805
|
+
self,
|
|
806
|
+
messages: List[BaseMessage],
|
|
807
|
+
stop: Optional[List[str]] = None,
|
|
808
|
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
809
|
+
**kwargs: Any,
|
|
810
|
+
) -> Iterator[ChatGenerationChunk]:
|
|
811
|
+
"""Streams the chat response and properly handles final usage metadata."""
|
|
812
|
+
self.logger.info(f"Streaming response for {len(messages)} messages.")
|
|
813
|
+
|
|
814
|
+
contents = self._convert_messages(messages)
|
|
815
|
+
config = self._prepare_generation_config(messages, stop)
|
|
816
|
+
|
|
817
|
+
try:
|
|
818
|
+
stream = self._client.models.generate_content_stream(
|
|
819
|
+
model=self.model_name,
|
|
820
|
+
contents=contents,
|
|
821
|
+
config=config,
|
|
822
|
+
**kwargs,
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
final_usage_metadata = None
|
|
826
|
+
for chunk_response in stream:
|
|
827
|
+
if chunk_response.usage_metadata:
|
|
828
|
+
final_usage_metadata = self._extract_usage_metadata(chunk_response)
|
|
829
|
+
|
|
830
|
+
if chunk_response.text:
|
|
831
|
+
yield self._create_chat_generation_chunk(chunk_response)
|
|
832
|
+
|
|
833
|
+
# **FIX:** Yield a final chunk with the mapped usage data
|
|
834
|
+
if final_usage_metadata:
|
|
835
|
+
lc_usage_metadata = self._map_usage_metadata(final_usage_metadata)
|
|
836
|
+
if lc_usage_metadata:
|
|
837
|
+
yield ChatGenerationChunk(
|
|
838
|
+
message=AIMessageChunk(content="", usage_metadata=lc_usage_metadata)
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
except Exception as e:
|
|
842
|
+
self.logger.error(f"Error streaming content: {e}", exc_info=True)
|
|
843
|
+
raise ValueError(f"Error during streaming: {e}")
|
|
844
|
+
|
|
845
|
+
async def _astream(
|
|
846
|
+
self,
|
|
847
|
+
messages: List[BaseMessage],
|
|
848
|
+
stop: Optional[List[str]] = None,
|
|
849
|
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
850
|
+
**kwargs: Any,
|
|
851
|
+
) -> AsyncIterator[ChatGenerationChunk]:
|
|
852
|
+
"""Asynchronously streams the chat response and properly handles final usage metadata."""
|
|
853
|
+
self.logger.info(f"Async streaming response for {len(messages)} messages.")
|
|
854
|
+
|
|
855
|
+
contents = self._convert_messages(messages)
|
|
856
|
+
config = self._prepare_generation_config(messages, stop)
|
|
857
|
+
|
|
858
|
+
try:
|
|
859
|
+
stream = await self._client.aio.models.generate_content_stream(
|
|
860
|
+
model=self.model_name,
|
|
861
|
+
contents=contents,
|
|
862
|
+
config=config,
|
|
863
|
+
**kwargs,
|
|
864
|
+
)
|
|
865
|
+
|
|
866
|
+
final_usage_metadata = None
|
|
867
|
+
async for chunk_response in stream:
|
|
868
|
+
if chunk_response.usage_metadata:
|
|
869
|
+
final_usage_metadata = self._extract_usage_metadata(chunk_response)
|
|
870
|
+
|
|
871
|
+
if chunk_response.text:
|
|
872
|
+
yield self._create_chat_generation_chunk(chunk_response)
|
|
873
|
+
|
|
874
|
+
# **FIX:** Yield a final chunk with the mapped usage data
|
|
875
|
+
if final_usage_metadata:
|
|
876
|
+
lc_usage_metadata = self._map_usage_metadata(final_usage_metadata)
|
|
877
|
+
if lc_usage_metadata:
|
|
878
|
+
yield ChatGenerationChunk(
|
|
879
|
+
message=AIMessageChunk(content="", usage_metadata=lc_usage_metadata)
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
except Exception as e:
|
|
883
|
+
self.logger.error(f"Error during async streaming: {e}", exc_info=True)
|
|
884
|
+
raise ValueError(f"Error during async streaming: {e}")
|