hatchet-sdk 0.42.2__py3-none-any.whl → 0.42.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

Files changed (38) hide show
  1. hatchet_sdk/clients/admin.py +6 -6
  2. hatchet_sdk/clients/dispatcher/dispatcher.py +22 -15
  3. hatchet_sdk/clients/events.py +3 -3
  4. hatchet_sdk/clients/rest/__init__.py +9 -3
  5. hatchet_sdk/clients/rest/api/event_api.py +6 -8
  6. hatchet_sdk/clients/rest/api/workflow_api.py +69 -0
  7. hatchet_sdk/clients/rest/models/__init__.py +9 -3
  8. hatchet_sdk/clients/rest/models/concurrency_limit_strategy.py +39 -0
  9. hatchet_sdk/clients/rest/models/cron_workflows.py +3 -16
  10. hatchet_sdk/clients/rest/models/cron_workflows_method.py +37 -0
  11. hatchet_sdk/clients/rest/models/events.py +110 -0
  12. hatchet_sdk/clients/rest/models/scheduled_workflows.py +5 -9
  13. hatchet_sdk/clients/rest/models/scheduled_workflows_method.py +37 -0
  14. hatchet_sdk/clients/rest/models/worker.py +2 -10
  15. hatchet_sdk/clients/rest/models/worker_type.py +38 -0
  16. hatchet_sdk/clients/rest/models/workflow_concurrency.py +6 -13
  17. hatchet_sdk/clients/rest/models/workflow_list.py +4 -4
  18. hatchet_sdk/clients/rest/models/workflow_run.py +10 -10
  19. hatchet_sdk/context/context.py +9 -7
  20. hatchet_sdk/context/worker_context.py +5 -5
  21. hatchet_sdk/contracts/dispatcher_pb2.pyi +0 -2
  22. hatchet_sdk/contracts/events_pb2.pyi +0 -2
  23. hatchet_sdk/contracts/workflows_pb2.pyi +0 -2
  24. hatchet_sdk/hatchet.py +67 -54
  25. hatchet_sdk/labels.py +1 -1
  26. hatchet_sdk/rate_limit.py +117 -4
  27. hatchet_sdk/utils/aio_utils.py +13 -4
  28. hatchet_sdk/utils/typing.py +4 -1
  29. hatchet_sdk/v2/callable.py +24 -24
  30. hatchet_sdk/v2/concurrency.py +10 -8
  31. hatchet_sdk/v2/hatchet.py +38 -36
  32. hatchet_sdk/worker/runner/runner.py +1 -1
  33. hatchet_sdk/worker/worker.py +16 -9
  34. hatchet_sdk/workflow.py +21 -9
  35. {hatchet_sdk-0.42.2.dist-info → hatchet_sdk-0.42.4.dist-info}/METADATA +2 -1
  36. {hatchet_sdk-0.42.2.dist-info → hatchet_sdk-0.42.4.dist-info}/RECORD +38 -33
  37. {hatchet_sdk-0.42.2.dist-info → hatchet_sdk-0.42.4.dist-info}/entry_points.txt +1 -0
  38. {hatchet_sdk-0.42.2.dist-info → hatchet_sdk-0.42.4.dist-info}/WHEEL +0 -0
@@ -0,0 +1,110 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ Hatchet API
5
+
6
+ The Hatchet API
7
+
8
+ The version of the OpenAPI document: 1.0.0
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import pprint
19
+ import re # noqa: F401
20
+ from typing import Any, ClassVar, Dict, List, Optional, Set
21
+
22
+ from pydantic import BaseModel, ConfigDict, Field
23
+ from typing_extensions import Self
24
+
25
+ from hatchet_sdk.clients.rest.models.api_resource_meta import APIResourceMeta
26
+ from hatchet_sdk.clients.rest.models.event import Event
27
+
28
+
29
+ class Events(BaseModel):
30
+ """
31
+ Events
32
+ """ # noqa: E501
33
+
34
+ metadata: APIResourceMeta
35
+ events: List[Event] = Field(description="The events.")
36
+ __properties: ClassVar[List[str]] = ["metadata", "events"]
37
+
38
+ model_config = ConfigDict(
39
+ populate_by_name=True,
40
+ validate_assignment=True,
41
+ protected_namespaces=(),
42
+ )
43
+
44
+ def to_str(self) -> str:
45
+ """Returns the string representation of the model using alias"""
46
+ return pprint.pformat(self.model_dump(by_alias=True))
47
+
48
+ def to_json(self) -> str:
49
+ """Returns the JSON representation of the model using alias"""
50
+ # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
51
+ return json.dumps(self.to_dict())
52
+
53
+ @classmethod
54
+ def from_json(cls, json_str: str) -> Optional[Self]:
55
+ """Create an instance of Events from a JSON string"""
56
+ return cls.from_dict(json.loads(json_str))
57
+
58
+ def to_dict(self) -> Dict[str, Any]:
59
+ """Return the dictionary representation of the model using alias.
60
+
61
+ This has the following differences from calling pydantic's
62
+ `self.model_dump(by_alias=True)`:
63
+
64
+ * `None` is only added to the output dict for nullable fields that
65
+ were set at model initialization. Other fields with value `None`
66
+ are ignored.
67
+ """
68
+ excluded_fields: Set[str] = set([])
69
+
70
+ _dict = self.model_dump(
71
+ by_alias=True,
72
+ exclude=excluded_fields,
73
+ exclude_none=True,
74
+ )
75
+ # override the default output from pydantic by calling `to_dict()` of metadata
76
+ if self.metadata:
77
+ _dict["metadata"] = self.metadata.to_dict()
78
+ # override the default output from pydantic by calling `to_dict()` of each item in events (list)
79
+ _items = []
80
+ if self.events:
81
+ for _item_events in self.events:
82
+ if _item_events:
83
+ _items.append(_item_events.to_dict())
84
+ _dict["events"] = _items
85
+ return _dict
86
+
87
+ @classmethod
88
+ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
89
+ """Create an instance of Events from a dict"""
90
+ if obj is None:
91
+ return None
92
+
93
+ if not isinstance(obj, dict):
94
+ return cls.model_validate(obj)
95
+
96
+ _obj = cls.model_validate(
97
+ {
98
+ "metadata": (
99
+ APIResourceMeta.from_dict(obj["metadata"])
100
+ if obj.get("metadata") is not None
101
+ else None
102
+ ),
103
+ "events": (
104
+ [Event.from_dict(_item) for _item in obj["events"]]
105
+ if obj.get("events") is not None
106
+ else None
107
+ ),
108
+ }
109
+ )
110
+ return _obj
@@ -20,10 +20,13 @@ import re # noqa: F401
20
20
  from datetime import datetime
21
21
  from typing import Any, ClassVar, Dict, List, Optional, Set
22
22
 
23
- from pydantic import BaseModel, ConfigDict, Field, StrictStr, field_validator
23
+ from pydantic import BaseModel, ConfigDict, Field, StrictStr
24
24
  from typing_extensions import Annotated, Self
25
25
 
26
26
  from hatchet_sdk.clients.rest.models.api_resource_meta import APIResourceMeta
27
+ from hatchet_sdk.clients.rest.models.scheduled_workflows_method import (
28
+ ScheduledWorkflowsMethod,
29
+ )
27
30
  from hatchet_sdk.clients.rest.models.workflow_run_status import WorkflowRunStatus
28
31
 
29
32
 
@@ -54,7 +57,7 @@ class ScheduledWorkflows(BaseModel):
54
57
  workflow_run_id: Optional[
55
58
  Annotated[str, Field(min_length=36, strict=True, max_length=36)]
56
59
  ] = Field(default=None, alias="workflowRunId")
57
- method: StrictStr
60
+ method: ScheduledWorkflowsMethod
58
61
  __properties: ClassVar[List[str]] = [
59
62
  "metadata",
60
63
  "tenantId",
@@ -71,13 +74,6 @@ class ScheduledWorkflows(BaseModel):
71
74
  "method",
72
75
  ]
73
76
 
74
- @field_validator("method")
75
- def method_validate_enum(cls, value):
76
- """Validates the enum"""
77
- if value not in set(["DEFAULT", "API"]):
78
- raise ValueError("must be one of enum values ('DEFAULT', 'API')")
79
- return value
80
-
81
77
  model_config = ConfigDict(
82
78
  populate_by_name=True,
83
79
  validate_assignment=True,
@@ -0,0 +1,37 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ Hatchet API
5
+
6
+ The Hatchet API
7
+
8
+ The version of the OpenAPI document: 1.0.0
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ from enum import Enum
19
+
20
+ from typing_extensions import Self
21
+
22
+
23
+ class ScheduledWorkflowsMethod(str, Enum):
24
+ """
25
+ ScheduledWorkflowsMethod
26
+ """
27
+
28
+ """
29
+ allowed enum values
30
+ """
31
+ DEFAULT = "DEFAULT"
32
+ API = "API"
33
+
34
+ @classmethod
35
+ def from_json(cls, json_str: str) -> Self:
36
+ """Create an instance of ScheduledWorkflowsMethod from a JSON string"""
37
+ return cls(json.loads(json_str))
@@ -28,6 +28,7 @@ from hatchet_sdk.clients.rest.models.recent_step_runs import RecentStepRuns
28
28
  from hatchet_sdk.clients.rest.models.semaphore_slots import SemaphoreSlots
29
29
  from hatchet_sdk.clients.rest.models.worker_label import WorkerLabel
30
30
  from hatchet_sdk.clients.rest.models.worker_runtime_info import WorkerRuntimeInfo
31
+ from hatchet_sdk.clients.rest.models.worker_type import WorkerType
31
32
 
32
33
 
33
34
  class Worker(BaseModel):
@@ -37,7 +38,7 @@ class Worker(BaseModel):
37
38
 
38
39
  metadata: APIResourceMeta
39
40
  name: StrictStr = Field(description="The name of the worker.")
40
- type: StrictStr
41
+ type: WorkerType
41
42
  last_heartbeat_at: Optional[datetime] = Field(
42
43
  default=None,
43
44
  description="The time this worker last sent a heartbeat.",
@@ -108,15 +109,6 @@ class Worker(BaseModel):
108
109
  "runtimeInfo",
109
110
  ]
110
111
 
111
- @field_validator("type")
112
- def type_validate_enum(cls, value):
113
- """Validates the enum"""
114
- if value not in set(["SELFHOSTED", "MANAGED", "WEBHOOK"]):
115
- raise ValueError(
116
- "must be one of enum values ('SELFHOSTED', 'MANAGED', 'WEBHOOK')"
117
- )
118
- return value
119
-
120
112
  @field_validator("status")
121
113
  def status_validate_enum(cls, value):
122
114
  """Validates the enum"""
@@ -0,0 +1,38 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ Hatchet API
5
+
6
+ The Hatchet API
7
+
8
+ The version of the OpenAPI document: 1.0.0
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ from enum import Enum
19
+
20
+ from typing_extensions import Self
21
+
22
+
23
+ class WorkerType(str, Enum):
24
+ """
25
+ WorkerType
26
+ """
27
+
28
+ """
29
+ allowed enum values
30
+ """
31
+ SELFHOSTED = "SELFHOSTED"
32
+ MANAGED = "MANAGED"
33
+ WEBHOOK = "WEBHOOK"
34
+
35
+ @classmethod
36
+ def from_json(cls, json_str: str) -> Self:
37
+ """Create an instance of WorkerType from a JSON string"""
38
+ return cls(json.loads(json_str))
@@ -19,9 +19,13 @@ import pprint
19
19
  import re # noqa: F401
20
20
  from typing import Any, ClassVar, Dict, List, Optional, Set
21
21
 
22
- from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr, field_validator
22
+ from pydantic import BaseModel, ConfigDict, Field, StrictInt, StrictStr
23
23
  from typing_extensions import Self
24
24
 
25
+ from hatchet_sdk.clients.rest.models.concurrency_limit_strategy import (
26
+ ConcurrencyLimitStrategy,
27
+ )
28
+
25
29
 
26
30
  class WorkflowConcurrency(BaseModel):
27
31
  """
@@ -31,7 +35,7 @@ class WorkflowConcurrency(BaseModel):
31
35
  max_runs: StrictInt = Field(
32
36
  description="The maximum number of concurrent workflow runs.", alias="maxRuns"
33
37
  )
34
- limit_strategy: StrictStr = Field(
38
+ limit_strategy: ConcurrencyLimitStrategy = Field(
35
39
  description="The strategy to use when the concurrency limit is reached.",
36
40
  alias="limitStrategy",
37
41
  )
@@ -45,17 +49,6 @@ class WorkflowConcurrency(BaseModel):
45
49
  "getConcurrencyGroup",
46
50
  ]
47
51
 
48
- @field_validator("limit_strategy")
49
- def limit_strategy_validate_enum(cls, value):
50
- """Validates the enum"""
51
- if value not in set(
52
- ["CANCEL_IN_PROGRESS", "DROP_NEWEST", "QUEUE_NEWEST", "GROUP_ROUND_ROBIN"]
53
- ):
54
- raise ValueError(
55
- "must be one of enum values ('CANCEL_IN_PROGRESS', 'DROP_NEWEST', 'QUEUE_NEWEST', 'GROUP_ROUND_ROBIN')"
56
- )
57
- return value
58
-
59
52
  model_config = ConfigDict(
60
53
  populate_by_name=True,
61
54
  validate_assignment=True,
@@ -32,10 +32,10 @@ class WorkflowList(BaseModel):
32
32
  WorkflowList
33
33
  """ # noqa: E501
34
34
 
35
- metadata: Optional[APIResourceMeta] = None
36
- rows: Optional[List[Workflow]] = None
37
- pagination: Optional[PaginationResponse] = None
38
- __properties: ClassVar[List[str]] = ["metadata", "rows", "pagination"]
35
+ metadata: APIResourceMeta | None = None
36
+ rows: list[Workflow] | None = None
37
+ pagination: PaginationResponse | None = None
38
+ __properties: ClassVar[list[str]] = ["metadata", "rows", "pagination"]
39
39
 
40
40
  model_config = ConfigDict(
41
41
  populate_by_name=True,
@@ -39,28 +39,28 @@ class WorkflowRun(BaseModel):
39
39
  metadata: APIResourceMeta
40
40
  tenant_id: StrictStr = Field(alias="tenantId")
41
41
  workflow_version_id: StrictStr = Field(alias="workflowVersionId")
42
- workflow_version: Optional[WorkflowVersion] = Field(
42
+ workflow_version: WorkflowVersion | None = Field(
43
43
  default=None, alias="workflowVersion"
44
44
  )
45
45
  status: WorkflowRunStatus
46
- display_name: Optional[StrictStr] = Field(default=None, alias="displayName")
47
- job_runs: Optional[List[JobRun]] = Field(default=None, alias="jobRuns")
46
+ display_name: StrictStr | None = Field(default=None, alias="displayName")
47
+ job_runs: list[JobRun] | None = Field(default=None, alias="jobRuns")
48
48
  triggered_by: WorkflowRunTriggeredBy = Field(alias="triggeredBy")
49
- input: Optional[Dict[str, Any]] = None
50
- error: Optional[StrictStr] = None
51
- started_at: Optional[datetime] = Field(default=None, alias="startedAt")
52
- finished_at: Optional[datetime] = Field(default=None, alias="finishedAt")
53
- duration: Optional[StrictInt] = None
49
+ input: dict[str, Any] | None = None
50
+ error: StrictStr | None = None
51
+ started_at: datetime | None = Field(default=None, alias="startedAt")
52
+ finished_at: datetime | None = Field(default=None, alias="finishedAt")
53
+ duration: StrictInt | None = None
54
54
  parent_id: Optional[
55
55
  Annotated[str, Field(min_length=36, strict=True, max_length=36)]
56
56
  ] = Field(default=None, alias="parentId")
57
57
  parent_step_run_id: Optional[
58
58
  Annotated[str, Field(min_length=36, strict=True, max_length=36)]
59
59
  ] = Field(default=None, alias="parentStepRunId")
60
- additional_metadata: Optional[Dict[str, Any]] = Field(
60
+ additional_metadata: dict[str, Any] | None = Field(
61
61
  default=None, alias="additionalMetadata"
62
62
  )
63
- __properties: ClassVar[List[str]] = [
63
+ __properties: ClassVar[list[str]] = [
64
64
  "metadata",
65
65
  "tenantId",
66
66
  "workflowVersionId",
@@ -13,8 +13,8 @@ from hatchet_sdk.clients.rest_client import RestApi
13
13
  from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
14
14
  from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
15
15
  from hatchet_sdk.context.worker_context import WorkerContext
16
- from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData # type: ignore
17
- from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined]
16
+ from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData
17
+ from hatchet_sdk.contracts.workflows_pb2 import (
18
18
  BulkTriggerWorkflowRequest,
19
19
  TriggerWorkflowRequest,
20
20
  )
@@ -69,7 +69,7 @@ class BaseContext:
69
69
  meta = options["additional_metadata"]
70
70
 
71
71
  ## TODO: Pydantic here to simplify this
72
- trigger_options: TriggerWorkflowOptions = { # type: ignore[typeddict-item]
72
+ trigger_options: TriggerWorkflowOptions = {
73
73
  "parent_id": workflow_run_id,
74
74
  "parent_step_run_id": step_run_id,
75
75
  "child_key": key,
@@ -149,8 +149,7 @@ class ContextAioImpl(BaseContext):
149
149
  key = child_workflow_run.get("key")
150
150
  options = child_workflow_run.get("options", {})
151
151
 
152
- ## TODO: figure out why this is failing
153
- trigger_options = self._prepare_workflow_options(key, options, worker_id) # type: ignore[arg-type]
152
+ trigger_options = self._prepare_workflow_options(key, options, worker_id)
154
153
 
155
154
  bulk_trigger_workflow_runs.append(
156
155
  WorkflowRunDict(
@@ -238,14 +237,17 @@ class Context(BaseContext):
238
237
  self.input = self.data.get("input", {})
239
238
 
240
239
  def step_output(self, step: str) -> dict[str, Any] | BaseModel:
241
- validators = self.validator_registry.get(step)
240
+ workflow_validator = next(
241
+ (v for k, v in self.validator_registry.items() if k.split(":")[-1] == step),
242
+ None,
243
+ )
242
244
 
243
245
  try:
244
246
  parent_step_data = cast(dict[str, Any], self.data["parents"][step])
245
247
  except KeyError:
246
248
  raise ValueError(f"Step output for '{step}' not found")
247
249
 
248
- if validators and (v := validators.step_output):
250
+ if workflow_validator and (v := workflow_validator.step_output):
249
251
  return v.model_validate(parent_step_data)
250
252
 
251
253
  return parent_step_data
@@ -2,7 +2,7 @@ from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
2
2
 
3
3
 
4
4
  class WorkerContext:
5
- _worker_id: str = None
5
+ _worker_id: str | None = None
6
6
  _registered_workflow_names: list[str] = []
7
7
  _labels: dict[str, str | int] = {}
8
8
 
@@ -10,18 +10,18 @@ class WorkerContext:
10
10
  self._labels = labels
11
11
  self.client = client
12
12
 
13
- def labels(self):
13
+ def labels(self) -> dict[str, str | int]:
14
14
  return self._labels
15
15
 
16
- def upsert_labels(self, labels: dict[str, str | int]):
16
+ def upsert_labels(self, labels: dict[str, str | int]) -> None:
17
17
  self.client.upsert_worker_labels(self._worker_id, labels)
18
18
  self._labels.update(labels)
19
19
 
20
- async def async_upsert_labels(self, labels: dict[str, str | int]):
20
+ async def async_upsert_labels(self, labels: dict[str, str | int]) -> None:
21
21
  await self.client.async_upsert_worker_labels(self._worker_id, labels)
22
22
  self._labels.update(labels)
23
23
 
24
- def id(self) -> str:
24
+ def id(self) -> str | None:
25
25
  return self._worker_id
26
26
 
27
27
  # def has_workflow(self, workflow_name: str):
@@ -1,5 +1,3 @@
1
- # type: ignore
2
-
3
1
  from google.protobuf import timestamp_pb2 as _timestamp_pb2
4
2
  from google.protobuf.internal import containers as _containers
5
3
  from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
@@ -1,5 +1,3 @@
1
- # type: ignore
2
-
3
1
  from google.protobuf import timestamp_pb2 as _timestamp_pb2
4
2
  from google.protobuf.internal import containers as _containers
5
3
  from google.protobuf import descriptor as _descriptor
@@ -1,5 +1,3 @@
1
- # type: ignore
2
-
3
1
  from google.protobuf import timestamp_pb2 as _timestamp_pb2
4
2
  from google.protobuf.internal import containers as _containers
5
3
  from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
hatchet_sdk/hatchet.py CHANGED
@@ -1,15 +1,13 @@
1
1
  import asyncio
2
2
  import logging
3
- from typing import Any, Callable, Optional, Type, TypeVar, cast, get_type_hints
3
+ from typing import Any, Callable, Optional, Type, TypeVar, Union
4
4
 
5
5
  from pydantic import BaseModel
6
6
  from typing_extensions import deprecated
7
7
 
8
8
  from hatchet_sdk.clients.rest_client import RestApi
9
-
10
- ## TODO: These type stubs need to be updated to mass MyPy, and then we can remove this ignore
11
- ## There are file-level type ignore lines in the corresponding .pyi files.
12
- from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined]
9
+ from hatchet_sdk.context.context import Context
10
+ from hatchet_sdk.contracts.workflows_pb2 import (
13
11
  ConcurrencyLimitStrategy,
14
12
  CreateStepRateLimit,
15
13
  DesiredWorkerLabels,
@@ -20,6 +18,7 @@ from hatchet_sdk.features.scheduled import ScheduledClient
20
18
  from hatchet_sdk.labels import DesiredWorkerLabel
21
19
  from hatchet_sdk.loader import ClientConfig, ConfigLoader
22
20
  from hatchet_sdk.rate_limit import RateLimit
21
+ from hatchet_sdk.v2.callable import HatchetCallable
23
22
 
24
23
  from .client import Client, new_client, new_client_raw
25
24
  from .clients.admin import AdminClient
@@ -36,6 +35,7 @@ from .workflow import (
36
35
  )
37
36
 
38
37
  T = TypeVar("T", bound=BaseModel)
38
+ TWorkflow = TypeVar("TWorkflow", bound=object)
39
39
 
40
40
 
41
41
  def workflow(
@@ -45,30 +45,35 @@ def workflow(
45
45
  version: str = "",
46
46
  timeout: str = "60m",
47
47
  schedule_timeout: str = "5m",
48
- sticky: StickyStrategy = None,
48
+ sticky: Union[StickyStrategy.Value, None] = None, # type: ignore[name-defined]
49
49
  default_priority: int | None = None,
50
50
  concurrency: ConcurrencyExpression | None = None,
51
51
  input_validator: Type[T] | None = None,
52
- ) -> Callable[[Type[WorkflowInterface]], WorkflowMeta]:
52
+ ) -> Callable[[Type[TWorkflow]], WorkflowMeta]:
53
53
  on_events = on_events or []
54
54
  on_crons = on_crons or []
55
55
 
56
- def inner(cls: Type[WorkflowInterface]) -> WorkflowMeta:
57
- cls.on_events = on_events
58
- cls.on_crons = on_crons
59
- cls.name = name or str(cls.__name__)
60
- cls.version = version
61
- cls.timeout = timeout
62
- cls.schedule_timeout = schedule_timeout
63
- cls.sticky = sticky
64
- cls.default_priority = default_priority
65
- cls.concurrency_expression = concurrency
56
+ def inner(cls: Type[TWorkflow]) -> WorkflowMeta:
57
+ nonlocal name
58
+ name = name or str(cls.__name__)
59
+
60
+ setattr(cls, "on_events", on_events)
61
+ setattr(cls, "on_crons", on_crons)
62
+ setattr(cls, "name", name)
63
+ setattr(cls, "version", version)
64
+ setattr(cls, "timeout", timeout)
65
+ setattr(cls, "schedule_timeout", schedule_timeout)
66
+ setattr(cls, "sticky", sticky)
67
+ setattr(cls, "default_priority", default_priority)
68
+ setattr(cls, "concurrency_expression", concurrency)
69
+
66
70
  # Define a new class with the same name and bases as the original, but
67
71
  # with WorkflowMeta as its metaclass
68
72
 
69
73
  ## TODO: Figure out how to type this metaclass correctly
70
- cls.input_validator = input_validator
71
- return WorkflowMeta(cls.name, cls.__bases__, dict(cls.__dict__))
74
+ setattr(cls, "input_validator", input_validator)
75
+
76
+ return WorkflowMeta(name, cls.__bases__, dict(cls.__dict__))
72
77
 
73
78
  return inner
74
79
 
@@ -82,37 +87,39 @@ def step(
82
87
  desired_worker_labels: dict[str, DesiredWorkerLabel] = {},
83
88
  backoff_factor: float | None = None,
84
89
  backoff_max_seconds: int | None = None,
85
- ) -> Callable[[WorkflowStepProtocol], WorkflowStepProtocol]:
90
+ ) -> Callable[..., Any]:
86
91
  parents = parents or []
87
92
 
88
- def inner(func: WorkflowStepProtocol) -> WorkflowStepProtocol:
93
+ def inner(func: Callable[[Context], Any]) -> Callable[[Context], Any]:
89
94
  limits = None
90
95
  if rate_limits:
91
- limits = [
92
- CreateStepRateLimit(key=rate_limit.key, units=rate_limit.units)
93
- for rate_limit in rate_limits or []
94
- ]
95
-
96
- func._step_name = name.lower() or str(func.__name__).lower()
97
- func._step_parents = parents
98
- func._step_timeout = timeout
99
- func._step_retries = retries
100
- func._step_rate_limits = limits
101
- func._step_backoff_factor = backoff_factor
102
- func._step_backoff_max_seconds = backoff_max_seconds
103
-
104
- func._step_desired_worker_labels = {}
105
-
106
- for key, d in desired_worker_labels.items():
96
+ limits = [rate_limit._req for rate_limit in rate_limits or []]
97
+
98
+ setattr(func, "_step_name", name.lower() or str(func.__name__).lower())
99
+ setattr(func, "_step_parents", parents)
100
+ setattr(func, "_step_timeout", timeout)
101
+ setattr(func, "_step_retries", retries)
102
+ setattr(func, "_step_rate_limits", retries)
103
+ setattr(func, "_step_rate_limits", limits)
104
+ setattr(func, "_step_backoff_factor", backoff_factor)
105
+ setattr(func, "_step_backoff_max_seconds", backoff_max_seconds)
106
+
107
+ def create_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels:
107
108
  value = d["value"] if "value" in d else None
108
- func._step_desired_worker_labels[key] = DesiredWorkerLabels(
109
+ return DesiredWorkerLabels(
109
110
  strValue=str(value) if not isinstance(value, int) else None,
110
111
  intValue=value if isinstance(value, int) else None,
111
- required=d["required"] if "required" in d else None,
112
+ required=d["required"] if "required" in d else None, # type: ignore[arg-type]
112
113
  weight=d["weight"] if "weight" in d else None,
113
- comparator=d["comparator"] if "comparator" in d else None,
114
+ comparator=d["comparator"] if "comparator" in d else None, # type: ignore[arg-type]
114
115
  )
115
116
 
117
+ setattr(
118
+ func,
119
+ "_step_desired_worker_labels",
120
+ {key: create_label(d) for key, d in desired_worker_labels.items()},
121
+ )
122
+
116
123
  return func
117
124
 
118
125
  return inner
@@ -125,21 +132,23 @@ def on_failure_step(
125
132
  rate_limits: list[RateLimit] | None = None,
126
133
  backoff_factor: float | None = None,
127
134
  backoff_max_seconds: int | None = None,
128
- ) -> Callable[[WorkflowStepProtocol], WorkflowStepProtocol]:
129
- def inner(func: WorkflowStepProtocol) -> WorkflowStepProtocol:
135
+ ) -> Callable[..., Any]:
136
+ def inner(func: Callable[[Context], Any]) -> Callable[[Context], Any]:
130
137
  limits = None
131
138
  if rate_limits:
132
139
  limits = [
133
- CreateStepRateLimit(key=rate_limit.key, units=rate_limit.units)
140
+ CreateStepRateLimit(key=rate_limit.static_key, units=rate_limit.units) # type: ignore[arg-type]
134
141
  for rate_limit in rate_limits or []
135
142
  ]
136
143
 
137
- func._on_failure_step_name = name.lower() or str(func.__name__).lower()
138
- func._on_failure_step_timeout = timeout
139
- func._on_failure_step_retries = retries
140
- func._on_failure_step_rate_limits = limits
141
- func._on_failure_step_backoff_factor = backoff_factor
142
- func._on_failure_step_backoff_max_seconds = backoff_max_seconds
144
+ setattr(
145
+ func, "_on_failure_step_name", name.lower() or str(func.__name__).lower()
146
+ )
147
+ setattr(func, "_on_failure_step_timeout", timeout)
148
+ setattr(func, "_on_failure_step_retries", retries)
149
+ setattr(func, "_on_failure_step_rate_limits", limits)
150
+ setattr(func, "_on_failure_step_backoff_factor", backoff_factor)
151
+ setattr(func, "_on_failure_step_backoff_max_seconds", backoff_max_seconds)
143
152
 
144
153
  return func
145
154
 
@@ -150,11 +159,15 @@ def concurrency(
150
159
  name: str = "",
151
160
  max_runs: int = 1,
152
161
  limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS,
153
- ) -> Callable[[WorkflowStepProtocol], WorkflowStepProtocol]:
154
- def inner(func: WorkflowStepProtocol) -> WorkflowStepProtocol:
155
- func._concurrency_fn_name = name.lower() or str(func.__name__).lower()
156
- func._concurrency_max_runs = max_runs
157
- func._concurrency_limit_strategy = limit_strategy
162
+ ) -> Callable[..., Any]:
163
+ def inner(func: Callable[[Context], Any]) -> Callable[[Context], Any]:
164
+ setattr(
165
+ func,
166
+ "_concurrency_fn_name",
167
+ name.lower() or str(func.__name__).lower(),
168
+ )
169
+ setattr(func, "_concurrency_max_runs", max_runs)
170
+ setattr(func, "_concurrency_limit_strategy", limit_strategy)
158
171
 
159
172
  return func
160
173
 
hatchet_sdk/labels.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from typing import TypedDict
2
2
 
3
3
 
4
- class DesiredWorkerLabel(TypedDict):
4
+ class DesiredWorkerLabel(TypedDict, total=False):
5
5
  value: str | int
6
6
  required: bool | None = None
7
7
  weight: int | None = None