distfl-client 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. distfl_client-1.0.0/MANIFEST.in +9 -0
  2. distfl_client-1.0.0/PKG-INFO +332 -0
  3. distfl_client-1.0.0/README.md +301 -0
  4. distfl_client-1.0.0/distfl_client.egg-info/PKG-INFO +332 -0
  5. distfl_client-1.0.0/distfl_client.egg-info/SOURCES.txt +33 -0
  6. distfl_client-1.0.0/distfl_client.egg-info/dependency_links.txt +1 -0
  7. distfl_client-1.0.0/distfl_client.egg-info/entry_points.txt +3 -0
  8. distfl_client-1.0.0/distfl_client.egg-info/requires.txt +13 -0
  9. distfl_client-1.0.0/distfl_client.egg-info/top_level.txt +1 -0
  10. distfl_client-1.0.0/fl_client/__init__.py +7 -0
  11. distfl_client-1.0.0/fl_client/cli/__init__.py +4 -0
  12. distfl_client-1.0.0/fl_client/cli/main.py +242 -0
  13. distfl_client-1.0.0/fl_client/communication/__init__.py +5 -0
  14. distfl_client-1.0.0/fl_client/communication/compressor.py +56 -0
  15. distfl_client-1.0.0/fl_client/communication/serializer.py +132 -0
  16. distfl_client-1.0.0/fl_client/config/__init__.py +4 -0
  17. distfl_client-1.0.0/fl_client/config/config.py +145 -0
  18. distfl_client-1.0.0/fl_client/core/__init__.py +4 -0
  19. distfl_client-1.0.0/fl_client/core/client.py +1274 -0
  20. distfl_client-1.0.0/fl_client/core/connection.py +270 -0
  21. distfl_client-1.0.0/fl_client/core/session.py +165 -0
  22. distfl_client-1.0.0/fl_client/core/state_manager.py +168 -0
  23. distfl_client-1.0.0/fl_client/model/__init__.py +4 -0
  24. distfl_client-1.0.0/fl_client/model/wrapper.py +347 -0
  25. distfl_client-1.0.0/fl_client/storage/__init__.py +4 -0
  26. distfl_client-1.0.0/fl_client/storage/db.py +267 -0
  27. distfl_client-1.0.0/fl_client/training/__init__.py +5 -0
  28. distfl_client-1.0.0/fl_client/training/dataset.py +196 -0
  29. distfl_client-1.0.0/fl_client/training/trainer.py +431 -0
  30. distfl_client-1.0.0/fl_client/validation/__init__.py +4 -0
  31. distfl_client-1.0.0/fl_client/validation/checks.py +128 -0
  32. distfl_client-1.0.0/fl_client/web/__init__.py +1 -0
  33. distfl_client-1.0.0/fl_client/web/bridge.py +641 -0
  34. distfl_client-1.0.0/pyproject.toml +70 -0
  35. distfl_client-1.0.0/setup.cfg +4 -0
@@ -0,0 +1,9 @@
1
+ # ── Include pre-built frontend assets ──
2
+ recursive-include fl_client/web/static *
3
+
4
+ # ── Exclude dev-only files from sdist ──
5
+ prune tests
6
+ prune fl_client/web/ui
7
+ prune fl_client/dashboard
8
+ exclude example_config.yaml
9
+ exclude example_usage.py
@@ -0,0 +1,332 @@
1
+ Metadata-Version: 2.4
2
+ Name: distfl-client
3
+ Version: 1.0.0
4
+ Summary: Production-grade Python Client SDK for room-based Federated Learning
5
+ Author: DistFL Team
6
+ License: MIT
7
+ Keywords: federated-learning,machine-learning,distributed,pytorch
8
+ Classifier: Development Status :: 4 - Beta
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Requires-Python: >=3.10
18
+ Description-Content-Type: text/markdown
19
+ Requires-Dist: torch>=2.0.0
20
+ Requires-Dist: numpy>=1.24.0
21
+ Requires-Dist: pandas>=2.0.0
22
+ Requires-Dist: websockets>=12.0
23
+ Requires-Dist: httpx>=0.25.0
24
+ Requires-Dist: pyyaml>=6.0
25
+ Requires-Dist: fastapi>=0.110.0
26
+ Requires-Dist: uvicorn[standard]>=0.27.0
27
+ Provides-Extra: dev
28
+ Requires-Dist: pytest>=7.0; extra == "dev"
29
+ Requires-Dist: pytest-asyncio>=0.21; extra == "dev"
30
+ Requires-Dist: pytest-cov>=4.0; extra == "dev"
31
+
32
+ <p align="center">
33
+ <h1 align="center">DistFL</h1>
34
+ <p align="center">
35
+ <strong>Production-Grade Federated Learning Client SDK</strong>
36
+ </p>
37
+ <p align="center">
38
+ <a href="#installation"><img src="https://img.shields.io/badge/python-≥3.10-blue?logo=python&logoColor=white" alt="Python"></a>
39
+ <a href="https://pypi.org/project/distfl-client/"><img src="https://img.shields.io/pypi/v/distfl-client?color=green&label=PyPI" alt="PyPI"></a>
40
+ <a href="#license"><img src="https://img.shields.io/badge/license-MIT-purple" alt="License"></a>
41
+ <a href="#testing"><img src="https://img.shields.io/badge/tests-52%20passed-brightgreen" alt="Tests"></a>
42
+ </p>
43
+ </p>
44
+
45
+ ---
46
+
47
+ Bring your own model (PyTorch or Scikit-Learn), connect to a DistFL server, train locally on **private data**, and let the server aggregate updates — all via compressed WebSocket communication. No raw data ever leaves the client.
48
+
49
+ ```
50
+ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐
51
+ │ Client A │ │ Client B │ │ Client C │
52
+ │ (Hospital) │ │ (Bank) │ │ (Lab) │
53
+ │ Local Data │ │ Local Data │ │ Local Data │
54
+ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘
55
+ │ model updates │ (gzip+WS) │
56
+ └────────────────────┼────────────────────┘
57
+
58
+ ┌───────▼────────┐
59
+ │ DistFL Server │
60
+ │ (Go Backend) │
61
+ │ FedAvg Agg. │
62
+ └───────┬────────┘
63
+
64
+ aggregated global model
65
+ broadcast to all clients
66
+ ```
67
+
68
+ ---
69
+
70
+ ## ✨ Features
71
+
72
+ | Category | Details |
73
+ |---|---|
74
+ | **BYOM** | Use any PyTorch `nn.Module` or Scikit-Learn estimator with `partial_fit` |
75
+ | **Simple Lifecycle** | `initialize()` → `validate()` → `start()` — 3 calls to go from zero to training |
76
+ | **Room-Based FL** | Create rooms, share invite codes, configure training params per room |
77
+ | **Compressed WebSocket** | GZIP-compressed binary messages over persistent WebSocket connections |
78
+ | **Auto Reconnect** | Exponential backoff with configurable delays and heartbeat pings |
79
+ | **Crash Recovery** | SQLite-backed state persistence — no duplicate round submissions after restart |
80
+ | **Live Dashboard** | Built-in web UI with real-time loss curves, ΔW tracking, and training logs |
81
+ | **Prediction** | Extract globally-aggregated weights and run inference locally |
82
+ | **CLI** | `distfl run`, `distfl create-room`, `distfl join-room`, `distfl ui`, `distfl status` |
83
+
84
+ ---
85
+
86
+ ## 📦 Installation
87
+
88
+ ```bash
89
+ pip install distfl-client
90
+ ```
91
+
92
+ **From source:**
93
+
94
+ ```bash
95
+ git clone https://github.com/AbhaySingh002/new-repo-code.git
96
+ cd new-repo-code/DistFL
97
+ pip install -e ".[dev]"
98
+ ```
99
+
100
+ ---
101
+
102
+ ## 🚀 Quick Start
103
+
104
+ ### 1. Room Creator
105
+
106
+ The creator relies on the `FLClient` to initialize the global model architecture, create a new room on the server, and wait for other participants to join before starting.
107
+
108
+ > [!NOTE]
109
+ > **Why `partial_fit`?**
110
+ > Federated Learning requires all clients to share the exact same weight matrix architecture. For Scikit-Learn models like `SGDClassifier`, the shape of the weights (`coef_` and `intercept_`) isn't initialized until it sees training data. We run a single dummy `partial_fit` on the creator side to establish this shape before sending it to the server.
111
+
112
+ ```python
113
+ from sklearn.linear_model import SGDClassifier
114
+ from fl_client import FLClient
115
+ import pandas as pd
116
+ import numpy as np
117
+
118
+ # Prepare model (scikit-learn requires partial_fit to initialize weights)
119
+ model = SGDClassifier(loss="log_loss", penalty="l2", max_iter=1,
120
+ learning_rate="constant", eta0=0.01)
121
+ df = pd.read_csv("./data.csv")
122
+ X = df.drop(columns=["label"]).values[:10].astype(np.float64)
123
+ y = df["label"].values[:10].astype(np.int64)
124
+ model.partial_fit(X, y, classes=[0, 1])
125
+
126
+ # Create room
127
+ client = FLClient(server_url="ws://localhost:8080")
128
+ room = client.create_room(
129
+ model=model,
130
+ data_path="./data.csv",
131
+ target="label",
132
+ training_config={"local_epochs": 1, "batch_size": 32, "learning_rate": 0.01},
133
+ room_name="Phishing Detection",
134
+ )
135
+
136
+ room_id = room["id"]
137
+ print(f"✅ Room created: {room_id}")
138
+ print(f" Invite code: {room['invite_code']}")
139
+
140
+ # Wait for participants, then start
141
+ client.wait_for_clients(min_clients=2, timeout=120)
142
+ client.start_training()
143
+ ```
144
+
145
+ ### 2. Room Joiner
146
+
147
+ Each participant joins an existing room, validates their local dataset, and starts training. This follows the **3-Step Lifecycle**:
148
+
149
+ 1. **`initialize()`** — Connects to the server, fetches the room's data schema and model configuration, and injects the latest global model weights into your local model.
150
+ 2. **`validate()`** — Checks your local dataset (`data.csv`) against the room's expected schema (e.g., ensuring it has the correct target column and feature count) and performs a dummy forward pass to catch shape errors early.
151
+ 3. **`start()`** — Signals readiness to the server and blocks while entering the federated training loop.
152
+
153
+ ```python
154
+ from sklearn.linear_model import SGDClassifier
155
+ from fl_client import FLClient
156
+
157
+ model = SGDClassifier(loss="log_loss", penalty="l2", max_iter=1,
158
+ learning_rate="constant", eta0=0.01)
159
+ # ... partial_fit to initialize shape (same architecture as creator)
160
+
161
+ client = FLClient(server_url="ws://localhost:8080")
162
+ client.join(room_id, invite_code="abc123", model=model)
163
+ client.validate("./data.csv")
164
+ client.ready()
165
+ client.start(max_rounds=5) # Blocks until training completes
166
+
167
+ print("✅ Training complete!")
168
+ ```
169
+
170
+ ### 3. PyTorch Models
171
+
172
+ ```python
173
+ import torch.nn as nn
174
+ from fl_client import FLClient
175
+
176
+ class PhishingMLP(nn.Module):
177
+ def __init__(self):
178
+ super().__init__()
179
+ self.net = nn.Sequential(
180
+ nn.Linear(30, 64), nn.ReLU(),
181
+ nn.Linear(64, 32), nn.ReLU(),
182
+ nn.Linear(32, 2),
183
+ )
184
+ def forward(self, x):
185
+ return self.net(x)
186
+
187
+ client = FLClient(server_url="ws://localhost:8080")
188
+ room = client.create_room(
189
+ model=PhishingMLP(),
190
+ data_path="./data.csv",
191
+ target="label",
192
+ training_config={"local_epochs": 2, "batch_size": 64, "learning_rate": 0.001},
193
+ room_name="PyTorch FL Room",
194
+ )
195
+ ```
196
+
197
+ ### 4. Prediction After Training
198
+
199
+ Because clients can disconnect, crash, or experience network drops, the DistFL SDK maintains a **local SQLite State Database**.
200
+
201
+ > [!TIP]
202
+ > **Why connect to the DB?**
203
+ > The server does not hold your data. Your final, fully-trained aggregated model weights are saved to your local `fl_client_state.db` at the end of training. By loading the state for your specific `client_id` (e.g. `worker-1`), you can extract these weights and run predictions *locally* without ever needing to communicate with the server again.
204
+
205
+ ```python
206
+ from fl_client.storage.db import StateDB
207
+ from fl_client.model.wrapper import wrap_model
208
+
209
+ db = StateDB("fl_client_state.db")
210
+ state = db.load_state("worker-1")
211
+
212
+ wrapper = wrap_model(model)
213
+ wrapper.set_weights(state.last_weights)
214
+
215
+ predictions = model.predict(X_test)
216
+ accuracy = (predictions == y_test).mean()
217
+ print(f"✅ Accuracy: {accuracy * 100:.2f}%")
218
+ ```
219
+
220
+ ---
221
+
222
+ ## 💻 CLI Reference
223
+
224
+ ```bash
225
+ # Full lifecycle from a YAML config
226
+ distfl run --config config.yaml
227
+
228
+ # Create a room
229
+ distfl create-room --server-url ws://localhost:8080 --room-name "My Room"
230
+
231
+ # Join a room and train
232
+ distfl join-room ROOM_ID --data ./data.csv --server-url ws://localhost:8080
233
+
234
+ # Launch the real-time web dashboard
235
+ distfl ui --port 5050
236
+
237
+ # Inspect persisted client state
238
+ distfl status --client-id worker-1 --db-path fl_client_state.db
239
+
240
+ # Clear persisted state
241
+ distfl clear --client-id worker-1 --db-path fl_client_state.db
242
+ ```
243
+
244
+ ---
245
+
246
+ ## ⚙️ Configuration
247
+
248
+ All options can be set via **YAML file**, **CLI flags**, or **environment variables** (`FL_` prefix):
249
+
250
+ ```yaml
251
+ # Server connection
252
+ server_url: "ws://localhost:8080"
253
+ room_id: "" # Leave empty to create a new room
254
+ client_id: "" # Auto-generated if omitted
255
+
256
+ # Dataset
257
+ data_path: "./data.csv"
258
+ label_column: "label"
259
+
260
+ # Training hyperparameters
261
+ batch_size: 32
262
+ local_epochs: 2
263
+ learning_rate: 0.001
264
+
265
+ # State persistence
266
+ db_path: "fl_client_state.db" # SQLite for crash recovery
267
+
268
+ # Networking
269
+ reconnect_max_delay: 60.0 # Max backoff delay (seconds)
270
+ reconnect_base_delay: 1.0 # Initial reconnect delay
271
+ heartbeat_interval: 30.0 # WebSocket ping interval
272
+
273
+ # Dashboard
274
+ dashboard_port: 5050 # Real-time metrics UI (0 = disabled)
275
+
276
+ # Logging
277
+ log_level: "INFO" # DEBUG, INFO, WARNING, ERROR
278
+ ```
279
+
280
+
281
+
282
+ ## 🧪 Supported Frameworks
283
+
284
+ | Framework | Requirements | Weight Extraction |
285
+ |---|---|---|
286
+ | **PyTorch** | Any `nn.Module` | `state_dict()` → 3D float32 lists |
287
+ | **Scikit-Learn** | Estimator with `partial_fit` (e.g. `SGDClassifier`, `SGDRegressor`) | `coef_` + `intercept_` → 3D float32 lists |
288
+
289
+ ---
290
+
291
+ ## 🧪 Testing
292
+
293
+ ```bash
294
+ # Install dev dependencies
295
+ pip install -e ".[dev]"
296
+
297
+ # Run all 52 unit tests
298
+ python -m pytest tests/ -v
299
+ ```
300
+
301
+ ### Test Coverage
302
+
303
+ | Module | Tests | What's Covered |
304
+ |---|---|---|
305
+ | `test_compressor.py` | 7 | Compress/decompress round-trip, empty data, large payloads |
306
+ | `test_connection.py` | 8 | WS URL construction, connect/disconnect, send/receive |
307
+ | `test_serializer.py` | 9 | Serialize/deserialize, shape preservation, JSON round-trip |
308
+ | `test_storage.py` | 7 | SQLite save/load, upsert, clear, round logging |
309
+ | `test_trainer.py` | 4 | Train results, finite loss, accuracy metrics, multi-epoch |
310
+ | `test_validation.py` | 17 | NaN/Inf/shape/range checks, loss validation, weight shapes |
311
+
312
+ ---
313
+
314
+ ## 🔐 Privacy & Security
315
+
316
+ - **Data never leaves the client** — only model weight updates are transmitted
317
+ - **GZIP compression** — reduces bandwidth and adds a layer of obfuscation
318
+ - **Server-side validation** — NaN, Inf, out-of-range, shape mismatch, L2 norm, and duplicate submission checks
319
+ - **Invite codes** — rooms can be access-controlled via invite codes
320
+ - **Crash recovery** — SQLite persistence prevents duplicate round submissions
321
+
322
+ ---
323
+
324
+ ## 📄 License
325
+
326
+ MIT
327
+
328
+ ---
329
+
330
+ <p align="center">
331
+ Built with ❤️ for privacy-preserving machine learning
332
+ </p>
@@ -0,0 +1,301 @@
1
+ <p align="center">
2
+ <h1 align="center">DistFL</h1>
3
+ <p align="center">
4
+ <strong>Production-Grade Federated Learning Client SDK</strong>
5
+ </p>
6
+ <p align="center">
7
+ <a href="#installation"><img src="https://img.shields.io/badge/python-≥3.10-blue?logo=python&logoColor=white" alt="Python"></a>
8
+ <a href="https://pypi.org/project/distfl-client/"><img src="https://img.shields.io/pypi/v/distfl-client?color=green&label=PyPI" alt="PyPI"></a>
9
+ <a href="#license"><img src="https://img.shields.io/badge/license-MIT-purple" alt="License"></a>
10
+ <a href="#testing"><img src="https://img.shields.io/badge/tests-52%20passed-brightgreen" alt="Tests"></a>
11
+ </p>
12
+ </p>
13
+
14
+ ---
15
+
16
+ Bring your own model (PyTorch or Scikit-Learn), connect to a DistFL server, train locally on **private data**, and let the server aggregate updates — all via compressed WebSocket communication. No raw data ever leaves the client.
17
+
18
+ ```
19
+ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐
20
+ │ Client A │ │ Client B │ │ Client C │
21
+ │ (Hospital) │ │ (Bank) │ │ (Lab) │
22
+ │ Local Data │ │ Local Data │ │ Local Data │
23
+ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘
24
+ │ model updates │ (gzip+WS) │
25
+ └────────────────────┼────────────────────┘
26
+
27
+ ┌───────▼────────┐
28
+ │ DistFL Server │
29
+ │ (Go Backend) │
30
+ │ FedAvg Agg. │
31
+ └───────┬────────┘
32
+
33
+ aggregated global model
34
+ broadcast to all clients
35
+ ```
36
+
37
+ ---
38
+
39
+ ## ✨ Features
40
+
41
+ | Category | Details |
42
+ |---|---|
43
+ | **BYOM** | Use any PyTorch `nn.Module` or Scikit-Learn estimator with `partial_fit` |
44
+ | **Simple Lifecycle** | `initialize()` → `validate()` → `start()` — 3 calls to go from zero to training |
45
+ | **Room-Based FL** | Create rooms, share invite codes, configure training params per room |
46
+ | **Compressed WebSocket** | GZIP-compressed binary messages over persistent WebSocket connections |
47
+ | **Auto Reconnect** | Exponential backoff with configurable delays and heartbeat pings |
48
+ | **Crash Recovery** | SQLite-backed state persistence — no duplicate round submissions after restart |
49
+ | **Live Dashboard** | Built-in web UI with real-time loss curves, ΔW tracking, and training logs |
50
+ | **Prediction** | Extract globally-aggregated weights and run inference locally |
51
+ | **CLI** | `distfl run`, `distfl create-room`, `distfl join-room`, `distfl ui`, `distfl status` |
52
+
53
+ ---
54
+
55
+ ## 📦 Installation
56
+
57
+ ```bash
58
+ pip install distfl-client
59
+ ```
60
+
61
+ **From source:**
62
+
63
+ ```bash
64
+ git clone https://github.com/AbhaySingh002/new-repo-code.git
65
+ cd new-repo-code/DistFL
66
+ pip install -e ".[dev]"
67
+ ```
68
+
69
+ ---
70
+
71
+ ## 🚀 Quick Start
72
+
73
+ ### 1. Room Creator
74
+
75
+ The creator relies on the `FLClient` to initialize the global model architecture, create a new room on the server, and wait for other participants to join before starting.
76
+
77
+ > [!NOTE]
78
+ > **Why `partial_fit`?**
79
+ > Federated Learning requires all clients to share the exact same weight matrix architecture. For Scikit-Learn models like `SGDClassifier`, the shape of the weights (`coef_` and `intercept_`) isn't initialized until it sees training data. We run a single dummy `partial_fit` on the creator side to establish this shape before sending it to the server.
80
+
81
+ ```python
82
+ from sklearn.linear_model import SGDClassifier
83
+ from fl_client import FLClient
84
+ import pandas as pd
85
+ import numpy as np
86
+
87
+ # Prepare model (scikit-learn requires partial_fit to initialize weights)
88
+ model = SGDClassifier(loss="log_loss", penalty="l2", max_iter=1,
89
+ learning_rate="constant", eta0=0.01)
90
+ df = pd.read_csv("./data.csv")
91
+ X = df.drop(columns=["label"]).values[:10].astype(np.float64)
92
+ y = df["label"].values[:10].astype(np.int64)
93
+ model.partial_fit(X, y, classes=[0, 1])
94
+
95
+ # Create room
96
+ client = FLClient(server_url="ws://localhost:8080")
97
+ room = client.create_room(
98
+ model=model,
99
+ data_path="./data.csv",
100
+ target="label",
101
+ training_config={"local_epochs": 1, "batch_size": 32, "learning_rate": 0.01},
102
+ room_name="Phishing Detection",
103
+ )
104
+
105
+ room_id = room["id"]
106
+ print(f"✅ Room created: {room_id}")
107
+ print(f" Invite code: {room['invite_code']}")
108
+
109
+ # Wait for participants, then start
110
+ client.wait_for_clients(min_clients=2, timeout=120)
111
+ client.start_training()
112
+ ```
113
+
114
+ ### 2. Room Joiner
115
+
116
+ Each participant joins an existing room, validates their local dataset, and starts training. This follows the **3-Step Lifecycle**:
117
+
118
+ 1. **`initialize()`** — Connects to the server, fetches the room's data schema and model configuration, and injects the latest global model weights into your local model.
119
+ 2. **`validate()`** — Checks your local dataset (`data.csv`) against the room's expected schema (e.g., ensuring it has the correct target column and feature count) and performs a dummy forward pass to catch shape errors early.
120
+ 3. **`start()`** — Signals readiness to the server and blocks while entering the federated training loop.
121
+
122
+ ```python
123
+ from sklearn.linear_model import SGDClassifier
124
+ from fl_client import FLClient
125
+
126
+ model = SGDClassifier(loss="log_loss", penalty="l2", max_iter=1,
127
+ learning_rate="constant", eta0=0.01)
128
+ # ... partial_fit to initialize shape (same architecture as creator)
129
+
130
+ client = FLClient(server_url="ws://localhost:8080")
131
+ client.join(room_id, invite_code="abc123", model=model)
132
+ client.validate("./data.csv")
133
+ client.ready()
134
+ client.start(max_rounds=5) # Blocks until training completes
135
+
136
+ print("✅ Training complete!")
137
+ ```
138
+
139
+ ### 3. PyTorch Models
140
+
141
+ ```python
142
+ import torch.nn as nn
143
+ from fl_client import FLClient
144
+
145
+ class PhishingMLP(nn.Module):
146
+ def __init__(self):
147
+ super().__init__()
148
+ self.net = nn.Sequential(
149
+ nn.Linear(30, 64), nn.ReLU(),
150
+ nn.Linear(64, 32), nn.ReLU(),
151
+ nn.Linear(32, 2),
152
+ )
153
+ def forward(self, x):
154
+ return self.net(x)
155
+
156
+ client = FLClient(server_url="ws://localhost:8080")
157
+ room = client.create_room(
158
+ model=PhishingMLP(),
159
+ data_path="./data.csv",
160
+ target="label",
161
+ training_config={"local_epochs": 2, "batch_size": 64, "learning_rate": 0.001},
162
+ room_name="PyTorch FL Room",
163
+ )
164
+ ```
165
+
166
+ ### 4. Prediction After Training
167
+
168
+ Because clients can disconnect, crash, or experience network drops, the DistFL SDK maintains a **local SQLite State Database**.
169
+
170
+ > [!TIP]
171
+ > **Why connect to the DB?**
172
+ > The server does not hold your data. Your final, fully-trained aggregated model weights are saved to your local `fl_client_state.db` at the end of training. By loading the state for your specific `client_id` (e.g. `worker-1`), you can extract these weights and run predictions *locally* without ever needing to communicate with the server again.
173
+
174
+ ```python
175
+ from fl_client.storage.db import StateDB
176
+ from fl_client.model.wrapper import wrap_model
177
+
178
+ db = StateDB("fl_client_state.db")
179
+ state = db.load_state("worker-1")
180
+
181
+ wrapper = wrap_model(model)
182
+ wrapper.set_weights(state.last_weights)
183
+
184
+ predictions = model.predict(X_test)
185
+ accuracy = (predictions == y_test).mean()
186
+ print(f"✅ Accuracy: {accuracy * 100:.2f}%")
187
+ ```
188
+
189
+ ---
190
+
191
+ ## 💻 CLI Reference
192
+
193
+ ```bash
194
+ # Full lifecycle from a YAML config
195
+ distfl run --config config.yaml
196
+
197
+ # Create a room
198
+ distfl create-room --server-url ws://localhost:8080 --room-name "My Room"
199
+
200
+ # Join a room and train
201
+ distfl join-room ROOM_ID --data ./data.csv --server-url ws://localhost:8080
202
+
203
+ # Launch the real-time web dashboard
204
+ distfl ui --port 5050
205
+
206
+ # Inspect persisted client state
207
+ distfl status --client-id worker-1 --db-path fl_client_state.db
208
+
209
+ # Clear persisted state
210
+ distfl clear --client-id worker-1 --db-path fl_client_state.db
211
+ ```
212
+
213
+ ---
214
+
215
+ ## ⚙️ Configuration
216
+
217
+ All options can be set via **YAML file**, **CLI flags**, or **environment variables** (`FL_` prefix):
218
+
219
+ ```yaml
220
+ # Server connection
221
+ server_url: "ws://localhost:8080"
222
+ room_id: "" # Leave empty to create a new room
223
+ client_id: "" # Auto-generated if omitted
224
+
225
+ # Dataset
226
+ data_path: "./data.csv"
227
+ label_column: "label"
228
+
229
+ # Training hyperparameters
230
+ batch_size: 32
231
+ local_epochs: 2
232
+ learning_rate: 0.001
233
+
234
+ # State persistence
235
+ db_path: "fl_client_state.db" # SQLite for crash recovery
236
+
237
+ # Networking
238
+ reconnect_max_delay: 60.0 # Max backoff delay (seconds)
239
+ reconnect_base_delay: 1.0 # Initial reconnect delay
240
+ heartbeat_interval: 30.0 # WebSocket ping interval
241
+
242
+ # Dashboard
243
+ dashboard_port: 5050 # Real-time metrics UI (0 = disabled)
244
+
245
+ # Logging
246
+ log_level: "INFO" # DEBUG, INFO, WARNING, ERROR
247
+ ```
248
+
249
+
250
+
251
+ ## 🧪 Supported Frameworks
252
+
253
+ | Framework | Requirements | Weight Extraction |
254
+ |---|---|---|
255
+ | **PyTorch** | Any `nn.Module` | `state_dict()` → 3D float32 lists |
256
+ | **Scikit-Learn** | Estimator with `partial_fit` (e.g. `SGDClassifier`, `SGDRegressor`) | `coef_` + `intercept_` → 3D float32 lists |
257
+
258
+ ---
259
+
260
+ ## 🧪 Testing
261
+
262
+ ```bash
263
+ # Install dev dependencies
264
+ pip install -e ".[dev]"
265
+
266
+ # Run all 52 unit tests
267
+ python -m pytest tests/ -v
268
+ ```
269
+
270
+ ### Test Coverage
271
+
272
+ | Module | Tests | What's Covered |
273
+ |---|---|---|
274
+ | `test_compressor.py` | 7 | Compress/decompress round-trip, empty data, large payloads |
275
+ | `test_connection.py` | 8 | WS URL construction, connect/disconnect, send/receive |
276
+ | `test_serializer.py` | 9 | Serialize/deserialize, shape preservation, JSON round-trip |
277
+ | `test_storage.py` | 7 | SQLite save/load, upsert, clear, round logging |
278
+ | `test_trainer.py` | 4 | Train results, finite loss, accuracy metrics, multi-epoch |
279
+ | `test_validation.py` | 17 | NaN/Inf/shape/range checks, loss validation, weight shapes |
280
+
281
+ ---
282
+
283
+ ## 🔐 Privacy & Security
284
+
285
+ - **Data never leaves the client** — only model weight updates are transmitted
286
+ - **GZIP compression** — reduces bandwidth and adds a layer of obfuscation
287
+ - **Server-side validation** — NaN, Inf, out-of-range, shape mismatch, L2 norm, and duplicate submission checks
288
+ - **Invite codes** — rooms can be access-controlled via invite codes
289
+ - **Crash recovery** — SQLite persistence prevents duplicate round submissions
290
+
291
+ ---
292
+
293
+ ## 📄 License
294
+
295
+ MIT
296
+
297
+ ---
298
+
299
+ <p align="center">
300
+ Built with ❤️ for privacy-preserving machine learning
301
+ </p>