bisheng-langchain 1.0.1__py3-none-any.whl → 1.2.0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. bisheng_langchain/agents/chatglm_functions_agent/base.py +6 -3
  2. bisheng_langchain/agents/llm_functions_agent/base.py +6 -3
  3. bisheng_langchain/chains/qa_generation/base.py +1 -1
  4. bisheng_langchain/chains/transform.py +1 -1
  5. bisheng_langchain/chat_models/host_llm.py +13 -8
  6. bisheng_langchain/chat_models/minimax.py +4 -7
  7. bisheng_langchain/chat_models/proxy_llm.py +5 -7
  8. bisheng_langchain/chat_models/qwen.py +5 -7
  9. bisheng_langchain/chat_models/sensetime.py +5 -7
  10. bisheng_langchain/chat_models/wenxin.py +4 -7
  11. bisheng_langchain/chat_models/xunfeiai.py +4 -7
  12. bisheng_langchain/chat_models/zhipuai.py +4 -7
  13. bisheng_langchain/embeddings/host_embedding.py +6 -4
  14. bisheng_langchain/embeddings/huggingfacegte.py +2 -2
  15. bisheng_langchain/embeddings/huggingfacemultilingual.py +2 -2
  16. bisheng_langchain/embeddings/wenxin.py +5 -8
  17. bisheng_langchain/gpts/agent_types/llm_functions_agent.py +6 -78
  18. bisheng_langchain/gpts/agent_types/llm_react_agent.py +2 -5
  19. bisheng_langchain/gpts/tools/api_tools/base.py +5 -7
  20. bisheng_langchain/gpts/tools/api_tools/firecrawl.py +1 -1
  21. bisheng_langchain/gpts/tools/api_tools/flow.py +1 -1
  22. bisheng_langchain/gpts/tools/api_tools/jina.py +6 -4
  23. bisheng_langchain/gpts/tools/api_tools/macro_data.py +3 -3
  24. bisheng_langchain/gpts/tools/api_tools/openapi.py +8 -6
  25. bisheng_langchain/gpts/tools/api_tools/sina.py +1 -1
  26. bisheng_langchain/gpts/tools/api_tools/tianyancha.py +6 -3
  27. bisheng_langchain/gpts/tools/bing_search/tool.py +2 -2
  28. bisheng_langchain/gpts/tools/calculator/tool.py +2 -2
  29. bisheng_langchain/gpts/tools/code_interpreter/tool.py +2 -2
  30. bisheng_langchain/gpts/tools/dalle_image_generator/tool.py +7 -11
  31. bisheng_langchain/gpts/tools/get_current_time/tool.py +1 -1
  32. bisheng_langchain/gpts/tools/message/dingding.py +1 -2
  33. bisheng_langchain/gpts/tools/message/email.py +2 -4
  34. bisheng_langchain/gpts/tools/message/feishu.py +10 -11
  35. bisheng_langchain/gpts/tools/message/wechat.py +2 -3
  36. bisheng_langchain/gpts/tools/sql_agent/tool.py +23 -20
  37. bisheng_langchain/input_output/input.py +7 -11
  38. bisheng_langchain/input_output/output.py +2 -6
  39. bisheng_langchain/memory/redis.py +3 -3
  40. bisheng_langchain/rag/bisheng_rag_chain.py +2 -8
  41. bisheng_langchain/rag/bisheng_rag_tool.py +1 -1
  42. bisheng_langchain/rag/init_retrievers/baseline_vector_retriever.py +1 -1
  43. bisheng_langchain/rag/init_retrievers/keyword_retriever.py +1 -1
  44. bisheng_langchain/rag/init_retrievers/mix_retriever.py +1 -1
  45. bisheng_langchain/rag/init_retrievers/smaller_chunks_retriever.py +2 -2
  46. bisheng_langchain/retrievers/ensemble.py +3 -2
  47. bisheng_langchain/utils/azure_dalle_image_generator.py +3 -2
  48. bisheng_langchain/utils/requests.py +10 -19
  49. bisheng_langchain/vectorstores/retriever.py +4 -7
  50. {bisheng_langchain-1.0.1.dist-info → bisheng_langchain-1.2.0.dev1.dist-info}/METADATA +13 -13
  51. {bisheng_langchain-1.0.1.dist-info → bisheng_langchain-1.2.0.dev1.dist-info}/RECORD +53 -53
  52. {bisheng_langchain-1.0.1.dist-info → bisheng_langchain-1.2.0.dev1.dist-info}/WHEEL +1 -1
  53. {bisheng_langchain-1.0.1.dist-info → bisheng_langchain-1.2.0.dev1.dist-info}/top_level.txt +0 -0
@@ -20,10 +20,12 @@ class JinaTool(BaseModel):
20
20
  """get url from jina api"""
21
21
  url = "https://r.jina.ai/" + target_url
22
22
 
23
- headers = {
24
- "Content-Type": "application/json",
25
- "Authorization": "Bearer " + self.jina_api_key,
26
- }
23
+ headers = None
24
+ if self.jina_api_key and len(self.jina_api_key)>0 :
25
+ headers = {
26
+ "Content-Type": "application/json",
27
+ "Authorization": "Bearer " + self.jina_api_key,
28
+ }
27
29
 
28
30
  response = requests.get(url, headers=headers)
29
31
 
@@ -4,15 +4,15 @@ from typing import Any
4
4
 
5
5
  import pandas as pd
6
6
  import requests
7
- from langchain.pydantic_v1 import BaseModel, Field
7
+ from pydantic import BaseModel, Field
8
8
  from langchain_core.tools import BaseTool
9
9
 
10
10
  from .base import MultArgsSchemaTool
11
11
 
12
12
 
13
13
  class QueryArg(BaseModel):
14
- start_date: str = Field(default='', description='开始月份, 使用YYYY-MM-DD 方式表示', example='2023-01-01')
15
- end_date: str = Field(default='', description='结束月份,使用YYYY-MM-DD 方式表示', example='2023-05-01')
14
+ start_date: str = Field(default='', description='开始月份, 使用YYYY-MM-DD 方式表示', examples=['2023-01-01'])
15
+ end_date: str = Field(default='', description='结束月份,使用YYYY-MM-DD 方式表示', examples=['2023-05-01'])
16
16
 
17
17
 
18
18
  class MacroData(BaseModel):
@@ -9,11 +9,11 @@ from .base import APIToolBase, Field, MultArgsSchemaTool
9
9
 
10
10
  class OpenApiTools(APIToolBase):
11
11
 
12
- api_key: Optional[str]
13
- api_location: Optional[str]
14
- parameter_name: Optional[str]
12
+ api_key: Optional[str] = None
13
+ api_location: Optional[str] = None
14
+ parameter_name: Optional[str] = None
15
15
 
16
- def get_real_path(self, path_params: dict|None):
16
+ def get_real_path(self, path_params: dict | None):
17
17
  path = self.params['path']
18
18
  if path_params:
19
19
  path = path.format(**path_params)
@@ -47,8 +47,10 @@ class OpenApiTools(APIToolBase):
47
47
  # elif self.params['parameter_name']:
48
48
  # params.update({self.params['parameter_name']:self.api_key})
49
49
  api_location = self.params.get('api_location')
50
- if (api_location == "query") or (hasattr(self, 'api_location') and self.api_location == "query"):
51
- parameter_name = getattr(self, 'parameter_name', None) or self.params.get('parameter_name')
50
+ if (api_location == 'query') or (hasattr(self, 'api_location')
51
+ and self.api_location == 'query'):
52
+ parameter_name = getattr(self, 'parameter_name',
53
+ None) or self.params.get('parameter_name')
52
54
  if parameter_name:
53
55
  params.update({parameter_name: self.api_key})
54
56
  return params, json_data, path_params
@@ -7,7 +7,7 @@ import re
7
7
  from datetime import datetime
8
8
  from typing import List, Type
9
9
 
10
- from langchain_core.pydantic_v1 import BaseModel, Field
10
+ from pydantic import BaseModel, Field
11
11
  from loguru import logger
12
12
 
13
13
  from .base import APIToolBase
@@ -3,8 +3,9 @@ from __future__ import annotations
3
3
 
4
4
  from typing import Any, Dict, Type
5
5
 
6
+ from pydantic import model_validator, BaseModel, Field
7
+
6
8
  from bisheng_langchain.utils.requests import Requests, RequestsWrapper
7
- from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
8
9
 
9
10
  from .base import APIToolBase
10
11
 
@@ -19,7 +20,8 @@ class CompanyInfo(APIToolBase):
19
20
  api_key: str = None
20
21
  args_schema: Type[BaseModel] = InputArgs
21
22
 
22
- @root_validator(pre=True)
23
+ @model_validator(mode='before')
24
+ @classmethod
23
25
  def build_header(cls, values: Dict[str, Any]) -> Dict[str, Any]:
24
26
  """Build headers that were passed in."""
25
27
  if not values.get('api_key'):
@@ -30,7 +32,8 @@ class CompanyInfo(APIToolBase):
30
32
  values['headers'] = headers
31
33
  return values
32
34
 
33
- @root_validator()
35
+ @model_validator(mode='before')
36
+ @classmethod
34
37
  def validate_environment(cls, values: Dict) -> Dict:
35
38
  """Validate that api key and python package exists in environment."""
36
39
  timeout = values.get('request_timeout', 30)
@@ -2,7 +2,7 @@
2
2
 
3
3
  from typing import Optional, Type
4
4
 
5
- from langchain.pydantic_v1 import BaseModel, Field
5
+ from pydantic import BaseModel, Field
6
6
  from langchain_community.utilities.bing_search import BingSearchAPIWrapper
7
7
  from langchain_core.callbacks import CallbackManagerForToolRun
8
8
  from langchain_core.tools import BaseTool
@@ -43,7 +43,7 @@ class BingSearchResults(BaseTool):
43
43
  "Input should be a search query. Output is a JSON array of the query results"
44
44
  )
45
45
  num_results: int = 5
46
- args_schema = BingSearchInput
46
+ args_schema: Type[BaseModel] = BingSearchInput
47
47
  api_wrapper: BingSearchAPIWrapper
48
48
 
49
49
  def _run(
@@ -2,7 +2,7 @@ import math
2
2
  from math import *
3
3
 
4
4
  import sympy
5
- from langchain.pydantic_v1 import BaseModel, Field
5
+ from pydantic import BaseModel, Field
6
6
  from langchain.tools import tool
7
7
  from sympy import *
8
8
 
@@ -10,7 +10,7 @@ from sympy import *
10
10
  class CalculatorInput(BaseModel):
11
11
  expression: str = Field(
12
12
  description="The input to this tool should be a mathematical expression using only Python's built-in mathematical operators.",
13
- example='200*7',
13
+ examples=['200*7'],
14
14
  )
15
15
 
16
16
 
@@ -15,7 +15,7 @@ from uuid import uuid4
15
15
 
16
16
  import matplotlib
17
17
  from langchain_community.tools import Tool
18
- from langchain_core.pydantic_v1 import BaseModel, Field
18
+ from pydantic import BaseModel, Field
19
19
  from loguru import logger
20
20
 
21
21
  CODE_BLOCK_PATTERN = r"```(\w*)\n(.*?)\n```"
@@ -239,7 +239,7 @@ class CodeInterpreterToolArguments(BaseModel):
239
239
 
240
240
  python_code: str = Field(
241
241
  ...,
242
- example="print('Hello World')",
242
+ examples=["print('Hello World')"],
243
243
  description=(
244
244
  'The pure python script to be evaluated. '
245
245
  'The contents will be in main.py. '
@@ -2,13 +2,11 @@ import logging
2
2
  import os
3
3
  from typing import Any, Dict, Mapping, Optional, Tuple, Type, Union
4
4
 
5
- from langchain.pydantic_v1 import BaseModel, Field
6
- from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
7
5
  from langchain_community.utils.openai import is_openai_v1
8
6
  from langchain_core.callbacks import CallbackManagerForToolRun
9
- from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
10
7
  from langchain_core.tools import BaseTool
11
8
  from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
9
+ from pydantic import ConfigDict, model_validator, BaseModel, Field
12
10
 
13
11
  from bisheng_langchain.utils.azure_dalle_image_generator import AzureDallEWrapper
14
12
 
@@ -26,7 +24,7 @@ class DallEAPIWrapper(BaseModel):
26
24
  2. save your OPENAI_API_KEY in an environment variable
27
25
  """
28
26
 
29
- client: Any #: :meta private:
27
+ client: Any = None #: :meta private:
30
28
  async_client: Any = Field(default=None, exclude=True) #: :meta private:
31
29
  model_name: str = Field(default="dall-e-2", alias="model")
32
30
  model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@@ -59,13 +57,10 @@ class DallEAPIWrapper(BaseModel):
59
57
  http_async_client: Union[Any, None] = None
60
58
  """Optional httpx.AsyncClient. Only used for async invocations. Must specify
61
59
  http_client as well if you'd like a custom client for sync invocations."""
60
+ model_config = ConfigDict(extra='forbid')
62
61
 
63
- class Config:
64
- """Configuration for this pydantic object."""
65
-
66
- extra = Extra.forbid
67
-
68
- @root_validator(pre=True)
62
+ @model_validator(mode='before')
63
+ @classmethod
69
64
  def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
70
65
  """Build extra kwargs from additional params that were passed in."""
71
66
  all_required_field_names = get_pydantic_field_names(cls)
@@ -91,7 +86,8 @@ class DallEAPIWrapper(BaseModel):
91
86
  values["model_kwargs"] = extra
92
87
  return values
93
88
 
94
- @root_validator()
89
+ @model_validator(mode='before')
90
+ @classmethod
95
91
  def validate_environment(cls, values: Dict) -> Dict:
96
92
  """Validate that api key and python package exists in environment."""
97
93
  values["openai_api_key"] = get_from_dict_or_env(values, "openai_api_key", "OPENAI_API_KEY")
@@ -1,7 +1,7 @@
1
1
  from datetime import datetime
2
2
 
3
3
  import pytz
4
- from langchain.pydantic_v1 import BaseModel, Field
4
+ from pydantic import BaseModel, Field
5
5
  from langchain.tools import tool
6
6
 
7
7
 
@@ -1,8 +1,7 @@
1
1
  from typing import Any, Optional, Type
2
2
 
3
3
  import requests
4
- from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
5
- from loguru import logger
4
+ from pydantic import BaseModel, Field
6
5
 
7
6
  from bisheng_langchain.gpts.tools.api_tools.base import (APIToolBase,
8
7
  MultArgsSchemaTool)
@@ -1,11 +1,9 @@
1
- import os
2
1
  import smtplib
3
- from email.mime.application import MIMEApplication
4
2
  from email.mime.multipart import MIMEMultipart
5
3
  from email.mime.text import MIMEText
6
- from typing import Any, Optional
4
+ from typing import Any
7
5
 
8
- from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
6
+ from pydantic import BaseModel, Field
9
7
 
10
8
  from bisheng_langchain.gpts.tools.api_tools.base import (APIToolBase,
11
9
  MultArgsSchemaTool)
@@ -1,29 +1,28 @@
1
1
  from typing import Any, Optional, Type
2
2
 
3
3
  import requests
4
- from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
5
- from loguru import logger
4
+ from pydantic import BaseModel, Field
6
5
 
7
6
  from bisheng_langchain.gpts.tools.api_tools.base import (APIToolBase,
8
7
  MultArgsSchemaTool)
9
8
 
10
9
 
11
10
  class InputArgs(BaseModel):
12
- message: Optional[str] = Field(description="需要发送的钉钉消息")
13
- receive_id: Optional[str] = Field(description="接收的ID")
14
- receive_id_type: Optional[str] = Field(description="接收的ID类型")
15
- container_id: Optional[str] = Field(description="container_id")
16
- start_time: Optional[str] = Field(description="start_time")
17
- end_time: Optional[str] = Field(description="end_time")
11
+ message: Optional[str] = Field(None, description="需要发送的钉钉消息")
12
+ receive_id: Optional[str] = Field(None, description="接收的ID")
13
+ receive_id_type: Optional[str] = Field(None, description="接收的ID类型")
14
+ container_id: Optional[str] = Field(None, description="container_id")
15
+ start_time: Optional[str] = Field(None, description="start_time")
16
+ end_time: Optional[str] = Field(None, description="end_time")
18
17
  # page_token: Optional[str] = Field(description="page_token")
19
- container_id_type: Optional[str] = Field(description="container_id_type")
18
+ container_id_type: Optional[str] = Field(None, description="container_id_type")
20
19
  page_size: Optional[int] = Field(default=20,description="page_size")
21
- page_token: Optional[str] = Field(description="page_token")
20
+ page_token: Optional[str] = Field(None, description="page_token")
22
21
  sort_type: Optional[str] = Field(description="sort_type",default="ByCreateTimeAsc")
23
22
 
24
23
 
25
24
  class FeishuMessageTool(BaseModel):
26
- API_BASE_URL = "https://open.feishu.cn/open-apis"
25
+ API_BASE_URL: str = "https://open.feishu.cn/open-apis"
27
26
  app_id: str = Field(description="app_id")
28
27
  app_secret: str = Field(description="app_secret")
29
28
 
@@ -1,8 +1,7 @@
1
- from typing import Any, Optional, Type
1
+ from typing import Any
2
2
 
3
3
  import requests
4
- from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
5
- from loguru import logger
4
+ from pydantic import BaseModel, Field
6
5
 
7
6
  from bisheng_langchain.gpts.tools.api_tools.base import (APIToolBase,
8
7
  MultArgsSchemaTool)
@@ -11,7 +11,7 @@ from langchain_core.tools import BaseTool, tool
11
11
  from langgraph.constants import END, START
12
12
  from langgraph.graph import add_messages, StateGraph
13
13
  from langgraph.prebuilt import ToolNode
14
- from pydantic import BaseModel, Field
14
+ from pydantic import ConfigDict, BaseModel, Field
15
15
 
16
16
 
17
17
  class State(TypedDict):
@@ -47,8 +47,8 @@ class SubmitFinalAnswer(BaseModel):
47
47
  final_answer: str = Field(..., description="The final answer to the user")
48
48
 
49
49
  class QueryDBTool(BaseTool):
50
- name = "db_query_tool"
51
- description = """Execute a SQL query against the database and get back the result.
50
+ name: str = "db_query_tool"
51
+ description: str = """Execute a SQL query against the database and get back the result.
52
52
  If the query is not correct, an error message will be returned.
53
53
  If an error is returned, rewrite the query, check the query, and try again."""
54
54
 
@@ -64,20 +64,18 @@ class SqlAgentAPIWrapper(BaseModel):
64
64
  llm: BaseLanguageModel = Field(description="llm to use for sql agent")
65
65
  sql_address: str = Field(description="sql database address for SQLDatabase uri")
66
66
 
67
- db: Optional[SQLDatabase]
68
- list_tables_tool: Optional[BaseTool]
69
- get_schema_tool: Optional[BaseTool]
70
- db_query_tool: Optional[BaseTool]
71
- query_check: Optional[Any]
72
- query_gen: Optional[Any]
73
- workflow: Optional[StateGraph]
74
- app: Optional[Any]
75
- schema_llm: Optional[Any]
76
- query_check_llm: Optional[Any]
77
- query_gen_llm: Optional[Any]
78
-
79
- class Config:
80
- arbitrary_types_allowed = True
67
+ db: Optional[SQLDatabase] = None
68
+ list_tables_tool: Optional[BaseTool] = None
69
+ get_schema_tool: Optional[BaseTool] = None
70
+ db_query_tool: Optional[BaseTool] = None
71
+ query_check: Optional[Any] = None
72
+ query_gen: Optional[Any] = None
73
+ workflow: Optional[StateGraph] = None
74
+ app: Optional[Any] = None
75
+ schema_llm: Optional[Any] = None
76
+ query_check_llm: Optional[Any] = None
77
+ query_gen_llm: Optional[Any] = None
78
+ model_config = ConfigDict(arbitrary_types_allowed=True)
81
79
 
82
80
  def __init__(self, **kwargs):
83
81
  super().__init__(**kwargs)
@@ -241,8 +239,8 @@ class SqlAgentInput(BaseModel):
241
239
 
242
240
 
243
241
  class SqlAgentTool(BaseTool):
244
- name = "sql_agent"
245
- description = "回答与 SQL 数据库有关的问题。给定用户问题,将从数据库中获取可用的表以及对应 DDL,生成 SQL 查询语句并进行执行,最终得到执行结果。"
242
+ name: str = "sql_agent"
243
+ description: str = "回答与 SQL 数据库有关的问题。给定用户问题,将从数据库中获取可用的表以及对应 DDL,生成 SQL 查询语句并进行执行,最终得到执行结果。"
246
244
  args_schema: Type[BaseModel] = SqlAgentInput
247
245
  api_wrapper: SqlAgentAPIWrapper
248
246
 
@@ -252,7 +250,12 @@ class SqlAgentTool(BaseTool):
252
250
  run_manager: Optional[CallbackManagerForToolRun] = None,
253
251
  ) -> str:
254
252
  """Use the tool."""
255
- return self.api_wrapper.run(query)
253
+ try:
254
+ res = self.api_wrapper.run(query)
255
+ finally:
256
+ if self.api_wrapper and self.api_wrapper.db:
257
+ self.api_wrapper.db._engine.dispose()
258
+ return res
256
259
 
257
260
 
258
261
  if __name__ == '__main__':
@@ -1,12 +1,12 @@
1
1
 
2
2
  from typing import List, Optional
3
3
 
4
- from pydantic import BaseModel, Extra
4
+ from pydantic import ConfigDict, BaseModel
5
5
 
6
6
 
7
7
  class InputNode(BaseModel):
8
8
  """Input组件,用来控制输入"""
9
- input: Optional[List[str]]
9
+ input: Optional[List[str]] = None
10
10
 
11
11
  def text(self):
12
12
  return self.input
@@ -15,14 +15,10 @@ class InputNode(BaseModel):
15
15
  class VariableNode(BaseModel):
16
16
  """用来设置变量"""
17
17
  # key
18
- variables: Optional[List[str]]
18
+ variables: Optional[List[str]] = None
19
19
  # vaulues
20
20
  variable_value: Optional[List[str]] = []
21
-
22
- class Config:
23
- """Configuration for this pydantic object."""
24
-
25
- extra = Extra.forbid
21
+ model_config = ConfigDict(extra="forbid")
26
22
 
27
23
  def text(self):
28
24
  if self.variable_value:
@@ -36,9 +32,9 @@ class VariableNode(BaseModel):
36
32
 
37
33
 
38
34
  class InputFileNode(BaseModel):
39
- file_path: Optional[str]
40
- file_name: Optional[str]
41
- file_type: Optional[str] # tips for file
35
+ file_path: Optional[str] = None
36
+ file_name: Optional[str] = None
37
+ file_type: Optional[str] = None # tips for file
42
38
  """Output组件,用来控制输出"""
43
39
 
44
40
  def text(self):
@@ -5,7 +5,7 @@ from venv import logger
5
5
  from bisheng_langchain.chains import LoaderOutputChain
6
6
  from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun
7
7
  from langchain.chains.base import Chain
8
- from pydantic import BaseModel, Extra
8
+ from pydantic import ConfigDict, BaseModel
9
9
 
10
10
  _TEXT_COLOR_MAPPING = {
11
11
  'blue': '36;1',
@@ -52,11 +52,7 @@ class Report(Chain):
52
52
 
53
53
  input_key: str = 'report_name' #: :meta private:
54
54
  output_key: str = 'text' #: :meta private:
55
-
56
- class Config:
57
- """Configuration for this pydantic object."""
58
- extra = Extra.forbid
59
- arbitrary_types_allowed = True
55
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
60
56
 
61
57
  @property
62
58
  def input_keys(self) -> List[str]:
@@ -5,8 +5,7 @@ import redis
5
5
  from langchain.memory.chat_memory import BaseChatMemory
6
6
  from langchain_core.messages import (AIMessage, BaseMessage, HumanMessage, get_buffer_string,
7
7
  message_to_dict, messages_from_dict)
8
- from langchain_core.pydantic_v1 import root_validator
9
- from pydantic import Field
8
+ from pydantic import Field, model_validator
10
9
 
11
10
 
12
11
  class ConversationRedisMemory(BaseChatMemory):
@@ -20,7 +19,8 @@ class ConversationRedisMemory(BaseChatMemory):
20
19
  redis_prefix: str = 'redis_buffer_'
21
20
  ttl: Optional[int] = None
22
21
 
23
- @root_validator()
22
+ @model_validator(mode='before')
23
+ @classmethod
24
24
  def validate_environment(cls, values: Dict) -> Dict:
25
25
  redis_url = values.get('redis_url')
26
26
  if not redis_url:
@@ -10,7 +10,7 @@ from langchain_core.callbacks import (AsyncCallbackManagerForChainRun, CallbackM
10
10
  from langchain_core.language_models import BaseLanguageModel
11
11
  from langchain_core.prompts import (ChatPromptTemplate, HumanMessagePromptTemplate,
12
12
  SystemMessagePromptTemplate)
13
- from langchain_core.pydantic_v1 import Extra, Field
13
+ from pydantic import ConfigDict, Field
14
14
 
15
15
  from .bisheng_rag_tool import BishengRAGTool
16
16
 
@@ -52,13 +52,7 @@ class BishengRetrievalQA(Chain):
52
52
  """Return the source documents or not."""
53
53
  bisheng_rag_tool: BishengRAGTool = Field(default_factory=BishengRAGTool,
54
54
  description='RAG tool')
55
-
56
- class Config:
57
- """Configuration for this pydantic object."""
58
-
59
- extra = Extra.forbid
60
- arbitrary_types_allowed = True
61
- allow_population_by_field_name = True
55
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True, validate_by_name=True)
62
56
 
63
57
  @property
64
58
  def input_keys(self) -> List[str]:
@@ -15,7 +15,7 @@ from langchain.chains.combine_documents import create_stuff_documents_chain
15
15
  from langchain_core.callbacks import CallbackManagerForChainRun
16
16
  from langchain_core.language_models.base import LanguageModelLike
17
17
  from langchain_core.prompts import ChatPromptTemplate
18
- from langchain_core.pydantic_v1 import BaseModel, Field
18
+ from pydantic import BaseModel, Field
19
19
  from langchain_core.runnables import RunnableConfig
20
20
  from langchain_core.tools import BaseTool, Tool
21
21
  from langchain_core.vectorstores import VectorStoreRetriever
@@ -2,7 +2,7 @@ from typing import Any, List, Optional
2
2
 
3
3
  from langchain.text_splitter import TextSplitter
4
4
  from langchain_core.documents import Document
5
- from langchain_core.pydantic_v1 import Field
5
+ from pydantic import Field
6
6
  from langchain_core.retrievers import BaseRetriever
7
7
  from loguru import logger
8
8
 
@@ -2,7 +2,7 @@ from typing import Any, List, Optional
2
2
 
3
3
  from langchain.text_splitter import TextSplitter
4
4
  from langchain_core.documents import Document
5
- from langchain_core.pydantic_v1 import Field
5
+ from pydantic import Field
6
6
  from langchain_core.retrievers import BaseRetriever
7
7
  from loguru import logger
8
8
 
@@ -3,7 +3,7 @@ from typing import Any, List, Optional
3
3
  from bisheng_langchain.vectorstores import ElasticKeywordsSearch
4
4
  from langchain.text_splitter import TextSplitter
5
5
  from langchain_core.documents import Document
6
- from langchain_core.pydantic_v1 import Field
6
+ from pydantic import Field
7
7
  from langchain_core.retrievers import BaseRetriever
8
8
 
9
9
 
@@ -3,7 +3,7 @@ from typing import Any, List, Optional
3
3
 
4
4
  from langchain.text_splitter import TextSplitter
5
5
  from langchain_core.documents import Document
6
- from langchain_core.pydantic_v1 import Field
6
+ from pydantic import Field
7
7
  from langchain_core.retrievers import BaseRetriever
8
8
 
9
9
 
@@ -15,7 +15,7 @@ class SmallerChunksVectorRetriever(BaseRetriever):
15
15
  parent_splitter: Optional[TextSplitter] = None
16
16
  """The text splitter to use to create parent documents.
17
17
  If none, then the parent documents will be the raw documents passed in."""
18
- id_key = 'doc_id'
18
+ id_key: str = 'doc_id'
19
19
 
20
20
  def add_documents(
21
21
  self,
@@ -6,13 +6,13 @@ multiple retrievers by using weighted Reciprocal Rank Fusion
6
6
  from typing import Any, Dict, List
7
7
 
8
8
  from langchain_core.documents import Document
9
- from langchain_core.pydantic_v1 import root_validator
10
9
  from langchain_core.retrievers import BaseRetriever
11
10
 
12
11
  from langchain.callbacks.manager import (
13
12
  AsyncCallbackManagerForRetrieverRun,
14
13
  CallbackManagerForRetrieverRun,
15
14
  )
15
+ from pydantic import model_validator
16
16
 
17
17
 
18
18
  class EnsembleRetriever(BaseRetriever):
@@ -33,7 +33,8 @@ class EnsembleRetriever(BaseRetriever):
33
33
  weights: List[float]
34
34
  c: int = 60
35
35
 
36
- @root_validator(pre=True)
36
+ @model_validator(mode='before')
37
+ @classmethod
37
38
  def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]:
38
39
  if not values.get("weights"):
39
40
  n_retrievers = len(values["retrievers"])
@@ -3,7 +3,7 @@ from typing import Callable, Dict, Optional, Union
3
3
 
4
4
  import openai
5
5
  from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
6
- from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
6
+ from pydantic import Field, SecretStr, model_validator
7
7
  from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
8
8
 
9
9
 
@@ -56,7 +56,8 @@ class AzureDallEWrapper(DallEAPIWrapper):
56
56
  chunk_size: int = 2048
57
57
  """Maximum number of texts to embed in each batch"""
58
58
 
59
- @root_validator()
59
+ @model_validator(mode='before')
60
+ @classmethod
60
61
  def validate_environment(cls, values: Dict) -> Dict:
61
62
  """Validate that api key and python package exists in environment."""
62
63
  # Check OPENAI_KEY for backwards compatibility.
@@ -5,7 +5,7 @@ from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union
5
5
  import aiohttp
6
6
  import requests
7
7
  from loguru import logger
8
- from pydantic import BaseModel, Extra
8
+ from pydantic import ConfigDict, BaseModel
9
9
 
10
10
 
11
11
  class Requests(BaseModel):
@@ -19,12 +19,7 @@ class Requests(BaseModel):
19
19
  aiosession: Optional[aiohttp.ClientSession] = None
20
20
  auth: Optional[Any] = None
21
21
  request_timeout: Union[float, Tuple[float, float]] = 120
22
-
23
- class Config:
24
- """Configuration for this pydantic object."""
25
-
26
- extra = Extra.forbid
27
- arbitrary_types_allowed = True
22
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
28
23
 
29
24
  def get(self, url: str, **kwargs: Any) -> requests.Response:
30
25
  """GET the URL and return the text."""
@@ -87,12 +82,13 @@ class Requests(BaseModel):
87
82
  **kwargs) as response:
88
83
  yield response
89
84
  else:
90
- async with self.aiosession.request(method,
91
- url,
92
- headers=self.headers,
93
- auth=self.auth,
94
- **kwargs) as response:
95
- yield response
85
+ async with self.aiosession:
86
+ async with self.aiosession.request(method,
87
+ url,
88
+ headers=self.headers,
89
+ auth=self.auth,
90
+ **kwargs) as response:
91
+ yield response
96
92
 
97
93
  @asynccontextmanager
98
94
  async def aget(self, url: str, **kwargs: Any) -> AsyncGenerator[aiohttp.ClientResponse, None]:
@@ -139,12 +135,7 @@ class TextRequestsWrapper(BaseModel):
139
135
  aiosession: Optional[aiohttp.ClientSession] = None
140
136
  auth: Optional[Any] = None
141
137
  request_timeout: Union[float, Tuple[float, float]] = 120
142
-
143
- class Config:
144
- """Configuration for this pydantic object."""
145
-
146
- extra = Extra.forbid
147
- arbitrary_types_allowed = True
138
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
148
139
 
149
140
  @property
150
141
  def requests(self) -> Requests: