langchain-google-genai 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import (
4
+ Dict,
5
+ List,
6
+ Type,
7
+ Union,
8
+ )
9
+
10
+ import google.ai.generativelanguage as glm
11
+ from langchain_core.pydantic_v1 import BaseModel
12
+ from langchain_core.tools import BaseTool
13
+ from langchain_core.utils.json_schema import dereference_refs
14
+
15
+ FunctionCallType = Union[BaseTool, Type[BaseModel], Dict]
16
+
17
+ TYPE_ENUM = {
18
+ "string": glm.Type.STRING,
19
+ "number": glm.Type.NUMBER,
20
+ "integer": glm.Type.INTEGER,
21
+ "boolean": glm.Type.BOOLEAN,
22
+ "array": glm.Type.ARRAY,
23
+ "object": glm.Type.OBJECT,
24
+ }
25
+
26
+
27
+ def convert_to_genai_function_declarations(
28
+ function_calls: List[FunctionCallType],
29
+ ) -> List[glm.Tool]:
30
+ return [
31
+ glm.Tool(
32
+ function_declarations=[_convert_to_genai_function(fc)],
33
+ )
34
+ for fc in function_calls
35
+ ]
36
+
37
+
38
+ def _convert_to_genai_function(fc: FunctionCallType) -> glm.FunctionDeclaration:
39
+ if isinstance(fc, BaseTool):
40
+ return _convert_tool_to_genai_function(fc)
41
+ elif isinstance(fc, type) and issubclass(fc, BaseModel):
42
+ return _convert_pydantic_to_genai_function(fc)
43
+ elif isinstance(fc, dict):
44
+ return glm.FunctionDeclaration(
45
+ name=fc["name"],
46
+ description=fc.get("description"),
47
+ parameters={
48
+ "properties": {
49
+ k: {
50
+ "type_": TYPE_ENUM[v["type"]],
51
+ "description": v.get("description"),
52
+ }
53
+ for k, v in fc["parameters"]["properties"].items()
54
+ },
55
+ "required": fc["parameters"].get("required", []),
56
+ "type_": TYPE_ENUM[fc["parameters"]["type"]],
57
+ },
58
+ )
59
+ else:
60
+ raise ValueError(f"Unsupported function call type {fc}")
61
+
62
+
63
+ def _convert_tool_to_genai_function(tool: BaseTool) -> glm.FunctionDeclaration:
64
+ if tool.args_schema:
65
+ schema = dereference_refs(tool.args_schema.schema())
66
+ schema.pop("definitions", None)
67
+
68
+ return glm.FunctionDeclaration(
69
+ name=tool.name or schema["title"],
70
+ description=tool.description or schema["description"],
71
+ parameters={
72
+ "properties": {
73
+ k: {
74
+ "type_": TYPE_ENUM[v["type"]],
75
+ "description": v.get("description"),
76
+ }
77
+ for k, v in schema["properties"].items()
78
+ },
79
+ "required": schema["required"],
80
+ "type_": TYPE_ENUM[schema["type"]],
81
+ },
82
+ )
83
+ else:
84
+ return glm.FunctionDeclaration(
85
+ name=tool.name,
86
+ description=tool.description,
87
+ parameters={
88
+ "properties": {
89
+ "__arg1": {"type_": TYPE_ENUM["string"]},
90
+ },
91
+ "required": ["__arg1"],
92
+ "type_": TYPE_ENUM["object"],
93
+ },
94
+ )
95
+
96
+
97
+ def _convert_pydantic_to_genai_function(
98
+ pydantic_model: Type[BaseModel],
99
+ ) -> glm.FunctionDeclaration:
100
+ schema = dereference_refs(pydantic_model.schema())
101
+ schema.pop("definitions", None)
102
+ return glm.FunctionDeclaration(
103
+ name=schema["title"],
104
+ description=schema.get("description", ""),
105
+ parameters={
106
+ "properties": {
107
+ k: {
108
+ "type_": TYPE_ENUM[v["type"]],
109
+ "description": v.get("description"),
110
+ }
111
+ for k, v in schema["properties"].items()
112
+ },
113
+ "required": schema["required"],
114
+ "type_": TYPE_ENUM[schema["type"]],
115
+ },
116
+ )
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import base64
4
+ import json
4
5
  import logging
5
6
  import os
6
7
  from io import BytesIO
@@ -15,14 +16,17 @@ from typing import (
15
16
  Optional,
16
17
  Sequence,
17
18
  Tuple,
18
- Type,
19
19
  Union,
20
20
  cast,
21
21
  )
22
22
  from urllib.parse import urlparse
23
23
 
24
+ import google.ai.generativelanguage as glm
25
+ import google.api_core
26
+
24
27
  # TODO: remove ignore once the google package is published with types
25
28
  import google.generativeai as genai # type: ignore[import]
29
+ import proto # type: ignore[import]
26
30
  import requests
27
31
  from langchain_core.callbacks.manager import (
28
32
  AsyncCallbackManagerForLLMRun,
@@ -33,14 +37,12 @@ from langchain_core.messages import (
33
37
  AIMessage,
34
38
  AIMessageChunk,
35
39
  BaseMessage,
36
- ChatMessage,
37
- ChatMessageChunk,
40
+ FunctionMessage,
38
41
  HumanMessage,
39
- HumanMessageChunk,
40
42
  SystemMessage,
41
43
  )
42
44
  from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
43
- from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
45
+ from langchain_core.pydantic_v1 import SecretStr, root_validator
44
46
  from langchain_core.utils import get_from_dict_or_env
45
47
  from tenacity import (
46
48
  before_sleep_log,
@@ -51,6 +53,10 @@ from tenacity import (
51
53
  )
52
54
 
53
55
  from langchain_google_genai._common import GoogleGenerativeAIError
56
+ from langchain_google_genai._function_utils import (
57
+ convert_to_genai_function_declarations,
58
+ )
59
+ from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI
54
60
 
55
61
  IMAGE_TYPES: Tuple = ()
56
62
  try:
@@ -87,8 +93,6 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
87
93
  Callable[[Any], Any]: A retry decorator configured for handling specific
88
94
  Google API exceptions.
89
95
  """
90
- import google.api_core.exceptions
91
-
92
96
  multiplier = 2
93
97
  min_seconds = 1
94
98
  max_seconds = 60
@@ -123,14 +127,22 @@ def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
123
127
  Any: The result from the chat generation method.
124
128
  """
125
129
  retry_decorator = _create_retry_decorator()
126
- from google.api_core.exceptions import InvalidArgument # type: ignore
127
130
 
128
131
  @retry_decorator
129
132
  def _chat_with_retry(**kwargs: Any) -> Any:
130
133
  try:
131
134
  return generation_method(**kwargs)
132
- except InvalidArgument as e:
133
- # Do not retry for these errors.
135
+ # Do not retry for these errors.
136
+ except google.api_core.exceptions.FailedPrecondition as exc:
137
+ if "location is not supported" in exc.message:
138
+ error_msg = (
139
+ "Your location is not supported by google-generativeai "
140
+ "at the moment. Try to use ChatVertexAI LLM from "
141
+ "langchain_google_vertexai."
142
+ )
143
+ raise ValueError(error_msg)
144
+
145
+ except google.api_core.exceptions.InvalidArgument as e:
134
146
  raise ChatGoogleGenerativeAIError(
135
147
  f"Invalid argument provided to Gemini: {e}"
136
148
  ) from e
@@ -312,14 +324,47 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
312
324
  continue
313
325
  elif isinstance(message, AIMessage):
314
326
  role = "model"
327
+ raw_function_call = message.additional_kwargs.get("function_call")
328
+ if raw_function_call:
329
+ function_call = glm.FunctionCall(
330
+ {
331
+ "name": raw_function_call["name"],
332
+ "args": json.loads(raw_function_call["arguments"]),
333
+ }
334
+ )
335
+ parts = [glm.Part(function_call=function_call)]
336
+ else:
337
+ parts = _convert_to_parts(message.content)
315
338
  elif isinstance(message, HumanMessage):
316
339
  role = "user"
340
+ parts = _convert_to_parts(message.content)
341
+ elif isinstance(message, FunctionMessage):
342
+ role = "user"
343
+ response: Any
344
+ if not isinstance(message.content, str):
345
+ response = message.content
346
+ else:
347
+ try:
348
+ response = json.loads(message.content)
349
+ except json.JSONDecodeError:
350
+ response = message.content # leave as str representation
351
+ parts = [
352
+ glm.Part(
353
+ function_response=glm.FunctionResponse(
354
+ name=message.name,
355
+ response=(
356
+ {"output": response}
357
+ if not isinstance(response, dict)
358
+ else response
359
+ ),
360
+ )
361
+ )
362
+ ]
317
363
  else:
318
364
  raise ValueError(
319
365
  f"Unexpected message with type {type(message)} at the position {i}."
320
366
  )
321
367
 
322
- parts = _convert_to_parts(message.content)
323
368
  if raw_system_message:
324
369
  if role == "model":
325
370
  raise ValueError(
@@ -332,71 +377,51 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
332
377
  return messages
333
378
 
334
379
 
335
- def _parts_to_content(parts: List[genai.types.PartType]) -> Union[List[dict], str]:
336
- """Converts a list of Gemini API Part objects into a list of LangChain messages."""
337
- if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data:
338
- # Simple text response. The typical response
339
- return parts[0].text
340
- elif not parts:
341
- logger.warning("Gemini produced an empty response.")
342
- return ""
343
- messages = []
344
- for part in parts:
345
- if part.text is not None:
346
- messages.append(
347
- {
348
- "type": "text",
349
- "text": part.text,
350
- }
351
- )
352
- else:
353
- # TODO: Handle inline_data if that's a thing?
354
- raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
355
- return messages
380
+ def _parse_response_candidate(
381
+ response_candidate: glm.Candidate, stream: bool
382
+ ) -> AIMessage:
383
+ first_part = response_candidate.content.parts[0]
384
+ if first_part.function_call:
385
+ function_call = proto.Message.to_dict(first_part.function_call)
386
+ function_call["arguments"] = json.dumps(function_call.pop("args", {}))
387
+ return (AIMessageChunk if stream else AIMessage)(
388
+ content="", additional_kwargs={"function_call": function_call}
389
+ )
390
+ else:
391
+ parts = response_candidate.content.parts
392
+
393
+ if len(parts) == 1 and parts[0].text:
394
+ content: Union[str, List[Union[str, Dict]]] = parts[0].text
395
+ else:
396
+ content = [proto.Message.to_dict(part) for part in parts]
397
+ return (AIMessageChunk if stream else AIMessage)(
398
+ content=content, additional_kwargs={}
399
+ )
356
400
 
357
401
 
358
402
  def _response_to_result(
359
- response: genai.types.GenerateContentResponse,
360
- ai_msg_t: Type[BaseMessage] = AIMessage,
361
- human_msg_t: Type[BaseMessage] = HumanMessage,
362
- chat_msg_t: Type[BaseMessage] = ChatMessage,
363
- generation_t: Type[ChatGeneration] = ChatGeneration,
403
+ response: glm.GenerateContentResponse,
404
+ stream: bool = False,
364
405
  ) -> ChatResult:
365
406
  """Converts a PaLM API response into a LangChain ChatResult."""
366
- llm_output = {}
367
- if response.prompt_feedback:
368
- try:
369
- prompt_feedback = type(response.prompt_feedback).to_dict(
370
- response.prompt_feedback, use_integers_for_enums=False
371
- )
372
- llm_output["prompt_feedback"] = prompt_feedback
373
- except Exception as e:
374
- logger.debug(f"Unable to convert prompt_feedback to dict: {e}")
407
+ llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
375
408
 
376
409
  generations: List[ChatGeneration] = []
377
410
 
378
- role_map = {
379
- "model": ai_msg_t,
380
- "user": human_msg_t,
381
- }
382
411
  for candidate in response.candidates:
383
- content = candidate.content
384
- parts_content = _parts_to_content(content.parts)
385
- if content.role not in role_map:
386
- logger.warning(
387
- f"Unrecognized role: {content.role}. Treating as a ChatMessage."
388
- )
389
- msg = chat_msg_t(content=parts_content, role=content.role)
390
- else:
391
- msg = role_map[content.role](content=parts_content)
392
412
  generation_info = {}
393
413
  if candidate.finish_reason:
394
414
  generation_info["finish_reason"] = candidate.finish_reason.name
395
- if candidate.safety_ratings:
396
- generation_info["safety_ratings"] = [
397
- type(rating).to_dict(rating) for rating in candidate.safety_ratings
398
- ]
399
- generations.append(generation_t(message=msg, generation_info=generation_info))
415
+ generation_info["safety_ratings"] = [
416
+ proto.Message.to_dict(safety_rating, use_integers_for_enums=False)
417
+ for safety_rating in candidate.safety_ratings
418
+ ]
419
+ generations.append(
420
+ (ChatGenerationChunk if stream else ChatGeneration)(
421
+ message=_parse_response_candidate(candidate, stream=stream),
422
+ generation_info=generation_info,
423
+ )
424
+ )
400
425
  if not response.candidates:
401
426
  # Likely a "prompt feedback" violation (e.g., toxic input)
402
427
  # Raising an error would be different than how OpenAI handles it,
@@ -405,11 +430,16 @@ def _response_to_result(
405
430
  "Gemini produced an empty response. Continuing with empty message\n"
406
431
  f"Feedback: {response.prompt_feedback}"
407
432
  )
408
- generations = [generation_t(message=ai_msg_t(content=""), generation_info={})]
433
+ generations = [
434
+ (ChatGenerationChunk if stream else ChatGeneration)(
435
+ message=(AIMessageChunk if stream else AIMessage)(content=""),
436
+ generation_info={},
437
+ )
438
+ ]
409
439
  return ChatResult(generations=generations, llm_output=llm_output)
410
440
 
411
441
 
412
- class ChatGoogleGenerativeAI(BaseChatModel):
442
+ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
413
443
  """`Google Generative AI` Chat models API.
414
444
 
415
445
  To use, you must have either:
@@ -427,53 +457,13 @@ class ChatGoogleGenerativeAI(BaseChatModel):
427
457
 
428
458
  """
429
459
 
430
- model: str = Field(
431
- ...,
432
- description="""The name of the model to use.
433
- Supported examples:
434
- - gemini-pro""",
435
- )
436
- max_output_tokens: int = Field(default=None, description="Max output tokens")
437
-
438
460
  client: Any #: :meta private:
439
- google_api_key: Optional[SecretStr] = None
440
- temperature: Optional[float] = None
441
- """Run inference with this temperature. Must by in the closed
442
- interval [0.0, 1.0]."""
443
- top_k: Optional[int] = None
444
- """Decode using top-k sampling: consider the set of top_k most probable tokens.
445
- Must be positive."""
446
- top_p: Optional[int] = None
447
- """The maximum cumulative probability of tokens to consider when sampling.
448
-
449
- The model uses combined Top-k and nucleus sampling.
450
-
451
- Tokens are sorted based on their assigned probabilities so
452
- that only the most likely tokens are considered. Top-k
453
- sampling directly limits the maximum number of tokens to
454
- consider, while Nucleus sampling limits number of tokens
455
- based on the cumulative probability.
456
-
457
- Note: The default value varies by model, see the
458
- `Model.top_p` attribute of the `Model` returned the
459
- `genai.get_model` function.
460
- """
461
- n: int = Field(default=1, alias="candidate_count")
462
- """Number of chat completions to generate for each prompt. Note that the API may
463
- not return the full n completions if duplicates are generated."""
461
+
464
462
  convert_system_message_to_human: bool = False
465
463
  """Whether to merge any leading SystemMessage into the following HumanMessage.
466
464
 
467
465
  Gemini does not support system messages; any unsupported messages will
468
466
  raise an error."""
469
- client_options: Optional[Dict] = Field(
470
- None,
471
- description="Client options to pass to the Google API client.",
472
- )
473
- transport: Optional[str] = Field(
474
- None,
475
- description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
476
- )
477
467
 
478
468
  class Config:
479
469
  allow_population_by_field_name = True
@@ -486,10 +476,6 @@ Supported examples:
486
476
  def _llm_type(self) -> str:
487
477
  return "chat-google-generative-ai"
488
478
 
489
- @property
490
- def _is_geminiai(self) -> bool:
491
- return self.model is not None and "gemini" in self.model
492
-
493
479
  @classmethod
494
480
  def is_lc_serializable(self) -> bool:
495
481
  return True
@@ -560,7 +546,11 @@ Supported examples:
560
546
  run_manager: Optional[CallbackManagerForLLMRun] = None,
561
547
  **kwargs: Any,
562
548
  ) -> ChatResult:
563
- params, chat, message = self._prepare_chat(messages, stop=stop)
549
+ params, chat, message = self._prepare_chat(
550
+ messages,
551
+ stop=stop,
552
+ functions=kwargs.get("functions"),
553
+ )
564
554
  response: genai.types.GenerateContentResponse = _chat_with_retry(
565
555
  content=message,
566
556
  **params,
@@ -575,7 +565,11 @@ Supported examples:
575
565
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
576
566
  **kwargs: Any,
577
567
  ) -> ChatResult:
578
- params, chat, message = self._prepare_chat(messages, stop=stop)
568
+ params, chat, message = self._prepare_chat(
569
+ messages,
570
+ stop=stop,
571
+ functions=kwargs.get("functions"),
572
+ )
579
573
  response: genai.types.GenerateContentResponse = await _achat_with_retry(
580
574
  content=message,
581
575
  **params,
@@ -590,7 +584,11 @@ Supported examples:
590
584
  run_manager: Optional[CallbackManagerForLLMRun] = None,
591
585
  **kwargs: Any,
592
586
  ) -> Iterator[ChatGenerationChunk]:
593
- params, chat, message = self._prepare_chat(messages, stop=stop)
587
+ params, chat, message = self._prepare_chat(
588
+ messages,
589
+ stop=stop,
590
+ functions=kwargs.get("functions"),
591
+ )
594
592
  response: genai.types.GenerateContentResponse = _chat_with_retry(
595
593
  content=message,
596
594
  **params,
@@ -598,17 +596,11 @@ Supported examples:
598
596
  stream=True,
599
597
  )
600
598
  for chunk in response:
601
- _chat_result = _response_to_result(
602
- chunk,
603
- ai_msg_t=AIMessageChunk,
604
- human_msg_t=HumanMessageChunk,
605
- chat_msg_t=ChatMessageChunk,
606
- generation_t=ChatGenerationChunk,
607
- )
599
+ _chat_result = _response_to_result(chunk, stream=True)
608
600
  gen = cast(ChatGenerationChunk, _chat_result.generations[0])
609
- yield gen
610
601
  if run_manager:
611
602
  run_manager.on_llm_new_token(gen.text)
603
+ yield gen
612
604
 
613
605
  async def _astream(
614
606
  self,
@@ -617,24 +609,22 @@ Supported examples:
617
609
  run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
618
610
  **kwargs: Any,
619
611
  ) -> AsyncIterator[ChatGenerationChunk]:
620
- params, chat, message = self._prepare_chat(messages, stop=stop)
612
+ params, chat, message = self._prepare_chat(
613
+ messages,
614
+ stop=stop,
615
+ functions=kwargs.get("functions"),
616
+ )
621
617
  async for chunk in await _achat_with_retry(
622
618
  content=message,
623
619
  **params,
624
620
  generation_method=chat.send_message_async,
625
621
  stream=True,
626
622
  ):
627
- _chat_result = _response_to_result(
628
- chunk,
629
- ai_msg_t=AIMessageChunk,
630
- human_msg_t=HumanMessageChunk,
631
- chat_msg_t=ChatMessageChunk,
632
- generation_t=ChatGenerationChunk,
633
- )
623
+ _chat_result = _response_to_result(chunk, stream=True)
634
624
  gen = cast(ChatGenerationChunk, _chat_result.generations[0])
635
- yield gen
636
625
  if run_manager:
637
626
  await run_manager.on_llm_new_token(gen.text)
627
+ yield gen
638
628
 
639
629
  def _prepare_chat(
640
630
  self,
@@ -642,11 +632,37 @@ Supported examples:
642
632
  stop: Optional[List[str]] = None,
643
633
  **kwargs: Any,
644
634
  ) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
635
+ client = self.client
636
+ functions = kwargs.pop("functions", None)
637
+ if functions:
638
+ tools = convert_to_genai_function_declarations(functions)
639
+ client = genai.GenerativeModel(model_name=self.model, tools=tools)
640
+
645
641
  params = self._prepare_params(stop, **kwargs)
646
642
  history = _parse_chat_history(
647
643
  messages,
648
644
  convert_system_message_to_human=self.convert_system_message_to_human,
649
645
  )
650
646
  message = history.pop()
651
- chat = self.client.start_chat(history=history)
647
+ chat = client.start_chat(history=history)
652
648
  return params, chat, message
649
+
650
+ def get_num_tokens(self, text: str) -> int:
651
+ """Get the number of tokens present in the text.
652
+
653
+ Useful for checking if an input will fit in a model's context window.
654
+
655
+ Args:
656
+ text: The string input to tokenize.
657
+
658
+ Returns:
659
+ The integer number of tokens in the text.
660
+ """
661
+ if self._model_family == GoogleModelFamily.GEMINI:
662
+ result = self.client.count_tokens(text)
663
+ token_count = result.total_tokens
664
+ else:
665
+ result = self.client.count_text_tokens(model=self.model, prompt=text)
666
+ token_count = result["token_count"]
667
+
668
+ return token_count
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from enum import Enum, auto
3
4
  from typing import Any, Callable, Dict, Iterator, List, Optional, Union
4
5
 
5
6
  import google.api_core
@@ -15,6 +16,19 @@ from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validat
15
16
  from langchain_core.utils import get_from_dict_or_env
16
17
 
17
18
 
19
+ class GoogleModelFamily(str, Enum):
20
+ GEMINI = auto()
21
+ PALM = auto()
22
+
23
+ @classmethod
24
+ def _missing_(cls, value: Any) -> Optional["GoogleModelFamily"]:
25
+ if "gemini" in value.lower():
26
+ return GoogleModelFamily.GEMINI
27
+ elif "text-bison" in value.lower():
28
+ return GoogleModelFamily.PALM
29
+ return None
30
+
31
+
18
32
  def _create_retry_decorator(
19
33
  llm: BaseLLM,
20
34
  *,
@@ -56,21 +70,25 @@ def _completion_with_retry(
56
70
  prompt: LanguageModelInput, is_gemini: bool, stream: bool, **kwargs: Any
57
71
  ) -> Any:
58
72
  generation_config = kwargs.get("generation_config", {})
59
- if is_gemini:
60
- return llm.client.generate_content(
61
- contents=prompt, stream=stream, generation_config=generation_config
62
- )
63
- return llm.client.generate_text(prompt=prompt, **kwargs)
73
+ error_msg = (
74
+ "Your location is not supported by google-generativeai at the moment. "
75
+ "Try to use VertexAI LLM from langchain_google_vertexai"
76
+ )
77
+ try:
78
+ if is_gemini:
79
+ return llm.client.generate_content(
80
+ contents=prompt, stream=stream, generation_config=generation_config
81
+ )
82
+ return llm.client.generate_text(prompt=prompt, **kwargs)
83
+ except google.api_core.exceptions.FailedPrecondition as exc:
84
+ if "location is not supported" in exc.message:
85
+ raise ValueError(error_msg)
64
86
 
65
87
  return _completion_with_retry(
66
88
  prompt=prompt, is_gemini=is_gemini, stream=stream, **kwargs
67
89
  )
68
90
 
69
91
 
70
- def _is_gemini_model(model_name: str) -> bool:
71
- return "gemini" in model_name
72
-
73
-
74
92
  def _strip_erroneous_leading_spaces(text: str) -> str:
75
93
  """Strip erroneous leading spaces from text.
76
94
 
@@ -84,17 +102,9 @@ def _strip_erroneous_leading_spaces(text: str) -> str:
84
102
  return text
85
103
 
86
104
 
87
- class GoogleGenerativeAI(BaseLLM, BaseModel):
88
- """Google GenerativeAI models.
89
-
90
- Example:
91
- .. code-block:: python
105
+ class _BaseGoogleGenerativeAI(BaseModel):
106
+ """Base class for Google Generative AI LLMs"""
92
107
 
93
- from langchain_google_genai import GoogleGenerativeAI
94
- llm = GoogleGenerativeAI(model="gemini-pro")
95
- """
96
-
97
- client: Any #: :meta private:
98
108
  model: str = Field(
99
109
  ...,
100
110
  description="""The name of the model to use.
@@ -133,15 +143,39 @@ Supported examples:
133
143
  description="A string, one of: [`rest`, `grpc`, `grpc_asyncio`].",
134
144
  )
135
145
 
136
- @property
137
- def is_gemini(self) -> bool:
138
- """Returns whether a model is belongs to a Gemini family or not."""
139
- return _is_gemini_model(self.model)
140
-
141
146
  @property
142
147
  def lc_secrets(self) -> Dict[str, str]:
143
148
  return {"google_api_key": "GOOGLE_API_KEY"}
144
149
 
150
+ @property
151
+ def _model_family(self) -> str:
152
+ return GoogleModelFamily(self.model)
153
+
154
+ @property
155
+ def _identifying_params(self) -> Dict[str, Any]:
156
+ """Get the identifying parameters."""
157
+ return {
158
+ "model": self.model,
159
+ "temperature": self.temperature,
160
+ "top_p": self.top_p,
161
+ "top_k": self.top_k,
162
+ "max_output_tokens": self.max_output_tokens,
163
+ "candidate_count": self.n,
164
+ }
165
+
166
+
167
+ class GoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseLLM):
168
+ """Google GenerativeAI models.
169
+
170
+ Example:
171
+ .. code-block:: python
172
+
173
+ from langchain_google_genai import GoogleGenerativeAI
174
+ llm = GoogleGenerativeAI(model="gemini-pro")
175
+ """
176
+
177
+ client: Any #: :meta private:
178
+
145
179
  @root_validator()
146
180
  def validate_environment(cls, values: Dict) -> Dict:
147
181
  """Validates params and passes them to google-generativeai package."""
@@ -159,7 +193,7 @@ Supported examples:
159
193
  client_options=values.get("client_options"),
160
194
  )
161
195
 
162
- if _is_gemini_model(model_name):
196
+ if GoogleModelFamily(model_name) == GoogleModelFamily.GEMINI:
163
197
  values["client"] = genai.GenerativeModel(model_name=model_name)
164
198
  else:
165
199
  values["client"] = genai
@@ -195,7 +229,7 @@ Supported examples:
195
229
  "candidate_count": self.n,
196
230
  }
197
231
  for prompt in prompts:
198
- if self.is_gemini:
232
+ if self._model_family == GoogleModelFamily.GEMINI:
199
233
  res = _completion_with_retry(
200
234
  self,
201
235
  prompt=prompt,
@@ -271,7 +305,11 @@ Supported examples:
271
305
  Returns:
272
306
  The integer number of tokens in the text.
273
307
  """
274
- if self.is_gemini:
275
- raise ValueError("Counting tokens is not yet supported!")
276
- result = self.client.count_text_tokens(model=self.model, prompt=text)
277
- return result["token_count"]
308
+ if self._model_family == GoogleModelFamily.GEMINI:
309
+ result = self.client.count_tokens(text)
310
+ token_count = result.total_tokens
311
+ else:
312
+ result = self.client.count_text_tokens(model=self.model, prompt=text)
313
+ token_count = result["token_count"]
314
+
315
+ return token_count
@@ -1,13 +1,16 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: langchain-google-genai
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: An integration package connecting Google's genai package and LangChain
5
5
  Home-page: https://github.com/langchain-ai/langchain
6
+ License: MIT
6
7
  Requires-Python: >=3.9,<4.0
8
+ Classifier: License :: OSI Approved :: MIT License
7
9
  Classifier: Programming Language :: Python :: 3
8
10
  Classifier: Programming Language :: Python :: 3.9
9
11
  Classifier: Programming Language :: Python :: 3.10
10
12
  Classifier: Programming Language :: Python :: 3.11
13
+ Classifier: Programming Language :: Python :: 3.12
11
14
  Provides-Extra: images
12
15
  Requires-Dist: google-generativeai (>=0.3.1,<0.4.0)
13
16
  Requires-Dist: langchain-core (>=0.1,<0.2)
@@ -0,0 +1,11 @@
1
+ langchain_google_genai/__init__.py,sha256=cDMb1xbsenQtYBACNP0dYPwA7Rt015MT7HC_XP3X-4Y,2304
2
+ langchain_google_genai/_common.py,sha256=1r0VrrBSTZfGprmICZ5OV-W5SK31jKRFFCNE3vJ3jmk,136
3
+ langchain_google_genai/_function_utils.py,sha256=9IVMPQq5lQB8F_whG3mrGOM_tjmP-TMEY3URHfJnjgI,3640
4
+ langchain_google_genai/chat_models.py,sha256=flq1xYC2OYoijOtU9qJy12CZwJ8sdY3va3x0004pl1M,23208
5
+ langchain_google_genai/embeddings.py,sha256=EMa-sDGXUpAPMSyjA2-YXF_TGrlSlqljNeqysAh574s,3951
6
+ langchain_google_genai/llms.py,sha256=Kk7fCrWbbfR1tpiFRYBX6fgEhgAYq7HQVSNvR6UvWqY,10990
7
+ langchain_google_genai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ langchain_google_genai-0.0.8.dist-info/LICENSE,sha256=DppmdYJVSc1jd0aio6ptnMUn5tIHrdAhQ12SclEBfBg,1072
9
+ langchain_google_genai-0.0.8.dist-info/METADATA,sha256=mwV8mJuT-jVGcq-HQa7IiWjd6HDvVOIVa01lSTe1R-A,2851
10
+ langchain_google_genai-0.0.8.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
11
+ langchain_google_genai-0.0.8.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.7.0
2
+ Generator: poetry-core 1.8.1
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,10 +0,0 @@
1
- langchain_google_genai/__init__.py,sha256=cDMb1xbsenQtYBACNP0dYPwA7Rt015MT7HC_XP3X-4Y,2304
2
- langchain_google_genai/_common.py,sha256=1r0VrrBSTZfGprmICZ5OV-W5SK31jKRFFCNE3vJ3jmk,136
3
- langchain_google_genai/chat_models.py,sha256=451I60uXY24h7bpTxkF3HNccdY0vIRm6wly1PL9URXk,22894
4
- langchain_google_genai/embeddings.py,sha256=EMa-sDGXUpAPMSyjA2-YXF_TGrlSlqljNeqysAh574s,3951
5
- langchain_google_genai/llms.py,sha256=xoypoTqhw-p3_1Htk8yURoTpRtjVjFvsqP1WPWwk8eg,9707
6
- langchain_google_genai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- langchain_google_genai-0.0.6.dist-info/LICENSE,sha256=DppmdYJVSc1jd0aio6ptnMUn5tIHrdAhQ12SclEBfBg,1072
8
- langchain_google_genai-0.0.6.dist-info/METADATA,sha256=9IOHkgYT2g3XbCU0kdxETvOJaK5z4FjckdV6I1hmn2w,2736
9
- langchain_google_genai-0.0.6.dist-info/WHEEL,sha256=d2fvjOD7sXsVzChCqf0Ty0JbHKBaLYwDbGQDwQTnJ50,88
10
- langchain_google_genai-0.0.6.dist-info/RECORD,,