durabletask 1.3.0.dev21__tar.gz → 1.3.0.dev22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/PKG-INFO +1 -1
  2. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/client.py +283 -104
  3. durabletask-1.3.0.dev22/durabletask/internal/client_helpers.py +199 -0
  4. durabletask-1.3.0.dev22/durabletask/internal/grpc_interceptor.py +127 -0
  5. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/internal/shared.py +40 -0
  6. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask.egg-info/PKG-INFO +1 -1
  7. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask.egg-info/SOURCES.txt +1 -0
  8. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/pyproject.toml +2 -1
  9. durabletask-1.3.0.dev21/durabletask/internal/grpc_interceptor.py +0 -65
  10. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/LICENSE +0 -0
  11. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/README.md +0 -0
  12. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/__init__.py +0 -0
  13. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/entities/__init__.py +0 -0
  14. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/entities/durable_entity.py +0 -0
  15. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/entities/entity_context.py +0 -0
  16. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/entities/entity_instance_id.py +0 -0
  17. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/entities/entity_lock.py +0 -0
  18. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/entities/entity_metadata.py +0 -0
  19. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/entities/entity_operation_failed_exception.py +0 -0
  20. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/internal/entity_state_shim.py +0 -0
  21. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/internal/exceptions.py +0 -0
  22. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/internal/helpers.py +0 -0
  23. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/internal/json_encode_output_exception.py +0 -0
  24. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/internal/orchestration_entity_context.py +0 -0
  25. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/internal/orchestrator_service_pb2.py +0 -0
  26. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/internal/orchestrator_service_pb2.pyi +0 -0
  27. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/internal/orchestrator_service_pb2_grpc.py +0 -0
  28. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/internal/proto_task_hub_sidecar_service_stub.py +0 -0
  29. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/py.typed +0 -0
  30. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/task.py +0 -0
  31. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/testing/__init__.py +0 -0
  32. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/testing/in_memory_backend.py +0 -0
  33. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask/worker.py +0 -0
  34. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask.egg-info/dependency_links.txt +0 -0
  35. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask.egg-info/requires.txt +0 -0
  36. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/durabletask.egg-info/top_level.txt +0 -0
  37. {durabletask-1.3.0.dev21 → durabletask-1.3.0.dev22}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: durabletask
3
- Version: 1.3.0.dev21
3
+ Version: 1.3.0.dev22
4
4
  Summary: A Durable Task Client SDK for Python
5
5
  License: MIT License
6
6
 
@@ -2,13 +2,13 @@
2
2
  # Licensed under the MIT License.
3
3
 
4
4
  import logging
5
- import uuid
6
5
  from dataclasses import dataclass
7
- from datetime import datetime, timezone
6
+ from datetime import datetime
8
7
  from enum import Enum
9
8
  from typing import Any, List, Optional, Sequence, TypeVar, Union
10
9
 
11
10
  import grpc
11
+ import grpc.aio
12
12
 
13
13
  from durabletask.entities import EntityInstanceId
14
14
  from durabletask.entities.entity_metadata import EntityMetadata
@@ -17,7 +17,19 @@ import durabletask.internal.orchestrator_service_pb2 as pb
17
17
  import durabletask.internal.orchestrator_service_pb2_grpc as stubs
18
18
  import durabletask.internal.shared as shared
19
19
  from durabletask import task
20
- from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
20
+ from durabletask.internal.client_helpers import (
21
+ build_query_entities_req,
22
+ build_query_instances_req,
23
+ build_purge_by_filter_req,
24
+ build_raise_event_req,
25
+ build_schedule_new_orchestration_req,
26
+ build_signal_entity_req,
27
+ build_terminate_req,
28
+ check_continuation_token,
29
+ log_completion_state,
30
+ prepare_async_interceptors,
31
+ prepare_sync_interceptors,
32
+ )
21
33
 
22
34
  TInput = TypeVar('TInput')
23
35
  TOutput = TypeVar('TOutput')
@@ -140,26 +152,22 @@ class TaskHubGrpcClient:
140
152
  interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
141
153
  default_version: Optional[str] = None):
142
154
 
143
- # If the caller provided metadata, we need to create a new interceptor for it and
144
- # add it to the list of interceptors.
145
- if interceptors is not None:
146
- interceptors = list(interceptors)
147
- if metadata is not None:
148
- interceptors.append(DefaultClientInterceptorImpl(metadata))
149
- elif metadata is not None:
150
- interceptors = [DefaultClientInterceptorImpl(metadata)]
151
- else:
152
- interceptors = None
155
+ interceptors = prepare_sync_interceptors(metadata, interceptors)
153
156
 
154
157
  channel = shared.get_grpc_channel(
155
158
  host_address=host_address,
156
159
  secure_channel=secure_channel,
157
160
  interceptors=interceptors
158
161
  )
162
+ self._channel = channel
159
163
  self._stub = stubs.TaskHubSidecarServiceStub(channel)
160
164
  self._logger = shared.get_logger("client", log_handler, log_formatter)
161
165
  self.default_version = default_version
162
166
 
167
+ def close(self) -> None:
168
+ """Closes the underlying gRPC channel."""
169
+ self._channel.close()
170
+
163
171
  def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
164
172
  input: Optional[TInput] = None,
165
173
  instance_id: Optional[str] = None,
@@ -168,19 +176,12 @@ class TaskHubGrpcClient:
168
176
  tags: Optional[dict[str, str]] = None,
169
177
  version: Optional[str] = None) -> str:
170
178
 
171
- name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)
172
-
173
- req = pb.CreateInstanceRequest(
174
- name=name,
175
- instanceId=instance_id if instance_id else uuid.uuid4().hex,
176
- input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
177
- scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
178
- version=helpers.get_string_value(version if version else self.default_version),
179
- orchestrationIdReusePolicy=reuse_id_policy,
180
- tags=tags
181
- )
179
+ req = build_schedule_new_orchestration_req(
180
+ orchestrator, input=input, instance_id=instance_id, start_at=start_at,
181
+ reuse_id_policy=reuse_id_policy, tags=tags,
182
+ version=version if version else self.default_version)
182
183
 
183
- self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.")
184
+ self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.")
184
185
  res: pb.CreateInstanceResponse = self._stub.StartInstance(req)
185
186
  return res.instanceId
186
187
 
@@ -201,24 +202,10 @@ class TaskHubGrpcClient:
201
202
  states = []
202
203
 
203
204
  while True:
204
- req = pb.QueryInstancesRequest(
205
- query=pb.InstanceQuery(
206
- runtimeStatus=[status.value for status in orchestration_query.runtime_status] if orchestration_query.runtime_status else None,
207
- createdTimeFrom=helpers.new_timestamp(orchestration_query.created_time_from) if orchestration_query.created_time_from else None,
208
- createdTimeTo=helpers.new_timestamp(orchestration_query.created_time_to) if orchestration_query.created_time_to else None,
209
- maxInstanceCount=orchestration_query.max_instance_count,
210
- fetchInputsAndOutputs=orchestration_query.fetch_inputs_and_outputs,
211
- continuationToken=_continuation_token
212
- )
213
- )
205
+ req = build_query_instances_req(orchestration_query, _continuation_token)
214
206
  resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req)
215
207
  states += [parse_orchestration_state(res) for res in resp.orchestrationState]
216
- # Check the value for continuationToken - none or "0" indicates that there are no more results.
217
- if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
218
- self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next list of instances...")
219
- if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
220
- self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
221
- break
208
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
222
209
  _continuation_token = resp.continuationToken
223
210
  else:
224
211
  break
@@ -248,53 +235,35 @@ class TaskHubGrpcClient:
248
235
  self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.")
249
236
  res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=timeout)
250
237
  state = new_orchestration_state(req.instanceId, res)
251
- if not state:
252
- return None
253
-
254
- if state.runtime_status == OrchestrationStatus.FAILED and state.failure_details is not None:
255
- details = state.failure_details
256
- self._logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}")
257
- elif state.runtime_status == OrchestrationStatus.TERMINATED:
258
- self._logger.info(f"Instance '{instance_id}' was terminated.")
259
- elif state.runtime_status == OrchestrationStatus.COMPLETED:
260
- self._logger.info(f"Instance '{instance_id}' completed.")
261
-
238
+ log_completion_state(self._logger, instance_id, state)
262
239
  return state
263
240
  except grpc.RpcError as rpc_error:
264
241
  if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
265
- # Replace gRPC error with the built-in TimeoutError
266
242
  raise TimeoutError("Timed-out waiting for the orchestration to complete")
267
243
  else:
268
244
  raise
269
245
 
270
246
  def raise_orchestration_event(self, instance_id: str, event_name: str, *,
271
- data: Optional[Any] = None):
272
- req = pb.RaiseEventRequest(
273
- instanceId=instance_id,
274
- name=event_name,
275
- input=helpers.get_string_value(shared.to_json(data) if data is not None else None)
276
- )
247
+ data: Optional[Any] = None) -> None:
248
+ req = build_raise_event_req(instance_id, event_name, data)
277
249
 
278
250
  self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
279
251
  self._stub.RaiseEvent(req)
280
252
 
281
253
  def terminate_orchestration(self, instance_id: str, *,
282
254
  output: Optional[Any] = None,
283
- recursive: bool = True):
284
- req = pb.TerminateRequest(
285
- instanceId=instance_id,
286
- output=helpers.get_string_value(shared.to_json(output) if output is not None else None),
287
- recursive=recursive)
255
+ recursive: bool = True) -> None:
256
+ req = build_terminate_req(instance_id, output, recursive)
288
257
 
289
258
  self._logger.info(f"Terminating instance '{instance_id}'.")
290
259
  self._stub.TerminateInstance(req)
291
260
 
292
- def suspend_orchestration(self, instance_id: str):
261
+ def suspend_orchestration(self, instance_id: str) -> None:
293
262
  req = pb.SuspendRequest(instanceId=instance_id)
294
263
  self._logger.info(f"Suspending instance '{instance_id}'.")
295
264
  self._stub.SuspendInstance(req)
296
265
 
297
- def resume_orchestration(self, instance_id: str):
266
+ def resume_orchestration(self, instance_id: str) -> None:
298
267
  req = pb.ResumeRequest(instanceId=instance_id)
299
268
  self._logger.info(f"Resuming instance '{instance_id}'.")
300
269
  self._stub.ResumeInstance(req)
@@ -335,29 +304,15 @@ class TaskHubGrpcClient:
335
304
  f"created_time_to={created_time_to}, "
336
305
  f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
337
306
  f"recursive={recursive}")
338
- resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(pb.PurgeInstancesRequest(
339
- purgeInstanceFilter=pb.PurgeInstanceFilter(
340
- createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
341
- createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
342
- runtimeStatus=[status.value for status in runtime_status] if runtime_status else None
343
- ),
344
- recursive=recursive
345
- ))
307
+ req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive)
308
+ resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req)
346
309
  return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
347
310
 
348
311
  def signal_entity(self,
349
312
  entity_instance_id: EntityInstanceId,
350
313
  operation_name: str,
351
314
  input: Optional[Any] = None) -> None:
352
- req = pb.SignalEntityRequest(
353
- instanceId=str(entity_instance_id),
354
- name=operation_name,
355
- input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
356
- requestId=str(uuid.uuid4()),
357
- scheduledTime=None,
358
- parentTraceContext=None,
359
- requestTime=helpers.new_timestamp(datetime.now(timezone.utc))
360
- )
315
+ req = build_signal_entity_req(entity_instance_id, operation_name, input)
361
316
  self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.")
362
317
  self._stub.SignalEntity(req, None) # TODO: Cancellation timeout?
363
318
 
@@ -384,24 +339,10 @@ class TaskHubGrpcClient:
384
339
  entities = []
385
340
 
386
341
  while True:
387
- query_request = pb.QueryEntitiesRequest(
388
- query=pb.EntityQuery(
389
- instanceIdStartsWith=helpers.get_string_value(entity_query.instance_id_starts_with),
390
- lastModifiedFrom=helpers.new_timestamp(entity_query.last_modified_from) if entity_query.last_modified_from else None,
391
- lastModifiedTo=helpers.new_timestamp(entity_query.last_modified_to) if entity_query.last_modified_to else None,
392
- includeState=entity_query.include_state,
393
- includeTransient=entity_query.include_transient,
394
- pageSize=helpers.get_int_value(entity_query.page_size),
395
- continuationToken=_continuation_token
396
- )
397
- )
342
+ query_request = build_query_entities_req(entity_query, _continuation_token)
398
343
  resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request)
399
344
  entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
400
- if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
401
- self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next page of entities...")
402
- if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
403
- self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
404
- break
345
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
405
346
  _continuation_token = resp.continuationToken
406
347
  else:
407
348
  break
@@ -427,11 +368,249 @@ class TaskHubGrpcClient:
427
368
  empty_entities_removed += resp.emptyEntitiesRemoved
428
369
  orphaned_locks_released += resp.orphanedLocksReleased
429
370
 
430
- if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0":
431
- self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, cleaning next page...")
432
- if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value:
433
- self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.")
434
- break
371
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
372
+ _continuation_token = resp.continuationToken
373
+ else:
374
+ break
375
+
376
+ return CleanEntityStorageResult(empty_entities_removed, orphaned_locks_released)
377
+
378
+
379
+ class AsyncTaskHubGrpcClient:
380
+ """Async version of TaskHubGrpcClient using grpc.aio for asyncio-based applications."""
381
+
382
+ def __init__(self, *,
383
+ host_address: Optional[str] = None,
384
+ metadata: Optional[list[tuple[str, str]]] = None,
385
+ log_handler: Optional[logging.Handler] = None,
386
+ log_formatter: Optional[logging.Formatter] = None,
387
+ secure_channel: bool = False,
388
+ interceptors: Optional[Sequence[shared.AsyncClientInterceptor]] = None,
389
+ default_version: Optional[str] = None):
390
+
391
+ interceptors = prepare_async_interceptors(metadata, interceptors)
392
+
393
+ channel = shared.get_async_grpc_channel(
394
+ host_address=host_address,
395
+ secure_channel=secure_channel,
396
+ interceptors=interceptors
397
+ )
398
+ self._channel = channel
399
+ self._stub = stubs.TaskHubSidecarServiceStub(channel)
400
+ self._logger = shared.get_logger("async_client", log_handler, log_formatter)
401
+ self.default_version = default_version
402
+
403
+ async def close(self) -> None:
404
+ """Closes the underlying gRPC channel."""
405
+ await self._channel.close()
406
+
407
+ async def __aenter__(self):
408
+ return self
409
+
410
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
411
+ await self.close()
412
+
413
+ async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
414
+ input: Optional[TInput] = None,
415
+ instance_id: Optional[str] = None,
416
+ start_at: Optional[datetime] = None,
417
+ reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None,
418
+ tags: Optional[dict[str, str]] = None,
419
+ version: Optional[str] = None) -> str:
420
+
421
+ req = build_schedule_new_orchestration_req(
422
+ orchestrator, input=input, instance_id=instance_id, start_at=start_at,
423
+ reuse_id_policy=reuse_id_policy, tags=tags,
424
+ version=version if version else self.default_version)
425
+
426
+ self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.")
427
+ res: pb.CreateInstanceResponse = await self._stub.StartInstance(req)
428
+ return res.instanceId
429
+
430
+ async def get_orchestration_state(self, instance_id: str, *,
431
+ fetch_payloads: bool = True) -> Optional[OrchestrationState]:
432
+ req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
433
+ res: pb.GetInstanceResponse = await self._stub.GetInstance(req)
434
+ return new_orchestration_state(req.instanceId, res)
435
+
436
+ async def get_all_orchestration_states(self,
437
+ orchestration_query: Optional[OrchestrationQuery] = None
438
+ ) -> List[OrchestrationState]:
439
+ if orchestration_query is None:
440
+ orchestration_query = OrchestrationQuery()
441
+ _continuation_token = None
442
+
443
+ self._logger.info(f"Querying orchestration instances with query: {orchestration_query}")
444
+
445
+ states = []
446
+
447
+ while True:
448
+ req = build_query_instances_req(orchestration_query, _continuation_token)
449
+ resp: pb.QueryInstancesResponse = await self._stub.QueryInstances(req)
450
+ states += [parse_orchestration_state(res) for res in resp.orchestrationState]
451
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
452
+ _continuation_token = resp.continuationToken
453
+ else:
454
+ break
455
+
456
+ return states
457
+
458
+ async def wait_for_orchestration_start(self, instance_id: str, *,
459
+ fetch_payloads: bool = False,
460
+ timeout: int = 60) -> Optional[OrchestrationState]:
461
+ req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
462
+ try:
463
+ self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.")
464
+ res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=timeout)
465
+ return new_orchestration_state(req.instanceId, res)
466
+ except grpc.aio.AioRpcError as rpc_error:
467
+ if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
468
+ raise TimeoutError("Timed-out waiting for the orchestration to start")
469
+ else:
470
+ raise
471
+
472
+ async def wait_for_orchestration_completion(self, instance_id: str, *,
473
+ fetch_payloads: bool = True,
474
+ timeout: int = 60) -> Optional[OrchestrationState]:
475
+ req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
476
+ try:
477
+ self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.")
478
+ res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=timeout)
479
+ state = new_orchestration_state(req.instanceId, res)
480
+ log_completion_state(self._logger, instance_id, state)
481
+ return state
482
+ except grpc.aio.AioRpcError as rpc_error:
483
+ if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
484
+ raise TimeoutError("Timed-out waiting for the orchestration to complete")
485
+ else:
486
+ raise
487
+
488
+ async def raise_orchestration_event(self, instance_id: str, event_name: str, *,
489
+ data: Optional[Any] = None) -> None:
490
+ req = build_raise_event_req(instance_id, event_name, data)
491
+
492
+ self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
493
+ await self._stub.RaiseEvent(req)
494
+
495
+ async def terminate_orchestration(self, instance_id: str, *,
496
+ output: Optional[Any] = None,
497
+ recursive: bool = True) -> None:
498
+ req = build_terminate_req(instance_id, output, recursive)
499
+
500
+ self._logger.info(f"Terminating instance '{instance_id}'.")
501
+ await self._stub.TerminateInstance(req)
502
+
503
+ async def suspend_orchestration(self, instance_id: str) -> None:
504
+ req = pb.SuspendRequest(instanceId=instance_id)
505
+ self._logger.info(f"Suspending instance '{instance_id}'.")
506
+ await self._stub.SuspendInstance(req)
507
+
508
+ async def resume_orchestration(self, instance_id: str) -> None:
509
+ req = pb.ResumeRequest(instanceId=instance_id)
510
+ self._logger.info(f"Resuming instance '{instance_id}'.")
511
+ await self._stub.ResumeInstance(req)
512
+
513
+ async def restart_orchestration(self, instance_id: str, *,
514
+ restart_with_new_instance_id: bool = False) -> str:
515
+ """Restarts an existing orchestration instance.
516
+
517
+ Args:
518
+ instance_id: The ID of the orchestration instance to restart.
519
+ restart_with_new_instance_id: If True, the restarted orchestration will use a new instance ID.
520
+ If False (default), the restarted orchestration will reuse the same instance ID.
521
+
522
+ Returns:
523
+ The instance ID of the restarted orchestration.
524
+ """
525
+ req = pb.RestartInstanceRequest(
526
+ instanceId=instance_id,
527
+ restartWithNewInstanceId=restart_with_new_instance_id)
528
+
529
+ self._logger.info(f"Restarting instance '{instance_id}'.")
530
+ res: pb.RestartInstanceResponse = await self._stub.RestartInstance(req)
531
+ return res.instanceId
532
+
533
+ async def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult:
534
+ req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
535
+ self._logger.info(f"Purging instance '{instance_id}'.")
536
+ resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req)
537
+ return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
538
+
539
+ async def purge_orchestrations_by(self,
540
+ created_time_from: Optional[datetime] = None,
541
+ created_time_to: Optional[datetime] = None,
542
+ runtime_status: Optional[List[OrchestrationStatus]] = None,
543
+ recursive: bool = False) -> PurgeInstancesResult:
544
+ self._logger.info("Purging orchestrations by filter: "
545
+ f"created_time_from={created_time_from}, "
546
+ f"created_time_to={created_time_to}, "
547
+ f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, "
548
+ f"recursive={recursive}")
549
+ req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive)
550
+ resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req)
551
+ return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value)
552
+
553
+ async def signal_entity(self,
554
+ entity_instance_id: EntityInstanceId,
555
+ operation_name: str,
556
+ input: Optional[Any] = None) -> None:
557
+ req = build_signal_entity_req(entity_instance_id, operation_name, input)
558
+ self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.")
559
+ await self._stub.SignalEntity(req, None)
560
+
561
+ async def get_entity(self,
562
+ entity_instance_id: EntityInstanceId,
563
+ include_state: bool = True
564
+ ) -> Optional[EntityMetadata]:
565
+ req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state)
566
+ self._logger.info(f"Getting entity '{entity_instance_id}'.")
567
+ res: pb.GetEntityResponse = await self._stub.GetEntity(req)
568
+ if not res.exists:
569
+ return None
570
+
571
+ return EntityMetadata.from_entity_metadata(res.entity, include_state)
572
+
573
+ async def get_all_entities(self,
574
+ entity_query: Optional[EntityQuery] = None) -> List[EntityMetadata]:
575
+ if entity_query is None:
576
+ entity_query = EntityQuery()
577
+ _continuation_token = None
578
+
579
+ self._logger.info(f"Retrieving entities by filter: {entity_query}")
580
+
581
+ entities = []
582
+
583
+ while True:
584
+ query_request = build_query_entities_req(entity_query, _continuation_token)
585
+ resp: pb.QueryEntitiesResponse = await self._stub.QueryEntities(query_request)
586
+ entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities]
587
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
588
+ _continuation_token = resp.continuationToken
589
+ else:
590
+ break
591
+ return entities
592
+
593
+ async def clean_entity_storage(self,
594
+ remove_empty_entities: bool = True,
595
+ release_orphaned_locks: bool = True
596
+ ) -> CleanEntityStorageResult:
597
+ self._logger.info("Cleaning entity storage")
598
+
599
+ empty_entities_removed = 0
600
+ orphaned_locks_released = 0
601
+ _continuation_token = None
602
+
603
+ while True:
604
+ req = pb.CleanEntityStorageRequest(
605
+ removeEmptyEntities=remove_empty_entities,
606
+ releaseOrphanedLocks=release_orphaned_locks,
607
+ continuationToken=_continuation_token
608
+ )
609
+ resp: pb.CleanEntityStorageResponse = await self._stub.CleanEntityStorage(req)
610
+ empty_entities_removed += resp.emptyEntitiesRemoved
611
+ orphaned_locks_released += resp.orphanedLocksReleased
612
+
613
+ if check_continuation_token(resp.continuationToken, _continuation_token, self._logger):
435
614
  _continuation_token = resp.continuationToken
436
615
  else:
437
616
  break
@@ -0,0 +1,199 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from __future__ import annotations
5
+
6
+ import logging
7
+ import uuid
8
+ from datetime import datetime, timezone
9
+ from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union
10
+
11
+ import durabletask.internal.helpers as helpers
12
+ import durabletask.internal.orchestrator_service_pb2 as pb
13
+ import durabletask.internal.shared as shared
14
+ from durabletask import task
15
+ from durabletask.internal.grpc_interceptor import (
16
+ DefaultAsyncClientInterceptorImpl,
17
+ DefaultClientInterceptorImpl,
18
+ )
19
+
20
+ if TYPE_CHECKING:
21
+ from durabletask.client import (
22
+ EntityQuery,
23
+ OrchestrationQuery,
24
+ OrchestrationState,
25
+ OrchestrationStatus,
26
+ )
27
+ from durabletask.entities import EntityInstanceId
28
+
29
+ TInput = TypeVar('TInput')
30
+ TOutput = TypeVar('TOutput')
31
+
32
+
33
+ def prepare_sync_interceptors(
34
+ metadata: Optional[list[tuple[str, str]]],
35
+ interceptors: Optional[Sequence[shared.ClientInterceptor]]
36
+ ) -> Optional[list[shared.ClientInterceptor]]:
37
+ """Prepare the list of sync gRPC interceptors, adding a metadata interceptor if needed."""
38
+ result: Optional[list[shared.ClientInterceptor]] = None
39
+ if interceptors is not None:
40
+ result = list(interceptors)
41
+ if metadata is not None:
42
+ result.append(DefaultClientInterceptorImpl(metadata))
43
+ elif metadata is not None:
44
+ result = [DefaultClientInterceptorImpl(metadata)]
45
+ return result
46
+
47
+
48
+ def prepare_async_interceptors(
49
+ metadata: Optional[list[tuple[str, str]]],
50
+ interceptors: Optional[Sequence[shared.AsyncClientInterceptor]]
51
+ ) -> Optional[list[shared.AsyncClientInterceptor]]:
52
+ """Prepare the list of async gRPC interceptors, adding a metadata interceptor if needed."""
53
+ result: Optional[list[shared.AsyncClientInterceptor]] = None
54
+ if interceptors is not None:
55
+ result = list(interceptors)
56
+ if metadata is not None:
57
+ result.append(DefaultAsyncClientInterceptorImpl(metadata))
58
+ elif metadata is not None:
59
+ result = [DefaultAsyncClientInterceptorImpl(metadata)]
60
+ return result
61
+
62
+
63
+ def build_schedule_new_orchestration_req(
64
+ orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
65
+ input: Optional[TInput],
66
+ instance_id: Optional[str],
67
+ start_at: Optional[datetime],
68
+ reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy],
69
+ tags: Optional[dict[str, str]],
70
+ version: Optional[str]) -> pb.CreateInstanceRequest:
71
+ """Build a CreateInstanceRequest for scheduling a new orchestration."""
72
+ name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)
73
+ return pb.CreateInstanceRequest(
74
+ name=name,
75
+ instanceId=instance_id if instance_id else uuid.uuid4().hex,
76
+ input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
77
+ scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
78
+ version=helpers.get_string_value(version),
79
+ orchestrationIdReusePolicy=reuse_id_policy,
80
+ tags=tags
81
+ )
82
+
83
+
84
+ def build_query_instances_req(
85
+ orchestration_query: OrchestrationQuery,
86
+ continuation_token) -> pb.QueryInstancesRequest:
87
+ """Build a QueryInstancesRequest from an OrchestrationQuery."""
88
+ return pb.QueryInstancesRequest(
89
+ query=pb.InstanceQuery(
90
+ runtimeStatus=[status.value for status in orchestration_query.runtime_status] if orchestration_query.runtime_status else None,
91
+ createdTimeFrom=helpers.new_timestamp(orchestration_query.created_time_from) if orchestration_query.created_time_from else None,
92
+ createdTimeTo=helpers.new_timestamp(orchestration_query.created_time_to) if orchestration_query.created_time_to else None,
93
+ maxInstanceCount=orchestration_query.max_instance_count,
94
+ fetchInputsAndOutputs=orchestration_query.fetch_inputs_and_outputs,
95
+ continuationToken=continuation_token
96
+ )
97
+ )
98
+
99
+
100
+ def build_purge_by_filter_req(
101
+ created_time_from: Optional[datetime],
102
+ created_time_to: Optional[datetime],
103
+ runtime_status: Optional[List[OrchestrationStatus]],
104
+ recursive: bool) -> pb.PurgeInstancesRequest:
105
+ """Build a PurgeInstancesRequest for purging orchestrations by filter."""
106
+ return pb.PurgeInstancesRequest(
107
+ purgeInstanceFilter=pb.PurgeInstanceFilter(
108
+ createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None,
109
+ createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None,
110
+ runtimeStatus=[status.value for status in runtime_status] if runtime_status else None
111
+ ),
112
+ recursive=recursive
113
+ )
114
+
115
+
116
+ def build_query_entities_req(
117
+ entity_query: EntityQuery,
118
+ continuation_token) -> pb.QueryEntitiesRequest:
119
+ """Build a QueryEntitiesRequest from an EntityQuery."""
120
+ return pb.QueryEntitiesRequest(
121
+ query=pb.EntityQuery(
122
+ instanceIdStartsWith=helpers.get_string_value(entity_query.instance_id_starts_with),
123
+ lastModifiedFrom=helpers.new_timestamp(entity_query.last_modified_from) if entity_query.last_modified_from else None,
124
+ lastModifiedTo=helpers.new_timestamp(entity_query.last_modified_to) if entity_query.last_modified_to else None,
125
+ includeState=entity_query.include_state,
126
+ includeTransient=entity_query.include_transient,
127
+ pageSize=helpers.get_int_value(entity_query.page_size),
128
+ continuationToken=continuation_token
129
+ )
130
+ )
131
+
132
+
133
+ def check_continuation_token(resp_token, prev_token, logger: logging.Logger) -> bool:
134
+ """Check if a continuation token indicates more pages. Returns True to continue, False to stop."""
135
+ if resp_token and resp_token.value and resp_token.value != "0":
136
+ logger.info(f"Received continuation token with value {resp_token.value}, fetching next page...")
137
+ if prev_token and prev_token.value and prev_token.value == resp_token.value:
138
+ logger.warning(f"Received the same continuation token value {resp_token.value} again, stopping to avoid infinite loop.")
139
+ return False
140
+ return True
141
+ return False
142
+
143
+
144
+ def log_completion_state(
145
+ logger: logging.Logger,
146
+ instance_id: str,
147
+ state: Optional[OrchestrationState]):
148
+ """Log the final state of a completed orchestration."""
149
+ if not state:
150
+ return
151
+ # Compare against proto constants to avoid circular imports with client.py
152
+ status_val = state.runtime_status.value
153
+ if status_val == pb.ORCHESTRATION_STATUS_FAILED and state.failure_details is not None:
154
+ details = state.failure_details
155
+ logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}")
156
+ elif status_val == pb.ORCHESTRATION_STATUS_TERMINATED:
157
+ logger.info(f"Instance '{instance_id}' was terminated.")
158
+ elif status_val == pb.ORCHESTRATION_STATUS_COMPLETED:
159
+ logger.info(f"Instance '{instance_id}' completed.")
160
+
161
+
162
+ def build_raise_event_req(
163
+ instance_id: str,
164
+ event_name: str,
165
+ data: Optional[Any] = None) -> pb.RaiseEventRequest:
166
+ """Build a RaiseEventRequest for raising an orchestration event."""
167
+ return pb.RaiseEventRequest(
168
+ instanceId=instance_id,
169
+ name=event_name,
170
+ input=helpers.get_string_value(shared.to_json(data) if data is not None else None)
171
+ )
172
+
173
+
174
+ def build_terminate_req(
175
+ instance_id: str,
176
+ output: Optional[Any] = None,
177
+ recursive: bool = True) -> pb.TerminateRequest:
178
+ """Build a TerminateRequest for terminating an orchestration."""
179
+ return pb.TerminateRequest(
180
+ instanceId=instance_id,
181
+ output=helpers.get_string_value(shared.to_json(output) if output is not None else None),
182
+ recursive=recursive
183
+ )
184
+
185
+
186
+ def build_signal_entity_req(
187
+ entity_instance_id: EntityInstanceId,
188
+ operation_name: str,
189
+ input: Optional[Any] = None) -> pb.SignalEntityRequest:
190
+ """Build a SignalEntityRequest for signaling an entity."""
191
+ return pb.SignalEntityRequest(
192
+ instanceId=str(entity_instance_id),
193
+ name=operation_name,
194
+ input=helpers.get_string_value(shared.to_json(input) if input is not None else None),
195
+ requestId=str(uuid.uuid4()),
196
+ scheduledTime=None,
197
+ parentTraceContext=None,
198
+ requestTime=helpers.new_timestamp(datetime.now(timezone.utc))
199
+ )
@@ -0,0 +1,127 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from collections import namedtuple
5
+
6
+ import grpc
7
+ import grpc.aio
8
+
9
+
10
+ class _ClientCallDetails(
11
+ namedtuple(
12
+ '_ClientCallDetails',
13
+ ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']),
14
+ grpc.ClientCallDetails):
15
+ """This is an implementation of the ClientCallDetails interface needed for interceptors.
16
+ This class takes six named values and inherits the ClientCallDetails from grpc package.
17
+ This class encloses the values that describe a RPC to be invoked.
18
+ """
19
+ pass
20
+
21
+
22
+ class _AsyncClientCallDetails(
23
+ namedtuple(
24
+ '_AsyncClientCallDetails',
25
+ ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready']),
26
+ grpc.aio.ClientCallDetails):
27
+ """This is an implementation of the aio ClientCallDetails interface needed for async interceptors.
28
+ This class takes five named values and inherits the ClientCallDetails from grpc.aio package.
29
+ This class encloses the values that describe a RPC to be invoked.
30
+ """
31
+ pass
32
+
33
+
34
+ def _apply_metadata(client_call_details, metadata):
35
+ """Shared logic for applying metadata to call details. Returns the updated metadata list."""
36
+ if metadata is None:
37
+ return client_call_details.metadata
38
+
39
+ if client_call_details.metadata is not None:
40
+ new_metadata = list(client_call_details.metadata)
41
+ else:
42
+ new_metadata = []
43
+
44
+ new_metadata.extend(metadata)
45
+ return new_metadata
46
+
47
+
48
+ class DefaultClientInterceptorImpl (
49
+ grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor,
50
+ grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor):
51
+ """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
52
+ StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an
53
+ interceptor to add additional headers to all calls as needed."""
54
+
55
+ def __init__(self, metadata: list[tuple[str, str]]):
56
+ super().__init__()
57
+ self._metadata = metadata
58
+
59
+ def _intercept_call(
60
+ self, client_call_details: grpc.ClientCallDetails) -> grpc.ClientCallDetails:
61
+ """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
62
+ call details."""
63
+ new_metadata = _apply_metadata(client_call_details, self._metadata)
64
+ if new_metadata is client_call_details.metadata:
65
+ return client_call_details
66
+
67
+ return _ClientCallDetails(
68
+ client_call_details.method, client_call_details.timeout, new_metadata,
69
+ client_call_details.credentials, client_call_details.wait_for_ready, client_call_details.compression)
70
+
71
+ def intercept_unary_unary(self, continuation, client_call_details, request):
72
+ new_client_call_details = self._intercept_call(client_call_details)
73
+ return continuation(new_client_call_details, request)
74
+
75
+ def intercept_unary_stream(self, continuation, client_call_details, request):
76
+ new_client_call_details = self._intercept_call(client_call_details)
77
+ return continuation(new_client_call_details, request)
78
+
79
+ def intercept_stream_unary(self, continuation, client_call_details, request):
80
+ new_client_call_details = self._intercept_call(client_call_details)
81
+ return continuation(new_client_call_details, request)
82
+
83
+ def intercept_stream_stream(self, continuation, client_call_details, request):
84
+ new_client_call_details = self._intercept_call(client_call_details)
85
+ return continuation(new_client_call_details, request)
86
+
87
+
88
+ class DefaultAsyncClientInterceptorImpl(
89
+ grpc.aio.UnaryUnaryClientInterceptor, grpc.aio.UnaryStreamClientInterceptor,
90
+ grpc.aio.StreamUnaryClientInterceptor, grpc.aio.StreamStreamClientInterceptor):
91
+ """Async gRPC interceptor that adds metadata headers to all calls."""
92
+
93
+ def __init__(self, metadata: list[tuple[str, str]]):
94
+ self._metadata = metadata
95
+
96
+ async def _intercept_call(
97
+ self, client_call_details: grpc.aio.ClientCallDetails) -> grpc.aio.ClientCallDetails:
98
+ """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
99
+ call details. This method is async to allow subclasses to perform async operations
100
+ (e.g., refreshing auth tokens) during interception."""
101
+ new_metadata = _apply_metadata(client_call_details, self._metadata)
102
+ if new_metadata is client_call_details.metadata:
103
+ return client_call_details
104
+
105
+ return _AsyncClientCallDetails(
106
+ client_call_details.method,
107
+ client_call_details.timeout,
108
+ new_metadata,
109
+ client_call_details.credentials,
110
+ client_call_details.wait_for_ready,
111
+ )
112
+
113
+ async def intercept_unary_unary(self, continuation, client_call_details, request):
114
+ new_client_call_details = await self._intercept_call(client_call_details)
115
+ return await continuation(new_client_call_details, request)
116
+
117
+ async def intercept_unary_stream(self, continuation, client_call_details, request):
118
+ new_client_call_details = await self._intercept_call(client_call_details)
119
+ return await continuation(new_client_call_details, request)
120
+
121
+ async def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
122
+ new_client_call_details = await self._intercept_call(client_call_details)
123
+ return await continuation(new_client_call_details, request_iterator)
124
+
125
+ async def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
126
+ new_client_call_details = await self._intercept_call(client_call_details)
127
+ return await continuation(new_client_call_details, request_iterator)
@@ -8,6 +8,7 @@ from types import SimpleNamespace
8
8
  from typing import Any, Optional, Sequence, Union
9
9
 
10
10
  import grpc
11
+ import grpc.aio
11
12
 
12
13
  ClientInterceptor = Union[
13
14
  grpc.UnaryUnaryClientInterceptor,
@@ -16,6 +17,13 @@ ClientInterceptor = Union[
16
17
  grpc.StreamStreamClientInterceptor
17
18
  ]
18
19
 
20
+ AsyncClientInterceptor = Union[
21
+ grpc.aio.UnaryUnaryClientInterceptor,
22
+ grpc.aio.UnaryStreamClientInterceptor,
23
+ grpc.aio.StreamUnaryClientInterceptor,
24
+ grpc.aio.StreamStreamClientInterceptor
25
+ ]
26
+
19
27
  # Field name used to indicate that an object was automatically serialized
20
28
  # and should be deserialized as a SimpleNamespace
21
29
  AUTO_SERIALIZED = "__durabletask_autoobject__"
@@ -62,6 +70,38 @@ def get_grpc_channel(
62
70
  return channel
63
71
 
64
72
 
73
+ def get_async_grpc_channel(
74
+ host_address: Optional[str],
75
+ secure_channel: bool = False,
76
+ interceptors: Optional[Sequence[AsyncClientInterceptor]] = None) -> grpc.aio.Channel:
77
+
78
+ if host_address is None:
79
+ host_address = get_default_host_address()
80
+
81
+ for protocol in SECURE_PROTOCOLS:
82
+ if host_address.lower().startswith(protocol):
83
+ secure_channel = True
84
+ host_address = host_address[len(protocol):]
85
+ break
86
+
87
+ for protocol in INSECURE_PROTOCOLS:
88
+ if host_address.lower().startswith(protocol):
89
+ secure_channel = False
90
+ host_address = host_address[len(protocol):]
91
+ break
92
+
93
+ if secure_channel:
94
+ channel = grpc.aio.secure_channel(
95
+ host_address, grpc.ssl_channel_credentials(),
96
+ interceptors=interceptors)
97
+ else:
98
+ channel = grpc.aio.insecure_channel(
99
+ host_address,
100
+ interceptors=interceptors)
101
+
102
+ return channel
103
+
104
+
65
105
  def get_logger(
66
106
  name_suffix: str,
67
107
  log_handler: Optional[logging.Handler] = None,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: durabletask
3
- Version: 1.3.0.dev21
3
+ Version: 1.3.0.dev22
4
4
  Summary: A Durable Task Client SDK for Python
5
5
  License: MIT License
6
6
 
@@ -18,6 +18,7 @@ durabletask/entities/entity_instance_id.py
18
18
  durabletask/entities/entity_lock.py
19
19
  durabletask/entities/entity_metadata.py
20
20
  durabletask/entities/entity_operation_failed_exception.py
21
+ durabletask/internal/client_helpers.py
21
22
  durabletask/internal/entity_state_shim.py
22
23
  durabletask/internal/exceptions.py
23
24
  durabletask/internal/grpc_interceptor.py
@@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"
9
9
 
10
10
  [project]
11
11
  name = "durabletask"
12
- version = "1.3.0.dev21"
12
+ version = "1.3.0.dev22"
13
13
  description = "A Durable Task Client SDK for Python"
14
14
  keywords = [
15
15
  "durable",
@@ -41,3 +41,4 @@ include = ["durabletask", "durabletask.*"]
41
41
  [tool.pytest.ini_options]
42
42
  minversion = "6.0"
43
43
  testpaths = ["tests"]
44
+ asyncio_mode = "auto"
@@ -1,65 +0,0 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
3
-
4
- from collections import namedtuple
5
-
6
- import grpc
7
-
8
-
9
- class _ClientCallDetails(
10
- namedtuple(
11
- '_ClientCallDetails',
12
- ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']),
13
- grpc.ClientCallDetails):
14
- """This is an implementation of the ClientCallDetails interface needed for interceptors.
15
- This class takes six named values and inherits the ClientCallDetails from grpc package.
16
- This class encloses the values that describe a RPC to be invoked.
17
- """
18
- pass
19
-
20
-
21
- class DefaultClientInterceptorImpl (
22
- grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor,
23
- grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor):
24
- """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
25
- StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an
26
- interceptor to add additional headers to all calls as needed."""
27
-
28
- def __init__(self, metadata: list[tuple[str, str]]):
29
- super().__init__()
30
- self._metadata = metadata
31
-
32
- def _intercept_call(
33
- self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
34
- """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
35
- call details."""
36
- if self._metadata is None:
37
- return client_call_details
38
-
39
- if client_call_details.metadata is not None:
40
- metadata = list(client_call_details.metadata)
41
- else:
42
- metadata = []
43
-
44
- metadata.extend(self._metadata)
45
- client_call_details = _ClientCallDetails(
46
- client_call_details.method, client_call_details.timeout, metadata,
47
- client_call_details.credentials, client_call_details.wait_for_ready, client_call_details.compression)
48
-
49
- return client_call_details
50
-
51
- def intercept_unary_unary(self, continuation, client_call_details, request):
52
- new_client_call_details = self._intercept_call(client_call_details)
53
- return continuation(new_client_call_details, request)
54
-
55
- def intercept_unary_stream(self, continuation, client_call_details, request):
56
- new_client_call_details = self._intercept_call(client_call_details)
57
- return continuation(new_client_call_details, request)
58
-
59
- def intercept_stream_unary(self, continuation, client_call_details, request):
60
- new_client_call_details = self._intercept_call(client_call_details)
61
- return continuation(new_client_call_details, request)
62
-
63
- def intercept_stream_stream(self, continuation, client_call_details, request):
64
- new_client_call_details = self._intercept_call(client_call_details)
65
- return continuation(new_client_call_details, request)