hatchet-sdk 1.6.5__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

@@ -0,0 +1,159 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ Hatchet API
5
+
6
+ The Hatchet API
7
+
8
+ The version of the OpenAPI document: 1.0.0
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import pprint
19
+ import re # noqa: F401
20
+ from datetime import datetime
21
+ from typing import Any, ClassVar, Dict, List, Optional, Set
22
+
23
+ from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr
24
+ from typing_extensions import Annotated, Self
25
+
26
+ from hatchet_sdk.clients.rest.models.api_resource_meta import APIResourceMeta
27
+ from hatchet_sdk.clients.rest.models.v1_task_status import V1TaskStatus
28
+
29
+
30
+ class V1TaskTiming(BaseModel):
31
+ """
32
+ V1TaskTiming
33
+ """ # noqa: E501
34
+
35
+ metadata: APIResourceMeta
36
+ depth: StrictInt = Field(description="The depth of the task in the waterfall.")
37
+ status: V1TaskStatus
38
+ task_display_name: StrictStr = Field(
39
+ description="The display name of the task run.", alias="taskDisplayName"
40
+ )
41
+ task_external_id: Annotated[
42
+ str, Field(min_length=36, strict=True, max_length=36)
43
+ ] = Field(description="The external ID of the task.", alias="taskExternalId")
44
+ task_id: StrictInt = Field(description="The ID of the task.", alias="taskId")
45
+ task_inserted_at: datetime = Field(
46
+ description="The timestamp the task was inserted.", alias="taskInsertedAt"
47
+ )
48
+ tenant_id: Annotated[str, Field(min_length=36, strict=True, max_length=36)] = Field(
49
+ description="The ID of the tenant.", alias="tenantId"
50
+ )
51
+ parent_task_external_id: Optional[
52
+ Annotated[str, Field(min_length=36, strict=True, max_length=36)]
53
+ ] = Field(
54
+ default=None,
55
+ description="The external ID of the parent task.",
56
+ alias="parentTaskExternalId",
57
+ )
58
+ queued_at: Optional[datetime] = Field(
59
+ default=None,
60
+ description="The timestamp the task run was queued.",
61
+ alias="queuedAt",
62
+ )
63
+ started_at: Optional[datetime] = Field(
64
+ default=None,
65
+ description="The timestamp the task run started.",
66
+ alias="startedAt",
67
+ )
68
+ finished_at: Optional[datetime] = Field(
69
+ default=None,
70
+ description="The timestamp the task run finished.",
71
+ alias="finishedAt",
72
+ )
73
+ __properties: ClassVar[List[str]] = [
74
+ "metadata",
75
+ "depth",
76
+ "status",
77
+ "taskDisplayName",
78
+ "taskExternalId",
79
+ "taskId",
80
+ "taskInsertedAt",
81
+ "tenantId",
82
+ "parentTaskExternalId",
83
+ "queuedAt",
84
+ "startedAt",
85
+ "finishedAt",
86
+ ]
87
+
88
+ model_config = ConfigDict(
89
+ populate_by_name=True,
90
+ validate_assignment=True,
91
+ protected_namespaces=(),
92
+ )
93
+
94
+ def to_str(self) -> str:
95
+ """Returns the string representation of the model using alias"""
96
+ return pprint.pformat(self.model_dump(by_alias=True))
97
+
98
+ def to_json(self) -> str:
99
+ """Returns the JSON representation of the model using alias"""
100
+ # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
101
+ return json.dumps(self.to_dict())
102
+
103
+ @classmethod
104
+ def from_json(cls, json_str: str) -> Optional[Self]:
105
+ """Create an instance of V1TaskTiming from a JSON string"""
106
+ return cls.from_dict(json.loads(json_str))
107
+
108
+ def to_dict(self) -> Dict[str, Any]:
109
+ """Return the dictionary representation of the model using alias.
110
+
111
+ This has the following differences from calling pydantic's
112
+ `self.model_dump(by_alias=True)`:
113
+
114
+ * `None` is only added to the output dict for nullable fields that
115
+ were set at model initialization. Other fields with value `None`
116
+ are ignored.
117
+ """
118
+ excluded_fields: Set[str] = set([])
119
+
120
+ _dict = self.model_dump(
121
+ by_alias=True,
122
+ exclude=excluded_fields,
123
+ exclude_none=True,
124
+ )
125
+ # override the default output from pydantic by calling `to_dict()` of metadata
126
+ if self.metadata:
127
+ _dict["metadata"] = self.metadata.to_dict()
128
+ return _dict
129
+
130
+ @classmethod
131
+ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
132
+ """Create an instance of V1TaskTiming from a dict"""
133
+ if obj is None:
134
+ return None
135
+
136
+ if not isinstance(obj, dict):
137
+ return cls.model_validate(obj)
138
+
139
+ _obj = cls.model_validate(
140
+ {
141
+ "metadata": (
142
+ APIResourceMeta.from_dict(obj["metadata"])
143
+ if obj.get("metadata") is not None
144
+ else None
145
+ ),
146
+ "depth": obj.get("depth"),
147
+ "status": obj.get("status"),
148
+ "taskDisplayName": obj.get("taskDisplayName"),
149
+ "taskExternalId": obj.get("taskExternalId"),
150
+ "taskId": obj.get("taskId"),
151
+ "taskInsertedAt": obj.get("taskInsertedAt"),
152
+ "tenantId": obj.get("tenantId"),
153
+ "parentTaskExternalId": obj.get("parentTaskExternalId"),
154
+ "queuedAt": obj.get("queuedAt"),
155
+ "startedAt": obj.get("startedAt"),
156
+ "finishedAt": obj.get("finishedAt"),
157
+ }
158
+ )
159
+ return _obj
@@ -0,0 +1,110 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ Hatchet API
5
+
6
+ The Hatchet API
7
+
8
+ The version of the OpenAPI document: 1.0.0
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import pprint
19
+ import re # noqa: F401
20
+ from typing import Any, ClassVar, Dict, List, Optional, Set
21
+
22
+ from pydantic import BaseModel, ConfigDict, Field
23
+ from typing_extensions import Self
24
+
25
+ from hatchet_sdk.clients.rest.models.pagination_response import PaginationResponse
26
+ from hatchet_sdk.clients.rest.models.v1_task_timing import V1TaskTiming
27
+
28
+
29
+ class V1TaskTimingList(BaseModel):
30
+ """
31
+ V1TaskTimingList
32
+ """ # noqa: E501
33
+
34
+ pagination: PaginationResponse
35
+ rows: List[V1TaskTiming] = Field(description="The list of task timings")
36
+ __properties: ClassVar[List[str]] = ["pagination", "rows"]
37
+
38
+ model_config = ConfigDict(
39
+ populate_by_name=True,
40
+ validate_assignment=True,
41
+ protected_namespaces=(),
42
+ )
43
+
44
+ def to_str(self) -> str:
45
+ """Returns the string representation of the model using alias"""
46
+ return pprint.pformat(self.model_dump(by_alias=True))
47
+
48
+ def to_json(self) -> str:
49
+ """Returns the JSON representation of the model using alias"""
50
+ # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
51
+ return json.dumps(self.to_dict())
52
+
53
+ @classmethod
54
+ def from_json(cls, json_str: str) -> Optional[Self]:
55
+ """Create an instance of V1TaskTimingList from a JSON string"""
56
+ return cls.from_dict(json.loads(json_str))
57
+
58
+ def to_dict(self) -> Dict[str, Any]:
59
+ """Return the dictionary representation of the model using alias.
60
+
61
+ This has the following differences from calling pydantic's
62
+ `self.model_dump(by_alias=True)`:
63
+
64
+ * `None` is only added to the output dict for nullable fields that
65
+ were set at model initialization. Other fields with value `None`
66
+ are ignored.
67
+ """
68
+ excluded_fields: Set[str] = set([])
69
+
70
+ _dict = self.model_dump(
71
+ by_alias=True,
72
+ exclude=excluded_fields,
73
+ exclude_none=True,
74
+ )
75
+ # override the default output from pydantic by calling `to_dict()` of pagination
76
+ if self.pagination:
77
+ _dict["pagination"] = self.pagination.to_dict()
78
+ # override the default output from pydantic by calling `to_dict()` of each item in rows (list)
79
+ _items = []
80
+ if self.rows:
81
+ for _item_rows in self.rows:
82
+ if _item_rows:
83
+ _items.append(_item_rows.to_dict())
84
+ _dict["rows"] = _items
85
+ return _dict
86
+
87
+ @classmethod
88
+ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
89
+ """Create an instance of V1TaskTimingList from a dict"""
90
+ if obj is None:
91
+ return None
92
+
93
+ if not isinstance(obj, dict):
94
+ return cls.model_validate(obj)
95
+
96
+ _obj = cls.model_validate(
97
+ {
98
+ "pagination": (
99
+ PaginationResponse.from_dict(obj["pagination"])
100
+ if obj.get("pagination") is not None
101
+ else None
102
+ ),
103
+ "rows": (
104
+ [V1TaskTiming.from_dict(_item) for _item in obj["rows"]]
105
+ if obj.get("rows") is not None
106
+ else None
107
+ ),
108
+ }
109
+ )
110
+ return _obj
@@ -172,6 +172,8 @@ class CronClient(BaseRestClient):
172
172
  additional_metadata: JSONSerializableMapping | None = None,
173
173
  order_by_field: CronWorkflowsOrderByField | None = None,
174
174
  order_by_direction: WorkflowRunOrderByDirection | None = None,
175
+ workflow_name: str | None = None,
176
+ cron_name: str | None = None,
175
177
  ) -> CronWorkflowsList:
176
178
  """
177
179
  Retrieve a list of all workflow cron triggers matching the criteria.
@@ -182,6 +184,8 @@ class CronClient(BaseRestClient):
182
184
  :param additional_metadata: Filter by additional metadata keys.
183
185
  :param order_by_field: The field to order the list by.
184
186
  :param order_by_direction: The direction to order the list by.
187
+ :param workflow_name: The name of the workflow to filter by.
188
+ :param cron_name: The name of the cron trigger to filter by.
185
189
 
186
190
  :return: A list of cron workflows.
187
191
  """
@@ -193,6 +197,8 @@ class CronClient(BaseRestClient):
193
197
  additional_metadata=additional_metadata,
194
198
  order_by_field=order_by_field,
195
199
  order_by_direction=order_by_direction,
200
+ workflow_name=workflow_name,
201
+ cron_name=cron_name,
196
202
  )
197
203
 
198
204
  def list(
@@ -203,6 +209,8 @@ class CronClient(BaseRestClient):
203
209
  additional_metadata: JSONSerializableMapping | None = None,
204
210
  order_by_field: CronWorkflowsOrderByField | None = None,
205
211
  order_by_direction: WorkflowRunOrderByDirection | None = None,
212
+ workflow_name: str | None = None,
213
+ cron_name: str | None = None,
206
214
  ) -> CronWorkflowsList:
207
215
  """
208
216
  Retrieve a list of all workflow cron triggers matching the criteria.
@@ -213,6 +221,8 @@ class CronClient(BaseRestClient):
213
221
  :param additional_metadata: Filter by additional metadata keys.
214
222
  :param order_by_field: The field to order the list by.
215
223
  :param order_by_direction: The direction to order the list by.
224
+ :param workflow_name: The name of the workflow to filter by.
225
+ :param cron_name: The name of the cron trigger to filter by.
216
226
 
217
227
  :return: A list of cron workflows.
218
228
  """
@@ -227,6 +237,8 @@ class CronClient(BaseRestClient):
227
237
  ),
228
238
  order_by_field=order_by_field,
229
239
  order_by_direction=order_by_direction,
240
+ workflow_name=workflow_name,
241
+ cron_name=cron_name,
230
242
  )
231
243
 
232
244
  def get(self, cron_id: str) -> CronWorkflows:
@@ -131,7 +131,7 @@ class RunsClient(BaseRestClient):
131
131
 
132
132
  async def aio_list(
133
133
  self,
134
- since: datetime = datetime.now() - timedelta(hours=1),
134
+ since: datetime | None = None,
135
135
  only_tasks: bool = False,
136
136
  offset: int | None = None,
137
137
  limit: int | None = None,
@@ -160,7 +160,7 @@ class RunsClient(BaseRestClient):
160
160
  """
161
161
  return await asyncio.to_thread(
162
162
  self.list,
163
- since=since,
163
+ since=since or datetime.now() - timedelta(days=1),
164
164
  only_tasks=only_tasks,
165
165
  offset=offset,
166
166
  limit=limit,
@@ -174,7 +174,7 @@ class RunsClient(BaseRestClient):
174
174
 
175
175
  def list(
176
176
  self,
177
- since: datetime = datetime.now() - timedelta(hours=1),
177
+ since: datetime | None = None,
178
178
  only_tasks: bool = False,
179
179
  offset: int | None = None,
180
180
  limit: int | None = None,
@@ -204,7 +204,7 @@ class RunsClient(BaseRestClient):
204
204
  with self.client() as client:
205
205
  return self._wra(client).v1_workflow_run_list(
206
206
  tenant=self.client_config.tenant_id,
207
- since=since,
207
+ since=since or datetime.now() - timedelta(days=1),
208
208
  only_tasks=only_tasks,
209
209
  offset=offset,
210
210
  limit=limit,
hatchet_sdk/hatchet.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import logging
3
+ from datetime import timedelta
3
4
  from typing import Any, Callable, Type, Union, cast, overload
4
5
 
5
6
  from hatchet_sdk import Context, DurableContext
@@ -21,8 +22,6 @@ from hatchet_sdk.logger import logger
21
22
  from hatchet_sdk.rate_limit import RateLimit
22
23
  from hatchet_sdk.runnables.standalone import Standalone
23
24
  from hatchet_sdk.runnables.types import (
24
- DEFAULT_EXECUTION_TIMEOUT,
25
- DEFAULT_SCHEDULE_TIMEOUT,
26
25
  ConcurrencyExpression,
27
26
  EmptyModel,
28
27
  R,
@@ -294,8 +293,8 @@ class Hatchet:
294
293
  sticky: StickyStrategy | None = None,
295
294
  default_priority: int = 1,
296
295
  concurrency: ConcurrencyExpression | list[ConcurrencyExpression] | None = None,
297
- schedule_timeout: Duration = DEFAULT_SCHEDULE_TIMEOUT,
298
- execution_timeout: Duration = DEFAULT_EXECUTION_TIMEOUT,
296
+ schedule_timeout: Duration = timedelta(minutes=5),
297
+ execution_timeout: Duration = timedelta(seconds=60),
299
298
  retries: int = 0,
300
299
  rate_limits: list[RateLimit] = [],
301
300
  desired_worker_labels: dict[str, DesiredWorkerLabel] = {},
@@ -316,8 +315,8 @@ class Hatchet:
316
315
  sticky: StickyStrategy | None = None,
317
316
  default_priority: int = 1,
318
317
  concurrency: ConcurrencyExpression | list[ConcurrencyExpression] | None = None,
319
- schedule_timeout: Duration = DEFAULT_SCHEDULE_TIMEOUT,
320
- execution_timeout: Duration = DEFAULT_EXECUTION_TIMEOUT,
318
+ schedule_timeout: Duration = timedelta(minutes=5),
319
+ execution_timeout: Duration = timedelta(seconds=60),
321
320
  retries: int = 0,
322
321
  rate_limits: list[RateLimit] = [],
323
322
  desired_worker_labels: dict[str, DesiredWorkerLabel] = {},
@@ -339,8 +338,8 @@ class Hatchet:
339
338
  sticky: StickyStrategy | None = None,
340
339
  default_priority: int = 1,
341
340
  concurrency: ConcurrencyExpression | list[ConcurrencyExpression] | None = None,
342
- schedule_timeout: Duration = DEFAULT_SCHEDULE_TIMEOUT,
343
- execution_timeout: Duration = DEFAULT_EXECUTION_TIMEOUT,
341
+ schedule_timeout: Duration = timedelta(minutes=5),
342
+ execution_timeout: Duration = timedelta(seconds=60),
344
343
  retries: int = 0,
345
344
  rate_limits: list[RateLimit] = [],
346
345
  desired_worker_labels: dict[str, DesiredWorkerLabel] = {},
@@ -451,8 +450,8 @@ class Hatchet:
451
450
  sticky: StickyStrategy | None = None,
452
451
  default_priority: int = 1,
453
452
  concurrency: ConcurrencyExpression | None = None,
454
- schedule_timeout: Duration = DEFAULT_SCHEDULE_TIMEOUT,
455
- execution_timeout: Duration = DEFAULT_EXECUTION_TIMEOUT,
453
+ schedule_timeout: Duration = timedelta(minutes=5),
454
+ execution_timeout: Duration = timedelta(seconds=60),
456
455
  retries: int = 0,
457
456
  rate_limits: list[RateLimit] = [],
458
457
  desired_worker_labels: dict[str, DesiredWorkerLabel] = {},
@@ -475,8 +474,8 @@ class Hatchet:
475
474
  sticky: StickyStrategy | None = None,
476
475
  default_priority: int = 1,
477
476
  concurrency: ConcurrencyExpression | None = None,
478
- schedule_timeout: Duration = DEFAULT_SCHEDULE_TIMEOUT,
479
- execution_timeout: Duration = DEFAULT_EXECUTION_TIMEOUT,
477
+ schedule_timeout: Duration = timedelta(minutes=5),
478
+ execution_timeout: Duration = timedelta(seconds=60),
480
479
  retries: int = 0,
481
480
  rate_limits: list[RateLimit] = [],
482
481
  desired_worker_labels: dict[str, DesiredWorkerLabel] = {},
@@ -498,8 +497,8 @@ class Hatchet:
498
497
  sticky: StickyStrategy | None = None,
499
498
  default_priority: int = 1,
500
499
  concurrency: ConcurrencyExpression | None = None,
501
- schedule_timeout: Duration = DEFAULT_SCHEDULE_TIMEOUT,
502
- execution_timeout: Duration = DEFAULT_EXECUTION_TIMEOUT,
500
+ schedule_timeout: Duration = timedelta(minutes=5),
501
+ execution_timeout: Duration = timedelta(seconds=60),
503
502
  retries: int = 0,
504
503
  rate_limits: list[RateLimit] = [],
505
504
  desired_worker_labels: dict[str, DesiredWorkerLabel] = {},
@@ -1,4 +1,5 @@
1
- from datetime import datetime
1
+ import asyncio
2
+ from datetime import datetime, timedelta
2
3
  from typing import Any, Generic, cast, get_type_hints
3
4
 
4
5
  from hatchet_sdk.clients.admin import (
@@ -7,7 +8,10 @@ from hatchet_sdk.clients.admin import (
7
8
  WorkflowRunTriggerConfig,
8
9
  )
9
10
  from hatchet_sdk.clients.rest.models.cron_workflows import CronWorkflows
11
+ from hatchet_sdk.clients.rest.models.v1_task_status import V1TaskStatus
12
+ from hatchet_sdk.clients.rest.models.v1_task_summary import V1TaskSummary
10
13
  from hatchet_sdk.contracts.workflows_pb2 import WorkflowVersion
14
+ from hatchet_sdk.logger import logger
11
15
  from hatchet_sdk.runnables.task import Task
12
16
  from hatchet_sdk.runnables.types import EmptyModel, R, TWorkflowInput
13
17
  from hatchet_sdk.runnables.workflow import BaseWorkflow, Workflow
@@ -294,3 +298,88 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
294
298
 
295
299
  def to_task(self) -> Task[TWorkflowInput, R]:
296
300
  return self._task
301
+
302
+ def list_runs(
303
+ self,
304
+ since: datetime | None = None,
305
+ until: datetime | None = None,
306
+ limit: int = 100,
307
+ offset: int | None = None,
308
+ statuses: list[V1TaskStatus] | None = None,
309
+ additional_metadata: dict[str, str] | None = None,
310
+ worker_id: str | None = None,
311
+ parent_task_external_id: str | None = None,
312
+ ) -> list[V1TaskSummary]:
313
+ """
314
+ List runs of the workflow.
315
+
316
+ :param since: The start time for the runs to be listed.
317
+ :param until: The end time for the runs to be listed.
318
+ :param limit: The maximum number of runs to be listed.
319
+ :param offset: The offset for pagination.
320
+ :param statuses: The statuses of the runs to be listed.
321
+ :param additional_metadata: Additional metadata for filtering the runs.
322
+ :param worker_id: The ID of the worker that ran the tasks.
323
+ :param parent_task_external_id: The external ID of the parent task.
324
+
325
+ :returns: A list of `V1TaskSummary` objects representing the runs of the workflow.
326
+ """
327
+ workflows = self.client.workflows.list(workflow_name=self._workflow.name)
328
+
329
+ if not workflows.rows:
330
+ logger.warning(f"No runs found for {self.name}")
331
+ return []
332
+
333
+ workflow = workflows.rows[0]
334
+
335
+ response = self.client.runs.list(
336
+ workflow_ids=[workflow.metadata.id],
337
+ since=since or datetime.now() - timedelta(days=1),
338
+ only_tasks=True,
339
+ offset=offset,
340
+ limit=limit,
341
+ statuses=statuses,
342
+ until=until,
343
+ additional_metadata=additional_metadata,
344
+ worker_id=worker_id,
345
+ parent_task_external_id=parent_task_external_id,
346
+ )
347
+
348
+ return response.rows
349
+
350
+ async def aio_list_runs(
351
+ self,
352
+ since: datetime | None = None,
353
+ until: datetime | None = None,
354
+ limit: int = 100,
355
+ offset: int | None = None,
356
+ statuses: list[V1TaskStatus] | None = None,
357
+ additional_metadata: dict[str, str] | None = None,
358
+ worker_id: str | None = None,
359
+ parent_task_external_id: str | None = None,
360
+ ) -> list[V1TaskSummary]:
361
+ """
362
+ List runs of the workflow.
363
+
364
+ :param since: The start time for the runs to be listed.
365
+ :param until: The end time for the runs to be listed.
366
+ :param limit: The maximum number of runs to be listed.
367
+ :param offset: The offset for pagination.
368
+ :param statuses: The statuses of the runs to be listed.
369
+ :param additional_metadata: Additional metadata for filtering the runs.
370
+ :param worker_id: The ID of the worker that ran the tasks.
371
+ :param parent_task_external_id: The external ID of the parent task.
372
+
373
+ :returns: A list of `V1TaskSummary` objects representing the runs of the workflow.
374
+ """
375
+ return await asyncio.to_thread(
376
+ self.list_runs,
377
+ since=since or datetime.now() - timedelta(days=1),
378
+ offset=offset,
379
+ limit=limit,
380
+ statuses=statuses,
381
+ until=until,
382
+ additional_metadata=additional_metadata,
383
+ worker_id=worker_id,
384
+ parent_task_external_id=parent_task_external_id,
385
+ )
@@ -1,10 +1,10 @@
1
+ from datetime import timedelta
1
2
  from typing import (
2
3
  TYPE_CHECKING,
3
4
  Any,
4
5
  Awaitable,
5
6
  Callable,
6
7
  Generic,
7
- TypeVar,
8
8
  Union,
9
9
  cast,
10
10
  get_type_hints,
@@ -18,8 +18,6 @@ from hatchet_sdk.contracts.v1.workflows_pb2 import (
18
18
  DesiredWorkerLabels,
19
19
  )
20
20
  from hatchet_sdk.runnables.types import (
21
- DEFAULT_EXECUTION_TIMEOUT,
22
- DEFAULT_SCHEDULE_TIMEOUT,
23
21
  ConcurrencyExpression,
24
22
  R,
25
23
  StepType,
@@ -43,18 +41,6 @@ if TYPE_CHECKING:
43
41
  from hatchet_sdk.runnables.workflow import Workflow
44
42
 
45
43
 
46
- T = TypeVar("T")
47
-
48
-
49
- def fall_back_to_default(value: T, default: T, fallback_value: T) -> T:
50
- ## If the value is not the default, it's set
51
- if value != default:
52
- return value
53
-
54
- ## Otherwise, it's unset, so return the fallback value
55
- return fallback_value
56
-
57
-
58
44
  class Task(Generic[TWorkflowInput, R]):
59
45
  def __init__(
60
46
  self,
@@ -68,8 +54,8 @@ class Task(Generic[TWorkflowInput, R]):
68
54
  type: StepType,
69
55
  workflow: "Workflow[TWorkflowInput]",
70
56
  name: str,
71
- execution_timeout: Duration = DEFAULT_EXECUTION_TIMEOUT,
72
- schedule_timeout: Duration = DEFAULT_SCHEDULE_TIMEOUT,
57
+ execution_timeout: Duration = timedelta(seconds=60),
58
+ schedule_timeout: Duration = timedelta(minutes=5),
73
59
  parents: "list[Task[TWorkflowInput, Any]]" = [],
74
60
  retries: int = 0,
75
61
  rate_limits: list[CreateTaskRateLimit] = [],
@@ -89,12 +75,8 @@ class Task(Generic[TWorkflowInput, R]):
89
75
  self.workflow = workflow
90
76
 
91
77
  self.type = type
92
- self.execution_timeout = fall_back_to_default(
93
- execution_timeout, DEFAULT_EXECUTION_TIMEOUT, DEFAULT_EXECUTION_TIMEOUT
94
- )
95
- self.schedule_timeout = fall_back_to_default(
96
- schedule_timeout, DEFAULT_SCHEDULE_TIMEOUT, DEFAULT_SCHEDULE_TIMEOUT
97
- )
78
+ self.execution_timeout = execution_timeout
79
+ self.schedule_timeout = schedule_timeout
98
80
  self.name = name
99
81
  self.parents = parents
100
82
  self.retries = retries
@@ -1,9 +1,8 @@
1
1
  import asyncio
2
- from datetime import timedelta
3
2
  from enum import Enum
4
3
  from typing import Any, Awaitable, Callable, ParamSpec, Type, TypeGuard, TypeVar, Union
5
4
 
6
- from pydantic import BaseModel, ConfigDict, Field, StrictInt, model_validator
5
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
7
6
 
8
7
  from hatchet_sdk.context.context import Context, DurableContext
9
8
  from hatchet_sdk.contracts.v1.workflows_pb2 import Concurrency
@@ -16,11 +15,6 @@ R = TypeVar("R", bound=Union[ValidTaskReturnType, Awaitable[ValidTaskReturnType]
16
15
  P = ParamSpec("P")
17
16
 
18
17
 
19
- DEFAULT_EXECUTION_TIMEOUT = timedelta(seconds=60)
20
- DEFAULT_SCHEDULE_TIMEOUT = timedelta(minutes=5)
21
- DEFAULT_PRIORITY = 1
22
-
23
-
24
18
  class EmptyModel(BaseModel):
25
19
  model_config = ConfigDict(extra="allow", frozen=True)
26
20
 
@@ -65,9 +59,12 @@ TWorkflowInput = TypeVar("TWorkflowInput", bound=BaseModel)
65
59
 
66
60
 
67
61
  class TaskDefaults(BaseModel):
68
- schedule_timeout: Duration = DEFAULT_SCHEDULE_TIMEOUT
69
- execution_timeout: Duration = DEFAULT_EXECUTION_TIMEOUT
70
- priority: StrictInt = Field(gt=0, lt=4, default=DEFAULT_PRIORITY)
62
+ schedule_timeout: Duration | None = None
63
+ execution_timeout: Duration | None = None
64
+ priority: int | None = Field(gt=0, lt=4, default=None)
65
+ retries: int | None = None
66
+ backoff_factor: float | None = None
67
+ backoff_max_seconds: int | None = None
71
68
 
72
69
 
73
70
  class WorkflowConfig(BaseModel):