durabletask 1.3.0.dev21__tar.gz → 1.3.0.dev23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/PKG-INFO +4 -1
  2. durabletask-1.3.0.dev23/durabletask/client.py +645 -0
  3. durabletask-1.3.0.dev23/durabletask/internal/client_helpers.py +199 -0
  4. durabletask-1.3.0.dev23/durabletask/internal/grpc_interceptor.py +127 -0
  5. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/internal/helpers.py +12 -6
  6. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/internal/shared.py +40 -0
  7. durabletask-1.3.0.dev23/durabletask/internal/tracing.py +863 -0
  8. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/task.py +2 -2
  9. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/worker.py +301 -46
  10. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask.egg-info/PKG-INFO +4 -1
  11. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask.egg-info/SOURCES.txt +2 -0
  12. durabletask-1.3.0.dev23/durabletask.egg-info/requires.txt +8 -0
  13. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/pyproject.toml +8 -1
  14. durabletask-1.3.0.dev21/durabletask/client.py +0 -439
  15. durabletask-1.3.0.dev21/durabletask/internal/grpc_interceptor.py +0 -65
  16. durabletask-1.3.0.dev21/durabletask.egg-info/requires.txt +0 -4
  17. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/LICENSE +0 -0
  18. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/README.md +0 -0
  19. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/__init__.py +0 -0
  20. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/entities/__init__.py +0 -0
  21. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/entities/durable_entity.py +0 -0
  22. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/entities/entity_context.py +0 -0
  23. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/entities/entity_instance_id.py +0 -0
  24. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/entities/entity_lock.py +0 -0
  25. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/entities/entity_metadata.py +0 -0
  26. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/entities/entity_operation_failed_exception.py +0 -0
  27. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/internal/entity_state_shim.py +0 -0
  28. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/internal/exceptions.py +0 -0
  29. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/internal/json_encode_output_exception.py +0 -0
  30. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/internal/orchestration_entity_context.py +0 -0
  31. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/internal/orchestrator_service_pb2.py +0 -0
  32. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/internal/orchestrator_service_pb2.pyi +0 -0
  33. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/internal/orchestrator_service_pb2_grpc.py +0 -0
  34. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/internal/proto_task_hub_sidecar_service_stub.py +0 -0
  35. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/py.typed +0 -0
  36. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/testing/__init__.py +0 -0
  37. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask/testing/in_memory_backend.py +0 -0
  38. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask.egg-info/dependency_links.txt +0 -0
  39. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/durabletask.egg-info/top_level.txt +0 -0
  40. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev23}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: durabletask
3
- Version: 1.3.0.dev21
3
+ Version: 1.3.0.dev23
4
4
  Summary: A Durable Task Client SDK for Python
5
5
  License: MIT License
6
6
 
@@ -37,6 +37,9 @@ Requires-Dist: grpcio
37
37
  Requires-Dist: protobuf
38
38
  Requires-Dist: asyncio
39
39
  Requires-Dist: packaging
40
+ Provides-Extra: opentelemetry
41
+ Requires-Dist: opentelemetry-api>=1.0.0; extra == "opentelemetry"
42
+ Requires-Dist: opentelemetry-sdk>=1.0.0; extra == "opentelemetry"
40
43
  Dynamic: license-file
41
44
 
42
45
  # Durable Task SDK for Python
@@ -0,0 +1,645 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import logging
5
+ import uuid
6
+ from dataclasses import dataclass
7
+ from datetime import datetime
8
+ from enum import Enum
9
+ from typing import Any, List, Optional, Sequence, TypeVar, Union
10
+
11
+ import grpc
12
+ import grpc.aio
13
+
14
+ from durabletask.entities import EntityInstanceId
15
+ from durabletask.entities.entity_metadata import EntityMetadata
16
+ import durabletask.internal.helpers as helpers
17
+ import durabletask.internal.orchestrator_service_pb2 as pb
18
+ import durabletask.internal.orchestrator_service_pb2_grpc as stubs
19
+ import durabletask.internal.shared as shared
20
+ import durabletask.internal.tracing as tracing
21
+ from durabletask import task
22
+ from durabletask.internal.client_helpers import (
23
+ build_query_entities_req,
24
+ build_query_instances_req,
25
+ build_purge_by_filter_req,
26
+ build_raise_event_req,
27
+ build_schedule_new_orchestration_req,
28
+ build_signal_entity_req,
29
+ build_terminate_req,
30
+ check_continuation_token,
31
+ log_completion_state,
32
+ prepare_async_interceptors,
33
+ prepare_sync_interceptors,
34
+ )
35
+
36
+ TInput = TypeVar('TInput')
37
+ TOutput = TypeVar('TOutput')
38
+
39
+
40
+ class OrchestrationStatus(Enum):
41
+ """The status of an orchestration instance."""
42
+ RUNNING = pb.ORCHESTRATION_STATUS_RUNNING
43
+ COMPLETED = pb.ORCHESTRATION_STATUS_COMPLETED
44
+ FAILED = pb.ORCHESTRATION_STATUS_FAILED
45
+ TERMINATED = pb.ORCHESTRATION_STATUS_TERMINATED
46
+ CONTINUED_AS_NEW = pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW
47
+ PENDING = pb.ORCHESTRATION_STATUS_PENDING
48
+ SUSPENDED = pb.ORCHESTRATION_STATUS_SUSPENDED
49
+
50
+ def __str__(self):
51
+ return helpers.get_orchestration_status_str(self.value)
52
+
53
+
54
+ @dataclass
55
+ class OrchestrationState:
56
+ instance_id: str
57
+ name: str
58
+ runtime_status: OrchestrationStatus
59
+ created_at: datetime
60
+ last_updated_at: datetime
61
+ serialized_input: Optional[str]
62
+ serialized_output: Optional[str]
63
+ serialized_custom_status: Optional[str]
64
+ failure_details: Optional[task.FailureDetails]
65
+
66
+ def raise_if_failed(self):
67
+ if self.failure_details is not None:
68
+ raise OrchestrationFailedError(
69
+ f"Orchestration '{self.instance_id}' failed: {self.failure_details.message}",
70
+ self.failure_details)
71
+
72
+
73
+ @dataclass
74
+ class OrchestrationQuery:
75
+ created_time_from: Optional[datetime] = None
76
+ created_time_to: Optional[datetime] = None
77
+ runtime_status: Optional[List[OrchestrationStatus]] = None
78
+ # Some backends don't respond well with max_instance_count = None, so we use the integer limit for non-paginated
79
+ # results instead.
80
+ max_instance_count: Optional[int] = (1 << 31) - 1
81
+ fetch_inputs_and_outputs: bool = False
82
+
83
+
84
+ @dataclass
85
+ class EntityQuery:
86
+ instance_id_starts_with: Optional[str] = None
87
+ last_modified_from: Optional[datetime] = None
88
+ last_modified_to: Optional[datetime] = None
89
+ include_state: bool = True
90
+ include_transient: bool = False
91
+ page_size: Optional[int] = None
92
+
93
+
94
+ @dataclass
95
+ class PurgeInstancesResult:
96
+ deleted_instance_count: int
97
+ is_complete: bool
98
+
99
+
100
+ @dataclass
101
+ class CleanEntityStorageResult:
102
+ empty_entities_removed: int
103
+ orphaned_locks_released: int
104
+
105
+
106
+ class OrchestrationFailedError(Exception):
107
+ def __init__(self, message: str, failure_details: task.FailureDetails):
108
+ super().__init__(message)
109
+ self._failure_details = failure_details
110
+
111
+ @property
112
+ def failure_details(self):
113
+ return self._failure_details
114
+
115
+
116
+ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Optional[OrchestrationState]:
117
+ if not res.exists:
118
+ return None
119
+
120
+ state = res.orchestrationState
121
+
122
+ new_state = parse_orchestration_state(state)
123
+ new_state.instance_id = instance_id # Override instance_id with the one from the request, to match old behavior
124
+ return new_state
125
+
126
+
127
+ def parse_orchestration_state(state: pb.OrchestrationState) -> OrchestrationState:
128
+ failure_details = None
129
+ if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '':
130
+ failure_details = task.FailureDetails(
131
+ state.failureDetails.errorMessage,
132
+ state.failureDetails.errorType,
133
+ state.failureDetails.stackTrace.value if not helpers.is_empty(state.failureDetails.stackTrace) else None)
134
+
135
+ return OrchestrationState(
136
+ state.instanceId,
137
+ state.name,
138
+ OrchestrationStatus(state.orchestrationStatus),
139
+ state.createdTimestamp.ToDatetime(),
140
+ state.lastUpdatedTimestamp.ToDatetime(),
141
+ state.input.value if not helpers.is_empty(state.input) else None,
142
+ state.output.value if not helpers.is_empty(state.output) else None,
143
+ state.customStatus.value if not helpers.is_empty(state.customStatus) else None,
144
+ failure_details)
145
+
146
+
147
+ class TaskHubGrpcClient:
148
+ def __init__(self, *,
149
+ host_address: Optional[str] = None,
150
+ metadata: Optional[list[tuple[str, str]]] = None,
151
+ log_handler: Optional[logging.Handler] = None,
152
+ log_formatter: Optional[logging.Formatter] = None,
153
+ secure_channel: bool = False,
154
+ interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
155
+ default_version: Optional[str] = None):
156
+
157
+ interceptors = prepare_sync_interceptors(metadata, interceptors)
158
+
159
+ channel = shared.get_grpc_channel(
160
+ host_address=host_address,
161
+ secure_channel=secure_channel,
162
+ interceptors=interceptors
163
+ )
164
+ self._channel = channel
165
+ self._stub = stubs.TaskHubSidecarServiceStub(channel)
166
+ self._logger = shared.get_logger("client", log_handler, log_formatter)
167
+ self.default_version = default_version
168
+
169
+ def close(self) -> None:
170
+ """Closes the underlying gRPC channel."""
171
+ self._channel.close()
172
+
173
+ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
174
+ input: Optional[TInput] = None,
175
+ instance_id: Optional[str] = None,
176
+ start_at: Optional[datetime] = None,
177
+ reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None,
178
+ tags: Optional[dict[str, str]] = None,
179
+ version: Optional[str] = None) -> str:
180
+
181
+ name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)
182
+ resolved_instance_id = instance_id if instance_id else uuid.uuid4().hex
183
+ resolved_version = version if version else self.default_version
184
+
185
+ with tracing.start_create_orchestration_span(
186
+ name, resolved_instance_id, version=resolved_version,
187
+ ):
188
+ req = build_schedule_new_orchestration_req(
189
+ orchestrator, input=input, instance_id=instance_id, start_at=start_at,
190
+ reuse_id_policy=reuse_id_policy, tags=tags,
191
+ version=version if version else self.default_version)
192
+
193
+ # Inject the active PRODUCER span context into the request so the sidecar
194
+ # stores it in the executionStarted event and the worker can parent all
195
+ # orchestration/activity/timer spans under this trace.
196
+ parent_trace_ctx = tracing.get_current_trace_context()
197
+ if parent_trace_ctx is not None:
198
+ req.parentTraceContext.CopyFrom(parent_trace_ctx)
199
+
200
+ self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.")
201
+ res: pb.CreateInstanceResponse = self._stub.StartInstance(req)
202
+ return res.instanceId
203
+
204
+ def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]:
205
+ req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
206
+ res: pb.GetInstanceResponse = self._stub.GetInstance(req)
207
+ return new_orchestration_state(req.instanceId, res)
208
+
209
+ def get_all_orchestration_states(self,
210
+ orchestration_query: Optional[OrchestrationQuery] = None
211
+ ) -> List[OrchestrationState]:
212
+ if orchestration_query is None:
213
+ orchestration_query = OrchestrationQuery()
214
+ _continuation_token = None
215
+
216
+ self._logger.info(f"Querying orchestration instances with query: {orchestration_query}")
217
+
218
+ states = []
219
+
220
+ while True:
221
+ req = build_query_instances_req(orchestration_query, _continuation_token)
222
+ resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req)
223
+ states += [parse_orchestration_state(res) for res in resp.orchestrationState]
224
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
225
+ _continuation_token = resp.continuationToken
226
+ else:
227
+ break
228
+
229
+ return states
230
+
231
+ def wait_for_orchestration_start(self, instance_id: str, *,
232
+ fetch_payloads: bool = False,
233
+ timeout: int = 60) -> Optional[OrchestrationState]:
234
+ req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
235
+ try:
236
+ self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.")
237
+ res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=timeout)
238
+ return new_orchestration_state(req.instanceId, res)
239
+ except grpc.RpcError as rpc_error:
240
+ if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
241
+ # Replace gRPC error with the built-in TimeoutError
242
+ raise TimeoutError("Timed-out waiting for the orchestration to start")
243
+ else:
244
+ raise
245
+
246
+ def wait_for_orchestration_completion(self, instance_id: str, *,
247
+ fetch_payloads: bool = True,
248
+ timeout: int = 60) -> Optional[OrchestrationState]:
249
+ req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
250
+ try:
251
+ self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.")
252
+ res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=timeout)
253
+ state = new_orchestration_state(req.instanceId, res)
254
+ log_completion_state(self._logger, instance_id, state)
255
+ return state
256
+ except grpc.RpcError as rpc_error:
257
+ if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
258
+ raise TimeoutError("Timed-out waiting for the orchestration to complete")
259
+ else:
260
+ raise
261
+
262
+ def raise_orchestration_event(self, instance_id: str, event_name: str, *,
263
+ data: Optional[Any] = None) -> None:
264
+ with tracing.start_raise_event_span(event_name, instance_id):
265
+ req = build_raise_event_req(instance_id, event_name, data)
266
+ self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
267
+ self._stub.RaiseEvent(req)
268
+
269
+ def terminate_orchestration(self, instance_id: str, *,
270
+ output: Optional[Any] = None,
271
+ recursive: bool = True) -> None:
272
+ req = build_terminate_req(instance_id, output, recursive)
273
+
274
+ self._logger.info(f"Terminating instance '{instance_id}'.")
275
+ self._stub.TerminateInstance(req)
276
+
277
+ def suspend_orchestration(self, instance_id: str) -> None:
278
+ req = pb.SuspendRequest(instanceId=instance_id)
279
+ self._logger.info(f"Suspending instance '{instance_id}'.")
280
+ self._stub.SuspendInstance(req)
281
+
282
+ def resume_orchestration(self, instance_id: str) -> None:
283
+ req = pb.ResumeRequest(instanceId=instance_id)
284
+ self._logger.info(f"Resuming instance '{instance_id}'.")
285
+ self._stub.ResumeInstance(req)
286
+
287
+ def restart_orchestration(self, instance_id: str, *,
288
+ restart_with_new_instance_id: bool = False) -> str:
289
+ """Restarts an existing orchestration instance.
290
+
291
+ Args:
292
+ instance_id: The ID of the orchestration instance to restart.
293
+ restart_with_new_instance_id: If True, the restarted orchestration will use a new instance ID.
294
+ If False (default), the restarted orchestration will reuse the same instance ID.
295
+
296
+ Returns:
297
+ The instance ID of the restarted orchestration.
298
+ """
299
+ req = pb.RestartInstanceRequest(
300
+ instanceId=instance_id,
301
+ restartWithNewInstanceId=restart_with_new_instance_id)
302
+
303
+ self._logger.info(f"Restarting instance '{instance_id}'.")
304
+ res: pb.RestartInstanceResponse = self._stub.RestartInstance(req)
305
+ return res.instanceId
306
+
307
+ def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
308
+ req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
309
+ self._logger.info(f"Purging instance '{instance_id}'.")
310
+ resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req)
311
+ return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
312
+
313
+ def purge_orchestrations_by(self,
314
+ created_time_from: Optional[datetime] = None,
315
+ created_time_to: Optional[datetime] = None,
316
+ runtime_status: Optional[List[OrchestrationStatus]] = None,
317
+ recursive: bool = False) -> PurgeInstancesResult:
318
+ self._logger.info("Purging orchestrations by filter: "
319
+ f"created_time_from={created_time_from}, "
320
+ f"created_time_to={created_time_to}, "
321
+ f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
322
+ f"recursive={recursive}")
323
+ req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive)
324
+ resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req)
325
+ return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
326
+
327
+ def signal_entity(self,
328
+ entity_instance_id: EntityInstanceId,
329
+ operation_name: str,
330
+ input: Optional[Any] = None) -> None:
331
+ req = build_signal_entity_req(entity_instance_id, operation_name, input)
332
+ self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.")
333
+ self._stub.SignalEntity(req, None) # TODO: Cancellation timeout?
334
+
335
+ def get_entity(self,
336
+ entity_instance_id: EntityInstanceId,
337
+ include_state: bool = True
338
+ ) -> Optional[EntityMetadata]:
339
+ req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state)
340
+ self._logger.info(f"Getting entity '{entity_instance_id}'.")
341
+ res: pb.GetEntityResponse = self._stub.GetEntity(req)
342
+ if not res.exists:
343
+ return None
344
+
345
+ return EntityMetadata.from_entity_metadata(res.entity, include_state)
346
+
347
+ def get_all_entities(self,
348
+ entity_query: Optional[EntityQuery] = None) -> List[EntityMetadata]:
349
+ if entity_query is None:
350
+ entity_query = EntityQuery()
351
+ _continuation_token = None
352
+
353
+ self._logger.info(f"Retrieving entities by filter: {entity_query}")
354
+
355
+ entities = []
356
+
357
+ while True:
358
+ query_request = build_query_entities_req(entity_query, _continuation_token)
359
+ resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request)
360
+ entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
361
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
362
+ _continuation_token = resp.continuationToken
363
+ else:
364
+ break
365
+ return entities
366
+
367
+ def clean_entity_storage(self,
368
+ remove_empty_entities: bool = True,
369
+ release_orphaned_locks: bool = True
370
+ ) -> CleanEntityStorageResult:
371
+ self._logger.info("Cleaning entity storage")
372
+
373
+ empty_entities_removed = 0
374
+ orphaned_locks_released = 0
375
+ _continuation_token = None
376
+
377
+ while True:
378
+ req = pb.CleanEntityStorageRequest(
379
+ removeEmptyEntities=remove_empty_entities,
380
+ releaseOrphanedLocks=release_orphaned_locks,
381
+ continuationToken=_continuation_token
382
+ )
383
+ resp: pb.CleanEntityStorageResponse = self._stub.CleanEntityStorage(req)
384
+ empty_entities_removed += resp.emptyEntitiesRemoved
385
+ orphaned_locks_released += resp.orphanedLocksReleased
386
+
387
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
388
+ _continuation_token = resp.continuationToken
389
+ else:
390
+ break
391
+
392
+ return CleanEntityStorageResult(empty_entities_removed, orphaned_locks_released)
393
+
394
+
395
+ class AsyncTaskHubGrpcClient:
396
+ """Async version of TaskHubGrpcClient using grpc.aio for asyncio-based applications."""
397
+
398
+ def __init__(self, *,
399
+ host_address: Optional[str] = None,
400
+ metadata: Optional[list[tuple[str, str]]] = None,
401
+ log_handler: Optional[logging.Handler] = None,
402
+ log_formatter: Optional[logging.Formatter] = None,
403
+ secure_channel: bool = False,
404
+ interceptors: Optional[Sequence[shared.AsyncClientInterceptor]] = None,
405
+ default_version: Optional[str] = None):
406
+
407
+ interceptors = prepare_async_interceptors(metadata, interceptors)
408
+
409
+ channel = shared.get_async_grpc_channel(
410
+ host_address=host_address,
411
+ secure_channel=secure_channel,
412
+ interceptors=interceptors
413
+ )
414
+ self._channel = channel
415
+ self._stub = stubs.TaskHubSidecarServiceStub(channel)
416
+ self._logger = shared.get_logger("async_client", log_handler, log_formatter)
417
+ self.default_version = default_version
418
+
419
+ async def close(self) -> None:
420
+ """Closes the underlying gRPC channel."""
421
+ await self._channel.close()
422
+
423
+ async def __aenter__(self):
424
+ return self
425
+
426
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
427
+ await self.close()
428
+
429
+ async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
430
+ input: Optional[TInput] = None,
431
+ instance_id: Optional[str] = None,
432
+ start_at: Optional[datetime] = None,
433
+ reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None,
434
+ tags: Optional[dict[str, str]] = None,
435
+ version: Optional[str] = None) -> str:
436
+
437
+ name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)
438
+ resolved_instance_id = instance_id if instance_id else uuid.uuid4().hex
439
+ resolved_version = version if version else self.default_version
440
+
441
+ with tracing.start_create_orchestration_span(
442
+ name, resolved_instance_id, version=resolved_version,
443
+ ):
444
+ req = build_schedule_new_orchestration_req(
445
+ orchestrator, input=input, instance_id=instance_id, start_at=start_at,
446
+ reuse_id_policy=reuse_id_policy, tags=tags,
447
+ version=version if version else self.default_version)
448
+
449
+ parent_trace_ctx = tracing.get_current_trace_context()
450
+ if parent_trace_ctx is not None:
451
+ req.parentTraceContext.CopyFrom(parent_trace_ctx)
452
+
453
+ self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.")
454
+ res: pb.CreateInstanceResponse = await self._stub.StartInstance(req)
455
+ return res.instanceId
456
+
457
+ async def get_orchestration_state(self, instance_id: str, *,
458
+ fetch_payloads: bool = True) -> Optional[OrchestrationState]:
459
+ req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
460
+ res: pb.GetInstanceResponse = await self._stub.GetInstance(req)
461
+ return new_orchestration_state(req.instanceId, res)
462
+
463
+ async def get_all_orchestration_states(self,
464
+ orchestration_query: Optional[OrchestrationQuery] = None
465
+ ) -> List[OrchestrationState]:
466
+ if orchestration_query is None:
467
+ orchestration_query = OrchestrationQuery()
468
+ _continuation_token = None
469
+
470
+ self._logger.info(f"Querying orchestration instances with query: {orchestration_query}")
471
+
472
+ states = []
473
+
474
+ while True:
475
+ req = build_query_instances_req(orchestration_query, _continuation_token)
476
+ resp: pb.QueryInstancesResponse = await self._stub.QueryInstances(req)
477
+ states += [parse_orchestration_state(res) for res in resp.orchestrationState]
478
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
479
+ _continuation_token = resp.continuationToken
480
+ else:
481
+ break
482
+
483
+ return states
484
+
485
+ async def wait_for_orchestration_start(self, instance_id: str, *,
486
+ fetch_payloads: bool = False,
487
+ timeout: int = 60) -> Optional[OrchestrationState]:
488
+ req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
489
+ try:
490
+ self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.")
491
+ res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=timeout)
492
+ return new_orchestration_state(req.instanceId, res)
493
+ except grpc.aio.AioRpcError as rpc_error:
494
+ if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
495
+ raise TimeoutError("Timed-out waiting for the orchestration to start")
496
+ else:
497
+ raise
498
+
499
+ async def wait_for_orchestration_completion(self, instance_id: str, *,
500
+ fetch_payloads: bool = True,
501
+ timeout: int = 60) -> Optional[OrchestrationState]:
502
+ req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
503
+ try:
504
+ self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.")
505
+ res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=timeout)
506
+ state = new_orchestration_state(req.instanceId, res)
507
+ log_completion_state(self._logger, instance_id, state)
508
+ return state
509
+ except grpc.aio.AioRpcError as rpc_error:
510
+ if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
511
+ raise TimeoutError("Timed-out waiting for the orchestration to complete")
512
+ else:
513
+ raise
514
+
515
+ async def raise_orchestration_event(self, instance_id: str, event_name: str, *,
516
+ data: Optional[Any] = None) -> None:
517
+ with tracing.start_raise_event_span(event_name, instance_id):
518
+ req = build_raise_event_req(instance_id, event_name, data)
519
+ self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
520
+ await self._stub.RaiseEvent(req)
521
+
522
+ async def terminate_orchestration(self, instance_id: str, *,
523
+ output: Optional[Any] = None,
524
+ recursive: bool = True) -> None:
525
+ req = build_terminate_req(instance_id, output, recursive)
526
+
527
+ self._logger.info(f"Terminating instance '{instance_id}'.")
528
+ await self._stub.TerminateInstance(req)
529
+
530
+ async def suspend_orchestration(self, instance_id: str) -> None:
531
+ req = pb.SuspendRequest(instanceId=instance_id)
532
+ self._logger.info(f"Suspending instance '{instance_id}'.")
533
+ await self._stub.SuspendInstance(req)
534
+
535
+ async def resume_orchestration(self, instance_id: str) -> None:
536
+ req = pb.ResumeRequest(instanceId=instance_id)
537
+ self._logger.info(f"Resuming instance '{instance_id}'.")
538
+ await self._stub.ResumeInstance(req)
539
+
540
+ async def restart_orchestration(self, instance_id: str, *,
541
+ restart_with_new_instance_id: bool = False) -> str:
542
+ """Restarts an existing orchestration instance.
543
+
544
+ Args:
545
+ instance_id: The ID of the orchestration instance to restart.
546
+ restart_with_new_instance_id: If True, the restarted orchestration will use a new instance ID.
547
+ If False (default), the restarted orchestration will reuse the same instance ID.
548
+
549
+ Returns:
550
+ The instance ID of the restarted orchestration.
551
+ """
552
+ req = pb.RestartInstanceRequest(
553
+ instanceId=instance_id,
554
+ restartWithNewInstanceId=restart_with_new_instance_id)
555
+
556
+ self._logger.info(f"Restarting instance '{instance_id}'.")
557
+ res: pb.RestartInstanceResponse = await self._stub.RestartInstance(req)
558
+ return res.instanceId
559
+
560
+ async def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
561
+ req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
562
+ self._logger.info(f"Purging instance '{instance_id}'.")
563
+ resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req)
564
+ return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
565
+
566
+ async def purge_orchestrations_by(self,
567
+ created_time_from: Optional[datetime] = None,
568
+ created_time_to: Optional[datetime] = None,
569
+ runtime_status: Optional[List[OrchestrationStatus]] = None,
570
+ recursive: bool = False) -> PurgeInstancesResult:
571
+ self._logger.info("Purging orchestrations by filter: "
572
+ f"created_time_from={created_time_from}, "
573
+ f"created_time_to={created_time_to}, "
574
+ f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
575
+ f"recursive={recursive}")
576
+ req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive)
577
+ resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req)
578
+ return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
579
+
580
+ async def signal_entity(self,
581
+ entity_instance_id: EntityInstanceId,
582
+ operation_name: str,
583
+ input: Optional[Any] = None) -> None:
584
+ req = build_signal_entity_req(entity_instance_id, operation_name, input)
585
+ self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.")
586
+ await self._stub.SignalEntity(req, None)
587
+
588
+ async def get_entity(self,
589
+ entity_instance_id: EntityInstanceId,
590
+ include_state: bool = True
591
+ ) -> Optional[EntityMetadata]:
592
+ req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state)
593
+ self._logger.info(f"Getting entity '{entity_instance_id}'.")
594
+ res: pb.GetEntityResponse = await self._stub.GetEntity(req)
595
+ if not res.exists:
596
+ return None
597
+
598
+ return EntityMetadata.from_entity_metadata(res.entity, include_state)
599
+
600
+ async def get_all_entities(self,
601
+ entity_query: Optional[EntityQuery] = None) -> List[EntityMetadata]:
602
+ if entity_query is None:
603
+ entity_query = EntityQuery()
604
+ _continuation_token = None
605
+
606
+ self._logger.info(f"Retrieving entities by filter: {entity_query}")
607
+
608
+ entities = []
609
+
610
+ while True:
611
+ query_request = build_query_entities_req(entity_query, _continuation_token)
612
+ resp: pb.QueryEntitiesResponse = await self._stub.QueryEntities(query_request)
613
+ entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
614
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
615
+ _continuation_token = resp.continuationToken
616
+ else:
617
+ break
618
+ return entities
619
+
620
+ async def clean_entity_storage(self,
621
+ remove_empty_entities: bool = True,
622
+ release_orphaned_locks: bool = True
623
+ ) -> CleanEntityStorageResult:
624
+ self._logger.info("Cleaning entity storage")
625
+
626
+ empty_entities_removed = 0
627
+ orphaned_locks_released = 0
628
+ _continuation_token = None
629
+
630
+ while True:
631
+ req = pb.CleanEntityStorageRequest(
632
+ removeEmptyEntities=remove_empty_entities,
633
+ releaseOrphanedLocks=release_orphaned_locks,
634
+ continuationToken=_continuation_token
635
+ )
636
+ resp: pb.CleanEntityStorageResponse = await self._stub.CleanEntityStorage(req)
637
+ empty_entities_removed += resp.emptyEntitiesRemoved
638
+ orphaned_locks_released += resp.orphanedLocksReleased
639
+
640
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
641
+ _continuation_token = resp.continuationToken
642
+ else:
643
+ break
644
+
645
+ return CleanEntityStorageResult(empty_entities_removed, orphaned_locks_released)