nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +50 -22
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +54 -27
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +68 -17
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +2 -3
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +62 -22
- nat/cli/entrypoint.py +8 -10
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +74 -66
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/span.py +41 -3
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +452 -282
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +19 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +35 -15
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +106 -8
- nat/runtime/session.py +69 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/get_memory_tool.py +1 -1
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/METADATA +42 -18
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/RECORD +238 -196
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc2.dist-info}/top_level.txt +0 -0
nat/tool/github_tools.py
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from datetime import datetime
|
|
17
|
+
from typing import Literal
|
|
18
|
+
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
from pydantic import Field
|
|
21
|
+
from pydantic import PositiveInt
|
|
22
|
+
from pydantic import computed_field
|
|
23
|
+
from pydantic import field_validator
|
|
24
|
+
|
|
25
|
+
from nat.builder.builder import Builder
|
|
26
|
+
from nat.builder.function import FunctionGroup
|
|
27
|
+
from nat.builder.function_info import FunctionInfo
|
|
28
|
+
from nat.cli.register_workflow import register_function
|
|
29
|
+
from nat.cli.register_workflow import register_function_group
|
|
30
|
+
from nat.data_models.function import FunctionBaseConfig
|
|
31
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class GithubCreateIssueModel(BaseModel):
|
|
35
|
+
title: str = Field(description="The title of the GitHub Issue")
|
|
36
|
+
body: str = Field(description="The body of the GitHub Issue")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class GithubCreateIssueModelList(BaseModel):
|
|
40
|
+
issues: list[GithubCreateIssueModel] = Field(default_factory=list,
|
|
41
|
+
description=("A list of GitHub issues, "
|
|
42
|
+
"each with a title and a body"))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GithubGetIssueModel(BaseModel):
|
|
46
|
+
state: Literal["open", "closed", "all"] | None = Field(default="open",
|
|
47
|
+
description="Issue state used in issue query filter")
|
|
48
|
+
assignee: str | None = Field(default=None, description="Assignee name used in issue query filter")
|
|
49
|
+
creator: str | None = Field(default=None, description="Creator name used in issue query filter")
|
|
50
|
+
mentioned: str | None = Field(default=None, description="Name of person mentioned in issue")
|
|
51
|
+
labels: list[str] | None = Field(default=None, description="A list of labels that are assigned to the issue")
|
|
52
|
+
since: str | None = Field(default=None,
|
|
53
|
+
description="Only show results that were last updated after the given time.")
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
@field_validator('since', mode='before')
|
|
57
|
+
def validate_since(cls, v):
|
|
58
|
+
if v is None:
|
|
59
|
+
return v
|
|
60
|
+
try:
|
|
61
|
+
# Parse the string to a datetime object
|
|
62
|
+
parsed_date = datetime.strptime(v, "%Y-%m-%dT%H:%M:%SZ")
|
|
63
|
+
# Return the formatted string
|
|
64
|
+
return parsed_date.isoformat() + 'Z'
|
|
65
|
+
except ValueError as e:
|
|
66
|
+
raise ValueError("since must be in ISO 8601 format: YYYY-MM-DDTHH:MM:SSZ") from e
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class GithubGetIssueModelList(BaseModel):
|
|
70
|
+
filter_parameters: list[GithubGetIssueModel] = Field(default_factory=list,
|
|
71
|
+
description=("A list of query params when fetching issues "
|
|
72
|
+
"each of type GithubGetIssueModel"))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class GithubUpdateIssueModel(BaseModel):
|
|
76
|
+
issue_number: str = Field(description="The issue number that will be updated")
|
|
77
|
+
title: str | None = Field(default=None, description="The title of the GitHub Issue")
|
|
78
|
+
body: str | None = Field(default=None, description="The body of the GitHub Issue")
|
|
79
|
+
state: Literal["open", "closed"] | None = Field(default=None, description="The new state of the issue")
|
|
80
|
+
|
|
81
|
+
state_reason: Literal["completed", "not_planned", "reopened"] | None = Field(
|
|
82
|
+
default=None, description="The reason for changing the state of the issue")
|
|
83
|
+
|
|
84
|
+
labels: list[str] | None = Field(default=None, description="A list of labels to assign to the issue")
|
|
85
|
+
assignees: list[str] | None = Field(default=None, description="A list of assignees to assign to the issue")
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class GithubUpdateIssueModelList(BaseModel):
|
|
89
|
+
issues: list[GithubUpdateIssueModel] = Field(default_factory=list,
|
|
90
|
+
description=("A list of GitHub issues each "
|
|
91
|
+
"of type GithubUpdateIssueModel"))
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class GithubCreatePullModel(BaseModel):
|
|
95
|
+
title: str = Field(description="Title of the pull request")
|
|
96
|
+
body: str = Field(description="Description of the pull request")
|
|
97
|
+
source_branch: str = Field(description="The name of the branch containing your changes", serialization_alias="head")
|
|
98
|
+
target_branch: str = Field(description="The name of the branch you want to merge into", serialization_alias="base")
|
|
99
|
+
assignees: list[str] | None = Field(default=None,
|
|
100
|
+
description="List of GitHub usernames to assign to the PR. "
|
|
101
|
+
"Always the current user")
|
|
102
|
+
reviewers: list[str] | None = Field(default=None, description="List of GitHub usernames to request review from")
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class GithubCreatePullList(BaseModel):
|
|
106
|
+
pull_details: list[GithubCreatePullModel] = Field(
|
|
107
|
+
default_factory=list, description=("A list of params used for creating the PR in GitHub"))
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class GithubGetPullsModel(BaseModel):
|
|
111
|
+
state: Literal["open", "closed", "all"] | None = Field(default="open",
|
|
112
|
+
description="Issue state used in issue query filter")
|
|
113
|
+
head: str | None = Field(default=None,
|
|
114
|
+
description="Filters pulls by head user or head organization and branch name")
|
|
115
|
+
base: str | None = Field(default=None, description="Filters pull by branch name")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class GithubGetPullsModelList(BaseModel):
|
|
119
|
+
filter_parameters: list[GithubGetPullsModel] = Field(
|
|
120
|
+
default_factory=list,
|
|
121
|
+
description=("A list of query params when fetching pull requests "
|
|
122
|
+
"each of type GithubGetPullsModel"))
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class GithubCommitCodeModel(BaseModel):
|
|
126
|
+
branch: str = Field(description="The branch of the remote repo to which the code will be committed")
|
|
127
|
+
commit_msg: str = Field(description="Message with which the code will be committed to the remote repo")
|
|
128
|
+
local_path: str = Field(description="Local filepath of the file that has been updated and "
|
|
129
|
+
"needs to be committed to the remote repo")
|
|
130
|
+
remote_path: str = Field(description="Remote filepath of the updated file in GitHub. Path is relative to "
|
|
131
|
+
"root of current repository")
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class GithubCommitCodeModelList(BaseModel):
|
|
135
|
+
updated_files: list[GithubCommitCodeModel] = Field(default_factory=list,
|
|
136
|
+
description=("A list of local filepaths and commit messages"))
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class GithubGroupConfig(FunctionGroupBaseConfig, name="github"):
|
|
140
|
+
"""Function group for GitHub repository operations.
|
|
141
|
+
|
|
142
|
+
Exposes issue, pull request, and commit operations with shared configuration.
|
|
143
|
+
"""
|
|
144
|
+
repo_name: str = Field(description="The repository name in the format 'owner/repo'")
|
|
145
|
+
timeout: int = Field(default=300, description="Timeout in seconds for GitHub API requests")
|
|
146
|
+
# Required for commit function
|
|
147
|
+
local_repo_dir: str | None = Field(default=None,
|
|
148
|
+
description="Absolute path to the local clone. Required for 'commit' function")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@register_function_group(config_type=GithubGroupConfig)
|
|
152
|
+
async def github_tool(config: GithubGroupConfig, _builder: Builder):
|
|
153
|
+
"""Register the `github` function group with shared configuration.
|
|
154
|
+
|
|
155
|
+
Implements:
|
|
156
|
+
- create_issue, get_issue, update_issue
|
|
157
|
+
- create_pull, get_pull
|
|
158
|
+
- commit
|
|
159
|
+
"""
|
|
160
|
+
import base64
|
|
161
|
+
import json
|
|
162
|
+
import os
|
|
163
|
+
|
|
164
|
+
import httpx
|
|
165
|
+
|
|
166
|
+
token: str | None = None
|
|
167
|
+
for env_var in ["GITHUB_TOKEN", "GITHUB_PAT", "GH_TOKEN"]:
|
|
168
|
+
token = os.getenv(env_var)
|
|
169
|
+
if token:
|
|
170
|
+
break
|
|
171
|
+
|
|
172
|
+
if not token:
|
|
173
|
+
raise ValueError("No GitHub token found in environment variables. Please set one of the following"
|
|
174
|
+
"environment variables: GITHUB_TOKEN, GITHUB_PAT, GH_TOKEN")
|
|
175
|
+
|
|
176
|
+
headers = {
|
|
177
|
+
"Authorization": f"Bearer {token}",
|
|
178
|
+
"Accept": "application/vnd.github+json",
|
|
179
|
+
"User-Agent": "NeMo-Agent-Toolkit",
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
async with httpx.AsyncClient(timeout=config.timeout, headers=headers) as client:
|
|
183
|
+
|
|
184
|
+
# Issues
|
|
185
|
+
async def create_issue(issues_list: GithubCreateIssueModelList) -> str:
|
|
186
|
+
url = f"https://api.github.com/repos/{config.repo_name}/issues"
|
|
187
|
+
results = []
|
|
188
|
+
for issue in issues_list.issues:
|
|
189
|
+
payload = issue.model_dump(exclude_unset=True)
|
|
190
|
+
response = await client.post(url, json=payload)
|
|
191
|
+
response.raise_for_status()
|
|
192
|
+
results.append(response.json())
|
|
193
|
+
return json.dumps(results)
|
|
194
|
+
|
|
195
|
+
async def get_issue(issues_list: GithubGetIssueModelList) -> str:
|
|
196
|
+
url = f"https://api.github.com/repos/{config.repo_name}/issues"
|
|
197
|
+
results = []
|
|
198
|
+
for issue in issues_list.filter_parameters:
|
|
199
|
+
params = issue.model_dump(exclude_unset=True, exclude_none=True)
|
|
200
|
+
response = await client.get(url, params=params)
|
|
201
|
+
response.raise_for_status()
|
|
202
|
+
results.append(response.json())
|
|
203
|
+
return json.dumps(results)
|
|
204
|
+
|
|
205
|
+
async def update_issue(issues_list: GithubUpdateIssueModelList) -> str:
|
|
206
|
+
url = f"https://api.github.com/repos/{config.repo_name}/issues"
|
|
207
|
+
results = []
|
|
208
|
+
for issue in issues_list.issues:
|
|
209
|
+
payload = issue.model_dump(exclude_unset=True, exclude_none=True)
|
|
210
|
+
issue_number = payload.pop("issue_number")
|
|
211
|
+
issue_url = f"{url}/{issue_number}"
|
|
212
|
+
response = await client.patch(issue_url, json=payload)
|
|
213
|
+
response.raise_for_status()
|
|
214
|
+
results.append(response.json())
|
|
215
|
+
return json.dumps(results)
|
|
216
|
+
|
|
217
|
+
# Pull requests
|
|
218
|
+
async def create_pull(pull_list: GithubCreatePullList) -> str:
|
|
219
|
+
results = []
|
|
220
|
+
pr_url = f"https://api.github.com/repos/{config.repo_name}/pulls"
|
|
221
|
+
|
|
222
|
+
for pull_detail in pull_list.pull_details:
|
|
223
|
+
|
|
224
|
+
pr_data = pull_detail.model_dump(
|
|
225
|
+
include={"title", "body", "source_branch", "target_branch"},
|
|
226
|
+
by_alias=True,
|
|
227
|
+
)
|
|
228
|
+
pr_response = await client.post(pr_url, json=pr_data)
|
|
229
|
+
pr_response.raise_for_status()
|
|
230
|
+
pr_number = pr_response.json()["number"]
|
|
231
|
+
|
|
232
|
+
result = {"pull_request": pr_response.json()}
|
|
233
|
+
|
|
234
|
+
if pull_detail.assignees:
|
|
235
|
+
assignees_url = f"https://api.github.com/repos/{config.repo_name}/issues/{pr_number}/assignees"
|
|
236
|
+
assignees_data = {"assignees": pull_detail.assignees}
|
|
237
|
+
assignees_response = await client.post(assignees_url, json=assignees_data)
|
|
238
|
+
assignees_response.raise_for_status()
|
|
239
|
+
result["assignees"] = assignees_response.json()
|
|
240
|
+
|
|
241
|
+
if pull_detail.reviewers:
|
|
242
|
+
reviewers_url = f"https://api.github.com/repos/{config.repo_name}/pulls/{pr_number}/requested_reviewers"
|
|
243
|
+
reviewers_data = {"reviewers": pull_detail.reviewers}
|
|
244
|
+
reviewers_response = await client.post(reviewers_url, json=reviewers_data)
|
|
245
|
+
reviewers_response.raise_for_status()
|
|
246
|
+
result["reviewers"] = reviewers_response.json()
|
|
247
|
+
|
|
248
|
+
results.append(result)
|
|
249
|
+
|
|
250
|
+
return json.dumps(results)
|
|
251
|
+
|
|
252
|
+
async def get_pull(pull_list: GithubGetPullsModelList) -> str:
|
|
253
|
+
url = f"https://api.github.com/repos/{config.repo_name}/pulls"
|
|
254
|
+
results = []
|
|
255
|
+
for pull_params in pull_list.filter_parameters:
|
|
256
|
+
params = pull_params.model_dump(exclude_unset=True, exclude_none=True)
|
|
257
|
+
response = await client.get(url, params=params)
|
|
258
|
+
response.raise_for_status()
|
|
259
|
+
results.append(response.json())
|
|
260
|
+
|
|
261
|
+
return json.dumps(results)
|
|
262
|
+
|
|
263
|
+
# Commits (commit updated files)
|
|
264
|
+
async def commit(updated_file_list: GithubCommitCodeModelList) -> str:
|
|
265
|
+
if not config.local_repo_dir:
|
|
266
|
+
raise ValueError("'local_repo_dir' must be set in the github function group config to use 'commit'")
|
|
267
|
+
|
|
268
|
+
results = []
|
|
269
|
+
for updated_file in updated_file_list.updated_files:
|
|
270
|
+
branch = updated_file.branch
|
|
271
|
+
commit_msg = updated_file.commit_msg
|
|
272
|
+
local_path = updated_file.local_path
|
|
273
|
+
remote_path = updated_file.remote_path
|
|
274
|
+
|
|
275
|
+
# Read content from the local file (secure + binary-safe)
|
|
276
|
+
safe_root = os.path.realpath(config.local_repo_dir)
|
|
277
|
+
candidate = os.path.realpath(os.path.join(config.local_repo_dir, local_path))
|
|
278
|
+
if not candidate.startswith(safe_root + os.sep):
|
|
279
|
+
raise ValueError(f"local_path '{local_path}' resolves outside local_repo_dir")
|
|
280
|
+
if not os.path.isfile(candidate):
|
|
281
|
+
raise FileNotFoundError(f"File not found: {candidate}")
|
|
282
|
+
with open(candidate, "rb") as f:
|
|
283
|
+
content_bytes = f.read()
|
|
284
|
+
content_b64 = base64.b64encode(content_bytes).decode("ascii")
|
|
285
|
+
|
|
286
|
+
# 1) Create blob
|
|
287
|
+
blob_url = f"https://api.github.com/repos/{config.repo_name}/git/blobs"
|
|
288
|
+
blob_data = {"content": content_b64, "encoding": "base64"}
|
|
289
|
+
blob_response = await client.post(blob_url, json=blob_data)
|
|
290
|
+
blob_response.raise_for_status()
|
|
291
|
+
blob_sha = blob_response.json()["sha"]
|
|
292
|
+
|
|
293
|
+
# 2) Get current ref (parent commit SHA)
|
|
294
|
+
ref_url = f"https://api.github.com/repos/{config.repo_name}/git/refs/heads/{branch}"
|
|
295
|
+
ref_response = await client.get(ref_url)
|
|
296
|
+
ref_response.raise_for_status()
|
|
297
|
+
parent_commit_sha = ref_response.json()["object"]["sha"]
|
|
298
|
+
|
|
299
|
+
# 3) Get parent commit to retrieve its tree SHA
|
|
300
|
+
parent_commit_url = f"https://api.github.com/repos/{config.repo_name}/git/commits/{parent_commit_sha}"
|
|
301
|
+
parent_commit_resp = await client.get(parent_commit_url)
|
|
302
|
+
parent_commit_resp.raise_for_status()
|
|
303
|
+
base_tree_sha = parent_commit_resp.json()["tree"]["sha"]
|
|
304
|
+
|
|
305
|
+
# 4) Create tree
|
|
306
|
+
tree_url = f"https://api.github.com/repos/{config.repo_name}/git/trees"
|
|
307
|
+
tree_data = {
|
|
308
|
+
"base_tree": base_tree_sha,
|
|
309
|
+
"tree": [{
|
|
310
|
+
"path": remote_path, "mode": "100644", "type": "blob", "sha": blob_sha
|
|
311
|
+
}],
|
|
312
|
+
}
|
|
313
|
+
tree_response = await client.post(tree_url, json=tree_data)
|
|
314
|
+
tree_response.raise_for_status()
|
|
315
|
+
tree_sha = tree_response.json()["sha"]
|
|
316
|
+
|
|
317
|
+
# 5) Create commit
|
|
318
|
+
commit_url = f"https://api.github.com/repos/{config.repo_name}/git/commits"
|
|
319
|
+
commit_data = {"message": commit_msg, "tree": tree_sha, "parents": [parent_commit_sha]}
|
|
320
|
+
commit_response = await client.post(commit_url, json=commit_data)
|
|
321
|
+
commit_response.raise_for_status()
|
|
322
|
+
commit_sha = commit_response.json()["sha"]
|
|
323
|
+
|
|
324
|
+
# 6) Update ref
|
|
325
|
+
update_ref_url = f"https://api.github.com/repos/{config.repo_name}/git/refs/heads/{branch}"
|
|
326
|
+
update_ref_data = {"sha": commit_sha, "force": False}
|
|
327
|
+
update_ref_response = await client.patch(update_ref_url, json=update_ref_data)
|
|
328
|
+
update_ref_response.raise_for_status()
|
|
329
|
+
|
|
330
|
+
results.append({
|
|
331
|
+
"blob_resp": blob_response.json(),
|
|
332
|
+
"parent_commit": parent_commit_resp.json(),
|
|
333
|
+
"new_tree": tree_response.json(),
|
|
334
|
+
"commit_resp": commit_response.json(),
|
|
335
|
+
"update_ref_resp": update_ref_response.json(),
|
|
336
|
+
})
|
|
337
|
+
|
|
338
|
+
return json.dumps(results)
|
|
339
|
+
|
|
340
|
+
group = FunctionGroup(config=config)
|
|
341
|
+
|
|
342
|
+
group.add_function("create_issue",
|
|
343
|
+
create_issue,
|
|
344
|
+
description=f"Creates a GitHub issue in the repo named {config.repo_name}",
|
|
345
|
+
input_schema=GithubCreateIssueModelList)
|
|
346
|
+
group.add_function("get_issue",
|
|
347
|
+
get_issue,
|
|
348
|
+
description=f"Fetches a particular GitHub issue in the repo named {config.repo_name}",
|
|
349
|
+
input_schema=GithubGetIssueModelList)
|
|
350
|
+
group.add_function("update_issue",
|
|
351
|
+
update_issue,
|
|
352
|
+
description=f"Updates a GitHub issue in the repo named {config.repo_name}",
|
|
353
|
+
input_schema=GithubUpdateIssueModelList)
|
|
354
|
+
group.add_function("create_pull",
|
|
355
|
+
create_pull,
|
|
356
|
+
description="Creates a pull request with assignees and reviewers in"
|
|
357
|
+
f"the GitHub repository named {config.repo_name}",
|
|
358
|
+
input_schema=GithubCreatePullList)
|
|
359
|
+
group.add_function("get_pull",
|
|
360
|
+
get_pull,
|
|
361
|
+
description="Fetches the files for a particular GitHub pull request"
|
|
362
|
+
f"in the repo named {config.repo_name}",
|
|
363
|
+
input_schema=GithubGetPullsModelList)
|
|
364
|
+
group.add_function("commit",
|
|
365
|
+
commit,
|
|
366
|
+
description="Commits and pushes modified code to a GitHub repository"
|
|
367
|
+
f"in the repo named {config.repo_name}",
|
|
368
|
+
input_schema=GithubCommitCodeModelList)
|
|
369
|
+
|
|
370
|
+
yield group
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
class GithubFilesGroupConfig(FunctionBaseConfig, name="github_files_tool"):
|
|
374
|
+
timeout: int = Field(default=5, description="Timeout in seconds for HTTP requests")
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
@register_function(config_type=GithubFilesGroupConfig)
|
|
378
|
+
async def github_files_tool(config: GithubFilesGroupConfig, _builder: Builder):
|
|
379
|
+
|
|
380
|
+
import re
|
|
381
|
+
|
|
382
|
+
import httpx
|
|
383
|
+
|
|
384
|
+
class FileMetadata(BaseModel):
|
|
385
|
+
repo_path: str
|
|
386
|
+
file_path: str
|
|
387
|
+
start: str | None = Field(default=None)
|
|
388
|
+
end: str | None = Field(default=None)
|
|
389
|
+
|
|
390
|
+
@computed_field
|
|
391
|
+
@property
|
|
392
|
+
def start_line(self) -> PositiveInt | None:
|
|
393
|
+
return int(self.start) if self.start else None
|
|
394
|
+
|
|
395
|
+
@computed_field
|
|
396
|
+
@property
|
|
397
|
+
def end_line(self) -> PositiveInt | None:
|
|
398
|
+
return int(self.end) if self.end else None
|
|
399
|
+
|
|
400
|
+
async with httpx.AsyncClient(timeout=config.timeout) as client:
|
|
401
|
+
|
|
402
|
+
async def get(url_text: str) -> str:
|
|
403
|
+
"""
|
|
404
|
+
Returns the text of a github file using a github url starting with https://github.com and ending
|
|
405
|
+
with a specific file. If a line reference is provided (#L409), the text of the line is returned.
|
|
406
|
+
If a range of lines is provided (#L409-L417), the text of the lines is returned.
|
|
407
|
+
|
|
408
|
+
Examples:
|
|
409
|
+
- https://github.com/org/repo/blob/main/README.md -> Returns full text of the README.md file
|
|
410
|
+
- https://github.com/org/repo/blob/main/README.md#L409 -> Returns the 409th line of the README.md file
|
|
411
|
+
- https://github.com/org/repo/blob/main/README.md#L409-L417 -> Returns lines 409-417 of the README.md file
|
|
412
|
+
"""
|
|
413
|
+
|
|
414
|
+
pattern = r"https://github\.com/(?P<repo_path>[^/]*/[^/]*)/blob/(?P<file_path>[^?#]*)(?:#L(?P<start>\d+)(?:-L(?P<end>\d+))?)?"
|
|
415
|
+
match = re.match(pattern, url_text)
|
|
416
|
+
if not match:
|
|
417
|
+
return ("Invalid github url. Please provide a valid github url. "
|
|
418
|
+
"Example: 'https://github.com/org/repo/blob/main/README.md' "
|
|
419
|
+
"or 'https://github.com/org/repo/blob/main/README.md#L409' "
|
|
420
|
+
"or 'https://github.com/org/repo/blob/main/README.md#L409-L417'")
|
|
421
|
+
|
|
422
|
+
file_metadata = FileMetadata(**match.groupdict())
|
|
423
|
+
|
|
424
|
+
# The following URL is the raw URL of the file. refs/heads/ always points to the top commit of the branch
|
|
425
|
+
raw_url = f"https://raw.githubusercontent.com/{file_metadata.repo_path}/refs/heads/{file_metadata.file_path}"
|
|
426
|
+
try:
|
|
427
|
+
response = await client.get(raw_url)
|
|
428
|
+
response.raise_for_status()
|
|
429
|
+
except httpx.TimeoutException:
|
|
430
|
+
return f"Timeout encountered when retrieving resource: {raw_url}"
|
|
431
|
+
|
|
432
|
+
if file_metadata.start_line is None:
|
|
433
|
+
return f"```{response.text}\n```"
|
|
434
|
+
|
|
435
|
+
lines = response.text.splitlines()
|
|
436
|
+
|
|
437
|
+
if file_metadata.start_line > len(lines):
|
|
438
|
+
return f"Error: Line {file_metadata.start_line} is out of range for the file {file_metadata.file_path}"
|
|
439
|
+
|
|
440
|
+
if file_metadata.end_line is None:
|
|
441
|
+
return f"```{lines[file_metadata.start_line - 1]}\n```"
|
|
442
|
+
|
|
443
|
+
if file_metadata.end_line > len(lines):
|
|
444
|
+
return f"Error: Line {file_metadata.end_line} is out of range for the file {file_metadata.file_path}"
|
|
445
|
+
|
|
446
|
+
selected_lines = lines[file_metadata.start_line - 1:file_metadata.end_line]
|
|
447
|
+
response_text = "\n".join(selected_lines)
|
|
448
|
+
return f"```{response_text}\n```"
|
|
449
|
+
|
|
450
|
+
yield FunctionInfo.from_fn(get, description=get.__doc__)
|
|
@@ -67,6 +67,6 @@ async def get_memory_tool(config: GetToolConfig, builder: Builder):
|
|
|
67
67
|
|
|
68
68
|
except Exception as e:
|
|
69
69
|
|
|
70
|
-
raise ToolException(f"Error
|
|
70
|
+
raise ToolException(f"Error retrieving memory: {e}") from e
|
|
71
71
|
|
|
72
72
|
yield FunctionInfo.from_fn(_arun, description=config.description)
|
nat/tool/nvidia_rag.py
CHANGED
|
@@ -86,7 +86,7 @@ async def nvidia_rag_tool(config: NVIDIARAGToolConfig, builder: Builder):
|
|
|
86
86
|
[await aformat_document(doc, document_prompt) for doc in docs])
|
|
87
87
|
return parsed_output
|
|
88
88
|
except Exception as e:
|
|
89
|
-
logger.exception("Error while running the tool"
|
|
89
|
+
logger.exception("Error while running the tool")
|
|
90
90
|
return f"Error while running the tool: {e}"
|
|
91
91
|
|
|
92
92
|
yield FunctionInfo.from_fn(
|
nat/tool/register.py
CHANGED
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
# pylint: disable=unused-import
|
|
17
16
|
# flake8: noqa
|
|
18
17
|
|
|
19
18
|
# Import any tools which need to be automatically registered here
|
|
@@ -25,14 +24,8 @@ from . import nvidia_rag
|
|
|
25
24
|
from . import retriever
|
|
26
25
|
from . import server_tools
|
|
27
26
|
from .code_execution import register
|
|
28
|
-
from .github_tools import
|
|
29
|
-
from .github_tools import
|
|
30
|
-
from .github_tools import create_github_pr
|
|
31
|
-
from .github_tools import get_github_file
|
|
32
|
-
from .github_tools import get_github_issue
|
|
33
|
-
from .github_tools import get_github_pr
|
|
34
|
-
from .github_tools import update_github_issue
|
|
35
|
-
from .mcp import mcp_tool
|
|
27
|
+
from .github_tools import github_tool
|
|
28
|
+
from .github_tools import github_files_tool
|
|
36
29
|
from .memory_tools import add_memory_tool
|
|
37
30
|
from .memory_tools import delete_memory_tool
|
|
38
31
|
from .memory_tools import get_memory_tool
|
nat/tool/retriever.py
CHANGED
|
@@ -78,8 +78,9 @@ async def retriever_tool(config: RetrieverConfig, builder: Builder):
|
|
|
78
78
|
|
|
79
79
|
except RetrieverError as e:
|
|
80
80
|
if config.raise_errors:
|
|
81
|
-
|
|
82
|
-
|
|
81
|
+
logger.error("Retriever threw an error: %s.", e)
|
|
82
|
+
raise
|
|
83
|
+
logger.exception("Retriever threw an error: %s. Returning an empty response.", e)
|
|
83
84
|
return RetrieverOutput(results=[])
|
|
84
85
|
|
|
85
86
|
yield FunctionInfo.from_fn(
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import inspect
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
async def ainvoke_any(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
|
22
|
+
"""Execute any type of callable and return the result.
|
|
23
|
+
|
|
24
|
+
Handles synchronous functions, asynchronous functions, generators,
|
|
25
|
+
and async generators uniformly, returning the final result value.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
func (Callable[..., Any]): The function to execute (sync/async function, generator, etc.)
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Any: The result of executing the callable
|
|
32
|
+
"""
|
|
33
|
+
# Execute the function
|
|
34
|
+
result_value = func(*args, **kwargs)
|
|
35
|
+
|
|
36
|
+
# Handle different return types
|
|
37
|
+
if inspect.iscoroutine(result_value):
|
|
38
|
+
# Async function - await the coroutine
|
|
39
|
+
return await result_value
|
|
40
|
+
|
|
41
|
+
if inspect.isgenerator(result_value):
|
|
42
|
+
# Sync generator - consume until StopIteration and get return value
|
|
43
|
+
try:
|
|
44
|
+
while True:
|
|
45
|
+
next(result_value)
|
|
46
|
+
except StopIteration as e:
|
|
47
|
+
# Return the generator's return value, or None if not provided
|
|
48
|
+
return e.value
|
|
49
|
+
|
|
50
|
+
if inspect.isasyncgen(result_value):
|
|
51
|
+
# Async generator - consume all values and return the last one
|
|
52
|
+
last_value = None
|
|
53
|
+
async for value in result_value:
|
|
54
|
+
last_value = value
|
|
55
|
+
return last_value
|
|
56
|
+
|
|
57
|
+
# Direct value from sync function (most common case)
|
|
58
|
+
return result_value
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def is_async_callable(func: Callable[..., Any]) -> bool:
|
|
62
|
+
"""Check if a function is async (coroutine function or async generator function).
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
func (Callable[..., Any]): The function to check
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
bool: True if the function is async, False otherwise
|
|
69
|
+
"""
|
|
70
|
+
return inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func)
|
|
@@ -21,7 +21,7 @@ from ..exception_handlers.schemas import yaml_exception_handler
|
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
@schema_exception_handler
|
|
24
|
-
def validate_schema(metadata, Schema):
|
|
24
|
+
def validate_schema(metadata, Schema):
|
|
25
25
|
|
|
26
26
|
try:
|
|
27
27
|
return Schema(**metadata)
|
|
@@ -31,7 +31,7 @@ def validate_schema(metadata, Schema): # pylint: disable=invalid-name
|
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
@yaml_exception_handler
|
|
34
|
-
def validate_yaml(ctx, param, value):
|
|
34
|
+
def validate_yaml(ctx, param, value):
|
|
35
35
|
"""
|
|
36
36
|
Validate that the file is a valid YAML file
|
|
37
37
|
|
|
@@ -52,7 +52,7 @@ def validate_yaml(ctx, param, value): # pylint: disable=unused-argument
|
|
|
52
52
|
if value is None:
|
|
53
53
|
return None
|
|
54
54
|
|
|
55
|
-
with open(value,
|
|
55
|
+
with open(value, encoding="utf-8") as f:
|
|
56
56
|
yaml.safe_load(f)
|
|
57
57
|
|
|
58
58
|
return value
|