langchain-google-genai 0.0.5__tar.gz → 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,18 +1,22 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langchain-google-genai
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: An integration package connecting Google's genai package and LangChain
5
- Home-page: https://github.com/langchain-ai/langchain/blob/master/libs/partners/google-genai
5
+ Home-page: https://github.com/langchain-ai/langchain
6
+ License: MIT
6
7
  Requires-Python: >=3.9,<4.0
8
+ Classifier: License :: OSI Approved :: MIT License
7
9
  Classifier: Programming Language :: Python :: 3
8
10
  Classifier: Programming Language :: Python :: 3.9
9
11
  Classifier: Programming Language :: Python :: 3.10
10
12
  Classifier: Programming Language :: Python :: 3.11
13
+ Classifier: Programming Language :: Python :: 3.12
11
14
  Provides-Extra: images
12
15
  Requires-Dist: google-generativeai (>=0.3.1,<0.4.0)
13
16
  Requires-Dist: langchain-core (>=0.1,<0.2)
14
17
  Requires-Dist: pillow (>=10.1.0,<11.0.0) ; extra == "images"
15
- Project-URL: Repository, https://github.com/langchain-ai/langchain/blob/master/libs/partners/google-genai
18
+ Project-URL: Repository, https://github.com/langchain-ai/langchain
19
+ Project-URL: Source Code, https://github.com/langchain-ai/langchain/tree/master/libs/partners/google-genai
16
20
  Description-Content-Type: text/markdown
17
21
 
18
22
  # langchain-google-genai
@@ -0,0 +1,135 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import (
4
+ Dict,
5
+ List,
6
+ Type,
7
+ Union,
8
+ )
9
+
10
+ from langchain_core.pydantic_v1 import BaseModel
11
+ from langchain_core.tools import BaseTool
12
+ from langchain_core.utils.json_schema import dereference_refs
13
+
14
+ FunctionCallType = Union[BaseTool, Type[BaseModel], Dict]
15
+
16
+ TYPE_ENUM = {
17
+ "string": 1,
18
+ "number": 2,
19
+ "integer": 3,
20
+ "boolean": 4,
21
+ "array": 5,
22
+ "object": 6,
23
+ }
24
+
25
+
26
+ def convert_to_genai_function_declarations(
27
+ function_calls: List[FunctionCallType],
28
+ ) -> Dict:
29
+ function_declarations = []
30
+ for fc in function_calls:
31
+ function_declarations.append(_convert_to_genai_function(fc))
32
+ return {
33
+ "function_declarations": function_declarations,
34
+ }
35
+
36
+
37
+ def _convert_to_genai_function(fc: FunctionCallType) -> Dict:
38
+ """
39
+ Produce
40
+
41
+ {
42
+ "name": "get_weather",
43
+ "description": "Determine weather in my location",
44
+ "parameters": {
45
+ "properties": {
46
+ "location": {
47
+ "description": "The city and state e.g. San Francisco, CA",
48
+ "type_": 1
49
+ },
50
+ "unit": { "enum": ["c", "f"], "type_": 1 }
51
+ },
52
+ "required": ["location"],
53
+ "type_": 6
54
+ }
55
+ }
56
+
57
+ """
58
+ if isinstance(fc, BaseTool):
59
+ return _convert_tool_to_genai_function(fc)
60
+ elif isinstance(fc, type) and issubclass(fc, BaseModel):
61
+ return _convert_pydantic_to_genai_function(fc)
62
+ elif isinstance(fc, dict):
63
+ return {
64
+ **fc,
65
+ "parameters": {
66
+ "properties": {
67
+ k: {
68
+ "type_": TYPE_ENUM[v["type"]],
69
+ "description": v.get("description"),
70
+ }
71
+ for k, v in fc["parameters"]["properties"].items()
72
+ },
73
+ "required": fc["parameters"].get("required", []),
74
+ "type_": TYPE_ENUM[fc["parameters"]["type"]],
75
+ },
76
+ }
77
+ else:
78
+ raise ValueError(f"Unsupported function call type {fc}")
79
+
80
+
81
+ def _convert_tool_to_genai_function(tool: BaseTool) -> Dict:
82
+ if tool.args_schema:
83
+ schema = dereference_refs(tool.args_schema.schema())
84
+ schema.pop("definitions", None)
85
+
86
+ return {
87
+ "name": tool.name or schema["title"],
88
+ "description": tool.description or schema["description"],
89
+ "parameters": {
90
+ "properties": {
91
+ k: {
92
+ "type_": TYPE_ENUM[v["type"]],
93
+ "description": v.get("description"),
94
+ }
95
+ for k, v in schema["properties"].items()
96
+ },
97
+ "required": schema["required"],
98
+ "type_": TYPE_ENUM[schema["type"]],
99
+ },
100
+ }
101
+ else:
102
+ return {
103
+ "name": tool.name,
104
+ "description": tool.description,
105
+ "parameters": {
106
+ "properties": {
107
+ "__arg1": {"type": "string"},
108
+ },
109
+ "required": ["__arg1"],
110
+ "type_": TYPE_ENUM["object"],
111
+ },
112
+ }
113
+
114
+
115
+ def _convert_pydantic_to_genai_function(
116
+ pydantic_model: Type[BaseModel],
117
+ ) -> Dict:
118
+ schema = dereference_refs(pydantic_model.schema())
119
+ schema.pop("definitions", None)
120
+
121
+ return {
122
+ "name": schema["title"],
123
+ "description": schema.get("description", ""),
124
+ "parameters": {
125
+ "properties": {
126
+ k: {
127
+ "type_": TYPE_ENUM[v["type"]],
128
+ "description": v.get("description"),
129
+ }
130
+ for k, v in schema["properties"].items()
131
+ },
132
+ "required": schema["required"],
133
+ "type_": TYPE_ENUM[schema["type"]],
134
+ },
135
+ }
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import base64
4
+ import json
4
5
  import logging
5
6
  import os
6
7
  from io import BytesIO
@@ -21,9 +22,12 @@ from typing import (
21
22
  )
22
23
  from urllib.parse import urlparse
23
24
 
25
+ import google.api_core
26
+
24
27
  # TODO: remove ignore once the google package is published with types
25
28
  import google.generativeai as genai # type: ignore[import]
26
29
  import requests
30
+ from google.ai.generativelanguage_v1beta import FunctionCall
27
31
  from langchain_core.callbacks.manager import (
28
32
  AsyncCallbackManagerForLLMRun,
29
33
  CallbackManagerForLLMRun,
@@ -35,12 +39,13 @@ from langchain_core.messages import (
35
39
  BaseMessage,
36
40
  ChatMessage,
37
41
  ChatMessageChunk,
42
+ FunctionMessage,
38
43
  HumanMessage,
39
44
  HumanMessageChunk,
40
45
  SystemMessage,
41
46
  )
42
47
  from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
43
- from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
48
+ from langchain_core.pydantic_v1 import SecretStr, root_validator
44
49
  from langchain_core.utils import get_from_dict_or_env
45
50
  from tenacity import (
46
51
  before_sleep_log,
@@ -51,6 +56,10 @@ from tenacity import (
51
56
  )
52
57
 
53
58
  from langchain_google_genai._common import GoogleGenerativeAIError
59
+ from langchain_google_genai._function_utils import (
60
+ convert_to_genai_function_declarations,
61
+ )
62
+ from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI
54
63
 
55
64
  IMAGE_TYPES: Tuple = ()
56
65
  try:
@@ -87,8 +96,6 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
87
96
  Callable[[Any], Any]: A retry decorator configured for handling specific
88
97
  Google API exceptions.
89
98
  """
90
- import google.api_core.exceptions
91
-
92
99
  multiplier = 2
93
100
  min_seconds = 1
94
101
  max_seconds = 60
@@ -123,14 +130,22 @@ def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
123
130
  Any: The result from the chat generation method.
124
131
  """
125
132
  retry_decorator = _create_retry_decorator()
126
- from google.api_core.exceptions import InvalidArgument # type: ignore
127
133
 
128
134
  @retry_decorator
129
135
  def _chat_with_retry(**kwargs: Any) -> Any:
130
136
  try:
131
137
  return generation_method(**kwargs)
132
- except InvalidArgument as e:
133
- # Do not retry for these errors.
138
+ # Do not retry for these errors.
139
+ except google.api_core.exceptions.FailedPrecondition as exc:
140
+ if "location is not supported" in exc.message:
141
+ error_msg = (
142
+ "Your location is not supported by google-generativeai "
143
+ "at the moment. Try to use ChatVertexAI LLM from "
144
+ "langchain_google_vertexai."
145
+ )
146
+ raise ValueError(error_msg)
147
+
148
+ except google.api_core.exceptions.InvalidArgument as e:
134
149
  raise ChatGoogleGenerativeAIError(
135
150
  f"Invalid argument provided to Gemini: {e}"
136
151
  ) from e
@@ -312,14 +327,20 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
312
327
  continue
313
328
  elif isinstance(message, AIMessage):
314
329
  role = "model"
330
+ # TODO: Handle AImessage with function call
331
+ parts = _convert_to_parts(message.content)
315
332
  elif isinstance(message, HumanMessage):
316
333
  role = "user"
334
+ parts = _convert_to_parts(message.content)
335
+ elif isinstance(message, FunctionMessage):
336
+ role = "user"
337
+ # TODO: Handle FunctionMessage
338
+ parts = _convert_to_parts(message.content)
317
339
  else:
318
340
  raise ValueError(
319
341
  f"Unexpected message with type {type(message)} at the position {i}."
320
342
  )
321
343
 
322
- parts = _convert_to_parts(message.content)
323
344
  if raw_system_message:
324
345
  if role == "model":
325
346
  raise ValueError(
@@ -332,15 +353,36 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
332
353
  return messages
333
354
 
334
355
 
335
- def _parts_to_content(parts: List[genai.types.PartType]) -> Union[List[dict], str]:
356
+ def _retrieve_function_call_response(
357
+ parts: List[genai.types.PartType],
358
+ ) -> Optional[Dict]:
359
+ for idx, part in enumerate(parts):
360
+ if part.function_call and part.function_call.name:
361
+ fc: FunctionCall = part.function_call
362
+ return {
363
+ "function_call": {
364
+ "name": fc.name,
365
+ "arguments": json.dumps(
366
+ dict(fc.args.items())
367
+ ), # dump to match other function calling llms for now
368
+ }
369
+ }
370
+ return None
371
+
372
+
373
+ def _parts_to_content(
374
+ parts: List[genai.types.PartType],
375
+ ) -> Tuple[Union[str, List[Union[Dict[Any, Any], str]]], Optional[Dict]]:
336
376
  """Converts a list of Gemini API Part objects into a list of LangChain messages."""
377
+ function_call_resp = _retrieve_function_call_response(parts)
378
+
337
379
  if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data:
338
380
  # Simple text response. The typical response
339
- return parts[0].text
381
+ return parts[0].text, function_call_resp
340
382
  elif not parts:
341
383
  logger.warning("Gemini produced an empty response.")
342
- return ""
343
- messages = []
384
+ return "", function_call_resp
385
+ messages: List[Union[Dict[Any, Any], str]] = []
344
386
  for part in parts:
345
387
  if part.text is not None:
346
388
  messages.append(
@@ -352,7 +394,7 @@ def _parts_to_content(parts: List[genai.types.PartType]) -> Union[List[dict], st
352
394
  else:
353
395
  # TODO: Handle inline_data if that's a thing?
354
396
  raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
355
- return messages
397
+ return messages, function_call_resp
356
398
 
357
399
 
358
400
  def _response_to_result(
@@ -379,16 +421,24 @@ def _response_to_result(
379
421
  "model": ai_msg_t,
380
422
  "user": human_msg_t,
381
423
  }
424
+
382
425
  for candidate in response.candidates:
383
426
  content = candidate.content
384
- parts_content = _parts_to_content(content.parts)
427
+ parts_content, additional_kwargs = _parts_to_content(content.parts)
385
428
  if content.role not in role_map:
386
429
  logger.warning(
387
430
  f"Unrecognized role: {content.role}. Treating as a ChatMessage."
388
431
  )
389
- msg = chat_msg_t(content=parts_content, role=content.role)
432
+ msg = chat_msg_t(
433
+ content=parts_content,
434
+ role=content.role,
435
+ additional_kwargs=additional_kwargs or {},
436
+ )
390
437
  else:
391
- msg = role_map[content.role](content=parts_content)
438
+ msg = role_map[content.role](
439
+ content=parts_content,
440
+ additional_kwargs=additional_kwargs or {},
441
+ )
392
442
  generation_info = {}
393
443
  if candidate.finish_reason:
394
444
  generation_info["finish_reason"] = candidate.finish_reason.name
@@ -409,7 +459,7 @@ def _response_to_result(
409
459
  return ChatResult(generations=generations, llm_output=llm_output)
410
460
 
411
461
 
412
- class ChatGoogleGenerativeAI(BaseChatModel):
462
+ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
413
463
  """`Google Generative AI` Chat models API.
414
464
 
415
465
  To use, you must have either:
@@ -427,40 +477,8 @@ class ChatGoogleGenerativeAI(BaseChatModel):
427
477
 
428
478
  """
429
479
 
430
- model: str = Field(
431
- ...,
432
- description="""The name of the model to use.
433
- Supported examples:
434
- - gemini-pro""",
435
- )
436
- max_output_tokens: int = Field(default=None, description="Max output tokens")
437
-
438
480
  client: Any #: :meta private:
439
- google_api_key: Optional[SecretStr] = None
440
- temperature: Optional[float] = None
441
- """Run inference with this temperature. Must by in the closed
442
- interval [0.0, 1.0]."""
443
- top_k: Optional[int] = None
444
- """Decode using top-k sampling: consider the set of top_k most probable tokens.
445
- Must be positive."""
446
- top_p: Optional[int] = None
447
- """The maximum cumulative probability of tokens to consider when sampling.
448
-
449
- The model uses combined Top-k and nucleus sampling.
450
-
451
- Tokens are sorted based on their assigned probabilities so
452
- that only the most likely tokens are considered. Top-k
453
- sampling directly limits the maximum number of tokens to
454
- consider, while Nucleus sampling limits number of tokens
455
- based on the cumulative probability.
456
-
457
- Note: The default value varies by model, see the
458
- `Model.top_p` attribute of the `Model` returned the
459
- `genai.get_model` function.
460
- """
461
- n: int = Field(default=1, alias="candidate_count")
462
- """Number of chat completions to generate for each prompt. Note that the API may
463
- not return the full n completions if duplicates are generated."""
481
+
464
482
  convert_system_message_to_human: bool = False
465
483
  """Whether to merge any leading SystemMessage into the following HumanMessage.
466
484
 
@@ -478,22 +496,24 @@ Supported examples:
478
496
  def _llm_type(self) -> str:
479
497
  return "chat-google-generative-ai"
480
498
 
481
- @property
482
- def _is_geminiai(self) -> bool:
483
- return self.model is not None and "gemini" in self.model
484
-
485
499
  @classmethod
486
500
  def is_lc_serializable(self) -> bool:
487
501
  return True
488
502
 
489
503
  @root_validator()
490
504
  def validate_environment(cls, values: Dict) -> Dict:
505
+ """Validates params and passes them to google-generativeai package."""
491
506
  google_api_key = get_from_dict_or_env(
492
507
  values, "google_api_key", "GOOGLE_API_KEY"
493
508
  )
494
509
  if isinstance(google_api_key, SecretStr):
495
510
  google_api_key = google_api_key.get_secret_value()
496
- genai.configure(api_key=google_api_key)
511
+
512
+ genai.configure(
513
+ api_key=google_api_key,
514
+ transport=values.get("transport"),
515
+ client_options=values.get("client_options"),
516
+ )
497
517
  if (
498
518
  values.get("temperature") is not None
499
519
  and not 0 <= values["temperature"] <= 1
@@ -546,7 +566,11 @@ Supported examples:
546
566
  run_manager: Optional[CallbackManagerForLLMRun] = None,
547
567
  **kwargs: Any,
548
568
  ) -> ChatResult:
549
- params, chat, message = self._prepare_chat(messages, stop=stop)
569
+ params, chat, message = self._prepare_chat(
570
+ messages,
571
+ stop=stop,
572
+ functions=kwargs.get("functions"),
573
+ )
550
574
  response: genai.types.GenerateContentResponse = _chat_with_retry(
551
575
  content=message,
552
576
  **params,
@@ -561,7 +585,11 @@ Supported examples:
561
585
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
562
586
  **kwargs: Any,
563
587
  ) -> ChatResult:
564
- params, chat, message = self._prepare_chat(messages, stop=stop)
588
+ params, chat, message = self._prepare_chat(
589
+ messages,
590
+ stop=stop,
591
+ functions=kwargs.get("functions"),
592
+ )
565
593
  response: genai.types.GenerateContentResponse = await _achat_with_retry(
566
594
  content=message,
567
595
  **params,
@@ -576,7 +604,11 @@ Supported examples:
576
604
  run_manager: Optional[CallbackManagerForLLMRun] = None,
577
605
  **kwargs: Any,
578
606
  ) -> Iterator[ChatGenerationChunk]:
579
- params, chat, message = self._prepare_chat(messages, stop=stop)
607
+ params, chat, message = self._prepare_chat(
608
+ messages,
609
+ stop=stop,
610
+ functions=kwargs.get("functions"),
611
+ )
580
612
  response: genai.types.GenerateContentResponse = _chat_with_retry(
581
613
  content=message,
582
614
  **params,
@@ -603,7 +635,11 @@ Supported examples:
603
635
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
604
636
  **kwargs: Any,
605
637
  ) -> AsyncIterator[ChatGenerationChunk]:
606
- params, chat, message = self._prepare_chat(messages, stop=stop)
638
+ params, chat, message = self._prepare_chat(
639
+ messages,
640
+ stop=stop,
641
+ functions=kwargs.get("functions"),
642
+ )
607
643
  async for chunk in await _achat_with_retry(
608
644
  content=message,
609
645
  **params,
@@ -628,11 +664,37 @@ Supported examples:
628
664
  stop: Optional[List[str]] = None,
629
665
  **kwargs: Any,
630
666
  ) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
667
+ client = self.client
668
+ functions = kwargs.pop("functions", None)
669
+ if functions:
670
+ tools = convert_to_genai_function_declarations(functions)
671
+ client = genai.GenerativeModel(model_name=self.model, tools=tools)
672
+
631
673
  params = self._prepare_params(stop, **kwargs)
632
674
  history = _parse_chat_history(
633
675
  messages,
634
676
  convert_system_message_to_human=self.convert_system_message_to_human,
635
677
  )
636
678
  message = history.pop()
637
- chat = self.client.start_chat(history=history)
679
+ chat = client.start_chat(history=history)
638
680
  return params, chat, message
681
+
682
+ def get_num_tokens(self, text: str) -> int:
683
+ """Get the number of tokens present in the text.
684
+
685
+ Useful for checking if an input will fit in a model's context window.
686
+
687
+ Args:
688
+ text: The string input to tokenize.
689
+
690
+ Returns:
691
+ The integer number of tokens in the text.
692
+ """
693
+ if self._model_family == GoogleModelFamily.GEMINI:
694
+ result = self.client.count_tokens(text)
695
+ token_count = result.total_tokens
696
+ else:
697
+ result = self.client.count_text_tokens(model=self.model, prompt=text)
698
+ token_count = result["token_count"]
699
+
700
+ return token_count
@@ -43,16 +43,32 @@ class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings):
43
43
  description="The Google API key to use. If not provided, "
44
44
  "the GOOGLE_API_KEY environment variable will be used.",
45
45
  )
46
+ client_options: Optional[Dict] = Field(
47
+ None,
48
+ description=(
49
+ "A dictionary of client options to pass to the Google API client, "
50
+ "such as `api_endpoint`."
51
+ ),
52
+ )
53
+ transport: Optional[str] = Field(
54
+ None,
55
+ description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
56
+ )
46
57
 
47
58
  @root_validator()
48
59
  def validate_environment(cls, values: Dict) -> Dict:
49
- """Validates that the python package exists in environment."""
60
+ """Validates params and passes them to google-generativeai package."""
50
61
  google_api_key = get_from_dict_or_env(
51
62
  values, "google_api_key", "GOOGLE_API_KEY"
52
63
  )
53
64
  if isinstance(google_api_key, SecretStr):
54
65
  google_api_key = google_api_key.get_secret_value()
55
- genai.configure(api_key=google_api_key)
66
+
67
+ genai.configure(
68
+ api_key=google_api_key,
69
+ transport=values.get("transport"),
70
+ client_options=values.get("client_options"),
71
+ )
56
72
  return values
57
73
 
58
74
  def _embed(
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from enum import Enum, auto
3
4
  from typing import Any, Callable, Dict, Iterator, List, Optional, Union
4
5
 
5
6
  import google.api_core
@@ -15,6 +16,19 @@ from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validat
15
16
  from langchain_core.utils import get_from_dict_or_env
16
17
 
17
18
 
19
+ class GoogleModelFamily(str, Enum):
20
+ GEMINI = auto()
21
+ PALM = auto()
22
+
23
+ @classmethod
24
+ def _missing_(cls, value: Any) -> Optional["GoogleModelFamily"]:
25
+ if "gemini" in value.lower():
26
+ return GoogleModelFamily.GEMINI
27
+ elif "text-bison" in value.lower():
28
+ return GoogleModelFamily.PALM
29
+ return None
30
+
31
+
18
32
  def _create_retry_decorator(
19
33
  llm: BaseLLM,
20
34
  *,
@@ -56,21 +70,25 @@ def _completion_with_retry(
56
70
  prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
57
71
  ) -> Any:
58
72
  generation_config = kwargs.get("generation_config", {})
59
- if is_gemini:
60
- return llm.client.generate_content(
61
- contents=prompt, stream=stream, generation_config=generation_config
62
- )
63
- return llm.client.generate_text(prompt=prompt, **kwargs)
73
+ error_msg = (
74
+ "Your location is not supported by google-generativeai at the moment. "
75
+ "Try to use VertexAI LLM from langchain_google_vertexai"
76
+ )
77
+ try:
78
+ if is_gemini:
79
+ return llm.client.generate_content(
80
+ contents=prompt, stream=stream, generation_config=generation_config
81
+ )
82
+ return llm.client.generate_text(prompt=prompt, **kwargs)
83
+ except google.api_core.exceptions.FailedPrecondition as exc:
84
+ if "location is not supported" in exc.message:
85
+ raise ValueError(error_msg)
64
86
 
65
87
  return _completion_with_retry(
66
88
  prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
67
89
  )
68
90
 
69
91
 
70
- def _is_gemini_model(model_name: str) -> bool:
71
- return "gemini" in model_name
72
-
73
-
74
92
  def _strip_erroneous_leading_spaces(text: str) -> str:
75
93
  """Strip erroneous leading spaces from text.
76
94
 
@@ -84,17 +102,9 @@ def _strip_erroneous_leading_spaces(text: str) -> str:
84
102
  return text
85
103
 
86
104
 
87
- class GoogleGenerativeAI(BaseLLM, BaseModel):
88
- """Google GenerativeAI models.
105
+ class _BaseGoogleGenerativeAI(BaseModel):
106
+ """Base class for Google Generative AI LLMs"""
89
107
 
90
- Example:
91
- .. code-block:: python
92
-
93
- from langchain_google_genai import GoogleGenerativeAI
94
- llm = GoogleGenerativeAI(model="gemini-pro")
95
- """
96
-
97
- client: Any #: :meta private:
98
108
  model: str = Field(
99
109
  ...,
100
110
  description="""The name of the model to use.
@@ -121,19 +131,42 @@ Supported examples:
121
131
  not return the full n completions if duplicates are generated."""
122
132
  max_retries: int = 6
123
133
  """The maximum number of retries to make when generating."""
124
-
125
- @property
126
- def is_gemini(self) -> bool:
127
- """Returns whether a model is belongs to a Gemini family or not."""
128
- return _is_gemini_model(self.model)
134
+ client_options: Optional[Dict] = Field(
135
+ None,
136
+ description=(
137
+ "A dictionary of client options to pass to the Google API client, "
138
+ "such as `api_endpoint`."
139
+ ),
140
+ )
141
+ transport: Optional[str] = Field(
142
+ None,
143
+ description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
144
+ )
129
145
 
130
146
  @property
131
147
  def lc_secrets(self) -> Dict[str, str]:
132
148
  return {"google_api_key": "GOOGLE_API_KEY"}
133
149
 
150
+ @property
151
+ def _model_family(self) -> str:
152
+ return GoogleModelFamily(self.model)
153
+
154
+
155
+ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
156
+ """Google GenerativeAI models.
157
+
158
+ Example:
159
+ .. code-block:: python
160
+
161
+ from langchain_google_genai import GoogleGenerativeAI
162
+ llm = GoogleGenerativeAI(model="gemini-pro")
163
+ """
164
+
165
+ client: Any #: :meta private:
166
+
134
167
  @root_validator()
135
168
  def validate_environment(cls, values: Dict) -> Dict:
136
- """Validate api key, python package exists."""
169
+ """Validates params and passes them to google-generativeai package."""
137
170
  google_api_key = get_from_dict_or_env(
138
171
  values, "google_api_key", "GOOGLE_API_KEY"
139
172
  )
@@ -142,9 +175,13 @@ Supported examples:
142
175
  if isinstance(google_api_key, SecretStr):
143
176
  google_api_key = google_api_key.get_secret_value()
144
177
 
145
- genai.configure(api_key=google_api_key)
178
+ genai.configure(
179
+ api_key=google_api_key,
180
+ transport=values.get("transport"),
181
+ client_options=values.get("client_options"),
182
+ )
146
183
 
147
- if _is_gemini_model(model_name):
184
+ if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
148
185
  values["client"] = genai.GenerativeModel(model_name=model_name)
149
186
  else:
150
187
  values["client"] = genai
@@ -180,7 +217,7 @@ Supported examples:
180
217
  "candidate_count": self.n,
181
218
  }
182
219
  for prompt in prompts:
183
- if self.is_gemini:
220
+ if self._model_family == GoogleModelFamily.GEMINI:
184
221
  res = _completion_with_retry(
185
222
  self,
186
223
  prompt=prompt,
@@ -256,7 +293,11 @@ Supported examples:
256
293
  Returns:
257
294
  The integer number of tokens in the text.
258
295
  """
259
- if self.is_gemini:
260
- raise ValueError("Counting tokens is not yet supported!")
261
- result = self.client.count_text_tokens(model=self.model, prompt=text)
262
- return result["token_count"]
296
+ if self._model_family == GoogleModelFamily.GEMINI:
297
+ result = self.client.count_tokens(text)
298
+ token_count = result.total_tokens
299
+ else:
300
+ result = self.client.count_text_tokens(model=self.model, prompt=text)
301
+ token_count = result["token_count"]
302
+
303
+ return token_count
@@ -1,10 +1,14 @@
1
1
  [tool.poetry]
2
2
  name = "langchain-google-genai"
3
- version = "0.0.5"
3
+ version = "0.0.7"
4
4
  description = "An integration package connecting Google's genai package and LangChain"
5
5
  authors = []
6
6
  readme = "README.md"
7
- repository = "https://github.com/langchain-ai/langchain/blob/master/libs/partners/google-genai"
7
+ repository = "https://github.com/langchain-ai/langchain"
8
+ license = "MIT"
9
+
10
+ [tool.poetry.urls]
11
+ "Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/google-genai"
8
12
 
9
13
  [tool.poetry.dependencies]
10
14
  python = ">=3.9,<4.0"