@temporalio/core-bridge 1.6.0 → 1.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. package/Cargo.lock +520 -456
  2. package/lib/index.d.ts +8 -6
  3. package/lib/index.js.map +1 -1
  4. package/package.json +8 -3
  5. package/releases/aarch64-apple-darwin/index.node +0 -0
  6. package/releases/aarch64-unknown-linux-gnu/index.node +0 -0
  7. package/releases/x86_64-apple-darwin/index.node +0 -0
  8. package/releases/x86_64-pc-windows-msvc/index.node +0 -0
  9. package/releases/x86_64-unknown-linux-gnu/index.node +0 -0
  10. package/sdk-core/.buildkite/docker/Dockerfile +2 -2
  11. package/sdk-core/.buildkite/docker/docker-compose.yaml +1 -1
  12. package/sdk-core/.buildkite/pipeline.yml +1 -1
  13. package/sdk-core/.github/workflows/heavy.yml +1 -0
  14. package/sdk-core/README.md +13 -7
  15. package/sdk-core/client/src/lib.rs +27 -9
  16. package/sdk-core/client/src/metrics.rs +17 -8
  17. package/sdk-core/client/src/raw.rs +3 -3
  18. package/sdk-core/core/Cargo.toml +3 -4
  19. package/sdk-core/core/src/abstractions/take_cell.rs +28 -0
  20. package/sdk-core/core/src/abstractions.rs +197 -18
  21. package/sdk-core/core/src/core_tests/activity_tasks.rs +137 -45
  22. package/sdk-core/core/src/core_tests/child_workflows.rs +6 -5
  23. package/sdk-core/core/src/core_tests/determinism.rs +212 -2
  24. package/sdk-core/core/src/core_tests/local_activities.rs +183 -36
  25. package/sdk-core/core/src/core_tests/queries.rs +32 -14
  26. package/sdk-core/core/src/core_tests/workers.rs +8 -5
  27. package/sdk-core/core/src/core_tests/workflow_tasks.rs +340 -51
  28. package/sdk-core/core/src/ephemeral_server/mod.rs +110 -8
  29. package/sdk-core/core/src/internal_flags.rs +141 -0
  30. package/sdk-core/core/src/lib.rs +14 -9
  31. package/sdk-core/core/src/replay/mod.rs +16 -27
  32. package/sdk-core/core/src/telemetry/metrics.rs +69 -35
  33. package/sdk-core/core/src/telemetry/mod.rs +38 -14
  34. package/sdk-core/core/src/telemetry/prometheus_server.rs +19 -13
  35. package/sdk-core/core/src/test_help/mod.rs +65 -13
  36. package/sdk-core/core/src/worker/activities/activity_heartbeat_manager.rs +119 -160
  37. package/sdk-core/core/src/worker/activities/activity_task_poller_stream.rs +89 -0
  38. package/sdk-core/core/src/worker/activities/local_activities.rs +122 -6
  39. package/sdk-core/core/src/worker/activities.rs +347 -173
  40. package/sdk-core/core/src/worker/client/mocks.rs +22 -2
  41. package/sdk-core/core/src/worker/client.rs +18 -2
  42. package/sdk-core/core/src/worker/mod.rs +137 -44
  43. package/sdk-core/core/src/worker/workflow/history_update.rs +132 -51
  44. package/sdk-core/core/src/worker/workflow/machines/activity_state_machine.rs +207 -166
  45. package/sdk-core/core/src/worker/workflow/machines/cancel_external_state_machine.rs +6 -7
  46. package/sdk-core/core/src/worker/workflow/machines/cancel_workflow_state_machine.rs +6 -7
  47. package/sdk-core/core/src/worker/workflow/machines/child_workflow_state_machine.rs +157 -82
  48. package/sdk-core/core/src/worker/workflow/machines/complete_workflow_state_machine.rs +12 -12
  49. package/sdk-core/core/src/worker/workflow/machines/continue_as_new_workflow_state_machine.rs +6 -7
  50. package/sdk-core/core/src/worker/workflow/machines/fail_workflow_state_machine.rs +13 -15
  51. package/sdk-core/core/src/worker/workflow/machines/local_activity_state_machine.rs +170 -60
  52. package/sdk-core/core/src/worker/workflow/machines/mod.rs +24 -16
  53. package/sdk-core/core/src/worker/workflow/machines/modify_workflow_properties_state_machine.rs +6 -8
  54. package/sdk-core/core/src/worker/workflow/machines/patch_state_machine.rs +320 -204
  55. package/sdk-core/core/src/worker/workflow/machines/signal_external_state_machine.rs +10 -13
  56. package/sdk-core/core/src/worker/workflow/machines/timer_state_machine.rs +15 -23
  57. package/sdk-core/core/src/worker/workflow/machines/upsert_search_attributes_state_machine.rs +187 -46
  58. package/sdk-core/core/src/worker/workflow/machines/workflow_machines.rs +237 -111
  59. package/sdk-core/core/src/worker/workflow/machines/workflow_task_state_machine.rs +13 -13
  60. package/sdk-core/core/src/worker/workflow/managed_run/managed_wf_test.rs +10 -6
  61. package/sdk-core/core/src/worker/workflow/managed_run.rs +81 -62
  62. package/sdk-core/core/src/worker/workflow/mod.rs +341 -79
  63. package/sdk-core/core/src/worker/workflow/run_cache.rs +18 -11
  64. package/sdk-core/core/src/worker/workflow/wft_extraction.rs +15 -3
  65. package/sdk-core/core/src/worker/workflow/workflow_stream/saved_wf_inputs.rs +2 -0
  66. package/sdk-core/core/src/worker/workflow/workflow_stream.rs +75 -52
  67. package/sdk-core/core-api/Cargo.toml +0 -1
  68. package/sdk-core/core-api/src/lib.rs +13 -7
  69. package/sdk-core/core-api/src/telemetry.rs +4 -6
  70. package/sdk-core/core-api/src/worker.rs +5 -0
  71. package/sdk-core/fsm/rustfsm_procmacro/src/lib.rs +80 -55
  72. package/sdk-core/fsm/rustfsm_trait/src/lib.rs +22 -68
  73. package/sdk-core/histories/ends_empty_wft_complete.bin +0 -0
  74. package/sdk-core/histories/old_change_marker_format.bin +0 -0
  75. package/sdk-core/protos/api_upstream/.github/CODEOWNERS +2 -1
  76. package/sdk-core/protos/api_upstream/Makefile +1 -1
  77. package/sdk-core/protos/api_upstream/temporal/api/command/v1/message.proto +5 -17
  78. package/sdk-core/protos/api_upstream/temporal/api/common/v1/message.proto +11 -0
  79. package/sdk-core/protos/api_upstream/temporal/api/enums/v1/command_type.proto +1 -6
  80. package/sdk-core/protos/api_upstream/temporal/api/enums/v1/event_type.proto +6 -6
  81. package/sdk-core/protos/api_upstream/temporal/api/enums/v1/failed_cause.proto +5 -0
  82. package/sdk-core/protos/api_upstream/temporal/api/enums/v1/update.proto +22 -6
  83. package/sdk-core/protos/api_upstream/temporal/api/history/v1/message.proto +48 -19
  84. package/sdk-core/protos/api_upstream/temporal/api/namespace/v1/message.proto +2 -0
  85. package/sdk-core/protos/api_upstream/temporal/api/operatorservice/v1/request_response.proto +3 -0
  86. package/sdk-core/protos/api_upstream/temporal/api/{enums/v1/interaction_type.proto → protocol/v1/message.proto} +29 -11
  87. package/sdk-core/protos/api_upstream/temporal/api/sdk/v1/task_complete_metadata.proto +63 -0
  88. package/sdk-core/protos/api_upstream/temporal/api/update/v1/message.proto +111 -0
  89. package/sdk-core/protos/api_upstream/temporal/api/workflowservice/v1/request_response.proto +59 -28
  90. package/sdk-core/protos/api_upstream/temporal/api/workflowservice/v1/service.proto +2 -2
  91. package/sdk-core/protos/local/temporal/sdk/core/activity_result/activity_result.proto +7 -8
  92. package/sdk-core/protos/local/temporal/sdk/core/activity_task/activity_task.proto +10 -7
  93. package/sdk-core/protos/local/temporal/sdk/core/child_workflow/child_workflow.proto +19 -30
  94. package/sdk-core/protos/local/temporal/sdk/core/common/common.proto +1 -0
  95. package/sdk-core/protos/local/temporal/sdk/core/core_interface.proto +1 -0
  96. package/sdk-core/protos/local/temporal/sdk/core/external_data/external_data.proto +8 -0
  97. package/sdk-core/protos/local/temporal/sdk/core/workflow_activation/workflow_activation.proto +65 -60
  98. package/sdk-core/protos/local/temporal/sdk/core/workflow_commands/workflow_commands.proto +85 -84
  99. package/sdk-core/protos/local/temporal/sdk/core/workflow_completion/workflow_completion.proto +9 -3
  100. package/sdk-core/sdk/Cargo.toml +1 -1
  101. package/sdk-core/sdk/src/lib.rs +21 -5
  102. package/sdk-core/sdk/src/workflow_context/options.rs +7 -1
  103. package/sdk-core/sdk/src/workflow_context.rs +24 -17
  104. package/sdk-core/sdk/src/workflow_future.rs +9 -3
  105. package/sdk-core/sdk-core-protos/src/history_builder.rs +114 -89
  106. package/sdk-core/sdk-core-protos/src/history_info.rs +6 -1
  107. package/sdk-core/sdk-core-protos/src/lib.rs +205 -64
  108. package/sdk-core/test-utils/src/canned_histories.rs +106 -296
  109. package/sdk-core/test-utils/src/lib.rs +32 -5
  110. package/sdk-core/tests/heavy_tests.rs +10 -43
  111. package/sdk-core/tests/integ_tests/ephemeral_server_tests.rs +25 -3
  112. package/sdk-core/tests/integ_tests/heartbeat_tests.rs +5 -3
  113. package/sdk-core/tests/integ_tests/metrics_tests.rs +218 -16
  114. package/sdk-core/tests/integ_tests/polling_tests.rs +3 -8
  115. package/sdk-core/tests/integ_tests/queries_tests.rs +4 -2
  116. package/sdk-core/tests/integ_tests/visibility_tests.rs +34 -23
  117. package/sdk-core/tests/integ_tests/workflow_tests/activities.rs +97 -81
  118. package/sdk-core/tests/integ_tests/workflow_tests/cancel_external.rs +1 -0
  119. package/sdk-core/tests/integ_tests/workflow_tests/cancel_wf.rs +1 -0
  120. package/sdk-core/tests/integ_tests/workflow_tests/child_workflows.rs +80 -3
  121. package/sdk-core/tests/integ_tests/workflow_tests/continue_as_new.rs +5 -1
  122. package/sdk-core/tests/integ_tests/workflow_tests/determinism.rs +1 -0
  123. package/sdk-core/tests/integ_tests/workflow_tests/local_activities.rs +25 -3
  124. package/sdk-core/tests/integ_tests/workflow_tests/modify_wf_properties.rs +2 -4
  125. package/sdk-core/tests/integ_tests/workflow_tests/patches.rs +30 -0
  126. package/sdk-core/tests/integ_tests/workflow_tests/replay.rs +64 -0
  127. package/sdk-core/tests/integ_tests/workflow_tests/resets.rs +1 -0
  128. package/sdk-core/tests/integ_tests/workflow_tests/signals.rs +4 -0
  129. package/sdk-core/tests/integ_tests/workflow_tests/stickyness.rs +3 -1
  130. package/sdk-core/tests/integ_tests/workflow_tests/timers.rs +7 -2
  131. package/sdk-core/tests/integ_tests/workflow_tests/upsert_search_attrs.rs +6 -7
  132. package/sdk-core/tests/integ_tests/workflow_tests.rs +8 -8
  133. package/sdk-core/tests/main.rs +16 -25
  134. package/sdk-core/tests/runner.rs +11 -9
  135. package/src/conversions.rs +14 -8
  136. package/src/runtime.rs +9 -8
  137. package/ts/index.ts +8 -6
  138. package/sdk-core/protos/api_upstream/temporal/api/interaction/v1/message.proto +0 -87
@@ -27,9 +27,10 @@ use opentelemetry_otlp::WithExportConfig;
27
27
  use parking_lot::Mutex;
28
28
  use std::{
29
29
  cell::RefCell,
30
- collections::VecDeque,
30
+ collections::{HashMap, VecDeque},
31
31
  convert::TryInto,
32
32
  env,
33
+ net::SocketAddr,
33
34
  sync::{
34
35
  atomic::{AtomicBool, Ordering},
35
36
  Arc,
@@ -60,6 +61,7 @@ pub struct TelemetryInstance {
60
61
  logs_out: Option<Mutex<CoreLogsOut>>,
61
62
  metrics: Option<(Box<dyn MeterProvider + Send + Sync + 'static>, Meter)>,
62
63
  trace_subscriber: Arc<dyn Subscriber + Send + Sync>,
64
+ prom_binding: Option<SocketAddr>,
63
65
  _keepalive_rx: Receiver<()>,
64
66
  }
65
67
 
@@ -69,6 +71,7 @@ impl TelemetryInstance {
69
71
  logs_out: Option<Mutex<CoreLogsOut>>,
70
72
  metric_prefix: &'static str,
71
73
  mut meter_provider: Option<Box<dyn MeterProvider + Send + Sync + 'static>>,
74
+ prom_binding: Option<SocketAddr>,
72
75
  keepalive_rx: Receiver<()>,
73
76
  ) -> Self {
74
77
  let metrics = meter_provider.take().map(|mp| {
@@ -80,6 +83,7 @@ impl TelemetryInstance {
80
83
  logs_out,
81
84
  metrics,
82
85
  trace_subscriber,
86
+ prom_binding,
83
87
  _keepalive_rx: keepalive_rx,
84
88
  }
85
89
  }
@@ -89,6 +93,18 @@ impl TelemetryInstance {
89
93
  pub fn trace_subscriber(&self) -> Arc<dyn Subscriber + Send + Sync> {
90
94
  self.trace_subscriber.clone()
91
95
  }
96
+
97
+ /// Returns the address the Prometheus server is bound to if it is running
98
+ pub fn prom_port(&self) -> Option<SocketAddr> {
99
+ self.prom_binding
100
+ }
101
+
102
+ /// Returns our wrapper for OTel metric meters, can be used to, ex: initialize clients
103
+ pub fn get_metric_meter(&self) -> Option<TemporalMeter> {
104
+ self.metrics
105
+ .as_ref()
106
+ .map(|(_, m)| TemporalMeter::new(m, self.metric_prefix))
107
+ }
92
108
  }
93
109
 
94
110
  thread_local! {
@@ -129,10 +145,6 @@ impl CoreTelemetry for TelemetryInstance {
129
145
  vec![]
130
146
  }
131
147
  }
132
-
133
- fn get_metric_meter(&self) -> Option<&Meter> {
134
- self.metrics.as_ref().map(|(_, m)| m)
135
- }
136
148
  }
137
149
 
138
150
  /// Initialize tracing subscribers/output and logging export, returning a [TelemetryInstance]
@@ -159,6 +171,7 @@ pub fn telemetry_init(opts: TelemetryOptions) -> Result<TelemetryInstance, anyho
159
171
  // Parts of telem dat ====
160
172
  let mut logs_out = None;
161
173
  let metric_prefix = metric_prefix(&opts);
174
+ let mut prom_binding = None;
162
175
  // =======================
163
176
 
164
177
  // Tracing subscriber layers =========
@@ -208,11 +221,15 @@ pub fn telemetry_init(opts: TelemetryOptions) -> Result<TelemetryInstance, anyho
208
221
  let aggregator = SDKAggSelector { metric_prefix };
209
222
  match metrics {
210
223
  MetricsExporter::Prometheus(addr) => {
211
- let srv = PromServer::new(
212
- *addr,
213
- aggregator,
214
- metric_temporality_to_selector(opts.metric_temporality),
215
- )?;
224
+ let srv = runtime.block_on(async {
225
+ PromServer::new(
226
+ *addr,
227
+ aggregator,
228
+ metric_temporality_to_selector(opts.metric_temporality),
229
+ &opts.global_tags,
230
+ )
231
+ })?;
232
+ prom_binding = Some(srv.bound_addr());
216
233
  let mp = srv.exporter.meter_provider()?;
217
234
  runtime.spawn(async move { srv.run().await });
218
235
  Some(Box::new(mp) as Box<dyn MeterProvider + Send + Sync>)
@@ -229,7 +246,7 @@ pub fn telemetry_init(opts: TelemetryOptions) -> Result<TelemetryInstance, anyho
229
246
  runtime::Tokio,
230
247
  )
231
248
  .with_period(metric_periodicity.unwrap_or_else(|| Duration::from_secs(1)))
232
- .with_resource(default_resource())
249
+ .with_resource(default_resource(&opts.global_tags))
233
250
  .with_exporter(
234
251
  opentelemetry_otlp::new_exporter()
235
252
  .tonic()
@@ -250,7 +267,8 @@ pub fn telemetry_init(opts: TelemetryOptions) -> Result<TelemetryInstance, anyho
250
267
  match &tracing.exporter {
251
268
  TraceExporter::Otel(OtelCollectorOptions { url, headers, .. }) => {
252
269
  runtime.block_on(async {
253
- let tracer_cfg = Config::default().with_resource(default_resource());
270
+ let tracer_cfg =
271
+ Config::default().with_resource(default_resource(&opts.global_tags));
254
272
  let tracer = opentelemetry_otlp::new_pipeline()
255
273
  .tracing()
256
274
  .with_exporter(
@@ -284,6 +302,7 @@ pub fn telemetry_init(opts: TelemetryOptions) -> Result<TelemetryInstance, anyho
284
302
  logs_out,
285
303
  metric_prefix,
286
304
  meter_provider,
305
+ prom_binding,
287
306
  keepalive_rx,
288
307
  ))
289
308
  .expect("Must be able to send telem instance out of thread");
@@ -321,8 +340,12 @@ fn default_resource_kvs() -> &'static [KeyValue] {
321
340
  static INSTANCE: OnceCell<[KeyValue; 1]> = OnceCell::new();
322
341
  INSTANCE.get_or_init(|| [KeyValue::new("service.name", TELEM_SERVICE_NAME)])
323
342
  }
324
- fn default_resource() -> Resource {
325
- Resource::new(default_resource_kvs().iter().cloned())
343
+
344
+ fn default_resource(override_values: &HashMap<String, String>) -> Resource {
345
+ let override_kvs = override_values
346
+ .iter()
347
+ .map(|(k, v)| KeyValue::new(k.clone(), v.clone()));
348
+ Resource::new(default_resource_kvs().iter().cloned()).merge(&Resource::new(override_kvs))
326
349
  }
327
350
 
328
351
  fn metric_temporality_to_selector(
@@ -375,6 +398,7 @@ pub mod test_initters {
375
398
  .unwrap();
376
399
  }
377
400
  }
401
+ use crate::telemetry::metrics::TemporalMeter;
378
402
  #[cfg(test)]
379
403
  pub use test_initters::*;
380
404
 
@@ -1,23 +1,21 @@
1
1
  use crate::telemetry::default_resource;
2
2
  use hyper::{
3
3
  header::CONTENT_TYPE,
4
+ server::conn::AddrIncoming,
4
5
  service::{make_service_fn, service_fn},
5
6
  Body, Method, Request, Response, Server,
6
7
  };
7
- use opentelemetry::{
8
- metrics::MetricsError,
9
- sdk::{
10
- export::metrics::{aggregation::TemporalitySelector, AggregatorSelector},
11
- metrics::{controllers, processors},
12
- },
8
+ use opentelemetry::sdk::{
9
+ export::metrics::{aggregation::TemporalitySelector, AggregatorSelector},
10
+ metrics::{controllers, processors},
13
11
  };
14
12
  use opentelemetry_prometheus::{ExporterBuilder, PrometheusExporter};
15
13
  use prometheus::{Encoder, TextEncoder};
16
- use std::{convert::Infallible, net::SocketAddr, sync::Arc};
14
+ use std::{collections::HashMap, convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
17
15
 
18
16
  /// Exposes prometheus metrics for scraping
19
17
  pub(super) struct PromServer {
20
- addr: SocketAddr,
18
+ bound_addr: AddrIncoming,
21
19
  pub exporter: Arc<PrometheusExporter>,
22
20
  }
23
21
 
@@ -26,19 +24,23 @@ impl PromServer {
26
24
  addr: SocketAddr,
27
25
  aggregation: impl AggregatorSelector + Send + Sync + 'static,
28
26
  temporality: impl TemporalitySelector + Send + Sync + 'static,
29
- ) -> Result<Self, MetricsError> {
27
+ tags: &HashMap<String, String>,
28
+ ) -> Result<Self, anyhow::Error> {
30
29
  let controller =
31
30
  controllers::basic(processors::factory(aggregation, temporality).with_memory(true))
32
- .with_resource(default_resource())
31
+ // Because Prom is pull-based, make this always refresh
32
+ .with_collect_period(Duration::from_secs(0))
33
+ .with_resource(default_resource(tags))
33
34
  .build();
34
35
  let exporter = ExporterBuilder::new(controller).try_init()?;
36
+ let bound_addr = AddrIncoming::bind(&addr)?;
35
37
  Ok(Self {
36
38
  exporter: Arc::new(exporter),
37
- addr,
39
+ bound_addr,
38
40
  })
39
41
  }
40
42
 
41
- pub async fn run(&self) -> hyper::Result<()> {
43
+ pub async fn run(self) -> hyper::Result<()> {
42
44
  // Spin up hyper server to serve metrics for scraping. We use hyper since we already depend
43
45
  // on it via Tonic.
44
46
  let expclone = self.exporter.clone();
@@ -46,9 +48,13 @@ impl PromServer {
46
48
  let expclone = expclone.clone();
47
49
  async move { Ok::<_, Infallible>(service_fn(move |req| metrics_req(req, expclone.clone()))) }
48
50
  });
49
- let server = Server::bind(&self.addr).serve(svc);
51
+ let server = Server::builder(self.bound_addr).serve(svc);
50
52
  server.await
51
53
  }
54
+
55
+ pub fn bound_addr(&self) -> SocketAddr {
56
+ self.bound_addr.local_addr()
57
+ }
52
58
  }
53
59
 
54
60
  /// Serves prometheus metrics in the expected format for scraping
@@ -14,6 +14,7 @@ use crate::{
14
14
  },
15
15
  TaskToken, Worker, WorkerConfig, WorkerConfigBuilder,
16
16
  };
17
+ use async_trait::async_trait;
17
18
  use bimap::BiMap;
18
19
  use futures::{future::BoxFuture, stream, stream::BoxStream, FutureExt, Stream, StreamExt};
19
20
  use mockall::TimesRange;
@@ -29,7 +30,10 @@ use std::{
29
30
  task::{Context, Poll},
30
31
  time::Duration,
31
32
  };
32
- use temporal_sdk_core_api::Worker as WorkerTrait;
33
+ use temporal_sdk_core_api::{
34
+ errors::{PollActivityError, PollWfError},
35
+ Worker as WorkerTrait,
36
+ };
33
37
  use temporal_sdk_core_protos::{
34
38
  coresdk::{
35
39
  workflow_activation::WorkflowActivation,
@@ -52,7 +56,6 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
52
56
  use tokio_util::sync::CancellationToken;
53
57
 
54
58
  pub const TEST_Q: &str = "q";
55
- pub static NO_MORE_WORK_ERROR_MSG: &str = "No more work to do";
56
59
 
57
60
  pub fn test_worker_cfg() -> WorkerConfigBuilder {
58
61
  let mut wcb = WorkerConfigBuilder::default();
@@ -148,6 +151,7 @@ pub(crate) fn mock_worker(mocks: MocksHolder) -> Worker {
148
151
  mocks.inputs.wft_stream,
149
152
  act_poller,
150
153
  MetricsContext::no_op(),
154
+ None,
151
155
  CancellationToken::new(),
152
156
  )
153
157
  }
@@ -301,7 +305,7 @@ where
301
305
  }
302
306
  .boxed()
303
307
  } else {
304
- async { Some(Err(tonic::Status::cancelled(NO_MORE_WORK_ERROR_MSG))) }.boxed()
308
+ async { None }.boxed()
305
309
  }
306
310
  });
307
311
  Box::new(mock_poller) as BoxedPoller<T>
@@ -375,6 +379,7 @@ pub(crate) struct MockPollCfg {
375
379
  pub expect_fail_wft_matcher:
376
380
  Box<dyn Fn(&TaskToken, &WorkflowTaskFailedCause, &Option<Failure>) -> bool + Send>,
377
381
  pub completion_asserts: Option<Box<dyn Fn(&WorkflowTaskCompletion) + Send>>,
382
+ pub num_expected_completions: Option<TimesRange>,
378
383
  /// If being used with the Rust SDK, this is set true. It ensures pollers will not error out
379
384
  /// early with no work, since we cannot know the exact number of times polling will happen.
380
385
  /// Instead, they will just block forever.
@@ -396,6 +401,7 @@ impl MockPollCfg {
396
401
  mock_client: mock_workflow_client(),
397
402
  expect_fail_wft_matcher: Box::new(|_, _, _| true),
398
403
  completion_asserts: None,
404
+ num_expected_completions: None,
399
405
  using_rust_sdk: false,
400
406
  make_poll_stream_interminable: false,
401
407
  }
@@ -418,6 +424,7 @@ impl MockPollCfg {
418
424
  mock_client,
419
425
  expect_fail_wft_matcher: Box::new(|_, _, _| true),
420
426
  completion_asserts: None,
427
+ num_expected_completions: None,
421
428
  using_rust_sdk: false,
422
429
  make_poll_stream_interminable: false,
423
430
  }
@@ -587,16 +594,20 @@ pub(crate) fn build_mock_pollers(mut cfg: MockPollCfg) -> MocksHolder {
587
594
  );
588
595
 
589
596
  let outstanding = outstanding_wf_task_tokens.clone();
590
- cfg.mock_client
591
- .expect_complete_workflow_task()
592
- .returning(move |comp| {
593
- if let Some(ass) = cfg.completion_asserts.as_ref() {
594
- // tee hee
595
- ass(&comp)
596
- }
597
- outstanding.release_token(&comp.task_token);
598
- Ok(RespondWorkflowTaskCompletedResponse::default())
599
- });
597
+ let expect_completes = cfg.mock_client.expect_complete_workflow_task();
598
+ if let Some(range) = cfg.num_expected_completions {
599
+ expect_completes.times(range);
600
+ } else if cfg.completion_asserts.is_some() {
601
+ expect_completes.times(1..);
602
+ }
603
+ expect_completes.returning(move |comp| {
604
+ if let Some(ass) = cfg.completion_asserts.as_ref() {
605
+ // tee hee
606
+ ass(&comp)
607
+ }
608
+ outstanding.release_token(&comp.task_token);
609
+ Ok(RespondWorkflowTaskCompletedResponse::default())
610
+ });
600
611
  let outstanding = outstanding_wf_task_tokens.clone();
601
612
  cfg.mock_client
602
613
  .expect_fail_workflow_task()
@@ -842,6 +853,7 @@ pub(crate) fn gen_assert_and_fail(asserter: &dyn Fn(&WorkflowActivation)) -> Ass
842
853
  message: "Intentional test failure".to_string(),
843
854
  ..Default::default()
844
855
  }),
856
+ ..Default::default()
845
857
  }
846
858
  .into(),
847
859
  )
@@ -887,3 +899,43 @@ macro_rules! prost_dur {
887
899
  .expect("test duration fits")
888
900
  };
889
901
  }
902
+
903
+ #[async_trait]
904
+ pub(crate) trait WorkerExt {
905
+ /// Initiate shutdown, drain the pollers, and wait for shutdown to complete.
906
+ async fn drain_pollers_and_shutdown(self);
907
+ /// Initiate shutdown, drain the *activity* poller, and wait for shutdown to complete.
908
+ /// Takes a ref because of that one test that needs it.
909
+ async fn drain_activity_poller_and_shutdown(&self);
910
+ }
911
+
912
+ #[async_trait]
913
+ impl WorkerExt for Worker {
914
+ async fn drain_pollers_and_shutdown(self) {
915
+ self.initiate_shutdown();
916
+ tokio::join!(
917
+ async {
918
+ assert_matches!(
919
+ self.poll_activity_task().await.unwrap_err(),
920
+ PollActivityError::ShutDown
921
+ );
922
+ },
923
+ async {
924
+ assert_matches!(
925
+ self.poll_workflow_activation().await.unwrap_err(),
926
+ PollWfError::ShutDown
927
+ );
928
+ }
929
+ );
930
+ self.finalize_shutdown().await;
931
+ }
932
+
933
+ async fn drain_activity_poller_and_shutdown(&self) {
934
+ self.initiate_shutdown();
935
+ assert_matches!(
936
+ self.poll_activity_task().await.unwrap_err(),
937
+ PollActivityError::ShutDown
938
+ );
939
+ self.shutdown().await;
940
+ }
941
+ }