nvidia-nat 1.2.1rc1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +27 -18
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +81 -50
- nat/agent/react_agent/register.py +59 -40
- nat/agent/reasoning_agent/reasoning_agent.py +17 -15
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +327 -149
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +64 -46
- nat/agent/tool_calling_agent/agent.py +152 -29
- nat/agent/tool_calling_agent/register.py +61 -38
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +69 -36
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +10 -6
- nat/builder/context.py +70 -18
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/intermediate_step_manager.py +6 -2
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +327 -79
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +14 -13
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +5 -2
- nat/cli/commands/workflow/templates/register.py.j2 +2 -3
- nat/cli/commands/workflow/templates/workflow.py.j2 +35 -21
- nat/cli/commands/workflow/workflow_commands.py +105 -19
- nat/cli/entrypoint.py +17 -11
- nat/cli/main.py +3 -0
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +79 -10
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +196 -67
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +42 -18
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/span.py +41 -3
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/azure_openai_embedder.py +46 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +2 -3
- nat/embedder/register.py +1 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +9 -6
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/execute_score_select_function.py +1 -1
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +3 -3
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +19 -7
- nat/front_ends/fastapi/auth_flow_handlers/http_flow_handler.py +1 -1
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +455 -282
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +74 -50
- nat/front_ends/fastapi/message_validator.py +20 -21
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +47 -3
- nat/front_ends/mcp/mcp_front_end_plugin.py +48 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +120 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +57 -0
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +5 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +35 -15
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +22 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +14 -7
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/forecasting/models/linear_model.py +1 -1
- nat/profiler/forecasting/models/random_forest_regressor.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/nested_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +8 -9
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +164 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +395 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +105 -8
- nat/runtime/session.py +69 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/memory_tools/add_memory_tool.py +3 -3
- nat/tool/memory_tools/delete_memory_tool.py +3 -4
- nat/tool/memory_tools/get_memory_tool.py +4 -4
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/decorators.py +210 -0
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +12 -3
- nat/utils/type_utils.py +9 -5
- nvidia_nat-1.3.0.dist-info/METADATA +195 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/RECORD +244 -200
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- nvidia_nat-1.2.1rc1.dist-info/METADATA +0 -365
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.2.1rc1.dist-info → nvidia_nat-1.3.0.dist-info}/top_level.txt +0 -0
nat/builder/context.py
CHANGED
|
@@ -19,6 +19,7 @@ from collections.abc import Awaitable
|
|
|
19
19
|
from collections.abc import Callable
|
|
20
20
|
from contextlib import contextmanager
|
|
21
21
|
from contextvars import ContextVar
|
|
22
|
+
from functools import cached_property
|
|
22
23
|
|
|
23
24
|
from nat.builder.intermediate_step_manager import IntermediateStepManager
|
|
24
25
|
from nat.builder.user_interaction_manager import UserInteractionManager
|
|
@@ -31,6 +32,7 @@ from nat.data_models.intermediate_step import IntermediateStep
|
|
|
31
32
|
from nat.data_models.intermediate_step import IntermediateStepPayload
|
|
32
33
|
from nat.data_models.intermediate_step import IntermediateStepType
|
|
33
34
|
from nat.data_models.intermediate_step import StreamEventData
|
|
35
|
+
from nat.data_models.intermediate_step import TraceMetadata
|
|
34
36
|
from nat.data_models.invocation_node import InvocationNode
|
|
35
37
|
from nat.runtime.user_metadata import RequestAttributes
|
|
36
38
|
from nat.utils.reactive.subject import Subject
|
|
@@ -38,13 +40,13 @@ from nat.utils.reactive.subject import Subject
|
|
|
38
40
|
|
|
39
41
|
class Singleton(type):
|
|
40
42
|
|
|
41
|
-
def __init__(cls, name, bases, dict):
|
|
42
|
-
super(
|
|
43
|
+
def __init__(cls, name, bases, dict):
|
|
44
|
+
super().__init__(name, bases, dict)
|
|
43
45
|
cls.instance = None
|
|
44
46
|
|
|
45
47
|
def __call__(cls, *args, **kw):
|
|
46
48
|
if cls.instance is None:
|
|
47
|
-
cls.instance = super(
|
|
49
|
+
cls.instance = super().__call__(*args, **kw)
|
|
48
50
|
return cls.instance
|
|
49
51
|
|
|
50
52
|
|
|
@@ -65,14 +67,15 @@ class ContextState(metaclass=Singleton):
|
|
|
65
67
|
|
|
66
68
|
def __init__(self):
|
|
67
69
|
self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None)
|
|
70
|
+
self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
|
|
71
|
+
self.workflow_run_id: ContextVar[str | None] = ContextVar("workflow_run_id", default=None)
|
|
72
|
+
self.workflow_trace_id: ContextVar[int | None] = ContextVar("workflow_trace_id", default=None)
|
|
68
73
|
self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
|
|
69
74
|
self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
|
|
70
|
-
self.
|
|
71
|
-
self.
|
|
72
|
-
self.
|
|
73
|
-
|
|
74
|
-
function_name="root"))
|
|
75
|
-
self.active_span_id_stack: ContextVar[list[str]] = ContextVar("active_span_id_stack", default=["root"])
|
|
75
|
+
self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None)
|
|
76
|
+
self._event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=None)
|
|
77
|
+
self._active_function: ContextVar[InvocationNode | None] = ContextVar("active_function", default=None)
|
|
78
|
+
self._active_span_id_stack: ContextVar[list[str] | None] = ContextVar("active_span_id_stack", default=None)
|
|
76
79
|
|
|
77
80
|
# Default is a lambda no-op which returns NoneType
|
|
78
81
|
self.user_input_callback: ContextVar[Callable[[InteractionPrompt], Awaitable[HumanResponse | None]]
|
|
@@ -83,6 +86,30 @@ class ContextState(metaclass=Singleton):
|
|
|
83
86
|
Awaitable[AuthenticatedContext]]
|
|
84
87
|
| None] = ContextVar("user_auth_callback", default=None)
|
|
85
88
|
|
|
89
|
+
@property
|
|
90
|
+
def metadata(self) -> ContextVar[RequestAttributes]:
|
|
91
|
+
if self._metadata.get() is None:
|
|
92
|
+
self._metadata.set(RequestAttributes())
|
|
93
|
+
return typing.cast(ContextVar[RequestAttributes], self._metadata)
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def active_function(self) -> ContextVar[InvocationNode]:
|
|
97
|
+
if self._active_function.get() is None:
|
|
98
|
+
self._active_function.set(InvocationNode(function_id="root", function_name="root"))
|
|
99
|
+
return typing.cast(ContextVar[InvocationNode], self._active_function)
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def event_stream(self) -> ContextVar[Subject[IntermediateStep]]:
|
|
103
|
+
if self._event_stream.get() is None:
|
|
104
|
+
self._event_stream.set(Subject())
|
|
105
|
+
return typing.cast(ContextVar[Subject[IntermediateStep]], self._event_stream)
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def active_span_id_stack(self) -> ContextVar[list[str]]:
|
|
109
|
+
if self._active_span_id_stack.get() is None:
|
|
110
|
+
self._active_span_id_stack.set(["root"])
|
|
111
|
+
return typing.cast(ContextVar[list[str]], self._active_span_id_stack)
|
|
112
|
+
|
|
86
113
|
@staticmethod
|
|
87
114
|
def get() -> "ContextState":
|
|
88
115
|
return ContextState()
|
|
@@ -96,14 +123,14 @@ class Context:
|
|
|
96
123
|
@property
|
|
97
124
|
def input_message(self):
|
|
98
125
|
"""
|
|
99
|
-
|
|
126
|
+
Retrieves the input message from the context state.
|
|
100
127
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
128
|
+
The input_message property is used to access the message stored in the
|
|
129
|
+
context state. This property returns the message as it is currently
|
|
130
|
+
maintained in the context.
|
|
104
131
|
|
|
105
|
-
|
|
106
|
-
|
|
132
|
+
Returns:
|
|
133
|
+
str: The input message retrieved from the context state.
|
|
107
134
|
"""
|
|
108
135
|
return self._context_state.input_message.get()
|
|
109
136
|
|
|
@@ -141,7 +168,7 @@ class Context:
|
|
|
141
168
|
"""
|
|
142
169
|
return UserInteractionManager(self._context_state)
|
|
143
170
|
|
|
144
|
-
@
|
|
171
|
+
@cached_property
|
|
145
172
|
def intermediate_step_manager(self) -> IntermediateStepManager:
|
|
146
173
|
"""
|
|
147
174
|
Retrieves the intermediate step manager instance from the current context state.
|
|
@@ -165,8 +192,32 @@ class Context:
|
|
|
165
192
|
"""
|
|
166
193
|
return self._context_state.conversation_id.get()
|
|
167
194
|
|
|
195
|
+
@property
|
|
196
|
+
def user_message_id(self) -> str | None:
|
|
197
|
+
"""
|
|
198
|
+
This property retrieves the user message ID which is the unique identifier for the current user message.
|
|
199
|
+
"""
|
|
200
|
+
return self._context_state.user_message_id.get()
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def workflow_run_id(self) -> str | None:
|
|
204
|
+
"""
|
|
205
|
+
Returns a stable identifier for the current workflow/agent invocation (UUID string).
|
|
206
|
+
"""
|
|
207
|
+
return self._context_state.workflow_run_id.get()
|
|
208
|
+
|
|
209
|
+
@property
|
|
210
|
+
def workflow_trace_id(self) -> int | None:
|
|
211
|
+
"""
|
|
212
|
+
Returns the 128-bit trace identifier for the current run, used as the OpenTelemetry trace_id.
|
|
213
|
+
"""
|
|
214
|
+
return self._context_state.workflow_trace_id.get()
|
|
215
|
+
|
|
168
216
|
@contextmanager
|
|
169
|
-
def push_active_function(self,
|
|
217
|
+
def push_active_function(self,
|
|
218
|
+
function_name: str,
|
|
219
|
+
input_data: typing.Any | None,
|
|
220
|
+
metadata: dict[str, typing.Any] | TraceMetadata | None = None):
|
|
170
221
|
"""
|
|
171
222
|
Set the 'active_function' in context, push an invocation node,
|
|
172
223
|
AND create an OTel child span for that function call.
|
|
@@ -187,7 +238,8 @@ class Context:
|
|
|
187
238
|
IntermediateStepPayload(UUID=current_function_id,
|
|
188
239
|
event_type=IntermediateStepType.FUNCTION_START,
|
|
189
240
|
name=function_name,
|
|
190
|
-
data=StreamEventData(input=input_data)
|
|
241
|
+
data=StreamEventData(input=input_data),
|
|
242
|
+
metadata=metadata))
|
|
191
243
|
|
|
192
244
|
manager = ActiveFunctionContextManager()
|
|
193
245
|
|
nat/builder/eval_builder.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
16
17
|
import dataclasses
|
|
17
18
|
import logging
|
|
18
19
|
from contextlib import asynccontextmanager
|
|
@@ -61,7 +62,7 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
|
|
|
61
62
|
# Store the evaluator
|
|
62
63
|
self._evaluators[name] = ConfiguredEvaluator(config=config, instance=info_obj)
|
|
63
64
|
except Exception as e:
|
|
64
|
-
logger.error("Error %s adding evaluator `%s` with config `%s`", e, name, config
|
|
65
|
+
logger.error("Error %s adding evaluator `%s` with config `%s`", e, name, config)
|
|
65
66
|
raise
|
|
66
67
|
|
|
67
68
|
@override
|
|
@@ -90,17 +91,20 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
|
|
|
90
91
|
return self.eval_general_config.output_dir
|
|
91
92
|
|
|
92
93
|
@override
|
|
93
|
-
def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str):
|
|
94
|
-
tools = []
|
|
94
|
+
async def get_all_tools(self, wrapper_type: LLMFrameworkEnum | str):
|
|
95
95
|
tool_wrapper_reg = self._registry.get_tool_wrapper(llm_framework=wrapper_type)
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
|
|
97
|
+
async def get_tool(fn_name: str):
|
|
98
|
+
fn = await self.get_function(fn_name)
|
|
98
99
|
try:
|
|
99
|
-
|
|
100
|
+
return tool_wrapper_reg.build_fn(fn_name, fn, self)
|
|
100
101
|
except Exception:
|
|
101
|
-
logger.exception("Error fetching tool `%s`", fn_name
|
|
102
|
+
logger.exception("Error fetching tool `%s`", fn_name)
|
|
103
|
+
return None
|
|
102
104
|
|
|
103
|
-
|
|
105
|
+
tasks = [get_tool(fn_name) for fn_name in self._functions]
|
|
106
|
+
tools = await asyncio.gather(*tasks, return_exceptions=False)
|
|
107
|
+
return [tool for tool in tools if tool is not None]
|
|
104
108
|
|
|
105
109
|
def _log_build_failure_evaluator(self,
|
|
106
110
|
failing_evaluator_name: str,
|
|
@@ -127,11 +131,12 @@ class WorkflowEvalBuilder(WorkflowBuilder, EvalBuilder):
|
|
|
127
131
|
remaining_components,
|
|
128
132
|
original_error)
|
|
129
133
|
|
|
130
|
-
|
|
134
|
+
@override
|
|
135
|
+
async def populate_builder(self, config: Config, skip_workflow: bool = False):
|
|
131
136
|
# Skip setting workflow if workflow config is EmptyFunctionConfig
|
|
132
|
-
skip_workflow = isinstance(config.workflow, EmptyFunctionConfig)
|
|
137
|
+
skip_workflow = skip_workflow or isinstance(config.workflow, EmptyFunctionConfig)
|
|
133
138
|
|
|
134
|
-
await super().populate_builder(config, skip_workflow)
|
|
139
|
+
await super().populate_builder(config, skip_workflow=skip_workflow)
|
|
135
140
|
|
|
136
141
|
# Initialize progress tracking for evaluators
|
|
137
142
|
completed_evaluators = []
|
nat/builder/framework_enum.py
CHANGED
nat/builder/front_end.py
CHANGED
|
@@ -37,7 +37,7 @@ class FrontEndBase(typing.Generic[FrontEndConfigT], ABC):
|
|
|
37
37
|
|
|
38
38
|
super().__init__()
|
|
39
39
|
|
|
40
|
-
self._full_config:
|
|
40
|
+
self._full_config: Config = full_config
|
|
41
41
|
self._front_end_config: FrontEndConfigT = typing.cast(FrontEndConfigT, full_config.general.front_end)
|
|
42
42
|
|
|
43
43
|
@property
|
nat/builder/function.py
CHANGED
|
@@ -14,12 +14,14 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import logging
|
|
17
|
+
import re
|
|
17
18
|
import typing
|
|
18
19
|
from abc import ABC
|
|
19
20
|
from abc import abstractmethod
|
|
20
21
|
from collections.abc import AsyncGenerator
|
|
21
22
|
from collections.abc import Awaitable
|
|
22
23
|
from collections.abc import Callable
|
|
24
|
+
from collections.abc import Sequence
|
|
23
25
|
|
|
24
26
|
from pydantic import BaseModel
|
|
25
27
|
|
|
@@ -29,7 +31,9 @@ from nat.builder.function_base import InputT
|
|
|
29
31
|
from nat.builder.function_base import SingleOutputT
|
|
30
32
|
from nat.builder.function_base import StreamingOutputT
|
|
31
33
|
from nat.builder.function_info import FunctionInfo
|
|
34
|
+
from nat.data_models.function import EmptyFunctionConfig
|
|
32
35
|
from nat.data_models.function import FunctionBaseConfig
|
|
36
|
+
from nat.data_models.function import FunctionGroupBaseConfig
|
|
33
37
|
|
|
34
38
|
_InvokeFnT = Callable[[InputT], Awaitable[SingleOutputT]]
|
|
35
39
|
_StreamFnT = Callable[[InputT], AsyncGenerator[StreamingOutputT]]
|
|
@@ -155,8 +159,8 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
|
|
|
155
159
|
|
|
156
160
|
return result
|
|
157
161
|
except Exception as e:
|
|
158
|
-
logger.error("Error with ainvoke in function with input: %s.", value,
|
|
159
|
-
raise
|
|
162
|
+
logger.error("Error with ainvoke in function with input: %s. Error: %s", value, e)
|
|
163
|
+
raise
|
|
160
164
|
|
|
161
165
|
@typing.final
|
|
162
166
|
async def acall_invoke(self, *args, **kwargs):
|
|
@@ -186,14 +190,14 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
|
|
|
186
190
|
input_obj = self.input_schema(*args, **kwargs)
|
|
187
191
|
|
|
188
192
|
return await self.ainvoke(value=input_obj)
|
|
189
|
-
except Exception
|
|
193
|
+
except Exception:
|
|
190
194
|
logger.error(
|
|
191
195
|
"Error in acall_invoke() converting input to function schema. Both args and kwargs were "
|
|
192
196
|
"supplied which could not be converted to the input schema. args: %s\nkwargs: %s\nschema: %s",
|
|
193
197
|
args,
|
|
194
198
|
kwargs,
|
|
195
199
|
self.input_schema)
|
|
196
|
-
raise
|
|
200
|
+
raise
|
|
197
201
|
|
|
198
202
|
@abstractmethod
|
|
199
203
|
async def _astream(self, value: InputT) -> AsyncGenerator[StreamingOutputT]:
|
|
@@ -252,8 +256,8 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
|
|
|
252
256
|
manager.set_output(final_output)
|
|
253
257
|
|
|
254
258
|
except Exception as e:
|
|
255
|
-
logger.error("Error with astream in function with input: %s.", value,
|
|
256
|
-
raise
|
|
259
|
+
logger.error("Error with astream in function with input: %s. Error: %s", value, e)
|
|
260
|
+
raise
|
|
257
261
|
|
|
258
262
|
@typing.final
|
|
259
263
|
async def acall_stream(self, *args, **kwargs):
|
|
@@ -287,14 +291,14 @@ class Function(FunctionBase[InputT, StreamingOutputT, SingleOutputT], ABC):
|
|
|
287
291
|
|
|
288
292
|
async for x in self.astream(value=input_obj):
|
|
289
293
|
yield x
|
|
290
|
-
except Exception
|
|
294
|
+
except Exception:
|
|
291
295
|
logger.error(
|
|
292
296
|
"Error in acall_stream() converting input to function schema. Both args and kwargs were "
|
|
293
297
|
"supplied which could not be converted to the input schema. args: %s\nkwargs: %s\nschema: %s",
|
|
294
298
|
args,
|
|
295
299
|
kwargs,
|
|
296
300
|
self.input_schema)
|
|
297
|
-
raise
|
|
301
|
+
raise
|
|
298
302
|
|
|
299
303
|
|
|
300
304
|
class LambdaFunction(Function[InputT, StreamingOutputT, SingleOutputT]):
|
|
@@ -342,3 +346,369 @@ class LambdaFunction(Function[InputT, StreamingOutputT, SingleOutputT]):
|
|
|
342
346
|
pass
|
|
343
347
|
|
|
344
348
|
return FunctionImpl(config=config, info=info, instance_name=instance_name)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class FunctionGroup:
|
|
352
|
+
"""
|
|
353
|
+
A group of functions that can be used together, sharing the same configuration, context, and resources.
|
|
354
|
+
"""
|
|
355
|
+
|
|
356
|
+
def __init__(self,
|
|
357
|
+
*,
|
|
358
|
+
config: FunctionGroupBaseConfig,
|
|
359
|
+
instance_name: str | None = None,
|
|
360
|
+
filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None):
|
|
361
|
+
"""
|
|
362
|
+
Creates a new function group.
|
|
363
|
+
|
|
364
|
+
Parameters
|
|
365
|
+
----------
|
|
366
|
+
config : FunctionGroupBaseConfig
|
|
367
|
+
The configuration for the function group.
|
|
368
|
+
instance_name : str | None, optional
|
|
369
|
+
The name of the function group. If not provided, the type of the function group will be used.
|
|
370
|
+
filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
|
|
371
|
+
A callback function to additionally filter the functions in the function group dynamically when
|
|
372
|
+
the functions are accessed via any accessor method.
|
|
373
|
+
"""
|
|
374
|
+
self._config = config
|
|
375
|
+
self._instance_name = instance_name or config.type
|
|
376
|
+
self._functions: dict[str, Function] = dict()
|
|
377
|
+
self._filter_fn = filter_fn
|
|
378
|
+
self._per_function_filter_fn: dict[str, Callable[[str], Awaitable[bool]]] = dict()
|
|
379
|
+
|
|
380
|
+
def add_function(self,
|
|
381
|
+
name: str,
|
|
382
|
+
fn: Callable,
|
|
383
|
+
*,
|
|
384
|
+
input_schema: type[BaseModel] | None = None,
|
|
385
|
+
description: str | None = None,
|
|
386
|
+
converters: list[Callable] | None = None,
|
|
387
|
+
filter_fn: Callable[[str], Awaitable[bool]] | None = None):
|
|
388
|
+
"""
|
|
389
|
+
Adds a function to the function group.
|
|
390
|
+
|
|
391
|
+
Parameters
|
|
392
|
+
----------
|
|
393
|
+
name : str
|
|
394
|
+
The name of the function.
|
|
395
|
+
fn : Callable
|
|
396
|
+
The function to add to the function group.
|
|
397
|
+
input_schema : type[BaseModel] | None, optional
|
|
398
|
+
The input schema for the function.
|
|
399
|
+
description : str | None, optional
|
|
400
|
+
The description of the function.
|
|
401
|
+
converters : list[Callable] | None, optional
|
|
402
|
+
The converters to use for the function.
|
|
403
|
+
filter_fn : Callable[[str], Awaitable[bool]] | None, optional
|
|
404
|
+
A callback to determine if the function should be included in the function group. The
|
|
405
|
+
callback will be called with the function name. The callback is invoked dynamically when
|
|
406
|
+
the functions are accessed via any accessor method such as `get_accessible_functions`,
|
|
407
|
+
`get_included_functions`, `get_excluded_functions`, `get_all_functions`.
|
|
408
|
+
|
|
409
|
+
Raises
|
|
410
|
+
------
|
|
411
|
+
ValueError
|
|
412
|
+
When the function name is empty or blank.
|
|
413
|
+
When the function name contains invalid characters.
|
|
414
|
+
When the function already exists in the function group.
|
|
415
|
+
"""
|
|
416
|
+
if not name.strip():
|
|
417
|
+
raise ValueError("Function name cannot be empty or blank")
|
|
418
|
+
if not re.match(r"^[a-zA-Z0-9_.-]+$", name):
|
|
419
|
+
raise ValueError(
|
|
420
|
+
f"Function name can only contain letters, numbers, underscores, periods, and hyphens: {name}")
|
|
421
|
+
if name in self._functions:
|
|
422
|
+
raise ValueError(f"Function {name} already exists in function group {self._instance_name}")
|
|
423
|
+
|
|
424
|
+
info = FunctionInfo.from_fn(fn, input_schema=input_schema, description=description, converters=converters)
|
|
425
|
+
full_name = self._get_fn_name(name)
|
|
426
|
+
lambda_fn = LambdaFunction.from_info(config=EmptyFunctionConfig(), info=info, instance_name=full_name)
|
|
427
|
+
self._functions[name] = lambda_fn
|
|
428
|
+
if filter_fn:
|
|
429
|
+
self._per_function_filter_fn[name] = filter_fn
|
|
430
|
+
|
|
431
|
+
def get_config(self) -> FunctionGroupBaseConfig:
|
|
432
|
+
"""
|
|
433
|
+
Returns the configuration for the function group.
|
|
434
|
+
|
|
435
|
+
Returns
|
|
436
|
+
-------
|
|
437
|
+
FunctionGroupBaseConfig
|
|
438
|
+
The configuration for the function group.
|
|
439
|
+
"""
|
|
440
|
+
return self._config
|
|
441
|
+
|
|
442
|
+
def _get_fn_name(self, name: str) -> str:
|
|
443
|
+
return f"{self._instance_name}.{name}"
|
|
444
|
+
|
|
445
|
+
async def _fn_should_be_included(self, name: str) -> bool:
|
|
446
|
+
if name not in self._per_function_filter_fn:
|
|
447
|
+
return True
|
|
448
|
+
return await self._per_function_filter_fn[name](name)
|
|
449
|
+
|
|
450
|
+
async def _get_all_but_excluded_functions(
|
|
451
|
+
self,
|
|
452
|
+
filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
|
|
453
|
+
) -> dict[str, Function]:
|
|
454
|
+
"""
|
|
455
|
+
Returns a dictionary of all functions in the function group except the excluded functions.
|
|
456
|
+
"""
|
|
457
|
+
missing = set(self._config.exclude) - set(self._functions.keys())
|
|
458
|
+
if missing:
|
|
459
|
+
raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
|
|
460
|
+
|
|
461
|
+
if filter_fn is None:
|
|
462
|
+
if self._filter_fn is None:
|
|
463
|
+
|
|
464
|
+
async def identity_filter(x: Sequence[str]) -> Sequence[str]:
|
|
465
|
+
return x
|
|
466
|
+
|
|
467
|
+
filter_fn = identity_filter
|
|
468
|
+
else:
|
|
469
|
+
filter_fn = self._filter_fn
|
|
470
|
+
|
|
471
|
+
excluded = set(self._config.exclude)
|
|
472
|
+
included = set(await filter_fn(list(self._functions.keys())))
|
|
473
|
+
|
|
474
|
+
result = {}
|
|
475
|
+
for name in self._functions:
|
|
476
|
+
if name in excluded:
|
|
477
|
+
continue
|
|
478
|
+
if not await self._fn_should_be_included(name):
|
|
479
|
+
continue
|
|
480
|
+
if name not in included:
|
|
481
|
+
continue
|
|
482
|
+
result[self._get_fn_name(name)] = self._functions[name]
|
|
483
|
+
|
|
484
|
+
return result
|
|
485
|
+
|
|
486
|
+
async def get_accessible_functions(
|
|
487
|
+
self,
|
|
488
|
+
filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
|
|
489
|
+
) -> dict[str, Function]:
|
|
490
|
+
"""
|
|
491
|
+
Returns a dictionary of all accessible functions in the function group.
|
|
492
|
+
|
|
493
|
+
First, the functions are filtered by the function group's configuration.
|
|
494
|
+
If the function group is configured to:
|
|
495
|
+
- include some functions, this will return only the included functions.
|
|
496
|
+
- not include or exclude any function, this will return all functions in the group.
|
|
497
|
+
- exclude some functions, this will return all functions in the group except the excluded functions.
|
|
498
|
+
|
|
499
|
+
Then, the functions are filtered by filter function and per-function filter functions.
|
|
500
|
+
|
|
501
|
+
Parameters
|
|
502
|
+
----------
|
|
503
|
+
filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
|
|
504
|
+
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
505
|
+
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
506
|
+
all functions will be returned.
|
|
507
|
+
|
|
508
|
+
Returns
|
|
509
|
+
-------
|
|
510
|
+
dict[str, Function]
|
|
511
|
+
A dictionary of all accessible functions in the function group.
|
|
512
|
+
|
|
513
|
+
Raises
|
|
514
|
+
------
|
|
515
|
+
ValueError
|
|
516
|
+
When the function group is configured to include functions that are not found in the group.
|
|
517
|
+
"""
|
|
518
|
+
if self._config.include:
|
|
519
|
+
return await self.get_included_functions(filter_fn=filter_fn)
|
|
520
|
+
if self._config.exclude:
|
|
521
|
+
return await self._get_all_but_excluded_functions(filter_fn=filter_fn)
|
|
522
|
+
return await self.get_all_functions(filter_fn=filter_fn)
|
|
523
|
+
|
|
524
|
+
async def get_excluded_functions(
|
|
525
|
+
self,
|
|
526
|
+
filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
|
|
527
|
+
) -> dict[str, Function]:
|
|
528
|
+
"""
|
|
529
|
+
Returns a dictionary of all functions in the function group which are configured to be excluded or filtered
|
|
530
|
+
out by a filter function or per-function filter function.
|
|
531
|
+
|
|
532
|
+
Parameters
|
|
533
|
+
----------
|
|
534
|
+
filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
|
|
535
|
+
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
536
|
+
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
537
|
+
then no functions will be added to the returned dictionary.
|
|
538
|
+
|
|
539
|
+
Returns
|
|
540
|
+
-------
|
|
541
|
+
dict[str, Function]
|
|
542
|
+
A dictionary of all excluded functions in the function group.
|
|
543
|
+
|
|
544
|
+
Raises
|
|
545
|
+
------
|
|
546
|
+
ValueError
|
|
547
|
+
When the function group is configured to exclude functions that are not found in the group.
|
|
548
|
+
"""
|
|
549
|
+
missing = set(self._config.exclude) - set(self._functions.keys())
|
|
550
|
+
if missing:
|
|
551
|
+
raise ValueError(f"Unknown excluded functions: {sorted(missing)}")
|
|
552
|
+
|
|
553
|
+
if filter_fn is None:
|
|
554
|
+
if self._filter_fn is None:
|
|
555
|
+
|
|
556
|
+
async def identity_filter(x: Sequence[str]) -> Sequence[str]:
|
|
557
|
+
return x
|
|
558
|
+
|
|
559
|
+
filter_fn = identity_filter
|
|
560
|
+
else:
|
|
561
|
+
filter_fn = self._filter_fn
|
|
562
|
+
|
|
563
|
+
excluded = set(self._config.exclude)
|
|
564
|
+
included = set(await filter_fn(list(self._functions.keys())))
|
|
565
|
+
|
|
566
|
+
result = {}
|
|
567
|
+
for name in self._functions:
|
|
568
|
+
is_excluded = False
|
|
569
|
+
if name in excluded:
|
|
570
|
+
is_excluded = True
|
|
571
|
+
elif not await self._fn_should_be_included(name):
|
|
572
|
+
is_excluded = True
|
|
573
|
+
elif name not in included:
|
|
574
|
+
is_excluded = True
|
|
575
|
+
|
|
576
|
+
if is_excluded:
|
|
577
|
+
result[self._get_fn_name(name)] = self._functions[name]
|
|
578
|
+
|
|
579
|
+
return result
|
|
580
|
+
|
|
581
|
+
async def get_included_functions(
|
|
582
|
+
self,
|
|
583
|
+
filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
|
|
584
|
+
) -> dict[str, Function]:
|
|
585
|
+
"""
|
|
586
|
+
Returns a dictionary of all functions in the function group which are:
|
|
587
|
+
- configured to be included and added to the global function registry
|
|
588
|
+
- not configured to be excluded.
|
|
589
|
+
- not filtered out by a filter function.
|
|
590
|
+
|
|
591
|
+
Parameters
|
|
592
|
+
----------
|
|
593
|
+
filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
|
|
594
|
+
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
595
|
+
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
596
|
+
all functions will be returned.
|
|
597
|
+
|
|
598
|
+
Returns
|
|
599
|
+
-------
|
|
600
|
+
dict[str, Function]
|
|
601
|
+
A dictionary of all included functions in the function group.
|
|
602
|
+
|
|
603
|
+
Raises
|
|
604
|
+
------
|
|
605
|
+
ValueError
|
|
606
|
+
When the function group is configured to include functions that are not found in the group.
|
|
607
|
+
"""
|
|
608
|
+
missing = set(self._config.include) - set(self._functions.keys())
|
|
609
|
+
if missing:
|
|
610
|
+
raise ValueError(f"Unknown included functions: {sorted(missing)}")
|
|
611
|
+
|
|
612
|
+
if filter_fn is None:
|
|
613
|
+
if self._filter_fn is None:
|
|
614
|
+
|
|
615
|
+
async def identity_filter(x: Sequence[str]) -> Sequence[str]:
|
|
616
|
+
return x
|
|
617
|
+
|
|
618
|
+
filter_fn = identity_filter
|
|
619
|
+
else:
|
|
620
|
+
filter_fn = self._filter_fn
|
|
621
|
+
|
|
622
|
+
included = set(await filter_fn(list(self._config.include)))
|
|
623
|
+
result = {}
|
|
624
|
+
for name in included:
|
|
625
|
+
if await self._fn_should_be_included(name):
|
|
626
|
+
result[self._get_fn_name(name)] = self._functions[name]
|
|
627
|
+
return result
|
|
628
|
+
|
|
629
|
+
async def get_all_functions(
|
|
630
|
+
self,
|
|
631
|
+
filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None = None,
|
|
632
|
+
) -> dict[str, Function]:
|
|
633
|
+
"""
|
|
634
|
+
Returns a dictionary of all functions in the function group, regardless if they are included or excluded.
|
|
635
|
+
|
|
636
|
+
If a filter function has been set, the returned functions will additionally be filtered by the callback.
|
|
637
|
+
|
|
638
|
+
Parameters
|
|
639
|
+
----------
|
|
640
|
+
filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]] | None, optional
|
|
641
|
+
A callback function to additionally filter the functions in the function group dynamically. If not provided
|
|
642
|
+
then fall back to the function group's filter function. If no filter function is set for the function group
|
|
643
|
+
all functions will be returned.
|
|
644
|
+
|
|
645
|
+
Returns
|
|
646
|
+
-------
|
|
647
|
+
dict[str, Function]
|
|
648
|
+
A dictionary of all functions in the function group.
|
|
649
|
+
"""
|
|
650
|
+
if filter_fn is None:
|
|
651
|
+
if self._filter_fn is None:
|
|
652
|
+
|
|
653
|
+
async def identity_filter(x: Sequence[str]) -> Sequence[str]:
|
|
654
|
+
return x
|
|
655
|
+
|
|
656
|
+
filter_fn = identity_filter
|
|
657
|
+
else:
|
|
658
|
+
filter_fn = self._filter_fn
|
|
659
|
+
|
|
660
|
+
included = set(await filter_fn(list(self._functions.keys())))
|
|
661
|
+
result = {}
|
|
662
|
+
for name in included:
|
|
663
|
+
if await self._fn_should_be_included(name):
|
|
664
|
+
result[self._get_fn_name(name)] = self._functions[name]
|
|
665
|
+
return result
|
|
666
|
+
|
|
667
|
+
def set_filter_fn(self, filter_fn: Callable[[Sequence[str]], Awaitable[Sequence[str]]]):
|
|
668
|
+
"""
|
|
669
|
+
Sets the filter function for the function group.
|
|
670
|
+
|
|
671
|
+
Parameters
|
|
672
|
+
----------
|
|
673
|
+
filter_fn : Callable[[Sequence[str]], Awaitable[Sequence[str]]]
|
|
674
|
+
The filter function to set for the function group.
|
|
675
|
+
"""
|
|
676
|
+
self._filter_fn = filter_fn
|
|
677
|
+
|
|
678
|
+
def set_per_function_filter_fn(self, name: str, filter_fn: Callable[[str], Awaitable[bool]]):
|
|
679
|
+
"""
|
|
680
|
+
Sets the a per-function filter function for the a function within the function group.
|
|
681
|
+
|
|
682
|
+
Parameters
|
|
683
|
+
----------
|
|
684
|
+
name : str
|
|
685
|
+
The name of the function.
|
|
686
|
+
filter_fn : Callable[[str], Awaitable[bool]]
|
|
687
|
+
The per-function filter function to set for the function group.
|
|
688
|
+
|
|
689
|
+
Raises
|
|
690
|
+
------
|
|
691
|
+
ValueError
|
|
692
|
+
When the function is not found in the function group.
|
|
693
|
+
"""
|
|
694
|
+
if name not in self._functions:
|
|
695
|
+
raise ValueError(f"Function {name} not found in function group {self._instance_name}")
|
|
696
|
+
self._per_function_filter_fn[name] = filter_fn
|
|
697
|
+
|
|
698
|
+
def set_instance_name(self, instance_name: str):
|
|
699
|
+
"""
|
|
700
|
+
Sets the instance name for the function group.
|
|
701
|
+
|
|
702
|
+
Parameters
|
|
703
|
+
----------
|
|
704
|
+
instance_name : str
|
|
705
|
+
The instance name to set for the function group.
|
|
706
|
+
"""
|
|
707
|
+
self._instance_name = instance_name
|
|
708
|
+
|
|
709
|
+
@property
|
|
710
|
+
def instance_name(self) -> str:
|
|
711
|
+
"""
|
|
712
|
+
Returns the instance name for the function group.
|
|
713
|
+
"""
|
|
714
|
+
return self._instance_name
|