nvidia-nat 1.3.0.dev2__py3-none-any.whl → 1.3.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (242) hide show
  1. aiq/__init__.py +2 -2
  2. nat/agent/base.py +24 -15
  3. nat/agent/dual_node.py +9 -4
  4. nat/agent/prompt_optimizer/prompt.py +68 -0
  5. nat/agent/prompt_optimizer/register.py +149 -0
  6. nat/agent/react_agent/agent.py +79 -47
  7. nat/agent/react_agent/register.py +41 -21
  8. nat/agent/reasoning_agent/reasoning_agent.py +11 -9
  9. nat/agent/register.py +1 -1
  10. nat/agent/rewoo_agent/agent.py +326 -148
  11. nat/agent/rewoo_agent/prompt.py +19 -22
  12. nat/agent/rewoo_agent/register.py +46 -26
  13. nat/agent/tool_calling_agent/agent.py +84 -28
  14. nat/agent/tool_calling_agent/register.py +51 -28
  15. nat/authentication/api_key/api_key_auth_provider.py +2 -2
  16. nat/authentication/credential_validator/bearer_token_validator.py +557 -0
  17. nat/authentication/http_basic_auth/http_basic_auth_provider.py +1 -1
  18. nat/authentication/interfaces.py +5 -2
  19. nat/authentication/oauth2/oauth2_auth_code_flow_provider.py +40 -20
  20. nat/authentication/oauth2/oauth2_resource_server_config.py +124 -0
  21. nat/authentication/register.py +0 -1
  22. nat/builder/builder.py +56 -24
  23. nat/builder/component_utils.py +9 -5
  24. nat/builder/context.py +46 -11
  25. nat/builder/eval_builder.py +16 -11
  26. nat/builder/framework_enum.py +1 -0
  27. nat/builder/front_end.py +1 -1
  28. nat/builder/function.py +378 -8
  29. nat/builder/function_base.py +3 -3
  30. nat/builder/function_info.py +6 -8
  31. nat/builder/user_interaction_manager.py +2 -2
  32. nat/builder/workflow.py +13 -1
  33. nat/builder/workflow_builder.py +281 -76
  34. nat/cli/cli_utils/config_override.py +2 -2
  35. nat/cli/commands/evaluate.py +1 -1
  36. nat/cli/commands/info/info.py +16 -6
  37. nat/cli/commands/info/list_channels.py +1 -1
  38. nat/cli/commands/info/list_components.py +7 -8
  39. nat/cli/commands/mcp/__init__.py +14 -0
  40. nat/cli/commands/mcp/mcp.py +986 -0
  41. nat/cli/commands/object_store/__init__.py +14 -0
  42. nat/cli/commands/object_store/object_store.py +227 -0
  43. nat/cli/commands/optimize.py +90 -0
  44. nat/cli/commands/registry/publish.py +2 -2
  45. nat/cli/commands/registry/pull.py +2 -2
  46. nat/cli/commands/registry/remove.py +2 -2
  47. nat/cli/commands/registry/search.py +15 -17
  48. nat/cli/commands/start.py +16 -5
  49. nat/cli/commands/uninstall.py +1 -1
  50. nat/cli/commands/workflow/templates/config.yml.j2 +0 -1
  51. nat/cli/commands/workflow/templates/pyproject.toml.j2 +4 -1
  52. nat/cli/commands/workflow/templates/register.py.j2 +0 -1
  53. nat/cli/commands/workflow/workflow_commands.py +9 -13
  54. nat/cli/entrypoint.py +8 -10
  55. nat/cli/register_workflow.py +38 -4
  56. nat/cli/type_registry.py +75 -6
  57. nat/control_flow/__init__.py +0 -0
  58. nat/control_flow/register.py +20 -0
  59. nat/control_flow/router_agent/__init__.py +0 -0
  60. nat/control_flow/router_agent/agent.py +329 -0
  61. nat/control_flow/router_agent/prompt.py +48 -0
  62. nat/control_flow/router_agent/register.py +91 -0
  63. nat/control_flow/sequential_executor.py +166 -0
  64. nat/data_models/agent.py +34 -0
  65. nat/data_models/api_server.py +10 -10
  66. nat/data_models/authentication.py +23 -9
  67. nat/data_models/common.py +1 -1
  68. nat/data_models/component.py +2 -0
  69. nat/data_models/component_ref.py +11 -0
  70. nat/data_models/config.py +41 -17
  71. nat/data_models/dataset_handler.py +1 -1
  72. nat/data_models/discovery_metadata.py +4 -4
  73. nat/data_models/evaluate.py +4 -1
  74. nat/data_models/function.py +34 -0
  75. nat/data_models/function_dependencies.py +14 -6
  76. nat/data_models/gated_field_mixin.py +242 -0
  77. nat/data_models/intermediate_step.py +3 -3
  78. nat/data_models/optimizable.py +119 -0
  79. nat/data_models/optimizer.py +149 -0
  80. nat/data_models/swe_bench_model.py +1 -1
  81. nat/data_models/temperature_mixin.py +44 -0
  82. nat/data_models/thinking_mixin.py +86 -0
  83. nat/data_models/top_p_mixin.py +44 -0
  84. nat/embedder/nim_embedder.py +1 -1
  85. nat/embedder/openai_embedder.py +1 -1
  86. nat/embedder/register.py +0 -1
  87. nat/eval/config.py +3 -1
  88. nat/eval/dataset_handler/dataset_handler.py +71 -7
  89. nat/eval/evaluate.py +86 -31
  90. nat/eval/evaluator/base_evaluator.py +1 -1
  91. nat/eval/evaluator/evaluator_model.py +13 -0
  92. nat/eval/intermediate_step_adapter.py +1 -1
  93. nat/eval/rag_evaluator/evaluate.py +2 -2
  94. nat/eval/rag_evaluator/register.py +3 -3
  95. nat/eval/register.py +4 -1
  96. nat/eval/remote_workflow.py +3 -3
  97. nat/eval/runtime_evaluator/__init__.py +14 -0
  98. nat/eval/runtime_evaluator/evaluate.py +123 -0
  99. nat/eval/runtime_evaluator/register.py +100 -0
  100. nat/eval/swe_bench_evaluator/evaluate.py +6 -6
  101. nat/eval/trajectory_evaluator/evaluate.py +1 -1
  102. nat/eval/trajectory_evaluator/register.py +1 -1
  103. nat/eval/tunable_rag_evaluator/evaluate.py +4 -7
  104. nat/eval/utils/eval_trace_ctx.py +89 -0
  105. nat/eval/utils/weave_eval.py +18 -9
  106. nat/experimental/decorators/experimental_warning_decorator.py +27 -7
  107. nat/experimental/test_time_compute/functions/plan_select_execute_function.py +7 -3
  108. nat/experimental/test_time_compute/functions/ttc_tool_orchestration_function.py +3 -3
  109. nat/experimental/test_time_compute/functions/ttc_tool_wrapper_function.py +1 -1
  110. nat/experimental/test_time_compute/models/strategy_base.py +5 -4
  111. nat/experimental/test_time_compute/register.py +0 -1
  112. nat/experimental/test_time_compute/selection/llm_based_output_merging_selector.py +1 -3
  113. nat/front_ends/console/authentication_flow_handler.py +82 -30
  114. nat/front_ends/console/console_front_end_plugin.py +8 -5
  115. nat/front_ends/fastapi/auth_flow_handlers/websocket_flow_handler.py +52 -17
  116. nat/front_ends/fastapi/dask_client_mixin.py +65 -0
  117. nat/front_ends/fastapi/fastapi_front_end_config.py +36 -5
  118. nat/front_ends/fastapi/fastapi_front_end_controller.py +4 -4
  119. nat/front_ends/fastapi/fastapi_front_end_plugin.py +135 -4
  120. nat/front_ends/fastapi/fastapi_front_end_plugin_worker.py +481 -281
  121. nat/front_ends/fastapi/job_store.py +518 -99
  122. nat/front_ends/fastapi/main.py +11 -19
  123. nat/front_ends/fastapi/message_handler.py +13 -14
  124. nat/front_ends/fastapi/message_validator.py +17 -19
  125. nat/front_ends/fastapi/response_helpers.py +4 -4
  126. nat/front_ends/fastapi/step_adaptor.py +2 -2
  127. nat/front_ends/fastapi/utils.py +57 -0
  128. nat/front_ends/mcp/introspection_token_verifier.py +73 -0
  129. nat/front_ends/mcp/mcp_front_end_config.py +10 -1
  130. nat/front_ends/mcp/mcp_front_end_plugin.py +45 -13
  131. nat/front_ends/mcp/mcp_front_end_plugin_worker.py +116 -8
  132. nat/front_ends/mcp/tool_converter.py +44 -14
  133. nat/front_ends/register.py +0 -1
  134. nat/front_ends/simple_base/simple_front_end_plugin_base.py +3 -1
  135. nat/llm/aws_bedrock_llm.py +24 -12
  136. nat/llm/azure_openai_llm.py +13 -6
  137. nat/llm/litellm_llm.py +69 -0
  138. nat/llm/nim_llm.py +20 -8
  139. nat/llm/openai_llm.py +14 -6
  140. nat/llm/register.py +4 -1
  141. nat/llm/utils/env_config_value.py +2 -3
  142. nat/llm/utils/thinking.py +215 -0
  143. nat/meta/pypi.md +9 -9
  144. nat/object_store/register.py +0 -1
  145. nat/observability/exporter/base_exporter.py +3 -3
  146. nat/observability/exporter/file_exporter.py +1 -1
  147. nat/observability/exporter/processing_exporter.py +309 -81
  148. nat/observability/exporter/span_exporter.py +1 -1
  149. nat/observability/exporter_manager.py +7 -7
  150. nat/observability/mixin/file_mixin.py +7 -7
  151. nat/observability/mixin/redaction_config_mixin.py +42 -0
  152. nat/observability/mixin/tagging_config_mixin.py +62 -0
  153. nat/observability/mixin/type_introspection_mixin.py +420 -107
  154. nat/observability/processor/batching_processor.py +5 -7
  155. nat/observability/processor/falsy_batch_filter_processor.py +55 -0
  156. nat/observability/processor/processor.py +3 -0
  157. nat/observability/processor/processor_factory.py +70 -0
  158. nat/observability/processor/redaction/__init__.py +24 -0
  159. nat/observability/processor/redaction/contextual_redaction_processor.py +125 -0
  160. nat/observability/processor/redaction/contextual_span_redaction_processor.py +66 -0
  161. nat/observability/processor/redaction/redaction_processor.py +177 -0
  162. nat/observability/processor/redaction/span_header_redaction_processor.py +92 -0
  163. nat/observability/processor/span_tagging_processor.py +68 -0
  164. nat/observability/register.py +6 -4
  165. nat/profiler/calc/calc_runner.py +3 -4
  166. nat/profiler/callbacks/agno_callback_handler.py +1 -1
  167. nat/profiler/callbacks/langchain_callback_handler.py +6 -6
  168. nat/profiler/callbacks/llama_index_callback_handler.py +3 -3
  169. nat/profiler/callbacks/semantic_kernel_callback_handler.py +3 -3
  170. nat/profiler/data_frame_row.py +1 -1
  171. nat/profiler/decorators/framework_wrapper.py +62 -13
  172. nat/profiler/decorators/function_tracking.py +160 -3
  173. nat/profiler/forecasting/models/forecasting_base_model.py +3 -1
  174. nat/profiler/inference_optimization/bottleneck_analysis/simple_stack_analysis.py +1 -1
  175. nat/profiler/inference_optimization/data_models.py +3 -3
  176. nat/profiler/inference_optimization/experimental/prefix_span_analysis.py +7 -8
  177. nat/profiler/inference_optimization/token_uniqueness.py +1 -1
  178. nat/profiler/parameter_optimization/__init__.py +0 -0
  179. nat/profiler/parameter_optimization/optimizable_utils.py +93 -0
  180. nat/profiler/parameter_optimization/optimizer_runtime.py +67 -0
  181. nat/profiler/parameter_optimization/parameter_optimizer.py +153 -0
  182. nat/profiler/parameter_optimization/parameter_selection.py +107 -0
  183. nat/profiler/parameter_optimization/pareto_visualizer.py +380 -0
  184. nat/profiler/parameter_optimization/prompt_optimizer.py +384 -0
  185. nat/profiler/parameter_optimization/update_helpers.py +66 -0
  186. nat/profiler/profile_runner.py +14 -9
  187. nat/profiler/utils.py +4 -2
  188. nat/registry_handlers/local/local_handler.py +2 -2
  189. nat/registry_handlers/package_utils.py +1 -2
  190. nat/registry_handlers/pypi/pypi_handler.py +23 -26
  191. nat/registry_handlers/register.py +3 -4
  192. nat/registry_handlers/rest/rest_handler.py +12 -13
  193. nat/retriever/milvus/retriever.py +2 -2
  194. nat/retriever/nemo_retriever/retriever.py +1 -1
  195. nat/retriever/register.py +0 -1
  196. nat/runtime/loader.py +2 -2
  197. nat/runtime/runner.py +3 -2
  198. nat/runtime/session.py +43 -8
  199. nat/settings/global_settings.py +16 -5
  200. nat/tool/chat_completion.py +5 -2
  201. nat/tool/code_execution/local_sandbox/local_sandbox_server.py +3 -3
  202. nat/tool/datetime_tools.py +49 -9
  203. nat/tool/document_search.py +2 -2
  204. nat/tool/github_tools.py +450 -0
  205. nat/tool/nvidia_rag.py +1 -1
  206. nat/tool/register.py +2 -9
  207. nat/tool/retriever.py +3 -2
  208. nat/utils/callable_utils.py +70 -0
  209. nat/utils/data_models/schema_validator.py +3 -3
  210. nat/utils/exception_handlers/automatic_retries.py +104 -51
  211. nat/utils/exception_handlers/schemas.py +1 -1
  212. nat/utils/io/yaml_tools.py +2 -2
  213. nat/utils/log_levels.py +25 -0
  214. nat/utils/reactive/base/observable_base.py +2 -2
  215. nat/utils/reactive/base/observer_base.py +1 -1
  216. nat/utils/reactive/observable.py +2 -2
  217. nat/utils/reactive/observer.py +4 -4
  218. nat/utils/reactive/subscription.py +1 -1
  219. nat/utils/settings/global_settings.py +6 -8
  220. nat/utils/type_converter.py +4 -3
  221. nat/utils/type_utils.py +9 -5
  222. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/METADATA +42 -16
  223. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/RECORD +230 -189
  224. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/entry_points.txt +1 -0
  225. nat/cli/commands/info/list_mcp.py +0 -304
  226. nat/tool/github_tools/create_github_commit.py +0 -133
  227. nat/tool/github_tools/create_github_issue.py +0 -87
  228. nat/tool/github_tools/create_github_pr.py +0 -106
  229. nat/tool/github_tools/get_github_file.py +0 -106
  230. nat/tool/github_tools/get_github_issue.py +0 -166
  231. nat/tool/github_tools/get_github_pr.py +0 -256
  232. nat/tool/github_tools/update_github_issue.py +0 -100
  233. nat/tool/mcp/exceptions.py +0 -142
  234. nat/tool/mcp/mcp_client.py +0 -255
  235. nat/tool/mcp/mcp_tool.py +0 -96
  236. nat/utils/exception_handlers/mcp.py +0 -211
  237. /nat/{tool/github_tools → agent/prompt_optimizer}/__init__.py +0 -0
  238. /nat/{tool/mcp → authentication/credential_validator}/__init__.py +0 -0
  239. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/WHEEL +0 -0
  240. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE-3rd-party.txt +0 -0
  241. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/licenses/LICENSE.md +0 -0
  242. {nvidia_nat-1.3.0.dev2.dist-info → nvidia_nat-1.3.0rc1.dist-info}/top_level.txt +0 -0
@@ -13,22 +13,67 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import json
16
17
  import logging
17
18
  import os
18
19
  import shutil
19
- import threading
20
+ import typing
21
+ from asyncio import current_task
22
+ from collections.abc import AsyncGenerator
23
+ from collections.abc import Callable
24
+ from contextlib import asynccontextmanager
20
25
  from datetime import UTC
21
26
  from datetime import datetime
22
27
  from datetime import timedelta
23
28
  from enum import Enum
24
29
  from uuid import uuid4
25
30
 
31
+ from dask.distributed import Client as DaskClient
32
+ from dask.distributed import Future
33
+ from dask.distributed import Variable
34
+ from dask.distributed import fire_and_forget
26
35
  from pydantic import BaseModel
36
+ from sqlalchemy import DateTime
37
+ from sqlalchemy import String
38
+ from sqlalchemy import and_
39
+ from sqlalchemy import select
40
+ from sqlalchemy import update
41
+ from sqlalchemy.ext.asyncio import async_scoped_session
42
+ from sqlalchemy.ext.asyncio import async_sessionmaker
43
+ from sqlalchemy.orm import DeclarativeBase
44
+ from sqlalchemy.orm import Mapped
45
+ from sqlalchemy.orm import mapped_column
46
+ from sqlalchemy.sql import expression as sa_expr
47
+
48
+ from nat.front_ends.fastapi.dask_client_mixin import DaskClientMixin
49
+
50
+ if typing.TYPE_CHECKING:
51
+ from sqlalchemy.engine import Engine
52
+ from sqlalchemy.ext.asyncio import AsyncEngine
53
+ from sqlalchemy.ext.asyncio import AsyncSession
27
54
 
28
55
  logger = logging.getLogger(__name__)
29
56
 
30
57
 
31
58
  class JobStatus(str, Enum):
59
+ """
60
+ Enumeration of possible job statuses in the job store.
61
+
62
+ Attributes
63
+ ----------
64
+ SUBMITTED : str
65
+ Job has been submitted to the scheduler but not yet started.
66
+ RUNNING : str
67
+ Job is currently being executed.
68
+ SUCCESS : str
69
+ Job completed successfully.
70
+ FAILURE : str
71
+ Job failed during execution.
72
+ INTERRUPTED : str
73
+ Job was interrupted or cancelled before completion.
74
+ NOT_FOUND : str
75
+ Job ID does not exist in the job store.
76
+ """
32
77
  SUBMITTED = "submitted"
33
78
  RUNNING = "running"
34
79
  SUCCESS = "success"
@@ -37,42 +82,175 @@ class JobStatus(str, Enum):
37
82
  NOT_FOUND = "not_found"
38
83
 
39
84
 
40
- # pydantic model for the job status
41
- class JobInfo(BaseModel):
42
- job_id: str
43
- status: JobStatus
44
- config_file: str | None
45
- error: str | None
46
- output_path: str | None
47
- created_at: datetime
48
- updated_at: datetime
49
- expiry_seconds: int
50
- output: BaseModel | None = None
51
-
52
-
53
- class JobStore:
85
+ class Base(DeclarativeBase):
86
+ pass
87
+
88
+
89
+ class JobInfo(Base):
90
+ """
91
+ SQLAlchemy model representing job metadata and status information.
92
+
93
+ This model stores comprehensive information about jobs submitted to the Dask scheduler, including their current
94
+ status, configuration, outputs, and lifecycle metadata.
95
+
96
+ Attributes
97
+ ----------
98
+ job_id : str
99
+ Unique identifier for the job (primary key).
100
+ status : JobStatus
101
+ Current status of the job.
102
+ config_file : str, optional
103
+ Path to the configuration file used for the job.
104
+ error : str, optional
105
+ Error message if the job failed.
106
+ output_path : str, optional
107
+ Path where job outputs are stored.
108
+ created_at : datetime
109
+ Timestamp when the job was created.
110
+ updated_at : datetime
111
+ Timestamp when the job was last updated.
112
+ expiry_seconds : int
113
+ Number of seconds after which the job is eligible for cleanup.
114
+ output : str, optional
115
+ Serialized job output data (JSON format).
116
+ is_expired : bool
117
+ Flag indicating if the job has been marked as expired.
118
+ """
119
+ __tablename__ = "job_info"
120
+
121
+ job_id: Mapped[str] = mapped_column(primary_key=True)
122
+ status: Mapped[JobStatus] = mapped_column(String(11))
123
+ config_file: Mapped[str] = mapped_column(nullable=True)
124
+ error: Mapped[str] = mapped_column(nullable=True)
125
+ output_path: Mapped[str] = mapped_column(nullable=True)
126
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.now(UTC))
127
+ updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True),
128
+ default=datetime.now(UTC),
129
+ onupdate=datetime.now(UTC))
130
+ expiry_seconds: Mapped[int]
131
+ output: Mapped[str] = mapped_column(nullable=True)
132
+ is_expired: Mapped[bool] = mapped_column(default=False, index=True)
133
+
134
+ def __repr__(self):
135
+ return f"JobInfo(job_id={self.job_id}, status={self.status})"
136
+
137
+
138
+ class JobStore(DaskClientMixin):
139
+ """
140
+ Tracks and manages jobs submitted to the Dask scheduler, along with persisting job metadata (JobInfo objects) in a
141
+ database.
142
+
143
+ Parameters
144
+ ----------
145
+ scheduler_address: str
146
+ The address of the Dask scheduler.
147
+ db_engine: AsyncEngine | None, optional, default=None
148
+ The database engine for the job store.
149
+ db_url: str | None, optional, default=None
150
+ The database URL to connect to, used when db_engine is not provided. Refer to:
151
+ https://docs.sqlalchemy.org/en/20/core/engines.html#database-urls
152
+ """
54
153
 
55
154
  MIN_EXPIRY = 600 # 10 minutes
56
155
  MAX_EXPIRY = 86400 # 24 hours
57
156
  DEFAULT_EXPIRY = 3600 # 1 hour
58
157
 
59
158
  # active jobs are exempt from expiry
60
- ACTIVE_STATUS = {"running", "submitted"}
159
+ ACTIVE_STATUS = {JobStatus.RUNNING, JobStatus.SUBMITTED}
160
+
161
+ def __init__(
162
+ self,
163
+ scheduler_address: str,
164
+ db_engine: "AsyncEngine | None" = None,
165
+ db_url: str | None = None,
166
+ ):
167
+ self._scheduler_address = scheduler_address
168
+
169
+ if db_engine is None:
170
+ if db_url is None:
171
+ raise ValueError("Either db_engine or db_url must be provided")
172
+
173
+ db_engine = get_db_engine(db_url, use_async=True)
174
+
175
+ # Disabling expire_on_commit allows us to detach (expunge) job
176
+ # instances from the session
177
+ session_maker = async_sessionmaker(db_engine, expire_on_commit=False)
178
+
179
+ # The async_scoped_session ensures that the same session is used
180
+ # within the same task, and that no two tasks share the same session.
181
+ self._session = async_scoped_session(session_maker, scopefunc=current_task)
182
+
183
+ @asynccontextmanager
184
+ async def client(self) -> AsyncGenerator[DaskClient]:
185
+ """
186
+ Async context manager for obtaining a Dask client connection.
187
+
188
+ Yields
189
+ ------
190
+ DaskClient
191
+ An active Dask client connected to the scheduler. The client is automatically closed when exiting the
192
+ context manager.
193
+ """
194
+ async with super().client(self._scheduler_address) as client:
195
+ yield client
196
+
197
+ @asynccontextmanager
198
+ async def session(self) -> AsyncGenerator["AsyncSession"]:
199
+ """
200
+ Async context manager for a SQLAlchemy session with automatic transaction management.
201
+
202
+ Creates a new database session scoped to the current async task and begins a transaction. The transaction is
203
+ committed on successful exit and rolled back on exception. The session is automatically removed from the
204
+ registry after use.
205
+
206
+ Yields
207
+ ------
208
+ AsyncSession
209
+ An active SQLAlchemy async session with an open transaction.
210
+ """
211
+ async with self._session() as session:
212
+ async with session.begin():
213
+ yield session
61
214
 
62
- def __init__(self):
63
- self._jobs = {}
64
- self._lock = threading.Lock() # Ensure thread safety for job operations
215
+ # Removes the current task key from the session registry, preventing
216
+ # potential memory leaks
217
+ await self._session.remove()
65
218
 
66
- def create_job(self,
67
- config_file: str | None = None,
68
- job_id: str | None = None,
69
- expiry_seconds: int = DEFAULT_EXPIRY) -> str:
219
+ def ensure_job_id(self, job_id: str | None) -> str:
220
+ """
221
+ Ensure a job ID is provided, generating a new one if necessary.
222
+
223
+ If a job ID is provided, it is returned as-is.
224
+
225
+ Parameters
226
+ ----------
227
+ job_id: str | None
228
+ The job ID to ensure, or None to generate a new one.
229
+ """
70
230
  if job_id is None:
71
231
  job_id = str(uuid4())
232
+ logger.info("Generated new job ID: %s", job_id)
233
+
234
+ return job_id
235
+
236
+ async def _create_job(self,
237
+ config_file: str | None = None,
238
+ job_id: str | None = None,
239
+ expiry_seconds: int = DEFAULT_EXPIRY) -> str:
240
+ """
241
+ Create a job and add it to the job store. This should not be called directly, but instead be called by
242
+ `submit_job`
243
+ """
244
+ job_id = self.ensure_job_id(job_id)
72
245
 
73
246
  clamped_expiry = max(self.MIN_EXPIRY, min(expiry_seconds, self.MAX_EXPIRY))
74
247
  if expiry_seconds != clamped_expiry:
75
- logger.info("Clamped expiry_seconds from %d to %d for job %s", expiry_seconds, clamped_expiry, job_id)
248
+ logger.info(
249
+ "Clamped expiry_seconds from %d to %d for job %s",
250
+ expiry_seconds,
251
+ clamped_expiry,
252
+ job_id,
253
+ )
76
254
 
77
255
  job = JobInfo(job_id=job_id,
78
256
  status=JobStatus.SUBMITTED,
@@ -83,101 +261,342 @@ class JobStore:
83
261
  output_path=None,
84
262
  expiry_seconds=clamped_expiry)
85
263
 
86
- with self._lock:
87
- self._jobs[job_id] = job
264
+ async with self.session() as session:
265
+ session.add(job)
88
266
 
89
267
  logger.info("Created new job %s with config %s", job_id, config_file)
90
268
  return job_id
91
269
 
92
- def update_status(self,
93
- job_id: str,
94
- status: str,
95
- error: str | None = None,
96
- output_path: str | None = None,
97
- output: BaseModel | None = None):
98
- if job_id not in self._jobs:
99
- raise ValueError(f"Job {job_id} not found")
100
-
101
- with self._lock:
102
- job = self._jobs[job_id]
103
- job.status = status
270
+ async def submit_job(self,
271
+ *,
272
+ job_id: str | None = None,
273
+ config_file: str | None = None,
274
+ expiry_seconds: int = DEFAULT_EXPIRY,
275
+ sync_timeout: int = 0,
276
+ job_fn: Callable[..., typing.Any],
277
+ job_args: list[typing.Any],
278
+ **job_kwargs) -> tuple[str, JobInfo | None]:
279
+ """
280
+ Submit a job to the Dask scheduler, and store job metadata in the database.
281
+
282
+ Parameters
283
+ ----------
284
+ job_id: str | None, optional, default=None
285
+ The job ID to use, or None to generate a new one.
286
+ config_file: str | None, optional, default=None
287
+ The config file used to run the job, if any.
288
+ expiry_seconds: int, optional, default=3600
289
+ The number of seconds after which the job should be considered expired. Expired jobs are eligible for
290
+ cleanup, but are not deleted immediately.
291
+ sync_timeout: int, optional, default=0
292
+ If greater than 0, wait for the job to complete for up to this many seconds. If the job does not complete
293
+ in this time, return immediately with the job ID and no job info. If the job completes in this time,
294
+ return the job ID and the job info. If 0, return immediately with the job ID and no job info.
295
+ job_fn: Callable[..., typing.Any]
296
+ The function to run as the job. This function must be serializable by Dask.
297
+ job_args: list[typing.Any]
298
+ The arguments to pass to the job function. These must be serializable by Dask.
299
+ job_kwargs: dict[str, typing.Any]
300
+ The keyword arguments to pass to the job function. These must be serializable by Dask
301
+ """
302
+ job_id = await self._create_job(job_id=job_id, config_file=config_file, expiry_seconds=expiry_seconds)
303
+
304
+ # We are intentionally not using job_id as the key, since Dask will clear the associated metadata once
305
+ # the job has completed, and we want the metadata to persist until the job expires.
306
+ async with self.client() as client:
307
+ logger.debug("Submitting job with job_args: %s, job_kwargs: %s", job_args, job_kwargs)
308
+ future = client.submit(job_fn, *job_args, key=f"{job_id}-job", **job_kwargs)
309
+
310
+ # Store the future in a variable, this allows us to potentially cancel the future later if needed
311
+ future_var = Variable(name=job_id, client=client)
312
+ await future_var.set(future)
313
+
314
+ if sync_timeout > 0:
315
+ try:
316
+ _ = await future.result(timeout=sync_timeout)
317
+ job = await self.get_job(job_id)
318
+ assert job is not None, "Job should exist after future result"
319
+ return (job_id, job)
320
+ except TimeoutError:
321
+ pass
322
+
323
+ fire_and_forget(future)
324
+
325
+ return (job_id, None)
326
+
327
+ async def update_status(self,
328
+ job_id: str,
329
+ status: str | JobStatus,
330
+ error: str | None = None,
331
+ output_path: str | None = None,
332
+ output: BaseModel | None = None):
333
+ """
334
+ Update the status and metadata of an existing job.
335
+
336
+ Parameters
337
+ ----------
338
+ job_id : str
339
+ The unique identifier of the job to update.
340
+ status : str | JobStatus
341
+ The new status to set for the job (should be a valid JobStatus value).
342
+ error : str, optional, default=None
343
+ Error message to store if the job failed.
344
+ output_path : str, optional, default=None
345
+ Path where job outputs are stored.
346
+ output : BaseModel, optional, default=None
347
+ Job output data. Can be a Pydantic BaseModel, dict, list, or string. BaseModel and dict/list objects are
348
+ serialized to JSON for storage.
349
+
350
+ Raises
351
+ ------
352
+ ValueError
353
+ If the specified job_id does not exist in the job store.
354
+ """
355
+
356
+ async with self.session() as session:
357
+ job: JobInfo = await session.get(JobInfo, job_id)
358
+ if job is None:
359
+ raise ValueError(f"Job {job_id} not found in job store")
360
+
361
+ if not isinstance(status, JobStatus):
362
+ status = JobStatus(status)
363
+
364
+ job.status = status.value
104
365
  job.error = error
105
366
  job.output_path = output_path
106
367
  job.updated_at = datetime.now(UTC)
368
+
369
+ if isinstance(output, BaseModel):
370
+ # Convert BaseModel to JSON string for storage
371
+ output = output.model_dump_json(round_trip=True)
372
+
373
+ if isinstance(output, dict | list):
374
+ # Convert dict or list to JSON string for storage
375
+ output = json.dumps(output)
376
+
107
377
  job.output = output
108
378
 
109
- def get_status(self, job_id: str) -> JobInfo | None:
110
- with self._lock:
111
- return self._jobs.get(job_id)
112
-
113
- def list_jobs(self):
114
- with self._lock:
115
- return self._jobs
116
-
117
- def get_job(self, job_id: str) -> JobInfo | None:
118
- """Get a job by its ID."""
119
- with self._lock:
120
- return self._jobs.get(job_id)
121
-
122
- def get_last_job(self) -> JobInfo | None:
123
- """Get the last created job."""
124
- with self._lock:
125
- if not self._jobs:
126
- logger.info("No jobs found in job store")
127
- return None
128
- last_job = max(self._jobs.values(), key=lambda job: job.created_at)
379
+ async def get_all_jobs(self) -> list[JobInfo]:
380
+ """
381
+ Retrieve all jobs from the job store.
382
+
383
+ Returns
384
+ -------
385
+ list[JobInfo]
386
+ A list of all JobInfo objects in the database. This operation can be expensive if there are many jobs
387
+ stored.
388
+
389
+ Warning
390
+ -------
391
+ This method loads all jobs into memory and should be used with caution in production environments with large
392
+ job stores.
393
+ """
394
+ async with self.session() as session:
395
+ return (await session.scalars(select(JobInfo))).all()
396
+
397
+ async def get_job(self, job_id: str) -> JobInfo | None:
398
+ """
399
+ Retrieve a specific job by its unique identifier.
400
+
401
+ Parameters
402
+ ----------
403
+ job_id : str
404
+ The unique identifier of the job to retrieve.
405
+
406
+ Returns
407
+ -------
408
+ JobInfo or None
409
+ The JobInfo object if found, None if the job_id does not exist.
410
+ """
411
+ async with self.session() as session:
412
+ return await session.get(JobInfo, job_id)
413
+
414
+ async def get_status(self, job_id: str) -> JobStatus:
415
+ """
416
+ Get the current status of a specific job.
417
+
418
+ Parameters
419
+ ----------
420
+ job_id : str
421
+ The unique identifier of the job.
422
+
423
+ Returns
424
+ -------
425
+ JobStatus
426
+ The current status of the job, or JobStatus.NOT_FOUND if the job does not exist in the store.
427
+ """
428
+ job = await self.get_job(job_id)
429
+ if job is not None:
430
+ return JobStatus(job.status)
431
+ else:
432
+ return JobStatus.NOT_FOUND
433
+
434
+ async def get_last_job(self) -> JobInfo | None:
435
+ """
436
+ Retrieve the most recently created job.
437
+
438
+ Returns
439
+ -------
440
+ JobInfo or None
441
+ The JobInfo object for the most recently created job based on the created_at timestamp, or None if no jobs
442
+ exist in the store.
443
+ """
444
+ stmt = select(JobInfo).order_by(JobInfo.created_at.desc())
445
+ async with self.session() as session:
446
+ last_job = (await session.scalars(stmt)).first()
447
+
448
+ if last_job is None:
449
+ logger.info("No jobs found in job store")
450
+ else:
129
451
  logger.info("Retrieved last job %s created at %s", last_job.job_id, last_job.created_at)
130
- return last_job
131
452
 
132
- def get_jobs_by_status(self, status: str) -> list[JobInfo]:
133
- """Get all jobs with the specified status."""
134
- with self._lock:
135
- return [job for job in self._jobs.values() if job.status == status]
453
+ return last_job
136
454
 
137
- def get_all_jobs(self) -> list[JobInfo]:
138
- """Get all jobs in the store."""
139
- with self._lock:
140
- return list(self._jobs.values())
455
+ async def get_jobs_by_status(self, status: str | JobStatus) -> list[JobInfo]:
456
+ """
457
+ Retrieve all jobs that have a specific status.
458
+
459
+ Parameters
460
+ ----------
461
+ status : str | JobStatus
462
+ The status to filter jobs by.
463
+
464
+ Returns
465
+ -------
466
+ list[JobInfo]
467
+ A list of JobInfo objects that have the specified status. Returns an empty list if no jobs match the
468
+ status.
469
+ """
470
+ if not isinstance(status, JobStatus):
471
+ status = JobStatus(status)
472
+
473
+ stmt = select(JobInfo).where(JobInfo.status == status)
474
+ async with self.session() as session:
475
+ return (await session.scalars(stmt)).all()
141
476
 
142
477
  def get_expires_at(self, job: JobInfo) -> datetime | None:
143
- """Get the time for a job to expire."""
478
+ """
479
+ Calculate the expiration time for a given job.
480
+
481
+ Active jobs (with status in `self.ACTIVE_STATUS`) do not expire and return `None`. For non-active jobs, the
482
+ expiration time is calculated as updated_at + expiry_seconds.
483
+
484
+ Parameters
485
+ ----------
486
+ job : JobInfo
487
+ The job object to calculate expiration time for.
488
+
489
+ Returns
490
+ -------
491
+ datetime or None
492
+ The UTC datetime when the job will expire, or None if the job is active and therefore exempt from
493
+ expiration.
494
+ """
144
495
  if job.status in self.ACTIVE_STATUS:
145
496
  return None
146
- return job.updated_at + timedelta(seconds=job.expiry_seconds)
147
497
 
148
- def cleanup_expired_jobs(self):
498
+ updated_at = job.updated_at
499
+ if updated_at.tzinfo is None:
500
+ # Not all DB backends support timezone aware datetimes
501
+ updated_at = updated_at.replace(tzinfo=UTC)
502
+
503
+ return updated_at + timedelta(seconds=job.expiry_seconds)
504
+
505
+ async def cleanup_expired_jobs(self):
149
506
  """
150
507
  Cleanup expired jobs, keeping the most recent one.
151
- Updated_at is used instead of created_at to determine the most recent job.
152
- This is because jobs may not be processed in the order they are created.
508
+
509
+ Updated_at is used instead of created_at to determine the most recent job. This is because jobs may not be
510
+ processed in the order they are created.
153
511
  """
154
512
  now = datetime.now(UTC)
155
513
 
514
+ stmt = select(JobInfo).where(
515
+ and_(JobInfo.is_expired == sa_expr.false(),
516
+ JobInfo.status.not_in(self.ACTIVE_STATUS))).order_by(JobInfo.updated_at.desc())
156
517
  # Filter out active jobs
157
- with self._lock:
158
- finished_jobs = {job_id: job for job_id, job in self._jobs.items() if job.status not in self.ACTIVE_STATUS}
159
-
160
- # Sort finished jobs by updated_at descending
161
- sorted_finished = sorted(finished_jobs.items(), key=lambda item: item[1].updated_at, reverse=True)
162
-
163
- # Always keep the most recent finished job
164
- jobs_to_check = sorted_finished[1:]
165
-
166
- expired_ids = []
167
- for job_id, job in jobs_to_check:
168
- expires_at = self.get_expires_at(job)
169
- if expires_at and now > expires_at:
170
- expired_ids.append(job_id)
171
- # cleanup output dir if present
172
- if job.output_path:
173
- logger.info("Cleaning up output directory for job %s at %s", job_id, job.output_path)
174
- # If it is a file remove it
175
- if os.path.isfile(job.output_path):
176
- os.remove(job.output_path)
177
- # If it is a directory remove it
178
- elif os.path.isdir(job.output_path):
179
- shutil.rmtree(job.output_path)
180
-
181
- with self._lock:
182
- for job_id in expired_ids:
183
- del self._jobs[job_id]
518
+ async with (self.client() as client, self.session() as session):
519
+ finished_jobs = (await session.execute(stmt)).scalars().all()
520
+
521
+ # Always keep the most recent finished job
522
+ jobs_to_check = finished_jobs[1:]
523
+
524
+ expired_ids = []
525
+ for job in jobs_to_check:
526
+ expires_at = self.get_expires_at(job)
527
+ if expires_at and now > expires_at:
528
+ expired_ids.append(job.job_id)
529
+ # cleanup output dir if present
530
+ if job.output_path:
531
+ logger.info("Cleaning up output directory for job %s at %s", job.job_id, job.output_path)
532
+ # If it is a file remove it
533
+ if os.path.isfile(job.output_path):
534
+ os.remove(job.output_path)
535
+ # If it is a directory remove it
536
+ elif os.path.isdir(job.output_path):
537
+ shutil.rmtree(job.output_path)
538
+
539
+ if len(expired_ids) > 0:
540
+ successfully_expired = []
541
+ for job_id in expired_ids:
542
+ try:
543
+ var = Variable(name=job_id, client=client)
544
+ try:
545
+ future = await var.get(timeout=5)
546
+ if isinstance(future, Future):
547
+ await client.cancel([future], asynchronous=True, force=True)
548
+
549
+ except TimeoutError:
550
+ pass
551
+
552
+ var.delete()
553
+ successfully_expired.append(job_id)
554
+ except Exception:
555
+ logger.exception("Failed to expire %s", job_id)
556
+
557
+ await session.execute(
558
+ update(JobInfo).where(JobInfo.job_id.in_(successfully_expired)).values(is_expired=True))
559
+
560
+
561
+ def get_db_engine(db_url: str | None = None, echo: bool = False, use_async: bool = True) -> "Engine | AsyncEngine":
562
+ """
563
+ Create a SQLAlchemy database engine, this should only be run once per process
564
+
565
+ Parameters
566
+ ----------
567
+ db_url: str | None, optional, default=None
568
+ The database URL to connect to. Refer to https://docs.sqlalchemy.org/en/20/core/engines.html#database-urls
569
+ echo: bool, optional, default=False
570
+ If True, SQLAlchemy will log all SQL statements. Useful for debugging.
571
+ use_async: bool, optional, default=True
572
+ If True, use the async database engine. The JobStore class requires an async database engine, setting
573
+ `use_async` to False is only useful for testing.
574
+ """
575
+ if db_url is None:
576
+ db_url = os.environ.get("NAT_JOB_STORE_DB_URL")
577
+ if db_url is None:
578
+ dot_tmp_dir = os.path.join(os.getcwd(), ".tmp")
579
+ os.makedirs(dot_tmp_dir, exist_ok=True)
580
+ db_file = os.path.join(dot_tmp_dir, "job_store.db")
581
+ if os.path.exists(db_file):
582
+ logger.warning("Database file %s already exists, it will be overwritten.", db_file)
583
+ os.remove(db_file)
584
+
585
+ if use_async:
586
+ driver = "+aiosqlite"
587
+ else:
588
+ driver = ""
589
+
590
+ db_url = f"sqlite{driver}:///{db_file}"
591
+
592
+ if use_async:
593
+ # This is actually a blocking call, it just returns an AsyncEngine
594
+ from sqlalchemy.ext.asyncio import create_async_engine as create_engine_fn
595
+ else:
596
+ from sqlalchemy import create_engine as create_engine_fn
597
+
598
+ return create_engine_fn(db_url, echo=echo)
599
+
600
+
601
+ # Prevent Sphinx from attempting to document the Base class which produces warnings
602
+ __all__ = ["get_db_engine", "JobInfo", "JobStatus", "JobStore"]