nvidia-nat-langchain 1.3.0a20250826__py3-none-any.whl → 1.3.0a20250828__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/plugins/langchain/llm.py +92 -36
- {nvidia_nat_langchain-1.3.0a20250826.dist-info → nvidia_nat_langchain-1.3.0a20250828.dist-info}/METADATA +2 -2
- {nvidia_nat_langchain-1.3.0a20250826.dist-info → nvidia_nat_langchain-1.3.0a20250828.dist-info}/RECORD +6 -6
- {nvidia_nat_langchain-1.3.0a20250826.dist-info → nvidia_nat_langchain-1.3.0a20250828.dist-info}/WHEEL +0 -0
- {nvidia_nat_langchain-1.3.0a20250826.dist-info → nvidia_nat_langchain-1.3.0a20250828.dist-info}/entry_points.txt +0 -0
- {nvidia_nat_langchain-1.3.0a20250826.dist-info → nvidia_nat_langchain-1.3.0a20250828.dist-info}/top_level.txt +0 -0
nat/plugins/langchain/llm.py
CHANGED
@@ -13,15 +13,96 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
+
from collections.abc import Sequence
|
17
|
+
from typing import TypeVar
|
18
|
+
|
16
19
|
from nat.builder.builder import Builder
|
17
20
|
from nat.builder.framework_enum import LLMFrameworkEnum
|
18
21
|
from nat.cli.register_workflow import register_llm_client
|
22
|
+
from nat.data_models.llm import LLMBaseConfig
|
19
23
|
from nat.data_models.retry_mixin import RetryMixin
|
24
|
+
from nat.data_models.thinking_mixin import ThinkingMixin
|
20
25
|
from nat.llm.aws_bedrock_llm import AWSBedrockModelConfig
|
21
26
|
from nat.llm.azure_openai_llm import AzureOpenAIModelConfig
|
22
27
|
from nat.llm.nim_llm import NIMModelConfig
|
23
28
|
from nat.llm.openai_llm import OpenAIModelConfig
|
29
|
+
from nat.llm.utils.thinking import BaseThinkingInjector
|
30
|
+
from nat.llm.utils.thinking import FunctionArgumentWrapper
|
31
|
+
from nat.llm.utils.thinking import patch_with_thinking
|
24
32
|
from nat.utils.exception_handlers.automatic_retries import patch_with_retry
|
33
|
+
from nat.utils.type_utils import override
|
34
|
+
|
35
|
+
ModelType = TypeVar("ModelType")
|
36
|
+
|
37
|
+
|
38
|
+
def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType:
|
39
|
+
|
40
|
+
from langchain_core.language_models import LanguageModelInput
|
41
|
+
from langchain_core.messages import BaseMessage
|
42
|
+
from langchain_core.messages import HumanMessage
|
43
|
+
from langchain_core.messages import SystemMessage
|
44
|
+
from langchain_core.prompt_values import PromptValue
|
45
|
+
|
46
|
+
class LangchainThinkingInjector(BaseThinkingInjector):
|
47
|
+
|
48
|
+
@override
|
49
|
+
def inject(self, messages: LanguageModelInput, *args, **kwargs) -> FunctionArgumentWrapper:
|
50
|
+
"""
|
51
|
+
Inject a system prompt into the messages.
|
52
|
+
|
53
|
+
The messages are the first (non-object) argument to the function.
|
54
|
+
The rest of the arguments are passed through unchanged.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
messages: The messages to inject the system prompt into.
|
58
|
+
*args: The rest of the arguments to the function.
|
59
|
+
**kwargs: The rest of the keyword arguments to the function.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
FunctionArgumentWrapper: An object that contains the transformed args and kwargs.
|
63
|
+
|
64
|
+
Raises:
|
65
|
+
ValueError: If the messages are not a valid type for LanguageModelInput.
|
66
|
+
"""
|
67
|
+
system_message = SystemMessage(content=self.system_prompt)
|
68
|
+
if isinstance(messages, BaseMessage):
|
69
|
+
new_messages = [system_message, messages]
|
70
|
+
return FunctionArgumentWrapper(new_messages, *args, **kwargs)
|
71
|
+
elif isinstance(messages, PromptValue):
|
72
|
+
new_messages = [system_message, *messages.to_messages()]
|
73
|
+
return FunctionArgumentWrapper(new_messages, *args, **kwargs)
|
74
|
+
elif isinstance(messages, str):
|
75
|
+
new_messages = [system_message, HumanMessage(content=messages)]
|
76
|
+
return FunctionArgumentWrapper(new_messages, *args, **kwargs)
|
77
|
+
elif isinstance(messages, Sequence):
|
78
|
+
if all(isinstance(m, BaseMessage) for m in messages):
|
79
|
+
new_messages = [system_message, *list(messages)]
|
80
|
+
return FunctionArgumentWrapper(new_messages, *args, **kwargs)
|
81
|
+
raise ValueError(
|
82
|
+
"Unsupported sequence element types for LanguageModelInput; expected Sequence[BaseMessage].")
|
83
|
+
else:
|
84
|
+
return FunctionArgumentWrapper(messages, *args, **kwargs)
|
85
|
+
|
86
|
+
if isinstance(llm_config, ThinkingMixin) and llm_config.thinking_system_prompt is not None:
|
87
|
+
client = patch_with_thinking(
|
88
|
+
client,
|
89
|
+
LangchainThinkingInjector(
|
90
|
+
system_prompt=llm_config.thinking_system_prompt,
|
91
|
+
function_names=[
|
92
|
+
"invoke",
|
93
|
+
"ainvoke",
|
94
|
+
"stream",
|
95
|
+
"astream",
|
96
|
+
],
|
97
|
+
))
|
98
|
+
|
99
|
+
if isinstance(llm_config, RetryMixin):
|
100
|
+
client = patch_with_retry(client,
|
101
|
+
retries=llm_config.num_retries,
|
102
|
+
retry_codes=llm_config.retry_on_status_codes,
|
103
|
+
retry_on_messages=llm_config.retry_on_errors)
|
104
|
+
|
105
|
+
return client
|
25
106
|
|
26
107
|
|
27
108
|
@register_llm_client(config_type=AWSBedrockModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
@@ -31,13 +112,7 @@ async def aws_bedrock_langchain(llm_config: AWSBedrockModelConfig, _builder: Bui
|
|
31
112
|
|
32
113
|
client = ChatBedrockConverse(**llm_config.model_dump(exclude={"type", "context_size"}, by_alias=True))
|
33
114
|
|
34
|
-
|
35
|
-
client = patch_with_retry(client,
|
36
|
-
retries=llm_config.num_retries,
|
37
|
-
retry_codes=llm_config.retry_on_status_codes,
|
38
|
-
retry_on_messages=llm_config.retry_on_errors)
|
39
|
-
|
40
|
-
yield client
|
115
|
+
yield _patch_llm_based_on_config(client, llm_config)
|
41
116
|
|
42
117
|
|
43
118
|
@register_llm_client(config_type=AzureOpenAIModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
@@ -47,13 +122,7 @@ async def azure_openai_langchain(llm_config: AzureOpenAIModelConfig, _builder: B
|
|
47
122
|
|
48
123
|
client = AzureChatOpenAI(**llm_config.model_dump(exclude={"type"}, by_alias=True))
|
49
124
|
|
50
|
-
|
51
|
-
client = patch_with_retry(client,
|
52
|
-
retries=llm_config.num_retries,
|
53
|
-
retry_codes=llm_config.retry_on_status_codes,
|
54
|
-
retry_on_messages=llm_config.retry_on_errors)
|
55
|
-
|
56
|
-
yield client
|
125
|
+
yield _patch_llm_based_on_config(client, llm_config)
|
57
126
|
|
58
127
|
|
59
128
|
@register_llm_client(config_type=NIMModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
@@ -61,15 +130,13 @@ async def nim_langchain(llm_config: NIMModelConfig, _builder: Builder):
|
|
61
130
|
|
62
131
|
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
63
132
|
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
retry_codes=llm_config.retry_on_status_codes,
|
70
|
-
retry_on_messages=llm_config.retry_on_errors)
|
133
|
+
# prefer max_completion_tokens over max_tokens
|
134
|
+
client = ChatNVIDIA(
|
135
|
+
**llm_config.model_dump(exclude={"type", "max_tokens"}, by_alias=True),
|
136
|
+
max_completion_tokens=llm_config.max_tokens,
|
137
|
+
)
|
71
138
|
|
72
|
-
yield client
|
139
|
+
yield _patch_llm_based_on_config(client, llm_config)
|
73
140
|
|
74
141
|
|
75
142
|
@register_llm_client(config_type=OpenAIModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
@@ -77,18 +144,7 @@ async def openai_langchain(llm_config: OpenAIModelConfig, _builder: Builder):
|
|
77
144
|
|
78
145
|
from langchain_openai import ChatOpenAI
|
79
146
|
|
80
|
-
#
|
81
|
-
|
82
|
-
default_kwargs = {"stream_usage": True}
|
83
|
-
|
84
|
-
kwargs = {**default_kwargs, **llm_config.model_dump(exclude={"type"}, by_alias=True)}
|
85
|
-
|
86
|
-
client = ChatOpenAI(**kwargs)
|
87
|
-
|
88
|
-
if isinstance(llm_config, RetryMixin):
|
89
|
-
client = patch_with_retry(client,
|
90
|
-
retries=llm_config.num_retries,
|
91
|
-
retry_codes=llm_config.retry_on_status_codes,
|
92
|
-
retry_on_messages=llm_config.retry_on_errors)
|
147
|
+
# If stream_usage is specified, it will override the default value of True.
|
148
|
+
client = ChatOpenAI(stream_usage=True, **llm_config.model_dump(exclude={"type"}, by_alias=True))
|
93
149
|
|
94
|
-
yield client
|
150
|
+
yield _patch_llm_based_on_config(client, llm_config)
|
@@ -1,12 +1,12 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: nvidia-nat-langchain
|
3
|
-
Version: 1.3.
|
3
|
+
Version: 1.3.0a20250828
|
4
4
|
Summary: Subpackage for LangChain and LangGraph integration in NeMo Agent toolkit
|
5
5
|
Keywords: ai,rag,agents
|
6
6
|
Classifier: Programming Language :: Python
|
7
7
|
Requires-Python: <3.13,>=3.11
|
8
8
|
Description-Content-Type: text/markdown
|
9
|
-
Requires-Dist: nvidia-nat==v1.3.
|
9
|
+
Requires-Dist: nvidia-nat==v1.3.0a20250828
|
10
10
|
Requires-Dist: langchain-aws~=0.2.1
|
11
11
|
Requires-Dist: langchain-core~=0.3.7
|
12
12
|
Requires-Dist: langchain-nvidia-ai-endpoints~=0.3.5
|
@@ -1,7 +1,7 @@
|
|
1
1
|
nat/meta/pypi.md,sha256=-RewrXPwhrT6398iluvXb5lefn18PybmvRFhmZF7KVI,1124
|
2
2
|
nat/plugins/langchain/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
nat/plugins/langchain/embedder.py,sha256=7bHXRcLwCeqG6ZQZQ35UAb46k5bN3UUL02OtG7RdL4Y,3282
|
4
|
-
nat/plugins/langchain/llm.py,sha256=
|
4
|
+
nat/plugins/langchain/llm.py,sha256=VbBq5wuTz8VjQjiUVE-bBfsQr2Zkx5wwNiTbvi-Ngaw,6537
|
5
5
|
nat/plugins/langchain/register.py,sha256=jgq6wSJoGQIZFJhS8RbUs25cLgNJjCkFu4M6qaWJS_4,906
|
6
6
|
nat/plugins/langchain/retriever.py,sha256=SWbXXOezEUuPACnmSSU497NAmEVEMj2SrFJGodkRg34,2644
|
7
7
|
nat/plugins/langchain/tool_wrapper.py,sha256=Zgb2_XB4bEhjPPeqS-ZH_OJT_pcQmteX7u03N_qCLfc,2121
|
@@ -10,8 +10,8 @@ nat/plugins/langchain/tools/code_generation_tool.py,sha256=qL3HBiOQzVPLw4EiUOWes
|
|
10
10
|
nat/plugins/langchain/tools/register.py,sha256=uemxqLxcNk1bGX4crV52oMphLTZWonStzkXwTZeG2Rw,889
|
11
11
|
nat/plugins/langchain/tools/tavily_internet_search.py,sha256=AnnLRY1xSU4DOzxbB8nFZRjHngXpqatPVOJ7yWV7jVw,2612
|
12
12
|
nat/plugins/langchain/tools/wikipedia_search.py,sha256=431YwLsjoC_mdvMZ_gY0Q37Uqaue2ASnAHpwr4jWCaU,2197
|
13
|
-
nvidia_nat_langchain-1.3.
|
14
|
-
nvidia_nat_langchain-1.3.
|
15
|
-
nvidia_nat_langchain-1.3.
|
16
|
-
nvidia_nat_langchain-1.3.
|
17
|
-
nvidia_nat_langchain-1.3.
|
13
|
+
nvidia_nat_langchain-1.3.0a20250828.dist-info/METADATA,sha256=hBxPAKOpltN-iXiHXlSr49pnRNtEgS8sngYGYxuLE8I,1735
|
14
|
+
nvidia_nat_langchain-1.3.0a20250828.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
15
|
+
nvidia_nat_langchain-1.3.0a20250828.dist-info/entry_points.txt,sha256=4deXsMn97I012HhDw0UjoqcZ8eEoZ7BnqaRx5QmzebY,123
|
16
|
+
nvidia_nat_langchain-1.3.0a20250828.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
|
17
|
+
nvidia_nat_langchain-1.3.0a20250828.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|