syft-flwr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of syft-flwr might be problematic. Click here for more details.
- README.md +6 -0
- pyproject.toml +52 -0
- syft_flwr/__init__.py +6 -0
- syft_flwr/bootstrap.py +97 -0
- syft_flwr/cli.py +116 -0
- syft_flwr/config.py +36 -0
- syft_flwr/flower_client.py +65 -0
- syft_flwr/flower_server.py +25 -0
- syft_flwr/flwr_compatibility.py +121 -0
- syft_flwr/grid.py +171 -0
- syft_flwr/run.py +61 -0
- syft_flwr/run_simulation.py +204 -0
- syft_flwr/serde.py +15 -0
- syft_flwr/strategy/__init__.py +3 -0
- syft_flwr/strategy/fedavg.py +33 -0
- syft_flwr/templates/main.py.tpl +27 -0
- syft_flwr/utils.py +36 -0
- syft_flwr-0.1.0.dist-info/METADATA +25 -0
- syft_flwr-0.1.0.dist-info/RECORD +22 -0
- syft_flwr-0.1.0.dist-info/WHEEL +4 -0
- syft_flwr-0.1.0.dist-info/entry_points.txt +2 -0
- syft_flwr-0.1.0.dist-info/licenses/LICENSE +201 -0
syft_flwr/grid.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from typing import Iterable, cast
|
|
3
|
+
|
|
4
|
+
from flwr.common.message import Message
|
|
5
|
+
from flwr.common.typing import Run
|
|
6
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
7
|
+
from loguru import logger
|
|
8
|
+
from syft_core import Client
|
|
9
|
+
from syft_rpc import rpc, rpc_db
|
|
10
|
+
from typing_extensions import Optional
|
|
11
|
+
|
|
12
|
+
from syft_flwr.flwr_compatibility import (
|
|
13
|
+
Grid,
|
|
14
|
+
RecordDict,
|
|
15
|
+
check_reply_to_field,
|
|
16
|
+
create_flwr_message,
|
|
17
|
+
)
|
|
18
|
+
from syft_flwr.serde import bytes_to_flower_message, flower_message_to_bytes
|
|
19
|
+
from syft_flwr.utils import str_to_int
|
|
20
|
+
|
|
21
|
+
# this is what superlink super node do
|
|
22
|
+
AGGREGATOR_NODE_ID = 1
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SyftGrid(Grid):
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
datasites: list[str] = [],
|
|
29
|
+
client: Client = None,
|
|
30
|
+
) -> None:
|
|
31
|
+
self._client = Client.load() if client is None else client
|
|
32
|
+
self._run: Optional[Run] = None
|
|
33
|
+
self.node = Node(node_id=AGGREGATOR_NODE_ID)
|
|
34
|
+
self.datasites = datasites
|
|
35
|
+
self.client_map = {str_to_int(ds): ds for ds in self.datasites}
|
|
36
|
+
logger.debug(
|
|
37
|
+
f"Initialize SyftGrid for '{self._client.email}' with datasites: {self.datasites}"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def set_run(self, run_id: int) -> None:
|
|
41
|
+
# TODO: In Grpc Grid case, the superlink is the one which sets up the run id,
|
|
42
|
+
# do we need to do the same here, where the run id is set from an external context.
|
|
43
|
+
|
|
44
|
+
# Convert to Flower Run object
|
|
45
|
+
self._run = Run.create_empty(run_id)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def run(self) -> Run:
|
|
49
|
+
"""Run ID."""
|
|
50
|
+
return Run(**vars(cast(Run, self._run)))
|
|
51
|
+
|
|
52
|
+
def _check_message(self, message: Message) -> None:
|
|
53
|
+
# Check if the message is valid
|
|
54
|
+
if not (
|
|
55
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
56
|
+
and message.metadata.src_node_id == self.node.node_id
|
|
57
|
+
and message.metadata.message_id == ""
|
|
58
|
+
and check_reply_to_field(message.metadata)
|
|
59
|
+
and message.metadata.ttl > 0
|
|
60
|
+
):
|
|
61
|
+
logger.debug(f"Invalid message with metadata: {message.metadata}")
|
|
62
|
+
raise ValueError(f"Invalid message: {message}")
|
|
63
|
+
|
|
64
|
+
def create_message(
|
|
65
|
+
self,
|
|
66
|
+
content: RecordDict,
|
|
67
|
+
message_type: str,
|
|
68
|
+
dst_node_id: int,
|
|
69
|
+
group_id: str,
|
|
70
|
+
ttl: Optional[float] = None,
|
|
71
|
+
) -> Message:
|
|
72
|
+
"""Create a new message with specified parameters."""
|
|
73
|
+
return create_flwr_message(
|
|
74
|
+
content=content,
|
|
75
|
+
message_type=message_type,
|
|
76
|
+
dst_node_id=dst_node_id,
|
|
77
|
+
group_id=group_id,
|
|
78
|
+
ttl=ttl,
|
|
79
|
+
run_id=cast(Run, self._run).run_id,
|
|
80
|
+
src_node_id=self.node.node_id,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def get_node_ids(self) -> list[int]:
|
|
84
|
+
"""Get node IDs of all connected nodes."""
|
|
85
|
+
# it is map from datasites to node id
|
|
86
|
+
return list(self.client_map.keys())
|
|
87
|
+
|
|
88
|
+
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
89
|
+
"""Push messages to specified node IDs."""
|
|
90
|
+
# Construct Messages
|
|
91
|
+
run_id = cast(Run, self._run).run_id
|
|
92
|
+
message_ids = []
|
|
93
|
+
for msg in messages:
|
|
94
|
+
# Set metadata
|
|
95
|
+
msg.metadata.__dict__["_run_id"] = run_id
|
|
96
|
+
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
|
97
|
+
# RPC URL
|
|
98
|
+
dest_datasite = self.client_map[msg.metadata.dst_node_id]
|
|
99
|
+
url = rpc.make_url(dest_datasite, app_name="flwr", endpoint="messages")
|
|
100
|
+
# Check message
|
|
101
|
+
self._check_message(msg)
|
|
102
|
+
# Serialize message
|
|
103
|
+
msg_bytes = flower_message_to_bytes(msg)
|
|
104
|
+
# Send message
|
|
105
|
+
future = rpc.send(url=url, body=msg_bytes, client=self._client)
|
|
106
|
+
logger.debug(
|
|
107
|
+
f"Pushed message to {url} with metadata {msg.metadata}; size {len(msg_bytes) / 1024 / 1024} (Mb)"
|
|
108
|
+
)
|
|
109
|
+
# Save future
|
|
110
|
+
rpc_db.save_future(future=future, namespace="flwr", client=self._client)
|
|
111
|
+
message_ids.append(future.id)
|
|
112
|
+
|
|
113
|
+
return message_ids
|
|
114
|
+
|
|
115
|
+
def pull_messages(self, message_ids):
|
|
116
|
+
"""Pull messages based on message IDs."""
|
|
117
|
+
messages = {}
|
|
118
|
+
|
|
119
|
+
for msg_id in message_ids:
|
|
120
|
+
future = rpc_db.get_future(future_id=msg_id, client=self._client)
|
|
121
|
+
response = future.resolve()
|
|
122
|
+
if response is None:
|
|
123
|
+
continue
|
|
124
|
+
|
|
125
|
+
response.raise_for_status()
|
|
126
|
+
|
|
127
|
+
if not response.body:
|
|
128
|
+
raise ValueError(f"Empty response: {response}")
|
|
129
|
+
|
|
130
|
+
message: Message = bytes_to_flower_message(response.body)
|
|
131
|
+
if message.has_error():
|
|
132
|
+
error = message.error
|
|
133
|
+
logger.error(
|
|
134
|
+
f"Message {msg_id} error with code={error.code}, reason={error.reason}"
|
|
135
|
+
)
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
logger.debug(
|
|
139
|
+
f"Pulled message from {response.url} with metadata: {message.metadata}, size: {len(response.body) / 1024 / 1024} (Mb)"
|
|
140
|
+
)
|
|
141
|
+
messages[msg_id] = message
|
|
142
|
+
rpc_db.delete_future(future_id=msg_id, client=self._client)
|
|
143
|
+
|
|
144
|
+
return messages
|
|
145
|
+
|
|
146
|
+
def send_and_receive(
|
|
147
|
+
self,
|
|
148
|
+
messages: Iterable[Message],
|
|
149
|
+
*,
|
|
150
|
+
timeout: Optional[float] = None,
|
|
151
|
+
) -> Iterable[Message]:
|
|
152
|
+
"""Push messages to specified node IDs and pull the reply messages.
|
|
153
|
+
|
|
154
|
+
This method sends a list of messages to their destination node IDs and then
|
|
155
|
+
waits for the replies. It continues to pull replies until either all replies are
|
|
156
|
+
received or the specified timeout duration is exceeded.
|
|
157
|
+
"""
|
|
158
|
+
# Push messages
|
|
159
|
+
msg_ids = set(self.push_messages(messages))
|
|
160
|
+
|
|
161
|
+
# Pull messages
|
|
162
|
+
end_time = time.time() + (timeout if timeout is not None else 0.0)
|
|
163
|
+
ret = {}
|
|
164
|
+
while timeout is None or time.time() < end_time:
|
|
165
|
+
res_msgs = self.pull_messages(msg_ids)
|
|
166
|
+
ret.update(res_msgs)
|
|
167
|
+
msg_ids.difference_update(res_msgs.keys())
|
|
168
|
+
if len(msg_ids) == 0:
|
|
169
|
+
break
|
|
170
|
+
time.sleep(3)
|
|
171
|
+
return ret.values()
|
syft_flwr/run.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from uuid import uuid4
|
|
4
|
+
|
|
5
|
+
from flwr.client.client_app import LoadClientAppError
|
|
6
|
+
from flwr.common import Context
|
|
7
|
+
from flwr.common.object_ref import load_app
|
|
8
|
+
from flwr.server.server_app import LoadServerAppError
|
|
9
|
+
|
|
10
|
+
from syft_flwr.config import load_flwr_pyproject
|
|
11
|
+
from syft_flwr.flower_client import syftbox_flwr_client
|
|
12
|
+
from syft_flwr.flower_server import syftbox_flwr_server
|
|
13
|
+
from syft_flwr.flwr_compatibility import RecordDict
|
|
14
|
+
from syft_flwr.run_simulation import run
|
|
15
|
+
|
|
16
|
+
__all__ = ["syftbox_run_flwr_client", "syftbox_run_flwr_server", "run"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Suppress Pydantic deprecation warnings
|
|
20
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pydantic")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def syftbox_run_flwr_client(flower_project_dir: Path) -> None:
|
|
24
|
+
pyproject_conf = load_flwr_pyproject(flower_project_dir)
|
|
25
|
+
client_ref = pyproject_conf["tool"]["flwr"]["app"]["components"]["clientapp"]
|
|
26
|
+
|
|
27
|
+
context = Context(
|
|
28
|
+
run_id=uuid4().int,
|
|
29
|
+
node_id=uuid4().int,
|
|
30
|
+
node_config=pyproject_conf["tool"]["flwr"]["app"]["config"],
|
|
31
|
+
state=RecordDict(),
|
|
32
|
+
run_config=pyproject_conf["tool"]["flwr"]["app"]["config"],
|
|
33
|
+
)
|
|
34
|
+
client_app = load_app(
|
|
35
|
+
client_ref,
|
|
36
|
+
LoadClientAppError,
|
|
37
|
+
flower_project_dir,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
syftbox_flwr_client(client_app, context)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def syftbox_run_flwr_server(flower_project_dir: Path) -> None:
|
|
44
|
+
pyproject_conf = load_flwr_pyproject(flower_project_dir)
|
|
45
|
+
datasites = pyproject_conf["tool"]["syft_flwr"]["datasites"]
|
|
46
|
+
server_ref = pyproject_conf["tool"]["flwr"]["app"]["components"]["serverapp"]
|
|
47
|
+
|
|
48
|
+
context = Context(
|
|
49
|
+
run_id=uuid4().int,
|
|
50
|
+
node_id=uuid4().int,
|
|
51
|
+
node_config=pyproject_conf["tool"]["flwr"]["app"]["config"],
|
|
52
|
+
state=RecordDict(),
|
|
53
|
+
run_config=pyproject_conf["tool"]["flwr"]["app"]["config"],
|
|
54
|
+
)
|
|
55
|
+
server_app = load_app(
|
|
56
|
+
server_ref,
|
|
57
|
+
LoadServerAppError,
|
|
58
|
+
flower_project_dir,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
syftbox_flwr_server(server_app, context, datasites)
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import os
|
|
3
|
+
import uuid
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from loguru import logger
|
|
7
|
+
from syft_rds.client.rds_client import RDSClient
|
|
8
|
+
from syft_rds.orchestra import remove_rds_stack_dir, setup_rds_server
|
|
9
|
+
from typing_extensions import Union
|
|
10
|
+
|
|
11
|
+
from syft_flwr.config import load_flwr_pyproject
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _setup_mock_rds_clients(
|
|
15
|
+
project_dir: Path, aggregator: str, datasites: list[str]
|
|
16
|
+
) -> tuple[str, list[RDSClient], RDSClient]:
|
|
17
|
+
"""Setup mock RDS clients for the given project directory"""
|
|
18
|
+
key = project_dir.name + "_" + str(uuid.uuid4())
|
|
19
|
+
remove_rds_stack_dir(key)
|
|
20
|
+
|
|
21
|
+
ds_stack = setup_rds_server(email=aggregator, key=key)
|
|
22
|
+
ds_client = ds_stack.init_session(host=aggregator)
|
|
23
|
+
|
|
24
|
+
do_clients = []
|
|
25
|
+
for datasite in datasites:
|
|
26
|
+
do_stack = setup_rds_server(email=datasite, key=key)
|
|
27
|
+
do_client = do_stack.init_session(host=datasite)
|
|
28
|
+
do_clients.append(do_client)
|
|
29
|
+
|
|
30
|
+
return key, do_clients, ds_client
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
async def _run_main_py(
|
|
34
|
+
main_py_path: Path,
|
|
35
|
+
config_path: Path,
|
|
36
|
+
client_email: str,
|
|
37
|
+
log_dir: Path,
|
|
38
|
+
dataset_path: Union[str, Path] | None = None,
|
|
39
|
+
) -> int:
|
|
40
|
+
"""Run the `main.py` file for a given client"""
|
|
41
|
+
log_file_path = log_dir / f"{client_email}.log"
|
|
42
|
+
|
|
43
|
+
# setting up env variables
|
|
44
|
+
env = os.environ.copy()
|
|
45
|
+
env["SYFTBOX_CLIENT_CONFIG_PATH"] = str(config_path)
|
|
46
|
+
env["DATA_DIR"] = str(dataset_path)
|
|
47
|
+
|
|
48
|
+
# running the main.py file asynchronously in a subprocess
|
|
49
|
+
try:
|
|
50
|
+
with open(log_file_path, "w") as f:
|
|
51
|
+
process = await asyncio.create_subprocess_exec(
|
|
52
|
+
"python",
|
|
53
|
+
str(main_py_path),
|
|
54
|
+
"-s",
|
|
55
|
+
stdout=f,
|
|
56
|
+
stderr=f,
|
|
57
|
+
env=env,
|
|
58
|
+
)
|
|
59
|
+
return_code = await process.wait()
|
|
60
|
+
logger.debug(
|
|
61
|
+
f"`{client_email}` returns code {return_code} for running `{main_py_path}`"
|
|
62
|
+
)
|
|
63
|
+
return return_code
|
|
64
|
+
except Exception as e:
|
|
65
|
+
logger.error(f"Error running `{main_py_path}` for `{client_email}`: {e}")
|
|
66
|
+
return 1
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
async def _run_simulated_flwr_project(
|
|
70
|
+
project_dir: Path,
|
|
71
|
+
do_clients: list[RDSClient],
|
|
72
|
+
ds_client: RDSClient,
|
|
73
|
+
mock_dataset_paths: list[Union[str, Path]],
|
|
74
|
+
) -> bool:
|
|
75
|
+
"""Run all clients and server concurrently"""
|
|
76
|
+
run_success = True
|
|
77
|
+
|
|
78
|
+
log_dir = project_dir / "simulation_logs"
|
|
79
|
+
log_dir.mkdir(parents=True, exist_ok=True)
|
|
80
|
+
logger.info(f"📝 Log directory: {log_dir}")
|
|
81
|
+
|
|
82
|
+
main_py_path = project_dir / "main.py"
|
|
83
|
+
|
|
84
|
+
logger.info(
|
|
85
|
+
f"Running DS client '{ds_client.email}' with config path {ds_client._syftbox_client.config_path}"
|
|
86
|
+
)
|
|
87
|
+
ds_task: asyncio.Task = asyncio.create_task(
|
|
88
|
+
_run_main_py(
|
|
89
|
+
main_py_path,
|
|
90
|
+
ds_client._syftbox_client.config_path,
|
|
91
|
+
ds_client.email,
|
|
92
|
+
log_dir,
|
|
93
|
+
)
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
client_tasks: list[asyncio.Task] = []
|
|
97
|
+
for client, mock_dataset_path in zip(do_clients, mock_dataset_paths):
|
|
98
|
+
# check if the client has a mock dataset path
|
|
99
|
+
logger.info(
|
|
100
|
+
f"Running DO client '{client.email}' with config path {client._syftbox_client.config_path} on mock dataset {mock_dataset_path}"
|
|
101
|
+
)
|
|
102
|
+
client_tasks.append(
|
|
103
|
+
asyncio.create_task(
|
|
104
|
+
_run_main_py(
|
|
105
|
+
main_py_path,
|
|
106
|
+
client._syftbox_client.config_path,
|
|
107
|
+
client.email,
|
|
108
|
+
log_dir,
|
|
109
|
+
mock_dataset_path,
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
ds_return_code = await ds_task
|
|
115
|
+
if ds_return_code != 0:
|
|
116
|
+
run_success = False
|
|
117
|
+
|
|
118
|
+
# log out ds client logs
|
|
119
|
+
with open(log_dir / f"{ds_client.email}.log", "r") as log_file:
|
|
120
|
+
log_content = log_file.read().strip()
|
|
121
|
+
logger.info(f"DS client '{ds_client.email}' logs:\n{log_content}")
|
|
122
|
+
|
|
123
|
+
# cancel all client tasks if DS client returns
|
|
124
|
+
logger.debug("Cancelling DO client tasks as DS client returned")
|
|
125
|
+
for task in client_tasks:
|
|
126
|
+
if not task.done():
|
|
127
|
+
task.cancel()
|
|
128
|
+
|
|
129
|
+
await asyncio.gather(*client_tasks, return_exceptions=True)
|
|
130
|
+
|
|
131
|
+
return run_success
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _validate_bootstraped_project(project_dir: Path) -> None:
|
|
135
|
+
"""Validate a bootstraped `syft_flwr` project directory"""
|
|
136
|
+
if not project_dir.exists():
|
|
137
|
+
raise FileNotFoundError(f"Project directory {project_dir} does not exist")
|
|
138
|
+
|
|
139
|
+
if not project_dir.is_dir():
|
|
140
|
+
raise NotADirectoryError(f"Project directory {project_dir} is not a directory")
|
|
141
|
+
|
|
142
|
+
if not (project_dir / "main.py").exists():
|
|
143
|
+
raise FileNotFoundError(f"main.py not found at {project_dir}")
|
|
144
|
+
|
|
145
|
+
if not (project_dir / "pyproject.toml").exists():
|
|
146
|
+
raise FileNotFoundError(f"pyproject.toml not found at {project_dir}")
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _validate_mock_dataset_paths(mock_dataset_paths: list[str]) -> list[Path]:
|
|
150
|
+
"""Validate the mock dataset paths"""
|
|
151
|
+
resolved_paths = []
|
|
152
|
+
for path in mock_dataset_paths:
|
|
153
|
+
path = Path(path).expanduser().resolve()
|
|
154
|
+
if not path.exists():
|
|
155
|
+
raise ValueError(f"Mock dataset path {path} does not exist")
|
|
156
|
+
resolved_paths.append(path)
|
|
157
|
+
return resolved_paths
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def run(
|
|
161
|
+
project_dir: Union[str, Path], mock_dataset_paths: list[Union[str, Path]]
|
|
162
|
+
) -> None:
|
|
163
|
+
"""Run a syft_flwr project in simulation mode over mock data"""
|
|
164
|
+
|
|
165
|
+
project_dir = Path(project_dir).expanduser().resolve()
|
|
166
|
+
_validate_bootstraped_project(project_dir)
|
|
167
|
+
mock_dataset_paths = _validate_mock_dataset_paths(mock_dataset_paths)
|
|
168
|
+
|
|
169
|
+
pyproject_conf = load_flwr_pyproject(project_dir)
|
|
170
|
+
datasites = pyproject_conf["tool"]["syft_flwr"]["datasites"]
|
|
171
|
+
aggregator = pyproject_conf["tool"]["syft_flwr"]["aggregator"]
|
|
172
|
+
|
|
173
|
+
key, do_clients, ds_client = _setup_mock_rds_clients(
|
|
174
|
+
project_dir, aggregator, datasites
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
async def main():
|
|
178
|
+
try:
|
|
179
|
+
run_success = await _run_simulated_flwr_project(
|
|
180
|
+
project_dir, do_clients, ds_client, mock_dataset_paths
|
|
181
|
+
)
|
|
182
|
+
if run_success:
|
|
183
|
+
logger.success("Simulation completed successfully ✅")
|
|
184
|
+
else:
|
|
185
|
+
logger.error("Simulation failed ❌")
|
|
186
|
+
except Exception as e:
|
|
187
|
+
logger.error(f"Simulation failed ❌: {e}")
|
|
188
|
+
finally:
|
|
189
|
+
# Clean up the RDS stack
|
|
190
|
+
remove_rds_stack_dir(key)
|
|
191
|
+
logger.debug(f"Removed RDS stack: {key}")
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
loop = asyncio.get_running_loop()
|
|
195
|
+
logger.debug(f"Running in an environment with an existing event loop {loop}")
|
|
196
|
+
# We are in an environment with an existing event loop (like Jupyter)
|
|
197
|
+
asyncio.create_task(main())
|
|
198
|
+
except RuntimeError:
|
|
199
|
+
logger.debug("No existing event loop, creating and running one")
|
|
200
|
+
# No existing event loop, create and run one (for scripts)
|
|
201
|
+
loop = asyncio.new_event_loop()
|
|
202
|
+
asyncio.set_event_loop(loop)
|
|
203
|
+
loop.run_until_complete(main())
|
|
204
|
+
loop.close()
|
syft_flwr/serde.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from flwr.common.message import Message
|
|
2
|
+
from flwr.common.serde import message_from_proto, message_to_proto
|
|
3
|
+
from flwr.proto.message_pb2 import Message as ProtoMessage
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def bytes_to_flower_message(data: bytes) -> Message:
|
|
7
|
+
message_pb = ProtoMessage()
|
|
8
|
+
message_pb.ParseFromString(data)
|
|
9
|
+
message = message_from_proto(message_pb)
|
|
10
|
+
return message
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def flower_message_to_bytes(message: Message) -> bytes:
|
|
14
|
+
msg_proto = message_to_proto(message)
|
|
15
|
+
return msg_proto.SerializeToString()
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
from flwr.common import parameters_to_ndarrays
|
|
4
|
+
from flwr.server.strategy import FedAvg
|
|
5
|
+
from loguru import logger
|
|
6
|
+
from safetensors.numpy import save_file
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FedAvgWithModelSaving(FedAvg):
|
|
10
|
+
"""This is a custom strategy that behaves exactly like
|
|
11
|
+
FedAvg with the difference of storing of the state of
|
|
12
|
+
the global model to disk after each round.
|
|
13
|
+
Ref: https://discuss.flower.ai/t/how-do-i-save-the-global-model-after-training/71/2
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, save_path: str, *args, **kwargs):
|
|
17
|
+
self.save_path = Path(save_path)
|
|
18
|
+
self.save_path.mkdir(exist_ok=True, parents=True)
|
|
19
|
+
super().__init__(*args, **kwargs)
|
|
20
|
+
|
|
21
|
+
def _save_global_model(self, server_round: int, parameters):
|
|
22
|
+
"""A new method to save the parameters to disk."""
|
|
23
|
+
ndarrays = parameters_to_ndarrays(parameters)
|
|
24
|
+
tensor_dict = {f"layer_{i}": array for i, array in enumerate(ndarrays)}
|
|
25
|
+
filename = self.save_path / f"parameters_round_{server_round}.safetensors"
|
|
26
|
+
save_file(tensor_dict, str(filename))
|
|
27
|
+
|
|
28
|
+
logger.info(f"Checkpoint saved to: {filename}")
|
|
29
|
+
|
|
30
|
+
def evaluate(self, server_round: int, parameters):
|
|
31
|
+
"""Evaluate model parameters using an evaluation function."""
|
|
32
|
+
self._save_global_model(server_round, parameters)
|
|
33
|
+
return super().evaluate(server_round, parameters)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from syft_core import Client
|
|
5
|
+
|
|
6
|
+
from syft_flwr.config import load_flwr_pyproject
|
|
7
|
+
from syft_flwr.run import syftbox_run_flwr_client, syftbox_run_flwr_server
|
|
8
|
+
|
|
9
|
+
DATA_DIR = os.getenv("DATA_DIR")
|
|
10
|
+
OUTPUT_DIR = os.getenv("OUTPUT_DIR")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
flower_project_dir = Path(__file__).parent.absolute()
|
|
14
|
+
client = Client.load()
|
|
15
|
+
config = load_flwr_pyproject(flower_project_dir)
|
|
16
|
+
|
|
17
|
+
is_client = client.email in config["tool"]["syft_flwr"]["datasites"]
|
|
18
|
+
is_server = client.email in config["tool"]["syft_flwr"]["aggregator"]
|
|
19
|
+
|
|
20
|
+
if is_client:
|
|
21
|
+
# run by each DO
|
|
22
|
+
syftbox_run_flwr_client(flower_project_dir)
|
|
23
|
+
elif is_server:
|
|
24
|
+
# run by the DS
|
|
25
|
+
syftbox_run_flwr_server(flower_project_dir)
|
|
26
|
+
else:
|
|
27
|
+
raise ValueError(f"{client.email} is not in config.datasites or config.aggregator")
|
syft_flwr/utils.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import zlib
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
EMAIL_REGEX = r"^[^@]+@[^@]+\.[^@]+$"
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def is_valid_datasite(datasite: str) -> bool:
|
|
10
|
+
return re.match(EMAIL_REGEX, datasite)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def str_to_int(input_string: str) -> int:
|
|
14
|
+
"""Convert a string to an int32"""
|
|
15
|
+
return zlib.crc32(input_string.encode())
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_syftbox_dataset_path() -> Path:
|
|
19
|
+
"""Get the path to the syftbox dataset from the environment variable"""
|
|
20
|
+
data_dir = Path(os.getenv("DATA_DIR", ".data/"))
|
|
21
|
+
if not data_dir.exists():
|
|
22
|
+
raise FileNotFoundError(
|
|
23
|
+
f"Path {data_dir} does not exist (must be a valid file or directory)"
|
|
24
|
+
)
|
|
25
|
+
return data_dir
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def run_syft_flwr() -> bool:
|
|
29
|
+
"""Util function to check if we are running with syft_flwr or plain flwr
|
|
30
|
+
Currently only checks the `DATA_DIR` environment variable.
|
|
31
|
+
"""
|
|
32
|
+
try:
|
|
33
|
+
get_syftbox_dataset_path()
|
|
34
|
+
return True
|
|
35
|
+
except FileNotFoundError:
|
|
36
|
+
return False
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: syft-flwr
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
License-File: LICENSE
|
|
6
|
+
Requires-Python: >=3.9.2
|
|
7
|
+
Requires-Dist: flwr-datasets[vision]>=0.5.0
|
|
8
|
+
Requires-Dist: flwr[simulation]==1.17.0
|
|
9
|
+
Requires-Dist: loguru>=0.7.3
|
|
10
|
+
Requires-Dist: safetensors>=0.5.0
|
|
11
|
+
Requires-Dist: syft-core==0.2.3
|
|
12
|
+
Requires-Dist: syft-event==0.2.0
|
|
13
|
+
Requires-Dist: syft-rds==0.1.0
|
|
14
|
+
Requires-Dist: syft-rpc==0.2.0
|
|
15
|
+
Requires-Dist: tomli-w>=1.2.0
|
|
16
|
+
Requires-Dist: tomli>=2.2.1
|
|
17
|
+
Requires-Dist: typing-extensions>=4.13.0
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
|
|
20
|
+
# syft_flwr
|
|
21
|
+
|
|
22
|
+
`syft_flwr` is an open source framework that facilitate federated learning projects using [Flower](https://github.com/adap/flower) over the [SyftBox](https://github.com/OpenMined/syftbox) protocol
|
|
23
|
+
|
|
24
|
+
## Installation
|
|
25
|
+
`pip install syft_flwr`
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
README.md,sha256=rxa8da_OJlI6RVDAXa3tfmZC_dGEN_VFPJXJNMfXl-Q,250
|
|
2
|
+
pyproject.toml,sha256=k7335xXFv_VDmvxmIk_wy6LriaTc92j0th88WDjJ-F0,1148
|
|
3
|
+
syft_flwr/__init__.py,sha256=Tm2mxtuCAXctyyi89c9psk0PuGdv1TOauCg7-gFPyRk,127
|
|
4
|
+
syft_flwr/bootstrap.py,sha256=P6bDLSnCQqrbxQ6-QBc0sR75L5ods8oLKJqItZzNlao,3446
|
|
5
|
+
syft_flwr/cli.py,sha256=imctwdQMxQeGQZaiKSX1Mo2nU_-RmA-cGB3H4huuUeA,3274
|
|
6
|
+
syft_flwr/config.py,sha256=4hwkovGtFOLNULjJwoGYcA0uT4y3vZSrxndXqYXquMY,821
|
|
7
|
+
syft_flwr/flower_client.py,sha256=up15MSEBExiZPgcXHo1EIflWaVN0nO6y08gXrzpumSM,2418
|
|
8
|
+
syft_flwr/flower_server.py,sha256=dz7dvEgFm76s8G266RfN46TohnMPGAk8K6DPWb23LsE,784
|
|
9
|
+
syft_flwr/flwr_compatibility.py,sha256=vURf9rfsZ1uPm04szw6RpGYxtlG3BE4tW3YijptiGyk,3197
|
|
10
|
+
syft_flwr/grid.py,sha256=76Wc1O1W8iGMr9QS97o9wTaU_87UB4fqYA5HnlIHy_U,6087
|
|
11
|
+
syft_flwr/run.py,sha256=gLm_zjzCuWVNKVK2ocJzx_izXDJAW411nl6IXDyrkFM,2008
|
|
12
|
+
syft_flwr/run_simulation.py,sha256=t3shhfzAWDUlf6iQmwf5sS9APZQE9mkaZ9MLCYs9Ng0,6922
|
|
13
|
+
syft_flwr/serde.py,sha256=5fCI-cRUOh5wE7cXQd4J6jex1grRGnyD1Jx-VlEDOXM,495
|
|
14
|
+
syft_flwr/utils.py,sha256=3dDYEB7btq9hxZ9UsfQWh3i44OerAhGXc5XaX5wO3-o,955
|
|
15
|
+
syft_flwr/strategy/__init__.py,sha256=mpUmExjjFkqU8gg41XsOBKfO3aqCBe7XPJSU-_P7smU,97
|
|
16
|
+
syft_flwr/strategy/fedavg.py,sha256=UCgXL_0woRjs7iHaqN16UYNJ7J6ogY3sO6620g1W3sU,1364
|
|
17
|
+
syft_flwr/templates/main.py.tpl,sha256=p0uK97jvLGk3LJdy1_HF1R5BQgIjaTGkYnr-csfh39M,791
|
|
18
|
+
syft_flwr-0.1.0.dist-info/METADATA,sha256=wnKYSmLIbSoKAd7A1Wk7PatwD_g5TuByahxzJYJitcc,799
|
|
19
|
+
syft_flwr-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
20
|
+
syft_flwr-0.1.0.dist-info/entry_points.txt,sha256=o7oT0dCoHn-3WyIwdDw1lBh2q-GvhB_8s0hWeJU4myc,49
|
|
21
|
+
syft_flwr-0.1.0.dist-info/licenses/LICENSE,sha256=0msOUar8uPZTqkAOTBp4rCzd7Jl9eRhfKiNufwrsg7k,11361
|
|
22
|
+
syft_flwr-0.1.0.dist-info/RECORD,,
|