spikard 0.6.2 → 0.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +90 -508
  3. data/ext/spikard_rb/Cargo.lock +3287 -0
  4. data/ext/spikard_rb/Cargo.toml +1 -1
  5. data/ext/spikard_rb/extconf.rb +3 -3
  6. data/lib/spikard/app.rb +72 -49
  7. data/lib/spikard/background.rb +38 -7
  8. data/lib/spikard/testing.rb +42 -4
  9. data/lib/spikard/version.rb +1 -1
  10. data/sig/spikard.rbs +4 -0
  11. data/vendor/crates/spikard-bindings-shared/Cargo.toml +1 -1
  12. data/vendor/crates/spikard-bindings-shared/tests/config_extractor_behavior.rs +191 -0
  13. data/vendor/crates/spikard-core/Cargo.toml +1 -1
  14. data/vendor/crates/spikard-core/src/http.rs +1 -0
  15. data/vendor/crates/spikard-core/src/lifecycle.rs +63 -0
  16. data/vendor/crates/spikard-core/tests/bindings_response_tests.rs +136 -0
  17. data/vendor/crates/spikard-core/tests/di_dependency_defaults.rs +37 -0
  18. data/vendor/crates/spikard-core/tests/error_mapper.rs +761 -0
  19. data/vendor/crates/spikard-core/tests/parameters_edge_cases.rs +106 -0
  20. data/vendor/crates/spikard-core/tests/parameters_full.rs +701 -0
  21. data/vendor/crates/spikard-core/tests/parameters_schema_and_formats.rs +301 -0
  22. data/vendor/crates/spikard-core/tests/request_data_roundtrip.rs +67 -0
  23. data/vendor/crates/spikard-core/tests/validation_coverage.rs +250 -0
  24. data/vendor/crates/spikard-core/tests/validation_error_paths.rs +45 -0
  25. data/vendor/crates/spikard-http/Cargo.toml +1 -1
  26. data/vendor/crates/spikard-http/src/jsonrpc/http_handler.rs +502 -0
  27. data/vendor/crates/spikard-http/src/jsonrpc/method_registry.rs +648 -0
  28. data/vendor/crates/spikard-http/src/jsonrpc/mod.rs +58 -0
  29. data/vendor/crates/spikard-http/src/jsonrpc/protocol.rs +1207 -0
  30. data/vendor/crates/spikard-http/src/jsonrpc/router.rs +2262 -0
  31. data/vendor/crates/spikard-http/src/testing/test_client.rs +155 -2
  32. data/vendor/crates/spikard-http/src/testing.rs +171 -0
  33. data/vendor/crates/spikard-http/src/websocket.rs +79 -6
  34. data/vendor/crates/spikard-http/tests/auth_integration.rs +647 -0
  35. data/vendor/crates/spikard-http/tests/common/test_builders.rs +633 -0
  36. data/vendor/crates/spikard-http/tests/di_handler_error_responses.rs +162 -0
  37. data/vendor/crates/spikard-http/tests/middleware_stack_integration.rs +389 -0
  38. data/vendor/crates/spikard-http/tests/request_extraction_full.rs +513 -0
  39. data/vendor/crates/spikard-http/tests/server_auth_middleware_behavior.rs +244 -0
  40. data/vendor/crates/spikard-http/tests/server_configured_router_behavior.rs +200 -0
  41. data/vendor/crates/spikard-http/tests/server_cors_preflight.rs +82 -0
  42. data/vendor/crates/spikard-http/tests/server_handler_wrappers.rs +464 -0
  43. data/vendor/crates/spikard-http/tests/server_method_router_additional_behavior.rs +286 -0
  44. data/vendor/crates/spikard-http/tests/server_method_router_coverage.rs +118 -0
  45. data/vendor/crates/spikard-http/tests/server_middleware_behavior.rs +99 -0
  46. data/vendor/crates/spikard-http/tests/server_middleware_branches.rs +206 -0
  47. data/vendor/crates/spikard-http/tests/server_openapi_jsonrpc_static.rs +281 -0
  48. data/vendor/crates/spikard-http/tests/server_router_behavior.rs +121 -0
  49. data/vendor/crates/spikard-http/tests/sse_full_behavior.rs +584 -0
  50. data/vendor/crates/spikard-http/tests/sse_handler_behavior.rs +130 -0
  51. data/vendor/crates/spikard-http/tests/test_client_requests.rs +167 -0
  52. data/vendor/crates/spikard-http/tests/testing_helpers.rs +87 -0
  53. data/vendor/crates/spikard-http/tests/testing_module_coverage.rs +156 -0
  54. data/vendor/crates/spikard-http/tests/urlencoded_content_type.rs +82 -0
  55. data/vendor/crates/spikard-http/tests/websocket_full_behavior.rs +440 -0
  56. data/vendor/crates/spikard-http/tests/websocket_integration.rs +152 -0
  57. data/vendor/crates/spikard-rb/Cargo.toml +1 -1
  58. data/vendor/crates/spikard-rb/src/gvl.rs +80 -0
  59. data/vendor/crates/spikard-rb/src/handler.rs +12 -9
  60. data/vendor/crates/spikard-rb/src/lib.rs +137 -124
  61. data/vendor/crates/spikard-rb/src/request.rs +342 -0
  62. data/vendor/crates/spikard-rb/src/runtime/server_runner.rs +1 -8
  63. data/vendor/crates/spikard-rb/src/server.rs +1 -8
  64. data/vendor/crates/spikard-rb/src/testing/client.rs +168 -9
  65. data/vendor/crates/spikard-rb/src/websocket.rs +119 -30
  66. data/vendor/crates/spikard-rb-macros/Cargo.toml +14 -0
  67. data/vendor/crates/spikard-rb-macros/src/lib.rs +52 -0
  68. metadata +44 -1
@@ -81,7 +81,7 @@ impl TestClient {
81
81
  let mut request = self.server.post(&full_path);
82
82
 
83
83
  if let Some(headers_vec) = headers {
84
- request = self.add_headers(request, headers_vec.clone())?;
84
+ request = self.add_headers(request, headers_vec)?;
85
85
  }
86
86
 
87
87
  if let Some((form_fields, files)) = multipart {
@@ -90,7 +90,9 @@ impl TestClient {
90
90
  request = request.add_header("content-type", &content_type);
91
91
  request = request.bytes(Bytes::from(body));
92
92
  } else if let Some(form_fields) = form_data {
93
- let encoded = super::encode_urlencoded_body(&serde_json::to_value(&form_fields).unwrap_or(Value::Null))
93
+ let fields_value = serde_json::to_value(&form_fields)
94
+ .map_err(|e| SnapshotError::Decompression(format!("Failed to serialize form fields: {}", e)))?;
95
+ let encoded = super::encode_urlencoded_body(&fields_value)
94
96
  .map_err(|e| SnapshotError::Decompression(format!("Form encoding failed: {}", e)))?;
95
97
  request = request.add_header("content-type", "application/x-www-form-urlencoded");
96
98
  request = request.bytes(Bytes::from(encoded));
@@ -241,6 +243,67 @@ impl TestClient {
241
243
  snapshot_response(response).await
242
244
  }
243
245
 
246
+ /// Send a GraphQL query/mutation to a custom endpoint
247
+ pub async fn graphql_at(
248
+ &self,
249
+ endpoint: &str,
250
+ query: &str,
251
+ variables: Option<Value>,
252
+ operation_name: Option<&str>,
253
+ ) -> Result<ResponseSnapshot, SnapshotError> {
254
+ let body = build_graphql_body(query, variables, operation_name);
255
+ self.post(endpoint, Some(body), None, None, None, None).await
256
+ }
257
+
258
+ /// Send a GraphQL query/mutation
259
+ pub async fn graphql(
260
+ &self,
261
+ query: &str,
262
+ variables: Option<Value>,
263
+ operation_name: Option<&str>,
264
+ ) -> Result<ResponseSnapshot, SnapshotError> {
265
+ self.graphql_at("/graphql", query, variables, operation_name).await
266
+ }
267
+
268
+ /// Send a GraphQL query and return HTTP status code separately
269
+ ///
270
+ /// This method allows tests to distinguish between:
271
+ /// - HTTP-level errors (400/422 for invalid requests)
272
+ /// - GraphQL-level errors (200 with errors in response body)
273
+ ///
274
+ /// # Example
275
+ /// ```ignore
276
+ /// let (status, snapshot) = client.graphql_with_status(
277
+ /// "query { invalid syntax",
278
+ /// None,
279
+ /// None
280
+ /// ).await?;
281
+ /// assert_eq!(status, 400); // HTTP parse error
282
+ /// ```
283
+ pub async fn graphql_with_status(
284
+ &self,
285
+ query: &str,
286
+ variables: Option<Value>,
287
+ operation_name: Option<&str>,
288
+ ) -> Result<(u16, ResponseSnapshot), SnapshotError> {
289
+ let snapshot = self.graphql(query, variables, operation_name).await?;
290
+ let status = snapshot.status;
291
+ Ok((status, snapshot))
292
+ }
293
+
294
+ /// Send a GraphQL subscription (WebSocket)
295
+ pub async fn graphql_subscription(
296
+ &self,
297
+ _query: &str,
298
+ _variables: Option<Value>,
299
+ _operation_name: Option<&str>,
300
+ ) -> Result<(), SnapshotError> {
301
+ // For now, return a placeholder - full WebSocket implementation comes later
302
+ Err(SnapshotError::Decompression(
303
+ "GraphQL subscriptions not yet implemented".to_string(),
304
+ ))
305
+ }
306
+
244
307
  /// Add headers to a test request builder
245
308
  fn add_headers(
246
309
  &self,
@@ -258,6 +321,18 @@ impl TestClient {
258
321
  }
259
322
  }
260
323
 
324
+ /// Build a GraphQL request body from query, variables, and operation name
325
+ pub fn build_graphql_body(query: &str, variables: Option<Value>, operation_name: Option<&str>) -> Value {
326
+ let mut body = serde_json::json!({ "query": query });
327
+ if let Some(vars) = variables {
328
+ body["variables"] = vars;
329
+ }
330
+ if let Some(op_name) = operation_name {
331
+ body["operationName"] = Value::String(op_name.to_string());
332
+ }
333
+ body
334
+ }
335
+
261
336
  /// Build a full path with query parameters
262
337
  fn build_full_path(path: &str, query_params: Option<&[(String, String)]>) -> String {
263
338
  match query_params {
@@ -308,4 +383,82 @@ mod tests {
308
383
  let result = build_full_path(path, Some(&params));
309
384
  assert_eq!(result, "/users?active=true&id=123");
310
385
  }
386
+
387
+ #[test]
388
+ fn test_graphql_query_builder() {
389
+ let query = "{ users { id name } }";
390
+ let variables = Some(serde_json::json!({ "limit": 10 }));
391
+ let op_name = Some("GetUsers");
392
+
393
+ let mut body = serde_json::json!({ "query": query });
394
+ if let Some(vars) = variables {
395
+ body["variables"] = vars;
396
+ }
397
+ if let Some(op_name) = op_name {
398
+ body["operationName"] = Value::String(op_name.to_string());
399
+ }
400
+
401
+ assert_eq!(body["query"], query);
402
+ assert_eq!(body["variables"]["limit"], 10);
403
+ assert_eq!(body["operationName"], "GetUsers");
404
+ }
405
+
406
+ #[test]
407
+ fn test_graphql_with_status_method() {
408
+ let query = "query { hello }";
409
+ let body = serde_json::json!({
410
+ "query": query,
411
+ "variables": null,
412
+ "operationName": null
413
+ });
414
+
415
+ // This test validates the method signature and return type
416
+ // Actual HTTP status testing will happen in integration tests
417
+ let expected_fields = vec!["query", "variables", "operationName"];
418
+ for field in expected_fields {
419
+ assert!(body.get(field).is_some(), "Missing field: {}", field);
420
+ }
421
+ }
422
+
423
+ #[test]
424
+ fn test_build_graphql_body_basic() {
425
+ let query = "{ users { id name } }";
426
+ let body = build_graphql_body(query, None, None);
427
+
428
+ assert_eq!(body["query"], query);
429
+ assert!(body.get("variables").is_none() || body["variables"].is_null());
430
+ assert!(body.get("operationName").is_none() || body["operationName"].is_null());
431
+ }
432
+
433
+ #[test]
434
+ fn test_build_graphql_body_with_variables() {
435
+ let query = "query GetUser($id: ID!) { user(id: $id) { name } }";
436
+ let variables = Some(serde_json::json!({ "id": "123" }));
437
+ let body = build_graphql_body(query, variables, None);
438
+
439
+ assert_eq!(body["query"], query);
440
+ assert_eq!(body["variables"]["id"], "123");
441
+ }
442
+
443
+ #[test]
444
+ fn test_build_graphql_body_with_operation_name() {
445
+ let query = "query GetUsers { users { id } }";
446
+ let op_name = Some("GetUsers");
447
+ let body = build_graphql_body(query, None, op_name);
448
+
449
+ assert_eq!(body["query"], query);
450
+ assert_eq!(body["operationName"], "GetUsers");
451
+ }
452
+
453
+ #[test]
454
+ fn test_build_graphql_body_all_fields() {
455
+ let query = "mutation CreateUser($name: String!) { createUser(name: $name) { id } }";
456
+ let variables = Some(serde_json::json!({ "name": "Alice" }));
457
+ let op_name = Some("CreateUser");
458
+ let body = build_graphql_body(query, variables, op_name);
459
+
460
+ assert_eq!(body["query"], query);
461
+ assert_eq!(body["variables"]["name"], "Alice");
462
+ assert_eq!(body["operationName"], "CreateUser");
463
+ }
311
464
  }
@@ -44,6 +44,28 @@ impl ResponseSnapshot {
44
44
  pub fn header(&self, name: &str) -> Option<&str> {
45
45
  self.headers.get(&name.to_ascii_lowercase()).map(|s| s.as_str())
46
46
  }
47
+
48
+ /// Extract GraphQL data from response
49
+ pub fn graphql_data(&self) -> Result<Value, SnapshotError> {
50
+ let body: Value = serde_json::from_slice(&self.body)
51
+ .map_err(|e| SnapshotError::Decompression(format!("Failed to parse JSON: {}", e)))?;
52
+
53
+ body.get("data")
54
+ .cloned()
55
+ .ok_or_else(|| SnapshotError::Decompression("No 'data' field in GraphQL response".to_string()))
56
+ }
57
+
58
+ /// Extract GraphQL errors from response
59
+ pub fn graphql_errors(&self) -> Result<Vec<Value>, SnapshotError> {
60
+ let body: Value = serde_json::from_slice(&self.body)
61
+ .map_err(|e| SnapshotError::Decompression(format!("Failed to parse JSON: {}", e)))?;
62
+
63
+ Ok(body
64
+ .get("errors")
65
+ .and_then(|e| e.as_array())
66
+ .cloned()
67
+ .unwrap_or_default())
68
+ }
47
69
  }
48
70
 
49
71
  /// Possible errors while converting an Axum response into a snapshot.
@@ -377,6 +399,9 @@ impl SseEvent {
377
399
  #[cfg(test)]
378
400
  mod tests {
379
401
  use super::*;
402
+ use axum::body::Body;
403
+ use axum::response::Response;
404
+ use std::io::Write;
380
405
 
381
406
  #[test]
382
407
  fn sse_stream_parses_multiple_events() {
@@ -403,4 +428,150 @@ mod tests {
403
428
  };
404
429
  assert!(event.as_json().is_err());
405
430
  }
431
+
432
+ #[test]
433
+ fn test_graphql_data_extraction() {
434
+ let mut headers = HashMap::new();
435
+ headers.insert("content-type".to_string(), "application/json".to_string());
436
+
437
+ let graphql_response = serde_json::json!({
438
+ "data": {
439
+ "user": {
440
+ "id": "1",
441
+ "name": "Alice"
442
+ }
443
+ }
444
+ });
445
+
446
+ let snapshot = ResponseSnapshot {
447
+ status: 200,
448
+ headers,
449
+ body: serde_json::to_vec(&graphql_response).unwrap(),
450
+ };
451
+
452
+ let data = snapshot.graphql_data().expect("data extraction");
453
+ assert_eq!(data["user"]["id"], "1");
454
+ assert_eq!(data["user"]["name"], "Alice");
455
+ }
456
+
457
+ #[test]
458
+ fn test_graphql_errors_extraction() {
459
+ let mut headers = HashMap::new();
460
+ headers.insert("content-type".to_string(), "application/json".to_string());
461
+
462
+ let graphql_response = serde_json::json!({
463
+ "errors": [
464
+ {
465
+ "message": "Field not found",
466
+ "path": ["user", "email"]
467
+ },
468
+ {
469
+ "message": "Unauthorized",
470
+ "extensions": { "code": "UNAUTHENTICATED" }
471
+ }
472
+ ]
473
+ });
474
+
475
+ let snapshot = ResponseSnapshot {
476
+ status: 400,
477
+ headers,
478
+ body: serde_json::to_vec(&graphql_response).unwrap(),
479
+ };
480
+
481
+ let errors = snapshot.graphql_errors().expect("errors extraction");
482
+ assert_eq!(errors.len(), 2);
483
+ assert_eq!(errors[0]["message"], "Field not found");
484
+ assert_eq!(errors[1]["message"], "Unauthorized");
485
+ }
486
+
487
+ #[test]
488
+ fn test_graphql_missing_data_field() {
489
+ let mut headers = HashMap::new();
490
+ headers.insert("content-type".to_string(), "application/json".to_string());
491
+
492
+ let graphql_response = serde_json::json!({
493
+ "errors": [{ "message": "Query failed" }]
494
+ });
495
+
496
+ let snapshot = ResponseSnapshot {
497
+ status: 400,
498
+ headers,
499
+ body: serde_json::to_vec(&graphql_response).unwrap(),
500
+ };
501
+
502
+ let result = snapshot.graphql_data();
503
+ assert!(result.is_err());
504
+ assert!(result.unwrap_err().to_string().contains("No 'data' field"));
505
+ }
506
+
507
+ #[test]
508
+ fn test_graphql_empty_errors() {
509
+ let mut headers = HashMap::new();
510
+ headers.insert("content-type".to_string(), "application/json".to_string());
511
+
512
+ let graphql_response = serde_json::json!({
513
+ "data": { "result": null }
514
+ });
515
+
516
+ let snapshot = ResponseSnapshot {
517
+ status: 200,
518
+ headers,
519
+ body: serde_json::to_vec(&graphql_response).unwrap(),
520
+ };
521
+
522
+ let errors = snapshot.graphql_errors().expect("errors extraction");
523
+ assert!(errors.is_empty());
524
+ }
525
+
526
+ fn gzip_bytes(input: &[u8]) -> Vec<u8> {
527
+ let mut encoder = flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
528
+ encoder.write_all(input).expect("gzip write");
529
+ encoder.finish().expect("gzip finish")
530
+ }
531
+
532
+ fn brotli_bytes(input: &[u8]) -> Vec<u8> {
533
+ let mut encoder = brotli::CompressorWriter::new(Vec::new(), 4096, 5, 22);
534
+ encoder.write_all(input).expect("brotli write");
535
+ encoder.into_inner()
536
+ }
537
+
538
+ #[tokio::test]
539
+ async fn snapshot_http_response_decodes_gzip_body() {
540
+ let body = b"hello gzip";
541
+ let compressed = gzip_bytes(body);
542
+ let response = Response::builder()
543
+ .status(200)
544
+ .header("content-encoding", "gzip")
545
+ .body(Body::from(compressed))
546
+ .unwrap();
547
+
548
+ let snapshot = snapshot_http_response(response).await.expect("snapshot");
549
+ assert_eq!(snapshot.body, body);
550
+ }
551
+
552
+ #[tokio::test]
553
+ async fn snapshot_http_response_decodes_brotli_body() {
554
+ let body = b"hello brotli";
555
+ let compressed = brotli_bytes(body);
556
+ let response = Response::builder()
557
+ .status(200)
558
+ .header("content-encoding", "br")
559
+ .body(Body::from(compressed))
560
+ .unwrap();
561
+
562
+ let snapshot = snapshot_http_response(response).await.expect("snapshot");
563
+ assert_eq!(snapshot.body, body);
564
+ }
565
+
566
+ #[tokio::test]
567
+ async fn snapshot_http_response_leaves_plain_body() {
568
+ let body = b"plain";
569
+ let response = Response::builder()
570
+ .status(200)
571
+ .body(Body::from(body.as_slice()))
572
+ .unwrap();
573
+
574
+ let snapshot = snapshot_http_response(response).await.expect("snapshot");
575
+ assert_eq!(snapshot.body, body);
576
+ }
406
577
  }
@@ -90,20 +90,31 @@ pub trait WebSocketHandler: Send + Sync {
90
90
  /// Contains the message handler and optional JSON schemas for validating
91
91
  /// incoming and outgoing messages. This state is shared among all connections
92
92
  /// to the same WebSocket endpoint.
93
- #[derive(Debug)]
94
93
  pub struct WebSocketState<H: WebSocketHandler> {
95
94
  /// The message handler implementation
96
95
  handler: Arc<H>,
96
+ /// Factory for producing per-connection handlers
97
+ handler_factory: Arc<dyn Fn() -> Result<Arc<H>, String> + Send + Sync>,
97
98
  /// Optional JSON Schema for validating incoming messages
98
99
  message_schema: Option<Arc<jsonschema::Validator>>,
99
100
  /// Optional JSON Schema for validating outgoing responses
100
101
  response_schema: Option<Arc<jsonschema::Validator>>,
101
102
  }
102
103
 
104
+ impl<H: WebSocketHandler> std::fmt::Debug for WebSocketState<H> {
105
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106
+ f.debug_struct("WebSocketState")
107
+ .field("message_schema", &self.message_schema.is_some())
108
+ .field("response_schema", &self.response_schema.is_some())
109
+ .finish()
110
+ }
111
+ }
112
+
103
113
  impl<H: WebSocketHandler> Clone for WebSocketState<H> {
104
114
  fn clone(&self) -> Self {
105
115
  Self {
106
116
  handler: Arc::clone(&self.handler),
117
+ handler_factory: Arc::clone(&self.handler_factory),
107
118
  message_schema: self.message_schema.clone(),
108
119
  response_schema: self.response_schema.clone(),
109
120
  }
@@ -125,8 +136,13 @@ impl<H: WebSocketHandler + 'static> WebSocketState<H> {
125
136
  /// let state = WebSocketState::new(MyHandler);
126
137
  /// ```
127
138
  pub fn new(handler: H) -> Self {
139
+ let handler = Arc::new(handler);
128
140
  Self {
129
- handler: Arc::new(handler),
141
+ handler_factory: Arc::new({
142
+ let handler = Arc::clone(&handler);
143
+ move || Ok(Arc::clone(&handler))
144
+ }),
145
+ handler,
130
146
  message_schema: None,
131
147
  response_schema: None,
132
148
  }
@@ -187,8 +203,56 @@ impl<H: WebSocketHandler + 'static> WebSocketState<H> {
187
203
  None
188
204
  };
189
205
 
206
+ let handler = Arc::new(handler);
190
207
  Ok(Self {
191
- handler: Arc::new(handler),
208
+ handler_factory: Arc::new({
209
+ let handler = Arc::clone(&handler);
210
+ move || Ok(Arc::clone(&handler))
211
+ }),
212
+ handler,
213
+ message_schema: message_validator,
214
+ response_schema: response_validator,
215
+ })
216
+ }
217
+
218
+ /// Create new WebSocket state with a handler factory and optional validation schemas.
219
+ ///
220
+ /// The factory is invoked once per connection, enabling per-connection handler state.
221
+ pub fn with_factory<F>(
222
+ factory: F,
223
+ message_schema: Option<serde_json::Value>,
224
+ response_schema: Option<serde_json::Value>,
225
+ ) -> Result<Self, String>
226
+ where
227
+ F: Fn() -> Result<H, String> + Send + Sync + 'static,
228
+ {
229
+ let message_validator = if let Some(schema) = message_schema {
230
+ Some(Arc::new(
231
+ jsonschema::validator_for(&schema).map_err(|e| format!("Invalid message schema: {}", e))?,
232
+ ))
233
+ } else {
234
+ None
235
+ };
236
+
237
+ let response_validator = if let Some(schema) = response_schema {
238
+ Some(Arc::new(
239
+ jsonschema::validator_for(&schema).map_err(|e| format!("Invalid response schema: {}", e))?,
240
+ ))
241
+ } else {
242
+ None
243
+ };
244
+
245
+ let factory = Arc::new(factory);
246
+ let handler = factory()
247
+ .map(Arc::new)
248
+ .map_err(|e| format!("Failed to build WebSocket handler: {}", e))?;
249
+
250
+ Ok(Self {
251
+ handler_factory: Arc::new({
252
+ let factory = Arc::clone(&factory);
253
+ move || factory().map(Arc::new)
254
+ }),
255
+ handler,
192
256
  message_schema: message_validator,
193
257
  response_schema: response_validator,
194
258
  })
@@ -258,7 +322,16 @@ async fn handle_socket<H: WebSocketHandler>(mut socket: WebSocket, state: WebSoc
258
322
  info!("WebSocket client connected");
259
323
  trace_ws("socket:connected");
260
324
 
261
- state.handler.on_connect().await;
325
+ let handler = match (state.handler_factory)() {
326
+ Ok(handler) => handler,
327
+ Err(err) => {
328
+ error!("Failed to create WebSocket handler: {}", err);
329
+ trace_ws("socket:handler-factory:error");
330
+ return;
331
+ }
332
+ };
333
+
334
+ handler.on_connect().await;
262
335
  trace_ws("socket:on_connect:done");
263
336
 
264
337
  while let Some(msg) = socket.recv().await {
@@ -285,7 +358,7 @@ async fn handle_socket<H: WebSocketHandler>(mut socket: WebSocket, state: WebSoc
285
358
  continue;
286
359
  }
287
360
 
288
- if let Some(response) = state.handler.handle_message(json_msg).await {
361
+ if let Some(response) = handler.handle_message(json_msg).await {
289
362
  trace_ws("handler:response:some");
290
363
  if let Some(validator) = &state.response_schema
291
364
  && !validator.is_valid(&response)
@@ -358,7 +431,7 @@ async fn handle_socket<H: WebSocketHandler>(mut socket: WebSocket, state: WebSoc
358
431
  }
359
432
  }
360
433
 
361
- state.handler.on_disconnect().await;
434
+ handler.on_disconnect().await;
362
435
  trace_ws("socket:on_disconnect:done");
363
436
  info!("WebSocket client disconnected");
364
437
  }