durabletask 0.0.0.dev67__py3-none-any.whl → 0.0.0.dev69__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- durabletask/__init__.py +7 -1
- durabletask/client.py +400 -111
- durabletask/extensions/__init__.py +4 -0
- durabletask/extensions/azure_blob_payloads/__init__.py +38 -0
- durabletask/extensions/azure_blob_payloads/blob_payload_store.py +225 -0
- durabletask/extensions/azure_blob_payloads/options.py +40 -0
- durabletask/internal/client_helpers.py +199 -0
- durabletask/internal/grpc_interceptor.py +74 -12
- durabletask/internal/helpers.py +17 -7
- durabletask/internal/orchestrator_service_pb2.py +272 -220
- durabletask/internal/orchestrator_service_pb2.pyi +232 -26
- durabletask/internal/orchestrator_service_pb2_grpc.py +132 -0
- durabletask/internal/proto_task_hub_sidecar_service_stub.py +3 -0
- durabletask/internal/shared.py +45 -8
- durabletask/internal/tracing.py +863 -0
- durabletask/payload/__init__.py +29 -0
- durabletask/payload/helpers.py +349 -0
- durabletask/payload/store.py +91 -0
- durabletask/task.py +73 -3
- durabletask/testing/__init__.py +14 -0
- durabletask/testing/in_memory_backend.py +1642 -0
- durabletask/worker.py +341 -50
- {durabletask-0.0.0.dev67.dist-info → durabletask-0.0.0.dev69.dist-info}/METADATA +20 -1
- durabletask-0.0.0.dev69.dist-info/RECORD +39 -0
- {durabletask-0.0.0.dev67.dist-info → durabletask-0.0.0.dev69.dist-info}/WHEEL +1 -1
- durabletask-0.0.0.dev67.dist-info/RECORD +0 -28
- {durabletask-0.0.0.dev67.dist-info → durabletask-0.0.0.dev69.dist-info}/licenses/LICENSE +0 -0
- {durabletask-0.0.0.dev67.dist-info → durabletask-0.0.0.dev69.dist-info}/top_level.txt +0 -0
durabletask/__init__.py
CHANGED
|
@@ -3,8 +3,14 @@
|
|
|
3
3
|
|
|
4
4
|
"""Durable Task SDK for Python"""
|
|
5
5
|
|
|
6
|
+
from durabletask.payload.store import LargePayloadStorageOptions, PayloadStore
|
|
6
7
|
from durabletask.worker import ConcurrencyOptions, VersioningOptions
|
|
7
8
|
|
|
8
|
-
__all__ = [
|
|
9
|
+
__all__ = [
|
|
10
|
+
"ConcurrencyOptions",
|
|
11
|
+
"LargePayloadStorageOptions",
|
|
12
|
+
"PayloadStore",
|
|
13
|
+
"VersioningOptions",
|
|
14
|
+
]
|
|
9
15
|
|
|
10
16
|
PACKAGE_NAME = "durabletask"
|
durabletask/client.py
CHANGED
|
@@ -4,11 +4,12 @@
|
|
|
4
4
|
import logging
|
|
5
5
|
import uuid
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
-
from datetime import datetime
|
|
7
|
+
from datetime import datetime
|
|
8
8
|
from enum import Enum
|
|
9
9
|
from typing import Any, List, Optional, Sequence, TypeVar, Union
|
|
10
10
|
|
|
11
11
|
import grpc
|
|
12
|
+
import grpc.aio
|
|
12
13
|
|
|
13
14
|
from durabletask.entities import EntityInstanceId
|
|
14
15
|
from durabletask.entities.entity_metadata import EntityMetadata
|
|
@@ -16,8 +17,23 @@ import durabletask.internal.helpers as helpers
|
|
|
16
17
|
import durabletask.internal.orchestrator_service_pb2 as pb
|
|
17
18
|
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
|
|
18
19
|
import durabletask.internal.shared as shared
|
|
20
|
+
import durabletask.internal.tracing as tracing
|
|
19
21
|
from durabletask import task
|
|
20
|
-
from durabletask.internal.
|
|
22
|
+
from durabletask.internal.client_helpers import (
|
|
23
|
+
build_query_entities_req,
|
|
24
|
+
build_query_instances_req,
|
|
25
|
+
build_purge_by_filter_req,
|
|
26
|
+
build_raise_event_req,
|
|
27
|
+
build_schedule_new_orchestration_req,
|
|
28
|
+
build_signal_entity_req,
|
|
29
|
+
build_terminate_req,
|
|
30
|
+
check_continuation_token,
|
|
31
|
+
log_completion_state,
|
|
32
|
+
prepare_async_interceptors,
|
|
33
|
+
prepare_sync_interceptors,
|
|
34
|
+
)
|
|
35
|
+
from durabletask.payload import helpers as payload_helpers
|
|
36
|
+
from durabletask.payload.store import PayloadStore
|
|
21
37
|
|
|
22
38
|
TInput = TypeVar('TInput')
|
|
23
39
|
TOutput = TypeVar('TOutput')
|
|
@@ -138,27 +154,25 @@ class TaskHubGrpcClient:
|
|
|
138
154
|
log_formatter: Optional[logging.Formatter] = None,
|
|
139
155
|
secure_channel: bool = False,
|
|
140
156
|
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
|
|
141
|
-
default_version: Optional[str] = None
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
if interceptors is not None:
|
|
146
|
-
interceptors = list(interceptors)
|
|
147
|
-
if metadata is not None:
|
|
148
|
-
interceptors.append(DefaultClientInterceptorImpl(metadata))
|
|
149
|
-
elif metadata is not None:
|
|
150
|
-
interceptors = [DefaultClientInterceptorImpl(metadata)]
|
|
151
|
-
else:
|
|
152
|
-
interceptors = None
|
|
157
|
+
default_version: Optional[str] = None,
|
|
158
|
+
payload_store: Optional[PayloadStore] = None):
|
|
159
|
+
|
|
160
|
+
interceptors = prepare_sync_interceptors(metadata, interceptors)
|
|
153
161
|
|
|
154
162
|
channel = shared.get_grpc_channel(
|
|
155
163
|
host_address=host_address,
|
|
156
164
|
secure_channel=secure_channel,
|
|
157
165
|
interceptors=interceptors
|
|
158
166
|
)
|
|
167
|
+
self._channel = channel
|
|
159
168
|
self._stub = stubs.TaskHubSidecarServiceStub(channel)
|
|
160
169
|
self._logger = shared.get_logger("client", log_handler, log_formatter)
|
|
161
170
|
self.default_version = default_version
|
|
171
|
+
self._payload_store = payload_store
|
|
172
|
+
|
|
173
|
+
def close(self) -> None:
|
|
174
|
+
"""Closes the underlying gRPC channel."""
|
|
175
|
+
self._channel.close()
|
|
162
176
|
|
|
163
177
|
def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
|
|
164
178
|
input: Optional[TInput] = None,
|
|
@@ -169,24 +183,39 @@ class TaskHubGrpcClient:
|
|
|
169
183
|
version: Optional[str] = None) -> str:
|
|
170
184
|
|
|
171
185
|
name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
+
resolved_instance_id = instance_id if instance_id else uuid.uuid4().hex
|
|
187
|
+
resolved_version = version if version else self.default_version
|
|
188
|
+
|
|
189
|
+
with tracing.start_create_orchestration_span(
|
|
190
|
+
name, resolved_instance_id, version=resolved_version,
|
|
191
|
+
):
|
|
192
|
+
req = build_schedule_new_orchestration_req(
|
|
193
|
+
orchestrator, input=input, instance_id=instance_id, start_at=start_at,
|
|
194
|
+
reuse_id_policy=reuse_id_policy, tags=tags,
|
|
195
|
+
version=version if version else self.default_version)
|
|
196
|
+
|
|
197
|
+
# Inject the active PRODUCER span context into the request so the sidecar
|
|
198
|
+
# stores it in the executionStarted event and the worker can parent all
|
|
199
|
+
# orchestration/activity/timer spans under this trace.
|
|
200
|
+
parent_trace_ctx = tracing.get_current_trace_context()
|
|
201
|
+
if parent_trace_ctx is not None:
|
|
202
|
+
req.parentTraceContext.CopyFrom(parent_trace_ctx)
|
|
203
|
+
|
|
204
|
+
self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.")
|
|
205
|
+
# Externalize any large payloads in the request
|
|
206
|
+
if self._payload_store is not None:
|
|
207
|
+
payload_helpers.externalize_payloads(
|
|
208
|
+
req, self._payload_store, instance_id=req.instanceId,
|
|
209
|
+
)
|
|
210
|
+
res: pb.CreateInstanceResponse = self._stub.StartInstance(req)
|
|
211
|
+
return res.instanceId
|
|
186
212
|
|
|
187
213
|
def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]:
|
|
188
214
|
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
|
|
189
215
|
res: pb.GetInstanceResponse = self._stub.GetInstance(req)
|
|
216
|
+
# De-externalize any large-payload tokens in the response
|
|
217
|
+
if self._payload_store is not None and res.exists:
|
|
218
|
+
payload_helpers.deexternalize_payloads(res, self._payload_store)
|
|
190
219
|
return new_orchestration_state(req.instanceId, res)
|
|
191
220
|
|
|
192
221
|
def get_all_orchestration_states(self,
|
|
@@ -201,24 +230,12 @@ class TaskHubGrpcClient:
|
|
|
201
230
|
states = []
|
|
202
231
|
|
|
203
232
|
while True:
|
|
204
|
-
req =
|
|
205
|
-
query=pb.InstanceQuery(
|
|
206
|
-
runtimeStatus=[status.value for status in orchestration_query.runtime_status] if orchestration_query.runtime_status else None,
|
|
207
|
-
createdTimeFrom=helpers.new_timestamp(orchestration_query.created_time_from) if orchestration_query.created_time_from else None,
|
|
208
|
-
createdTimeTo=helpers.new_timestamp(orchestration_query.created_time_to) if orchestration_query.created_time_to else None,
|
|
209
|
-
maxInstanceCount=orchestration_query.max_instance_count,
|
|
210
|
-
fetchInputsAndOutputs=orchestration_query.fetch_inputs_and_outputs,
|
|
211
|
-
continuationToken=_continuation_token
|
|
212
|
-
)
|
|
213
|
-
)
|
|
233
|
+
req = build_query_instances_req(orchestration_query, _continuation_token)
|
|
214
234
|
resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req)
|
|
235
|
+
if self._payload_store is not None:
|
|
236
|
+
payload_helpers.deexternalize_payloads(resp, self._payload_store)
|
|
215
237
|
states += [parse_orchestration_state(res) for res in resp.orchestrationState]
|
|
216
|
-
|
|
217
|
-
if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
|
|
218
|
-
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next list of instances...")
|
|
219
|
-
if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
|
|
220
|
-
self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
|
|
221
|
-
break
|
|
238
|
+
if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
|
|
222
239
|
_continuation_token = resp.continuationToken
|
|
223
240
|
else:
|
|
224
241
|
break
|
|
@@ -232,6 +249,8 @@ class TaskHubGrpcClient:
|
|
|
232
249
|
try:
|
|
233
250
|
self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.")
|
|
234
251
|
res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=timeout)
|
|
252
|
+
if self._payload_store is not None and res.exists:
|
|
253
|
+
payload_helpers.deexternalize_payloads(res, self._payload_store)
|
|
235
254
|
return new_orchestration_state(req.instanceId, res)
|
|
236
255
|
except grpc.RpcError as rpc_error:
|
|
237
256
|
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
|
|
@@ -247,58 +266,70 @@ class TaskHubGrpcClient:
|
|
|
247
266
|
try:
|
|
248
267
|
self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.")
|
|
249
268
|
res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=timeout)
|
|
269
|
+
if self._payload_store is not None and res.exists:
|
|
270
|
+
payload_helpers.deexternalize_payloads(res, self._payload_store)
|
|
250
271
|
state = new_orchestration_state(req.instanceId, res)
|
|
251
|
-
|
|
252
|
-
return None
|
|
253
|
-
|
|
254
|
-
if state.runtime_status == OrchestrationStatus.FAILED and state.failure_details is not None:
|
|
255
|
-
details = state.failure_details
|
|
256
|
-
self._logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}")
|
|
257
|
-
elif state.runtime_status == OrchestrationStatus.TERMINATED:
|
|
258
|
-
self._logger.info(f"Instance '{instance_id}' was terminated.")
|
|
259
|
-
elif state.runtime_status == OrchestrationStatus.COMPLETED:
|
|
260
|
-
self._logger.info(f"Instance '{instance_id}' completed.")
|
|
261
|
-
|
|
272
|
+
log_completion_state(self._logger, instance_id, state)
|
|
262
273
|
return state
|
|
263
274
|
except grpc.RpcError as rpc_error:
|
|
264
275
|
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
|
|
265
|
-
# Replace gRPC error with the built-in TimeoutError
|
|
266
276
|
raise TimeoutError("Timed-out waiting for the orchestration to complete")
|
|
267
277
|
else:
|
|
268
278
|
raise
|
|
269
279
|
|
|
270
280
|
def raise_orchestration_event(self, instance_id: str, event_name: str, *,
|
|
271
|
-
data: Optional[Any] = None):
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
281
|
+
data: Optional[Any] = None) -> None:
|
|
282
|
+
with tracing.start_raise_event_span(event_name, instance_id):
|
|
283
|
+
req = build_raise_event_req(instance_id, event_name, data)
|
|
284
|
+
self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
|
|
285
|
+
if self._payload_store is not None:
|
|
286
|
+
payload_helpers.externalize_payloads(
|
|
287
|
+
req, self._payload_store, instance_id=instance_id,
|
|
288
|
+
)
|
|
289
|
+
self._stub.RaiseEvent(req)
|
|
280
290
|
|
|
281
291
|
def terminate_orchestration(self, instance_id: str, *,
|
|
282
292
|
output: Optional[Any] = None,
|
|
283
|
-
recursive: bool = True):
|
|
284
|
-
req =
|
|
285
|
-
instanceId=instance_id,
|
|
286
|
-
output=helpers.get_string_value(shared.to_json(output) if output is not None else None),
|
|
287
|
-
recursive=recursive)
|
|
293
|
+
recursive: bool = True) -> None:
|
|
294
|
+
req = build_terminate_req(instance_id, output, recursive)
|
|
288
295
|
|
|
289
296
|
self._logger.info(f"Terminating instance '{instance_id}'.")
|
|
297
|
+
if self._payload_store is not None:
|
|
298
|
+
payload_helpers.externalize_payloads(
|
|
299
|
+
req, self._payload_store, instance_id=instance_id,
|
|
300
|
+
)
|
|
290
301
|
self._stub.TerminateInstance(req)
|
|
291
302
|
|
|
292
|
-
def suspend_orchestration(self, instance_id: str):
|
|
303
|
+
def suspend_orchestration(self, instance_id: str) -> None:
|
|
293
304
|
req = pb.SuspendRequest(instanceId=instance_id)
|
|
294
305
|
self._logger.info(f"Suspending instance '{instance_id}'.")
|
|
295
306
|
self._stub.SuspendInstance(req)
|
|
296
307
|
|
|
297
|
-
def resume_orchestration(self, instance_id: str):
|
|
308
|
+
def resume_orchestration(self, instance_id: str) -> None:
|
|
298
309
|
req = pb.ResumeRequest(instanceId=instance_id)
|
|
299
310
|
self._logger.info(f"Resuming instance '{instance_id}'.")
|
|
300
311
|
self._stub.ResumeInstance(req)
|
|
301
312
|
|
|
313
|
+
def restart_orchestration(self, instance_id: str, *,
|
|
314
|
+
restart_with_new_instance_id: bool = False) -> str:
|
|
315
|
+
"""Restarts an existing orchestration instance.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
instance_id: The ID of the orchestration instance to restart.
|
|
319
|
+
restart_with_new_instance_id: If True, the restarted orchestration will use a new instance ID.
|
|
320
|
+
If False (default), the restarted orchestration will reuse the same instance ID.
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
The instance ID of the restarted orchestration.
|
|
324
|
+
"""
|
|
325
|
+
req = pb.RestartInstanceRequest(
|
|
326
|
+
instanceId=instance_id,
|
|
327
|
+
restartWithNewInstanceId=restart_with_new_instance_id)
|
|
328
|
+
|
|
329
|
+
self._logger.info(f"Restarting instance '{instance_id}'.")
|
|
330
|
+
res: pb.RestartInstanceResponse = self._stub.RestartInstance(req)
|
|
331
|
+
return res.instanceId
|
|
332
|
+
|
|
302
333
|
def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
|
|
303
334
|
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
|
|
304
335
|
self._logger.info(f"Purging instance '{instance_id}'.")
|
|
@@ -315,30 +346,20 @@ class TaskHubGrpcClient:
|
|
|
315
346
|
f"created_time_to={created_time_to}, "
|
|
316
347
|
f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
|
|
317
348
|
f"recursive={recursive}")
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
|
|
321
|
-
createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
|
|
322
|
-
runtimeStatus=[status.value for status in runtime_status] if runtime_status else None
|
|
323
|
-
),
|
|
324
|
-
recursive=recursive
|
|
325
|
-
))
|
|
349
|
+
req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive)
|
|
350
|
+
resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req)
|
|
326
351
|
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
|
|
327
352
|
|
|
328
353
|
def signal_entity(self,
|
|
329
354
|
entity_instance_id: EntityInstanceId,
|
|
330
355
|
operation_name: str,
|
|
331
356
|
input: Optional[Any] = None) -> None:
|
|
332
|
-
req =
|
|
333
|
-
instanceId=str(entity_instance_id),
|
|
334
|
-
name=operation_name,
|
|
335
|
-
input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
|
|
336
|
-
requestId=str(uuid.uuid4()),
|
|
337
|
-
scheduledTime=None,
|
|
338
|
-
parentTraceContext=None,
|
|
339
|
-
requestTime=helpers.new_timestamp(datetime.now(timezone.utc))
|
|
340
|
-
)
|
|
357
|
+
req = build_signal_entity_req(entity_instance_id, operation_name, input)
|
|
341
358
|
self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.")
|
|
359
|
+
if self._payload_store is not None:
|
|
360
|
+
payload_helpers.externalize_payloads(
|
|
361
|
+
req, self._payload_store, instance_id=str(entity_instance_id),
|
|
362
|
+
)
|
|
342
363
|
self._stub.SignalEntity(req, None) # TODO: Cancellation timeout?
|
|
343
364
|
|
|
344
365
|
def get_entity(self,
|
|
@@ -350,7 +371,8 @@ class TaskHubGrpcClient:
|
|
|
350
371
|
res: pb.GetEntityResponse = self._stub.GetEntity(req)
|
|
351
372
|
if not res.exists:
|
|
352
373
|
return None
|
|
353
|
-
|
|
374
|
+
if self._payload_store is not None:
|
|
375
|
+
payload_helpers.deexternalize_payloads(res, self._payload_store)
|
|
354
376
|
return EntityMetadata.from_entity_metadata(res.entity, include_state)
|
|
355
377
|
|
|
356
378
|
def get_all_entities(self,
|
|
@@ -364,24 +386,12 @@ class TaskHubGrpcClient:
|
|
|
364
386
|
entities = []
|
|
365
387
|
|
|
366
388
|
while True:
|
|
367
|
-
query_request =
|
|
368
|
-
query=pb.EntityQuery(
|
|
369
|
-
instanceIdStartsWith=helpers.get_string_value(entity_query.instance_id_starts_with),
|
|
370
|
-
lastModifiedFrom=helpers.new_timestamp(entity_query.last_modified_from) if entity_query.last_modified_from else None,
|
|
371
|
-
lastModifiedTo=helpers.new_timestamp(entity_query.last_modified_to) if entity_query.last_modified_to else None,
|
|
372
|
-
includeState=entity_query.include_state,
|
|
373
|
-
includeTransient=entity_query.include_transient,
|
|
374
|
-
pageSize=helpers.get_int_value(entity_query.page_size),
|
|
375
|
-
continuationToken=_continuation_token
|
|
376
|
-
)
|
|
377
|
-
)
|
|
389
|
+
query_request = build_query_entities_req(entity_query, _continuation_token)
|
|
378
390
|
resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request)
|
|
391
|
+
if self._payload_store is not None:
|
|
392
|
+
payload_helpers.deexternalize_payloads(resp, self._payload_store)
|
|
379
393
|
entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
|
|
380
|
-
if resp.continuationToken
|
|
381
|
-
self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next page of entities...")
|
|
382
|
-
if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
|
|
383
|
-
self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
|
|
384
|
-
break
|
|
394
|
+
if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
|
|
385
395
|
_continuation_token = resp.continuationToken
|
|
386
396
|
else:
|
|
387
397
|
break
|
|
@@ -407,11 +417,290 @@ class TaskHubGrpcClient:
|
|
|
407
417
|
empty_entities_removed += resp.emptyEntitiesRemoved
|
|
408
418
|
orphaned_locks_released += resp.orphanedLocksReleased
|
|
409
419
|
|
|
410
|
-
if resp.continuationToken
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
420
|
+
if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
|
|
421
|
+
_continuation_token = resp.continuationToken
|
|
422
|
+
else:
|
|
423
|
+
break
|
|
424
|
+
|
|
425
|
+
return CleanEntityStorageResult(empty_entities_removed, orphaned_locks_released)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
class AsyncTaskHubGrpcClient:
|
|
429
|
+
"""Async version of TaskHubGrpcClient using grpc.aio for asyncio-based applications."""
|
|
430
|
+
|
|
431
|
+
def __init__(self, *,
|
|
432
|
+
host_address: Optional[str] = None,
|
|
433
|
+
metadata: Optional[list[tuple[str, str]]] = None,
|
|
434
|
+
log_handler: Optional[logging.Handler] = None,
|
|
435
|
+
log_formatter: Optional[logging.Formatter] = None,
|
|
436
|
+
secure_channel: bool = False,
|
|
437
|
+
interceptors: Optional[Sequence[shared.AsyncClientInterceptor]] = None,
|
|
438
|
+
default_version: Optional[str] = None,
|
|
439
|
+
payload_store: Optional[PayloadStore] = None):
|
|
440
|
+
|
|
441
|
+
interceptors = prepare_async_interceptors(metadata, interceptors)
|
|
442
|
+
|
|
443
|
+
channel = shared.get_async_grpc_channel(
|
|
444
|
+
host_address=host_address,
|
|
445
|
+
secure_channel=secure_channel,
|
|
446
|
+
interceptors=interceptors
|
|
447
|
+
)
|
|
448
|
+
self._channel = channel
|
|
449
|
+
self._stub = stubs.TaskHubSidecarServiceStub(channel)
|
|
450
|
+
self._logger = shared.get_logger("async_client", log_handler, log_formatter)
|
|
451
|
+
self.default_version = default_version
|
|
452
|
+
self._payload_store = payload_store
|
|
453
|
+
|
|
454
|
+
async def close(self) -> None:
|
|
455
|
+
"""Closes the underlying gRPC channel."""
|
|
456
|
+
await self._channel.close()
|
|
457
|
+
|
|
458
|
+
async def __aenter__(self):
|
|
459
|
+
return self
|
|
460
|
+
|
|
461
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
462
|
+
await self.close()
|
|
463
|
+
|
|
464
|
+
async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
|
|
465
|
+
input: Optional[TInput] = None,
|
|
466
|
+
instance_id: Optional[str] = None,
|
|
467
|
+
start_at: Optional[datetime] = None,
|
|
468
|
+
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None,
|
|
469
|
+
tags: Optional[dict[str, str]] = None,
|
|
470
|
+
version: Optional[str] = None) -> str:
|
|
471
|
+
|
|
472
|
+
name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)
|
|
473
|
+
resolved_instance_id = instance_id if instance_id else uuid.uuid4().hex
|
|
474
|
+
resolved_version = version if version else self.default_version
|
|
475
|
+
|
|
476
|
+
with tracing.start_create_orchestration_span(
|
|
477
|
+
name, resolved_instance_id, version=resolved_version,
|
|
478
|
+
):
|
|
479
|
+
req = build_schedule_new_orchestration_req(
|
|
480
|
+
orchestrator, input=input, instance_id=instance_id, start_at=start_at,
|
|
481
|
+
reuse_id_policy=reuse_id_policy, tags=tags,
|
|
482
|
+
version=version if version else self.default_version)
|
|
483
|
+
|
|
484
|
+
parent_trace_ctx = tracing.get_current_trace_context()
|
|
485
|
+
if parent_trace_ctx is not None:
|
|
486
|
+
req.parentTraceContext.CopyFrom(parent_trace_ctx)
|
|
487
|
+
|
|
488
|
+
self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.")
|
|
489
|
+
# Externalize any large payloads in the request
|
|
490
|
+
if self._payload_store is not None:
|
|
491
|
+
await payload_helpers.externalize_payloads_async(
|
|
492
|
+
req, self._payload_store, instance_id=req.instanceId,
|
|
493
|
+
)
|
|
494
|
+
res: pb.CreateInstanceResponse = await self._stub.StartInstance(req)
|
|
495
|
+
return res.instanceId
|
|
496
|
+
|
|
497
|
+
async def get_orchestration_state(self, instance_id: str, *,
|
|
498
|
+
fetch_payloads: bool = True) -> Optional[OrchestrationState]:
|
|
499
|
+
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
|
|
500
|
+
res: pb.GetInstanceResponse = await self._stub.GetInstance(req)
|
|
501
|
+
if self._payload_store is not None and res.exists:
|
|
502
|
+
await payload_helpers.deexternalize_payloads_async(res, self._payload_store)
|
|
503
|
+
return new_orchestration_state(req.instanceId, res)
|
|
504
|
+
|
|
505
|
+
async def get_all_orchestration_states(self,
|
|
506
|
+
orchestration_query: Optional[OrchestrationQuery] = None
|
|
507
|
+
) -> List[OrchestrationState]:
|
|
508
|
+
if orchestration_query is None:
|
|
509
|
+
orchestration_query = OrchestrationQuery()
|
|
510
|
+
_continuation_token = None
|
|
511
|
+
|
|
512
|
+
self._logger.info(f"Querying orchestration instances with query: {orchestration_query}")
|
|
513
|
+
|
|
514
|
+
states = []
|
|
515
|
+
|
|
516
|
+
while True:
|
|
517
|
+
req = build_query_instances_req(orchestration_query, _continuation_token)
|
|
518
|
+
resp: pb.QueryInstancesResponse = await self._stub.QueryInstances(req)
|
|
519
|
+
if self._payload_store is not None:
|
|
520
|
+
await payload_helpers.deexternalize_payloads_async(resp, self._payload_store)
|
|
521
|
+
states += [parse_orchestration_state(res) for res in resp.orchestrationState]
|
|
522
|
+
if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
|
|
523
|
+
_continuation_token = resp.continuationToken
|
|
524
|
+
else:
|
|
525
|
+
break
|
|
526
|
+
|
|
527
|
+
return states
|
|
528
|
+
|
|
529
|
+
async def wait_for_orchestration_start(self, instance_id: str, *,
|
|
530
|
+
fetch_payloads: bool = False,
|
|
531
|
+
timeout: int = 60) -> Optional[OrchestrationState]:
|
|
532
|
+
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
|
|
533
|
+
try:
|
|
534
|
+
self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.")
|
|
535
|
+
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=timeout)
|
|
536
|
+
if self._payload_store is not None and res.exists:
|
|
537
|
+
await payload_helpers.deexternalize_payloads_async(res, self._payload_store)
|
|
538
|
+
return new_orchestration_state(req.instanceId, res)
|
|
539
|
+
except grpc.aio.AioRpcError as rpc_error:
|
|
540
|
+
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
|
|
541
|
+
raise TimeoutError("Timed-out waiting for the orchestration to start")
|
|
542
|
+
else:
|
|
543
|
+
raise
|
|
544
|
+
|
|
545
|
+
async def wait_for_orchestration_completion(self, instance_id: str, *,
|
|
546
|
+
fetch_payloads: bool = True,
|
|
547
|
+
timeout: int = 60) -> Optional[OrchestrationState]:
|
|
548
|
+
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
|
|
549
|
+
try:
|
|
550
|
+
self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.")
|
|
551
|
+
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=timeout)
|
|
552
|
+
if self._payload_store is not None and res.exists:
|
|
553
|
+
await payload_helpers.deexternalize_payloads_async(res, self._payload_store)
|
|
554
|
+
state = new_orchestration_state(req.instanceId, res)
|
|
555
|
+
log_completion_state(self._logger, instance_id, state)
|
|
556
|
+
return state
|
|
557
|
+
except grpc.aio.AioRpcError as rpc_error:
|
|
558
|
+
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
|
|
559
|
+
raise TimeoutError("Timed-out waiting for the orchestration to complete")
|
|
560
|
+
else:
|
|
561
|
+
raise
|
|
562
|
+
|
|
563
|
+
async def raise_orchestration_event(self, instance_id: str, event_name: str, *,
|
|
564
|
+
data: Optional[Any] = None) -> None:
|
|
565
|
+
with tracing.start_raise_event_span(event_name, instance_id):
|
|
566
|
+
req = build_raise_event_req(instance_id, event_name, data)
|
|
567
|
+
self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
|
|
568
|
+
if self._payload_store is not None:
|
|
569
|
+
await payload_helpers.externalize_payloads_async(
|
|
570
|
+
req, self._payload_store, instance_id=instance_id,
|
|
571
|
+
)
|
|
572
|
+
await self._stub.RaiseEvent(req)
|
|
573
|
+
|
|
574
|
+
async def terminate_orchestration(self, instance_id: str, *,
|
|
575
|
+
output: Optional[Any] = None,
|
|
576
|
+
recursive: bool = True) -> None:
|
|
577
|
+
req = build_terminate_req(instance_id, output, recursive)
|
|
578
|
+
|
|
579
|
+
self._logger.info(f"Terminating instance '{instance_id}'.")
|
|
580
|
+
if self._payload_store is not None:
|
|
581
|
+
await payload_helpers.externalize_payloads_async(
|
|
582
|
+
req, self._payload_store, instance_id=instance_id,
|
|
583
|
+
)
|
|
584
|
+
await self._stub.TerminateInstance(req)
|
|
585
|
+
|
|
586
|
+
async def suspend_orchestration(self, instance_id: str) -> None:
|
|
587
|
+
req = pb.SuspendRequest(instanceId=instance_id)
|
|
588
|
+
self._logger.info(f"Suspending instance '{instance_id}'.")
|
|
589
|
+
await self._stub.SuspendInstance(req)
|
|
590
|
+
|
|
591
|
+
async def resume_orchestration(self, instance_id: str) -> None:
|
|
592
|
+
req = pb.ResumeRequest(instanceId=instance_id)
|
|
593
|
+
self._logger.info(f"Resuming instance '{instance_id}'.")
|
|
594
|
+
await self._stub.ResumeInstance(req)
|
|
595
|
+
|
|
596
|
+
async def restart_orchestration(self, instance_id: str, *,
|
|
597
|
+
restart_with_new_instance_id: bool = False) -> str:
|
|
598
|
+
"""Restarts an existing orchestration instance.
|
|
599
|
+
|
|
600
|
+
Args:
|
|
601
|
+
instance_id: The ID of the orchestration instance to restart.
|
|
602
|
+
restart_with_new_instance_id: If True, the restarted orchestration will use a new instance ID.
|
|
603
|
+
If False (default), the restarted orchestration will reuse the same instance ID.
|
|
604
|
+
|
|
605
|
+
Returns:
|
|
606
|
+
The instance ID of the restarted orchestration.
|
|
607
|
+
"""
|
|
608
|
+
req = pb.RestartInstanceRequest(
|
|
609
|
+
instanceId=instance_id,
|
|
610
|
+
restartWithNewInstanceId=restart_with_new_instance_id)
|
|
611
|
+
|
|
612
|
+
self._logger.info(f"Restarting instance '{instance_id}'.")
|
|
613
|
+
res: pb.RestartInstanceResponse = await self._stub.RestartInstance(req)
|
|
614
|
+
return res.instanceId
|
|
615
|
+
|
|
616
|
+
async def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
|
|
617
|
+
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
|
|
618
|
+
self._logger.info(f"Purging instance '{instance_id}'.")
|
|
619
|
+
resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req)
|
|
620
|
+
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
|
|
621
|
+
|
|
622
|
+
async def purge_orchestrations_by(self,
|
|
623
|
+
created_time_from: Optional[datetime] = None,
|
|
624
|
+
created_time_to: Optional[datetime] = None,
|
|
625
|
+
runtime_status: Optional[List[OrchestrationStatus]] = None,
|
|
626
|
+
recursive: bool = False) -> PurgeInstancesResult:
|
|
627
|
+
self._logger.info("Purging orchestrations by filter: "
|
|
628
|
+
f"created_time_from={created_time_from}, "
|
|
629
|
+
f"created_time_to={created_time_to}, "
|
|
630
|
+
f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
|
|
631
|
+
f"recursive={recursive}")
|
|
632
|
+
req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive)
|
|
633
|
+
resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req)
|
|
634
|
+
return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
|
|
635
|
+
|
|
636
|
+
async def signal_entity(self,
|
|
637
|
+
entity_instance_id: EntityInstanceId,
|
|
638
|
+
operation_name: str,
|
|
639
|
+
input: Optional[Any] = None) -> None:
|
|
640
|
+
req = build_signal_entity_req(entity_instance_id, operation_name, input)
|
|
641
|
+
self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.")
|
|
642
|
+
if self._payload_store is not None:
|
|
643
|
+
await payload_helpers.externalize_payloads_async(
|
|
644
|
+
req, self._payload_store, instance_id=str(entity_instance_id),
|
|
645
|
+
)
|
|
646
|
+
await self._stub.SignalEntity(req, None)
|
|
647
|
+
|
|
648
|
+
async def get_entity(self,
|
|
649
|
+
entity_instance_id: EntityInstanceId,
|
|
650
|
+
include_state: bool = True
|
|
651
|
+
) -> Optional[EntityMetadata]:
|
|
652
|
+
req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state)
|
|
653
|
+
self._logger.info(f"Getting entity '{entity_instance_id}'.")
|
|
654
|
+
res: pb.GetEntityResponse = await self._stub.GetEntity(req)
|
|
655
|
+
if not res.exists:
|
|
656
|
+
return None
|
|
657
|
+
if self._payload_store is not None:
|
|
658
|
+
await payload_helpers.deexternalize_payloads_async(res, self._payload_store)
|
|
659
|
+
return EntityMetadata.from_entity_metadata(res.entity, include_state)
|
|
660
|
+
|
|
661
|
+
async def get_all_entities(self,
|
|
662
|
+
entity_query: Optional[EntityQuery] = None) -> List[EntityMetadata]:
|
|
663
|
+
if entity_query is None:
|
|
664
|
+
entity_query = EntityQuery()
|
|
665
|
+
_continuation_token = None
|
|
666
|
+
|
|
667
|
+
self._logger.info(f"Retrieving entities by filter: {entity_query}")
|
|
668
|
+
|
|
669
|
+
entities = []
|
|
670
|
+
|
|
671
|
+
while True:
|
|
672
|
+
query_request = build_query_entities_req(entity_query, _continuation_token)
|
|
673
|
+
resp: pb.QueryEntitiesResponse = await self._stub.QueryEntities(query_request)
|
|
674
|
+
if self._payload_store is not None:
|
|
675
|
+
await payload_helpers.deexternalize_payloads_async(resp, self._payload_store)
|
|
676
|
+
entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
|
|
677
|
+
if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
|
|
678
|
+
_continuation_token = resp.continuationToken
|
|
679
|
+
else:
|
|
680
|
+
break
|
|
681
|
+
return entities
|
|
682
|
+
|
|
683
|
+
async def clean_entity_storage(self,
|
|
684
|
+
remove_empty_entities: bool = True,
|
|
685
|
+
release_orphaned_locks: bool = True
|
|
686
|
+
) -> CleanEntityStorageResult:
|
|
687
|
+
self._logger.info("Cleaning entity storage")
|
|
688
|
+
|
|
689
|
+
empty_entities_removed = 0
|
|
690
|
+
orphaned_locks_released = 0
|
|
691
|
+
_continuation_token = None
|
|
692
|
+
|
|
693
|
+
while True:
|
|
694
|
+
req = pb.CleanEntityStorageRequest(
|
|
695
|
+
removeEmptyEntities=remove_empty_entities,
|
|
696
|
+
releaseOrphanedLocks=release_orphaned_locks,
|
|
697
|
+
continuationToken=_continuation_token
|
|
698
|
+
)
|
|
699
|
+
resp: pb.CleanEntityStorageResponse = await self._stub.CleanEntityStorage(req)
|
|
700
|
+
empty_entities_removed += resp.emptyEntitiesRemoved
|
|
701
|
+
orphaned_locks_released += resp.orphanedLocksReleased
|
|
702
|
+
|
|
703
|
+
if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
|
|
415
704
|
_continuation_token = resp.continuationToken
|
|
416
705
|
else:
|
|
417
706
|
break
|