nvidia-nat 1.3.0a20250827__py3-none-any.whl → 1.3.0a20250829__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nat/agent/base.py +12 -7
- nat/agent/dual_node.py +7 -2
- nat/agent/react_agent/agent.py +15 -14
- nat/agent/react_agent/register.py +5 -1
- nat/agent/rewoo_agent/agent.py +23 -32
- nat/agent/rewoo_agent/register.py +8 -4
- nat/agent/tool_calling_agent/agent.py +15 -20
- nat/agent/tool_calling_agent/register.py +6 -2
- nat/builder/context.py +7 -2
- nat/builder/eval_builder.py +2 -2
- nat/builder/function.py +8 -8
- nat/builder/workflow_builder.py +21 -24
- nat/cli/cli_utils/config_override.py +1 -1
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +1 -1
- nat/cli/commands/start.py +1 -1
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/workflow_commands.py +4 -4
- nat/cli/entrypoint.py +3 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/gated_field_mixin.py +12 -14
- nat/data_models/temperature_mixin.py +1 -1
- nat/data_models/thinking_mixin.py +68 -0
- nat/data_models/top_p_mixin.py +1 -1
- nat/eval/evaluate.py +6 -6
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +1 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/swe_bench_evaluator/evaluate.py +5 -5
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +2 -2
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +1 -1
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +3 -3
- nat/front_ends/fastapi/message_handler.py +2 -2
- nat/front_ends/fastapi/message_validator.py +8 -10
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +1 -1
- nat/llm/aws_bedrock_llm.py +10 -9
- nat/llm/azure_openai_llm.py +9 -1
- nat/llm/nim_llm.py +2 -1
- nat/llm/openai_llm.py +2 -1
- nat/llm/utils/thinking.py +215 -0
- nat/observability/exporter/base_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +8 -9
- nat/observability/exporter_manager.py +5 -5
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/processor/batching_processor.py +4 -6
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +5 -5
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +2 -2
- nat/profiler/decorators/function_tracking.py +125 -0
- nat/profiler/profile_runner.py +1 -1
- nat/profiler/utils.py +1 -1
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -1
- nat/registry_handlers/pypi/pypi_handler.py +3 -3
- nat/registry_handlers/rest/rest_handler.py +4 -4
- nat/retriever/milvus/retriever.py +1 -1
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/runtime/loader.py +1 -1
- nat/runtime/runner.py +2 -2
- nat/settings/global_settings.py +1 -1
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +1 -1
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/retriever.py +3 -2
- nat/utils/io/yaml_tools.py +1 -1
- nat/utils/reactive/observer.py +2 -2
- nat/utils/settings/global_settings.py +2 -2
- {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250829.dist-info}/METADATA +3 -1
- {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250829.dist-info}/RECORD +87 -81
- {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250829.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250829.dist-info}/entry_points.txt +0 -0
- {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250829.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250829.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0a20250827.dist-info → nvidia_nat-1.3.0a20250829.dist-info}/top_level.txt +0 -0
nat/builder/workflow_builder.py
CHANGED
|
@@ -420,8 +420,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
420
420
|
# Wrap in the correct wrapper
|
|
421
421
|
return tool_wrapper_reg.build_fn(fn_name, fn.instance, self)
|
|
422
422
|
except Exception as e:
|
|
423
|
-
logger.error("Error fetching tool `%s
|
|
424
|
-
raise
|
|
423
|
+
logger.error("Error fetching tool `%s`: %s", fn_name, e)
|
|
424
|
+
raise
|
|
425
425
|
|
|
426
426
|
@override
|
|
427
427
|
async def add_llm(self, name: str | LLMRef, config: LLMBaseConfig):
|
|
@@ -436,8 +436,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
436
436
|
|
|
437
437
|
self._llms[name] = ConfiguredLLM(config=config, instance=info_obj)
|
|
438
438
|
except Exception as e:
|
|
439
|
-
logger.error("Error adding llm `%s` with config `%s
|
|
440
|
-
raise
|
|
439
|
+
logger.error("Error adding llm `%s` with config `%s`: %s", name, config, e)
|
|
440
|
+
raise
|
|
441
441
|
|
|
442
442
|
@override
|
|
443
443
|
async def get_llm(self, llm_name: str | LLMRef, wrapper_type: LLMFrameworkEnum | str):
|
|
@@ -457,8 +457,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
457
457
|
# Return a frameworks specific client
|
|
458
458
|
return client
|
|
459
459
|
except Exception as e:
|
|
460
|
-
logger.error("Error getting llm `%s` with wrapper `%s
|
|
461
|
-
raise
|
|
460
|
+
logger.error("Error getting llm `%s` with wrapper `%s`: %s", llm_name, wrapper_type, e)
|
|
461
|
+
raise
|
|
462
462
|
|
|
463
463
|
@override
|
|
464
464
|
def get_llm_config(self, llm_name: str | LLMRef) -> LLMBaseConfig:
|
|
@@ -508,8 +508,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
508
508
|
|
|
509
509
|
return info_obj
|
|
510
510
|
except Exception as e:
|
|
511
|
-
logger.error("Error adding authentication `%s` with config `%s
|
|
512
|
-
raise
|
|
511
|
+
logger.error("Error adding authentication `%s` with config `%s`: %s", name, config, e)
|
|
512
|
+
raise
|
|
513
513
|
|
|
514
514
|
@override
|
|
515
515
|
async def get_auth_provider(self, auth_provider_name: str) -> AuthProviderBase:
|
|
@@ -552,9 +552,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
552
552
|
|
|
553
553
|
self._embedders[name] = ConfiguredEmbedder(config=config, instance=info_obj)
|
|
554
554
|
except Exception as e:
|
|
555
|
-
logger.error("Error adding embedder `%s` with config `%s
|
|
556
|
-
|
|
557
|
-
raise e
|
|
555
|
+
logger.error("Error adding embedder `%s` with config `%s`: %s", name, config, e)
|
|
556
|
+
raise
|
|
558
557
|
|
|
559
558
|
@override
|
|
560
559
|
async def get_embedder(self, embedder_name: str | EmbedderRef, wrapper_type: LLMFrameworkEnum | str):
|
|
@@ -574,8 +573,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
574
573
|
# Return a frameworks specific client
|
|
575
574
|
return client
|
|
576
575
|
except Exception as e:
|
|
577
|
-
logger.error("Error getting embedder `%s` with wrapper `%s
|
|
578
|
-
raise
|
|
576
|
+
logger.error("Error getting embedder `%s` with wrapper `%s`: %s", embedder_name, wrapper_type, e)
|
|
577
|
+
raise
|
|
579
578
|
|
|
580
579
|
@override
|
|
581
580
|
def get_embedder_config(self, embedder_name: str | EmbedderRef) -> EmbedderBaseConfig:
|
|
@@ -660,9 +659,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
660
659
|
self._retrievers[name] = ConfiguredRetriever(config=config, instance=info_obj)
|
|
661
660
|
|
|
662
661
|
except Exception as e:
|
|
663
|
-
logger.error("Error adding retriever `%s` with config `%s
|
|
664
|
-
|
|
665
|
-
raise e
|
|
662
|
+
logger.error("Error adding retriever `%s` with config `%s`: %s", name, config, e)
|
|
663
|
+
raise
|
|
666
664
|
|
|
667
665
|
# return info_obj
|
|
668
666
|
|
|
@@ -687,8 +685,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
687
685
|
# Return a frameworks specific client
|
|
688
686
|
return client
|
|
689
687
|
except Exception as e:
|
|
690
|
-
logger.error("Error getting retriever `%s` with wrapper `%s
|
|
691
|
-
raise
|
|
688
|
+
logger.error("Error getting retriever `%s` with wrapper `%s`: %s", retriever_name, wrapper_type, e)
|
|
689
|
+
raise
|
|
692
690
|
|
|
693
691
|
@override
|
|
694
692
|
async def get_retriever_config(self, retriever_name: str | RetrieverRef) -> RetrieverBaseConfig:
|
|
@@ -712,9 +710,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
712
710
|
self._ttc_strategies[name] = ConfiguredTTCStrategy(config=config, instance=info_obj)
|
|
713
711
|
|
|
714
712
|
except Exception as e:
|
|
715
|
-
logger.error("Error adding TTC strategy `%s` with config `%s
|
|
716
|
-
|
|
717
|
-
raise e
|
|
713
|
+
logger.error("Error adding TTC strategy `%s` with config `%s`: %s", name, config, e)
|
|
714
|
+
raise
|
|
718
715
|
|
|
719
716
|
@override
|
|
720
717
|
async def get_ttc_strategy(self,
|
|
@@ -742,8 +739,8 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
742
739
|
|
|
743
740
|
return instance
|
|
744
741
|
except Exception as e:
|
|
745
|
-
logger.error("Error getting TTC strategy `%s
|
|
746
|
-
raise
|
|
742
|
+
logger.error("Error getting TTC strategy `%s`: %s", strategy_name, e)
|
|
743
|
+
raise
|
|
747
744
|
|
|
748
745
|
@override
|
|
749
746
|
async def get_ttc_strategy_config(self,
|
|
@@ -820,7 +817,7 @@ class WorkflowBuilder(Builder, AbstractAsyncContextManager):
|
|
|
820
817
|
else:
|
|
821
818
|
logger.error("No remaining components to build")
|
|
822
819
|
|
|
823
|
-
logger.error("Original error:", exc_info=
|
|
820
|
+
logger.error("Original error: %s", original_error, exc_info=True)
|
|
824
821
|
|
|
825
822
|
def _log_build_failure_component(self,
|
|
826
823
|
failing_component: ComponentInstanceData,
|
|
@@ -213,7 +213,7 @@ def load_and_override_config(config_file: Path, overrides: tuple[tuple[str, str]
|
|
|
213
213
|
yaml.dump(effective_config, default_flow_style=False),
|
|
214
214
|
)
|
|
215
215
|
except Exception as e:
|
|
216
|
-
logger.
|
|
216
|
+
logger.error("Modified configuration failed validation: %s", e)
|
|
217
217
|
raise click.BadParameter(f"Modified configuration failed validation: {str(e)}")
|
|
218
218
|
finally:
|
|
219
219
|
# Clean up the temporary file
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import importlib
|
|
18
|
+
import logging
|
|
19
|
+
import mimetypes
|
|
20
|
+
import time
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
import click
|
|
24
|
+
|
|
25
|
+
from nat.builder.workflow_builder import WorkflowBuilder
|
|
26
|
+
from nat.data_models.object_store import ObjectStoreBaseConfig
|
|
27
|
+
from nat.object_store.interfaces import ObjectStore
|
|
28
|
+
from nat.object_store.models import ObjectStoreItem
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
STORE_CONFIGS = {
|
|
33
|
+
"s3": {
|
|
34
|
+
"module": "nat.plugins.s3.object_store", "config_class": "S3ObjectStoreClientConfig"
|
|
35
|
+
},
|
|
36
|
+
"mysql": {
|
|
37
|
+
"module": "nat.plugins.mysql.object_store", "config_class": "MySQLObjectStoreClientConfig"
|
|
38
|
+
},
|
|
39
|
+
"redis": {
|
|
40
|
+
"module": "nat.plugins.redis.object_store", "config_class": "RedisObjectStoreClientConfig"
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_object_store_config(**kwargs) -> ObjectStoreBaseConfig:
|
|
46
|
+
"""Process common object store arguments and return the config class"""
|
|
47
|
+
store_type = kwargs.pop("store_type")
|
|
48
|
+
config = STORE_CONFIGS[store_type]
|
|
49
|
+
module = importlib.import_module(config["module"])
|
|
50
|
+
config_class = getattr(module, config["config_class"])
|
|
51
|
+
return config_class(**kwargs)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
async def upload_file(object_store: ObjectStore, file_path: Path, key: str):
|
|
55
|
+
"""
|
|
56
|
+
Upload a single file to object store.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
object_store: The object store instance to use.
|
|
60
|
+
file_path: The path to the file to upload.
|
|
61
|
+
key: The key to upload the file to.
|
|
62
|
+
"""
|
|
63
|
+
try:
|
|
64
|
+
data = await asyncio.to_thread(file_path.read_bytes)
|
|
65
|
+
|
|
66
|
+
item = ObjectStoreItem(data=data,
|
|
67
|
+
content_type=mimetypes.guess_type(str(file_path))[0],
|
|
68
|
+
metadata={
|
|
69
|
+
"original_filename": file_path.name,
|
|
70
|
+
"file_size": str(len(data)),
|
|
71
|
+
"file_extension": file_path.suffix,
|
|
72
|
+
"upload_timestamp": str(int(time.time()))
|
|
73
|
+
})
|
|
74
|
+
|
|
75
|
+
# Upload using upsert to allow overwriting
|
|
76
|
+
await object_store.upsert_object(key, item)
|
|
77
|
+
click.echo(f"✅ Uploaded: {file_path.name} -> {key}")
|
|
78
|
+
|
|
79
|
+
except Exception as e:
|
|
80
|
+
raise RuntimeError(f"Failed to upload {file_path.name}:\n{e}") from e
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def object_store_command_decorator(async_func):
|
|
84
|
+
"""
|
|
85
|
+
Decorator that handles the common object store command pattern.
|
|
86
|
+
|
|
87
|
+
The decorated function should take (store: ObjectStore, kwargs) as parameters
|
|
88
|
+
and return an exit code (0 for success).
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
@click.pass_context
|
|
92
|
+
def wrapper(ctx: click.Context, **kwargs):
|
|
93
|
+
config = ctx.obj["store_config"]
|
|
94
|
+
|
|
95
|
+
async def work():
|
|
96
|
+
async with WorkflowBuilder() as builder:
|
|
97
|
+
await builder.add_object_store(name="store", config=config)
|
|
98
|
+
store = await builder.get_object_store_client("store")
|
|
99
|
+
return await async_func(store, **kwargs)
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
exit_code = asyncio.run(work())
|
|
103
|
+
except Exception as e:
|
|
104
|
+
raise click.ClickException(f"Command failed: {e}") from e
|
|
105
|
+
if exit_code != 0:
|
|
106
|
+
raise click.ClickException(f"Command failed with exit code {exit_code}")
|
|
107
|
+
return exit_code
|
|
108
|
+
|
|
109
|
+
return wrapper
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@click.command(name="upload", help="Upload a directory to an object store.")
|
|
113
|
+
@click.argument("local_dir",
|
|
114
|
+
type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path),
|
|
115
|
+
required=True)
|
|
116
|
+
@click.help_option("--help", "-h")
|
|
117
|
+
@object_store_command_decorator
|
|
118
|
+
async def upload_command(store: ObjectStore, local_dir: Path, **_kwargs):
|
|
119
|
+
"""
|
|
120
|
+
Upload a directory to an object store.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
local_dir: The local directory to upload.
|
|
124
|
+
store: The object store to use.
|
|
125
|
+
_kwargs: Additional keyword arguments.
|
|
126
|
+
"""
|
|
127
|
+
try:
|
|
128
|
+
click.echo(f"📁 Processing directory: {local_dir}")
|
|
129
|
+
file_count = 0
|
|
130
|
+
|
|
131
|
+
# Process each file recursively
|
|
132
|
+
for file_path in local_dir.rglob("*"):
|
|
133
|
+
if file_path.is_file():
|
|
134
|
+
key = file_path.relative_to(local_dir).as_posix()
|
|
135
|
+
await upload_file(store, file_path, key)
|
|
136
|
+
file_count += 1
|
|
137
|
+
|
|
138
|
+
click.echo(f"✅ Directory uploaded successfully! {file_count} files uploaded.")
|
|
139
|
+
return 0
|
|
140
|
+
|
|
141
|
+
except Exception as e:
|
|
142
|
+
raise click.ClickException(f"❌ Failed to upload directory {local_dir}:\n {e}") from e
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@click.command(name="delete", help="Delete files from an object store.")
|
|
146
|
+
@click.argument("keys", type=str, required=True, nargs=-1)
|
|
147
|
+
@click.help_option("--help", "-h")
|
|
148
|
+
@object_store_command_decorator
|
|
149
|
+
async def delete_command(store: ObjectStore, keys: list[str], **_kwargs):
|
|
150
|
+
"""
|
|
151
|
+
Delete files from an object store.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
store: The object store to use.
|
|
155
|
+
keys: The keys to delete.
|
|
156
|
+
_kwargs: Additional keyword arguments.
|
|
157
|
+
"""
|
|
158
|
+
deleted_count = 0
|
|
159
|
+
failed_count = 0
|
|
160
|
+
for key in keys:
|
|
161
|
+
try:
|
|
162
|
+
await store.delete_object(key)
|
|
163
|
+
click.echo(f"✅ Deleted: {key}")
|
|
164
|
+
deleted_count += 1
|
|
165
|
+
except Exception as e:
|
|
166
|
+
click.echo(f"❌ Failed to delete {key}: {e}")
|
|
167
|
+
failed_count += 1
|
|
168
|
+
|
|
169
|
+
click.echo(f"✅ Deletion completed! {deleted_count} keys deleted. {failed_count} keys failed to delete.")
|
|
170
|
+
return 0 if failed_count == 0 else 1
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@click.group(name="object-store", invoke_without_command=False, help="Manage object store operations.")
|
|
174
|
+
def object_store_command(**_kwargs):
|
|
175
|
+
"""Manage object store operations including uploading files and directories."""
|
|
176
|
+
pass
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def register_object_store_commands():
|
|
180
|
+
|
|
181
|
+
@click.group(name="s3", invoke_without_command=False, help="S3 object store operations.")
|
|
182
|
+
@click.argument("bucket_name", type=str, required=True)
|
|
183
|
+
@click.option("--endpoint-url", type=str, help="S3 endpoint URL")
|
|
184
|
+
@click.option("--access-key", type=str, help="S3 access key")
|
|
185
|
+
@click.option("--secret-key", type=str, help="S3 secret key")
|
|
186
|
+
@click.option("--region", type=str, help="S3 region")
|
|
187
|
+
@click.pass_context
|
|
188
|
+
def s3(ctx: click.Context, **kwargs):
|
|
189
|
+
ctx.ensure_object(dict)
|
|
190
|
+
ctx.obj["store_config"] = get_object_store_config(store_type="s3", **kwargs)
|
|
191
|
+
|
|
192
|
+
@click.group(name="mysql", invoke_without_command=False, help="MySQL object store operations.")
|
|
193
|
+
@click.argument("bucket_name", type=str, required=True)
|
|
194
|
+
@click.option("--host", type=str, help="MySQL host")
|
|
195
|
+
@click.option("--port", type=int, help="MySQL port")
|
|
196
|
+
@click.option("--db", type=str, help="MySQL database name")
|
|
197
|
+
@click.option("--username", type=str, help="MySQL username")
|
|
198
|
+
@click.option("--password", type=str, help="MySQL password")
|
|
199
|
+
@click.pass_context
|
|
200
|
+
def mysql(ctx: click.Context, **kwargs):
|
|
201
|
+
ctx.ensure_object(dict)
|
|
202
|
+
ctx.obj["store_config"] = get_object_store_config(store_type="mysql", **kwargs)
|
|
203
|
+
|
|
204
|
+
@click.group(name="redis", invoke_without_command=False, help="Redis object store operations.")
|
|
205
|
+
@click.argument("bucket_name", type=str, required=True)
|
|
206
|
+
@click.option("--host", type=str, help="Redis host")
|
|
207
|
+
@click.option("--port", type=int, help="Redis port")
|
|
208
|
+
@click.option("--db", type=int, help="Redis db")
|
|
209
|
+
@click.pass_context
|
|
210
|
+
def redis(ctx: click.Context, **kwargs):
|
|
211
|
+
ctx.ensure_object(dict)
|
|
212
|
+
ctx.obj["store_config"] = get_object_store_config(store_type="redis", **kwargs)
|
|
213
|
+
|
|
214
|
+
commands = {"s3": s3, "mysql": mysql, "redis": redis}
|
|
215
|
+
|
|
216
|
+
for store_type, config in STORE_CONFIGS.items():
|
|
217
|
+
try:
|
|
218
|
+
importlib.import_module(config["module"])
|
|
219
|
+
command = commands[store_type]
|
|
220
|
+
object_store_command.add_command(command, name=store_type)
|
|
221
|
+
command.add_command(upload_command, name="upload")
|
|
222
|
+
command.add_command(delete_command, name="delete")
|
|
223
|
+
except ImportError:
|
|
224
|
+
pass
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
register_object_store_commands()
|
|
@@ -40,7 +40,7 @@ async def publish_artifact(registry_handler_config: RegistryHandlerBaseConfig, p
|
|
|
40
40
|
try:
|
|
41
41
|
artifact = build_artifact(package_root=package_root)
|
|
42
42
|
except Exception as e:
|
|
43
|
-
logger.exception("Error building artifact: %s", e
|
|
43
|
+
logger.exception("Error building artifact: %s", e)
|
|
44
44
|
return
|
|
45
45
|
await stack.enter_async_context(registry_handler.publish(artifact=artifact))
|
|
46
46
|
|
|
@@ -82,7 +82,7 @@ def publish(channel: str, config_file: str, package_root: str) -> None:
|
|
|
82
82
|
logger.error("Publish channel '%s' has not been configured.", channel)
|
|
83
83
|
return
|
|
84
84
|
except Exception as e:
|
|
85
|
-
logger.exception("Error loading user settings: %s", e
|
|
85
|
+
logger.exception("Error loading user settings: %s", e)
|
|
86
86
|
return
|
|
87
87
|
|
|
88
88
|
asyncio.run(publish_artifact(registry_handler_config=publish_channel_config, package_root=package_root))
|
|
@@ -66,7 +66,7 @@ async def pull_artifact(registry_handler_config: RegistryHandlerBaseConfig, pack
|
|
|
66
66
|
validated_packages = PullRequestPackages(packages=package_list)
|
|
67
67
|
|
|
68
68
|
except Exception as e:
|
|
69
|
-
logger.exception("Error processing package names: %s", e
|
|
69
|
+
logger.exception("Error processing package names: %s", e)
|
|
70
70
|
return
|
|
71
71
|
|
|
72
72
|
await stack.enter_async_context(registry_handler.pull(packages=validated_packages))
|
|
@@ -112,7 +112,7 @@ def pull(channel: str, config_file: str, packages: str) -> None:
|
|
|
112
112
|
logger.error("Pull channel '%s' has not been configured.", channel)
|
|
113
113
|
return
|
|
114
114
|
except Exception as e:
|
|
115
|
-
logger.exception("Error loading user settings: %s", e
|
|
115
|
+
logger.exception("Error loading user settings: %s", e)
|
|
116
116
|
return
|
|
117
117
|
|
|
118
118
|
asyncio.run(pull_artifact(pull_channel_config, packages))
|
|
@@ -41,7 +41,7 @@ async def remove_artifact(registry_handler_config: RegistryHandlerBaseConfig, pa
|
|
|
41
41
|
try:
|
|
42
42
|
package_name_list = PackageNameVersionList(**{"packages": packages})
|
|
43
43
|
except Exception as e:
|
|
44
|
-
logger.exception("Invalid package format: '%s'", e
|
|
44
|
+
logger.exception("Invalid package format: '%s'", e)
|
|
45
45
|
|
|
46
46
|
await stack.enter_async_context(registry_handler.remove(packages=package_name_list))
|
|
47
47
|
|
|
@@ -102,7 +102,7 @@ def remove(channel: str, config_file: str, packages: str) -> None:
|
|
|
102
102
|
logger.error("Remove channel '%s' has not been configured.", channel)
|
|
103
103
|
return
|
|
104
104
|
except Exception as e:
|
|
105
|
-
logger.exception("Error loading user settings: %s", e
|
|
105
|
+
logger.exception("Error loading user settings: %s", e)
|
|
106
106
|
return
|
|
107
107
|
|
|
108
108
|
asyncio.run(remove_artifact(registry_handler_config=remove_channel_config, packages=packages_versions))
|
|
@@ -140,7 +140,7 @@ def search(config_file: str,
|
|
|
140
140
|
logger.error("Search channel '%s' has not been configured.", channel)
|
|
141
141
|
return
|
|
142
142
|
except Exception as e:
|
|
143
|
-
logger.exception("Error loading user settings: %s", e
|
|
143
|
+
logger.exception("Error loading user settings: %s", e)
|
|
144
144
|
return
|
|
145
145
|
|
|
146
146
|
asyncio.run(
|
nat/cli/commands/start.py
CHANGED
|
@@ -224,7 +224,7 @@ class StartCommandGroup(click.Group):
|
|
|
224
224
|
return asyncio.run(run_plugin())
|
|
225
225
|
|
|
226
226
|
except Exception as e:
|
|
227
|
-
logger.error("Failed to initialize workflow"
|
|
227
|
+
logger.error("Failed to initialize workflow")
|
|
228
228
|
raise click.ClickException(str(e)) from e
|
|
229
229
|
|
|
230
230
|
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
|
nat/cli/commands/uninstall.py
CHANGED
|
@@ -44,7 +44,7 @@ async def uninstall_packages(packages: list[dict[str, str]]) -> None:
|
|
|
44
44
|
try:
|
|
45
45
|
package_name_list = PackageNameVersionList(**{"packages": packages})
|
|
46
46
|
except Exception as e:
|
|
47
|
-
logger.exception("Error validating package format: %s", e
|
|
47
|
+
logger.exception("Error validating package format: %s", e)
|
|
48
48
|
return
|
|
49
49
|
|
|
50
50
|
async with AsyncExitStack() as stack:
|
|
@@ -97,7 +97,7 @@ def find_package_root(package_name: str) -> Path | None:
|
|
|
97
97
|
try:
|
|
98
98
|
info = json.loads(direct_url)
|
|
99
99
|
except json.JSONDecodeError:
|
|
100
|
-
logger.
|
|
100
|
+
logger.exception("Malformed direct_url.json for package: %s", package_name)
|
|
101
101
|
return None
|
|
102
102
|
|
|
103
103
|
if not info.get("dir_info", {}).get("editable"):
|
|
@@ -271,7 +271,7 @@ def create_command(workflow_name: str, install: bool, workflow_dir: str, descrip
|
|
|
271
271
|
|
|
272
272
|
click.echo(f"Workflow '{workflow_name}' created successfully in '{new_workflow_dir}'.")
|
|
273
273
|
except Exception as e:
|
|
274
|
-
logger.exception("An error occurred while creating the workflow: %s", e
|
|
274
|
+
logger.exception("An error occurred while creating the workflow: %s", e)
|
|
275
275
|
click.echo(f"An error occurred while creating the workflow: {e}")
|
|
276
276
|
|
|
277
277
|
|
|
@@ -307,7 +307,7 @@ def reinstall_command(workflow_name):
|
|
|
307
307
|
|
|
308
308
|
click.echo(f"Workflow '{workflow_name}' reinstalled successfully.")
|
|
309
309
|
except Exception as e:
|
|
310
|
-
logger.exception("An error occurred while reinstalling the workflow: %s", e
|
|
310
|
+
logger.exception("An error occurred while reinstalling the workflow: %s", e)
|
|
311
311
|
click.echo(f"An error occurred while reinstalling the workflow: {e}")
|
|
312
312
|
|
|
313
313
|
|
|
@@ -354,7 +354,7 @@ def delete_command(workflow_name: str):
|
|
|
354
354
|
|
|
355
355
|
click.echo(f"Workflow '{workflow_name}' deleted successfully.")
|
|
356
356
|
except Exception as e:
|
|
357
|
-
logger.exception("An error occurred while deleting the workflow: %s", e
|
|
357
|
+
logger.exception("An error occurred while deleting the workflow: %s", e)
|
|
358
358
|
click.echo(f"An error occurred while deleting the workflow: {e}")
|
|
359
359
|
|
|
360
360
|
|
nat/cli/entrypoint.py
CHANGED
|
@@ -33,6 +33,7 @@ import nest_asyncio
|
|
|
33
33
|
from .commands.configure.configure import configure_command
|
|
34
34
|
from .commands.evaluate import eval_command
|
|
35
35
|
from .commands.info.info import info_command
|
|
36
|
+
from .commands.object_store.object_store import object_store_command
|
|
36
37
|
from .commands.registry.registry import registry_command
|
|
37
38
|
from .commands.sizing.sizing import sizing
|
|
38
39
|
from .commands.start import start_command
|
|
@@ -107,11 +108,12 @@ cli.add_command(uninstall_command, name="uninstall")
|
|
|
107
108
|
cli.add_command(validate_command, name="validate")
|
|
108
109
|
cli.add_command(workflow_command, name="workflow")
|
|
109
110
|
cli.add_command(sizing, name="sizing")
|
|
111
|
+
cli.add_command(object_store_command, name="object-store")
|
|
110
112
|
|
|
111
113
|
# Aliases
|
|
112
114
|
cli.add_command(start_command.get_command(None, "console"), name="run") # type: ignore
|
|
113
115
|
cli.add_command(start_command.get_command(None, "fastapi"), name="serve") # type: ignore
|
|
114
|
-
cli.add_command(start_command.get_command(None, "mcp"), name="mcp")
|
|
116
|
+
cli.add_command(start_command.get_command(None, "mcp"), name="mcp") # type: ignore
|
|
115
117
|
|
|
116
118
|
|
|
117
119
|
@cli.result_callback()
|
|
@@ -177,7 +177,7 @@ class DiscoveryMetadata(BaseModel):
|
|
|
177
177
|
logger.warning("Package metadata not found for %s", distro_name)
|
|
178
178
|
version = ""
|
|
179
179
|
except Exception as e:
|
|
180
|
-
logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e
|
|
180
|
+
logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e)
|
|
181
181
|
return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
|
|
182
182
|
|
|
183
183
|
description = generate_config_type_docs(config_type=config_type)
|
|
@@ -217,7 +217,7 @@ class DiscoveryMetadata(BaseModel):
|
|
|
217
217
|
logger.warning("Package metadata not found for %s", distro_name)
|
|
218
218
|
version = ""
|
|
219
219
|
except Exception as e:
|
|
220
|
-
logger.exception("Encountered issue extracting module metadata for %s: %s", fn, e
|
|
220
|
+
logger.exception("Encountered issue extracting module metadata for %s: %s", fn, e)
|
|
221
221
|
return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
|
|
222
222
|
|
|
223
223
|
if isinstance(wrapper_type, LLMFrameworkEnum):
|
|
@@ -252,7 +252,7 @@ class DiscoveryMetadata(BaseModel):
|
|
|
252
252
|
description = ""
|
|
253
253
|
package_version = package_version or ""
|
|
254
254
|
except Exception as e:
|
|
255
|
-
logger.exception("Encountered issue extracting module metadata for %s: %s", package_name, e
|
|
255
|
+
logger.exception("Encountered issue extracting module metadata for %s: %s", package_name, e)
|
|
256
256
|
return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
|
|
257
257
|
|
|
258
258
|
return DiscoveryMetadata(package=package_name,
|
|
@@ -290,7 +290,7 @@ class DiscoveryMetadata(BaseModel):
|
|
|
290
290
|
logger.warning("Package metadata not found for %s", distro_name)
|
|
291
291
|
version = ""
|
|
292
292
|
except Exception as e:
|
|
293
|
-
logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e
|
|
293
|
+
logger.exception("Encountered issue extracting module metadata for %s: %s", config_type, e)
|
|
294
294
|
return DiscoveryMetadata(status=DiscoveryStatusEnum.FAILURE)
|
|
295
295
|
|
|
296
296
|
wrapper_type = wrapper_type.value if isinstance(wrapper_type, LLMFrameworkEnum) else wrapper_type
|
|
@@ -16,8 +16,6 @@
|
|
|
16
16
|
from collections.abc import Sequence
|
|
17
17
|
from dataclasses import dataclass
|
|
18
18
|
from re import Pattern
|
|
19
|
-
from typing import Generic
|
|
20
|
-
from typing import TypeVar
|
|
21
19
|
|
|
22
20
|
from pydantic import model_validator
|
|
23
21
|
|
|
@@ -33,10 +31,7 @@ class GatedFieldMixinConfig:
|
|
|
33
31
|
keys: Sequence[str]
|
|
34
32
|
|
|
35
33
|
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
class GatedFieldMixin(Generic[T]):
|
|
34
|
+
class GatedFieldMixin:
|
|
40
35
|
"""
|
|
41
36
|
A mixin that gates a field based on specified keys.
|
|
42
37
|
|
|
@@ -46,7 +41,7 @@ class GatedFieldMixin(Generic[T]):
|
|
|
46
41
|
----------
|
|
47
42
|
field_name: `str`
|
|
48
43
|
The name of the field.
|
|
49
|
-
default_if_supported: `
|
|
44
|
+
default_if_supported: `object | None`
|
|
50
45
|
The default value of the field if it is supported for the key.
|
|
51
46
|
keys: `Sequence[str]`
|
|
52
47
|
A sequence of keys that are used to validate the field.
|
|
@@ -61,7 +56,7 @@ class GatedFieldMixin(Generic[T]):
|
|
|
61
56
|
def __init_subclass__(
|
|
62
57
|
cls,
|
|
63
58
|
field_name: str | None = None,
|
|
64
|
-
default_if_supported:
|
|
59
|
+
default_if_supported: object | None = None,
|
|
65
60
|
keys: Sequence[str] | None = None,
|
|
66
61
|
unsupported: Sequence[Pattern[str]] | None = None,
|
|
67
62
|
supported: Sequence[Pattern[str]] | None = None,
|
|
@@ -90,7 +85,7 @@ class GatedFieldMixin(Generic[T]):
|
|
|
90
85
|
def _setup_direct_mixin(
|
|
91
86
|
cls,
|
|
92
87
|
field_name: str,
|
|
93
|
-
default_if_supported:
|
|
88
|
+
default_if_supported: object | None,
|
|
94
89
|
unsupported: Sequence[Pattern[str]] | None,
|
|
95
90
|
supported: Sequence[Pattern[str]] | None,
|
|
96
91
|
keys: Sequence[str],
|
|
@@ -135,7 +130,7 @@ class GatedFieldMixin(Generic[T]):
|
|
|
135
130
|
def _create_gated_field_validator(
|
|
136
131
|
cls,
|
|
137
132
|
field_name: str,
|
|
138
|
-
default_if_supported:
|
|
133
|
+
default_if_supported: object | None,
|
|
139
134
|
unsupported: Sequence[Pattern[str]] | None,
|
|
140
135
|
supported: Sequence[Pattern[str]] | None,
|
|
141
136
|
keys: Sequence[str],
|
|
@@ -167,16 +162,19 @@ class GatedFieldMixin(Generic[T]):
|
|
|
167
162
|
keys: Sequence[str],
|
|
168
163
|
) -> bool:
|
|
169
164
|
"""Check if a specific field is supported based on its configuration and keys."""
|
|
165
|
+
seen = False
|
|
170
166
|
for key in keys:
|
|
171
167
|
if not hasattr(instance, key):
|
|
172
168
|
continue
|
|
169
|
+
seen = True
|
|
173
170
|
value = str(getattr(instance, key))
|
|
174
171
|
if supported is not None:
|
|
175
|
-
|
|
172
|
+
if any(p.search(value) for p in supported):
|
|
173
|
+
return True
|
|
176
174
|
elif unsupported is not None:
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
return True
|
|
175
|
+
if any(p.search(value) for p in unsupported):
|
|
176
|
+
return False
|
|
177
|
+
return True if not seen else (unsupported is not None)
|
|
180
178
|
|
|
181
179
|
@classmethod
|
|
182
180
|
def _find_blocking_key(
|
|
@@ -23,7 +23,7 @@ from nat.data_models.gated_field_mixin import GatedFieldMixin
|
|
|
23
23
|
|
|
24
24
|
class TemperatureMixin(
|
|
25
25
|
BaseModel,
|
|
26
|
-
GatedFieldMixin
|
|
26
|
+
GatedFieldMixin,
|
|
27
27
|
field_name="temperature",
|
|
28
28
|
default_if_supported=0.0,
|
|
29
29
|
keys=("model_name", "model", "azure_deployment"),
|