nvidia-nat-crewai 1.3.0a20250827__py3-none-any.whl → 1.3.0a20250828__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nat/plugins/crewai/llm.py CHANGED
@@ -14,15 +14,49 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import os
17
+ from typing import TypeVar
17
18
 
18
19
  from nat.builder.builder import Builder
19
20
  from nat.builder.framework_enum import LLMFrameworkEnum
20
21
  from nat.cli.register_workflow import register_llm_client
22
+ from nat.data_models.llm import LLMBaseConfig
21
23
  from nat.data_models.retry_mixin import RetryMixin
24
+ from nat.data_models.thinking_mixin import ThinkingMixin
22
25
  from nat.llm.azure_openai_llm import AzureOpenAIModelConfig
23
26
  from nat.llm.nim_llm import NIMModelConfig
24
27
  from nat.llm.openai_llm import OpenAIModelConfig
28
+ from nat.llm.utils.thinking import BaseThinkingInjector
29
+ from nat.llm.utils.thinking import FunctionArgumentWrapper
30
+ from nat.llm.utils.thinking import patch_with_thinking
25
31
  from nat.utils.exception_handlers.automatic_retries import patch_with_retry
32
+ from nat.utils.type_utils import override
33
+
34
+ ModelType = TypeVar("ModelType")
35
+
36
+
37
+ def _patch_llm_based_on_config(client: ModelType, llm_config: LLMBaseConfig) -> ModelType:
38
+
39
+ class CrewAIThinkingInjector(BaseThinkingInjector):
40
+
41
+ @override
42
+ def inject(self, messages: list[dict[str, str]], *args, **kwargs) -> FunctionArgumentWrapper:
43
+ new_messages = [{"role": "system", "content": self.system_prompt}] + messages
44
+ return FunctionArgumentWrapper(new_messages, *args, **kwargs)
45
+
46
+ if isinstance(llm_config, ThinkingMixin) and llm_config.thinking_system_prompt is not None:
47
+ client = patch_with_thinking(
48
+ client, CrewAIThinkingInjector(
49
+ system_prompt=llm_config.thinking_system_prompt,
50
+ function_names=["call"],
51
+ ))
52
+
53
+ if isinstance(llm_config, RetryMixin):
54
+ client = patch_with_retry(client,
55
+ retries=llm_config.num_retries,
56
+ retry_codes=llm_config.retry_on_status_codes,
57
+ retry_on_messages=llm_config.retry_on_errors)
58
+
59
+ return client
26
60
 
27
61
 
28
62
  @register_llm_client(config_type=AzureOpenAIModelConfig, wrapper_type=LLMFrameworkEnum.CREWAI)
@@ -41,33 +75,25 @@ async def azure_openai_crewai(llm_config: AzureOpenAIModelConfig, _builder: Buil
41
75
  }, by_alias=True),
42
76
  }
43
77
 
44
- api_key = config_obj.get("api_key") or os.environ.get("AZURE_OPENAI_API_KEY") or os.environ.get("AZURE_API_KEY")
78
+ api_key = llm_config.api_key or os.environ.get("AZURE_OPENAI_API_KEY") or os.environ.get("AZURE_API_KEY")
45
79
  if api_key is None:
46
80
  raise ValueError("Azure API key is not set")
47
81
  os.environ["AZURE_API_KEY"] = api_key
48
- api_base = (config_obj.get("azure_endpoint") or os.environ.get("AZURE_OPENAI_ENDPOINT")
82
+ api_base = (llm_config.azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
49
83
  or os.environ.get("AZURE_API_BASE"))
50
84
  if api_base is None:
51
85
  raise ValueError("Azure endpoint is not set")
52
86
  os.environ["AZURE_API_BASE"] = api_base
53
87
 
54
88
  os.environ["AZURE_API_VERSION"] = llm_config.api_version
55
- model = config_obj.get("azure_deployment") or os.environ.get("AZURE_MODEL_DEPLOYMENT")
89
+ model = llm_config.azure_deployment or os.environ.get("AZURE_MODEL_DEPLOYMENT")
56
90
  if model is None:
57
91
  raise ValueError("Azure model deployment is not set")
58
-
59
92
  config_obj["model"] = model
60
93
 
61
94
  client = LLM(**config_obj)
62
95
 
63
- if isinstance(llm_config, RetryMixin):
64
-
65
- client = patch_with_retry(client,
66
- retries=llm_config.num_retries,
67
- retry_codes=llm_config.retry_on_status_codes,
68
- retry_on_messages=llm_config.retry_on_errors)
69
-
70
- yield client
96
+ yield _patch_llm_based_on_config(client, llm_config)
71
97
 
72
98
 
73
99
  @register_llm_client(config_type=NIMModelConfig, wrapper_type=LLMFrameworkEnum.CREWAI)
@@ -81,28 +107,14 @@ async def nim_crewai(llm_config: NIMModelConfig, _builder: Builder):
81
107
  }
82
108
 
83
109
  # Because CrewAI uses a different environment variable for the API key, we need to set it here manually
84
- if ("api_key" not in config_obj or config_obj["api_key"] is None):
85
-
86
- if ("NVIDIA_NIM_API_KEY" in os.environ):
87
- # Dont need to do anything. User has already set the correct key
88
- pass
89
- else:
90
- nvidai_api_key = os.getenv("NVIDIA_API_KEY")
91
-
92
- if (nvidai_api_key is not None):
93
- # Transfer the key to the correct environment variable for LiteLLM
94
- os.environ["NVIDIA_NIM_API_KEY"] = nvidai_api_key
110
+ if config_obj.get("api_key") is None and "NVIDIA_NIM_API_KEY" not in os.environ:
111
+ nvidia_api_key = os.getenv("NVIDIA_API_KEY")
112
+ if nvidia_api_key is not None:
113
+ os.environ["NVIDIA_NIM_API_KEY"] = nvidia_api_key
95
114
 
96
115
  client = LLM(**config_obj)
97
116
 
98
- if isinstance(llm_config, RetryMixin):
99
-
100
- client = patch_with_retry(client,
101
- retries=llm_config.num_retries,
102
- retry_codes=llm_config.retry_on_status_codes,
103
- retry_on_messages=llm_config.retry_on_errors)
104
-
105
- yield client
117
+ yield _patch_llm_based_on_config(client, llm_config)
106
118
 
107
119
 
108
120
  @register_llm_client(config_type=OpenAIModelConfig, wrapper_type=LLMFrameworkEnum.CREWAI)
@@ -116,11 +128,4 @@ async def openai_crewai(llm_config: OpenAIModelConfig, _builder: Builder):
116
128
 
117
129
  client = LLM(**config_obj)
118
130
 
119
- if isinstance(llm_config, RetryMixin):
120
-
121
- client = patch_with_retry(client,
122
- retries=llm_config.num_retries,
123
- retry_codes=llm_config.retry_on_status_codes,
124
- retry_on_messages=llm_config.retry_on_errors)
125
-
126
- yield client
131
+ yield _patch_llm_based_on_config(client, llm_config)
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nvidia-nat-crewai
3
- Version: 1.3.0a20250827
3
+ Version: 1.3.0a20250828
4
4
  Summary: Subpackage for CrewAI integration in NeMo Agent toolkit
5
5
  Keywords: ai,rag,agents
6
6
  Classifier: Programming Language :: Python
7
7
  Requires-Python: <3.13,>=3.11
8
8
  Description-Content-Type: text/markdown
9
- Requires-Dist: nvidia-nat==v1.3.0a20250827
9
+ Requires-Dist: nvidia-nat==v1.3.0a20250828
10
10
  Requires-Dist: crewai~=0.95.0
11
11
 
12
12
  <!--
@@ -1,11 +1,11 @@
1
1
  nat/meta/pypi.md,sha256=T68FnThRzDGFf1LR8u-okM-r11-skSnKqSyI6HOktQY,1107
2
2
  nat/plugins/crewai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  nat/plugins/crewai/crewai_callback_handler.py,sha256=LDOctDQC9qdba1SVGoVkceCOSYuDj_mnl3HCuq2nIuQ,8382
4
- nat/plugins/crewai/llm.py,sha256=ErsUXsHpQ-chWqFhYjp-OhARAnCl-HdV_tkrrTZzIOU,4665
4
+ nat/plugins/crewai/llm.py,sha256=DlkLN_fCXS7LUuNku4vvOUrMAgIA_om5LaQsKJOAAPk,5024
5
5
  nat/plugins/crewai/register.py,sha256=_R3bhGmz___696_NwyIcpw3koMBiWqIFoWEFJ0VAgXs,831
6
6
  nat/plugins/crewai/tool_wrapper.py,sha256=BNKEPQQCLKtXNzGDAKBLCdmGJXe9lBOVI1hObha8hoI,1569
7
- nvidia_nat_crewai-1.3.0a20250827.dist-info/METADATA,sha256=VEQudpIIMTEppwXQVv-6NPBtX_eTakcyVqPacygMySo,1453
8
- nvidia_nat_crewai-1.3.0a20250827.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
- nvidia_nat_crewai-1.3.0a20250827.dist-info/entry_points.txt,sha256=YF5PUdQGr_OUDXB4TykElHJTsKT8yKkuE0bMX5n_RXs,58
10
- nvidia_nat_crewai-1.3.0a20250827.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
11
- nvidia_nat_crewai-1.3.0a20250827.dist-info/RECORD,,
7
+ nvidia_nat_crewai-1.3.0a20250828.dist-info/METADATA,sha256=9B44N7vNvV2u57Py6bbSCurH3vn-5HUTZ9b_eJINOPM,1453
8
+ nvidia_nat_crewai-1.3.0a20250828.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
+ nvidia_nat_crewai-1.3.0a20250828.dist-info/entry_points.txt,sha256=YF5PUdQGr_OUDXB4TykElHJTsKT8yKkuE0bMX5n_RXs,58
10
+ nvidia_nat_crewai-1.3.0a20250828.dist-info/top_level.txt,sha256=8-CJ2cP6-f0ZReXe5Hzqp-5pvzzHz-5Ds5H2bGqh1-U,4
11
+ nvidia_nat_crewai-1.3.0a20250828.dist-info/RECORD,,