nvidia-nat 1.3.0rc1__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/prompt_optimizer/register.py +2 -2
- nat/agent/react_agent/register.py +9 -1
- nat/agent/rewoo_agent/register.py +8 -1
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +31 -18
- nat/builder/context.py +22 -6
- nat/cli/commands/mcp/mcp.py +6 -6
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -12
- nat/cli/commands/workflow/templates/register.py.j2 +2 -2
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +54 -10
- nat/cli/main.py +3 -0
- nat/data_models/api_server.py +65 -57
- nat/data_models/span.py +41 -3
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +5 -35
- nat/front_ends/fastapi/message_validator.py +3 -1
- nat/observability/exporter/span_exporter.py +34 -14
- nat/profiler/decorators/framework_wrapper.py +1 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +1 -1
- nat/runtime/runner.py +103 -6
- nat/runtime/session.py +26 -0
- nat/tool/memory_tools/get_memory_tool.py +1 -1
- nat/utils/decorators.py +210 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +1 -3
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +32 -31
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0rc1.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -51,7 +51,7 @@ async def prompt_optimizer_function(config: PromptOptimizerConfig, builder: Buil
|
|
|
51
51
|
from .prompt import mutator_prompt
|
|
52
52
|
except ImportError as exc:
|
|
53
53
|
raise ImportError("langchain-core is not installed. Please install it to use MultiLLMPlanner.\n"
|
|
54
|
-
"This error can be resolve by installing nvidia-nat[langchain]") from exc
|
|
54
|
+
"This error can be resolve by installing \"nvidia-nat[langchain]\".") from exc
|
|
55
55
|
|
|
56
56
|
llm = await builder.get_llm(config.optimizer_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
57
57
|
|
|
@@ -111,7 +111,7 @@ async def prompt_recombiner_function(config: PromptRecombinerConfig, builder: Bu
|
|
|
111
111
|
from langchain_core.prompts import PromptTemplate
|
|
112
112
|
except ImportError as exc:
|
|
113
113
|
raise ImportError("langchain-core is not installed. Please install it to use MultiLLMPlanner.\n"
|
|
114
|
-
"This error can be resolve by installing nvidia-nat[langchain].") from exc
|
|
114
|
+
"This error can be resolve by installing \"nvidia-nat[langchain]\".") from exc
|
|
115
115
|
|
|
116
116
|
llm = await builder.get_llm(config.optimizer_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
|
|
117
117
|
|
|
@@ -25,6 +25,7 @@ from nat.cli.register_workflow import register_function
|
|
|
25
25
|
from nat.data_models.agent import AgentBaseConfig
|
|
26
26
|
from nat.data_models.api_server import ChatRequest
|
|
27
27
|
from nat.data_models.api_server import ChatResponse
|
|
28
|
+
from nat.data_models.api_server import Usage
|
|
28
29
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
29
30
|
from nat.data_models.component_ref import FunctionRef
|
|
30
31
|
from nat.data_models.optimizable import OptimizableField
|
|
@@ -149,7 +150,14 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
|
|
|
149
150
|
# get and return the output from the state
|
|
150
151
|
state = ReActGraphState(**state)
|
|
151
152
|
output_message = state.messages[-1]
|
|
152
|
-
|
|
153
|
+
content = str(output_message.content)
|
|
154
|
+
|
|
155
|
+
# Create usage statistics for the response
|
|
156
|
+
prompt_tokens = sum(len(str(msg.content).split()) for msg in input_message.messages)
|
|
157
|
+
completion_tokens = len(content.split()) if content else 0
|
|
158
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
159
|
+
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
|
|
160
|
+
return ChatResponse.from_string(content, usage=usage)
|
|
153
161
|
|
|
154
162
|
except Exception as ex:
|
|
155
163
|
logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, str(ex))
|
|
@@ -26,6 +26,7 @@ from nat.cli.register_workflow import register_function
|
|
|
26
26
|
from nat.data_models.agent import AgentBaseConfig
|
|
27
27
|
from nat.data_models.api_server import ChatRequest
|
|
28
28
|
from nat.data_models.api_server import ChatResponse
|
|
29
|
+
from nat.data_models.api_server import Usage
|
|
29
30
|
from nat.data_models.component_ref import FunctionGroupRef
|
|
30
31
|
from nat.data_models.component_ref import FunctionRef
|
|
31
32
|
from nat.utils.type_converter import GlobalTypeConverter
|
|
@@ -157,7 +158,13 @@ async def rewoo_agent_workflow(config: ReWOOAgentWorkflowConfig, builder: Builde
|
|
|
157
158
|
# Ensure output_message is a string
|
|
158
159
|
if isinstance(output_message, list | dict):
|
|
159
160
|
output_message = str(output_message)
|
|
160
|
-
|
|
161
|
+
|
|
162
|
+
# Create usage statistics for the response
|
|
163
|
+
prompt_tokens = sum(len(str(msg.content).split()) for msg in input_message.messages)
|
|
164
|
+
completion_tokens = len(output_message.split()) if output_message else 0
|
|
165
|
+
total_tokens = prompt_tokens + completion_tokens
|
|
166
|
+
usage = Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens)
|
|
167
|
+
return ChatResponse.from_string(output_message, usage=usage)
|
|
161
168
|
|
|
162
169
|
except Exception as ex:
|
|
163
170
|
logger.exception("ReWOO Agent failed with exception: %s", ex)
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import logging
|
|
17
|
+
from collections.abc import Awaitable
|
|
17
18
|
from collections.abc import Callable
|
|
18
19
|
from datetime import UTC
|
|
19
20
|
from datetime import datetime
|
|
@@ -35,10 +36,15 @@ logger = logging.getLogger(__name__)
|
|
|
35
36
|
|
|
36
37
|
class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConfig]):
|
|
37
38
|
|
|
38
|
-
def __init__(self, config: OAuth2AuthCodeFlowProviderConfig):
|
|
39
|
+
def __init__(self, config: OAuth2AuthCodeFlowProviderConfig, token_storage=None):
|
|
39
40
|
super().__init__(config)
|
|
40
|
-
self._authenticated_tokens: dict[str, AuthResult] = {}
|
|
41
41
|
self._auth_callback = None
|
|
42
|
+
# Always use token storage - defaults to in-memory if not provided
|
|
43
|
+
if token_storage is None:
|
|
44
|
+
from nat.plugins.mcp.auth.token_storage import InMemoryTokenStorage
|
|
45
|
+
self._token_storage = InMemoryTokenStorage()
|
|
46
|
+
else:
|
|
47
|
+
self._token_storage = token_storage
|
|
42
48
|
|
|
43
49
|
async def _attempt_token_refresh(self, user_id: str, auth_result: AuthResult) -> AuthResult | None:
|
|
44
50
|
refresh_token = auth_result.raw.get("refresh_token")
|
|
@@ -61,7 +67,7 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
|
|
|
61
67
|
raw=new_token_data,
|
|
62
68
|
)
|
|
63
69
|
|
|
64
|
-
self.
|
|
70
|
+
await self._token_storage.store(user_id, new_auth_result)
|
|
65
71
|
except httpx.HTTPStatusError:
|
|
66
72
|
return None
|
|
67
73
|
except httpx.RequestError:
|
|
@@ -74,26 +80,30 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
|
|
|
74
80
|
|
|
75
81
|
def _set_custom_auth_callback(self,
|
|
76
82
|
auth_callback: Callable[[OAuth2AuthCodeFlowProviderConfig, AuthFlowType],
|
|
77
|
-
AuthenticatedContext]):
|
|
83
|
+
Awaitable[AuthenticatedContext]]):
|
|
78
84
|
self._auth_callback = auth_callback
|
|
79
85
|
|
|
80
86
|
async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult:
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
87
|
+
context = Context.get()
|
|
88
|
+
if user_id is None and hasattr(context, "metadata") and hasattr(
|
|
89
|
+
context.metadata, "cookies") and context.metadata.cookies is not None:
|
|
90
|
+
session_id = context.metadata.cookies.get("nat-session", None)
|
|
84
91
|
if not session_id:
|
|
85
92
|
raise RuntimeError("Authentication failed. No session ID found. Cannot identify user.")
|
|
86
93
|
|
|
87
94
|
user_id = session_id
|
|
88
95
|
|
|
89
|
-
if user_id
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
96
|
+
if user_id:
|
|
97
|
+
# Try to retrieve from token storage
|
|
98
|
+
auth_result = await self._token_storage.retrieve(user_id)
|
|
99
|
+
|
|
100
|
+
if auth_result:
|
|
101
|
+
if not auth_result.is_expired():
|
|
102
|
+
return auth_result
|
|
93
103
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
104
|
+
refreshed_auth_result = await self._attempt_token_refresh(user_id, auth_result)
|
|
105
|
+
if refreshed_auth_result:
|
|
106
|
+
return refreshed_auth_result
|
|
97
107
|
|
|
98
108
|
# Try getting callback from the context if that's not set, use the default callback
|
|
99
109
|
try:
|
|
@@ -109,19 +119,22 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
|
|
|
109
119
|
except Exception as e:
|
|
110
120
|
raise RuntimeError(f"Authentication callback failed: {e}") from e
|
|
111
121
|
|
|
112
|
-
|
|
122
|
+
headers = authenticated_context.headers or {}
|
|
123
|
+
auth_header = headers.get("Authorization", "")
|
|
113
124
|
if not auth_header.startswith("Bearer "):
|
|
114
125
|
raise RuntimeError("Invalid Authorization header")
|
|
115
126
|
|
|
116
127
|
token = auth_header.split(" ")[1]
|
|
117
128
|
|
|
129
|
+
# Safely access metadata
|
|
130
|
+
metadata = authenticated_context.metadata or {}
|
|
118
131
|
auth_result = AuthResult(
|
|
119
132
|
credentials=[BearerTokenCred(token=SecretStr(token))],
|
|
120
|
-
token_expires_at=
|
|
121
|
-
raw=
|
|
133
|
+
token_expires_at=metadata.get("expires_at"),
|
|
134
|
+
raw=metadata.get("raw_token") or {},
|
|
122
135
|
)
|
|
123
136
|
|
|
124
137
|
if user_id:
|
|
125
|
-
self.
|
|
138
|
+
await self._token_storage.store(user_id, auth_result)
|
|
126
139
|
|
|
127
140
|
return auth_result
|
nat/builder/context.py
CHANGED
|
@@ -67,6 +67,8 @@ class ContextState(metaclass=Singleton):
|
|
|
67
67
|
def __init__(self):
|
|
68
68
|
self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None)
|
|
69
69
|
self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
|
|
70
|
+
self.workflow_run_id: ContextVar[str | None] = ContextVar("workflow_run_id", default=None)
|
|
71
|
+
self.workflow_trace_id: ContextVar[int | None] = ContextVar("workflow_trace_id", default=None)
|
|
70
72
|
self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
|
|
71
73
|
self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
|
|
72
74
|
self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None)
|
|
@@ -120,14 +122,14 @@ class Context:
|
|
|
120
122
|
@property
|
|
121
123
|
def input_message(self):
|
|
122
124
|
"""
|
|
123
|
-
|
|
125
|
+
Retrieves the input message from the context state.
|
|
124
126
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
127
|
+
The input_message property is used to access the message stored in the
|
|
128
|
+
context state. This property returns the message as it is currently
|
|
129
|
+
maintained in the context.
|
|
128
130
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
+
Returns:
|
|
132
|
+
str: The input message retrieved from the context state.
|
|
131
133
|
"""
|
|
132
134
|
return self._context_state.input_message.get()
|
|
133
135
|
|
|
@@ -196,6 +198,20 @@ class Context:
|
|
|
196
198
|
"""
|
|
197
199
|
return self._context_state.user_message_id.get()
|
|
198
200
|
|
|
201
|
+
@property
|
|
202
|
+
def workflow_run_id(self) -> str | None:
|
|
203
|
+
"""
|
|
204
|
+
Returns a stable identifier for the current workflow/agent invocation (UUID string).
|
|
205
|
+
"""
|
|
206
|
+
return self._context_state.workflow_run_id.get()
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def workflow_trace_id(self) -> int | None:
|
|
210
|
+
"""
|
|
211
|
+
Returns the 128-bit trace identifier for the current run, used as the OpenTelemetry trace_id.
|
|
212
|
+
"""
|
|
213
|
+
return self._context_state.workflow_trace_id.get()
|
|
214
|
+
|
|
199
215
|
@contextmanager
|
|
200
216
|
def push_active_function(self,
|
|
201
217
|
function_name: str,
|
nat/cli/commands/mcp/mcp.py
CHANGED
|
@@ -194,7 +194,7 @@ async def _create_mcp_client_config(
|
|
|
194
194
|
auth_user_id: str | None,
|
|
195
195
|
auth_scopes: list[str] | None,
|
|
196
196
|
):
|
|
197
|
-
from nat.plugins.mcp.
|
|
197
|
+
from nat.plugins.mcp.client_config import MCPClientConfig
|
|
198
198
|
|
|
199
199
|
if url and transport == "streamable-http" and auth_redirect_uri:
|
|
200
200
|
try:
|
|
@@ -236,8 +236,8 @@ async def list_tools_via_function_group(
|
|
|
236
236
|
try:
|
|
237
237
|
# Ensure the registration side-effects are loaded
|
|
238
238
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
239
|
-
from nat.plugins.mcp.
|
|
240
|
-
from nat.plugins.mcp.
|
|
239
|
+
from nat.plugins.mcp.client_config import MCPClientConfig
|
|
240
|
+
from nat.plugins.mcp.client_config import MCPServerConfig
|
|
241
241
|
except ImportError:
|
|
242
242
|
click.echo(
|
|
243
243
|
"MCP client functionality requires nvidia-nat-mcp package. Install with: uv pip install nvidia-nat-mcp",
|
|
@@ -297,7 +297,7 @@ async def list_tools_via_function_group(
|
|
|
297
297
|
if fn is not None:
|
|
298
298
|
tools.append(to_tool_entry(full, fn))
|
|
299
299
|
else:
|
|
300
|
-
for full, fn in
|
|
300
|
+
for full, fn in fns.items():
|
|
301
301
|
tools.append(to_tool_entry(full, fn))
|
|
302
302
|
|
|
303
303
|
return tools
|
|
@@ -826,8 +826,8 @@ async def call_tool_and_print(command: str | None,
|
|
|
826
826
|
|
|
827
827
|
try:
|
|
828
828
|
from nat.builder.workflow_builder import WorkflowBuilder
|
|
829
|
-
from nat.plugins.mcp.
|
|
830
|
-
from nat.plugins.mcp.
|
|
829
|
+
from nat.plugins.mcp.client_config import MCPClientConfig
|
|
830
|
+
from nat.plugins.mcp.client_config import MCPServerConfig
|
|
831
831
|
except ImportError:
|
|
832
832
|
click.echo(
|
|
833
833
|
"MCP client functionality requires nvidia-nat-mcp package. Install with: uv pip install nvidia-nat-mcp",
|
|
@@ -1,15 +1,17 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
1
|
+
functions:
|
|
2
|
+
current_datetime:
|
|
3
|
+
_type: current_datetime
|
|
4
|
+
{{python_safe_workflow_name}}:
|
|
5
|
+
_type: {{python_safe_workflow_name}}
|
|
6
|
+
prefix: "Hello:"
|
|
6
7
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
8
|
+
llms:
|
|
9
|
+
nim_llm:
|
|
10
|
+
_type: nim
|
|
11
|
+
model_name: meta/llama-3.1-70b-instruct
|
|
12
|
+
temperature: 0.0
|
|
12
13
|
|
|
13
14
|
workflow:
|
|
14
|
-
_type:
|
|
15
|
-
|
|
15
|
+
_type: react_agent
|
|
16
|
+
llm_name: nim_llm
|
|
17
|
+
tool_names: [current_datetime, {{python_safe_workflow_name}}]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
1
|
# flake8: noqa
|
|
2
2
|
|
|
3
|
-
# Import
|
|
4
|
-
from {{package_name}} import {{
|
|
3
|
+
# Import the generated workflow function to trigger registration
|
|
4
|
+
from .{{package_name}} import {{ python_safe_workflow_name }}_function
|
|
@@ -3,6 +3,7 @@ import logging
|
|
|
3
3
|
from pydantic import Field
|
|
4
4
|
|
|
5
5
|
from nat.builder.builder import Builder
|
|
6
|
+
from nat.builder.framework_enum import LLMFrameworkEnum
|
|
6
7
|
from nat.builder.function_info import FunctionInfo
|
|
7
8
|
from nat.cli.register_workflow import register_function
|
|
8
9
|
from nat.data_models.function import FunctionBaseConfig
|
|
@@ -12,25 +13,38 @@ logger = logging.getLogger(__name__)
|
|
|
12
13
|
|
|
13
14
|
class {{ workflow_class_name }}(FunctionBaseConfig, name="{{ workflow_name }}"):
|
|
14
15
|
"""
|
|
15
|
-
{{workflow_description}}
|
|
16
|
+
{{ workflow_description }}
|
|
16
17
|
"""
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
18
|
+
prefix: str = Field(default="Echo:", description="Prefix to add before the echoed text.")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@register_function(config_type={{ workflow_class_name }}, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
|
|
22
|
+
async def {{ python_safe_workflow_name }}_function(config: {{ workflow_class_name }}, builder: Builder):
|
|
23
|
+
"""
|
|
24
|
+
Registers a function (addressable via `{{ workflow_name }}` in the configuration).
|
|
25
|
+
This registration ensures a static mapping of the function type, `{{ workflow_name }}`, to the `{{ workflow_class_name }}` configuration object.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
config ({{ workflow_class_name }}): The configuration for the function.
|
|
29
|
+
builder (Builder): The builder object.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
FunctionInfo: The function info object for the function.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
# Define the function that will be registered.
|
|
36
|
+
async def _echo(text: str) -> str:
|
|
37
|
+
"""
|
|
38
|
+
Takes a text input and echoes back with a pre-defined prefix.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
text (str): The text to echo back.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
str: The text with the prefix.
|
|
45
|
+
"""
|
|
46
|
+
return f"{config.prefix} {text}"
|
|
47
|
+
|
|
48
|
+
# The callable is wrapped in a FunctionInfo object.
|
|
49
|
+
# The description parameter is used to describe the function.
|
|
50
|
+
yield FunctionInfo.from_fn(_echo, description=_echo.__doc__)
|
|
@@ -27,6 +27,50 @@ from jinja2 import FileSystemLoader
|
|
|
27
27
|
logger = logging.getLogger(__name__)
|
|
28
28
|
|
|
29
29
|
|
|
30
|
+
def _get_nat_version() -> str | None:
|
|
31
|
+
"""
|
|
32
|
+
Get the current NAT version.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
str: The NAT version intended for use in a dependency string.
|
|
36
|
+
None: If the NAT version is not found.
|
|
37
|
+
"""
|
|
38
|
+
from nat.cli.entrypoint import get_version
|
|
39
|
+
|
|
40
|
+
current_version = get_version()
|
|
41
|
+
if current_version == "unknown":
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
version_parts = current_version.split(".")
|
|
45
|
+
if len(version_parts) < 3:
|
|
46
|
+
# If the version somehow doesn't have three parts, return the full version
|
|
47
|
+
return current_version
|
|
48
|
+
|
|
49
|
+
patch = version_parts[2]
|
|
50
|
+
try:
|
|
51
|
+
# If the patch is a number, keep only the major and minor parts
|
|
52
|
+
# Useful for stable releases and adheres to semantic versioning
|
|
53
|
+
_ = int(patch)
|
|
54
|
+
digits_to_keep = 2
|
|
55
|
+
except ValueError:
|
|
56
|
+
# If the patch is not a number, keep all three digits
|
|
57
|
+
# Useful for pre-release versions (and nightly builds)
|
|
58
|
+
digits_to_keep = 3
|
|
59
|
+
|
|
60
|
+
return ".".join(version_parts[:digits_to_keep])
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _is_nat_version_prerelease() -> bool:
|
|
64
|
+
"""
|
|
65
|
+
Check if the NAT version is a prerelease.
|
|
66
|
+
"""
|
|
67
|
+
version = _get_nat_version()
|
|
68
|
+
if version is None:
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
return len(version.split(".")) >= 3
|
|
72
|
+
|
|
73
|
+
|
|
30
74
|
def _get_nat_dependency(versioned: bool = True) -> str:
|
|
31
75
|
"""
|
|
32
76
|
Get the NAT dependency string with version.
|
|
@@ -44,16 +88,12 @@ def _get_nat_dependency(versioned: bool = True) -> str:
|
|
|
44
88
|
logger.debug("Using unversioned NAT dependency: %s", dependency)
|
|
45
89
|
return dependency
|
|
46
90
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
if current_version == "unknown":
|
|
51
|
-
logger.warning("Could not detect NAT version, using unversioned dependency")
|
|
91
|
+
version = _get_nat_version()
|
|
92
|
+
if version is None:
|
|
93
|
+
logger.debug("Could not detect NAT version, using unversioned dependency: %s", dependency)
|
|
52
94
|
return dependency
|
|
53
95
|
|
|
54
|
-
|
|
55
|
-
major_minor = ".".join(current_version.split(".")[:2])
|
|
56
|
-
dependency += f"~={major_minor}"
|
|
96
|
+
dependency += f"~={version}"
|
|
57
97
|
logger.debug("Using NAT dependency: %s", dependency)
|
|
58
98
|
return dependency
|
|
59
99
|
|
|
@@ -219,12 +259,16 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
219
259
|
install_cmd = ['uv', 'pip', 'install', '-e', str(new_workflow_dir)]
|
|
220
260
|
else:
|
|
221
261
|
install_cmd = ['pip', 'install', '-e', str(new_workflow_dir)]
|
|
262
|
+
if _is_nat_version_prerelease():
|
|
263
|
+
install_cmd.insert(2, "--pre")
|
|
264
|
+
|
|
265
|
+
python_safe_workflow_name = workflow_name.replace("-", "_")
|
|
222
266
|
|
|
223
267
|
# List of templates and their destinations
|
|
224
268
|
files_to_render = {
|
|
225
269
|
'pyproject.toml.j2': new_workflow_dir / 'pyproject.toml',
|
|
226
270
|
'register.py.j2': base_dir / 'register.py',
|
|
227
|
-
'workflow.py.j2': base_dir / f'{
|
|
271
|
+
'workflow.py.j2': base_dir / f'{python_safe_workflow_name}.py',
|
|
228
272
|
'__init__.py.j2': base_dir / '__init__.py',
|
|
229
273
|
'config.yml.j2': configs_dir / 'config.yml',
|
|
230
274
|
}
|
|
@@ -233,7 +277,7 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
233
277
|
context = {
|
|
234
278
|
'editable': editable,
|
|
235
279
|
'workflow_name': workflow_name,
|
|
236
|
-
'python_safe_workflow_name':
|
|
280
|
+
'python_safe_workflow_name': python_safe_workflow_name,
|
|
237
281
|
'package_name': package_name,
|
|
238
282
|
'rel_path_to_repo_root': rel_path_to_repo_root,
|
|
239
283
|
'workflow_class_name': f"{_generate_valid_classname(workflow_name)}FunctionConfig",
|
nat/cli/main.py
CHANGED