syft-flwr 0.1.0__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of syft-flwr might be problematic. Click here for more details.

@@ -182,6 +182,7 @@ checkpoints/
182
182
  # Datasets
183
183
  **MedMNIST**
184
184
  **/datasets/
185
+ **/dataset/
185
186
 
186
187
  # Ruff cache
187
188
  .ruff_cache/
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: syft-flwr
3
- Version: 0.1.0
4
- Summary: Add your description here
3
+ Version: 0.1.1
4
+ Summary: syft_flwr is an open source framework that facilitate federated learning projects using Flower over the SyftBox protocol
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.9.2
7
7
  Requires-Dist: flwr-datasets[vision]>=0.5.0
@@ -22,4 +22,6 @@ Description-Content-Type: text/markdown
22
22
  `syft_flwr` is an open source framework that facilitate federated learning projects using [Flower](https://github.com/adap/flower) over the [SyftBox](https://github.com/OpenMined/syftbox) protocol
23
23
 
24
24
  ## Installation
25
- `pip install syft_flwr`
25
+ - Install uv: `brew install uv`
26
+ - Create a virtual environment: `uv venv`
27
+ - Install `syft-flwr`: `uv pip install syft-flwr`
@@ -3,4 +3,6 @@
3
3
  `syft_flwr` is an open source framework that facilitate federated learning projects using [Flower](https://github.com/adap/flower) over the [SyftBox](https://github.com/OpenMined/syftbox) protocol
4
4
 
5
5
  ## Installation
6
- `pip install syft_flwr`
6
+ - Install uv: `brew install uv`
7
+ - Create a virtual environment: `uv venv`
8
+ - Install `syft-flwr`: `uv pip install syft-flwr`
@@ -1,7 +1,7 @@
1
1
  [project]
2
2
  name = "syft-flwr"
3
- version = "0.1.0"
4
- description = "Add your description here"
3
+ version = "0.1.1"
4
+ description = "syft_flwr is an open source framework that facilitate federated learning projects using Flower over the SyftBox protocol"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.9.2"
7
7
  dependencies = [
@@ -21,10 +21,6 @@ dependencies = [
21
21
  [project.scripts]
22
22
  syft_flwr = "syft_flwr.cli:main"
23
23
 
24
- [build-system]
25
- requires = ["hatchling"]
26
- build-backend = "hatchling.build"
27
-
28
24
  [tool.uv]
29
25
  dev-dependencies = [
30
26
  "ipykernel>=6.29.5",
@@ -33,6 +29,10 @@ dev-dependencies = [
33
29
  "pre-commit>=4.0.1",
34
30
  ]
35
31
 
32
+ [build-system]
33
+ requires = ["hatchling"]
34
+ build-backend = "hatchling.build"
35
+
36
36
  [tool.hatch.build.targets.wheel]
37
37
  packages = ["src/syft_flwr"]
38
38
  only-include = ["src", "pyproject.toml", "/README.md"]
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.0"
1
+ __version__ = "0.1.1"
2
2
 
3
3
  from syft_flwr.bootstrap import bootstrap
4
4
  from syft_flwr.run import run
@@ -1,3 +1,4 @@
1
+ import time
1
2
  from pathlib import Path
2
3
 
3
4
  from loguru import logger
@@ -50,8 +51,15 @@ def __update_pyproject_toml(
50
51
  deps.append(f"syft_flwr=={__version__}")
51
52
  pyproject_conf["project"]["dependencies"] = deps
52
53
 
53
- # always override the datasites and aggregator
54
54
  pyproject_conf["tool"]["syft_flwr"] = {}
55
+
56
+ # configure unique app name for each syft_flwr run
57
+ base_app_name = pyproject_conf["project"]["name"]
58
+ pyproject_conf["tool"]["syft_flwr"]["app_name"] = (
59
+ f"{aggregator}_{base_app_name}_{int(time.time())}"
60
+ )
61
+
62
+ # always override the datasites and aggregator
55
63
  pyproject_conf["tool"]["syft_flwr"]["datasites"] = datasites
56
64
  pyproject_conf["tool"]["syft_flwr"]["aggregator"] = aggregator
57
65
 
@@ -0,0 +1,87 @@
1
+ import sys
2
+ import traceback
3
+
4
+ from loguru import logger
5
+ from syft_event import SyftEvents
6
+ from syft_event.types import Request
7
+
8
+ from flwr.client import ClientApp
9
+ from flwr.common import Context
10
+ from flwr.common.constant import ErrorCode, MessageType
11
+ from flwr.common.message import Error, Message
12
+ from syft_flwr.flwr_compatibility import RecordDict, create_flwr_message
13
+ from syft_flwr.serde import bytes_to_flower_message, flower_message_to_bytes
14
+
15
+
16
+ def _handle_normal_message(
17
+ message: Message, client_app: ClientApp, context: Context
18
+ ) -> bytes:
19
+ # Normal message handling
20
+ logger.info(f"Receive message with metadata: {message.metadata}")
21
+ reply_message: Message = client_app(message=message, context=context)
22
+ res_bytes: bytes = flower_message_to_bytes(reply_message)
23
+ logger.info(f"Reply message size: {len(res_bytes)/2**20} MB")
24
+ return res_bytes
25
+
26
+
27
+ def _create_error_reply(message: Message, error: Error) -> bytes:
28
+ """Create and return error reply message in bytes."""
29
+ error_reply: Message = create_flwr_message(
30
+ content=RecordDict(),
31
+ reply_to=message,
32
+ message_type=message.metadata.message_type,
33
+ src_node_id=message.metadata.dst_node_id,
34
+ dst_node_id=message.metadata.src_node_id,
35
+ group_id=message.metadata.group_id,
36
+ run_id=message.metadata.run_id,
37
+ error=error,
38
+ )
39
+ error_bytes: bytes = flower_message_to_bytes(error_reply)
40
+ logger.info(f"Error reply message size: {len(error_bytes)/2**20} MB")
41
+ return error_bytes
42
+
43
+
44
+ def syftbox_flwr_client(client_app: ClientApp, context: Context, app_name: str):
45
+ """Run the Flower ClientApp with SyftBox."""
46
+ syft_flwr_app_name = f"flwr/{app_name}"
47
+ box = SyftEvents(syft_flwr_app_name)
48
+ client_email = box.client.email
49
+ logger.info(f"Started SyftBox Flower Client on: {client_email}")
50
+ logger.info(f"syft_flwr app name: {syft_flwr_app_name}")
51
+
52
+ @box.on_request("/messages")
53
+ def handle_messages(request: Request) -> None:
54
+ logger.info(
55
+ f"Received request id: {request.id}, size: {len(request.body) / 1024 / 1024} (MB)"
56
+ )
57
+ message: Message = bytes_to_flower_message(request.body)
58
+ try:
59
+ # Handle stop signal
60
+ if (
61
+ message.metadata.message_type == MessageType.SYSTEM
62
+ and message.content["config"]["action"] == "stop"
63
+ ):
64
+ logger.info(f"Received stop message: {message}")
65
+ box._stop_event.set()
66
+ return None
67
+
68
+ return _handle_normal_message(message, client_app, context)
69
+
70
+ except Exception as e:
71
+ error_traceback = traceback.format_exc()
72
+ error_message = f"Client: '{client_email}'. Error: {str(e)}. Traceback: {error_traceback}"
73
+ logger.error(error_message)
74
+
75
+ error = Error(
76
+ code=ErrorCode.CLIENT_APP_RAISED_EXCEPTION, reason=error_message
77
+ )
78
+ box._stop_event.set()
79
+ return _create_error_reply(message, error)
80
+
81
+ try:
82
+ box.run_forever()
83
+ except Exception as e:
84
+ logger.error(
85
+ f"Fatal error in syftbox_flwr_client: {str(e)}\n{traceback.format_exc()}"
86
+ )
87
+ sys.exit(1)
@@ -0,0 +1,42 @@
1
+ import traceback
2
+ from random import randint
3
+
4
+ from loguru import logger
5
+
6
+ from flwr.common import Context
7
+ from flwr.server import ServerApp
8
+ from flwr.server.run_serverapp import run as run_server
9
+ from syft_flwr.grid import SyftGrid
10
+
11
+
12
+ def syftbox_flwr_server(
13
+ server_app: ServerApp,
14
+ context: Context,
15
+ datasites: list[str],
16
+ app_name: str,
17
+ ) -> Context:
18
+ """Run the Flower ServerApp with SyftBox."""
19
+ syft_flwr_app_name = f"flwr/{app_name}"
20
+ syft_grid = SyftGrid(app_name=syft_flwr_app_name, datasites=datasites)
21
+ run_id = randint(0, 1000)
22
+ syft_grid.set_run(run_id)
23
+ logger.info(f"Started SyftBox Flower Server on: {syft_grid._client.email}")
24
+ logger.info(f"syft_flwr app name: {syft_flwr_app_name}")
25
+
26
+ try:
27
+ updated_context = run_server(
28
+ syft_grid,
29
+ context=context,
30
+ loaded_server_app=server_app,
31
+ server_app_dir="",
32
+ )
33
+ logger.info(f"Server completed with context: {updated_context}")
34
+ except Exception as e:
35
+ logger.error(f"Server encountered an error: {str(e)}")
36
+ logger.error(f"Traceback: {traceback.format_exc()}")
37
+ updated_context = context
38
+ finally:
39
+ syft_grid.send_stop_signal(group_id="final", reason="Server stopped")
40
+ logger.info("Sending stop signals to the clients")
41
+
42
+ return updated_context
@@ -1,14 +1,16 @@
1
1
  import time
2
2
  from typing import Iterable, cast
3
3
 
4
- from flwr.common.message import Message
5
- from flwr.common.typing import Run
6
- from flwr.proto.node_pb2 import Node # pylint: disable=E0611
7
4
  from loguru import logger
8
5
  from syft_core import Client
9
6
  from syft_rpc import rpc, rpc_db
10
7
  from typing_extensions import Optional
11
8
 
9
+ from flwr.common import ConfigRecord
10
+ from flwr.common.constant import MessageType
11
+ from flwr.common.message import Message
12
+ from flwr.common.typing import Run
13
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
12
14
  from syft_flwr.flwr_compatibility import (
13
15
  Grid,
14
16
  RecordDict,
@@ -25,6 +27,7 @@ AGGREGATOR_NODE_ID = 1
25
27
  class SyftGrid(Grid):
26
28
  def __init__(
27
29
  self,
30
+ app_name: str,
28
31
  datasites: list[str] = [],
29
32
  client: Client = None,
30
33
  ) -> None:
@@ -36,6 +39,7 @@ class SyftGrid(Grid):
36
39
  logger.debug(
37
40
  f"Initialize SyftGrid for '{self._client.email}' with datasites: {self.datasites}"
38
41
  )
42
+ self.app_name = app_name
39
43
 
40
44
  def set_run(self, run_id: int) -> None:
41
45
  # TODO: In Grpc Grid case, the superlink is the one which sets up the run id,
@@ -96,7 +100,9 @@ class SyftGrid(Grid):
96
100
  msg.metadata.__dict__["_src_node_id"] = self.node.node_id
97
101
  # RPC URL
98
102
  dest_datasite = self.client_map[msg.metadata.dst_node_id]
99
- url = rpc.make_url(dest_datasite, app_name="flwr", endpoint="messages")
103
+ url = rpc.make_url(
104
+ dest_datasite, app_name=self.app_name, endpoint="messages"
105
+ )
100
106
  # Check message
101
107
  self._check_message(msg)
102
108
  # Serialize message
@@ -107,7 +113,9 @@ class SyftGrid(Grid):
107
113
  f"Pushed message to {url} with metadata {msg.metadata}; size {len(msg_bytes) / 1024 / 1024} (Mb)"
108
114
  )
109
115
  # Save future
110
- rpc_db.save_future(future=future, namespace="flwr", client=self._client)
116
+ rpc_db.save_future(
117
+ future=future, namespace=self.app_name, client=self._client
118
+ )
111
119
  message_ids.append(future.id)
112
120
 
113
121
  return message_ids
@@ -147,13 +155,13 @@ class SyftGrid(Grid):
147
155
  self,
148
156
  messages: Iterable[Message],
149
157
  *,
150
- timeout: Optional[float] = None,
158
+ timeout: Optional[float] = 60,
151
159
  ) -> Iterable[Message]:
152
160
  """Push messages to specified node IDs and pull the reply messages.
153
161
 
154
162
  This method sends a list of messages to their destination node IDs and then
155
163
  waits for the replies. It continues to pull replies until either all replies are
156
- received or the specified timeout duration is exceeded.
164
+ received or the specified timeout duration (in seconds) is exceeded.
157
165
  """
158
166
  # Push messages
159
167
  msg_ids = set(self.push_messages(messages))
@@ -165,7 +173,33 @@ class SyftGrid(Grid):
165
173
  res_msgs = self.pull_messages(msg_ids)
166
174
  ret.update(res_msgs)
167
175
  msg_ids.difference_update(res_msgs.keys())
168
- if len(msg_ids) == 0:
176
+ if len(msg_ids) == 0: # All messages received
169
177
  break
170
- time.sleep(3)
178
+ time.sleep(3) # polling interval
179
+
180
+ if msg_ids:
181
+ logger.warning(
182
+ f"Timeout reached. {len(msg_ids)} message(s) sent out but not replied."
183
+ )
184
+
171
185
  return ret.values()
186
+
187
+ def send_stop_signal(
188
+ self, group_id: str, reason: str = "Training complete", ttl: float = 60.0
189
+ ) -> list[Message]:
190
+ """Send a stop signal to all datasites (clients)."""
191
+ stop_messages: list[Message] = [
192
+ self.create_message(
193
+ content=RecordDict(
194
+ {"config": ConfigRecord({"action": "stop", "reason": reason})}
195
+ ),
196
+ message_type=MessageType.SYSTEM,
197
+ dst_node_id=node_id,
198
+ group_id=group_id,
199
+ ttl=ttl,
200
+ )
201
+ for node_id in self.get_node_ids()
202
+ ]
203
+ self.push_messages(stop_messages)
204
+
205
+ return stop_messages
@@ -6,7 +6,6 @@ from flwr.client.client_app import LoadClientAppError
6
6
  from flwr.common import Context
7
7
  from flwr.common.object_ref import load_app
8
8
  from flwr.server.server_app import LoadServerAppError
9
-
10
9
  from syft_flwr.config import load_flwr_pyproject
11
10
  from syft_flwr.flower_client import syftbox_flwr_client
12
11
  from syft_flwr.flower_server import syftbox_flwr_server
@@ -23,6 +22,7 @@ warnings.filterwarnings("ignore", category=DeprecationWarning, module="pydantic"
23
22
  def syftbox_run_flwr_client(flower_project_dir: Path) -> None:
24
23
  pyproject_conf = load_flwr_pyproject(flower_project_dir)
25
24
  client_ref = pyproject_conf["tool"]["flwr"]["app"]["components"]["clientapp"]
25
+ app_name = pyproject_conf["tool"]["syft_flwr"]["app_name"]
26
26
 
27
27
  context = Context(
28
28
  run_id=uuid4().int,
@@ -37,13 +37,14 @@ def syftbox_run_flwr_client(flower_project_dir: Path) -> None:
37
37
  flower_project_dir,
38
38
  )
39
39
 
40
- syftbox_flwr_client(client_app, context)
40
+ syftbox_flwr_client(client_app, context, app_name)
41
41
 
42
42
 
43
43
  def syftbox_run_flwr_server(flower_project_dir: Path) -> None:
44
44
  pyproject_conf = load_flwr_pyproject(flower_project_dir)
45
45
  datasites = pyproject_conf["tool"]["syft_flwr"]["datasites"]
46
46
  server_ref = pyproject_conf["tool"]["flwr"]["app"]["components"]["serverapp"]
47
+ app_name = pyproject_conf["tool"]["syft_flwr"]["app_name"]
47
48
 
48
49
  context = Context(
49
50
  run_id=uuid4().int,
@@ -58,4 +59,4 @@ def syftbox_run_flwr_server(flower_project_dir: Path) -> None:
58
59
  flower_project_dir,
59
60
  )
60
61
 
61
- syftbox_flwr_server(server_app, context, datasites)
62
+ syftbox_flwr_server(server_app, context, datasites, app_name)
@@ -1,10 +1,11 @@
1
1
  from pathlib import Path
2
2
 
3
- from flwr.common import parameters_to_ndarrays
4
- from flwr.server.strategy import FedAvg
5
3
  from loguru import logger
6
4
  from safetensors.numpy import save_file
7
5
 
6
+ from flwr.common import parameters_to_ndarrays
7
+ from flwr.server.strategy import FedAvg
8
+
8
9
 
9
10
  class FedAvgWithModelSaving(FedAvg):
10
11
  """This is a custom strategy that behaves exactly like
@@ -23,9 +24,13 @@ class FedAvgWithModelSaving(FedAvg):
23
24
  ndarrays = parameters_to_ndarrays(parameters)
24
25
  tensor_dict = {f"layer_{i}": array for i, array in enumerate(ndarrays)}
25
26
  filename = self.save_path / f"parameters_round_{server_round}.safetensors"
26
- save_file(tensor_dict, str(filename))
27
-
28
- logger.info(f"Checkpoint saved to: {filename}")
27
+ if not self.save_path.exists():
28
+ logger.error(
29
+ f"Save directory {self.save_path} does NOT exist! Maybe it's deleted or moved."
30
+ )
31
+ else:
32
+ save_file(tensor_dict, str(filename))
33
+ logger.info(f"Checkpoint saved to: {filename}")
29
34
 
30
35
  def evaluate(self, server_round: int, parameters):
31
36
  """Evaluate model parameters using an evaluation function."""
@@ -1,65 +0,0 @@
1
- import sys
2
- import traceback
3
-
4
- from flwr.client import ClientApp
5
- from flwr.common import Context
6
- from flwr.common.constant import ErrorCode
7
- from flwr.common.message import Error, Message
8
- from loguru import logger
9
- from syft_event import SyftEvents
10
- from syft_event.types import Request
11
-
12
- from syft_flwr.flwr_compatibility import RecordDict, create_flwr_message
13
- from syft_flwr.serde import bytes_to_flower_message, flower_message_to_bytes
14
-
15
-
16
- def syftbox_flwr_client(client_app: ClientApp, context: Context):
17
- """Run the Flower ClientApp with SyftBox."""
18
-
19
- box = SyftEvents("flwr")
20
- client_email = box.client.email
21
- logger.info(f"Started SyftBox Flower Client on: {client_email}")
22
-
23
- @box.on_request("/messages")
24
- def handle_messages(request: Request) -> None:
25
- logger.info(
26
- f"Received request id: {request.id}, size: {len(request.body) / 1024 / 1024} (MB)"
27
- )
28
- message: Message = bytes_to_flower_message(request.body)
29
-
30
- try:
31
- reply_message: Message = client_app(message=message, context=context)
32
- res_bytes: bytes = flower_message_to_bytes(reply_message)
33
- logger.info(f"Reply message size: {len(res_bytes)/2**20} MB")
34
- return res_bytes
35
-
36
- except Exception as e:
37
- error_traceback = traceback.format_exc()
38
- error_message = f"Client: '{client_email}'. Error: {str(e)}. Traceback: {error_traceback}"
39
- logger.error(error_message)
40
-
41
- error = Error(
42
- code=ErrorCode.CLIENT_APP_RAISED_EXCEPTION, reason=f"{error_message}"
43
- )
44
-
45
- error_reply: Message = create_flwr_message(
46
- content=RecordDict(),
47
- reply_to=message,
48
- message_type=message.metadata.message_type,
49
- src_node_id=message.metadata.dst_node_id,
50
- dst_node_id=message.metadata.src_node_id,
51
- group_id=message.metadata.group_id,
52
- run_id=message.metadata.run_id,
53
- error=error,
54
- )
55
- error_bytes: bytes = flower_message_to_bytes(error_reply)
56
- logger.info(f"Error reply message size: {len(error_bytes)/2**20} MB")
57
- return error_bytes
58
-
59
- try:
60
- box.run_forever()
61
- except Exception as e:
62
- logger.error(
63
- f"Fatal error in syftbox_flwr_client: {str(e)}\n{traceback.format_exc()}"
64
- )
65
- sys.exit(1)
@@ -1,25 +0,0 @@
1
- from random import randint
2
-
3
- from flwr.common import Context
4
- from flwr.server import ServerApp
5
- from flwr.server.run_serverapp import run as run_server
6
- from loguru import logger
7
-
8
- from syft_flwr.grid import SyftGrid
9
-
10
-
11
- def syftbox_flwr_server(server_app: ServerApp, context: Context, datasites: list[str]):
12
- """Run the Flower ServerApp with SyftBox."""
13
- syft_grid = SyftGrid(datasites=datasites)
14
- run_id = randint(0, 1000)
15
- syft_grid.set_run(run_id)
16
- logger.info(f"Started SyftBox Flower Server on: {syft_grid._client.email}")
17
-
18
- updated_context = run_server(
19
- syft_grid,
20
- context=context,
21
- loaded_server_app=server_app,
22
- server_app_dir="",
23
- )
24
- logger.info(f"Server completed with context: {updated_context}")
25
- return updated_context
File without changes