nvidia-nat-langchain 1.3.0a20250826__py3-none-any.whl → 1.3.0a20250828__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,15 +13,96 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from collections.abc import Sequence
17
+ from typing import TypeVar
18
+
16
19
  from nat.builder.builder import Builder
17
20
  from nat.builder.framework_enum import LLMFrameworkEnum
18
21
  from nat.cli.register_workflow import register_llm_client
22
+ from nat.data_models.llm import LLMBaseConfig
19
23
  from nat.data_models.retry_mixin import RetryMixin
24
+ from nat.data_models.thinking_mixin import ThinkingMixin
20
25
  from nat.llm.aws_bedrock_llm import AWSBedrockModelConfig
21
26
  from nat.llm.azure_openai_llm import AzureOpenAIModelConfig
22
27
  from nat.llm.nim_llm import NIMModelConfig
23
28
  from nat.llm.openai_llm import OpenAIModelConfig
29
+ from nat.llm.utils.thinking import BaseThinkingInjector
30
+ from nat.llm.utils.thinking import FunctionArgumentWrapper
31
+ from nat.llm.utils.thinking import patch_with_thinking
24
32
  from nat.utils.exception_handlers.automatic_retries import patch_with_retry
33
+ from nat.utils.type_utils import override
34
+
35
+ ModelType = TypeVar("ModelType")
36
+
37
+
38
+ def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType:
39
+
40
+ from langchain_core.language_models import LanguageModelInput
41
+ from langchain_core.messages import BaseMessage
42
+ from langchain_core.messages import HumanMessage
43
+ from langchain_core.messages import SystemMessage
44
+ from langchain_core.prompt_values import PromptValue
45
+
46
+ class LangchainThinkingInjector(BaseThinkingInjector):
47
+
48
+ @override
49
+ def inject(self, messages: LanguageModelInput, *args, **kwargs) -> FunctionArgumentWrapper:
50
+ """
51
+ Inject a system prompt into the messages.
52
+
53
+ The messages are the first (non-object) argument to the function.
54
+ The rest of the arguments are passed through unchanged.
55
+
56
+ Args:
57
+ messages: The messages to inject the system prompt into.
58
+ *args: The rest of the arguments to the function.
59
+ **kwargs: The rest of the keyword arguments to the function.
60
+
61
+ Returns:
62
+ FunctionArgumentWrapper: An object that contains the transformed args and kwargs.
63
+
64
+ Raises:
65
+ ValueError: If the messages are not a valid type for LanguageModelInput.
66
+ """
67
+ system_message = SystemMessage(content=self.system_prompt)
68
+ if isinstance(messages, BaseMessage):
69
+ new_messages = [system_message, messages]
70
+ return FunctionArgumentWrapper(new_messages, *args, **kwargs)
71
+ elif isinstance(messages, PromptValue):
72
+ new_messages = [system_message, *messages.to_messages()]
73
+ return FunctionArgumentWrapper(new_messages, *args, **kwargs)
74
+ elif isinstance(messages, str):
75
+ new_messages = [system_message, HumanMessage(content=messages)]
76
+ return FunctionArgumentWrapper(new_messages, *args, **kwargs)
77
+ elif isinstance(messages, Sequence):
78
+ if all(isinstance(m, BaseMessage) for m in messages):
79
+ new_messages = [system_message, *list(messages)]
80
+ return FunctionArgumentWrapper(new_messages, *args, **kwargs)
81
+ raise ValueError(
82
+ "Unsupported sequence element types for LanguageModelInput; expected Sequence[BaseMessage].")
83
+ else:
84
+ return FunctionArgumentWrapper(messages, *args, **kwargs)
85
+
86
+ if isinstance(llm_config, ThinkingMixin) and llm_config.thinking_system_prompt is not None:
87
+ client = patch_with_thinking(
88
+ client,
89
+ LangchainThinkingInjector(
90
+ system_prompt=llm_config.thinking_system_prompt,
91
+ function_names=[
92
+ "invoke",
93
+ "ainvoke",
94
+ "stream",
95
+ "astream",
96
+ ],
97
+ ))
98
+
99
+ if isinstance(llm_config, RetryMixin):
100
+ client = patch_with_retry(client,
101
+ retries=llm_config.num_retries,
102
+ retry_codes=llm_config.retry_on_status_codes,
103
+ retry_on_messages=llm_config.retry_on_errors)
104
+
105
+ return client
25
106
 
26
107
 
27
108
  @register_llm_client(config_type=AWSBedrockModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
@@ -31,13 +112,7 @@ async def aws_bedrock_langchain(llm_config: AWSBedrockModelConfig, _builder: Bui
31
112
 
32
113
  client = ChatBedrockConverse(**llm_config.model_dump(exclude={"type", "context_size"}, by_alias=True))
33
114
 
34
- if isinstance(llm_config, RetryMixin):
35
- client = patch_with_retry(client,
36
- retries=llm_config.num_retries,
37
- retry_codes=llm_config.retry_on_status_codes,
38
- retry_on_messages=llm_config.retry_on_errors)
39
-
40
- yield client
115
+ yield _patch_llm_based_on_config(client, llm_config)
41
116
 
42
117
 
43
118
  @register_llm_client(config_type=AzureOpenAIModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
@@ -47,13 +122,7 @@ async def azure_openai_langchain(llm_config: AzureOpenAIModelConfig, _builder: B
47
122
 
48
123
  client = AzureChatOpenAI(**llm_config.model_dump(exclude={"type"}, by_alias=True))
49
124
 
50
- if isinstance(llm_config, RetryMixin):
51
- client = patch_with_retry(client,
52
- retries=llm_config.num_retries,
53
- retry_codes=llm_config.retry_on_status_codes,
54
- retry_on_messages=llm_config.retry_on_errors)
55
-
56
- yield client
125
+ yield _patch_llm_based_on_config(client, llm_config)
57
126
 
58
127
 
59
128
  @register_llm_client(config_type=NIMModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
@@ -61,15 +130,13 @@ async def nim_langchain(llm_config: NIMModelConfig, _builder: Builder):
61
130
 
62
131
  from langchain_nvidia_ai_endpoints import ChatNVIDIA
63
132
 
64
- client = ChatNVIDIA(**llm_config.model_dump(exclude={"type"}, by_alias=True))
65
-
66
- if isinstance(llm_config, RetryMixin):
67
- client = patch_with_retry(client,
68
- retries=llm_config.num_retries,
69
- retry_codes=llm_config.retry_on_status_codes,
70
- retry_on_messages=llm_config.retry_on_errors)
133
+ # prefer max_completion_tokens over max_tokens
134
+ client = ChatNVIDIA(
135
+ **llm_config.model_dump(exclude={"type", "max_tokens"}, by_alias=True),
136
+ max_completion_tokens=llm_config.max_tokens,
137
+ )
71
138
 
72
- yield client
139
+ yield _patch_llm_based_on_config(client, llm_config)
73
140
 
74
141
 
75
142
  @register_llm_client(config_type=OpenAIModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
@@ -77,18 +144,7 @@ async def openai_langchain(llm_config: OpenAIModelConfig, _builder: Builder):
77
144
 
78
145
  from langchain_openai import ChatOpenAI
79
146
 
80
- # Default kwargs for OpenAI to include usage metadata in the response. If the user has set stream_usage to False, we
81
- # will not include this.
82
- default_kwargs = {"stream_usage": True}
83
-
84
- kwargs = {**default_kwargs, **llm_config.model_dump(exclude={"type"}, by_alias=True)}
85
-
86
- client = ChatOpenAI(**kwargs)
87
-
88
- if isinstance(llm_config, RetryMixin):
89
- client = patch_with_retry(client,
90
- retries=llm_config.num_retries,
91
- retry_codes=llm_config.retry_on_status_codes,
92
- retry_on_messages=llm_config.retry_on_errors)
147
+ # If stream_usage is specified, it will override the default value of True.
148
+ client = ChatOpenAI(stream_usage=True, **llm_config.model_dump(exclude={"type"}, by_alias=True))
93
149
 
94
- yield client
150
+ yield _patch_llm_based_on_config(client, llm_config)
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat-langchain
3
- Version: 1.3.0a20250826
3
+ Version: 1.3.0a20250828
4
4
  Summary: Subpackage for LangChain and LangGraph integration in NeMo Agent toolkit
5
5
  Keywords: ai,rag,agents
6
6
  Classifier: Programming Language :: Python
7
7
  Requires-Python: <3.13,>=3.11
8
8
  Description-Content-Type: text/markdown
9
- Requires-Dist: nvidia-nat==v1.3.0a20250826
9
+ Requires-Dist: nvidia-nat==v1.3.0a20250828
10
10
  Requires-Dist: langchain-aws~=0.2.1
11
11
  Requires-Dist: langchain-core~=0.3.7
12
12
  Requires-Dist: langchain-nvidia-ai-endpoints~=0.3.5
@@ -1,7 +1,7 @@
1
1
  nat/meta/pypi.md,sha256=-RewrXPwhrT6398iluvXb5lefn18PybmvRFhmZF7KVI,1124
2
2
  nat/plugins/langchain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  nat/plugins/langchain/embedder.py,sha256=7bHXRcLwCeqG6ZQZQ35UAb46k5bN3UUL02OtG7RdL4Y,3282
4
- nat/plugins/langchain/llm.py,sha256=HJZsk0i72ZN0QTb1hZSUV4IsLmE0Ab0JOYYnvtqJfno,4013
4
+ nat/plugins/langchain/llm.py,sha256=VbBq5wuTz8VjQjiUVE-bBfsQr2Zkx5wwNiTbvi-Ngaw,6537
5
5
  nat/plugins/langchain/register.py,sha256=jgq6wSJoGQIZFJhS8RbUs25cLgNJjCkFu4M6qaWJS_4,906
6
6
  nat/plugins/langchain/retriever.py,sha256=SWbXXOezEUuPACnmSSU497NAmEVEMj2SrFJGodkRg34,2644
7
7
  nat/plugins/langchain/tool_wrapper.py,sha256=Zgb2_XB4bEhjPPeqS-ZH_OJT_pcQmteX7u03N_qCLfc,2121
@@ -10,8 +10,8 @@ nat/plugins/langchain/tools/code_generation_tool.py,sha256=qL3HBiOQzVPLw4EiUOWes
10
10
  nat/plugins/langchain/tools/register.py,sha256=uemxqLxcNk1bGX4crV52oMphLTZWonStzkXwTZeG2Rw,889
11
11
  nat/plugins/langchain/tools/tavily_internet_search.py,sha256=AnnLRY1xSU4DOzxbB8nFZRjHngXpqatPVOJ7yWV7jVw,2612
12
12
  nat/plugins/langchain/tools/wikipedia_search.py,sha256=431YwLsjoC_mdvMZ_gY0Q37Uqaue2ASnAHpwr4jWCaU,2197
13
- nvidia_nat_langchain-1.3.0a20250826.dist-info/METADATA,sha256=ZOIis3NGacByh78Sqvro8zVYc6qCus0BEGqxQ8yfNSI,1735
14
- nvidia_nat_langchain-1.3.0a20250826.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
15
- nvidia_nat_langchain-1.3.0a20250826.dist-info/entry_points.txt,sha256=4deXsMn97I012HhDw0UjoqcZ8eEoZ7BnqaRx5QmzebY,123
16
- nvidia_nat_langchain-1.3.0a20250826.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
17
- nvidia_nat_langchain-1.3.0a20250826.dist-info/RECORD,,
13
+ nvidia_nat_langchain-1.3.0a20250828.dist-info/METADATA,sha256=hBxPAKOpltN-iXiHXlSr49pnRNtEgS8sngYGYxuLE8I,1735
14
+ nvidia_nat_langchain-1.3.0a20250828.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
15
+ nvidia_nat_langchain-1.3.0a20250828.dist-info/entry_points.txt,sha256=4deXsMn97I012HhDw0UjoqcZ8eEoZ7BnqaRx5QmzebY,123
16
+ nvidia_nat_langchain-1.3.0a20250828.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
17
+ nvidia_nat_langchain-1.3.0a20250828.dist-info/RECORD,,