nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiq/__init__.py +2 -2
- nat/agent/base.py +24 -15
- nat/agent/dual_node.py +9 -4
- nat/agent/prompt_optimizer/prompt.py +68 -0
- nat/agent/prompt_optimizer/register.py +149 -0
- nat/agent/react_agent/agent.py +79 -47
- nat/agent/react_agent/register.py +41 -21
- nat/agent/reasoning_agent/reasoning_agent.py +11 -9
- nat/agent/register.py +1 -1
- nat/agent/rewoo_agent/agent.py +326 -148
- nat/agent/rewoo_agent/prompt.py +19 -22
- nat/agent/rewoo_agent/register.py +46 -26
- nat/agent/tool_calling_agent/agent.py +84 -28
- nat/agent/tool_calling_agent/register.py +51 -28
- nat/authentication/api_key/api_key_auth_provider.py +2 -2
- nat/authentication/credential_validator/bearer_token_validator.py +557 -0
- nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
- nat/authentication/interfaces.py +5 -2
- nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
- nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
- nat/authentication/register.py +0 -1
- nat/builder/builder.py +56 -24
- nat/builder/component_utils.py +9 -5
- nat/builder/context.py +46 -11
- nat/builder/eval_builder.py +16 -11
- nat/builder/framework_enum.py +1 -0
- nat/builder/front_end.py +1 -1
- nat/builder/function.py +378 -8
- nat/builder/function_base.py +3 -3
- nat/builder/function_info.py +6 -8
- nat/builder/user_interaction_manager.py +2 -2
- nat/builder/workflow.py +13 -1
- nat/builder/workflow_builder.py +281 -76
- nat/cli/cli_utils/config_override.py +2 -2
- nat/cli/commands/evaluate.py +1 -1
- nat/cli/commands/info/info.py +16 -6
- nat/cli/commands/info/list_channels.py +1 -1
- nat/cli/commands/info/list_components.py +7 -8
- nat/cli/commands/mcp/__init__.py +14 -0
- nat/cli/commands/mcp/mcp.py +986 -0
- nat/cli/commands/object_store/__init__.py +14 -0
- nat/cli/commands/object_store/object_store.py +227 -0
- nat/cli/commands/optimize.py +90 -0
- nat/cli/commands/registry/publish.py +2 -2
- nat/cli/commands/registry/pull.py +2 -2
- nat/cli/commands/registry/remove.py +2 -2
- nat/cli/commands/registry/search.py +15 -17
- nat/cli/commands/start.py +16 -5
- nat/cli/commands/uninstall.py +1 -1
- nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
- nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
- nat/cli/commands/workflow/templates/register.py.j2 +0 -1
- nat/cli/commands/workflow/workflow_commands.py +9 -13
- nat/cli/entrypoint.py +8 -10
- nat/cli/register_workflow.py +38 -4
- nat/cli/type_registry.py +75 -6
- nat/control_flow/__init__.py +0 -0
- nat/control_flow/register.py +20 -0
- nat/control_flow/router_agent/__init__.py +0 -0
- nat/control_flow/router_agent/agent.py +329 -0
- nat/control_flow/router_agent/prompt.py +48 -0
- nat/control_flow/router_agent/register.py +91 -0
- nat/control_flow/sequential_executor.py +166 -0
- nat/data_models/agent.py +34 -0
- nat/data_models/api_server.py +10 -10
- nat/data_models/authentication.py +23 -9
- nat/data_models/common.py +1 -1
- nat/data_models/component.py +2 -0
- nat/data_models/component_ref.py +11 -0
- nat/data_models/config.py +41 -17
- nat/data_models/dataset_handler.py +1 -1
- nat/data_models/discovery_metadata.py +4 -4
- nat/data_models/evaluate.py +4 -1
- nat/data_models/function.py +34 -0
- nat/data_models/function_dependencies.py +14 -6
- nat/data_models/gated_field_mixin.py +242 -0
- nat/data_models/intermediate_step.py +3 -3
- nat/data_models/optimizable.py +119 -0
- nat/data_models/optimizer.py +149 -0
- nat/data_models/swe_bench_model.py +1 -1
- nat/data_models/temperature_mixin.py +44 -0
- nat/data_models/thinking_mixin.py +86 -0
- nat/data_models/top_p_mixin.py +44 -0
- nat/embedder/nim_embedder.py +1 -1
- nat/embedder/openai_embedder.py +1 -1
- nat/embedder/register.py +0 -1
- nat/eval/config.py +3 -1
- nat/eval/dataset_handler/dataset_handler.py +71 -7
- nat/eval/evaluate.py +86 -31
- nat/eval/evaluator/base_evaluator.py +1 -1
- nat/eval/evaluator/evaluator_model.py +13 -0
- nat/eval/intermediate_step_adapter.py +1 -1
- nat/eval/rag_evaluator/evaluate.py +2 -2
- nat/eval/rag_evaluator/register.py +3 -3
- nat/eval/register.py +4 -1
- nat/eval/remote_workflow.py +3 -3
- nat/eval/runtime_evaluator/__init__.py +14 -0
- nat/eval/runtime_evaluator/evaluate.py +123 -0
- nat/eval/runtime_evaluator/register.py +100 -0
- nat/eval/swe_bench_evaluator/evaluate.py +6 -6
- nat/eval/trajectory_evaluator/evaluate.py +1 -1
- nat/eval/trajectory_evaluator/register.py +1 -1
- nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
- nat/eval/utils/eval_trace_ctx.py +89 -0
- nat/eval/utils/weave_eval.py +18 -9
- nat/experimental/decorators/experimental_warning_decorator.py +27 -7
- nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
- nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
- nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
- nat/experimental/test_time_compute/models/strategy_base.py +5 -4
- nat/experimental/test_time_compute/register.py +0 -1
- nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
- nat/front_ends/console/authentication_flow_handler.py +82 -30
- nat/front_ends/console/console_front_end_plugin.py +8 -5
- nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
- nat/front_ends/fastapi/dask_client_mixin.py +65 -0
- nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
- nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
- nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
- nat/front_ends/fastapi/job_store.py +518 -99
- nat/front_ends/fastapi/main.py +11 -19
- nat/front_ends/fastapi/message_handler.py +13 -14
- nat/front_ends/fastapi/message_validator.py +17 -19
- nat/front_ends/fastapi/response_helpers.py +4 -4
- nat/front_ends/fastapi/step_adaptor.py +2 -2
- nat/front_ends/fastapi/utils.py +57 -0
- nat/front_ends/mcp/introspection_token_verifier.py +73 -0
- nat/front_ends/mcp/mcp_front_end_config.py +10 -1
- nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
- nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
- nat/front_ends/mcp/tool_converter.py +44 -14
- nat/front_ends/register.py +0 -1
- nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
- nat/llm/aws_bedrock_llm.py +24 -12
- nat/llm/azure_openai_llm.py +13 -6
- nat/llm/litellm_llm.py +69 -0
- nat/llm/nim_llm.py +20 -8
- nat/llm/openai_llm.py +14 -6
- nat/llm/register.py +4 -1
- nat/llm/utils/env_config_value.py +2 -3
- nat/llm/utils/thinking.py +215 -0
- nat/meta/pypi.md +9 -9
- nat/object_store/register.py +0 -1
- nat/observability/exporter/base_exporter.py +3 -3
- nat/observability/exporter/file_exporter.py +1 -1
- nat/observability/exporter/processing_exporter.py +309 -81
- nat/observability/exporter/span_exporter.py +1 -1
- nat/observability/exporter_manager.py +7 -7
- nat/observability/mixin/file_mixin.py +7 -7
- nat/observability/mixin/redaction_config_mixin.py +42 -0
- nat/observability/mixin/tagging_config_mixin.py +62 -0
- nat/observability/mixin/type_introspection_mixin.py +420 -107
- nat/observability/processor/batching_processor.py +5 -7
- nat/observability/processor/falsy_batch_filter_processor.py +55 -0
- nat/observability/processor/processor.py +3 -0
- nat/observability/processor/processor_factory.py +70 -0
- nat/observability/processor/redaction/__init__.py +24 -0
- nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
- nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
- nat/observability/processor/redaction/redaction_processor.py +177 -0
- nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
- nat/observability/processor/span_tagging_processor.py +68 -0
- nat/observability/register.py +6 -4
- nat/profiler/calc/calc_runner.py +3 -4
- nat/profiler/callbacks/agno_callback_handler.py +1 -1
- nat/profiler/callbacks/langchain_callback_handler.py +6 -6
- nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
- nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
- nat/profiler/data_frame_row.py +1 -1
- nat/profiler/decorators/framework_wrapper.py +62 -13
- nat/profiler/decorators/function_tracking.py +160 -3
- nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
- nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
- nat/profiler/inference_optimization/data_models.py +3 -3
- nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
- nat/profiler/inference_optimization/token_uniqueness.py +1 -1
- nat/profiler/parameter_optimization/__init__.py +0 -0
- nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
- nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
- nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
- nat/profiler/parameter_optimization/parameter_selection.py +107 -0
- nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
- nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
- nat/profiler/parameter_optimization/update_helpers.py +66 -0
- nat/profiler/profile_runner.py +14 -9
- nat/profiler/utils.py +4 -2
- nat/registry_handlers/local/local_handler.py +2 -2
- nat/registry_handlers/package_utils.py +1 -2
- nat/registry_handlers/pypi/pypi_handler.py +23 -26
- nat/registry_handlers/register.py +3 -4
- nat/registry_handlers/rest/rest_handler.py +12 -13
- nat/retriever/milvus/retriever.py +2 -2
- nat/retriever/nemo_retriever/retriever.py +1 -1
- nat/retriever/register.py +0 -1
- nat/runtime/loader.py +2 -2
- nat/runtime/runner.py +3 -2
- nat/runtime/session.py +43 -8
- nat/settings/global_settings.py +16 -5
- nat/tool/chat_completion.py +5 -2
- nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
- nat/tool/datetime_tools.py +49 -9
- nat/tool/document_search.py +2 -2
- nat/tool/github_tools.py +450 -0
- nat/tool/nvidia_rag.py +1 -1
- nat/tool/register.py +2 -9
- nat/tool/retriever.py +3 -2
- nat/utils/callable_utils.py +70 -0
- nat/utils/data_models/schema_validator.py +3 -3
- nat/utils/exception_handlers/automatic_retries.py +104 -51
- nat/utils/exception_handlers/schemas.py +1 -1
- nat/utils/io/yaml_tools.py +2 -2
- nat/utils/log_levels.py +25 -0
- nat/utils/reactive/base/observable_base.py +2 -2
- nat/utils/reactive/base/observer_base.py +1 -1
- nat/utils/reactive/observable.py +2 -2
- nat/utils/reactive/observer.py +4 -4
- nat/utils/reactive/subscription.py +1 -1
- nat/utils/settings/global_settings.py +6 -8
- nat/utils/type_converter.py +4 -3
- nat/utils/type_utils.py +9 -5
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
- nat/cli/commands/info/list_mcp.py +0 -304
- nat/tool/github_tools/create_github_commit.py +0 -133
- nat/tool/github_tools/create_github_issue.py +0 -87
- nat/tool/github_tools/create_github_pr.py +0 -106
- nat/tool/github_tools/get_github_file.py +0 -106
- nat/tool/github_tools/get_github_issue.py +0 -166
- nat/tool/github_tools/get_github_pr.py +0 -256
- nat/tool/github_tools/update_github_issue.py +0 -100
- nat/tool/mcp/exceptions.py +0 -142
- nat/tool/mcp/mcp_client.py +0 -255
- nat/tool/mcp/mcp_tool.py +0 -96
- nat/utils/exception_handlers/mcp.py +0 -211
- /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
- /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
- {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
|
@@ -13,22 +13,67 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import json
|
|
16
17
|
import logging
|
|
17
18
|
import os
|
|
18
19
|
import shutil
|
|
19
|
-
import
|
|
20
|
+
import typing
|
|
21
|
+
from asyncio import current_task
|
|
22
|
+
from collections.abc import AsyncGenerator
|
|
23
|
+
from collections.abc import Callable
|
|
24
|
+
from contextlib import asynccontextmanager
|
|
20
25
|
from datetime import UTC
|
|
21
26
|
from datetime import datetime
|
|
22
27
|
from datetime import timedelta
|
|
23
28
|
from enum import Enum
|
|
24
29
|
from uuid import uuid4
|
|
25
30
|
|
|
31
|
+
from dask.distributed import Client as DaskClient
|
|
32
|
+
from dask.distributed import Future
|
|
33
|
+
from dask.distributed import Variable
|
|
34
|
+
from dask.distributed import fire_and_forget
|
|
26
35
|
from pydantic import BaseModel
|
|
36
|
+
from sqlalchemy import DateTime
|
|
37
|
+
from sqlalchemy import String
|
|
38
|
+
from sqlalchemy import and_
|
|
39
|
+
from sqlalchemy import select
|
|
40
|
+
from sqlalchemy import update
|
|
41
|
+
from sqlalchemy.ext.asyncio import async_scoped_session
|
|
42
|
+
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
43
|
+
from sqlalchemy.orm import DeclarativeBase
|
|
44
|
+
from sqlalchemy.orm import Mapped
|
|
45
|
+
from sqlalchemy.orm import mapped_column
|
|
46
|
+
from sqlalchemy.sql import expression as sa_expr
|
|
47
|
+
|
|
48
|
+
from nat.front_ends.fastapi.dask_client_mixin import DaskClientMixin
|
|
49
|
+
|
|
50
|
+
if typing.TYPE_CHECKING:
|
|
51
|
+
from sqlalchemy.engine import Engine
|
|
52
|
+
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
53
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
27
54
|
|
|
28
55
|
logger = logging.getLogger(__name__)
|
|
29
56
|
|
|
30
57
|
|
|
31
58
|
class JobStatus(str, Enum):
|
|
59
|
+
"""
|
|
60
|
+
Enumeration of possible job statuses in the job store.
|
|
61
|
+
|
|
62
|
+
Attributes
|
|
63
|
+
----------
|
|
64
|
+
SUBMITTED : str
|
|
65
|
+
Job has been submitted to the scheduler but not yet started.
|
|
66
|
+
RUNNING : str
|
|
67
|
+
Job is currently being executed.
|
|
68
|
+
SUCCESS : str
|
|
69
|
+
Job completed successfully.
|
|
70
|
+
FAILURE : str
|
|
71
|
+
Job failed during execution.
|
|
72
|
+
INTERRUPTED : str
|
|
73
|
+
Job was interrupted or cancelled before completion.
|
|
74
|
+
NOT_FOUND : str
|
|
75
|
+
Job ID does not exist in the job store.
|
|
76
|
+
"""
|
|
32
77
|
SUBMITTED = "submitted"
|
|
33
78
|
RUNNING = "running"
|
|
34
79
|
SUCCESS = "success"
|
|
@@ -37,42 +82,175 @@ class JobStatus(str, Enum):
|
|
|
37
82
|
NOT_FOUND = "not_found"
|
|
38
83
|
|
|
39
84
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
85
|
+
class Base(DeclarativeBase):
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class JobInfo(Base):
|
|
90
|
+
"""
|
|
91
|
+
SQLAlchemy model representing job metadata and status information.
|
|
92
|
+
|
|
93
|
+
This model stores comprehensive information about jobs submitted to the Dask scheduler, including their current
|
|
94
|
+
status, configuration, outputs, and lifecycle metadata.
|
|
95
|
+
|
|
96
|
+
Attributes
|
|
97
|
+
----------
|
|
98
|
+
job_id : str
|
|
99
|
+
Unique identifier for the job (primary key).
|
|
100
|
+
status : JobStatus
|
|
101
|
+
Current status of the job.
|
|
102
|
+
config_file : str, optional
|
|
103
|
+
Path to the configuration file used for the job.
|
|
104
|
+
error : str, optional
|
|
105
|
+
Error message if the job failed.
|
|
106
|
+
output_path : str, optional
|
|
107
|
+
Path where job outputs are stored.
|
|
108
|
+
created_at : datetime
|
|
109
|
+
Timestamp when the job was created.
|
|
110
|
+
updated_at : datetime
|
|
111
|
+
Timestamp when the job was last updated.
|
|
112
|
+
expiry_seconds : int
|
|
113
|
+
Number of seconds after which the job is eligible for cleanup.
|
|
114
|
+
output : str, optional
|
|
115
|
+
Serialized job output data (JSON format).
|
|
116
|
+
is_expired : bool
|
|
117
|
+
Flag indicating if the job has been marked as expired.
|
|
118
|
+
"""
|
|
119
|
+
__tablename__ = "job_info"
|
|
120
|
+
|
|
121
|
+
job_id: Mapped[str] = mapped_column(primary_key=True)
|
|
122
|
+
status: Mapped[JobStatus] = mapped_column(String(11))
|
|
123
|
+
config_file: Mapped[str] = mapped_column(nullable=True)
|
|
124
|
+
error: Mapped[str] = mapped_column(nullable=True)
|
|
125
|
+
output_path: Mapped[str] = mapped_column(nullable=True)
|
|
126
|
+
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.now(UTC))
|
|
127
|
+
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True),
|
|
128
|
+
default=datetime.now(UTC),
|
|
129
|
+
onupdate=datetime.now(UTC))
|
|
130
|
+
expiry_seconds: Mapped[int]
|
|
131
|
+
output: Mapped[str] = mapped_column(nullable=True)
|
|
132
|
+
is_expired: Mapped[bool] = mapped_column(default=False, index=True)
|
|
133
|
+
|
|
134
|
+
def __repr__(self):
|
|
135
|
+
return f"JobInfo(job_id={self.job_id}, status={self.status})"
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class JobStore(DaskClientMixin):
|
|
139
|
+
"""
|
|
140
|
+
Tracks and manages jobs submitted to the Dask scheduler, along with persisting job metadata (JobInfo objects) in a
|
|
141
|
+
database.
|
|
142
|
+
|
|
143
|
+
Parameters
|
|
144
|
+
----------
|
|
145
|
+
scheduler_address: str
|
|
146
|
+
The address of the Dask scheduler.
|
|
147
|
+
db_engine: AsyncEngine | None, optional, default=None
|
|
148
|
+
The database engine for the job store.
|
|
149
|
+
db_url: str | None, optional, default=None
|
|
150
|
+
The database URL to connect to, used when db_engine is not provided. Refer to:
|
|
151
|
+
https://docs.sqlalchemy.org/en/20/core/engines.html#database-urls
|
|
152
|
+
"""
|
|
54
153
|
|
|
55
154
|
MIN_EXPIRY = 600 # 10 minutes
|
|
56
155
|
MAX_EXPIRY = 86400 # 24 hours
|
|
57
156
|
DEFAULT_EXPIRY = 3600 # 1 hour
|
|
58
157
|
|
|
59
158
|
# active jobs are exempt from expiry
|
|
60
|
-
ACTIVE_STATUS = {
|
|
159
|
+
ACTIVE_STATUS = {JobStatus.RUNNING, JobStatus.SUBMITTED}
|
|
160
|
+
|
|
161
|
+
def __init__(
|
|
162
|
+
self,
|
|
163
|
+
scheduler_address: str,
|
|
164
|
+
db_engine: "AsyncEngine | None" = None,
|
|
165
|
+
db_url: str | None = None,
|
|
166
|
+
):
|
|
167
|
+
self._scheduler_address = scheduler_address
|
|
168
|
+
|
|
169
|
+
if db_engine is None:
|
|
170
|
+
if db_url is None:
|
|
171
|
+
raise ValueError("Either db_engine or db_url must be provided")
|
|
172
|
+
|
|
173
|
+
db_engine = get_db_engine(db_url, use_async=True)
|
|
174
|
+
|
|
175
|
+
# Disabling expire_on_commit allows us to detach (expunge) job
|
|
176
|
+
# instances from the session
|
|
177
|
+
session_maker = async_sessionmaker(db_engine, expire_on_commit=False)
|
|
178
|
+
|
|
179
|
+
# The async_scoped_session ensures that the same session is used
|
|
180
|
+
# within the same task, and that no two tasks share the same session.
|
|
181
|
+
self._session = async_scoped_session(session_maker, scopefunc=current_task)
|
|
182
|
+
|
|
183
|
+
@asynccontextmanager
|
|
184
|
+
async def client(self) -> AsyncGenerator[DaskClient]:
|
|
185
|
+
"""
|
|
186
|
+
Async context manager for obtaining a Dask client connection.
|
|
187
|
+
|
|
188
|
+
Yields
|
|
189
|
+
------
|
|
190
|
+
DaskClient
|
|
191
|
+
An active Dask client connected to the scheduler. The client is automatically closed when exiting the
|
|
192
|
+
context manager.
|
|
193
|
+
"""
|
|
194
|
+
async with super().client(self._scheduler_address) as client:
|
|
195
|
+
yield client
|
|
196
|
+
|
|
197
|
+
@asynccontextmanager
|
|
198
|
+
async def session(self) -> AsyncGenerator["AsyncSession"]:
|
|
199
|
+
"""
|
|
200
|
+
Async context manager for a SQLAlchemy session with automatic transaction management.
|
|
201
|
+
|
|
202
|
+
Creates a new database session scoped to the current async task and begins a transaction. The transaction is
|
|
203
|
+
committed on successful exit and rolled back on exception. The session is automatically removed from the
|
|
204
|
+
registry after use.
|
|
205
|
+
|
|
206
|
+
Yields
|
|
207
|
+
------
|
|
208
|
+
AsyncSession
|
|
209
|
+
An active SQLAlchemy async session with an open transaction.
|
|
210
|
+
"""
|
|
211
|
+
async with self._session() as session:
|
|
212
|
+
async with session.begin():
|
|
213
|
+
yield session
|
|
61
214
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
self.
|
|
215
|
+
# Removes the current task key from the session registry, preventing
|
|
216
|
+
# potential memory leaks
|
|
217
|
+
await self._session.remove()
|
|
65
218
|
|
|
66
|
-
def
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
219
|
+
def ensure_job_id(self, job_id: str | None) -> str:
|
|
220
|
+
"""
|
|
221
|
+
Ensure a job ID is provided, generating a new one if necessary.
|
|
222
|
+
|
|
223
|
+
If a job ID is provided, it is returned as-is.
|
|
224
|
+
|
|
225
|
+
Parameters
|
|
226
|
+
----------
|
|
227
|
+
job_id: str | None
|
|
228
|
+
The job ID to ensure, or None to generate a new one.
|
|
229
|
+
"""
|
|
70
230
|
if job_id is None:
|
|
71
231
|
job_id = str(uuid4())
|
|
232
|
+
logger.info("Generated new job ID: %s", job_id)
|
|
233
|
+
|
|
234
|
+
return job_id
|
|
235
|
+
|
|
236
|
+
async def _create_job(self,
|
|
237
|
+
config_file: str | None = None,
|
|
238
|
+
job_id: str | None = None,
|
|
239
|
+
expiry_seconds: int = DEFAULT_EXPIRY) -> str:
|
|
240
|
+
"""
|
|
241
|
+
Create a job and add it to the job store. This should not be called directly, but instead be called by
|
|
242
|
+
`submit_job`
|
|
243
|
+
"""
|
|
244
|
+
job_id = self.ensure_job_id(job_id)
|
|
72
245
|
|
|
73
246
|
clamped_expiry = max(self.MIN_EXPIRY, min(expiry_seconds, self.MAX_EXPIRY))
|
|
74
247
|
if expiry_seconds != clamped_expiry:
|
|
75
|
-
logger.info(
|
|
248
|
+
logger.info(
|
|
249
|
+
"Clamped expiry_seconds from %d to %d for job %s",
|
|
250
|
+
expiry_seconds,
|
|
251
|
+
clamped_expiry,
|
|
252
|
+
job_id,
|
|
253
|
+
)
|
|
76
254
|
|
|
77
255
|
job = JobInfo(job_id=job_id,
|
|
78
256
|
status=JobStatus.SUBMITTED,
|
|
@@ -83,101 +261,342 @@ class JobStore:
|
|
|
83
261
|
output_path=None,
|
|
84
262
|
expiry_seconds=clamped_expiry)
|
|
85
263
|
|
|
86
|
-
with self.
|
|
87
|
-
|
|
264
|
+
async with self.session() as session:
|
|
265
|
+
session.add(job)
|
|
88
266
|
|
|
89
267
|
logger.info("Created new job %s with config %s", job_id, config_file)
|
|
90
268
|
return job_id
|
|
91
269
|
|
|
92
|
-
def
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
270
|
+
async def submit_job(self,
|
|
271
|
+
*,
|
|
272
|
+
job_id: str | None = None,
|
|
273
|
+
config_file: str | None = None,
|
|
274
|
+
expiry_seconds: int = DEFAULT_EXPIRY,
|
|
275
|
+
sync_timeout: int = 0,
|
|
276
|
+
job_fn: Callable[..., typing.Any],
|
|
277
|
+
job_args: list[typing.Any],
|
|
278
|
+
**job_kwargs) -> tuple[str, JobInfo | None]:
|
|
279
|
+
"""
|
|
280
|
+
Submit a job to the Dask scheduler, and store job metadata in the database.
|
|
281
|
+
|
|
282
|
+
Parameters
|
|
283
|
+
----------
|
|
284
|
+
job_id: str | None, optional, default=None
|
|
285
|
+
The job ID to use, or None to generate a new one.
|
|
286
|
+
config_file: str | None, optional, default=None
|
|
287
|
+
The config file used to run the job, if any.
|
|
288
|
+
expiry_seconds: int, optional, default=3600
|
|
289
|
+
The number of seconds after which the job should be considered expired. Expired jobs are eligible for
|
|
290
|
+
cleanup, but are not deleted immediately.
|
|
291
|
+
sync_timeout: int, optional, default=0
|
|
292
|
+
If greater than 0, wait for the job to complete for up to this many seconds. If the job does not complete
|
|
293
|
+
in this time, return immediately with the job ID and no job info. If the job completes in this time,
|
|
294
|
+
return the job ID and the job info. If 0, return immediately with the job ID and no job info.
|
|
295
|
+
job_fn: Callable[..., typing.Any]
|
|
296
|
+
The function to run as the job. This function must be serializable by Dask.
|
|
297
|
+
job_args: list[typing.Any]
|
|
298
|
+
The arguments to pass to the job function. These must be serializable by Dask.
|
|
299
|
+
job_kwargs: dict[str, typing.Any]
|
|
300
|
+
The keyword arguments to pass to the job function. These must be serializable by Dask
|
|
301
|
+
"""
|
|
302
|
+
job_id = await self._create_job(job_id=job_id, config_file=config_file, expiry_seconds=expiry_seconds)
|
|
303
|
+
|
|
304
|
+
# We are intentionally not using job_id as the key, since Dask will clear the associated metadata once
|
|
305
|
+
# the job has completed, and we want the metadata to persist until the job expires.
|
|
306
|
+
async with self.client() as client:
|
|
307
|
+
logger.debug("Submitting job with job_args: %s, job_kwargs: %s", job_args, job_kwargs)
|
|
308
|
+
future = client.submit(job_fn, *job_args, key=f"{job_id}-job", **job_kwargs)
|
|
309
|
+
|
|
310
|
+
# Store the future in a variable, this allows us to potentially cancel the future later if needed
|
|
311
|
+
future_var = Variable(name=job_id, client=client)
|
|
312
|
+
await future_var.set(future)
|
|
313
|
+
|
|
314
|
+
if sync_timeout > 0:
|
|
315
|
+
try:
|
|
316
|
+
_ = await future.result(timeout=sync_timeout)
|
|
317
|
+
job = await self.get_job(job_id)
|
|
318
|
+
assert job is not None, "Job should exist after future result"
|
|
319
|
+
return (job_id, job)
|
|
320
|
+
except TimeoutError:
|
|
321
|
+
pass
|
|
322
|
+
|
|
323
|
+
fire_and_forget(future)
|
|
324
|
+
|
|
325
|
+
return (job_id, None)
|
|
326
|
+
|
|
327
|
+
async def update_status(self,
|
|
328
|
+
job_id: str,
|
|
329
|
+
status: str | JobStatus,
|
|
330
|
+
error: str | None = None,
|
|
331
|
+
output_path: str | None = None,
|
|
332
|
+
output: BaseModel | None = None):
|
|
333
|
+
"""
|
|
334
|
+
Update the status and metadata of an existing job.
|
|
335
|
+
|
|
336
|
+
Parameters
|
|
337
|
+
----------
|
|
338
|
+
job_id : str
|
|
339
|
+
The unique identifier of the job to update.
|
|
340
|
+
status : str | JobStatus
|
|
341
|
+
The new status to set for the job (should be a valid JobStatus value).
|
|
342
|
+
error : str, optional, default=None
|
|
343
|
+
Error message to store if the job failed.
|
|
344
|
+
output_path : str, optional, default=None
|
|
345
|
+
Path where job outputs are stored.
|
|
346
|
+
output : BaseModel, optional, default=None
|
|
347
|
+
Job output data. Can be a Pydantic BaseModel, dict, list, or string. BaseModel and dict/list objects are
|
|
348
|
+
serialized to JSON for storage.
|
|
349
|
+
|
|
350
|
+
Raises
|
|
351
|
+
------
|
|
352
|
+
ValueError
|
|
353
|
+
If the specified job_id does not exist in the job store.
|
|
354
|
+
"""
|
|
355
|
+
|
|
356
|
+
async with self.session() as session:
|
|
357
|
+
job: JobInfo = await session.get(JobInfo, job_id)
|
|
358
|
+
if job is None:
|
|
359
|
+
raise ValueError(f"Job {job_id} not found in job store")
|
|
360
|
+
|
|
361
|
+
if not isinstance(status, JobStatus):
|
|
362
|
+
status = JobStatus(status)
|
|
363
|
+
|
|
364
|
+
job.status = status.value
|
|
104
365
|
job.error = error
|
|
105
366
|
job.output_path = output_path
|
|
106
367
|
job.updated_at = datetime.now(UTC)
|
|
368
|
+
|
|
369
|
+
if isinstance(output, BaseModel):
|
|
370
|
+
# Convert BaseModel to JSON string for storage
|
|
371
|
+
output = output.model_dump_json(round_trip=True)
|
|
372
|
+
|
|
373
|
+
if isinstance(output, dict | list):
|
|
374
|
+
# Convert dict or list to JSON string for storage
|
|
375
|
+
output = json.dumps(output)
|
|
376
|
+
|
|
107
377
|
job.output = output
|
|
108
378
|
|
|
109
|
-
def
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
"""
|
|
124
|
-
with self.
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
379
|
+
async def get_all_jobs(self) -> list[JobInfo]:
|
|
380
|
+
"""
|
|
381
|
+
Retrieve all jobs from the job store.
|
|
382
|
+
|
|
383
|
+
Returns
|
|
384
|
+
-------
|
|
385
|
+
list[JobInfo]
|
|
386
|
+
A list of all JobInfo objects in the database. This operation can be expensive if there are many jobs
|
|
387
|
+
stored.
|
|
388
|
+
|
|
389
|
+
Warning
|
|
390
|
+
-------
|
|
391
|
+
This method loads all jobs into memory and should be used with caution in production environments with large
|
|
392
|
+
job stores.
|
|
393
|
+
"""
|
|
394
|
+
async with self.session() as session:
|
|
395
|
+
return (await session.scalars(select(JobInfo))).all()
|
|
396
|
+
|
|
397
|
+
async def get_job(self, job_id: str) -> JobInfo | None:
|
|
398
|
+
"""
|
|
399
|
+
Retrieve a specific job by its unique identifier.
|
|
400
|
+
|
|
401
|
+
Parameters
|
|
402
|
+
----------
|
|
403
|
+
job_id : str
|
|
404
|
+
The unique identifier of the job to retrieve.
|
|
405
|
+
|
|
406
|
+
Returns
|
|
407
|
+
-------
|
|
408
|
+
JobInfo or None
|
|
409
|
+
The JobInfo object if found, None if the job_id does not exist.
|
|
410
|
+
"""
|
|
411
|
+
async with self.session() as session:
|
|
412
|
+
return await session.get(JobInfo, job_id)
|
|
413
|
+
|
|
414
|
+
async def get_status(self, job_id: str) -> JobStatus:
|
|
415
|
+
"""
|
|
416
|
+
Get the current status of a specific job.
|
|
417
|
+
|
|
418
|
+
Parameters
|
|
419
|
+
----------
|
|
420
|
+
job_id : str
|
|
421
|
+
The unique identifier of the job.
|
|
422
|
+
|
|
423
|
+
Returns
|
|
424
|
+
-------
|
|
425
|
+
JobStatus
|
|
426
|
+
The current status of the job, or JobStatus.NOT_FOUND if the job does not exist in the store.
|
|
427
|
+
"""
|
|
428
|
+
job = await self.get_job(job_id)
|
|
429
|
+
if job is not None:
|
|
430
|
+
return JobStatus(job.status)
|
|
431
|
+
else:
|
|
432
|
+
return JobStatus.NOT_FOUND
|
|
433
|
+
|
|
434
|
+
async def get_last_job(self) -> JobInfo | None:
|
|
435
|
+
"""
|
|
436
|
+
Retrieve the most recently created job.
|
|
437
|
+
|
|
438
|
+
Returns
|
|
439
|
+
-------
|
|
440
|
+
JobInfo or None
|
|
441
|
+
The JobInfo object for the most recently created job based on the created_at timestamp, or None if no jobs
|
|
442
|
+
exist in the store.
|
|
443
|
+
"""
|
|
444
|
+
stmt = select(JobInfo).order_by(JobInfo.created_at.desc())
|
|
445
|
+
async with self.session() as session:
|
|
446
|
+
last_job = (await session.scalars(stmt)).first()
|
|
447
|
+
|
|
448
|
+
if last_job is None:
|
|
449
|
+
logger.info("No jobs found in job store")
|
|
450
|
+
else:
|
|
129
451
|
logger.info("Retrieved last job %s created at %s", last_job.job_id, last_job.created_at)
|
|
130
|
-
return last_job
|
|
131
452
|
|
|
132
|
-
|
|
133
|
-
"""Get all jobs with the specified status."""
|
|
134
|
-
with self._lock:
|
|
135
|
-
return [job for job in self._jobs.values() if job.status == status]
|
|
453
|
+
return last_job
|
|
136
454
|
|
|
137
|
-
def
|
|
138
|
-
"""
|
|
139
|
-
|
|
140
|
-
|
|
455
|
+
async def get_jobs_by_status(self, status: str | JobStatus) -> list[JobInfo]:
|
|
456
|
+
"""
|
|
457
|
+
Retrieve all jobs that have a specific status.
|
|
458
|
+
|
|
459
|
+
Parameters
|
|
460
|
+
----------
|
|
461
|
+
status : str | JobStatus
|
|
462
|
+
The status to filter jobs by.
|
|
463
|
+
|
|
464
|
+
Returns
|
|
465
|
+
-------
|
|
466
|
+
list[JobInfo]
|
|
467
|
+
A list of JobInfo objects that have the specified status. Returns an empty list if no jobs match the
|
|
468
|
+
status.
|
|
469
|
+
"""
|
|
470
|
+
if not isinstance(status, JobStatus):
|
|
471
|
+
status = JobStatus(status)
|
|
472
|
+
|
|
473
|
+
stmt = select(JobInfo).where(JobInfo.status == status)
|
|
474
|
+
async with self.session() as session:
|
|
475
|
+
return (await session.scalars(stmt)).all()
|
|
141
476
|
|
|
142
477
|
def get_expires_at(self, job: JobInfo) -> datetime | None:
|
|
143
|
-
"""
|
|
478
|
+
"""
|
|
479
|
+
Calculate the expiration time for a given job.
|
|
480
|
+
|
|
481
|
+
Active jobs (with status in `self.ACTIVE_STATUS`) do not expire and return `None`. For non-active jobs, the
|
|
482
|
+
expiration time is calculated as updated_at + expiry_seconds.
|
|
483
|
+
|
|
484
|
+
Parameters
|
|
485
|
+
----------
|
|
486
|
+
job : JobInfo
|
|
487
|
+
The job object to calculate expiration time for.
|
|
488
|
+
|
|
489
|
+
Returns
|
|
490
|
+
-------
|
|
491
|
+
datetime or None
|
|
492
|
+
The UTC datetime when the job will expire, or None if the job is active and therefore exempt from
|
|
493
|
+
expiration.
|
|
494
|
+
"""
|
|
144
495
|
if job.status in self.ACTIVE_STATUS:
|
|
145
496
|
return None
|
|
146
|
-
return job.updated_at + timedelta(seconds=job.expiry_seconds)
|
|
147
497
|
|
|
148
|
-
|
|
498
|
+
updated_at = job.updated_at
|
|
499
|
+
if updated_at.tzinfo is None:
|
|
500
|
+
# Not all DB backends support timezone aware datetimes
|
|
501
|
+
updated_at = updated_at.replace(tzinfo=UTC)
|
|
502
|
+
|
|
503
|
+
return updated_at + timedelta(seconds=job.expiry_seconds)
|
|
504
|
+
|
|
505
|
+
async def cleanup_expired_jobs(self):
|
|
149
506
|
"""
|
|
150
507
|
Cleanup expired jobs, keeping the most recent one.
|
|
151
|
-
|
|
152
|
-
|
|
508
|
+
|
|
509
|
+
Updated_at is used instead of created_at to determine the most recent job. This is because jobs may not be
|
|
510
|
+
processed in the order they are created.
|
|
153
511
|
"""
|
|
154
512
|
now = datetime.now(UTC)
|
|
155
513
|
|
|
514
|
+
stmt = select(JobInfo).where(
|
|
515
|
+
and_(JobInfo.is_expired == sa_expr.false(),
|
|
516
|
+
JobInfo.status.not_in(self.ACTIVE_STATUS))).order_by(JobInfo.updated_at.desc())
|
|
156
517
|
# Filter out active jobs
|
|
157
|
-
with self.
|
|
158
|
-
finished_jobs =
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
518
|
+
async with (self.client() as client, self.session() as session):
|
|
519
|
+
finished_jobs = (await session.execute(stmt)).scalars().all()
|
|
520
|
+
|
|
521
|
+
# Always keep the most recent finished job
|
|
522
|
+
jobs_to_check = finished_jobs[1:]
|
|
523
|
+
|
|
524
|
+
expired_ids = []
|
|
525
|
+
for job in jobs_to_check:
|
|
526
|
+
expires_at = self.get_expires_at(job)
|
|
527
|
+
if expires_at and now > expires_at:
|
|
528
|
+
expired_ids.append(job.job_id)
|
|
529
|
+
# cleanup output dir if present
|
|
530
|
+
if job.output_path:
|
|
531
|
+
logger.info("Cleaning up output directory for job %s at %s", job.job_id, job.output_path)
|
|
532
|
+
# If it is a file remove it
|
|
533
|
+
if os.path.isfile(job.output_path):
|
|
534
|
+
os.remove(job.output_path)
|
|
535
|
+
# If it is a directory remove it
|
|
536
|
+
elif os.path.isdir(job.output_path):
|
|
537
|
+
shutil.rmtree(job.output_path)
|
|
538
|
+
|
|
539
|
+
if len(expired_ids) > 0:
|
|
540
|
+
successfully_expired = []
|
|
541
|
+
for job_id in expired_ids:
|
|
542
|
+
try:
|
|
543
|
+
var = Variable(name=job_id, client=client)
|
|
544
|
+
try:
|
|
545
|
+
future = await var.get(timeout=5)
|
|
546
|
+
if isinstance(future, Future):
|
|
547
|
+
await client.cancel([future], asynchronous=True, force=True)
|
|
548
|
+
|
|
549
|
+
except TimeoutError:
|
|
550
|
+
pass
|
|
551
|
+
|
|
552
|
+
var.delete()
|
|
553
|
+
successfully_expired.append(job_id)
|
|
554
|
+
except Exception:
|
|
555
|
+
logger.exception("Failed to expire %s", job_id)
|
|
556
|
+
|
|
557
|
+
await session.execute(
|
|
558
|
+
update(JobInfo).where(JobInfo.job_id.in_(successfully_expired)).values(is_expired=True))
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
def get_db_engine(db_url: str | None = None, echo: bool = False, use_async: bool = True) -> "Engine | AsyncEngine":
|
|
562
|
+
"""
|
|
563
|
+
Create a SQLAlchemy database engine, this should only be run once per process
|
|
564
|
+
|
|
565
|
+
Parameters
|
|
566
|
+
----------
|
|
567
|
+
db_url: str | None, optional, default=None
|
|
568
|
+
The database URL to connect to. Refer to https://docs.sqlalchemy.org/en/20/core/engines.html#database-urls
|
|
569
|
+
echo: bool, optional, default=False
|
|
570
|
+
If True, SQLAlchemy will log all SQL statements. Useful for debugging.
|
|
571
|
+
use_async: bool, optional, default=True
|
|
572
|
+
If True, use the async database engine. The JobStore class requires an async database engine, setting
|
|
573
|
+
`use_async` to False is only useful for testing.
|
|
574
|
+
"""
|
|
575
|
+
if db_url is None:
|
|
576
|
+
db_url = os.environ.get("NAT_JOB_STORE_DB_URL")
|
|
577
|
+
if db_url is None:
|
|
578
|
+
dot_tmp_dir = os.path.join(os.getcwd(), ".tmp")
|
|
579
|
+
os.makedirs(dot_tmp_dir, exist_ok=True)
|
|
580
|
+
db_file = os.path.join(dot_tmp_dir, "job_store.db")
|
|
581
|
+
if os.path.exists(db_file):
|
|
582
|
+
logger.warning("Database file %s already exists, it will be overwritten.", db_file)
|
|
583
|
+
os.remove(db_file)
|
|
584
|
+
|
|
585
|
+
if use_async:
|
|
586
|
+
driver = "+aiosqlite"
|
|
587
|
+
else:
|
|
588
|
+
driver = ""
|
|
589
|
+
|
|
590
|
+
db_url = f"sqlite{driver}:///{db_file}"
|
|
591
|
+
|
|
592
|
+
if use_async:
|
|
593
|
+
# This is actually a blocking call, it just returns an AsyncEngine
|
|
594
|
+
from sqlalchemy.ext.asyncio import create_async_engine as create_engine_fn
|
|
595
|
+
else:
|
|
596
|
+
from sqlalchemy import create_engine as create_engine_fn
|
|
597
|
+
|
|
598
|
+
return create_engine_fn(db_url, echo=echo)
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
# Prevent Sphinx from attempting to document the Base class which produces warnings
|
|
602
|
+
__all__ = ["get_db_engine", "JobInfo", "JobStatus", "JobStore"]
|