spikard 0.3.5 → 0.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +659 -659
  4. data/ext/spikard_rb/Cargo.toml +17 -17
  5. data/ext/spikard_rb/extconf.rb +10 -10
  6. data/ext/spikard_rb/src/lib.rs +6 -6
  7. data/lib/spikard/app.rb +386 -386
  8. data/lib/spikard/background.rb +27 -27
  9. data/lib/spikard/config.rb +396 -396
  10. data/lib/spikard/converters.rb +13 -13
  11. data/lib/spikard/handler_wrapper.rb +113 -113
  12. data/lib/spikard/provide.rb +214 -214
  13. data/lib/spikard/response.rb +173 -173
  14. data/lib/spikard/schema.rb +243 -243
  15. data/lib/spikard/sse.rb +111 -111
  16. data/lib/spikard/streaming_response.rb +44 -44
  17. data/lib/spikard/testing.rb +221 -221
  18. data/lib/spikard/upload_file.rb +131 -131
  19. data/lib/spikard/version.rb +5 -5
  20. data/lib/spikard/websocket.rb +59 -59
  21. data/lib/spikard.rb +43 -43
  22. data/sig/spikard.rbs +366 -360
  23. data/vendor/crates/spikard-core/Cargo.toml +40 -40
  24. data/vendor/crates/spikard-core/src/bindings/mod.rs +3 -3
  25. data/vendor/crates/spikard-core/src/bindings/response.rs +133 -133
  26. data/vendor/crates/spikard-core/src/debug.rs +63 -63
  27. data/vendor/crates/spikard-core/src/di/container.rs +726 -726
  28. data/vendor/crates/spikard-core/src/di/dependency.rs +273 -273
  29. data/vendor/crates/spikard-core/src/di/error.rs +118 -118
  30. data/vendor/crates/spikard-core/src/di/factory.rs +538 -538
  31. data/vendor/crates/spikard-core/src/di/graph.rs +545 -545
  32. data/vendor/crates/spikard-core/src/di/mod.rs +192 -192
  33. data/vendor/crates/spikard-core/src/di/resolved.rs +411 -411
  34. data/vendor/crates/spikard-core/src/di/value.rs +283 -283
  35. data/vendor/crates/spikard-core/src/errors.rs +39 -39
  36. data/vendor/crates/spikard-core/src/http.rs +153 -153
  37. data/vendor/crates/spikard-core/src/lib.rs +29 -29
  38. data/vendor/crates/spikard-core/src/lifecycle.rs +422 -422
  39. data/vendor/crates/spikard-core/src/parameters.rs +722 -722
  40. data/vendor/crates/spikard-core/src/problem.rs +310 -310
  41. data/vendor/crates/spikard-core/src/request_data.rs +189 -189
  42. data/vendor/crates/spikard-core/src/router.rs +249 -249
  43. data/vendor/crates/spikard-core/src/schema_registry.rs +183 -183
  44. data/vendor/crates/spikard-core/src/type_hints.rs +304 -304
  45. data/vendor/crates/spikard-core/src/validation.rs +699 -699
  46. data/vendor/crates/spikard-http/Cargo.toml +68 -68
  47. data/vendor/crates/spikard-http/src/auth.rs +247 -247
  48. data/vendor/crates/spikard-http/src/background.rs +249 -249
  49. data/vendor/crates/spikard-http/src/bindings/mod.rs +3 -3
  50. data/vendor/crates/spikard-http/src/bindings/response.rs +1 -1
  51. data/vendor/crates/spikard-http/src/body_metadata.rs +8 -8
  52. data/vendor/crates/spikard-http/src/cors.rs +490 -490
  53. data/vendor/crates/spikard-http/src/debug.rs +63 -63
  54. data/vendor/crates/spikard-http/src/di_handler.rs +423 -423
  55. data/vendor/crates/spikard-http/src/handler_response.rs +190 -190
  56. data/vendor/crates/spikard-http/src/handler_trait.rs +228 -228
  57. data/vendor/crates/spikard-http/src/handler_trait_tests.rs +284 -284
  58. data/vendor/crates/spikard-http/src/lib.rs +529 -529
  59. data/vendor/crates/spikard-http/src/lifecycle/adapter.rs +149 -149
  60. data/vendor/crates/spikard-http/src/lifecycle.rs +428 -428
  61. data/vendor/crates/spikard-http/src/middleware/mod.rs +285 -285
  62. data/vendor/crates/spikard-http/src/middleware/multipart.rs +86 -86
  63. data/vendor/crates/spikard-http/src/middleware/urlencoded.rs +147 -147
  64. data/vendor/crates/spikard-http/src/middleware/validation.rs +287 -287
  65. data/vendor/crates/spikard-http/src/openapi/mod.rs +309 -309
  66. data/vendor/crates/spikard-http/src/openapi/parameter_extraction.rs +190 -190
  67. data/vendor/crates/spikard-http/src/openapi/schema_conversion.rs +308 -308
  68. data/vendor/crates/spikard-http/src/openapi/spec_generation.rs +195 -195
  69. data/vendor/crates/spikard-http/src/parameters.rs +1 -1
  70. data/vendor/crates/spikard-http/src/problem.rs +1 -1
  71. data/vendor/crates/spikard-http/src/query_parser.rs +369 -369
  72. data/vendor/crates/spikard-http/src/response.rs +399 -399
  73. data/vendor/crates/spikard-http/src/router.rs +1 -1
  74. data/vendor/crates/spikard-http/src/schema_registry.rs +1 -1
  75. data/vendor/crates/spikard-http/src/server/handler.rs +87 -87
  76. data/vendor/crates/spikard-http/src/server/lifecycle_execution.rs +98 -98
  77. data/vendor/crates/spikard-http/src/server/mod.rs +805 -805
  78. data/vendor/crates/spikard-http/src/server/request_extraction.rs +119 -119
  79. data/vendor/crates/spikard-http/src/sse.rs +447 -447
  80. data/vendor/crates/spikard-http/src/testing/form.rs +14 -14
  81. data/vendor/crates/spikard-http/src/testing/multipart.rs +60 -60
  82. data/vendor/crates/spikard-http/src/testing/test_client.rs +285 -285
  83. data/vendor/crates/spikard-http/src/testing.rs +377 -377
  84. data/vendor/crates/spikard-http/src/type_hints.rs +1 -1
  85. data/vendor/crates/spikard-http/src/validation.rs +1 -1
  86. data/vendor/crates/spikard-http/src/websocket.rs +324 -324
  87. data/vendor/crates/spikard-rb/Cargo.toml +42 -42
  88. data/vendor/crates/spikard-rb/build.rs +8 -8
  89. data/vendor/crates/spikard-rb/src/background.rs +63 -63
  90. data/vendor/crates/spikard-rb/src/config.rs +294 -294
  91. data/vendor/crates/spikard-rb/src/conversion.rs +453 -453
  92. data/vendor/crates/spikard-rb/src/di.rs +409 -409
  93. data/vendor/crates/spikard-rb/src/handler.rs +625 -625
  94. data/vendor/crates/spikard-rb/src/lib.rs +2771 -2771
  95. data/vendor/crates/spikard-rb/src/lifecycle.rs +274 -274
  96. data/vendor/crates/spikard-rb/src/server.rs +283 -283
  97. data/vendor/crates/spikard-rb/src/sse.rs +231 -231
  98. data/vendor/crates/spikard-rb/src/test_client.rs +404 -404
  99. data/vendor/crates/spikard-rb/src/test_sse.rs +143 -143
  100. data/vendor/crates/spikard-rb/src/test_websocket.rs +221 -221
  101. data/vendor/crates/spikard-rb/src/websocket.rs +233 -233
  102. metadata +1 -1
@@ -1 +1 @@
1
- pub use spikard_core::type_hints::*;
1
+ pub use spikard_core::type_hints::*;
@@ -1 +1 @@
1
- pub use spikard_core::validation::*;
1
+ pub use spikard_core::validation::*;
@@ -1,324 +1,324 @@
1
- //! WebSocket support for Spikard
2
- //!
3
- //! Provides WebSocket connection handling with message validation and routing.
4
-
5
- use axum::{
6
- extract::{
7
- State,
8
- ws::{Message, WebSocket, WebSocketUpgrade},
9
- },
10
- response::IntoResponse,
11
- };
12
- use serde_json::Value;
13
- use std::sync::Arc;
14
- use tracing::{debug, error, info, warn};
15
-
16
- /// WebSocket message handler trait
17
- ///
18
- /// Implement this trait to create custom WebSocket message handlers for your application.
19
- /// The handler processes JSON messages received from WebSocket clients and can optionally
20
- /// send responses back.
21
- ///
22
- /// # Implementing the Trait
23
- ///
24
- /// You must implement the `handle_message` method. The `on_connect` and `on_disconnect`
25
- /// methods are optional and provide lifecycle hooks.
26
- ///
27
- /// # Example
28
- ///
29
- /// ```ignore
30
- /// use spikard_http::websocket::WebSocketHandler;
31
- /// use serde_json::{json, Value};
32
- ///
33
- /// struct EchoHandler;
34
- ///
35
- /// #[async_trait]
36
- /// impl WebSocketHandler for EchoHandler {
37
- /// async fn handle_message(&self, message: Value) -> Option<Value> {
38
- /// // Echo the message back to the client
39
- /// Some(message)
40
- /// }
41
- ///
42
- /// async fn on_connect(&self) {
43
- /// println!("Client connected");
44
- /// }
45
- ///
46
- /// async fn on_disconnect(&self) {
47
- /// println!("Client disconnected");
48
- /// }
49
- /// }
50
- /// ```
51
- pub trait WebSocketHandler: Send + Sync {
52
- /// Handle incoming WebSocket message
53
- ///
54
- /// Called whenever a text message is received from a WebSocket client.
55
- /// Messages are automatically parsed as JSON.
56
- ///
57
- /// # Arguments
58
- /// * `message` - JSON value received from the client
59
- ///
60
- /// # Returns
61
- /// * `Some(value)` - JSON value to send back to the client
62
- /// * `None` - No response to send
63
- fn handle_message(&self, message: Value) -> impl std::future::Future<Output = Option<Value>> + Send;
64
-
65
- /// Called when a client connects to the WebSocket
66
- ///
67
- /// Optional lifecycle hook invoked when a new WebSocket connection is established.
68
- /// Default implementation does nothing.
69
- fn on_connect(&self) -> impl std::future::Future<Output = ()> + Send {
70
- async {}
71
- }
72
-
73
- /// Called when a client disconnects from the WebSocket
74
- ///
75
- /// Optional lifecycle hook invoked when a WebSocket connection is closed
76
- /// (either by the client or due to an error). Default implementation does nothing.
77
- fn on_disconnect(&self) -> impl std::future::Future<Output = ()> + Send {
78
- async {}
79
- }
80
- }
81
-
82
- /// WebSocket state shared across connections
83
- ///
84
- /// Contains the message handler and optional JSON schemas for validating
85
- /// incoming and outgoing messages. This state is shared among all connections
86
- /// to the same WebSocket endpoint.
87
- pub struct WebSocketState<H: WebSocketHandler> {
88
- /// The message handler implementation
89
- handler: Arc<H>,
90
- /// Optional JSON Schema for validating incoming messages
91
- message_schema: Option<Arc<jsonschema::Validator>>,
92
- /// Optional JSON Schema for validating outgoing responses
93
- response_schema: Option<Arc<jsonschema::Validator>>,
94
- }
95
-
96
- impl<H: WebSocketHandler> Clone for WebSocketState<H> {
97
- fn clone(&self) -> Self {
98
- Self {
99
- handler: Arc::clone(&self.handler),
100
- message_schema: self.message_schema.clone(),
101
- response_schema: self.response_schema.clone(),
102
- }
103
- }
104
- }
105
-
106
- impl<H: WebSocketHandler + 'static> WebSocketState<H> {
107
- /// Create new WebSocket state with a handler
108
- ///
109
- /// Creates a new state without message or response validation schemas.
110
- /// Messages and responses are not validated.
111
- ///
112
- /// # Arguments
113
- /// * `handler` - The message handler implementation
114
- ///
115
- /// # Example
116
- ///
117
- /// ```ignore
118
- /// let state = WebSocketState::new(MyHandler);
119
- /// ```
120
- pub fn new(handler: H) -> Self {
121
- Self {
122
- handler: Arc::new(handler),
123
- message_schema: None,
124
- response_schema: None,
125
- }
126
- }
127
-
128
- /// Create new WebSocket state with a handler and optional validation schemas
129
- ///
130
- /// Creates a new state with optional JSON schemas for validating incoming messages
131
- /// and outgoing responses. If a schema is provided and validation fails, the message
132
- /// or response is rejected.
133
- ///
134
- /// # Arguments
135
- /// * `handler` - The message handler implementation
136
- /// * `message_schema` - Optional JSON schema for validating client messages
137
- /// * `response_schema` - Optional JSON schema for validating handler responses
138
- ///
139
- /// # Returns
140
- /// * `Ok(state)` - Successfully created state
141
- /// * `Err(msg)` - Invalid schema provided
142
- ///
143
- /// # Example
144
- ///
145
- /// ```ignore
146
- /// use serde_json::json;
147
- ///
148
- /// let message_schema = json!({
149
- /// "type": "object",
150
- /// "properties": {
151
- /// "type": {"type": "string"},
152
- /// "data": {"type": "string"}
153
- /// }
154
- /// });
155
- ///
156
- /// let state = WebSocketState::with_schemas(
157
- /// MyHandler,
158
- /// Some(message_schema),
159
- /// None,
160
- /// )?;
161
- /// ```
162
- pub fn with_schemas(
163
- handler: H,
164
- message_schema: Option<serde_json::Value>,
165
- response_schema: Option<serde_json::Value>,
166
- ) -> Result<Self, String> {
167
- let message_validator = if let Some(schema) = message_schema {
168
- Some(Arc::new(
169
- jsonschema::validator_for(&schema).map_err(|e| format!("Invalid message schema: {}", e))?,
170
- ))
171
- } else {
172
- None
173
- };
174
-
175
- let response_validator = if let Some(schema) = response_schema {
176
- Some(Arc::new(
177
- jsonschema::validator_for(&schema).map_err(|e| format!("Invalid response schema: {}", e))?,
178
- ))
179
- } else {
180
- None
181
- };
182
-
183
- Ok(Self {
184
- handler: Arc::new(handler),
185
- message_schema: message_validator,
186
- response_schema: response_validator,
187
- })
188
- }
189
- }
190
-
191
- /// WebSocket upgrade handler
192
- ///
193
- /// This is the main entry point for WebSocket connections. Use this as an Axum route
194
- /// handler by passing it to an Axum router's `.route()` method with `get()`.
195
- ///
196
- /// # Arguments
197
- /// * `ws` - WebSocket upgrade from Axum
198
- /// * `State(state)` - Application state containing the handler and optional schemas
199
- ///
200
- /// # Returns
201
- /// An Axum response that upgrades the connection to WebSocket
202
- ///
203
- /// # Example
204
- ///
205
- /// ```ignore
206
- /// use axum::{Router, routing::get, extract::State};
207
- ///
208
- /// let state = WebSocketState::new(MyHandler);
209
- /// let router = Router::new()
210
- /// .route("/ws", get(websocket_handler::<MyHandler>))
211
- /// .with_state(state);
212
- /// ```
213
- pub async fn websocket_handler<H: WebSocketHandler + 'static>(
214
- ws: WebSocketUpgrade,
215
- State(state): State<WebSocketState<H>>,
216
- ) -> impl IntoResponse {
217
- ws.on_upgrade(move |socket| handle_socket(socket, state))
218
- }
219
-
220
- /// Handle an individual WebSocket connection
221
- async fn handle_socket<H: WebSocketHandler>(mut socket: WebSocket, state: WebSocketState<H>) {
222
- info!("WebSocket client connected");
223
-
224
- state.handler.on_connect().await;
225
-
226
- while let Some(msg) = socket.recv().await {
227
- match msg {
228
- Ok(Message::Text(text)) => {
229
- debug!("Received text message: {}", text);
230
-
231
- match serde_json::from_str::<Value>(&text) {
232
- Ok(json_msg) => {
233
- if let Some(validator) = &state.message_schema
234
- && !validator.is_valid(&json_msg)
235
- {
236
- error!("Message validation failed");
237
- let error_response = serde_json::json!({
238
- "error": "Message validation failed"
239
- });
240
- if let Ok(error_text) = serde_json::to_string(&error_response) {
241
- let _ = socket.send(Message::Text(error_text.into())).await;
242
- }
243
- continue;
244
- }
245
-
246
- if let Some(response) = state.handler.handle_message(json_msg).await {
247
- if let Some(validator) = &state.response_schema
248
- && !validator.is_valid(&response)
249
- {
250
- error!("Response validation failed");
251
- continue;
252
- }
253
-
254
- let response_text = serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string());
255
-
256
- if let Err(e) = socket.send(Message::Text(response_text.into())).await {
257
- error!("Failed to send response: {}", e);
258
- break;
259
- }
260
- }
261
- }
262
- Err(e) => {
263
- warn!("Failed to parse JSON message: {}", e);
264
- let error_msg = serde_json::json!({
265
- "type": "error",
266
- "message": "Invalid JSON"
267
- });
268
- let error_text = serde_json::to_string(&error_msg).unwrap();
269
- let _ = socket.send(Message::Text(error_text.into())).await;
270
- }
271
- }
272
- }
273
- Ok(Message::Binary(data)) => {
274
- debug!("Received binary message: {} bytes", data.len());
275
- if let Err(e) = socket.send(Message::Binary(data)).await {
276
- error!("Failed to send binary response: {}", e);
277
- break;
278
- }
279
- }
280
- Ok(Message::Ping(data)) => {
281
- debug!("Received ping");
282
- if let Err(e) = socket.send(Message::Pong(data)).await {
283
- error!("Failed to send pong: {}", e);
284
- break;
285
- }
286
- }
287
- Ok(Message::Pong(_)) => {
288
- debug!("Received pong");
289
- }
290
- Ok(Message::Close(_)) => {
291
- info!("Client closed connection");
292
- break;
293
- }
294
- Err(e) => {
295
- error!("WebSocket error: {}", e);
296
- break;
297
- }
298
- }
299
- }
300
-
301
- state.handler.on_disconnect().await;
302
- info!("WebSocket client disconnected");
303
- }
304
-
305
- #[cfg(test)]
306
- mod tests {
307
- use super::*;
308
-
309
- struct EchoHandler;
310
-
311
- impl WebSocketHandler for EchoHandler {
312
- async fn handle_message(&self, message: Value) -> Option<Value> {
313
- Some(message)
314
- }
315
- }
316
-
317
- #[test]
318
- fn test_websocket_state_creation() {
319
- let handler = EchoHandler;
320
- let state = WebSocketState::new(handler);
321
- let cloned = state.clone();
322
- assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
323
- }
324
- }
1
+ //! WebSocket support for Spikard
2
+ //!
3
+ //! Provides WebSocket connection handling with message validation and routing.
4
+
5
+ use axum::{
6
+ extract::{
7
+ State,
8
+ ws::{Message, WebSocket, WebSocketUpgrade},
9
+ },
10
+ response::IntoResponse,
11
+ };
12
+ use serde_json::Value;
13
+ use std::sync::Arc;
14
+ use tracing::{debug, error, info, warn};
15
+
16
+ /// WebSocket message handler trait
17
+ ///
18
+ /// Implement this trait to create custom WebSocket message handlers for your application.
19
+ /// The handler processes JSON messages received from WebSocket clients and can optionally
20
+ /// send responses back.
21
+ ///
22
+ /// # Implementing the Trait
23
+ ///
24
+ /// You must implement the `handle_message` method. The `on_connect` and `on_disconnect`
25
+ /// methods are optional and provide lifecycle hooks.
26
+ ///
27
+ /// # Example
28
+ ///
29
+ /// ```ignore
30
+ /// use spikard_http::websocket::WebSocketHandler;
31
+ /// use serde_json::{json, Value};
32
+ ///
33
+ /// struct EchoHandler;
34
+ ///
35
+ /// #[async_trait]
36
+ /// impl WebSocketHandler for EchoHandler {
37
+ /// async fn handle_message(&self, message: Value) -> Option<Value> {
38
+ /// // Echo the message back to the client
39
+ /// Some(message)
40
+ /// }
41
+ ///
42
+ /// async fn on_connect(&self) {
43
+ /// println!("Client connected");
44
+ /// }
45
+ ///
46
+ /// async fn on_disconnect(&self) {
47
+ /// println!("Client disconnected");
48
+ /// }
49
+ /// }
50
+ /// ```
51
+ pub trait WebSocketHandler: Send + Sync {
52
+ /// Handle incoming WebSocket message
53
+ ///
54
+ /// Called whenever a text message is received from a WebSocket client.
55
+ /// Messages are automatically parsed as JSON.
56
+ ///
57
+ /// # Arguments
58
+ /// * `message` - JSON value received from the client
59
+ ///
60
+ /// # Returns
61
+ /// * `Some(value)` - JSON value to send back to the client
62
+ /// * `None` - No response to send
63
+ fn handle_message(&self, message: Value) -> impl std::future::Future<Output = Option<Value>> + Send;
64
+
65
+ /// Called when a client connects to the WebSocket
66
+ ///
67
+ /// Optional lifecycle hook invoked when a new WebSocket connection is established.
68
+ /// Default implementation does nothing.
69
+ fn on_connect(&self) -> impl std::future::Future<Output = ()> + Send {
70
+ async {}
71
+ }
72
+
73
+ /// Called when a client disconnects from the WebSocket
74
+ ///
75
+ /// Optional lifecycle hook invoked when a WebSocket connection is closed
76
+ /// (either by the client or due to an error). Default implementation does nothing.
77
+ fn on_disconnect(&self) -> impl std::future::Future<Output = ()> + Send {
78
+ async {}
79
+ }
80
+ }
81
+
82
+ /// WebSocket state shared across connections
83
+ ///
84
+ /// Contains the message handler and optional JSON schemas for validating
85
+ /// incoming and outgoing messages. This state is shared among all connections
86
+ /// to the same WebSocket endpoint.
87
+ pub struct WebSocketState<H: WebSocketHandler> {
88
+ /// The message handler implementation
89
+ handler: Arc<H>,
90
+ /// Optional JSON Schema for validating incoming messages
91
+ message_schema: Option<Arc<jsonschema::Validator>>,
92
+ /// Optional JSON Schema for validating outgoing responses
93
+ response_schema: Option<Arc<jsonschema::Validator>>,
94
+ }
95
+
96
+ impl<H: WebSocketHandler> Clone for WebSocketState<H> {
97
+ fn clone(&self) -> Self {
98
+ Self {
99
+ handler: Arc::clone(&self.handler),
100
+ message_schema: self.message_schema.clone(),
101
+ response_schema: self.response_schema.clone(),
102
+ }
103
+ }
104
+ }
105
+
106
+ impl<H: WebSocketHandler + 'static> WebSocketState<H> {
107
+ /// Create new WebSocket state with a handler
108
+ ///
109
+ /// Creates a new state without message or response validation schemas.
110
+ /// Messages and responses are not validated.
111
+ ///
112
+ /// # Arguments
113
+ /// * `handler` - The message handler implementation
114
+ ///
115
+ /// # Example
116
+ ///
117
+ /// ```ignore
118
+ /// let state = WebSocketState::new(MyHandler);
119
+ /// ```
120
+ pub fn new(handler: H) -> Self {
121
+ Self {
122
+ handler: Arc::new(handler),
123
+ message_schema: None,
124
+ response_schema: None,
125
+ }
126
+ }
127
+
128
+ /// Create new WebSocket state with a handler and optional validation schemas
129
+ ///
130
+ /// Creates a new state with optional JSON schemas for validating incoming messages
131
+ /// and outgoing responses. If a schema is provided and validation fails, the message
132
+ /// or response is rejected.
133
+ ///
134
+ /// # Arguments
135
+ /// * `handler` - The message handler implementation
136
+ /// * `message_schema` - Optional JSON schema for validating client messages
137
+ /// * `response_schema` - Optional JSON schema for validating handler responses
138
+ ///
139
+ /// # Returns
140
+ /// * `Ok(state)` - Successfully created state
141
+ /// * `Err(msg)` - Invalid schema provided
142
+ ///
143
+ /// # Example
144
+ ///
145
+ /// ```ignore
146
+ /// use serde_json::json;
147
+ ///
148
+ /// let message_schema = json!({
149
+ /// "type": "object",
150
+ /// "properties": {
151
+ /// "type": {"type": "string"},
152
+ /// "data": {"type": "string"}
153
+ /// }
154
+ /// });
155
+ ///
156
+ /// let state = WebSocketState::with_schemas(
157
+ /// MyHandler,
158
+ /// Some(message_schema),
159
+ /// None,
160
+ /// )?;
161
+ /// ```
162
+ pub fn with_schemas(
163
+ handler: H,
164
+ message_schema: Option<serde_json::Value>,
165
+ response_schema: Option<serde_json::Value>,
166
+ ) -> Result<Self, String> {
167
+ let message_validator = if let Some(schema) = message_schema {
168
+ Some(Arc::new(
169
+ jsonschema::validator_for(&schema).map_err(|e| format!("Invalid message schema: {}", e))?,
170
+ ))
171
+ } else {
172
+ None
173
+ };
174
+
175
+ let response_validator = if let Some(schema) = response_schema {
176
+ Some(Arc::new(
177
+ jsonschema::validator_for(&schema).map_err(|e| format!("Invalid response schema: {}", e))?,
178
+ ))
179
+ } else {
180
+ None
181
+ };
182
+
183
+ Ok(Self {
184
+ handler: Arc::new(handler),
185
+ message_schema: message_validator,
186
+ response_schema: response_validator,
187
+ })
188
+ }
189
+ }
190
+
191
+ /// WebSocket upgrade handler
192
+ ///
193
+ /// This is the main entry point for WebSocket connections. Use this as an Axum route
194
+ /// handler by passing it to an Axum router's `.route()` method with `get()`.
195
+ ///
196
+ /// # Arguments
197
+ /// * `ws` - WebSocket upgrade from Axum
198
+ /// * `State(state)` - Application state containing the handler and optional schemas
199
+ ///
200
+ /// # Returns
201
+ /// An Axum response that upgrades the connection to WebSocket
202
+ ///
203
+ /// # Example
204
+ ///
205
+ /// ```ignore
206
+ /// use axum::{Router, routing::get, extract::State};
207
+ ///
208
+ /// let state = WebSocketState::new(MyHandler);
209
+ /// let router = Router::new()
210
+ /// .route("/ws", get(websocket_handler::<MyHandler>))
211
+ /// .with_state(state);
212
+ /// ```
213
+ pub async fn websocket_handler<H: WebSocketHandler + 'static>(
214
+ ws: WebSocketUpgrade,
215
+ State(state): State<WebSocketState<H>>,
216
+ ) -> impl IntoResponse {
217
+ ws.on_upgrade(move |socket| handle_socket(socket, state))
218
+ }
219
+
220
+ /// Handle an individual WebSocket connection
221
+ async fn handle_socket<H: WebSocketHandler>(mut socket: WebSocket, state: WebSocketState<H>) {
222
+ info!("WebSocket client connected");
223
+
224
+ state.handler.on_connect().await;
225
+
226
+ while let Some(msg) = socket.recv().await {
227
+ match msg {
228
+ Ok(Message::Text(text)) => {
229
+ debug!("Received text message: {}", text);
230
+
231
+ match serde_json::from_str::<Value>(&text) {
232
+ Ok(json_msg) => {
233
+ if let Some(validator) = &state.message_schema
234
+ && !validator.is_valid(&json_msg)
235
+ {
236
+ error!("Message validation failed");
237
+ let error_response = serde_json::json!({
238
+ "error": "Message validation failed"
239
+ });
240
+ if let Ok(error_text) = serde_json::to_string(&error_response) {
241
+ let _ = socket.send(Message::Text(error_text.into())).await;
242
+ }
243
+ continue;
244
+ }
245
+
246
+ if let Some(response) = state.handler.handle_message(json_msg).await {
247
+ if let Some(validator) = &state.response_schema
248
+ && !validator.is_valid(&response)
249
+ {
250
+ error!("Response validation failed");
251
+ continue;
252
+ }
253
+
254
+ let response_text = serde_json::to_string(&response).unwrap_or_else(|_| "{}".to_string());
255
+
256
+ if let Err(e) = socket.send(Message::Text(response_text.into())).await {
257
+ error!("Failed to send response: {}", e);
258
+ break;
259
+ }
260
+ }
261
+ }
262
+ Err(e) => {
263
+ warn!("Failed to parse JSON message: {}", e);
264
+ let error_msg = serde_json::json!({
265
+ "type": "error",
266
+ "message": "Invalid JSON"
267
+ });
268
+ let error_text = serde_json::to_string(&error_msg).unwrap();
269
+ let _ = socket.send(Message::Text(error_text.into())).await;
270
+ }
271
+ }
272
+ }
273
+ Ok(Message::Binary(data)) => {
274
+ debug!("Received binary message: {} bytes", data.len());
275
+ if let Err(e) = socket.send(Message::Binary(data)).await {
276
+ error!("Failed to send binary response: {}", e);
277
+ break;
278
+ }
279
+ }
280
+ Ok(Message::Ping(data)) => {
281
+ debug!("Received ping");
282
+ if let Err(e) = socket.send(Message::Pong(data)).await {
283
+ error!("Failed to send pong: {}", e);
284
+ break;
285
+ }
286
+ }
287
+ Ok(Message::Pong(_)) => {
288
+ debug!("Received pong");
289
+ }
290
+ Ok(Message::Close(_)) => {
291
+ info!("Client closed connection");
292
+ break;
293
+ }
294
+ Err(e) => {
295
+ error!("WebSocket error: {}", e);
296
+ break;
297
+ }
298
+ }
299
+ }
300
+
301
+ state.handler.on_disconnect().await;
302
+ info!("WebSocket client disconnected");
303
+ }
304
+
305
+ #[cfg(test)]
306
+ mod tests {
307
+ use super::*;
308
+
309
+ struct EchoHandler;
310
+
311
+ impl WebSocketHandler for EchoHandler {
312
+ async fn handle_message(&self, message: Value) -> Option<Value> {
313
+ Some(message)
314
+ }
315
+ }
316
+
317
+ #[test]
318
+ fn test_websocket_state_creation() {
319
+ let handler = EchoHandler;
320
+ let state = WebSocketState::new(handler);
321
+ let cloned = state.clone();
322
+ assert!(Arc::ptr_eq(&state.handler, &cloned.handler));
323
+ }
324
+ }