nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +41 -21
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +46 -26
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +46 -11
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +9 -13
- nat/cli/entrypoint.py +8 -10
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +10 -10
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +17 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +3 -2
- nat/runtime/session.py +43 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
nat/observability/register.py
CHANGED
|
@@ -45,7 +45,7 @@ class FileTelemetryExporterConfig(TelemetryExporterBaseConfig, name="file"):
|
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
@register_telemetry_exporter(config_type=FileTelemetryExporterConfig)
|
|
48
|
-
async def file_telemetry_exporter(config: FileTelemetryExporterConfig, builder: Builder):
|
|
48
|
+
async def file_telemetry_exporter(config: FileTelemetryExporterConfig, builder: Builder):
|
|
49
49
|
"""
|
|
50
50
|
Build and return a FileExporter for file-based telemetry export with optional rolling.
|
|
51
51
|
"""
|
|
@@ -68,12 +68,14 @@ class ConsoleLoggingMethodConfig(LoggingBaseConfig, name="console"):
|
|
|
68
68
|
|
|
69
69
|
|
|
70
70
|
@register_logging_method(config_type=ConsoleLoggingMethodConfig)
|
|
71
|
-
async def console_logging_method(config: ConsoleLoggingMethodConfig, builder: Builder):
|
|
71
|
+
async def console_logging_method(config: ConsoleLoggingMethodConfig, builder: Builder):
|
|
72
72
|
"""
|
|
73
73
|
Build and return a StreamHandler for console-based logging.
|
|
74
74
|
"""
|
|
75
|
+
import sys
|
|
76
|
+
|
|
75
77
|
level = getattr(logging, config.level.upper(), logging.INFO)
|
|
76
|
-
handler = logging.StreamHandler()
|
|
78
|
+
handler = logging.StreamHandler(stream=sys.stdout)
|
|
77
79
|
handler.setLevel(level)
|
|
78
80
|
yield handler
|
|
79
81
|
|
|
@@ -86,7 +88,7 @@ class FileLoggingMethod(LoggingBaseConfig, name="file"):
|
|
|
86
88
|
|
|
87
89
|
|
|
88
90
|
@register_logging_method(config_type=FileLoggingMethod)
|
|
89
|
-
async def file_logging_method(config: FileLoggingMethod, builder: Builder):
|
|
91
|
+
async def file_logging_method(config: FileLoggingMethod, builder: Builder):
|
|
90
92
|
"""
|
|
91
93
|
Build and return a FileHandler for file-based logging.
|
|
92
94
|
"""
|
nat/profiler/calc/calc_runner.py
CHANGED
|
@@ -442,7 +442,7 @@ class CalcRunner:
|
|
|
442
442
|
runtime_fit=self.linear_analyzer.wf_runtime_fit # May be None
|
|
443
443
|
)
|
|
444
444
|
except Exception as e:
|
|
445
|
-
logger.exception("Failed to plot concurrency vs. time metrics: %s", e
|
|
445
|
+
logger.exception("Failed to plot concurrency vs. time metrics: %s", e)
|
|
446
446
|
logger.warning("Skipping plot of concurrency vs. time metrics")
|
|
447
447
|
|
|
448
448
|
def write_output(self, output_dir: Path, calc_runner_output: CalcRunnerOutput):
|
|
@@ -506,11 +506,10 @@ class CalcRunner:
|
|
|
506
506
|
continue
|
|
507
507
|
try:
|
|
508
508
|
calc_output = CalcRunnerOutput.model_validate_json(calc_runner_output_path.read_text())
|
|
509
|
-
except ValidationError
|
|
509
|
+
except ValidationError:
|
|
510
510
|
logger.exception("Failed to validate calc runner output file %s. Skipping job %s.",
|
|
511
511
|
calc_runner_output_path,
|
|
512
|
-
|
|
513
|
-
exc_info=True)
|
|
512
|
+
job_dir.name)
|
|
514
513
|
continue
|
|
515
514
|
|
|
516
515
|
# Extract sizing metrics from calc_data
|
|
@@ -53,7 +53,7 @@ def _extract_tools_schema(invocation_params: dict) -> list:
|
|
|
53
53
|
return tools_schema
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback):
|
|
56
|
+
class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback):
|
|
57
57
|
"""Callback Handler that tracks NIM info."""
|
|
58
58
|
|
|
59
59
|
total_tokens: int = 0
|
|
@@ -106,7 +106,7 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback): # p
|
|
|
106
106
|
try:
|
|
107
107
|
model_name = kwargs.get("metadata")["ls_model_name"]
|
|
108
108
|
except Exception as e:
|
|
109
|
-
logger.exception("Error getting model name: %s", e
|
|
109
|
+
logger.exception("Error getting model name: %s", e)
|
|
110
110
|
|
|
111
111
|
run_id = str(kwargs.get("run_id", str(uuid4())))
|
|
112
112
|
self._run_id_to_model_name[run_id] = model_name
|
|
@@ -144,7 +144,7 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback): # p
|
|
|
144
144
|
try:
|
|
145
145
|
model_name = metadata["ls_model_name"] if metadata else kwargs.get("metadata")["ls_model_name"]
|
|
146
146
|
except Exception as e:
|
|
147
|
-
logger.exception("Error getting model name: %s", e
|
|
147
|
+
logger.exception("Error getting model name: %s", e)
|
|
148
148
|
|
|
149
149
|
run_id = str(run_id)
|
|
150
150
|
self._run_id_to_model_name[run_id] = model_name
|
|
@@ -173,13 +173,13 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback): # p
|
|
|
173
173
|
try:
|
|
174
174
|
model_name = self._run_id_to_model_name.get(str(kwargs.get("run_id", "")), "")
|
|
175
175
|
except Exception as e:
|
|
176
|
-
logger.exception("Error getting model name: %s", e
|
|
176
|
+
logger.exception("Error getting model name: %s", e)
|
|
177
177
|
|
|
178
178
|
usage_metadata = {}
|
|
179
179
|
try:
|
|
180
180
|
usage_metadata = kwargs.get("chunk").message.usage_metadata if kwargs.get("chunk") else {}
|
|
181
181
|
except Exception as e:
|
|
182
|
-
logger.exception("Error getting usage metadata: %s", e
|
|
182
|
+
logger.exception("Error getting usage metadata: %s", e)
|
|
183
183
|
|
|
184
184
|
stats = IntermediateStepPayload(
|
|
185
185
|
event_type=IntermediateStepType.LLM_NEW_TOKEN,
|
|
@@ -206,7 +206,7 @@ class LangchainProfilerHandler(AsyncCallbackHandler, BaseProfilerCallback): # p
|
|
|
206
206
|
try:
|
|
207
207
|
model_name = self._run_id_to_model_name.get(str(kwargs.get("run_id", "")), "")
|
|
208
208
|
except Exception as e_inner:
|
|
209
|
-
logger.exception("Error getting model name: %s from outer error %s", e_inner, e
|
|
209
|
+
logger.exception("Error getting model name: %s from outer error %s", e_inner, e)
|
|
210
210
|
|
|
211
211
|
try:
|
|
212
212
|
generation = response.generations[0][0]
|
|
@@ -94,7 +94,7 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
|
|
|
94
94
|
try:
|
|
95
95
|
model_name = payload.get(EventPayload.SERIALIZED)['model']
|
|
96
96
|
except Exception as e:
|
|
97
|
-
logger.exception("Error getting model name: %s", e
|
|
97
|
+
logger.exception("Error getting model name: %s", e)
|
|
98
98
|
|
|
99
99
|
llm_text_input = " ".join(prompts_or_messages) if prompts_or_messages else ""
|
|
100
100
|
|
|
@@ -159,13 +159,13 @@ class LlamaIndexProfilerHandler(BaseCallbackHandler, BaseProfilerCallback):
|
|
|
159
159
|
for block in response.message.blocks:
|
|
160
160
|
llm_text_output += block.text
|
|
161
161
|
except Exception as e:
|
|
162
|
-
logger.exception("Error getting LLM text output: %s", e
|
|
162
|
+
logger.exception("Error getting LLM text output: %s", e)
|
|
163
163
|
|
|
164
164
|
model_name = ""
|
|
165
165
|
try:
|
|
166
166
|
model_name = response.raw.model
|
|
167
167
|
except Exception as e:
|
|
168
|
-
logger.exception("Error getting model name: %s", e
|
|
168
|
+
logger.exception("Error getting model name: %s", e)
|
|
169
169
|
|
|
170
170
|
# Append usage data to NAT usage stats
|
|
171
171
|
with self._lock:
|
|
@@ -86,7 +86,7 @@ class SemanticKernelProfilerHandler(BaseProfilerCallback):
|
|
|
86
86
|
|
|
87
87
|
# Gather the appropriate modules/functions based on your builder config
|
|
88
88
|
for llm in self._builder_llms:
|
|
89
|
-
if self._builder_llms[llm].provider_type == 'openai':
|
|
89
|
+
if self._builder_llms[llm].provider_type == 'openai':
|
|
90
90
|
functions_to_patch.extend(["openai_non_streaming", "openai_streaming"])
|
|
91
91
|
|
|
92
92
|
# Grab original reference for the tool call
|
|
@@ -132,7 +132,7 @@ class SemanticKernelProfilerHandler(BaseProfilerCallback):
|
|
|
132
132
|
if "text" in item:
|
|
133
133
|
model_input += item["text"]
|
|
134
134
|
except Exception as e:
|
|
135
|
-
logger.exception("Error in getting model input: %s", e
|
|
135
|
+
logger.exception("Error in getting model input: %s", e)
|
|
136
136
|
|
|
137
137
|
input_stats = IntermediateStepPayload(event_type=IntermediateStepType.LLM_START,
|
|
138
138
|
framework=LLMFrameworkEnum.SEMANTIC_KERNEL,
|
|
@@ -232,7 +232,7 @@ class SemanticKernelProfilerHandler(BaseProfilerCallback):
|
|
|
232
232
|
return result
|
|
233
233
|
|
|
234
234
|
except Exception as e:
|
|
235
|
-
logger.
|
|
235
|
+
logger.error("ToolUsage._use error: %s", e)
|
|
236
236
|
raise
|
|
237
237
|
|
|
238
238
|
return patched_tool_call
|
nat/profiler/data_frame_row.py
CHANGED
|
@@ -42,7 +42,7 @@ class DataFrameRow(BaseModel):
|
|
|
42
42
|
framework: str | None
|
|
43
43
|
|
|
44
44
|
@field_validator('llm_text_input', 'llm_text_output', 'llm_new_token', mode='before')
|
|
45
|
-
def cast_to_str(cls, v):
|
|
45
|
+
def cast_to_str(cls, v):
|
|
46
46
|
if v is None:
|
|
47
47
|
return v
|
|
48
48
|
try:
|
|
@@ -13,12 +13,11 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
-
# pylint disable=ungrouped-imports
|
|
17
|
-
|
|
18
16
|
from __future__ import annotations
|
|
19
17
|
|
|
20
18
|
import functools
|
|
21
19
|
import logging
|
|
20
|
+
from collections.abc import AsyncIterator
|
|
22
21
|
from collections.abc import Callable
|
|
23
22
|
from contextlib import AbstractAsyncContextManager as AsyncContextManager
|
|
24
23
|
from contextlib import asynccontextmanager
|
|
@@ -34,35 +33,55 @@ _library_instrumented = {
|
|
|
34
33
|
"crewai": False,
|
|
35
34
|
"semantic_kernel": False,
|
|
36
35
|
"agno": False,
|
|
36
|
+
"adk": False,
|
|
37
37
|
}
|
|
38
38
|
|
|
39
39
|
callback_handler_var: ContextVar[Any | None] = ContextVar("callback_handler_var", default=None)
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
def set_framework_profiler_handler(
|
|
43
|
-
workflow_llms: dict = None,
|
|
44
|
-
frameworks: list[LLMFrameworkEnum] = None,
|
|
43
|
+
workflow_llms: dict | None = None,
|
|
44
|
+
frameworks: list[LLMFrameworkEnum] | None = None,
|
|
45
45
|
) -> Callable[[Callable[..., AsyncContextManager[Any]]], Callable[..., AsyncContextManager[Any]]]:
|
|
46
46
|
"""
|
|
47
47
|
Decorator that wraps an async context manager function to set up framework-specific profiling.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
workflow_llms (dict | None): A dictionary of workflow LLM configurations.
|
|
51
|
+
frameworks (list[LLMFrameworkEnum] | None): A list of LLM frameworks used in the workflow functions.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Callable[[Callable[..., AsyncContextManager[Any]]], Callable[..., AsyncContextManager[Any]]]:
|
|
55
|
+
A decorator that wraps the original function with profiling setup.
|
|
48
56
|
"""
|
|
49
57
|
|
|
50
58
|
def decorator(func: Callable[..., AsyncContextManager[Any]]) -> Callable[..., AsyncContextManager[Any]]:
|
|
59
|
+
"""The actual decorator that wraps the function.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
func (Callable[..., AsyncContextManager[Any]]): The function to wrap.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Callable[..., AsyncContextManager[Any]]: The wrapped function.
|
|
66
|
+
"""
|
|
51
67
|
|
|
52
68
|
@functools.wraps(func)
|
|
53
69
|
@asynccontextmanager
|
|
54
70
|
async def wrapper(workflow_config, builder):
|
|
55
71
|
|
|
56
|
-
if LLMFrameworkEnum.LANGCHAIN in frameworks
|
|
57
|
-
|
|
58
|
-
|
|
72
|
+
if LLMFrameworkEnum.LANGCHAIN in frameworks:
|
|
73
|
+
# Always set a fresh handler in the current context so callbacks
|
|
74
|
+
# route to the active run. Only register the hook once globally.
|
|
59
75
|
from nat.profiler.callbacks.langchain_callback_handler import LangchainProfilerHandler
|
|
60
76
|
|
|
61
77
|
handler = LangchainProfilerHandler()
|
|
62
78
|
callback_handler_var.set(handler)
|
|
63
|
-
|
|
64
|
-
_library_instrumented["langchain"]
|
|
65
|
-
|
|
79
|
+
|
|
80
|
+
if not _library_instrumented["langchain"]:
|
|
81
|
+
from langchain_core.tracers.context import register_configure_hook
|
|
82
|
+
register_configure_hook(callback_handler_var, inheritable=True)
|
|
83
|
+
_library_instrumented["langchain"] = True
|
|
84
|
+
logger.debug("LangChain/LangGraph callback hook registered")
|
|
66
85
|
|
|
67
86
|
if LLMFrameworkEnum.LLAMA_INDEX in frameworks:
|
|
68
87
|
from llama_index.core import Settings
|
|
@@ -75,8 +94,7 @@ def set_framework_profiler_handler(
|
|
|
75
94
|
logger.debug("LlamaIndex callback handler registered")
|
|
76
95
|
|
|
77
96
|
if LLMFrameworkEnum.CREWAI in frameworks and not _library_instrumented["crewai"]:
|
|
78
|
-
from nat.plugins.crewai.crewai_callback_handler import
|
|
79
|
-
CrewAIProfilerHandler # pylint: disable=ungrouped-imports,line-too-long # noqa E501
|
|
97
|
+
from nat.plugins.crewai.crewai_callback_handler import CrewAIProfilerHandler
|
|
80
98
|
|
|
81
99
|
handler = CrewAIProfilerHandler()
|
|
82
100
|
handler.instrument()
|
|
@@ -99,6 +117,20 @@ def set_framework_profiler_handler(
|
|
|
99
117
|
_library_instrumented["agno"] = True
|
|
100
118
|
logger.info("Agno callback handler registered")
|
|
101
119
|
|
|
120
|
+
if LLMFrameworkEnum.ADK in frameworks and not _library_instrumented["adk"]:
|
|
121
|
+
try:
|
|
122
|
+
from nat.plugins.adk.adk_callback_handler import ADKProfilerHandler
|
|
123
|
+
except ImportError as e:
|
|
124
|
+
logger.warning(
|
|
125
|
+
"ADK profiler not available. " +
|
|
126
|
+
"Install NAT with ADK extras: pip install 'nvidia-nat[adk]'. Error: %s",
|
|
127
|
+
e)
|
|
128
|
+
else:
|
|
129
|
+
handler = ADKProfilerHandler()
|
|
130
|
+
handler.instrument()
|
|
131
|
+
_library_instrumented["adk"] = True
|
|
132
|
+
logger.debug("ADK callback handler registered")
|
|
133
|
+
|
|
102
134
|
# IMPORTANT: actually call the wrapped function as an async context manager
|
|
103
135
|
async with func(workflow_config, builder) as result:
|
|
104
136
|
yield result
|
|
@@ -117,11 +149,28 @@ def chain_wrapped_build_fn(
|
|
|
117
149
|
Convert an original build function into an async context manager that
|
|
118
150
|
wraps it with a single call to set_framework_profiler_handler, passing
|
|
119
151
|
all frameworks at once.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
original_build_fn (Callable[..., AsyncContextManager]): The original build function to wrap.
|
|
155
|
+
workflow_llms (dict): A dictionary of workflow LLM configurations.
|
|
156
|
+
function_frameworks (list[LLMFrameworkEnum]): A list of LLM frameworks used in the workflow functions.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Callable[..., AsyncContextManager]: The wrapped build function.
|
|
120
160
|
"""
|
|
121
161
|
|
|
122
162
|
# Define a base async context manager that simply calls the original build function.
|
|
123
163
|
@asynccontextmanager
|
|
124
|
-
async def base_fn(*args, **kwargs):
|
|
164
|
+
async def base_fn(*args, **kwargs) -> AsyncIterator[Any]:
|
|
165
|
+
"""Base async context manager that calls the original build function.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
*args: Positional arguments to pass to the original build function.
|
|
169
|
+
**kwargs: Keyword arguments to pass to the original build function.
|
|
170
|
+
|
|
171
|
+
Yields:
|
|
172
|
+
The result of the original build function.
|
|
173
|
+
"""
|
|
125
174
|
async with original_build_fn(*args, **kwargs) as w:
|
|
126
175
|
yield w
|
|
127
176
|
|
|
@@ -16,7 +16,11 @@
|
|
|
16
16
|
import functools
|
|
17
17
|
import inspect
|
|
18
18
|
import uuid
|
|
19
|
+
from collections.abc import Callable
|
|
19
20
|
from typing import Any
|
|
21
|
+
from typing import TypeVar
|
|
22
|
+
from typing import cast
|
|
23
|
+
from typing import overload
|
|
20
24
|
|
|
21
25
|
from pydantic import BaseModel
|
|
22
26
|
|
|
@@ -36,10 +40,10 @@ def _serialize_data(obj: Any) -> Any:
|
|
|
36
40
|
|
|
37
41
|
if isinstance(obj, dict):
|
|
38
42
|
return {str(k): _serialize_data(v) for k, v in obj.items()}
|
|
39
|
-
if isinstance(obj,
|
|
43
|
+
if isinstance(obj, list | tuple | set):
|
|
40
44
|
return [_serialize_data(item) for item in obj]
|
|
41
45
|
|
|
42
|
-
if isinstance(obj,
|
|
46
|
+
if isinstance(obj, str | int | float | bool | type(None)):
|
|
43
47
|
return obj
|
|
44
48
|
|
|
45
49
|
# Fallback
|
|
@@ -75,7 +79,24 @@ def push_intermediate_step(step_manager: IntermediateStepManager,
|
|
|
75
79
|
step_manager.push_intermediate_step(payload)
|
|
76
80
|
|
|
77
81
|
|
|
78
|
-
|
|
82
|
+
# Type variable for overloads
|
|
83
|
+
F = TypeVar('F', bound=Callable[..., Any])
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# Overloads for different function types
|
|
87
|
+
@overload
|
|
88
|
+
def track_function(func: F, *, metadata: dict[str, Any] | None = None) -> F:
|
|
89
|
+
"""Overload for when a function is passed directly."""
|
|
90
|
+
...
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@overload
|
|
94
|
+
def track_function(*, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
|
|
95
|
+
"""Overload for decorator factory usage (when called with parentheses)."""
|
|
96
|
+
...
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None) -> Any:
|
|
79
100
|
"""
|
|
80
101
|
Decorator that can wrap any type of function (sync, async, generator,
|
|
81
102
|
async generator) and executes "tracking logic" around it.
|
|
@@ -252,3 +273,139 @@ def track_function(func: Any = None, *, metadata: dict[str, Any] | None = None):
|
|
|
252
273
|
return result
|
|
253
274
|
|
|
254
275
|
return sync_wrapper
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
# Overloads for track_unregistered_function
|
|
279
|
+
@overload
|
|
280
|
+
def track_unregistered_function(func: F, *, name: str | None = None, metadata: dict[str, Any] | None = None) -> F:
|
|
281
|
+
"""Overload for when a function is passed directly."""
|
|
282
|
+
...
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
@overload
|
|
286
|
+
def track_unregistered_function(*, name: str | None = None, metadata: dict[str, Any] | None = None) -> Callable[[F], F]:
|
|
287
|
+
"""Overload for decorator factory usage (when called with parentheses)."""
|
|
288
|
+
...
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def track_unregistered_function(func: Callable[..., Any] | None = None,
|
|
292
|
+
*,
|
|
293
|
+
name: str | None = None,
|
|
294
|
+
metadata: dict[str, Any] | None = None) -> Callable[..., Any]:
|
|
295
|
+
"""
|
|
296
|
+
Decorator that wraps any function with scope management and automatic tracking.
|
|
297
|
+
|
|
298
|
+
- Sets active function context using the function name
|
|
299
|
+
- Leverages Context.push_active_function for built-in tracking
|
|
300
|
+
- Avoids duplicate tracking entries by relying on the library's built-in systems
|
|
301
|
+
- Supports sync/async functions and generators
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
func: The function to wrap (auto-detected when used without parentheses)
|
|
305
|
+
name: Custom name to use for tracking instead of func.__name__
|
|
306
|
+
metadata: Additional metadata to include in tracking
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
# If called with parameters: @track_unregistered_function(name="...", metadata={...})
|
|
310
|
+
if func is None:
|
|
311
|
+
|
|
312
|
+
def decorator_wrapper(actual_func: Callable[..., Any]) -> Callable[..., Any]:
|
|
313
|
+
# Cast to ensure type checker understands this returns a callable
|
|
314
|
+
return cast(Callable[..., Any], track_unregistered_function(actual_func, name=name, metadata=metadata))
|
|
315
|
+
|
|
316
|
+
return decorator_wrapper
|
|
317
|
+
|
|
318
|
+
# Direct decoration: @track_unregistered_function or recursive call with actual function
|
|
319
|
+
function_name: str = name if name else func.__name__
|
|
320
|
+
|
|
321
|
+
# --- Validate metadata ---
|
|
322
|
+
if metadata is not None:
|
|
323
|
+
if not isinstance(metadata, dict):
|
|
324
|
+
raise TypeError("metadata must be a dict[str, Any].")
|
|
325
|
+
if any(not isinstance(k, str) for k in metadata.keys()):
|
|
326
|
+
raise TypeError("All metadata keys must be strings.")
|
|
327
|
+
|
|
328
|
+
trace_metadata = TraceMetadata(provided_metadata=metadata)
|
|
329
|
+
|
|
330
|
+
# --- Now detect the function type and wrap accordingly ---
|
|
331
|
+
if inspect.isasyncgenfunction(func):
|
|
332
|
+
# ---------------------
|
|
333
|
+
# ASYNC GENERATOR
|
|
334
|
+
# ---------------------
|
|
335
|
+
|
|
336
|
+
@functools.wraps(func)
|
|
337
|
+
async def async_gen_wrapper(*args, **kwargs):
|
|
338
|
+
context = Context.get()
|
|
339
|
+
input_data = (
|
|
340
|
+
*args,
|
|
341
|
+
kwargs,
|
|
342
|
+
)
|
|
343
|
+
# Only do context management - let push_active_function handle tracking
|
|
344
|
+
with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
|
|
345
|
+
final_outputs = []
|
|
346
|
+
async for item in func(*args, **kwargs):
|
|
347
|
+
final_outputs.append(item)
|
|
348
|
+
yield item
|
|
349
|
+
|
|
350
|
+
manager.set_output(final_outputs)
|
|
351
|
+
|
|
352
|
+
return async_gen_wrapper
|
|
353
|
+
|
|
354
|
+
if inspect.iscoroutinefunction(func):
|
|
355
|
+
# ---------------------
|
|
356
|
+
# ASYNC FUNCTION
|
|
357
|
+
# ---------------------
|
|
358
|
+
@functools.wraps(func)
|
|
359
|
+
async def async_wrapper(*args, **kwargs):
|
|
360
|
+
context = Context.get()
|
|
361
|
+
input_data = (
|
|
362
|
+
*args,
|
|
363
|
+
kwargs,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# Only do context management - let push_active_function handle tracking
|
|
367
|
+
with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
|
|
368
|
+
result = await func(*args, **kwargs)
|
|
369
|
+
manager.set_output(result)
|
|
370
|
+
return result
|
|
371
|
+
|
|
372
|
+
return async_wrapper
|
|
373
|
+
|
|
374
|
+
if inspect.isgeneratorfunction(func):
|
|
375
|
+
# ---------------------
|
|
376
|
+
# SYNC GENERATOR
|
|
377
|
+
# ---------------------
|
|
378
|
+
@functools.wraps(func)
|
|
379
|
+
def sync_gen_wrapper(*args, **kwargs):
|
|
380
|
+
context = Context.get()
|
|
381
|
+
input_data = (
|
|
382
|
+
*args,
|
|
383
|
+
kwargs,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Only do context management - let push_active_function handle tracking
|
|
387
|
+
with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
|
|
388
|
+
final_outputs = []
|
|
389
|
+
for item in func(*args, **kwargs):
|
|
390
|
+
final_outputs.append(item)
|
|
391
|
+
yield item
|
|
392
|
+
|
|
393
|
+
manager.set_output(final_outputs)
|
|
394
|
+
|
|
395
|
+
return sync_gen_wrapper
|
|
396
|
+
|
|
397
|
+
@functools.wraps(func)
|
|
398
|
+
def sync_wrapper(*args, **kwargs):
|
|
399
|
+
context = Context.get()
|
|
400
|
+
input_data = (
|
|
401
|
+
*args,
|
|
402
|
+
kwargs,
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# Only do context management - let push_active_function handle tracking
|
|
406
|
+
with context.push_active_function(function_name, input_data=input_data, metadata=trace_metadata) as manager:
|
|
407
|
+
result = func(*args, **kwargs)
|
|
408
|
+
manager.set_output(result)
|
|
409
|
+
return result
|
|
410
|
+
|
|
411
|
+
return sync_wrapper
|
|
@@ -195,7 +195,7 @@ def profile_workflow_bottlenecks(all_steps: list[list[IntermediateStep]]) -> Sim
|
|
|
195
195
|
c_max = 0
|
|
196
196
|
for ts, delta in events_sub:
|
|
197
197
|
c_curr += delta
|
|
198
|
-
if c_curr > c_max: #
|
|
198
|
+
if c_curr > c_max: # noqa: PLR1730 - don't use max built-in
|
|
199
199
|
c_max = c_curr
|
|
200
200
|
max_concurrency_by_name[op_name] = c_max
|
|
201
201
|
|
|
@@ -172,7 +172,7 @@ class CallNode(BaseModel):
|
|
|
172
172
|
if not self.children:
|
|
173
173
|
return self.duration
|
|
174
174
|
|
|
175
|
-
intervals = [(c.start_time, c.end_time) for c in self.children]
|
|
175
|
+
intervals = [(c.start_time, c.end_time) for c in self.children]
|
|
176
176
|
# Sort by start time
|
|
177
177
|
intervals.sort(key=lambda x: x[0])
|
|
178
178
|
|
|
@@ -204,7 +204,7 @@ class CallNode(BaseModel):
|
|
|
204
204
|
This ensures no overlap double-counting among children.
|
|
205
205
|
"""
|
|
206
206
|
total = self.compute_self_time()
|
|
207
|
-
for c in self.children:
|
|
207
|
+
for c in self.children:
|
|
208
208
|
total += c.compute_subtree_time()
|
|
209
209
|
return total
|
|
210
210
|
|
|
@@ -216,7 +216,7 @@ class CallNode(BaseModel):
|
|
|
216
216
|
info = (f"{indent}- {self.operation_type} '{self.operation_name}' "
|
|
217
217
|
f"(uuid={self.uuid}, start={self.start_time:.2f}, "
|
|
218
218
|
f"end={self.end_time:.2f}, dur={self.duration:.2f})")
|
|
219
|
-
child_strs = [child._repr(level + 1) for child in self.children]
|
|
219
|
+
child_strs = [child._repr(level + 1) for child in self.children]
|
|
220
220
|
return "\n".join([info] + child_strs)
|
|
221
221
|
|
|
222
222
|
|
|
@@ -228,7 +228,7 @@ def run_prefixspan(sequences_map: dict[int, list[PrefixCallNode]],
|
|
|
228
228
|
else:
|
|
229
229
|
abs_min_support = min_support
|
|
230
230
|
|
|
231
|
-
freq_patterns = ps.frequent(abs_min_support)
|
|
231
|
+
freq_patterns = ps.frequent(abs_min_support)
|
|
232
232
|
# freq_patterns => [(count, [item1, item2, ...])]
|
|
233
233
|
|
|
234
234
|
results = []
|
|
@@ -321,13 +321,12 @@ def compute_coverage_and_duration(sequences_map: dict[int, list[PrefixCallNode]]
|
|
|
321
321
|
# --------------------------------------------------------------------------------
|
|
322
322
|
|
|
323
323
|
|
|
324
|
-
def prefixspan_subworkflow_with_text(
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
prefix_list: list[str] = None) -> PrefixSpanSubworkflowResult:
|
|
324
|
+
def prefixspan_subworkflow_with_text(all_steps: list[list[IntermediateStep]],
|
|
325
|
+
min_support: int | float = 2,
|
|
326
|
+
top_k: int = 10,
|
|
327
|
+
min_coverage: float = 0.0,
|
|
328
|
+
max_text_len: int = 700,
|
|
329
|
+
prefix_list: list[str] = None) -> PrefixSpanSubworkflowResult:
|
|
331
330
|
"""
|
|
332
331
|
1) Build sequences of calls for each example (with llm_text_input).
|
|
333
332
|
2) Convert to token lists, run PrefixSpan with min_support.
|
|
@@ -66,7 +66,7 @@ def compute_inter_query_token_uniqueness_by_llm(all_steps: list[list[Intermediat
|
|
|
66
66
|
# 2) Group by (llm_name, example_number), then sort each group
|
|
67
67
|
grouped = cdf.groupby(['llm_name', 'example_number'], as_index=False, group_keys=True)
|
|
68
68
|
|
|
69
|
-
for (llm, ex_num), group_df in grouped:
|
|
69
|
+
for (llm, ex_num), group_df in grouped:
|
|
70
70
|
# Sort by event_timestamp
|
|
71
71
|
group_df = group_df.sort_values('event_timestamp', ascending=True)
|
|
72
72
|
|
|
File without changes
|