distfl-client 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- distfl_client-1.0.0/MANIFEST.in +9 -0
- distfl_client-1.0.0/PKG-INFO +332 -0
- distfl_client-1.0.0/README.md +301 -0
- distfl_client-1.0.0/distfl_client.egg-info/PKG-INFO +332 -0
- distfl_client-1.0.0/distfl_client.egg-info/SOURCES.txt +33 -0
- distfl_client-1.0.0/distfl_client.egg-info/dependency_links.txt +1 -0
- distfl_client-1.0.0/distfl_client.egg-info/entry_points.txt +3 -0
- distfl_client-1.0.0/distfl_client.egg-info/requires.txt +13 -0
- distfl_client-1.0.0/distfl_client.egg-info/top_level.txt +1 -0
- distfl_client-1.0.0/fl_client/__init__.py +7 -0
- distfl_client-1.0.0/fl_client/cli/__init__.py +4 -0
- distfl_client-1.0.0/fl_client/cli/main.py +242 -0
- distfl_client-1.0.0/fl_client/communication/__init__.py +5 -0
- distfl_client-1.0.0/fl_client/communication/compressor.py +56 -0
- distfl_client-1.0.0/fl_client/communication/serializer.py +132 -0
- distfl_client-1.0.0/fl_client/config/__init__.py +4 -0
- distfl_client-1.0.0/fl_client/config/config.py +145 -0
- distfl_client-1.0.0/fl_client/core/__init__.py +4 -0
- distfl_client-1.0.0/fl_client/core/client.py +1274 -0
- distfl_client-1.0.0/fl_client/core/connection.py +270 -0
- distfl_client-1.0.0/fl_client/core/session.py +165 -0
- distfl_client-1.0.0/fl_client/core/state_manager.py +168 -0
- distfl_client-1.0.0/fl_client/model/__init__.py +4 -0
- distfl_client-1.0.0/fl_client/model/wrapper.py +347 -0
- distfl_client-1.0.0/fl_client/storage/__init__.py +4 -0
- distfl_client-1.0.0/fl_client/storage/db.py +267 -0
- distfl_client-1.0.0/fl_client/training/__init__.py +5 -0
- distfl_client-1.0.0/fl_client/training/dataset.py +196 -0
- distfl_client-1.0.0/fl_client/training/trainer.py +431 -0
- distfl_client-1.0.0/fl_client/validation/__init__.py +4 -0
- distfl_client-1.0.0/fl_client/validation/checks.py +128 -0
- distfl_client-1.0.0/fl_client/web/__init__.py +1 -0
- distfl_client-1.0.0/fl_client/web/bridge.py +641 -0
- distfl_client-1.0.0/pyproject.toml +70 -0
- distfl_client-1.0.0/setup.cfg +4 -0
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: distfl-client
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: Production-grade Python Client SDK for room-based Federated Learning
|
|
5
|
+
Author: DistFL Team
|
|
6
|
+
License: MIT
|
|
7
|
+
Keywords: federated-learning,machine-learning,distributed,pytorch
|
|
8
|
+
Classifier: Development Status :: 4 - Beta
|
|
9
|
+
Classifier: Intended Audience :: Developers
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
+
Requires-Python: >=3.10
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
Requires-Dist: torch>=2.0.0
|
|
20
|
+
Requires-Dist: numpy>=1.24.0
|
|
21
|
+
Requires-Dist: pandas>=2.0.0
|
|
22
|
+
Requires-Dist: websockets>=12.0
|
|
23
|
+
Requires-Dist: httpx>=0.25.0
|
|
24
|
+
Requires-Dist: pyyaml>=6.0
|
|
25
|
+
Requires-Dist: fastapi>=0.110.0
|
|
26
|
+
Requires-Dist: uvicorn[standard]>=0.27.0
|
|
27
|
+
Provides-Extra: dev
|
|
28
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
29
|
+
Requires-Dist: pytest-asyncio>=0.21; extra == "dev"
|
|
30
|
+
Requires-Dist: pytest-cov>=4.0; extra == "dev"
|
|
31
|
+
|
|
32
|
+
<p align="center">
|
|
33
|
+
<h1 align="center">DistFL</h1>
|
|
34
|
+
<p align="center">
|
|
35
|
+
<strong>Production-Grade Federated Learning Client SDK</strong>
|
|
36
|
+
</p>
|
|
37
|
+
<p align="center">
|
|
38
|
+
<a href="#installation"><img src="https://img.shields.io/badge/python-≥3.10-blue?logo=python&logoColor=white" alt="Python"></a>
|
|
39
|
+
<a href="https://pypi.org/project/distfl-client/"><img src="https://img.shields.io/pypi/v/distfl-client?color=green&label=PyPI" alt="PyPI"></a>
|
|
40
|
+
<a href="#license"><img src="https://img.shields.io/badge/license-MIT-purple" alt="License"></a>
|
|
41
|
+
<a href="#testing"><img src="https://img.shields.io/badge/tests-52%20passed-brightgreen" alt="Tests"></a>
|
|
42
|
+
</p>
|
|
43
|
+
</p>
|
|
44
|
+
|
|
45
|
+
---
|
|
46
|
+
|
|
47
|
+
Bring your own model (PyTorch or Scikit-Learn), connect to a DistFL server, train locally on **private data**, and let the server aggregate updates — all via compressed WebSocket communication. No raw data ever leaves the client.
|
|
48
|
+
|
|
49
|
+
```
|
|
50
|
+
┌──────────────┐ ┌──────────────┐ ┌──────────────┐
|
|
51
|
+
│ Client A │ │ Client B │ │ Client C │
|
|
52
|
+
│ (Hospital) │ │ (Bank) │ │ (Lab) │
|
|
53
|
+
│ Local Data │ │ Local Data │ │ Local Data │
|
|
54
|
+
└──────┬───────┘ └──────┬───────┘ └──────┬───────┘
|
|
55
|
+
│ model updates │ (gzip+WS) │
|
|
56
|
+
└────────────────────┼────────────────────┘
|
|
57
|
+
│
|
|
58
|
+
┌───────▼────────┐
|
|
59
|
+
│ DistFL Server │
|
|
60
|
+
│ (Go Backend) │
|
|
61
|
+
│ FedAvg Agg. │
|
|
62
|
+
└───────┬────────┘
|
|
63
|
+
│
|
|
64
|
+
aggregated global model
|
|
65
|
+
broadcast to all clients
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
---
|
|
69
|
+
|
|
70
|
+
## ✨ Features
|
|
71
|
+
|
|
72
|
+
| Category | Details |
|
|
73
|
+
|---|---|
|
|
74
|
+
| **BYOM** | Use any PyTorch `nn.Module` or Scikit-Learn estimator with `partial_fit` |
|
|
75
|
+
| **Simple Lifecycle** | `initialize()` → `validate()` → `start()` — 3 calls to go from zero to training |
|
|
76
|
+
| **Room-Based FL** | Create rooms, share invite codes, configure training params per room |
|
|
77
|
+
| **Compressed WebSocket** | GZIP-compressed binary messages over persistent WebSocket connections |
|
|
78
|
+
| **Auto Reconnect** | Exponential backoff with configurable delays and heartbeat pings |
|
|
79
|
+
| **Crash Recovery** | SQLite-backed state persistence — no duplicate round submissions after restart |
|
|
80
|
+
| **Live Dashboard** | Built-in web UI with real-time loss curves, ΔW tracking, and training logs |
|
|
81
|
+
| **Prediction** | Extract globally-aggregated weights and run inference locally |
|
|
82
|
+
| **CLI** | `distfl run`, `distfl create-room`, `distfl join-room`, `distfl ui`, `distfl status` |
|
|
83
|
+
|
|
84
|
+
---
|
|
85
|
+
|
|
86
|
+
## 📦 Installation
|
|
87
|
+
|
|
88
|
+
```bash
|
|
89
|
+
pip install distfl-client
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
**From source:**
|
|
93
|
+
|
|
94
|
+
```bash
|
|
95
|
+
git clone https://github.com/AbhaySingh002/new-repo-code.git
|
|
96
|
+
cd new-repo-code/DistFL
|
|
97
|
+
pip install -e ".[dev]"
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
---
|
|
101
|
+
|
|
102
|
+
## 🚀 Quick Start
|
|
103
|
+
|
|
104
|
+
### 1. Room Creator
|
|
105
|
+
|
|
106
|
+
The creator relies on the `FLClient` to initialize the global model architecture, create a new room on the server, and wait for other participants to join before starting.
|
|
107
|
+
|
|
108
|
+
> [!NOTE]
|
|
109
|
+
> **Why `partial_fit`?**
|
|
110
|
+
> Federated Learning requires all clients to share the exact same weight matrix architecture. For Scikit-Learn models like `SGDClassifier`, the shape of the weights (`coef_` and `intercept_`) isn't initialized until it sees training data. We run a single dummy `partial_fit` on the creator side to establish this shape before sending it to the server.
|
|
111
|
+
|
|
112
|
+
```python
|
|
113
|
+
from sklearn.linear_model import SGDClassifier
|
|
114
|
+
from fl_client import FLClient
|
|
115
|
+
import pandas as pd
|
|
116
|
+
import numpy as np
|
|
117
|
+
|
|
118
|
+
# Prepare model (scikit-learn requires partial_fit to initialize weights)
|
|
119
|
+
model = SGDClassifier(loss="log_loss", penalty="l2", max_iter=1,
|
|
120
|
+
learning_rate="constant", eta0=0.01)
|
|
121
|
+
df = pd.read_csv("./data.csv")
|
|
122
|
+
X = df.drop(columns=["label"]).values[:10].astype(np.float64)
|
|
123
|
+
y = df["label"].values[:10].astype(np.int64)
|
|
124
|
+
model.partial_fit(X, y, classes=[0, 1])
|
|
125
|
+
|
|
126
|
+
# Create room
|
|
127
|
+
client = FLClient(server_url="ws://localhost:8080")
|
|
128
|
+
room = client.create_room(
|
|
129
|
+
model=model,
|
|
130
|
+
data_path="./data.csv",
|
|
131
|
+
target="label",
|
|
132
|
+
training_config={"local_epochs": 1, "batch_size": 32, "learning_rate": 0.01},
|
|
133
|
+
room_name="Phishing Detection",
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
room_id = room["id"]
|
|
137
|
+
print(f"✅ Room created: {room_id}")
|
|
138
|
+
print(f" Invite code: {room['invite_code']}")
|
|
139
|
+
|
|
140
|
+
# Wait for participants, then start
|
|
141
|
+
client.wait_for_clients(min_clients=2, timeout=120)
|
|
142
|
+
client.start_training()
|
|
143
|
+
```
|
|
144
|
+
|
|
145
|
+
### 2. Room Joiner
|
|
146
|
+
|
|
147
|
+
Each participant joins an existing room, validates their local dataset, and starts training. This follows the **3-Step Lifecycle**:
|
|
148
|
+
|
|
149
|
+
1. **`initialize()`** — Connects to the server, fetches the room's data schema and model configuration, and injects the latest global model weights into your local model.
|
|
150
|
+
2. **`validate()`** — Checks your local dataset (`data.csv`) against the room's expected schema (e.g., ensuring it has the correct target column and feature count) and performs a dummy forward pass to catch shape errors early.
|
|
151
|
+
3. **`start()`** — Signals readiness to the server and blocks while entering the federated training loop.
|
|
152
|
+
|
|
153
|
+
```python
|
|
154
|
+
from sklearn.linear_model import SGDClassifier
|
|
155
|
+
from fl_client import FLClient
|
|
156
|
+
|
|
157
|
+
model = SGDClassifier(loss="log_loss", penalty="l2", max_iter=1,
|
|
158
|
+
learning_rate="constant", eta0=0.01)
|
|
159
|
+
# ... partial_fit to initialize shape (same architecture as creator)
|
|
160
|
+
|
|
161
|
+
client = FLClient(server_url="ws://localhost:8080")
|
|
162
|
+
client.join(room_id, invite_code="abc123", model=model)
|
|
163
|
+
client.validate("./data.csv")
|
|
164
|
+
client.ready()
|
|
165
|
+
client.start(max_rounds=5) # Blocks until training completes
|
|
166
|
+
|
|
167
|
+
print("✅ Training complete!")
|
|
168
|
+
```
|
|
169
|
+
|
|
170
|
+
### 3. PyTorch Models
|
|
171
|
+
|
|
172
|
+
```python
|
|
173
|
+
import torch.nn as nn
|
|
174
|
+
from fl_client import FLClient
|
|
175
|
+
|
|
176
|
+
class PhishingMLP(nn.Module):
|
|
177
|
+
def __init__(self):
|
|
178
|
+
super().__init__()
|
|
179
|
+
self.net = nn.Sequential(
|
|
180
|
+
nn.Linear(30, 64), nn.ReLU(),
|
|
181
|
+
nn.Linear(64, 32), nn.ReLU(),
|
|
182
|
+
nn.Linear(32, 2),
|
|
183
|
+
)
|
|
184
|
+
def forward(self, x):
|
|
185
|
+
return self.net(x)
|
|
186
|
+
|
|
187
|
+
client = FLClient(server_url="ws://localhost:8080")
|
|
188
|
+
room = client.create_room(
|
|
189
|
+
model=PhishingMLP(),
|
|
190
|
+
data_path="./data.csv",
|
|
191
|
+
target="label",
|
|
192
|
+
training_config={"local_epochs": 2, "batch_size": 64, "learning_rate": 0.001},
|
|
193
|
+
room_name="PyTorch FL Room",
|
|
194
|
+
)
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
### 4. Prediction After Training
|
|
198
|
+
|
|
199
|
+
Because clients can disconnect, crash, or experience network drops, the DistFL SDK maintains a **local SQLite State Database**.
|
|
200
|
+
|
|
201
|
+
> [!TIP]
|
|
202
|
+
> **Why connect to the DB?**
|
|
203
|
+
> The server does not hold your data. Your final, fully-trained aggregated model weights are saved to your local `fl_client_state.db` at the end of training. By loading the state for your specific `client_id` (e.g. `worker-1`), you can extract these weights and run predictions *locally* without ever needing to communicate with the server again.
|
|
204
|
+
|
|
205
|
+
```python
|
|
206
|
+
from fl_client.storage.db import StateDB
|
|
207
|
+
from fl_client.model.wrapper import wrap_model
|
|
208
|
+
|
|
209
|
+
db = StateDB("fl_client_state.db")
|
|
210
|
+
state = db.load_state("worker-1")
|
|
211
|
+
|
|
212
|
+
wrapper = wrap_model(model)
|
|
213
|
+
wrapper.set_weights(state.last_weights)
|
|
214
|
+
|
|
215
|
+
predictions = model.predict(X_test)
|
|
216
|
+
accuracy = (predictions == y_test).mean()
|
|
217
|
+
print(f"✅ Accuracy: {accuracy * 100:.2f}%")
|
|
218
|
+
```
|
|
219
|
+
|
|
220
|
+
---
|
|
221
|
+
|
|
222
|
+
## 💻 CLI Reference
|
|
223
|
+
|
|
224
|
+
```bash
|
|
225
|
+
# Full lifecycle from a YAML config
|
|
226
|
+
distfl run --config config.yaml
|
|
227
|
+
|
|
228
|
+
# Create a room
|
|
229
|
+
distfl create-room --server-url ws://localhost:8080 --room-name "My Room"
|
|
230
|
+
|
|
231
|
+
# Join a room and train
|
|
232
|
+
distfl join-room ROOM_ID --data ./data.csv --server-url ws://localhost:8080
|
|
233
|
+
|
|
234
|
+
# Launch the real-time web dashboard
|
|
235
|
+
distfl ui --port 5050
|
|
236
|
+
|
|
237
|
+
# Inspect persisted client state
|
|
238
|
+
distfl status --client-id worker-1 --db-path fl_client_state.db
|
|
239
|
+
|
|
240
|
+
# Clear persisted state
|
|
241
|
+
distfl clear --client-id worker-1 --db-path fl_client_state.db
|
|
242
|
+
```
|
|
243
|
+
|
|
244
|
+
---
|
|
245
|
+
|
|
246
|
+
## ⚙️ Configuration
|
|
247
|
+
|
|
248
|
+
All options can be set via **YAML file**, **CLI flags**, or **environment variables** (`FL_` prefix):
|
|
249
|
+
|
|
250
|
+
```yaml
|
|
251
|
+
# Server connection
|
|
252
|
+
server_url: "ws://localhost:8080"
|
|
253
|
+
room_id: "" # Leave empty to create a new room
|
|
254
|
+
client_id: "" # Auto-generated if omitted
|
|
255
|
+
|
|
256
|
+
# Dataset
|
|
257
|
+
data_path: "./data.csv"
|
|
258
|
+
label_column: "label"
|
|
259
|
+
|
|
260
|
+
# Training hyperparameters
|
|
261
|
+
batch_size: 32
|
|
262
|
+
local_epochs: 2
|
|
263
|
+
learning_rate: 0.001
|
|
264
|
+
|
|
265
|
+
# State persistence
|
|
266
|
+
db_path: "fl_client_state.db" # SQLite for crash recovery
|
|
267
|
+
|
|
268
|
+
# Networking
|
|
269
|
+
reconnect_max_delay: 60.0 # Max backoff delay (seconds)
|
|
270
|
+
reconnect_base_delay: 1.0 # Initial reconnect delay
|
|
271
|
+
heartbeat_interval: 30.0 # WebSocket ping interval
|
|
272
|
+
|
|
273
|
+
# Dashboard
|
|
274
|
+
dashboard_port: 5050 # Real-time metrics UI (0 = disabled)
|
|
275
|
+
|
|
276
|
+
# Logging
|
|
277
|
+
log_level: "INFO" # DEBUG, INFO, WARNING, ERROR
|
|
278
|
+
```
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
## 🧪 Supported Frameworks
|
|
283
|
+
|
|
284
|
+
| Framework | Requirements | Weight Extraction |
|
|
285
|
+
|---|---|---|
|
|
286
|
+
| **PyTorch** | Any `nn.Module` | `state_dict()` → 3D float32 lists |
|
|
287
|
+
| **Scikit-Learn** | Estimator with `partial_fit` (e.g. `SGDClassifier`, `SGDRegressor`) | `coef_` + `intercept_` → 3D float32 lists |
|
|
288
|
+
|
|
289
|
+
---
|
|
290
|
+
|
|
291
|
+
## 🧪 Testing
|
|
292
|
+
|
|
293
|
+
```bash
|
|
294
|
+
# Install dev dependencies
|
|
295
|
+
pip install -e ".[dev]"
|
|
296
|
+
|
|
297
|
+
# Run all 52 unit tests
|
|
298
|
+
python -m pytest tests/ -v
|
|
299
|
+
```
|
|
300
|
+
|
|
301
|
+
### Test Coverage
|
|
302
|
+
|
|
303
|
+
| Module | Tests | What's Covered |
|
|
304
|
+
|---|---|---|
|
|
305
|
+
| `test_compressor.py` | 7 | Compress/decompress round-trip, empty data, large payloads |
|
|
306
|
+
| `test_connection.py` | 8 | WS URL construction, connect/disconnect, send/receive |
|
|
307
|
+
| `test_serializer.py` | 9 | Serialize/deserialize, shape preservation, JSON round-trip |
|
|
308
|
+
| `test_storage.py` | 7 | SQLite save/load, upsert, clear, round logging |
|
|
309
|
+
| `test_trainer.py` | 4 | Train results, finite loss, accuracy metrics, multi-epoch |
|
|
310
|
+
| `test_validation.py` | 17 | NaN/Inf/shape/range checks, loss validation, weight shapes |
|
|
311
|
+
|
|
312
|
+
---
|
|
313
|
+
|
|
314
|
+
## 🔐 Privacy & Security
|
|
315
|
+
|
|
316
|
+
- **Data never leaves the client** — only model weight updates are transmitted
|
|
317
|
+
- **GZIP compression** — reduces bandwidth and adds a layer of obfuscation
|
|
318
|
+
- **Server-side validation** — NaN, Inf, out-of-range, shape mismatch, L2 norm, and duplicate submission checks
|
|
319
|
+
- **Invite codes** — rooms can be access-controlled via invite codes
|
|
320
|
+
- **Crash recovery** — SQLite persistence prevents duplicate round submissions
|
|
321
|
+
|
|
322
|
+
---
|
|
323
|
+
|
|
324
|
+
## 📄 License
|
|
325
|
+
|
|
326
|
+
MIT
|
|
327
|
+
|
|
328
|
+
---
|
|
329
|
+
|
|
330
|
+
<p align="center">
|
|
331
|
+
Built with ❤️ for privacy-preserving machine learning
|
|
332
|
+
</p>
|
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
<p align="center">
|
|
2
|
+
<h1 align="center">DistFL</h1>
|
|
3
|
+
<p align="center">
|
|
4
|
+
<strong>Production-Grade Federated Learning Client SDK</strong>
|
|
5
|
+
</p>
|
|
6
|
+
<p align="center">
|
|
7
|
+
<a href="#installation"><img src="https://img.shields.io/badge/python-≥3.10-blue?logo=python&logoColor=white" alt="Python"></a>
|
|
8
|
+
<a href="https://pypi.org/project/distfl-client/"><img src="https://img.shields.io/pypi/v/distfl-client?color=green&label=PyPI" alt="PyPI"></a>
|
|
9
|
+
<a href="#license"><img src="https://img.shields.io/badge/license-MIT-purple" alt="License"></a>
|
|
10
|
+
<a href="#testing"><img src="https://img.shields.io/badge/tests-52%20passed-brightgreen" alt="Tests"></a>
|
|
11
|
+
</p>
|
|
12
|
+
</p>
|
|
13
|
+
|
|
14
|
+
---
|
|
15
|
+
|
|
16
|
+
Bring your own model (PyTorch or Scikit-Learn), connect to a DistFL server, train locally on **private data**, and let the server aggregate updates — all via compressed WebSocket communication. No raw data ever leaves the client.
|
|
17
|
+
|
|
18
|
+
```
|
|
19
|
+
┌──────────────┐ ┌──────────────┐ ┌──────────────┐
|
|
20
|
+
│ Client A │ │ Client B │ │ Client C │
|
|
21
|
+
│ (Hospital) │ │ (Bank) │ │ (Lab) │
|
|
22
|
+
│ Local Data │ │ Local Data │ │ Local Data │
|
|
23
|
+
└──────┬───────┘ └──────┬───────┘ └──────┬───────┘
|
|
24
|
+
│ model updates │ (gzip+WS) │
|
|
25
|
+
└────────────────────┼────────────────────┘
|
|
26
|
+
│
|
|
27
|
+
┌───────▼────────┐
|
|
28
|
+
│ DistFL Server │
|
|
29
|
+
│ (Go Backend) │
|
|
30
|
+
│ FedAvg Agg. │
|
|
31
|
+
└───────┬────────┘
|
|
32
|
+
│
|
|
33
|
+
aggregated global model
|
|
34
|
+
broadcast to all clients
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
---
|
|
38
|
+
|
|
39
|
+
## ✨ Features
|
|
40
|
+
|
|
41
|
+
| Category | Details |
|
|
42
|
+
|---|---|
|
|
43
|
+
| **BYOM** | Use any PyTorch `nn.Module` or Scikit-Learn estimator with `partial_fit` |
|
|
44
|
+
| **Simple Lifecycle** | `initialize()` → `validate()` → `start()` — 3 calls to go from zero to training |
|
|
45
|
+
| **Room-Based FL** | Create rooms, share invite codes, configure training params per room |
|
|
46
|
+
| **Compressed WebSocket** | GZIP-compressed binary messages over persistent WebSocket connections |
|
|
47
|
+
| **Auto Reconnect** | Exponential backoff with configurable delays and heartbeat pings |
|
|
48
|
+
| **Crash Recovery** | SQLite-backed state persistence — no duplicate round submissions after restart |
|
|
49
|
+
| **Live Dashboard** | Built-in web UI with real-time loss curves, ΔW tracking, and training logs |
|
|
50
|
+
| **Prediction** | Extract globally-aggregated weights and run inference locally |
|
|
51
|
+
| **CLI** | `distfl run`, `distfl create-room`, `distfl join-room`, `distfl ui`, `distfl status` |
|
|
52
|
+
|
|
53
|
+
---
|
|
54
|
+
|
|
55
|
+
## 📦 Installation
|
|
56
|
+
|
|
57
|
+
```bash
|
|
58
|
+
pip install distfl-client
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
**From source:**
|
|
62
|
+
|
|
63
|
+
```bash
|
|
64
|
+
git clone https://github.com/AbhaySingh002/new-repo-code.git
|
|
65
|
+
cd new-repo-code/DistFL
|
|
66
|
+
pip install -e ".[dev]"
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
---
|
|
70
|
+
|
|
71
|
+
## 🚀 Quick Start
|
|
72
|
+
|
|
73
|
+
### 1. Room Creator
|
|
74
|
+
|
|
75
|
+
The creator relies on the `FLClient` to initialize the global model architecture, create a new room on the server, and wait for other participants to join before starting.
|
|
76
|
+
|
|
77
|
+
> [!NOTE]
|
|
78
|
+
> **Why `partial_fit`?**
|
|
79
|
+
> Federated Learning requires all clients to share the exact same weight matrix architecture. For Scikit-Learn models like `SGDClassifier`, the shape of the weights (`coef_` and `intercept_`) isn't initialized until it sees training data. We run a single dummy `partial_fit` on the creator side to establish this shape before sending it to the server.
|
|
80
|
+
|
|
81
|
+
```python
|
|
82
|
+
from sklearn.linear_model import SGDClassifier
|
|
83
|
+
from fl_client import FLClient
|
|
84
|
+
import pandas as pd
|
|
85
|
+
import numpy as np
|
|
86
|
+
|
|
87
|
+
# Prepare model (scikit-learn requires partial_fit to initialize weights)
|
|
88
|
+
model = SGDClassifier(loss="log_loss", penalty="l2", max_iter=1,
|
|
89
|
+
learning_rate="constant", eta0=0.01)
|
|
90
|
+
df = pd.read_csv("./data.csv")
|
|
91
|
+
X = df.drop(columns=["label"]).values[:10].astype(np.float64)
|
|
92
|
+
y = df["label"].values[:10].astype(np.int64)
|
|
93
|
+
model.partial_fit(X, y, classes=[0, 1])
|
|
94
|
+
|
|
95
|
+
# Create room
|
|
96
|
+
client = FLClient(server_url="ws://localhost:8080")
|
|
97
|
+
room = client.create_room(
|
|
98
|
+
model=model,
|
|
99
|
+
data_path="./data.csv",
|
|
100
|
+
target="label",
|
|
101
|
+
training_config={"local_epochs": 1, "batch_size": 32, "learning_rate": 0.01},
|
|
102
|
+
room_name="Phishing Detection",
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
room_id = room["id"]
|
|
106
|
+
print(f"✅ Room created: {room_id}")
|
|
107
|
+
print(f" Invite code: {room['invite_code']}")
|
|
108
|
+
|
|
109
|
+
# Wait for participants, then start
|
|
110
|
+
client.wait_for_clients(min_clients=2, timeout=120)
|
|
111
|
+
client.start_training()
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
### 2. Room Joiner
|
|
115
|
+
|
|
116
|
+
Each participant joins an existing room, validates their local dataset, and starts training. This follows the **3-Step Lifecycle**:
|
|
117
|
+
|
|
118
|
+
1. **`initialize()`** — Connects to the server, fetches the room's data schema and model configuration, and injects the latest global model weights into your local model.
|
|
119
|
+
2. **`validate()`** — Checks your local dataset (`data.csv`) against the room's expected schema (e.g., ensuring it has the correct target column and feature count) and performs a dummy forward pass to catch shape errors early.
|
|
120
|
+
3. **`start()`** — Signals readiness to the server and blocks while entering the federated training loop.
|
|
121
|
+
|
|
122
|
+
```python
|
|
123
|
+
from sklearn.linear_model import SGDClassifier
|
|
124
|
+
from fl_client import FLClient
|
|
125
|
+
|
|
126
|
+
model = SGDClassifier(loss="log_loss", penalty="l2", max_iter=1,
|
|
127
|
+
learning_rate="constant", eta0=0.01)
|
|
128
|
+
# ... partial_fit to initialize shape (same architecture as creator)
|
|
129
|
+
|
|
130
|
+
client = FLClient(server_url="ws://localhost:8080")
|
|
131
|
+
client.join(room_id, invite_code="abc123", model=model)
|
|
132
|
+
client.validate("./data.csv")
|
|
133
|
+
client.ready()
|
|
134
|
+
client.start(max_rounds=5) # Blocks until training completes
|
|
135
|
+
|
|
136
|
+
print("✅ Training complete!")
|
|
137
|
+
```
|
|
138
|
+
|
|
139
|
+
### 3. PyTorch Models
|
|
140
|
+
|
|
141
|
+
```python
|
|
142
|
+
import torch.nn as nn
|
|
143
|
+
from fl_client import FLClient
|
|
144
|
+
|
|
145
|
+
class PhishingMLP(nn.Module):
|
|
146
|
+
def __init__(self):
|
|
147
|
+
super().__init__()
|
|
148
|
+
self.net = nn.Sequential(
|
|
149
|
+
nn.Linear(30, 64), nn.ReLU(),
|
|
150
|
+
nn.Linear(64, 32), nn.ReLU(),
|
|
151
|
+
nn.Linear(32, 2),
|
|
152
|
+
)
|
|
153
|
+
def forward(self, x):
|
|
154
|
+
return self.net(x)
|
|
155
|
+
|
|
156
|
+
client = FLClient(server_url="ws://localhost:8080")
|
|
157
|
+
room = client.create_room(
|
|
158
|
+
model=PhishingMLP(),
|
|
159
|
+
data_path="./data.csv",
|
|
160
|
+
target="label",
|
|
161
|
+
training_config={"local_epochs": 2, "batch_size": 64, "learning_rate": 0.001},
|
|
162
|
+
room_name="PyTorch FL Room",
|
|
163
|
+
)
|
|
164
|
+
```
|
|
165
|
+
|
|
166
|
+
### 4. Prediction After Training
|
|
167
|
+
|
|
168
|
+
Because clients can disconnect, crash, or experience network drops, the DistFL SDK maintains a **local SQLite State Database**.
|
|
169
|
+
|
|
170
|
+
> [!TIP]
|
|
171
|
+
> **Why connect to the DB?**
|
|
172
|
+
> The server does not hold your data. Your final, fully-trained aggregated model weights are saved to your local `fl_client_state.db` at the end of training. By loading the state for your specific `client_id` (e.g. `worker-1`), you can extract these weights and run predictions *locally* without ever needing to communicate with the server again.
|
|
173
|
+
|
|
174
|
+
```python
|
|
175
|
+
from fl_client.storage.db import StateDB
|
|
176
|
+
from fl_client.model.wrapper import wrap_model
|
|
177
|
+
|
|
178
|
+
db = StateDB("fl_client_state.db")
|
|
179
|
+
state = db.load_state("worker-1")
|
|
180
|
+
|
|
181
|
+
wrapper = wrap_model(model)
|
|
182
|
+
wrapper.set_weights(state.last_weights)
|
|
183
|
+
|
|
184
|
+
predictions = model.predict(X_test)
|
|
185
|
+
accuracy = (predictions == y_test).mean()
|
|
186
|
+
print(f"✅ Accuracy: {accuracy * 100:.2f}%")
|
|
187
|
+
```
|
|
188
|
+
|
|
189
|
+
---
|
|
190
|
+
|
|
191
|
+
## 💻 CLI Reference
|
|
192
|
+
|
|
193
|
+
```bash
|
|
194
|
+
# Full lifecycle from a YAML config
|
|
195
|
+
distfl run --config config.yaml
|
|
196
|
+
|
|
197
|
+
# Create a room
|
|
198
|
+
distfl create-room --server-url ws://localhost:8080 --room-name "My Room"
|
|
199
|
+
|
|
200
|
+
# Join a room and train
|
|
201
|
+
distfl join-room ROOM_ID --data ./data.csv --server-url ws://localhost:8080
|
|
202
|
+
|
|
203
|
+
# Launch the real-time web dashboard
|
|
204
|
+
distfl ui --port 5050
|
|
205
|
+
|
|
206
|
+
# Inspect persisted client state
|
|
207
|
+
distfl status --client-id worker-1 --db-path fl_client_state.db
|
|
208
|
+
|
|
209
|
+
# Clear persisted state
|
|
210
|
+
distfl clear --client-id worker-1 --db-path fl_client_state.db
|
|
211
|
+
```
|
|
212
|
+
|
|
213
|
+
---
|
|
214
|
+
|
|
215
|
+
## ⚙️ Configuration
|
|
216
|
+
|
|
217
|
+
All options can be set via **YAML file**, **CLI flags**, or **environment variables** (`FL_` prefix):
|
|
218
|
+
|
|
219
|
+
```yaml
|
|
220
|
+
# Server connection
|
|
221
|
+
server_url: "ws://localhost:8080"
|
|
222
|
+
room_id: "" # Leave empty to create a new room
|
|
223
|
+
client_id: "" # Auto-generated if omitted
|
|
224
|
+
|
|
225
|
+
# Dataset
|
|
226
|
+
data_path: "./data.csv"
|
|
227
|
+
label_column: "label"
|
|
228
|
+
|
|
229
|
+
# Training hyperparameters
|
|
230
|
+
batch_size: 32
|
|
231
|
+
local_epochs: 2
|
|
232
|
+
learning_rate: 0.001
|
|
233
|
+
|
|
234
|
+
# State persistence
|
|
235
|
+
db_path: "fl_client_state.db" # SQLite for crash recovery
|
|
236
|
+
|
|
237
|
+
# Networking
|
|
238
|
+
reconnect_max_delay: 60.0 # Max backoff delay (seconds)
|
|
239
|
+
reconnect_base_delay: 1.0 # Initial reconnect delay
|
|
240
|
+
heartbeat_interval: 30.0 # WebSocket ping interval
|
|
241
|
+
|
|
242
|
+
# Dashboard
|
|
243
|
+
dashboard_port: 5050 # Real-time metrics UI (0 = disabled)
|
|
244
|
+
|
|
245
|
+
# Logging
|
|
246
|
+
log_level: "INFO" # DEBUG, INFO, WARNING, ERROR
|
|
247
|
+
```
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
## 🧪 Supported Frameworks
|
|
252
|
+
|
|
253
|
+
| Framework | Requirements | Weight Extraction |
|
|
254
|
+
|---|---|---|
|
|
255
|
+
| **PyTorch** | Any `nn.Module` | `state_dict()` → 3D float32 lists |
|
|
256
|
+
| **Scikit-Learn** | Estimator with `partial_fit` (e.g. `SGDClassifier`, `SGDRegressor`) | `coef_` + `intercept_` → 3D float32 lists |
|
|
257
|
+
|
|
258
|
+
---
|
|
259
|
+
|
|
260
|
+
## 🧪 Testing
|
|
261
|
+
|
|
262
|
+
```bash
|
|
263
|
+
# Install dev dependencies
|
|
264
|
+
pip install -e ".[dev]"
|
|
265
|
+
|
|
266
|
+
# Run all 52 unit tests
|
|
267
|
+
python -m pytest tests/ -v
|
|
268
|
+
```
|
|
269
|
+
|
|
270
|
+
### Test Coverage
|
|
271
|
+
|
|
272
|
+
| Module | Tests | What's Covered |
|
|
273
|
+
|---|---|---|
|
|
274
|
+
| `test_compressor.py` | 7 | Compress/decompress round-trip, empty data, large payloads |
|
|
275
|
+
| `test_connection.py` | 8 | WS URL construction, connect/disconnect, send/receive |
|
|
276
|
+
| `test_serializer.py` | 9 | Serialize/deserialize, shape preservation, JSON round-trip |
|
|
277
|
+
| `test_storage.py` | 7 | SQLite save/load, upsert, clear, round logging |
|
|
278
|
+
| `test_trainer.py` | 4 | Train results, finite loss, accuracy metrics, multi-epoch |
|
|
279
|
+
| `test_validation.py` | 17 | NaN/Inf/shape/range checks, loss validation, weight shapes |
|
|
280
|
+
|
|
281
|
+
---
|
|
282
|
+
|
|
283
|
+
## 🔐 Privacy & Security
|
|
284
|
+
|
|
285
|
+
- **Data never leaves the client** — only model weight updates are transmitted
|
|
286
|
+
- **GZIP compression** — reduces bandwidth and adds a layer of obfuscation
|
|
287
|
+
- **Server-side validation** — NaN, Inf, out-of-range, shape mismatch, L2 norm, and duplicate submission checks
|
|
288
|
+
- **Invite codes** — rooms can be access-controlled via invite codes
|
|
289
|
+
- **Crash recovery** — SQLite persistence prevents duplicate round submissions
|
|
290
|
+
|
|
291
|
+
---
|
|
292
|
+
|
|
293
|
+
## 📄 License
|
|
294
|
+
|
|
295
|
+
MIT
|
|
296
|
+
|
|
297
|
+
---
|
|
298
|
+
|
|
299
|
+
<p align="center">
|
|
300
|
+
Built with ❤️ for privacy-preserving machine learning
|
|
301
|
+
</p>
|