langchain-google-genai 0.0.10rc0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain-google-genai might be problematic. Click here for more details.
- langchain_google_genai/__init__.py +15 -0
- langchain_google_genai/_genai_extension.py +618 -0
- langchain_google_genai/chat_models.py +17 -10
- langchain_google_genai/embeddings.py +26 -12
- langchain_google_genai/genai_aqa.py +134 -0
- langchain_google_genai/google_vector_store.py +493 -0
- langchain_google_genai/llms.py +22 -12
- {langchain_google_genai-0.0.10rc0.dist-info → langchain_google_genai-1.0.1.dist-info}/METADATA +32 -2
- langchain_google_genai-1.0.1.dist-info/RECORD +15 -0
- langchain_google_genai-0.0.10rc0.dist-info/RECORD +0 -12
- {langchain_google_genai-0.0.10rc0.dist-info → langchain_google_genai-1.0.1.dist-info}/LICENSE +0 -0
- {langchain_google_genai-0.0.10rc0.dist-info → langchain_google_genai-1.0.1.dist-info}/WHEEL +0 -0
|
@@ -58,12 +58,27 @@ embeddings.embed_query("hello, world!")
|
|
|
58
58
|
from langchain_google_genai._enums import HarmBlockThreshold, HarmCategory
|
|
59
59
|
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
|
|
60
60
|
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
|
61
|
+
from langchain_google_genai.genai_aqa import (
|
|
62
|
+
AqaInput,
|
|
63
|
+
AqaOutput,
|
|
64
|
+
GenAIAqa,
|
|
65
|
+
)
|
|
66
|
+
from langchain_google_genai.google_vector_store import (
|
|
67
|
+
DoesNotExistsException,
|
|
68
|
+
GoogleVectorStore,
|
|
69
|
+
)
|
|
61
70
|
from langchain_google_genai.llms import GoogleGenerativeAI
|
|
62
71
|
|
|
63
72
|
__all__ = [
|
|
73
|
+
"AqaInput",
|
|
74
|
+
"AqaOutput",
|
|
64
75
|
"ChatGoogleGenerativeAI",
|
|
76
|
+
"DoesNotExistsException",
|
|
77
|
+
"GenAIAqa",
|
|
65
78
|
"GoogleGenerativeAIEmbeddings",
|
|
66
79
|
"GoogleGenerativeAI",
|
|
80
|
+
"GoogleVectorStore",
|
|
67
81
|
"HarmBlockThreshold",
|
|
68
82
|
"HarmCategory",
|
|
83
|
+
"DoesNotExistsException",
|
|
69
84
|
]
|
|
@@ -0,0 +1,618 @@
|
|
|
1
|
+
"""Temporary high-level library of the Google GenerativeAI API.
|
|
2
|
+
|
|
3
|
+
The content of this file should eventually go into the Python package
|
|
4
|
+
google.generativeai.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import datetime
|
|
8
|
+
import logging
|
|
9
|
+
import re
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any, Dict, Iterator, List, MutableSequence, Optional
|
|
12
|
+
|
|
13
|
+
import google.ai.generativelanguage as genai
|
|
14
|
+
import langchain_core
|
|
15
|
+
from google.api_core import client_options as client_options_lib
|
|
16
|
+
from google.api_core import exceptions as gapi_exception
|
|
17
|
+
from google.api_core import gapic_v1
|
|
18
|
+
from google.auth import credentials, exceptions # type: ignore
|
|
19
|
+
from google.protobuf import timestamp_pb2
|
|
20
|
+
|
|
21
|
+
_logger = logging.getLogger(__name__)
|
|
22
|
+
_DEFAULT_API_ENDPOINT = "generativelanguage.googleapis.com"
|
|
23
|
+
_USER_AGENT = f"langchain/{langchain_core.__version__}"
|
|
24
|
+
_DEFAULT_PAGE_SIZE = 20
|
|
25
|
+
_DEFAULT_GENERATE_SERVICE_MODEL = "models/aqa"
|
|
26
|
+
_MAX_REQUEST_PER_CHUNK = 100
|
|
27
|
+
_NAME_REGEX = re.compile(r"^corpora/([^/]+?)(/documents/([^/]+?)(/chunks/([^/]+?))?)?$")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class EntityName:
|
|
32
|
+
corpus_id: str
|
|
33
|
+
document_id: Optional[str] = None
|
|
34
|
+
chunk_id: Optional[str] = None
|
|
35
|
+
|
|
36
|
+
def __post_init__(self) -> None:
|
|
37
|
+
if self.chunk_id is not None and self.document_id is None:
|
|
38
|
+
raise ValueError(f"Chunk must have document ID but found {self}")
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def from_str(cls, encoded: str) -> "EntityName":
|
|
42
|
+
matched = _NAME_REGEX.match(encoded)
|
|
43
|
+
if not matched:
|
|
44
|
+
raise ValueError(f"Invalid entity name: {encoded}")
|
|
45
|
+
|
|
46
|
+
return cls(
|
|
47
|
+
corpus_id=matched.group(1),
|
|
48
|
+
document_id=matched.group(3),
|
|
49
|
+
chunk_id=matched.group(5),
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def __repr__(self) -> str:
|
|
53
|
+
name = f"corpora/{self.corpus_id}"
|
|
54
|
+
if self.document_id is None:
|
|
55
|
+
return name
|
|
56
|
+
name += f"/documents/{self.document_id}"
|
|
57
|
+
if self.chunk_id is None:
|
|
58
|
+
return name
|
|
59
|
+
name += f"/chunks/{self.chunk_id}"
|
|
60
|
+
return name
|
|
61
|
+
|
|
62
|
+
def __str__(self) -> str:
|
|
63
|
+
return repr(self)
|
|
64
|
+
|
|
65
|
+
def is_corpus(self) -> bool:
|
|
66
|
+
return self.document_id is None
|
|
67
|
+
|
|
68
|
+
def is_document(self) -> bool:
|
|
69
|
+
return self.document_id is not None and self.chunk_id is None
|
|
70
|
+
|
|
71
|
+
def is_chunk(self) -> bool:
|
|
72
|
+
return self.chunk_id is not None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class Corpus:
|
|
77
|
+
name: str
|
|
78
|
+
display_name: Optional[str]
|
|
79
|
+
create_time: Optional[timestamp_pb2.Timestamp]
|
|
80
|
+
update_time: Optional[timestamp_pb2.Timestamp]
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def corpus_id(self) -> str:
|
|
84
|
+
name = EntityName.from_str(self.name)
|
|
85
|
+
return name.corpus_id
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def from_corpus(cls, c: genai.Corpus) -> "Corpus":
|
|
89
|
+
return cls(
|
|
90
|
+
name=c.name,
|
|
91
|
+
display_name=c.display_name,
|
|
92
|
+
create_time=c.create_time,
|
|
93
|
+
update_time=c.update_time,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@dataclass
|
|
98
|
+
class Document:
|
|
99
|
+
name: str
|
|
100
|
+
display_name: Optional[str]
|
|
101
|
+
create_time: Optional[timestamp_pb2.Timestamp]
|
|
102
|
+
update_time: Optional[timestamp_pb2.Timestamp]
|
|
103
|
+
custom_metadata: Optional[MutableSequence[genai.CustomMetadata]]
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def corpus_id(self) -> str:
|
|
107
|
+
name = EntityName.from_str(self.name)
|
|
108
|
+
return name.corpus_id
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def document_id(self) -> str:
|
|
112
|
+
name = EntityName.from_str(self.name)
|
|
113
|
+
assert isinstance(name.document_id, str)
|
|
114
|
+
return name.document_id
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def from_document(cls, d: genai.Document) -> "Document":
|
|
118
|
+
return cls(
|
|
119
|
+
name=d.name,
|
|
120
|
+
display_name=d.display_name,
|
|
121
|
+
create_time=d.create_time,
|
|
122
|
+
update_time=d.update_time,
|
|
123
|
+
custom_metadata=d.custom_metadata,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
@dataclass
|
|
128
|
+
class Config:
|
|
129
|
+
"""Global configuration for Google Generative AI API.
|
|
130
|
+
|
|
131
|
+
Normally, the defaults should work fine. Use this to pass Google Auth credentials
|
|
132
|
+
such as using a service account. Refer to for auth credentials documentation:
|
|
133
|
+
https://developers.google.com/identity/protocols/oauth2/service-account#creatinganaccount.
|
|
134
|
+
|
|
135
|
+
Attributes:
|
|
136
|
+
api_endpoint: The Google Generative API endpoint address.
|
|
137
|
+
user_agent: The user agent to use for logging.
|
|
138
|
+
page_size: For paging RPCs, how many entities to return per RPC.
|
|
139
|
+
testing: Are the unit tests running?
|
|
140
|
+
auth_credentials: For setting credentials such as using service accounts.
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
api_endpoint: str = _DEFAULT_API_ENDPOINT
|
|
144
|
+
user_agent: str = _USER_AGENT
|
|
145
|
+
page_size: int = _DEFAULT_PAGE_SIZE
|
|
146
|
+
testing: bool = False
|
|
147
|
+
auth_credentials: Optional[credentials.Credentials] = None
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def set_config(config: Config) -> None:
|
|
151
|
+
"""Set global defaults for operations with Google Generative AI API."""
|
|
152
|
+
global _config
|
|
153
|
+
_config = config
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def get_config() -> Config:
|
|
157
|
+
return _config
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
_config = Config()
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class TestCredentials(credentials.Credentials):
|
|
164
|
+
"""Credentials that do not provide any authentication information.
|
|
165
|
+
|
|
166
|
+
Useful for unit tests where the credentials are not used.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def expired(self) -> bool:
|
|
171
|
+
"""Returns `False`, test credentials never expire."""
|
|
172
|
+
return False
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def valid(self) -> bool:
|
|
176
|
+
"""Returns `True`, test credentials are always valid."""
|
|
177
|
+
return True
|
|
178
|
+
|
|
179
|
+
def refresh(self, request: Any) -> None:
|
|
180
|
+
"""Raises :class:``InvalidOperation``, test credentials cannot be
|
|
181
|
+
refreshed.
|
|
182
|
+
"""
|
|
183
|
+
raise exceptions.InvalidOperation("Test credentials cannot be refreshed.")
|
|
184
|
+
|
|
185
|
+
def apply(self, headers: Any, token: Any = None) -> None:
|
|
186
|
+
"""Anonymous credentials do nothing to the request.
|
|
187
|
+
|
|
188
|
+
The optional ``token`` argument is not supported.
|
|
189
|
+
|
|
190
|
+
Raises:
|
|
191
|
+
google.auth.exceptions.InvalidValue: If a token was specified.
|
|
192
|
+
"""
|
|
193
|
+
if token is not None:
|
|
194
|
+
raise exceptions.InvalidValue("Test credentials don't support tokens.")
|
|
195
|
+
|
|
196
|
+
def before_request(self, request: Any, method: Any, url: Any, headers: Any) -> None:
|
|
197
|
+
"""Test credentials do nothing to the request."""
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _get_credentials() -> Optional[credentials.Credentials]:
|
|
201
|
+
"""Returns credential from config if set or fake credentials for unit testing.
|
|
202
|
+
|
|
203
|
+
If _config.testing is True, a fake credential is returned.
|
|
204
|
+
Otherwise, we are in a real environment and will use credentials if provided
|
|
205
|
+
or None is returned.
|
|
206
|
+
|
|
207
|
+
If None is passed to the clients later on, the actual credentials will be
|
|
208
|
+
inferred by the rules specified in google.auth package.
|
|
209
|
+
"""
|
|
210
|
+
if _config.testing:
|
|
211
|
+
return TestCredentials()
|
|
212
|
+
elif _config.auth_credentials:
|
|
213
|
+
return _config.auth_credentials
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def build_semantic_retriever() -> genai.RetrieverServiceClient:
|
|
218
|
+
credentials = _get_credentials()
|
|
219
|
+
return genai.RetrieverServiceClient(
|
|
220
|
+
credentials=credentials,
|
|
221
|
+
client_info=gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT),
|
|
222
|
+
client_options=client_options_lib.ClientOptions(
|
|
223
|
+
api_endpoint=_config.api_endpoint
|
|
224
|
+
),
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def build_generative_service() -> genai.GenerativeServiceClient:
|
|
229
|
+
credentials = _get_credentials()
|
|
230
|
+
return genai.GenerativeServiceClient(
|
|
231
|
+
credentials=credentials,
|
|
232
|
+
client_info=gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT),
|
|
233
|
+
client_options=client_options_lib.ClientOptions(
|
|
234
|
+
api_endpoint=_config.api_endpoint
|
|
235
|
+
),
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def list_corpora(
|
|
240
|
+
*,
|
|
241
|
+
client: genai.RetrieverServiceClient,
|
|
242
|
+
) -> Iterator[Corpus]:
|
|
243
|
+
for corpus in client.list_corpora(
|
|
244
|
+
genai.ListCorporaRequest(page_size=_config.page_size)
|
|
245
|
+
):
|
|
246
|
+
yield Corpus.from_corpus(corpus)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def get_corpus(
|
|
250
|
+
*,
|
|
251
|
+
corpus_id: str,
|
|
252
|
+
client: genai.RetrieverServiceClient,
|
|
253
|
+
) -> Optional[Corpus]:
|
|
254
|
+
try:
|
|
255
|
+
corpus = client.get_corpus(
|
|
256
|
+
genai.GetCorpusRequest(name=str(EntityName(corpus_id=corpus_id)))
|
|
257
|
+
)
|
|
258
|
+
return Corpus.from_corpus(corpus)
|
|
259
|
+
except Exception as e:
|
|
260
|
+
# If the corpus does not exist, the server returns a permission error.
|
|
261
|
+
if not isinstance(e, gapi_exception.PermissionDenied):
|
|
262
|
+
raise
|
|
263
|
+
_logger.warning(f"Corpus {corpus_id} not found: {e}")
|
|
264
|
+
return None
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def create_corpus(
|
|
268
|
+
*,
|
|
269
|
+
corpus_id: Optional[str] = None,
|
|
270
|
+
display_name: Optional[str] = None,
|
|
271
|
+
client: genai.RetrieverServiceClient,
|
|
272
|
+
) -> Corpus:
|
|
273
|
+
name: Optional[str]
|
|
274
|
+
if corpus_id is not None:
|
|
275
|
+
name = str(EntityName(corpus_id=corpus_id))
|
|
276
|
+
else:
|
|
277
|
+
name = None
|
|
278
|
+
|
|
279
|
+
new_display_name = display_name or f"Untitled {datetime.datetime.now()}"
|
|
280
|
+
|
|
281
|
+
new_corpus = client.create_corpus(
|
|
282
|
+
genai.CreateCorpusRequest(
|
|
283
|
+
corpus=genai.Corpus(name=name, display_name=new_display_name)
|
|
284
|
+
)
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
return Corpus.from_corpus(new_corpus)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def delete_corpus(
|
|
291
|
+
*,
|
|
292
|
+
corpus_id: str,
|
|
293
|
+
client: genai.RetrieverServiceClient,
|
|
294
|
+
) -> None:
|
|
295
|
+
client.delete_corpus(
|
|
296
|
+
genai.DeleteCorpusRequest(name=str(EntityName(corpus_id=corpus_id)), force=True)
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def list_documents(
|
|
301
|
+
*,
|
|
302
|
+
corpus_id: str,
|
|
303
|
+
client: genai.RetrieverServiceClient,
|
|
304
|
+
) -> Iterator[Document]:
|
|
305
|
+
for document in client.list_documents(
|
|
306
|
+
genai.ListDocumentsRequest(
|
|
307
|
+
parent=str(EntityName(corpus_id=corpus_id)), page_size=_DEFAULT_PAGE_SIZE
|
|
308
|
+
)
|
|
309
|
+
):
|
|
310
|
+
yield Document.from_document(document)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def get_document(
|
|
314
|
+
*,
|
|
315
|
+
corpus_id: str,
|
|
316
|
+
document_id: str,
|
|
317
|
+
client: genai.RetrieverServiceClient,
|
|
318
|
+
) -> Optional[Document]:
|
|
319
|
+
try:
|
|
320
|
+
document = client.get_document(
|
|
321
|
+
genai.GetDocumentRequest(
|
|
322
|
+
name=str(EntityName(corpus_id=corpus_id, document_id=document_id))
|
|
323
|
+
)
|
|
324
|
+
)
|
|
325
|
+
return Document.from_document(document)
|
|
326
|
+
except Exception as e:
|
|
327
|
+
if not isinstance(e, gapi_exception.NotFound):
|
|
328
|
+
raise
|
|
329
|
+
_logger.warning(f"Document {document_id} in corpus {corpus_id} not found: {e}")
|
|
330
|
+
return None
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def create_document(
|
|
334
|
+
*,
|
|
335
|
+
corpus_id: str,
|
|
336
|
+
document_id: Optional[str] = None,
|
|
337
|
+
display_name: Optional[str] = None,
|
|
338
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
339
|
+
client: genai.RetrieverServiceClient,
|
|
340
|
+
) -> Document:
|
|
341
|
+
name: Optional[str]
|
|
342
|
+
if document_id is not None:
|
|
343
|
+
name = str(EntityName(corpus_id=corpus_id, document_id=document_id))
|
|
344
|
+
else:
|
|
345
|
+
name = None
|
|
346
|
+
|
|
347
|
+
new_display_name = display_name or f"Untitled {datetime.datetime.now()}"
|
|
348
|
+
new_metadatas = _convert_to_metadata(metadata) if metadata else None
|
|
349
|
+
|
|
350
|
+
new_document = client.create_document(
|
|
351
|
+
genai.CreateDocumentRequest(
|
|
352
|
+
parent=str(EntityName(corpus_id=corpus_id)),
|
|
353
|
+
document=genai.Document(
|
|
354
|
+
name=name, display_name=new_display_name, custom_metadata=new_metadatas
|
|
355
|
+
),
|
|
356
|
+
)
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
return Document.from_document(new_document)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def delete_document(
|
|
363
|
+
*,
|
|
364
|
+
corpus_id: str,
|
|
365
|
+
document_id: str,
|
|
366
|
+
client: genai.RetrieverServiceClient,
|
|
367
|
+
) -> None:
|
|
368
|
+
client.delete_document(
|
|
369
|
+
genai.DeleteDocumentRequest(
|
|
370
|
+
name=str(EntityName(corpus_id=corpus_id, document_id=document_id)),
|
|
371
|
+
force=True,
|
|
372
|
+
)
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def batch_create_chunk(
|
|
377
|
+
*,
|
|
378
|
+
corpus_id: str,
|
|
379
|
+
document_id: str,
|
|
380
|
+
texts: List[str],
|
|
381
|
+
metadatas: Optional[List[Dict[str, Any]]] = None,
|
|
382
|
+
client: genai.RetrieverServiceClient,
|
|
383
|
+
) -> List[genai.Chunk]:
|
|
384
|
+
if metadatas is None:
|
|
385
|
+
metadatas = [{} for _ in texts]
|
|
386
|
+
if len(texts) != len(metadatas):
|
|
387
|
+
raise ValueError(
|
|
388
|
+
f"metadatas's length {len(metadatas)} "
|
|
389
|
+
f"and texts's length {len(texts)} are mismatched"
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
doc_name = str(EntityName(corpus_id=corpus_id, document_id=document_id))
|
|
393
|
+
|
|
394
|
+
created_chunks: List[genai.Chunk] = []
|
|
395
|
+
|
|
396
|
+
batch_request = genai.BatchCreateChunksRequest(
|
|
397
|
+
parent=doc_name,
|
|
398
|
+
requests=[],
|
|
399
|
+
)
|
|
400
|
+
for text, metadata in zip(texts, metadatas):
|
|
401
|
+
batch_request.requests.append(
|
|
402
|
+
genai.CreateChunkRequest(
|
|
403
|
+
parent=doc_name,
|
|
404
|
+
chunk=genai.Chunk(
|
|
405
|
+
data=genai.ChunkData(string_value=text),
|
|
406
|
+
custom_metadata=_convert_to_metadata(metadata),
|
|
407
|
+
),
|
|
408
|
+
)
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
if len(batch_request.requests) >= _MAX_REQUEST_PER_CHUNK:
|
|
412
|
+
response = client.batch_create_chunks(batch_request)
|
|
413
|
+
created_chunks.extend(list(response.chunks))
|
|
414
|
+
# Prepare a new batch for next round.
|
|
415
|
+
batch_request = genai.BatchCreateChunksRequest(
|
|
416
|
+
parent=doc_name,
|
|
417
|
+
requests=[],
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# Process left over.
|
|
421
|
+
if len(batch_request.requests) > 0:
|
|
422
|
+
response = client.batch_create_chunks(batch_request)
|
|
423
|
+
created_chunks.extend(list(response.chunks))
|
|
424
|
+
|
|
425
|
+
return created_chunks
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def delete_chunk(
|
|
429
|
+
*,
|
|
430
|
+
corpus_id: str,
|
|
431
|
+
document_id: str,
|
|
432
|
+
chunk_id: str,
|
|
433
|
+
client: genai.RetrieverServiceClient,
|
|
434
|
+
) -> None:
|
|
435
|
+
client.delete_chunk(
|
|
436
|
+
genai.DeleteChunkRequest(
|
|
437
|
+
name=str(
|
|
438
|
+
EntityName(
|
|
439
|
+
corpus_id=corpus_id, document_id=document_id, chunk_id=chunk_id
|
|
440
|
+
)
|
|
441
|
+
)
|
|
442
|
+
)
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def query_corpus(
|
|
447
|
+
*,
|
|
448
|
+
corpus_id: str,
|
|
449
|
+
query: str,
|
|
450
|
+
k: int = 4,
|
|
451
|
+
filter: Optional[Dict[str, Any]] = None,
|
|
452
|
+
client: genai.RetrieverServiceClient,
|
|
453
|
+
) -> List[genai.RelevantChunk]:
|
|
454
|
+
response = client.query_corpus(
|
|
455
|
+
genai.QueryCorpusRequest(
|
|
456
|
+
name=str(EntityName(corpus_id=corpus_id)),
|
|
457
|
+
query=query,
|
|
458
|
+
metadata_filters=_convert_filter(filter),
|
|
459
|
+
results_count=k,
|
|
460
|
+
)
|
|
461
|
+
)
|
|
462
|
+
return list(response.relevant_chunks)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def query_document(
|
|
466
|
+
*,
|
|
467
|
+
corpus_id: str,
|
|
468
|
+
document_id: str,
|
|
469
|
+
query: str,
|
|
470
|
+
k: int = 4,
|
|
471
|
+
filter: Optional[Dict[str, Any]] = None,
|
|
472
|
+
client: genai.RetrieverServiceClient,
|
|
473
|
+
) -> List[genai.RelevantChunk]:
|
|
474
|
+
response = client.query_document(
|
|
475
|
+
genai.QueryDocumentRequest(
|
|
476
|
+
name=str(EntityName(corpus_id=corpus_id, document_id=document_id)),
|
|
477
|
+
query=query,
|
|
478
|
+
metadata_filters=_convert_filter(filter),
|
|
479
|
+
results_count=k,
|
|
480
|
+
)
|
|
481
|
+
)
|
|
482
|
+
return list(response.relevant_chunks)
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
@dataclass
|
|
486
|
+
class Passage:
|
|
487
|
+
text: str
|
|
488
|
+
id: str
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
@dataclass
|
|
492
|
+
class GroundedAnswer:
|
|
493
|
+
answer: str
|
|
494
|
+
attributed_passages: List[Passage]
|
|
495
|
+
answerable_probability: Optional[float]
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
@dataclass
|
|
499
|
+
class GenerateAnswerError(Exception):
|
|
500
|
+
finish_reason: genai.Candidate.FinishReason
|
|
501
|
+
finish_message: str
|
|
502
|
+
safety_ratings: MutableSequence[genai.SafetyRating]
|
|
503
|
+
|
|
504
|
+
def __str__(self) -> str:
|
|
505
|
+
return (
|
|
506
|
+
f"finish_reason: {self.finish_reason} "
|
|
507
|
+
f"finish_message: {self.finish_message} "
|
|
508
|
+
f"safety ratings: {self.safety_ratings}"
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def generate_answer(
|
|
513
|
+
*,
|
|
514
|
+
prompt: str,
|
|
515
|
+
passages: List[str],
|
|
516
|
+
answer_style: int = genai.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE,
|
|
517
|
+
safety_settings: List[genai.SafetySetting] = [],
|
|
518
|
+
temperature: Optional[float] = None,
|
|
519
|
+
client: genai.GenerativeServiceClient,
|
|
520
|
+
) -> GroundedAnswer:
|
|
521
|
+
# TODO: Consider passing in the corpus ID instead of the actual
|
|
522
|
+
# passages.
|
|
523
|
+
response = client.generate_answer(
|
|
524
|
+
genai.GenerateAnswerRequest(
|
|
525
|
+
contents=[
|
|
526
|
+
genai.Content(parts=[genai.Part(text=prompt)]),
|
|
527
|
+
],
|
|
528
|
+
model=_DEFAULT_GENERATE_SERVICE_MODEL,
|
|
529
|
+
answer_style=answer_style,
|
|
530
|
+
safety_settings=safety_settings,
|
|
531
|
+
temperature=temperature,
|
|
532
|
+
inline_passages=genai.GroundingPassages(
|
|
533
|
+
passages=[
|
|
534
|
+
genai.GroundingPassage(
|
|
535
|
+
# IDs here takes alphanumeric only. No dashes allowed.
|
|
536
|
+
id=str(index),
|
|
537
|
+
content=genai.Content(parts=[genai.Part(text=chunk)]),
|
|
538
|
+
)
|
|
539
|
+
for index, chunk in enumerate(passages)
|
|
540
|
+
]
|
|
541
|
+
),
|
|
542
|
+
)
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
if response.answer.finish_reason != genai.Candidate.FinishReason.STOP:
|
|
546
|
+
finish_message = _get_finish_message(response.answer)
|
|
547
|
+
raise GenerateAnswerError(
|
|
548
|
+
finish_reason=response.answer.finish_reason,
|
|
549
|
+
finish_message=finish_message,
|
|
550
|
+
safety_ratings=response.answer.safety_ratings,
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
assert len(response.answer.content.parts) == 1
|
|
554
|
+
return GroundedAnswer(
|
|
555
|
+
answer=response.answer.content.parts[0].text,
|
|
556
|
+
attributed_passages=[
|
|
557
|
+
Passage(
|
|
558
|
+
text=passage.content.parts[0].text,
|
|
559
|
+
id=passage.source_id.grounding_passage.passage_id,
|
|
560
|
+
)
|
|
561
|
+
for passage in response.answer.grounding_attributions
|
|
562
|
+
if len(passage.content.parts) > 0
|
|
563
|
+
],
|
|
564
|
+
answerable_probability=response.answerable_probability,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
# TODO: Use candidate.finish_message when that field is launched.
|
|
569
|
+
# For now, we derive this message from other existing fields.
|
|
570
|
+
def _get_finish_message(candidate: genai.Candidate) -> str:
|
|
571
|
+
finish_messages: Dict[int, str] = {
|
|
572
|
+
genai.Candidate.FinishReason.MAX_TOKENS: "Maximum token in context window reached", # noqa: E501
|
|
573
|
+
genai.Candidate.FinishReason.SAFETY: "Blocked because of safety",
|
|
574
|
+
genai.Candidate.FinishReason.RECITATION: "Blocked because of recitation",
|
|
575
|
+
}
|
|
576
|
+
|
|
577
|
+
finish_reason = candidate.finish_reason
|
|
578
|
+
if finish_reason not in finish_messages:
|
|
579
|
+
return "Unexpected generation error"
|
|
580
|
+
|
|
581
|
+
return finish_messages[finish_reason]
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def _convert_to_metadata(metadata: Dict[str, Any]) -> List[genai.CustomMetadata]:
|
|
585
|
+
cs: List[genai.CustomMetadata] = []
|
|
586
|
+
for key, value in metadata.items():
|
|
587
|
+
if isinstance(value, str):
|
|
588
|
+
c = genai.CustomMetadata(key=key, string_value=value)
|
|
589
|
+
elif isinstance(value, (float, int)):
|
|
590
|
+
c = genai.CustomMetadata(key=key, numeric_value=value)
|
|
591
|
+
else:
|
|
592
|
+
raise ValueError(f"Metadata value {value} is not supported")
|
|
593
|
+
|
|
594
|
+
cs.append(c)
|
|
595
|
+
return cs
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
def _convert_filter(fs: Optional[Dict[str, Any]]) -> List[genai.MetadataFilter]:
|
|
599
|
+
if fs is None:
|
|
600
|
+
return []
|
|
601
|
+
assert isinstance(fs, dict)
|
|
602
|
+
|
|
603
|
+
filters: List[genai.MetadataFilter] = []
|
|
604
|
+
for key, value in fs.items():
|
|
605
|
+
if isinstance(value, str):
|
|
606
|
+
condition = genai.Condition(
|
|
607
|
+
operation=genai.Condition.Operator.EQUAL, string_value=value
|
|
608
|
+
)
|
|
609
|
+
elif isinstance(value, (float, int)):
|
|
610
|
+
condition = genai.Condition(
|
|
611
|
+
operation=genai.Condition.Operator.EQUAL, numeric_value=value
|
|
612
|
+
)
|
|
613
|
+
else:
|
|
614
|
+
raise ValueError(f"Filter value {value} is not supported")
|
|
615
|
+
|
|
616
|
+
filters.append(genai.MetadataFilter(key=key, conditions=[condition]))
|
|
617
|
+
|
|
618
|
+
return filters
|
|
@@ -483,17 +483,24 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|
|
483
483
|
@root_validator()
|
|
484
484
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
485
485
|
"""Validates params and passes them to google-generativeai package."""
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
486
|
+
if values.get("credentials"):
|
|
487
|
+
genai.configure(
|
|
488
|
+
credentials=values.get("credentials"),
|
|
489
|
+
transport=values.get("transport"),
|
|
490
|
+
client_options=values.get("client_options"),
|
|
491
|
+
)
|
|
492
|
+
else:
|
|
493
|
+
google_api_key = get_from_dict_or_env(
|
|
494
|
+
values, "google_api_key", "GOOGLE_API_KEY"
|
|
495
|
+
)
|
|
496
|
+
if isinstance(google_api_key, SecretStr):
|
|
497
|
+
google_api_key = google_api_key.get_secret_value()
|
|
491
498
|
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
499
|
+
genai.configure(
|
|
500
|
+
api_key=google_api_key,
|
|
501
|
+
transport=values.get("transport"),
|
|
502
|
+
client_options=values.get("client_options"),
|
|
503
|
+
)
|
|
497
504
|
if (
|
|
498
505
|
values.get("temperature") is not None
|
|
499
506
|
and not 0 <= values["temperature"] <= 1
|