durabletask 0.0.0.dev68__py3-none-any.whl → 0.0.0.dev69__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
durabletask/__init__.py CHANGED
@@ -3,8 +3,14 @@
3
3
 
4
4
  """Durable Task SDK for Python"""
5
5
 
6
+ from durabletask.payload.store import LargePayloadStorageOptions, PayloadStore
6
7
  from durabletask.worker import ConcurrencyOptions, VersioningOptions
7
8
 
8
- __all__ = ["ConcurrencyOptions", "VersioningOptions"]
9
+ __all__ = [
10
+ "ConcurrencyOptions",
11
+ "LargePayloadStorageOptions",
12
+ "PayloadStore",
13
+ "VersioningOptions",
14
+ ]
9
15
 
10
16
  PACKAGE_NAME = "durabletask"
durabletask/client.py CHANGED
@@ -4,11 +4,12 @@
4
4
  import logging
5
5
  import uuid
6
6
  from dataclasses import dataclass
7
- from datetime import datetime, timezone
7
+ from datetime import datetime
8
8
  from enum import Enum
9
9
  from typing import Any, List, Optional, Sequence, TypeVar, Union
10
10
 
11
11
  import grpc
12
+ import grpc.aio
12
13
 
13
14
  from durabletask.entities import EntityInstanceId
14
15
  from durabletask.entities.entity_metadata import EntityMetadata
@@ -16,8 +17,23 @@ import durabletask.internal.helpers as helpers
16
17
  import durabletask.internal.orchestrator_service_pb2 as pb
17
18
  import durabletask.internal.orchestrator_service_pb2_grpc as stubs
18
19
  import durabletask.internal.shared as shared
20
+ import durabletask.internal.tracing as tracing
19
21
  from durabletask import task
20
- from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
22
+ from durabletask.internal.client_helpers import (
23
+ build_query_entities_req,
24
+ build_query_instances_req,
25
+ build_purge_by_filter_req,
26
+ build_raise_event_req,
27
+ build_schedule_new_orchestration_req,
28
+ build_signal_entity_req,
29
+ build_terminate_req,
30
+ check_continuation_token,
31
+ log_completion_state,
32
+ prepare_async_interceptors,
33
+ prepare_sync_interceptors,
34
+ )
35
+ from durabletask.payload import helpers as payload_helpers
36
+ from durabletask.payload.store import PayloadStore
21
37
 
22
38
  TInput = TypeVar('TInput')
23
39
  TOutput = TypeVar('TOutput')
@@ -138,27 +154,25 @@ class TaskHubGrpcClient:
138
154
  log_formatter: Optional[logging.Formatter] = None,
139
155
  secure_channel: bool = False,
140
156
  interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
141
- default_version: Optional[str] = None):
142
-
143
- # If the caller provided metadata, we need to create a new interceptor for it and
144
- # add it to the list of interceptors.
145
- if interceptors is not None:
146
- interceptors = list(interceptors)
147
- if metadata is not None:
148
- interceptors.append(DefaultClientInterceptorImpl(metadata))
149
- elif metadata is not None:
150
- interceptors = [DefaultClientInterceptorImpl(metadata)]
151
- else:
152
- interceptors = None
157
+ default_version: Optional[str] = None,
158
+ payload_store: Optional[PayloadStore] = None):
159
+
160
+ interceptors = prepare_sync_interceptors(metadata, interceptors)
153
161
 
154
162
  channel = shared.get_grpc_channel(
155
163
  host_address=host_address,
156
164
  secure_channel=secure_channel,
157
165
  interceptors=interceptors
158
166
  )
167
+ self._channel = channel
159
168
  self._stub = stubs.TaskHubSidecarServiceStub(channel)
160
169
  self._logger = shared.get_logger("client", log_handler, log_formatter)
161
170
  self.default_version = default_version
171
+ self._payload_store = payload_store
172
+
173
+ def close(self) -> None:
174
+ """Closes the underlying gRPC channel."""
175
+ self._channel.close()
162
176
 
163
177
  def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
164
178
  input: Optional[TInput] = None,
@@ -169,24 +183,39 @@ class TaskHubGrpcClient:
169
183
  version: Optional[str] = None) -> str:
170
184
 
171
185
  name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)
172
-
173
- req = pb.CreateInstanceRequest(
174
- name=name,
175
- instanceId=instance_id if instance_id else uuid.uuid4().hex,
176
- input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
177
- scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
178
- version=helpers.get_string_value(version if version else self.default_version),
179
- orchestrationIdReusePolicy=reuse_id_policy,
180
- tags=tags
181
- )
182
-
183
- self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.")
184
- res: pb.CreateInstanceResponse = self._stub.StartInstance(req)
185
- return res.instanceId
186
+ resolved_instance_id = instance_id if instance_id else uuid.uuid4().hex
187
+ resolved_version = version if version else self.default_version
188
+
189
+ with tracing.start_create_orchestration_span(
190
+ name, resolved_instance_id, version=resolved_version,
191
+ ):
192
+ req = build_schedule_new_orchestration_req(
193
+ orchestrator, input=input, instance_id=instance_id, start_at=start_at,
194
+ reuse_id_policy=reuse_id_policy, tags=tags,
195
+ version=version if version else self.default_version)
196
+
197
+ # Inject the active PRODUCER span context into the request so the sidecar
198
+ # stores it in the executionStarted event and the worker can parent all
199
+ # orchestration/activity/timer spans under this trace.
200
+ parent_trace_ctx = tracing.get_current_trace_context()
201
+ if parent_trace_ctx is not None:
202
+ req.parentTraceContext.CopyFrom(parent_trace_ctx)
203
+
204
+ self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.")
205
+ # Externalize any large payloads in the request
206
+ if self._payload_store is not None:
207
+ payload_helpers.externalize_payloads(
208
+ req, self._payload_store, instance_id=req.instanceId,
209
+ )
210
+ res: pb.CreateInstanceResponse = self._stub.StartInstance(req)
211
+ return res.instanceId
186
212
 
187
213
  def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]:
188
214
  req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
189
215
  res: pb.GetInstanceResponse = self._stub.GetInstance(req)
216
+ # De-externalize any large-payload tokens in the response
217
+ if self._payload_store is not None and res.exists:
218
+ payload_helpers.deexternalize_payloads(res, self._payload_store)
190
219
  return new_orchestration_state(req.instanceId, res)
191
220
 
192
221
  def get_all_orchestration_states(self,
@@ -201,24 +230,12 @@ class TaskHubGrpcClient:
201
230
  states = []
202
231
 
203
232
  while True:
204
- req = pb.QueryInstancesRequest(
205
- query=pb.InstanceQuery(
206
- runtimeStatus=[status.value for status in orchestration_query.runtime_status] if orchestration_query.runtime_status else None,
207
- createdTimeFrom=helpers.new_timestamp(orchestration_query.created_time_from) if orchestration_query.created_time_from else None,
208
- createdTimeTo=helpers.new_timestamp(orchestration_query.created_time_to) if orchestration_query.created_time_to else None,
209
- maxInstanceCount=orchestration_query.max_instance_count,
210
- fetchInputsAndOutputs=orchestration_query.fetch_inputs_and_outputs,
211
- continuationToken=_continuation_token
212
- )
213
- )
233
+ req = build_query_instances_req(orchestration_query, _continuation_token)
214
234
  resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req)
235
+ if self._payload_store is not None:
236
+ payload_helpers.deexternalize_payloads(resp, self._payload_store)
215
237
  states += [parse_orchestration_state(res) for res in resp.orchestrationState]
216
- # Check the value for continuationToken - none or "0" indicates that there are no more results.
217
- if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
218
- self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next list of instances...")
219
- if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
220
- self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
221
- break
238
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
222
239
  _continuation_token = resp.continuationToken
223
240
  else:
224
241
  break
@@ -232,6 +249,8 @@ class TaskHubGrpcClient:
232
249
  try:
233
250
  self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.")
234
251
  res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=timeout)
252
+ if self._payload_store is not None and res.exists:
253
+ payload_helpers.deexternalize_payloads(res, self._payload_store)
235
254
  return new_orchestration_state(req.instanceId, res)
236
255
  except grpc.RpcError as rpc_error:
237
256
  if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
@@ -247,54 +266,46 @@ class TaskHubGrpcClient:
247
266
  try:
248
267
  self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.")
249
268
  res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=timeout)
269
+ if self._payload_store is not None and res.exists:
270
+ payload_helpers.deexternalize_payloads(res, self._payload_store)
250
271
  state = new_orchestration_state(req.instanceId, res)
251
- if not state:
252
- return None
253
-
254
- if state.runtime_status == OrchestrationStatus.FAILED and state.failure_details is not None:
255
- details = state.failure_details
256
- self._logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}")
257
- elif state.runtime_status == OrchestrationStatus.TERMINATED:
258
- self._logger.info(f"Instance '{instance_id}' was terminated.")
259
- elif state.runtime_status == OrchestrationStatus.COMPLETED:
260
- self._logger.info(f"Instance '{instance_id}' completed.")
261
-
272
+ log_completion_state(self._logger, instance_id, state)
262
273
  return state
263
274
  except grpc.RpcError as rpc_error:
264
275
  if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
265
- # Replace gRPC error with the built-in TimeoutError
266
276
  raise TimeoutError("Timed-out waiting for the orchestration to complete")
267
277
  else:
268
278
  raise
269
279
 
270
280
  def raise_orchestration_event(self, instance_id: str, event_name: str, *,
271
- data: Optional[Any] = None):
272
- req = pb.RaiseEventRequest(
273
- instanceId=instance_id,
274
- name=event_name,
275
- input=helpers.get_string_value(shared.to_json(data) if data is not None else None)
276
- )
277
-
278
- self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
279
- self._stub.RaiseEvent(req)
281
+ data: Optional[Any] = None) -> None:
282
+ with tracing.start_raise_event_span(event_name, instance_id):
283
+ req = build_raise_event_req(instance_id, event_name, data)
284
+ self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
285
+ if self._payload_store is not None:
286
+ payload_helpers.externalize_payloads(
287
+ req, self._payload_store, instance_id=instance_id,
288
+ )
289
+ self._stub.RaiseEvent(req)
280
290
 
281
291
  def terminate_orchestration(self, instance_id: str, *,
282
292
  output: Optional[Any] = None,
283
- recursive: bool = True):
284
- req = pb.TerminateRequest(
285
- instanceId=instance_id,
286
- output=helpers.get_string_value(shared.to_json(output) if output is not None else None),
287
- recursive=recursive)
293
+ recursive: bool = True) -> None:
294
+ req = build_terminate_req(instance_id, output, recursive)
288
295
 
289
296
  self._logger.info(f"Terminating instance '{instance_id}'.")
297
+ if self._payload_store is not None:
298
+ payload_helpers.externalize_payloads(
299
+ req, self._payload_store, instance_id=instance_id,
300
+ )
290
301
  self._stub.TerminateInstance(req)
291
302
 
292
- def suspend_orchestration(self, instance_id: str):
303
+ def suspend_orchestration(self, instance_id: str) -> None:
293
304
  req = pb.SuspendRequest(instanceId=instance_id)
294
305
  self._logger.info(f"Suspending instance '{instance_id}'.")
295
306
  self._stub.SuspendInstance(req)
296
307
 
297
- def resume_orchestration(self, instance_id: str):
308
+ def resume_orchestration(self, instance_id: str) -> None:
298
309
  req = pb.ResumeRequest(instanceId=instance_id)
299
310
  self._logger.info(f"Resuming instance '{instance_id}'.")
300
311
  self._stub.ResumeInstance(req)
@@ -335,30 +346,20 @@ class TaskHubGrpcClient:
335
346
  f"created_time_to={created_time_to}, "
336
347
  f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
337
348
  f"recursive={recursive}")
338
- resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(pb.PurgeInstancesRequest(
339
- purgeInstanceFilter=pb.PurgeInstanceFilter(
340
- createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
341
- createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
342
- runtimeStatus=[status.value for status in runtime_status] if runtime_status else None
343
- ),
344
- recursive=recursive
345
- ))
349
+ req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive)
350
+ resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req)
346
351
  return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
347
352
 
348
353
  def signal_entity(self,
349
354
  entity_instance_id: EntityInstanceId,
350
355
  operation_name: str,
351
356
  input: Optional[Any] = None) -> None:
352
- req = pb.SignalEntityRequest(
353
- instanceId=str(entity_instance_id),
354
- name=operation_name,
355
- input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
356
- requestId=str(uuid.uuid4()),
357
- scheduledTime=None,
358
- parentTraceContext=None,
359
- requestTime=helpers.new_timestamp(datetime.now(timezone.utc))
360
- )
357
+ req = build_signal_entity_req(entity_instance_id, operation_name, input)
361
358
  self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.")
359
+ if self._payload_store is not None:
360
+ payload_helpers.externalize_payloads(
361
+ req, self._payload_store, instance_id=str(entity_instance_id),
362
+ )
362
363
  self._stub.SignalEntity(req, None) # TODO: Cancellation timeout?
363
364
 
364
365
  def get_entity(self,
@@ -370,7 +371,8 @@ class TaskHubGrpcClient:
370
371
  res: pb.GetEntityResponse = self._stub.GetEntity(req)
371
372
  if not res.exists:
372
373
  return None
373
-
374
+ if self._payload_store is not None:
375
+ payload_helpers.deexternalize_payloads(res, self._payload_store)
374
376
  return EntityMetadata.from_entity_metadata(res.entity, include_state)
375
377
 
376
378
  def get_all_entities(self,
@@ -384,24 +386,12 @@ class TaskHubGrpcClient:
384
386
  entities = []
385
387
 
386
388
  while True:
387
- query_request = pb.QueryEntitiesRequest(
388
- query=pb.EntityQuery(
389
- instanceIdStartsWith=helpers.get_string_value(entity_query.instance_id_starts_with),
390
- lastModifiedFrom=helpers.new_timestamp(entity_query.last_modified_from) if entity_query.last_modified_from else None,
391
- lastModifiedTo=helpers.new_timestamp(entity_query.last_modified_to) if entity_query.last_modified_to else None,
392
- includeState=entity_query.include_state,
393
- includeTransient=entity_query.include_transient,
394
- pageSize=helpers.get_int_value(entity_query.page_size),
395
- continuationToken=_continuation_token
396
- )
397
- )
389
+ query_request = build_query_entities_req(entity_query, _continuation_token)
398
390
  resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request)
391
+ if self._payload_store is not None:
392
+ payload_helpers.deexternalize_payloads(resp, self._payload_store)
399
393
  entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
400
- if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
401
- self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next page of entities...")
402
- if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
403
- self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
404
- break
394
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
405
395
  _continuation_token = resp.continuationToken
406
396
  else:
407
397
  break
@@ -427,11 +417,290 @@ class TaskHubGrpcClient:
427
417
  empty_entities_removed += resp.emptyEntitiesRemoved
428
418
  orphaned_locks_released += resp.orphanedLocksReleased
429
419
 
430
- if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
431
- self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, cleaning next page...")
432
- if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
433
- self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
434
- break
420
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
421
+ _continuation_token = resp.continuationToken
422
+ else:
423
+ break
424
+
425
+ return CleanEntityStorageResult(empty_entities_removed, orphaned_locks_released)
426
+
427
+
428
+ class AsyncTaskHubGrpcClient:
429
+ """Async version of TaskHubGrpcClient using grpc.aio for asyncio-based applications."""
430
+
431
+ def __init__(self, *,
432
+ host_address: Optional[str] = None,
433
+ metadata: Optional[list[tuple[str, str]]] = None,
434
+ log_handler: Optional[logging.Handler] = None,
435
+ log_formatter: Optional[logging.Formatter] = None,
436
+ secure_channel: bool = False,
437
+ interceptors: Optional[Sequence[shared.AsyncClientInterceptor]] = None,
438
+ default_version: Optional[str] = None,
439
+ payload_store: Optional[PayloadStore] = None):
440
+
441
+ interceptors = prepare_async_interceptors(metadata, interceptors)
442
+
443
+ channel = shared.get_async_grpc_channel(
444
+ host_address=host_address,
445
+ secure_channel=secure_channel,
446
+ interceptors=interceptors
447
+ )
448
+ self._channel = channel
449
+ self._stub = stubs.TaskHubSidecarServiceStub(channel)
450
+ self._logger = shared.get_logger("async_client", log_handler, log_formatter)
451
+ self.default_version = default_version
452
+ self._payload_store = payload_store
453
+
454
+ async def close(self) -> None:
455
+ """Closes the underlying gRPC channel."""
456
+ await self._channel.close()
457
+
458
+ async def __aenter__(self):
459
+ return self
460
+
461
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
462
+ await self.close()
463
+
464
+ async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
465
+ input: Optional[TInput] = None,
466
+ instance_id: Optional[str] = None,
467
+ start_at: Optional[datetime] = None,
468
+ reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None,
469
+ tags: Optional[dict[str, str]] = None,
470
+ version: Optional[str] = None) -> str:
471
+
472
+ name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)
473
+ resolved_instance_id = instance_id if instance_id else uuid.uuid4().hex
474
+ resolved_version = version if version else self.default_version
475
+
476
+ with tracing.start_create_orchestration_span(
477
+ name, resolved_instance_id, version=resolved_version,
478
+ ):
479
+ req = build_schedule_new_orchestration_req(
480
+ orchestrator, input=input, instance_id=instance_id, start_at=start_at,
481
+ reuse_id_policy=reuse_id_policy, tags=tags,
482
+ version=version if version else self.default_version)
483
+
484
+ parent_trace_ctx = tracing.get_current_trace_context()
485
+ if parent_trace_ctx is not None:
486
+ req.parentTraceContext.CopyFrom(parent_trace_ctx)
487
+
488
+ self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.")
489
+ # Externalize any large payloads in the request
490
+ if self._payload_store is not None:
491
+ await payload_helpers.externalize_payloads_async(
492
+ req, self._payload_store, instance_id=req.instanceId,
493
+ )
494
+ res: pb.CreateInstanceResponse = await self._stub.StartInstance(req)
495
+ return res.instanceId
496
+
497
+ async def get_orchestration_state(self, instance_id: str, *,
498
+ fetch_payloads: bool = True) -> Optional[OrchestrationState]:
499
+ req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
500
+ res: pb.GetInstanceResponse = await self._stub.GetInstance(req)
501
+ if self._payload_store is not None and res.exists:
502
+ await payload_helpers.deexternalize_payloads_async(res, self._payload_store)
503
+ return new_orchestration_state(req.instanceId, res)
504
+
505
+ async def get_all_orchestration_states(self,
506
+ orchestration_query: Optional[OrchestrationQuery] = None
507
+ ) -> List[OrchestrationState]:
508
+ if orchestration_query is None:
509
+ orchestration_query = OrchestrationQuery()
510
+ _continuation_token = None
511
+
512
+ self._logger.info(f"Querying orchestration instances with query: {orchestration_query}")
513
+
514
+ states = []
515
+
516
+ while True:
517
+ req = build_query_instances_req(orchestration_query, _continuation_token)
518
+ resp: pb.QueryInstancesResponse = await self._stub.QueryInstances(req)
519
+ if self._payload_store is not None:
520
+ await payload_helpers.deexternalize_payloads_async(resp, self._payload_store)
521
+ states += [parse_orchestration_state(res) for res in resp.orchestrationState]
522
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
523
+ _continuation_token = resp.continuationToken
524
+ else:
525
+ break
526
+
527
+ return states
528
+
529
+ async def wait_for_orchestration_start(self, instance_id: str, *,
530
+ fetch_payloads: bool = False,
531
+ timeout: int = 60) -> Optional[OrchestrationState]:
532
+ req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
533
+ try:
534
+ self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.")
535
+ res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=timeout)
536
+ if self._payload_store is not None and res.exists:
537
+ await payload_helpers.deexternalize_payloads_async(res, self._payload_store)
538
+ return new_orchestration_state(req.instanceId, res)
539
+ except grpc.aio.AioRpcError as rpc_error:
540
+ if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
541
+ raise TimeoutError("Timed-out waiting for the orchestration to start")
542
+ else:
543
+ raise
544
+
545
+ async def wait_for_orchestration_completion(self, instance_id: str, *,
546
+ fetch_payloads: bool = True,
547
+ timeout: int = 60) -> Optional[OrchestrationState]:
548
+ req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
549
+ try:
550
+ self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.")
551
+ res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=timeout)
552
+ if self._payload_store is not None and res.exists:
553
+ await payload_helpers.deexternalize_payloads_async(res, self._payload_store)
554
+ state = new_orchestration_state(req.instanceId, res)
555
+ log_completion_state(self._logger, instance_id, state)
556
+ return state
557
+ except grpc.aio.AioRpcError as rpc_error:
558
+ if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
559
+ raise TimeoutError("Timed-out waiting for the orchestration to complete")
560
+ else:
561
+ raise
562
+
563
+ async def raise_orchestration_event(self, instance_id: str, event_name: str, *,
564
+ data: Optional[Any] = None) -> None:
565
+ with tracing.start_raise_event_span(event_name, instance_id):
566
+ req = build_raise_event_req(instance_id, event_name, data)
567
+ self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
568
+ if self._payload_store is not None:
569
+ await payload_helpers.externalize_payloads_async(
570
+ req, self._payload_store, instance_id=instance_id,
571
+ )
572
+ await self._stub.RaiseEvent(req)
573
+
574
+ async def terminate_orchestration(self, instance_id: str, *,
575
+ output: Optional[Any] = None,
576
+ recursive: bool = True) -> None:
577
+ req = build_terminate_req(instance_id, output, recursive)
578
+
579
+ self._logger.info(f"Terminating instance '{instance_id}'.")
580
+ if self._payload_store is not None:
581
+ await payload_helpers.externalize_payloads_async(
582
+ req, self._payload_store, instance_id=instance_id,
583
+ )
584
+ await self._stub.TerminateInstance(req)
585
+
586
+ async def suspend_orchestration(self, instance_id: str) -> None:
587
+ req = pb.SuspendRequest(instanceId=instance_id)
588
+ self._logger.info(f"Suspending instance '{instance_id}'.")
589
+ await self._stub.SuspendInstance(req)
590
+
591
+ async def resume_orchestration(self, instance_id: str) -> None:
592
+ req = pb.ResumeRequest(instanceId=instance_id)
593
+ self._logger.info(f"Resuming instance '{instance_id}'.")
594
+ await self._stub.ResumeInstance(req)
595
+
596
+ async def restart_orchestration(self, instance_id: str, *,
597
+ restart_with_new_instance_id: bool = False) -> str:
598
+ """Restarts an existing orchestration instance.
599
+
600
+ Args:
601
+ instance_id: The ID of the orchestration instance to restart.
602
+ restart_with_new_instance_id: If True, the restarted orchestration will use a new instance ID.
603
+ If False (default), the restarted orchestration will reuse the same instance ID.
604
+
605
+ Returns:
606
+ The instance ID of the restarted orchestration.
607
+ """
608
+ req = pb.RestartInstanceRequest(
609
+ instanceId=instance_id,
610
+ restartWithNewInstanceId=restart_with_new_instance_id)
611
+
612
+ self._logger.info(f"Restarting instance '{instance_id}'.")
613
+ res: pb.RestartInstanceResponse = await self._stub.RestartInstance(req)
614
+ return res.instanceId
615
+
616
+ async def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
617
+ req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
618
+ self._logger.info(f"Purging instance '{instance_id}'.")
619
+ resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req)
620
+ return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
621
+
622
+ async def purge_orchestrations_by(self,
623
+ created_time_from: Optional[datetime] = None,
624
+ created_time_to: Optional[datetime] = None,
625
+ runtime_status: Optional[List[OrchestrationStatus]] = None,
626
+ recursive: bool = False) -> PurgeInstancesResult:
627
+ self._logger.info("Purging orchestrations by filter: "
628
+ f"created_time_from={created_time_from}, "
629
+ f"created_time_to={created_time_to}, "
630
+ f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
631
+ f"recursive={recursive}")
632
+ req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive)
633
+ resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req)
634
+ return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
635
+
636
+ async def signal_entity(self,
637
+ entity_instance_id: EntityInstanceId,
638
+ operation_name: str,
639
+ input: Optional[Any] = None) -> None:
640
+ req = build_signal_entity_req(entity_instance_id, operation_name, input)
641
+ self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.")
642
+ if self._payload_store is not None:
643
+ await payload_helpers.externalize_payloads_async(
644
+ req, self._payload_store, instance_id=str(entity_instance_id),
645
+ )
646
+ await self._stub.SignalEntity(req, None)
647
+
648
+ async def get_entity(self,
649
+ entity_instance_id: EntityInstanceId,
650
+ include_state: bool = True
651
+ ) -> Optional[EntityMetadata]:
652
+ req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state)
653
+ self._logger.info(f"Getting entity '{entity_instance_id}'.")
654
+ res: pb.GetEntityResponse = await self._stub.GetEntity(req)
655
+ if not res.exists:
656
+ return None
657
+ if self._payload_store is not None:
658
+ await payload_helpers.deexternalize_payloads_async(res, self._payload_store)
659
+ return EntityMetadata.from_entity_metadata(res.entity, include_state)
660
+
661
+ async def get_all_entities(self,
662
+ entity_query: Optional[EntityQuery] = None) -> List[EntityMetadata]:
663
+ if entity_query is None:
664
+ entity_query = EntityQuery()
665
+ _continuation_token = None
666
+
667
+ self._logger.info(f"Retrieving entities by filter: {entity_query}")
668
+
669
+ entities = []
670
+
671
+ while True:
672
+ query_request = build_query_entities_req(entity_query, _continuation_token)
673
+ resp: pb.QueryEntitiesResponse = await self._stub.QueryEntities(query_request)
674
+ if self._payload_store is not None:
675
+ await payload_helpers.deexternalize_payloads_async(resp, self._payload_store)
676
+ entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
677
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
678
+ _continuation_token = resp.continuationToken
679
+ else:
680
+ break
681
+ return entities
682
+
683
+ async def clean_entity_storage(self,
684
+ remove_empty_entities: bool = True,
685
+ release_orphaned_locks: bool = True
686
+ ) -> CleanEntityStorageResult:
687
+ self._logger.info("Cleaning entity storage")
688
+
689
+ empty_entities_removed = 0
690
+ orphaned_locks_released = 0
691
+ _continuation_token = None
692
+
693
+ while True:
694
+ req = pb.CleanEntityStorageRequest(
695
+ removeEmptyEntities=remove_empty_entities,
696
+ releaseOrphanedLocks=release_orphaned_locks,
697
+ continuationToken=_continuation_token
698
+ )
699
+ resp: pb.CleanEntityStorageResponse = await self._stub.CleanEntityStorage(req)
700
+ empty_entities_removed += resp.emptyEntitiesRemoved
701
+ orphaned_locks_released += resp.orphanedLocksReleased
702
+
703
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
435
704
  _continuation_token = resp.continuationToken
436
705
  else:
437
706
  break
@@ -0,0 +1,4 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ """Durable Task SDK extension packages."""