federated-learning-framework 0.0.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- federated_learning_framework/__init__.py +9 -0
- federated_learning_framework/active_learning.py +7 -0
- federated_learning_framework/central_server.py +56 -0
- federated_learning_framework/client_device.py +59 -0
- federated_learning_framework/connection.py +54 -0
- federated_learning_framework/decorators.py +24 -0
- federated_learning_framework/encryption.py +22 -0
- federated_learning_framework/models/__init__.py +0 -0
- federated_learning_framework/models/abstract_model.py +18 -0
- federated_learning_framework/models/pytorch_model.py +28 -0
- federated_learning_framework/models/tensorflow_model.py +18 -0
- federated_learning_framework/utils.py +4 -0
- federated_learning_framework-0.0.61.dist-info/LICENSE +37 -0
- federated_learning_framework-0.0.61.dist-info/METADATA +333 -0
- federated_learning_framework-0.0.61.dist-info/RECORD +24 -0
- federated_learning_framework-0.0.61.dist-info/WHEEL +5 -0
- federated_learning_framework-0.0.61.dist-info/top_level.txt +2 -0
- tests/__init__.py +0 -0
- tests/test_active_learning.py +12 -0
- tests/test_central_server.py +18 -0
- tests/test_client_device.py +25 -0
- tests/test_connection.py +28 -0
- tests/test_encryption.py +9 -0
- tests/test_utils.py +10 -0
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from .active_learning import query_active_learning
|
|
2
|
+
from .central_server import CentralServer
|
|
3
|
+
from .client_device import ClientDevice
|
|
4
|
+
from .connection import ConnectionServer, ConnectionClient
|
|
5
|
+
from .decorators import federated_learning_decorator, encryption_decorator
|
|
6
|
+
from .encryption import create_context, encrypt_weights, decrypt_weights
|
|
7
|
+
from .models.tensorflow_model import TensorFlowModel
|
|
8
|
+
from .models.pytorch_model import PyTorchModel
|
|
9
|
+
from .utils import setup_logging
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import numpy as np
|
|
4
|
+
from federated_learning_framework.connection import ConnectionServer
|
|
5
|
+
from websockets.exceptions import ConnectionClosedError
|
|
6
|
+
from federated_learning_framework.encryption import create_context
|
|
7
|
+
|
|
8
|
+
class CentralServer:
|
|
9
|
+
def __init__(self, connection_type='websocket', host='0.0.0.0', port=8089, context=None):
|
|
10
|
+
self.model_weights = None
|
|
11
|
+
self.lock = asyncio.Lock()
|
|
12
|
+
self.clients = set()
|
|
13
|
+
self.logger = logging.getLogger(__name__)
|
|
14
|
+
self.connection = ConnectionServer(connection_type, host, port, self.handle_client)
|
|
15
|
+
self.context = context or create_context()
|
|
16
|
+
|
|
17
|
+
async def run_server(self):
|
|
18
|
+
self.logger.info("Central Server is starting...")
|
|
19
|
+
await self.connection.start()
|
|
20
|
+
|
|
21
|
+
async def handle_client(self, websocket, client_id):
|
|
22
|
+
self.clients.add(client_id)
|
|
23
|
+
self.logger.info(f"Central Server: Client {client_id} connected")
|
|
24
|
+
try:
|
|
25
|
+
while True:
|
|
26
|
+
message = await self.connection.receive(client_id)
|
|
27
|
+
if isinstance(message, dict):
|
|
28
|
+
if 'weights' in message:
|
|
29
|
+
await self.transmit_weights(message['weights'])
|
|
30
|
+
elif 'data_request' in message:
|
|
31
|
+
data = await self.get_data_from_client(client_id)
|
|
32
|
+
await self.send_data_to_client(client_id, {'data': data})
|
|
33
|
+
except ConnectionClosedError:
|
|
34
|
+
self.logger.info(f"Central Server: Client {client_id} disconnected")
|
|
35
|
+
finally:
|
|
36
|
+
self.clients.remove(client_id)
|
|
37
|
+
|
|
38
|
+
async def transmit_weights(self, weights):
|
|
39
|
+
async with self.lock:
|
|
40
|
+
self.model_weights = weights
|
|
41
|
+
await asyncio.gather(*[self.connection.send(client_id, {'weights': self.model_weights}) for client_id in self.clients])
|
|
42
|
+
self.logger.info("Central Server: Transmitted weights to clients")
|
|
43
|
+
|
|
44
|
+
async def send_data_to_client(self, client_id, data):
|
|
45
|
+
self.logger.info(f"Central Server: Sending data to client {client_id}")
|
|
46
|
+
await self.connection.send(client_id, data)
|
|
47
|
+
|
|
48
|
+
async def get_data_from_client(self, client_id):
|
|
49
|
+
self.logger.info(f"Central Server: Requesting data from client {client_id}. Simulating response.")
|
|
50
|
+
await asyncio.sleep(1)
|
|
51
|
+
return np.random.rand(10, 3072)
|
|
52
|
+
|
|
53
|
+
def query_active_learning(self, unlabeled_data, model):
|
|
54
|
+
uncertainty = model.predict(unlabeled_data)
|
|
55
|
+
selected_indices = np.argsort(uncertainty.max(axis=1))[:5]
|
|
56
|
+
return selected_indices
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
from federated_learning_framework.encryption import encrypt_weights, decrypt_weights
|
|
5
|
+
from federated_learning_framework.models.tensorflow_model import TensorFlowModel
|
|
6
|
+
import websockets
|
|
7
|
+
|
|
8
|
+
class ClientDevice:
|
|
9
|
+
def __init__(self, client_id, model: TensorFlowModel, context):
|
|
10
|
+
self.client_id = client_id
|
|
11
|
+
self.model = model
|
|
12
|
+
self.context = context
|
|
13
|
+
self.connection = None
|
|
14
|
+
self.logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
async def connect_to_central_server(self, uri):
|
|
17
|
+
try:
|
|
18
|
+
self.connection = await websockets.connect(uri)
|
|
19
|
+
await self.connection.send({'client_id': self.client_id})
|
|
20
|
+
self.logger.info(f"Client {self.client_id}: Connected to central server at {uri}")
|
|
21
|
+
except Exception as e:
|
|
22
|
+
self.logger.error(f"Client {self.client_id}: Error connecting to central server: {e}")
|
|
23
|
+
|
|
24
|
+
async def federated_learning(self, x_train, y_train):
|
|
25
|
+
try:
|
|
26
|
+
while True:
|
|
27
|
+
weights = await self.receive_weights()
|
|
28
|
+
if weights is None:
|
|
29
|
+
break
|
|
30
|
+
self.model.set_weights(decrypt_weights(self.context, weights))
|
|
31
|
+
self.model.train(x_train, y_train, epochs=1)
|
|
32
|
+
new_weights = self.model.get_weights()
|
|
33
|
+
await self.send_weights(encrypt_weights(self.context, new_weights))
|
|
34
|
+
except Exception as e:
|
|
35
|
+
self.logger.error(f"Client {self.client_id}: Error in federated learning loop: {e}")
|
|
36
|
+
|
|
37
|
+
async def receive_weights(self):
|
|
38
|
+
try:
|
|
39
|
+
message = await self.connection.recv()
|
|
40
|
+
self.logger.info(f"Client {self.client_id}: Received weights")
|
|
41
|
+
return message['weights']
|
|
42
|
+
except Exception as e:
|
|
43
|
+
self.logger.error(f"Client {self.client_id}: Error receiving weights: {e}")
|
|
44
|
+
|
|
45
|
+
async def send_weights(self, weights):
|
|
46
|
+
try:
|
|
47
|
+
await self.connection.send({'weights': weights})
|
|
48
|
+
self.logger.info(f"Client {self.client_id}: Sent weights to central server")
|
|
49
|
+
except Exception as e:
|
|
50
|
+
self.logger.error(f"Client {self.client_id}: Error sending weights: {e}")
|
|
51
|
+
|
|
52
|
+
async def request_data(self):
|
|
53
|
+
try:
|
|
54
|
+
await self.connection.send({'data_request': True})
|
|
55
|
+
data = await self.connection.recv()
|
|
56
|
+
self.logger.info(f"Client {self.client_id}: Received data from central server")
|
|
57
|
+
return data['data']
|
|
58
|
+
except Exception as e:
|
|
59
|
+
self.logger.error(f"Client {self.client_id}: Error requesting data: {e}")
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import websockets
|
|
3
|
+
import pickle
|
|
4
|
+
from websockets.exceptions import ConnectionClosedError
|
|
5
|
+
|
|
6
|
+
class ConnectionServer:
|
|
7
|
+
def __init__(self, connection_type, host, port, client_handler):
|
|
8
|
+
self.connection_type = connection_type
|
|
9
|
+
self.host = host
|
|
10
|
+
self.port = port
|
|
11
|
+
self.client_handler = client_handler
|
|
12
|
+
self.clients = {}
|
|
13
|
+
|
|
14
|
+
async def start(self):
|
|
15
|
+
if self.connection_type == 'websocket':
|
|
16
|
+
async with websockets.serve(self.handle_client, self.host, self.port):
|
|
17
|
+
await asyncio.Future() # Run forever
|
|
18
|
+
else:
|
|
19
|
+
raise NotImplementedError(f"Connection type {self.connection_type} not supported")
|
|
20
|
+
|
|
21
|
+
async def handle_client(self, websocket, path):
|
|
22
|
+
client_id = len(self.clients) + 1
|
|
23
|
+
self.clients[client_id] = websocket
|
|
24
|
+
await self.client_handler(websocket, client_id)
|
|
25
|
+
|
|
26
|
+
async def send(self, client_id, message):
|
|
27
|
+
client = self.clients[client_id]
|
|
28
|
+
serialized_message = pickle.dumps(message)
|
|
29
|
+
await client.send(serialized_message)
|
|
30
|
+
|
|
31
|
+
async def receive(self, client_id):
|
|
32
|
+
client = self.clients[client_id]
|
|
33
|
+
message = await client.recv()
|
|
34
|
+
return pickle.loads(message)
|
|
35
|
+
|
|
36
|
+
class ConnectionClient:
|
|
37
|
+
def __init__(self, connection_type, uri):
|
|
38
|
+
self.connection_type = connection_type
|
|
39
|
+
self.uri = uri
|
|
40
|
+
self.connection = None
|
|
41
|
+
|
|
42
|
+
async def connect(self):
|
|
43
|
+
if self.connection_type == 'websocket':
|
|
44
|
+
self.connection = await websockets.connect(self.uri)
|
|
45
|
+
else:
|
|
46
|
+
raise NotImplementedError(f"Connection type {self.connection_type} not supported")
|
|
47
|
+
|
|
48
|
+
async def send(self, message):
|
|
49
|
+
serialized_message = pickle.dumps(message)
|
|
50
|
+
await self.connection.send(serialized_message)
|
|
51
|
+
|
|
52
|
+
async def receive(self):
|
|
53
|
+
message = await self.connection.recv()
|
|
54
|
+
return pickle.loads(message)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# federated_learning_framework/decorators.py
|
|
2
|
+
import functools
|
|
3
|
+
|
|
4
|
+
def federated_learning_decorator(uri):
|
|
5
|
+
def decorator(func):
|
|
6
|
+
@functools.wraps(func)
|
|
7
|
+
async def wrapper(*args, **kwargs):
|
|
8
|
+
central_server = kwargs.get('central_server')
|
|
9
|
+
client = kwargs.get('client')
|
|
10
|
+
await central_server.run_server()
|
|
11
|
+
await client.connect_to_central_server(uri)
|
|
12
|
+
return await func(*args, **kwargs)
|
|
13
|
+
return wrapper
|
|
14
|
+
return decorator
|
|
15
|
+
|
|
16
|
+
def encryption_decorator(context):
|
|
17
|
+
def decorator(func):
|
|
18
|
+
@functools.wraps(func)
|
|
19
|
+
async def wrapper(*args, **kwargs):
|
|
20
|
+
client = kwargs.get('client')
|
|
21
|
+
client.context = context
|
|
22
|
+
return await func(*args, **kwargs)
|
|
23
|
+
return wrapper
|
|
24
|
+
return decorator
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import tenseal as ts
|
|
2
|
+
|
|
3
|
+
def create_context():
|
|
4
|
+
context = ts.context(ts.SCHEME_TYPE.CKKS, poly_modulus_degree=32768, coeff_mod_bit_sizes=[60, 40, 40, 60])
|
|
5
|
+
context.generate_galois_keys()
|
|
6
|
+
context.global_scale = 2**40
|
|
7
|
+
return context
|
|
8
|
+
|
|
9
|
+
def encrypt_weights(context, model_weights):
|
|
10
|
+
encrypted_weights = []
|
|
11
|
+
for weight in model_weights:
|
|
12
|
+
weight_array = weight.flatten().tolist()
|
|
13
|
+
encrypted_vector = ts.ckks_vector(context, weight_array)
|
|
14
|
+
encrypted_weights.append(encrypted_vector.serialize())
|
|
15
|
+
return encrypted_weights
|
|
16
|
+
|
|
17
|
+
def decrypt_weights(context, encrypted_weights):
|
|
18
|
+
decrypted_weights = []
|
|
19
|
+
for enc_weight in encrypted_weights:
|
|
20
|
+
enc_vector = ts.ckks_vector_from(context, enc_weight)
|
|
21
|
+
decrypted_weights.append(enc_vector.decrypt())
|
|
22
|
+
return decrypted_weights
|
|
File without changes
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
class AbstractModel(ABC):
|
|
4
|
+
@abstractmethod
|
|
5
|
+
def get_weights(self):
|
|
6
|
+
pass
|
|
7
|
+
|
|
8
|
+
@abstractmethod
|
|
9
|
+
def set_weights(self, weights):
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def train(self, x_train, y_train, epochs=1):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def predict(self, data):
|
|
18
|
+
pass
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from federated_learning_framework.models.abstract_model import AbstractModel
|
|
4
|
+
|
|
5
|
+
class PyTorchModel(AbstractModel):
|
|
6
|
+
def __init__(self, model):
|
|
7
|
+
self.model = model
|
|
8
|
+
|
|
9
|
+
def get_weights(self):
|
|
10
|
+
return [param.data.numpy() for param in self.model.parameters()]
|
|
11
|
+
|
|
12
|
+
def set_weights(self, weights):
|
|
13
|
+
for param, weight in zip(self.model.parameters(), weights):
|
|
14
|
+
param.data = torch.tensor(weight, dtype=param.data.dtype)
|
|
15
|
+
|
|
16
|
+
def train(self, x_train, y_train, epochs=1):
|
|
17
|
+
criterion = nn.CrossEntropyLoss()
|
|
18
|
+
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
|
|
19
|
+
for _ in range(epochs):
|
|
20
|
+
optimizer.zero_grad()
|
|
21
|
+
outputs = self.model(x_train)
|
|
22
|
+
loss = criterion(outputs, y_train)
|
|
23
|
+
loss.backward()
|
|
24
|
+
optimizer.step()
|
|
25
|
+
|
|
26
|
+
def predict(self, data):
|
|
27
|
+
with torch.no_grad():
|
|
28
|
+
return self.model(data).numpy()
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import tensorflow as tf
|
|
2
|
+
from federated_learning_framework.models.abstract_model import AbstractModel
|
|
3
|
+
|
|
4
|
+
class TensorFlowModel(AbstractModel):
|
|
5
|
+
def __init__(self, model):
|
|
6
|
+
self.model = model
|
|
7
|
+
|
|
8
|
+
def get_weights(self):
|
|
9
|
+
return self.model.get_weights()
|
|
10
|
+
|
|
11
|
+
def set_weights(self, weights):
|
|
12
|
+
self.model.set_weights(weights)
|
|
13
|
+
|
|
14
|
+
def train(self, x_train, y_train, epochs=1):
|
|
15
|
+
self.model.fit(x_train, y_train, epochs=epochs, verbose=0)
|
|
16
|
+
|
|
17
|
+
def predict(self, data):
|
|
18
|
+
return self.model.predict(data)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# Custom License for Federated Learning Framework
|
|
2
|
+
|
|
3
|
+
This license governs the use of the Federated Learning Framework. By using this framework, you agree to the following terms and conditions:
|
|
4
|
+
|
|
5
|
+
## Academic Use
|
|
6
|
+
|
|
7
|
+
Academic use of this framework is permitted free of charge, provided that proper attribution is given. When using this framework in academic works, please include the following citation:
|
|
8
|
+
|
|
9
|
+
```
|
|
10
|
+
@framework{
|
|
11
|
+
author = {Mehrdad Javadi},
|
|
12
|
+
author_email = {mehrdaddjavadi@gamil.com}
|
|
13
|
+
title = {Federated Learning Framework},
|
|
14
|
+
year = {2024},
|
|
15
|
+
url = {https://github.com/mehrdaddjavadi/federated_learning_framework}
|
|
16
|
+
}
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
## Commercial, Governmental, and Other Non-Academic Use
|
|
20
|
+
|
|
21
|
+
For commercial, governmental, and other non-academic uses, please contact the author to obtain a usage license. All such uses are subject to approval and may require a licensing fee.
|
|
22
|
+
|
|
23
|
+
Contact: mehrdaddjavadi@gmail.com or mehrdadjavadi64@gmail.com
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
## Rights
|
|
27
|
+
|
|
28
|
+
All rights to this framework, including its use and distribution, belong to the author. Unauthorized use, distribution, or modification of this framework is strictly prohibited.
|
|
29
|
+
|
|
30
|
+
## Disclaimer
|
|
31
|
+
|
|
32
|
+
This framework is provided "as is", without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose, and no infringement. In no event shall the author be liable for any claim, damages, or other liability, whether in an action of contract, tort, or otherwise, arising from, out of, or in connection with the framework or the use or other dealings in the framework.
|
|
33
|
+
|
|
34
|
+
---
|
|
35
|
+
|
|
36
|
+
Mehrdad javadi
|
|
37
|
+
2024
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: federated-learning-framework
|
|
3
|
+
Version: 0.0.61
|
|
4
|
+
Summary: A modular and extensible framework for federated learning applications.
|
|
5
|
+
Home-page: https://github.com/mehrdaddjavadi/federated_learning_framework
|
|
6
|
+
Author: Mehrdad Javadi
|
|
7
|
+
Author-email: mehrdaddjavadi@gmail.com
|
|
8
|
+
Keywords: federated learning,machine learning,deep learning,active learning,encryption,homomorphic encryption
|
|
9
|
+
Classifier: Development Status :: 4 - Beta
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: Intended Audience :: Education
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.6
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.7
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
19
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
21
|
+
Requires-Python: >=3.6
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
License-File: LICENSE
|
|
24
|
+
Requires-Dist: numpy
|
|
25
|
+
Requires-Dist: tensorflow
|
|
26
|
+
Requires-Dist: websockets
|
|
27
|
+
Requires-Dist: pytest
|
|
28
|
+
Requires-Dist: tenseal
|
|
29
|
+
|
|
30
|
+
# Federated Learning Framework
|
|
31
|
+
|
|
32
|
+
## Overview
|
|
33
|
+
|
|
34
|
+
Welcome to the Federated Learning Framework, a modular and extensible solution for implementing federated learning across various applications. Harness the power of collective intelligence, ensure data privacy with homomorphic encryption, and apply it to domains like NLP, autonomous vehicles, drones, and more.
|
|
35
|
+
|
|
36
|
+
## Features
|
|
37
|
+
|
|
38
|
+
- **Modular and Extensible**: Easily customizable for different machine learning and deep learning applications.
|
|
39
|
+
- **Secure**: Utilizes homomorphic encryption to ensure data privacy.
|
|
40
|
+
- **Active Learning**: Incorporates active learning strategies to improve model performance.
|
|
41
|
+
- **Flexible Communication**: Supports various connection methods including socket programming.
|
|
42
|
+
- **Customizable**: Users can edit and control every part of the framework with various functions.
|
|
43
|
+
|
|
44
|
+
## Potential Applications
|
|
45
|
+
|
|
46
|
+
### Healthcare
|
|
47
|
+
|
|
48
|
+
Federated learning can be used to train models on patient data from multiple hospitals without sharing sensitive information. This approach can improve medical diagnostics and treatment recommendations while preserving patient privacy.
|
|
49
|
+
|
|
50
|
+
### Autonomous Vehicles
|
|
51
|
+
|
|
52
|
+
By collecting and learning from data across multiple autonomous vehicles, the framework can help improve the safety and performance of self-driving cars without exposing individual vehicle data.
|
|
53
|
+
|
|
54
|
+
### Drones
|
|
55
|
+
|
|
56
|
+
Drones can use federated learning to share and learn from data collected during their operations, enhancing their navigation, object detection, and other capabilities while ensuring data security.
|
|
57
|
+
|
|
58
|
+
### Natural Language Processing (NLP)
|
|
59
|
+
|
|
60
|
+
Federated learning can be applied to train NLP models on data from multiple sources, such as user devices, to improve language understanding and generation without compromising user privacy.
|
|
61
|
+
|
|
62
|
+
### Finance
|
|
63
|
+
|
|
64
|
+
Financial institutions can use federated learning to develop fraud detection and risk management models by leveraging data from multiple sources while keeping customer data secure.
|
|
65
|
+
|
|
66
|
+
### Smart Homes and IoT Devices
|
|
67
|
+
|
|
68
|
+
IoT devices in smart homes can collaboratively learn from user interactions to optimize performance and provide better services without sharing raw data.
|
|
69
|
+
|
|
70
|
+
## Detailed Component Description
|
|
71
|
+
|
|
72
|
+
### Central Server
|
|
73
|
+
|
|
74
|
+
**File:** `central_server.py`
|
|
75
|
+
|
|
76
|
+
The central server orchestrates the federated learning process by coordinating the communication and aggregation of model weights from various client devices.
|
|
77
|
+
|
|
78
|
+
**Key Functions:**
|
|
79
|
+
|
|
80
|
+
- `run_server`: Starts the server to handle client connections.
|
|
81
|
+
- `handle_client`: Manages incoming messages from clients.
|
|
82
|
+
- `transmit_weights`: Broadcasts the aggregated weights to clients.
|
|
83
|
+
- `send_data_to_client`: Sends specific data to a client.
|
|
84
|
+
- `get_data_from_client`: Requests and receives data from a client.
|
|
85
|
+
- `query_active_learning`: Implements active learning strategies to select data for labeling.
|
|
86
|
+
|
|
87
|
+
### Client Device
|
|
88
|
+
|
|
89
|
+
**File:** `client_device.py`
|
|
90
|
+
|
|
91
|
+
Client devices perform local training on their datasets and communicate with the central server.
|
|
92
|
+
|
|
93
|
+
**Key Functions:**
|
|
94
|
+
|
|
95
|
+
- `connect_to_central_server`: Connects to the central server.
|
|
96
|
+
- `federated_learning`: Coordinates local training and communication with the server.
|
|
97
|
+
- `receive_weights`: Receives model weights from the central server.
|
|
98
|
+
- `send_weights`: Sends model weights to the central server.
|
|
99
|
+
- `receive_data`: Receives data from the central server.
|
|
100
|
+
|
|
101
|
+
### Encryption
|
|
102
|
+
|
|
103
|
+
**File:** `encryption.py`
|
|
104
|
+
|
|
105
|
+
Provides functions for creating encryption contexts and encrypting/decrypting model weights.
|
|
106
|
+
|
|
107
|
+
**Key Functions:**
|
|
108
|
+
|
|
109
|
+
- `create_context`: Sets up the encryption context using TenSEAL.
|
|
110
|
+
- `encrypt_weights`: Encrypts model weights.
|
|
111
|
+
- `decrypt_weights`: Decrypts encrypted model weights.
|
|
112
|
+
|
|
113
|
+
### Active Learning
|
|
114
|
+
|
|
115
|
+
**File:** `active_learning.py`
|
|
116
|
+
|
|
117
|
+
Implements active learning strategies to enhance the training process by selectively querying informative data points.
|
|
118
|
+
|
|
119
|
+
**Key Functions:**
|
|
120
|
+
|
|
121
|
+
- `select_informative_samples`: Selects samples for labeling based on uncertainty.
|
|
122
|
+
|
|
123
|
+
### Connection
|
|
124
|
+
|
|
125
|
+
**File:** `connection.py`
|
|
126
|
+
|
|
127
|
+
Manages the connection types and protocols (e.g., WebSocket) for communication between the central server and client devices.
|
|
128
|
+
|
|
129
|
+
**Key Functions:**
|
|
130
|
+
|
|
131
|
+
- `run_server`: Starts a WebSocket server.
|
|
132
|
+
- `connect_to_server`: Establishes a WebSocket connection to the server.
|
|
133
|
+
|
|
134
|
+
### Decorators
|
|
135
|
+
|
|
136
|
+
**File:** `decorators.py`
|
|
137
|
+
|
|
138
|
+
Provides decorators for adding federated learning and encryption functionalities to functions.
|
|
139
|
+
|
|
140
|
+
**Key Functions:**
|
|
141
|
+
|
|
142
|
+
- `federated_learning_decorator`: Wraps a function to enable federated learning.
|
|
143
|
+
- `encryption_decorator`: Wraps a function to enable homomorphic encryption.
|
|
144
|
+
|
|
145
|
+
### Utilities
|
|
146
|
+
|
|
147
|
+
**File:** `utils.py`
|
|
148
|
+
|
|
149
|
+
Includes utility functions used throughout the framework.
|
|
150
|
+
|
|
151
|
+
## Installation
|
|
152
|
+
|
|
153
|
+
Clone the repository:
|
|
154
|
+
|
|
155
|
+
```sh
|
|
156
|
+
git clone https://github.com/mehrdaddjavadi/federated_learning_framework.git
|
|
157
|
+
```
|
|
158
|
+
|
|
159
|
+
Navigate to the directory:
|
|
160
|
+
|
|
161
|
+
```sh
|
|
162
|
+
cd federated_learning_framework
|
|
163
|
+
```
|
|
164
|
+
|
|
165
|
+
Install the dependencies:
|
|
166
|
+
|
|
167
|
+
```sh
|
|
168
|
+
pip install -r requirements.txt
|
|
169
|
+
```
|
|
170
|
+
|
|
171
|
+
## Usage
|
|
172
|
+
|
|
173
|
+
### Setting Up the Central Server
|
|
174
|
+
|
|
175
|
+
```python
|
|
176
|
+
import asyncio
|
|
177
|
+
from federated_learning_framework.central_server import CentralServer
|
|
178
|
+
|
|
179
|
+
async def main():
|
|
180
|
+
server = CentralServer()
|
|
181
|
+
await server.run_server()
|
|
182
|
+
|
|
183
|
+
asyncio.run(main())
|
|
184
|
+
```
|
|
185
|
+
|
|
186
|
+
### Setting Up the Central Server On Interactive Environment Like Jupyter Notebook
|
|
187
|
+
|
|
188
|
+
```python
|
|
189
|
+
import nest_asyncio
|
|
190
|
+
import asyncio
|
|
191
|
+
from federated_learning_framework.central_server import CentralServer
|
|
192
|
+
|
|
193
|
+
nest_asyncio.apply()
|
|
194
|
+
|
|
195
|
+
async def main():
|
|
196
|
+
server = CentralServer()
|
|
197
|
+
await server.run_server()
|
|
198
|
+
|
|
199
|
+
# If running in an environment with an existing event loop
|
|
200
|
+
if __name__ == "__main__":
|
|
201
|
+
asyncio.run(main())
|
|
202
|
+
```
|
|
203
|
+
|
|
204
|
+
### Setting Up a Client Device
|
|
205
|
+
|
|
206
|
+
```python
|
|
207
|
+
import asyncio
|
|
208
|
+
import tensorflow as tf
|
|
209
|
+
from federated_learning_framework.client_device import ClientDevice
|
|
210
|
+
from federated_learning_framework.encryption import create_context
|
|
211
|
+
|
|
212
|
+
# Define your model
|
|
213
|
+
model = tf.keras.Sequential([
|
|
214
|
+
tf.keras.layers.Dense(4, activation='relu', input_shape=(3072,)),
|
|
215
|
+
tf.keras.layers.Dense(10, activation='softmax')
|
|
216
|
+
])
|
|
217
|
+
|
|
218
|
+
# Create context for encryption
|
|
219
|
+
context = create_context()
|
|
220
|
+
|
|
221
|
+
# Initialize the client device
|
|
222
|
+
client = ClientDevice(client_id=1, model=model, context=context)
|
|
223
|
+
|
|
224
|
+
async def main():
|
|
225
|
+
uri = "ws://localhost:8089"
|
|
226
|
+
await client.connect_to_central_server(uri)
|
|
227
|
+
x_train, y_train = ... # Load your training data
|
|
228
|
+
await client.federated_learning(uri, x_train, y_train)
|
|
229
|
+
# Optionally receive data from central server
|
|
230
|
+
data = await client.receive_data()
|
|
231
|
+
print(f"Received data: {data}")
|
|
232
|
+
|
|
233
|
+
asyncio.run(main())
|
|
234
|
+
```
|
|
235
|
+
|
|
236
|
+
### Sample Execution Script Using Decorators For Interactive Environments Like Colab And Jupyter Notebook
|
|
237
|
+
|
|
238
|
+
```python
|
|
239
|
+
import asyncio
|
|
240
|
+
import tensorflow as tf
|
|
241
|
+
import numpy as np
|
|
242
|
+
from federated_learning_framework.client_device import ClientDevice
|
|
243
|
+
from federated_learning_framework.central_server import CentralServer
|
|
244
|
+
from federated_learning_framework.encryption import create_context
|
|
245
|
+
from federated_learning_framework.models.tensorflow_model import TensorFlowModel
|
|
246
|
+
|
|
247
|
+
# Setup logging
|
|
248
|
+
import logging
|
|
249
|
+
logging.basicConfig(level=logging.INFO)
|
|
250
|
+
|
|
251
|
+
# Define a simple TensorFlow model
|
|
252
|
+
model = tf.keras.Sequential([
|
|
253
|
+
tf.keras.layers.Dense(4, activation='relu', input_shape=(3072,)),
|
|
254
|
+
tf.keras.layers.Dense(10, activation='softmax')
|
|
255
|
+
])
|
|
256
|
+
wrapped_model = TensorFlowModel(model)
|
|
257
|
+
|
|
258
|
+
# Create encryption context
|
|
259
|
+
context = create_context()
|
|
260
|
+
|
|
261
|
+
# Initialize server and clients
|
|
262
|
+
central_server = CentralServer(context=context)
|
|
263
|
+
client1 = ClientDevice(client_id=1, model=wrapped_model, context=context)
|
|
264
|
+
client2 = ClientDevice(client_id=2, model=wrapped_model, context=context)
|
|
265
|
+
|
|
266
|
+
# Dummy training data
|
|
267
|
+
x_train = np.random.rand(10, 3072)
|
|
268
|
+
y_train = np.random.randint(0, 10, 10)
|
|
269
|
+
|
|
270
|
+
async def main():
|
|
271
|
+
await asyncio.gather(
|
|
272
|
+
central_server.run_server(),
|
|
273
|
+
client1.connect_to_central_server("ws://localhost:8089"),
|
|
274
|
+
client2.connect_to_central_server("ws://localhost:8089"),
|
|
275
|
+
client1.federated_learning(x_train, y_train),
|
|
276
|
+
client2.federated_learning(x_train, y_train)
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
asyncio.run(main())
|
|
280
|
+
|
|
281
|
+
```
|
|
282
|
+
|
|
283
|
+
### Using Decorators
|
|
284
|
+
|
|
285
|
+
```python
|
|
286
|
+
import asyncio
|
|
287
|
+
import tensorflow as tf
|
|
288
|
+
from federated_learning_framework.decorators import federated_learning_decorator, encryption_decorator
|
|
289
|
+
from federated_learning_framework.client_device import ClientDevice
|
|
290
|
+
from federated_learning_framework.encryption import create_context
|
|
291
|
+
|
|
292
|
+
# Create context for encryption
|
|
293
|
+
context = create_context()
|
|
294
|
+
|
|
295
|
+
# Define your model
|
|
296
|
+
model = tf.keras.Sequential([
|
|
297
|
+
tf.keras.layers.Dense(4, activation='relu', input_shape=(3072,)),
|
|
298
|
+
tf.keras.layers.Dense(10, activation='softmax')
|
|
299
|
+
])
|
|
300
|
+
|
|
301
|
+
@federated_learning_decorator(uri="ws://localhost:8089")
|
|
302
|
+
@encryption_decorator(context=context)
|
|
303
|
+
async def main():
|
|
304
|
+
client = ClientDevice(client_id=1, model=model, context=context)
|
|
305
|
+
await client.connect_to_central_server('ws://localhost:8089')
|
|
306
|
+
x_train, y_train = ... # Load your training data
|
|
307
|
+
await client.federated_learning('ws://localhost:8089', x_train, y_train)
|
|
308
|
+
|
|
309
|
+
asyncio.run(main())
|
|
310
|
+
```
|
|
311
|
+
|
|
312
|
+
## Running Tests
|
|
313
|
+
|
|
314
|
+
To run the tests, execute the following command in the root directory:
|
|
315
|
+
|
|
316
|
+
```sh
|
|
317
|
+
python -m unittest discover -s tests
|
|
318
|
+
```
|
|
319
|
+
|
|
320
|
+
## License
|
|
321
|
+
|
|
322
|
+
The usage of this library is free for academic work with proper referencing. For business, governmental, and any other types of usage, please contact me directly. All rights are reserved.
|
|
323
|
+
|
|
324
|
+
**Contact:** mehrdaddjavadi@gmail.com
|
|
325
|
+
|
|
326
|
+
## Contributing
|
|
327
|
+
|
|
328
|
+
Feel free to contribute by submitting a pull request or opening an issue.
|
|
329
|
+
|
|
330
|
+
```
|
|
331
|
+
|
|
332
|
+
Copy and paste this into your README.md file. This format provides a clear, organized structure and includes all necessary details and instructions for potential users and contributors.
|
|
333
|
+
```
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
federated_learning_framework/__init__.py,sha256=HvZKThSQYzF5ofyU4UIgaaInB8mR9Y5jeZEd43l1_t0,482
|
|
2
|
+
federated_learning_framework/active_learning.py,sha256=jv6xzbRsHdzieMDwEtTpEn8YL_0M4a3DOawul0X5EnU,269
|
|
3
|
+
federated_learning_framework/central_server.py,sha256=eexr9XoZt0I5hGj1iG3HwX7CmnTX2KAjXfZGMkpRbB8,2620
|
|
4
|
+
federated_learning_framework/client_device.py,sha256=8b5FPXuWJ-PukR4CYzpY9JSVbn-kxtJRq-AhXXAy_rM,2659
|
|
5
|
+
federated_learning_framework/connection.py,sha256=D7zkKkgcFBhUEMcHUOLWZWJEO3jjxxFUrAishBDBT2w,1958
|
|
6
|
+
federated_learning_framework/decorators.py,sha256=gihFUEeIla0osy3319u8sMx0YxAdZqmAfkdRKmJsHXA,825
|
|
7
|
+
federated_learning_framework/encryption.py,sha256=bNarMUqYuyJG52yceGdSyW5jIDyVky2KnX7gqzFdKtY,840
|
|
8
|
+
federated_learning_framework/utils.py,sha256=t4dCUuAqvDach0qZObVoTfaQjXh06v5cpVfLxs_cRQ0,231
|
|
9
|
+
federated_learning_framework/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
+
federated_learning_framework/models/abstract_model.py,sha256=oNn4auUpt7XqOnfF5l0ebCWn9olPqb5ADwfibIcajQA,357
|
|
11
|
+
federated_learning_framework/models/pytorch_model.py,sha256=kibGM_dzM4uZIcxB3ZCdbv8KPkiKzFT9Vqnu4CLh9fg,988
|
|
12
|
+
federated_learning_framework/models/tensorflow_model.py,sha256=a-ZZbzoYjOL8Pza2KRn4q-TYBm-7OBuYeeOBzTFXuUs,548
|
|
13
|
+
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
+
tests/test_active_learning.py,sha256=qrnaCkYwdWTu_Jvz17jz0QElogPmKtayKy_qG7P_U08,535
|
|
15
|
+
tests/test_central_server.py,sha256=OOGvSOlKMyMHbrXwDOgdA-QYJMOP7uLCiLpzLBLoM9M,604
|
|
16
|
+
tests/test_client_device.py,sha256=hRLZWUUPCrF-BHK46QtpJBdGjO4WZm5tcoZGQtefKyk,932
|
|
17
|
+
tests/test_connection.py,sha256=E5lFwzvB8FCpjfuePwfc8eQq-CGDJBLW6GmtFB0FldA,921
|
|
18
|
+
tests/test_encryption.py,sha256=PHSRxqS4ltepe7ztOhNQcrfHOYhMMRKqKQJU0Dsa8-o,374
|
|
19
|
+
tests/test_utils.py,sha256=H7Cp9nOf4IOcLeNZKtAUZt-YY3x0c1WiUHAmQKROHWo,272
|
|
20
|
+
federated_learning_framework-0.0.61.dist-info/LICENSE,sha256=Igduf49yy8CINu7sRZg4TvMGy7jOi0xNgmL3e4PRoe0,1646
|
|
21
|
+
federated_learning_framework-0.0.61.dist-info/METADATA,sha256=K-JIgZKigswGbyx97Z4RBXcNDg9A8NrElpG9MAf9b3o,11188
|
|
22
|
+
federated_learning_framework-0.0.61.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
|
23
|
+
federated_learning_framework-0.0.61.dist-info/top_level.txt,sha256=bnGlrzFER9O24oPGb7hrGhFKBtcrx321yE2CJuwiRMY,35
|
|
24
|
+
federated_learning_framework-0.0.61.dist-info/RECORD,,
|
tests/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import tensorflow as tf
|
|
3
|
+
from federated_learning_framework.active_learning import query_active_learning
|
|
4
|
+
|
|
5
|
+
def test_active_learning():
|
|
6
|
+
model = tf.keras.Sequential([
|
|
7
|
+
tf.keras.layers.Dense(10, input_shape=(3072,), activation='softmax')
|
|
8
|
+
])
|
|
9
|
+
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
|
|
10
|
+
unlabeled_data = np.random.rand(100, 3072)
|
|
11
|
+
selected_indices = query_active_learning(model, unlabeled_data, 5)
|
|
12
|
+
assert len(selected_indices) == 5
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import pytest
|
|
3
|
+
from federated_learning_framework.central_server import CentralServer
|
|
4
|
+
from federated_learning_framework.encryption import create_context
|
|
5
|
+
|
|
6
|
+
@pytest.mark.asyncio
|
|
7
|
+
async def test_central_server():
|
|
8
|
+
context = create_context()
|
|
9
|
+
server = CentralServer(context=context)
|
|
10
|
+
server_task = asyncio.create_task(server.run_server())
|
|
11
|
+
|
|
12
|
+
await asyncio.sleep(1) # Give the server some time to start
|
|
13
|
+
|
|
14
|
+
# Simulate client connections and other test scenarios here
|
|
15
|
+
|
|
16
|
+
server_task.cancel()
|
|
17
|
+
with pytest.raises(asyncio.CancelledError):
|
|
18
|
+
await server_task
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import pytest
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
from federated_learning_framework.client_device import ClientDevice
|
|
5
|
+
from federated_learning_framework.encryption import create_context
|
|
6
|
+
from federated_learning_framework.models.tensorflow_model import TensorFlowModel
|
|
7
|
+
|
|
8
|
+
@pytest.mark.asyncio
|
|
9
|
+
async def test_client_device():
|
|
10
|
+
context = create_context()
|
|
11
|
+
model = tf.keras.Sequential([
|
|
12
|
+
tf.keras.layers.Dense(10, input_shape=(3072,), activation='softmax')
|
|
13
|
+
])
|
|
14
|
+
model = TensorFlowModel(model)
|
|
15
|
+
client = ClientDevice(client_id=1, model=model, context=context)
|
|
16
|
+
|
|
17
|
+
connect_task = asyncio.create_task(client.connect_to_central_server('ws://localhost:8089'))
|
|
18
|
+
|
|
19
|
+
await asyncio.sleep(1) # Give the client some time to connect
|
|
20
|
+
|
|
21
|
+
# Simulate communication and other test scenarios here
|
|
22
|
+
|
|
23
|
+
connect_task.cancel()
|
|
24
|
+
with pytest.raises(asyncio.CancelledError):
|
|
25
|
+
await connect_task
|
tests/test_connection.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import pytest
|
|
3
|
+
from federated_learning_framework.connection import ConnectionServer, ConnectionClient
|
|
4
|
+
|
|
5
|
+
@pytest.mark.asyncio
|
|
6
|
+
async def test_connection():
|
|
7
|
+
async def handle_client(connection, client_id):
|
|
8
|
+
message = await connection.receive()
|
|
9
|
+
await connection.send(message)
|
|
10
|
+
|
|
11
|
+
server = ConnectionServer('websocket', 'localhost', 8089, handle_client)
|
|
12
|
+
server_task = asyncio.create_task(server.start())
|
|
13
|
+
|
|
14
|
+
await asyncio.sleep(1) # Give the server some time to start
|
|
15
|
+
|
|
16
|
+
client = ConnectionClient('websocket', 'ws://localhost:8089')
|
|
17
|
+
await client.connect()
|
|
18
|
+
|
|
19
|
+
test_message = "test_message"
|
|
20
|
+
await client.send(test_message)
|
|
21
|
+
received_message = await client.receive()
|
|
22
|
+
|
|
23
|
+
assert received_message == test_message
|
|
24
|
+
|
|
25
|
+
await client.connection.close()
|
|
26
|
+
server_task.cancel()
|
|
27
|
+
with pytest.raises(asyncio.CancelledError):
|
|
28
|
+
await server_task
|
tests/test_encryption.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from federated_learning_framework.encryption import create_context, encrypt_weights, decrypt_weights
|
|
3
|
+
|
|
4
|
+
def test_encryption():
|
|
5
|
+
context = create_context()
|
|
6
|
+
weights = [np.random.rand(10, 10)]
|
|
7
|
+
encrypted = encrypt_weights(context, weights)
|
|
8
|
+
decrypted = decrypt_weights(context, encrypted)
|
|
9
|
+
assert np.allclose(weights[0], decrypted[0])
|
tests/test_utils.py
ADDED