@temporalio/core-bridge 0.22.0 → 0.23.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/Cargo.lock CHANGED
@@ -1577,6 +1577,7 @@ dependencies = [
1577
1577
  "futures-retry",
1578
1578
  "http",
1579
1579
  "opentelemetry 0.17.0",
1580
+ "parking_lot 0.12.0",
1580
1581
  "prost-types",
1581
1582
  "temporal-sdk-core-protos",
1582
1583
  "thiserror",
@@ -1683,6 +1684,7 @@ dependencies = [
1683
1684
  "neon",
1684
1685
  "once_cell",
1685
1686
  "opentelemetry 0.16.0",
1687
+ "parking_lot 0.12.0",
1686
1688
  "prost",
1687
1689
  "prost-types",
1688
1690
  "temporal-client",
package/Cargo.toml CHANGED
@@ -19,8 +19,9 @@ incremental = false
19
19
  [dependencies]
20
20
  futures = { version = "0.3", features = ["executor"] }
21
21
  log = "0.4"
22
- neon = { version = "0.8.0", default-features = false, features = ["napi-4", "event-queue-api"] }
22
+ neon = { version = "0.8", default-features = false, features = ["napi-6", "event-queue-api"] }
23
23
  opentelemetry = "0.16"
24
+ parking_lot = "0.12"
24
25
  prost = "0.9"
25
26
  prost-types = "0.9"
26
27
  tokio = "1.13"
package/index.d.ts CHANGED
@@ -53,34 +53,104 @@ export interface ClientOptions {
53
53
  retry?: RetryOptions;
54
54
  }
55
55
 
56
+ /**
57
+ * Log directly to console
58
+ */
59
+ export interface ConsoleLogger {
60
+ console: {};
61
+ }
62
+
63
+ /**
64
+ * Forward logs to {@link Runtime} logger
65
+ */
66
+ export interface ForwardLogger {
67
+ forward: {
68
+ /**
69
+ * What level, if any, logs should be forwarded from core at
70
+ *
71
+ */
72
+ // These strings should match the log::LevelFilter enum in rust
73
+ level: LogLevel;
74
+ };
75
+ }
76
+
77
+ /**
78
+ * Logger types supported by Core
79
+ */
80
+ export type Logger = ConsoleLogger | ForwardLogger;
81
+
82
+ /**
83
+ * OpenTelemetry Collector options for exporting metrics or traces
84
+ */
85
+ export interface OtelCollectorExporter {
86
+ otel: {
87
+ /**
88
+ * URL of a gRPC OpenTelemetry collector.
89
+ */
90
+ url: string;
91
+ /**
92
+ * Optional set of HTTP request headers to send to Collector (e.g. for authentication)
93
+ */
94
+ headers?: Record<string, string>;
95
+ };
96
+ }
97
+
98
+ /**
99
+ * Prometheus metrics exporter options
100
+ */
101
+ export interface PrometheusMetricsExporter {
102
+ prometheus: {
103
+ /**
104
+ * Address to bind the Prometheus HTTP metrics exporter server
105
+ * (for example, `0.0.0.0:1234`).
106
+ *
107
+ * Metrics will be available for scraping under the standard `/metrics` route.
108
+ */
109
+ bindAddress: string;
110
+ };
111
+ }
112
+
113
+ /**
114
+ * Metrics exporters supported by Core
115
+ */
116
+ export type MetricsExporter = PrometheusMetricsExporter | OtelCollectorExporter;
117
+
118
+ /**
119
+ * Trace exporters supported by Core
120
+ */
121
+ export type TraceExporter = OtelCollectorExporter;
122
+
56
123
  export interface TelemetryOptions {
57
- /**
58
- * If set, telemetry is turned on and this URL must be the URL of a gRPC
59
- * OTel collector.
60
- */
61
- oTelCollectorUrl?: string;
62
124
  /**
63
125
  * A string in the env filter format specified here:
64
126
  * https://docs.rs/tracing-subscriber/0.2.20/tracing_subscriber/struct.EnvFilter.html
65
127
  *
66
128
  * Which determines what tracing data is collected in the Core SDK
129
+ *
130
+ * @default `temporal_sdk_core=WARN`
67
131
  */
68
132
  tracingFilter?: string;
133
+
69
134
  /**
70
- * What level, if any, logs should be forwarded from core at
135
+ * Control where to send Rust Core logs
71
136
  *
72
- * @default OFF
137
+ * @default log to console
73
138
  */
74
- // These strings should match the log::LevelFilter enum in rust
75
- logForwardingLevel?: 'OFF' | LogLevel;
76
- /** If set, metrics will be exposed on an http server in this process for direct scraping by
77
- * prometheus. If used in conjunction with the OTel collector, metrics will *not* be exported
78
- * to the collector, but traces will be.
139
+ logging?: Logger;
140
+
141
+ /**
142
+ * Control where to send traces generated by Rust Core, optional and turned off by default.
79
143
  *
80
- * The string provided must be an address the server will bind to. For example, `0.0.0.0:1234`.
81
- * Metrics will be available for scraping under the standard `/metrics` route.
144
+ * This is typically used for profiling SDK internals.
82
145
  */
83
- prometheusMetricsBindAddress?: string;
146
+ tracing?: TraceExporter;
147
+
148
+ /**
149
+ * Control exporting {@link NativeConnection} and {@link Worker} metrics.
150
+ *
151
+ * Turned off by default
152
+ */
153
+ metrics?: MetricsExporter;
84
154
  }
85
155
 
86
156
  export interface WorkerOptions {
@@ -166,6 +236,11 @@ export declare function newReplayWorker(
166
236
  callback: WorkerCallback
167
237
  ): void;
168
238
  export declare function workerShutdown(worker: Worker, callback: VoidCallback): void;
239
+ export declare function clientUpdateHeaders(
240
+ client: Client,
241
+ headers: Record<string, string>,
242
+ callback: VoidCallback
243
+ ): void;
169
244
  export declare function clientClose(client: Client): void;
170
245
  export declare function runtimeShutdown(runtime: Runtime, callback: VoidCallback): void;
171
246
  export declare function pollLogs(runtime: Runtime, callback: LogsCallback): void;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@temporalio/core-bridge",
3
- "version": "0.22.0",
3
+ "version": "0.23.0",
4
4
  "description": "Temporal.io SDK Core<>Node bridge",
5
5
  "main": "index.js",
6
6
  "types": "index.d.ts",
@@ -20,7 +20,7 @@
20
20
  "license": "MIT",
21
21
  "dependencies": {
22
22
  "@opentelemetry/api": "^1.0.3",
23
- "@temporalio/internal-non-workflow-common": "^0.22.0",
23
+ "@temporalio/internal-non-workflow-common": "^0.23.0",
24
24
  "arg": "^5.0.1",
25
25
  "cargo-cp-artifact": "^0.1.4",
26
26
  "which": "^2.0.2"
@@ -43,5 +43,5 @@
43
43
  "publishConfig": {
44
44
  "access": "public"
45
45
  },
46
- "gitHead": "3aa1f14982bd170d21b728cbf016dc4f1b595a76"
46
+ "gitHead": "81ee3fd09c2fd866b31b1dbfabce7ef221e338ea"
47
47
  }
@@ -397,7 +397,7 @@ pub extern "C" fn tmprl_client_init(
397
397
 
398
398
  let user_data = UserDataHandle(user_data);
399
399
  runtime.tokio_runtime.spawn(async move {
400
- match req.connect(namespace, None).await {
400
+ match req.connect(namespace, None, None).await {
401
401
  Ok(client) => unsafe {
402
402
  callback(
403
403
  user_data.into(),
@@ -1,6 +1,8 @@
1
1
  use log::LevelFilter;
2
2
  use std::{net::SocketAddr, str::FromStr};
3
- use temporal_sdk_core::{TelemetryOptionsBuilder, Url};
3
+ use temporal_sdk_core::{
4
+ Logger, MetricsExporter, OtelCollectorOptions, TelemetryOptionsBuilder, TraceExporter, Url,
5
+ };
4
6
  use temporal_sdk_core_protos::coresdk::bridge;
5
7
 
6
8
  // Present for try-from only
@@ -11,41 +13,68 @@ impl TryFrom<InitTelemetryRequest> for temporal_sdk_core::TelemetryOptions {
11
13
 
12
14
  fn try_from(InitTelemetryRequest(req): InitTelemetryRequest) -> Result<Self, Self::Error> {
13
15
  let mut telemetry_opts = TelemetryOptionsBuilder::default();
14
- if !req.otel_collector_url.is_empty() {
15
- telemetry_opts.otel_collector_url(
16
- Url::parse(&req.otel_collector_url)
17
- .map_err(|err| format!("invalid OpenTelemetry collector URL: {}", err))?,
18
- );
19
- }
20
- if !req.tracing_filter.is_empty() {
21
- telemetry_opts.tracing_filter(req.tracing_filter.clone());
22
- }
23
- match req.log_forwarding_level() {
24
- bridge::LogLevel::Unspecified => {}
25
- bridge::LogLevel::Off => {
26
- telemetry_opts.log_forwarding_level(LevelFilter::Off);
27
- }
28
- bridge::LogLevel::Error => {
29
- telemetry_opts.log_forwarding_level(LevelFilter::Error);
30
- }
31
- bridge::LogLevel::Warn => {
32
- telemetry_opts.log_forwarding_level(LevelFilter::Warn);
16
+ telemetry_opts.tracing_filter(req.tracing_filter.clone());
17
+ match req.logging {
18
+ None => {}
19
+ Some(bridge::init_telemetry_request::Logging::Console(_)) => {
20
+ telemetry_opts.logging(Logger::Console);
33
21
  }
34
- bridge::LogLevel::Info => {
35
- telemetry_opts.log_forwarding_level(LevelFilter::Info);
22
+ Some(bridge::init_telemetry_request::Logging::Forward(
23
+ bridge::init_telemetry_request::ForwardLoggerOptions { level },
24
+ )) => {
25
+ if let Some(level) = bridge::LogLevel::from_i32(level) {
26
+ let level = match level {
27
+ bridge::LogLevel::Unspecified => {
28
+ return Err("Must specify log level".to_string())
29
+ }
30
+ bridge::LogLevel::Off => LevelFilter::Off,
31
+ bridge::LogLevel::Error => LevelFilter::Error,
32
+ bridge::LogLevel::Warn => LevelFilter::Warn,
33
+ bridge::LogLevel::Info => LevelFilter::Info,
34
+ bridge::LogLevel::Debug => LevelFilter::Debug,
35
+ bridge::LogLevel::Trace => LevelFilter::Trace,
36
+ };
37
+ telemetry_opts.logging(Logger::Forward(level));
38
+ } else {
39
+ return Err(format!("Unknown LogLevel: {}", level));
40
+ }
36
41
  }
37
- bridge::LogLevel::Debug => {
38
- telemetry_opts.log_forwarding_level(LevelFilter::Debug);
42
+ }
43
+ match req.metrics {
44
+ None => {}
45
+ Some(bridge::init_telemetry_request::Metrics::Prometheus(
46
+ bridge::init_telemetry_request::PrometheusOptions {
47
+ export_bind_address,
48
+ },
49
+ )) => {
50
+ telemetry_opts.metrics(MetricsExporter::Prometheus(
51
+ SocketAddr::from_str(&export_bind_address)
52
+ .map_err(|err| format!("invalid Prometheus address: {}", err))?,
53
+ ));
39
54
  }
40
- bridge::LogLevel::Trace => {
41
- telemetry_opts.log_forwarding_level(LevelFilter::Trace);
55
+ Some(bridge::init_telemetry_request::Metrics::OtelMetrics(
56
+ bridge::init_telemetry_request::OtelCollectorOptions { url, headers },
57
+ )) => {
58
+ telemetry_opts.metrics(MetricsExporter::Otel(OtelCollectorOptions {
59
+ url: Url::parse(&url).map_err(|err| {
60
+ format!("invalid OpenTelemetry collector URL for metrics: {}", err)
61
+ })?,
62
+ headers,
63
+ }));
42
64
  }
43
65
  }
44
- if !req.prometheus_export_bind_address.is_empty() {
45
- telemetry_opts.prometheus_export_bind_address(
46
- SocketAddr::from_str(&req.prometheus_export_bind_address)
47
- .map_err(|err| format!("invalid Prometheus address: {}", err))?,
48
- );
66
+ match req.tracing {
67
+ None => {}
68
+ Some(bridge::init_telemetry_request::Tracing::OtelTracing(
69
+ bridge::init_telemetry_request::OtelCollectorOptions { url, headers },
70
+ )) => {
71
+ telemetry_opts.tracing(TraceExporter::Otel(OtelCollectorOptions {
72
+ url: Url::parse(&url).map_err(|err| {
73
+ format!("invalid OpenTelemetry collector URL for tracing: {}", err)
74
+ })?,
75
+ headers,
76
+ }));
77
+ }
49
78
  }
50
79
 
51
80
  telemetry_opts
@@ -73,9 +102,6 @@ impl TryFrom<ClientOptions> for temporal_sdk_core::ClientOptions {
73
102
  if !req.client_version.is_empty() {
74
103
  client_opts.client_version(req.client_version);
75
104
  }
76
- if !req.static_headers.is_empty() {
77
- client_opts.static_headers(req.static_headers);
78
- }
79
105
  if !req.identity.is_empty() {
80
106
  client_opts.identity(req.identity);
81
107
  }
@@ -20,6 +20,7 @@ futures = "0.3"
20
20
  futures-retry = "0.6.0"
21
21
  http = "0.2"
22
22
  opentelemetry = { version = "0.17", features = ["metrics"] }
23
+ parking_lot = "0.12"
23
24
  prost-types = "0.9"
24
25
  thiserror = "1.0"
25
26
  tokio = "1.1"
@@ -25,6 +25,7 @@ use crate::{
25
25
  use backoff::{ExponentialBackoff, SystemClock};
26
26
  use http::uri::InvalidUri;
27
27
  use opentelemetry::metrics::Meter;
28
+ use parking_lot::RwLock;
28
29
  use std::{
29
30
  collections::HashMap,
30
31
  fmt::{Debug, Formatter},
@@ -80,10 +81,6 @@ pub struct ClientOptions {
80
81
  /// in all RPC calls. The server decides if the client is supported based on this.
81
82
  pub client_version: String,
82
83
 
83
- /// Other non-required static headers which should be attached to every request
84
- #[builder(default)]
85
- pub static_headers: HashMap<String, String>,
86
-
87
84
  /// A human-readable string that can identify this process. Defaults to empty string.
88
85
  #[builder(default)]
89
86
  pub identity: String,
@@ -252,11 +249,18 @@ impl From<ConfiguredClient<WorkflowServiceClientWithMetrics>> for AnyClient {
252
249
  pub struct ConfiguredClient<C> {
253
250
  client: C,
254
251
  options: ClientOptions,
252
+ headers: Arc<RwLock<HashMap<String, String>>>,
255
253
  /// Capabilities as read from the `get_system_info` RPC call made on client connection
256
254
  capabilities: Option<get_system_info_response::Capabilities>,
257
255
  }
258
256
 
259
257
  impl<C> ConfiguredClient<C> {
258
+ /// Set HTTP request headers overwriting previous headers
259
+ pub fn set_headers(&self, headers: HashMap<String, String>) {
260
+ let mut guard = self.headers.write();
261
+ *guard = headers;
262
+ }
263
+
260
264
  /// Returns the options the client is configured with
261
265
  pub fn options(&self) -> &ClientOptions {
262
266
  &self.options
@@ -295,8 +299,12 @@ impl ClientOptions {
295
299
  &self,
296
300
  namespace: impl Into<String>,
297
301
  metrics_meter: Option<&Meter>,
302
+ headers: Option<Arc<RwLock<HashMap<String, String>>>>,
298
303
  ) -> Result<RetryClient<Client>, ClientInitError> {
299
- let client = self.connect_no_namespace(metrics_meter).await?.into_inner();
304
+ let client = self
305
+ .connect_no_namespace(metrics_meter, headers)
306
+ .await?
307
+ .into_inner();
300
308
  let client = Client::new(client, namespace.into());
301
309
  let retry_client = RetryClient::new(client, self.retry_config.clone());
302
310
  Ok(retry_client)
@@ -309,6 +317,7 @@ impl ClientOptions {
309
317
  pub async fn connect_no_namespace(
310
318
  &self,
311
319
  metrics_meter: Option<&Meter>,
320
+ headers: Option<Arc<RwLock<HashMap<String, String>>>>,
312
321
  ) -> Result<RetryClient<ConfiguredClient<WorkflowServiceClientWithMetrics>>, ClientInitError>
313
322
  {
314
323
  let channel = Channel::from_shared(self.target_url.to_string())?;
@@ -320,9 +329,14 @@ impl ClientOptions {
320
329
  metrics: metrics_meter.map(|mm| MetricsContext::new(vec![], mm)),
321
330
  })
322
331
  .service(channel);
323
- let interceptor = ServiceCallInterceptor { opts: self.clone() };
332
+ let headers = headers.unwrap_or_default();
333
+ let interceptor = ServiceCallInterceptor {
334
+ opts: self.clone(),
335
+ headers: headers.clone(),
336
+ };
324
337
 
325
338
  let mut client = ConfiguredClient {
339
+ headers,
326
340
  client: WorkflowServiceClient::with_interceptor(service, interceptor),
327
341
  options: self.clone(),
328
342
  capabilities: None,
@@ -394,6 +408,8 @@ pub struct WorkflowTaskCompletion {
394
408
  #[derive(Clone)]
395
409
  pub struct ServiceCallInterceptor {
396
410
  opts: ClientOptions,
411
+ /// Only accessed as a reader
412
+ headers: Arc<RwLock<HashMap<String, String>>>,
397
413
  }
398
414
 
399
415
  impl Interceptor for ServiceCallInterceptor {
@@ -415,7 +431,8 @@ impl Interceptor for ServiceCallInterceptor {
415
431
  .parse()
416
432
  .unwrap_or_else(|_| MetadataValue::from_static("")),
417
433
  );
418
- for (k, v) in &self.opts.static_headers {
434
+ let headers = &*self.headers.read();
435
+ for (k, v) in headers {
419
436
  if let (Ok(k), Ok(v)) = (MetadataKey::from_str(k), MetadataValue::from_str(v)) {
420
437
  metadata.insert(k, v);
421
438
  }
@@ -558,7 +558,7 @@ mod tests {
558
558
  #[allow(dead_code)]
559
559
  async fn raw_client_retry_compiles() {
560
560
  let opts = ClientOptionsBuilder::default().build().unwrap();
561
- let raw_client = opts.connect_no_namespace(None).await.unwrap();
561
+ let raw_client = opts.connect_no_namespace(None, None).await.unwrap();
562
562
  let mut retry_client = RetryClient::new(raw_client, opts.retry_config);
563
563
 
564
564
  let the_request = ListNamespacesRequest::default();
@@ -11,13 +11,7 @@ pub fn criterion_benchmark(c: &mut Criterion) {
11
11
  .enable_time()
12
12
  .build()
13
13
  .unwrap();
14
- telemetry_init(
15
- &TelemetryOptionsBuilder::default()
16
- .totally_disable(true)
17
- .build()
18
- .unwrap(),
19
- )
20
- .unwrap();
14
+ telemetry_init(&TelemetryOptionsBuilder::default().build().unwrap()).unwrap();
21
15
  let _g = tokio_runtime.enter();
22
16
 
23
17
  let num_timers = 10;
@@ -34,7 +34,8 @@ pub use pollers::{
34
34
  TlsConfig, WorkflowClientTrait,
35
35
  };
36
36
  pub use telemetry::{
37
- fetch_global_buffered_logs, telemetry_init, TelemetryOptions, TelemetryOptionsBuilder,
37
+ fetch_global_buffered_logs, telemetry_init, Logger, MetricsExporter, OtelCollectorOptions,
38
+ TelemetryOptions, TelemetryOptionsBuilder, TraceExporter,
38
39
  };
39
40
  pub use temporal_sdk_core_api as api;
40
41
  pub use temporal_sdk_core_protos as protos;
@@ -359,9 +359,7 @@ impl ValidScheduleLA {
359
359
  ))
360
360
  }
361
361
  };
362
- let retry_policy = v
363
- .retry_policy
364
- .ok_or_else(|| anyhow!("Retry policy must be defined!"))?;
362
+ let retry_policy = v.retry_policy.unwrap_or_default();
365
363
  let local_retry_threshold = v
366
364
  .local_retry_threshold
367
365
  .clone()
@@ -1,5 +1,8 @@
1
1
  use std::time::Duration;
2
- use temporal_sdk_core_protos::{coresdk::common::RetryPolicy, utilities::TryIntoOrNone};
2
+ use temporal_sdk_core_protos::{
3
+ coresdk::common::RetryPolicy, temporal::api::failure::v1::ApplicationFailureInfo,
4
+ utilities::TryIntoOrNone,
5
+ };
3
6
 
4
7
  pub(crate) trait RetryPolicyExt {
5
8
  /// Ask this retry policy if a retry should be performed. Caller provides the current attempt
@@ -7,11 +10,31 @@ pub(crate) trait RetryPolicyExt {
7
10
  ///
8
11
  /// Returns `None` if it should not, otherwise a duration indicating how long to wait before
9
12
  /// performing the retry.
10
- fn should_retry(&self, attempt_number: usize, err_type_str: &str) -> Option<Duration>;
13
+ ///
14
+ /// Applies defaults to missing fields:
15
+ /// `initial_interval` - 1 second
16
+ /// `maximum_interval` - 100 x initial_interval
17
+ /// `backoff_coefficient` - 2.0
18
+ fn should_retry(
19
+ &self,
20
+ attempt_number: usize,
21
+ application_failure: Option<&ApplicationFailureInfo>,
22
+ ) -> Option<Duration>;
11
23
  }
12
24
 
13
25
  impl RetryPolicyExt for RetryPolicy {
14
- fn should_retry(&self, attempt_number: usize, err_type_str: &str) -> Option<Duration> {
26
+ fn should_retry(
27
+ &self,
28
+ attempt_number: usize,
29
+ application_failure: Option<&ApplicationFailureInfo>,
30
+ ) -> Option<Duration> {
31
+ let non_retryable = application_failure
32
+ .map(|f| f.non_retryable)
33
+ .unwrap_or_default();
34
+ if non_retryable {
35
+ return None;
36
+ }
37
+ let err_type_str = application_failure.map_or("", |f| &f.r#type);
15
38
  let realmax = self.maximum_attempts.max(0);
16
39
  if realmax > 0 && attempt_number >= realmax as usize {
17
40
  return None;
@@ -23,7 +46,11 @@ impl RetryPolicyExt for RetryPolicy {
23
46
  }
24
47
  }
25
48
 
26
- let converted_interval = self.initial_interval.clone().try_into_or_none();
49
+ let converted_interval = self
50
+ .initial_interval
51
+ .clone()
52
+ .try_into_or_none()
53
+ .or(Some(Duration::from_secs(1)));
27
54
  if attempt_number == 1 {
28
55
  return converted_interval;
29
56
  }
@@ -74,32 +101,32 @@ mod tests {
74
101
  maximum_attempts: 10,
75
102
  non_retryable_error_types: vec![],
76
103
  };
77
- let res = rp.should_retry(1, "").unwrap();
104
+ let res = rp.should_retry(1, None).unwrap();
78
105
  assert_eq!(res.as_millis(), 1_000);
79
- let res = rp.should_retry(2, "").unwrap();
106
+ let res = rp.should_retry(2, None).unwrap();
80
107
  assert_eq!(res.as_millis(), 2_000);
81
- let res = rp.should_retry(3, "").unwrap();
108
+ let res = rp.should_retry(3, None).unwrap();
82
109
  assert_eq!(res.as_millis(), 4_000);
83
- let res = rp.should_retry(4, "").unwrap();
110
+ let res = rp.should_retry(4, None).unwrap();
84
111
  assert_eq!(res.as_millis(), 8_000);
85
- let res = rp.should_retry(5, "").unwrap();
112
+ let res = rp.should_retry(5, None).unwrap();
86
113
  assert_eq!(res.as_millis(), 10_000);
87
- let res = rp.should_retry(6, "").unwrap();
114
+ let res = rp.should_retry(6, None).unwrap();
88
115
  assert_eq!(res.as_millis(), 10_000);
89
116
  // Max attempts - no retry
90
- assert!(rp.should_retry(10, "").is_none());
117
+ assert!(rp.should_retry(10, None).is_none());
91
118
  }
92
119
 
93
120
  #[test]
94
121
  fn no_interval_no_backoff() {
95
122
  let rp = RetryPolicy {
96
123
  initial_interval: None,
97
- backoff_coefficient: 2.0,
124
+ backoff_coefficient: 0.,
98
125
  maximum_interval: None,
99
126
  maximum_attempts: 10,
100
127
  non_retryable_error_types: vec![],
101
128
  };
102
- assert!(rp.should_retry(1, "").is_none());
129
+ assert!(rp.should_retry(1, None).is_some());
103
130
  }
104
131
 
105
132
  #[test]
@@ -112,7 +139,7 @@ mod tests {
112
139
  non_retryable_error_types: vec![],
113
140
  };
114
141
  for i in 0..50 {
115
- assert!(rp.should_retry(i, "").is_some());
142
+ assert!(rp.should_retry(i, None).is_some());
116
143
  }
117
144
  }
118
145
 
@@ -126,7 +153,7 @@ mod tests {
126
153
  non_retryable_error_types: vec![],
127
154
  };
128
155
  for i in 0..50 {
129
- assert!(rp.should_retry(i, "").is_some());
156
+ assert!(rp.should_retry(i, None).is_some());
130
157
  }
131
158
  }
132
159
 
@@ -139,6 +166,36 @@ mod tests {
139
166
  maximum_attempts: 10,
140
167
  non_retryable_error_types: vec!["no retry".to_string()],
141
168
  };
142
- assert!(rp.should_retry(1, "no retry").is_none());
169
+ assert!(rp
170
+ .should_retry(
171
+ 1,
172
+ Some(&ApplicationFailureInfo {
173
+ r#type: "no retry".to_string(),
174
+ non_retryable: false,
175
+ details: None,
176
+ })
177
+ )
178
+ .is_none());
179
+ }
180
+
181
+ #[test]
182
+ fn no_non_retryable_application_failure() {
183
+ let rp = RetryPolicy {
184
+ initial_interval: Some(Duration::from_secs(1).into()),
185
+ backoff_coefficient: 2.0,
186
+ maximum_interval: Some(Duration::from_secs(10).into()),
187
+ maximum_attempts: 10,
188
+ non_retryable_error_types: vec![],
189
+ };
190
+ assert!(rp
191
+ .should_retry(
192
+ 1,
193
+ Some(&ApplicationFailureInfo {
194
+ r#type: "".to_string(),
195
+ non_retryable: true,
196
+ details: None,
197
+ })
198
+ )
199
+ .is_none());
143
200
  }
144
201
  }