syft-flwr 0.1.0__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of syft-flwr might be problematic. Click here for more details.

@@ -182,6 +182,7 @@ checkpoints/
182
182
  # Datasets
183
183
  **MedMNIST**
184
184
  **/datasets/
185
+ **/dataset/
185
186
 
186
187
  # Ruff cache
187
188
  .ruff_cache/
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: syft-flwr
3
- Version: 0.1.0
4
- Summary: Add your description here
3
+ Version: 0.1.2
4
+ Summary: syft_flwr is an open source framework that facilitate federated learning projects using Flower over the SyftBox protocol
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.9.2
7
7
  Requires-Dist: flwr-datasets[vision]>=0.5.0
@@ -22,4 +22,6 @@ Description-Content-Type: text/markdown
22
22
  `syft_flwr` is an open source framework that facilitate federated learning projects using [Flower](https://github.com/adap/flower) over the [SyftBox](https://github.com/OpenMined/syftbox) protocol
23
23
 
24
24
  ## Installation
25
- `pip install syft_flwr`
25
+ - Install uv: `brew install uv`
26
+ - Create a virtual environment: `uv venv`
27
+ - Install `syft-flwr`: `uv pip install syft-flwr`
@@ -3,4 +3,6 @@
3
3
  `syft_flwr` is an open source framework that facilitate federated learning projects using [Flower](https://github.com/adap/flower) over the [SyftBox](https://github.com/OpenMined/syftbox) protocol
4
4
 
5
5
  ## Installation
6
- `pip install syft_flwr`
6
+ - Install uv: `brew install uv`
7
+ - Create a virtual environment: `uv venv`
8
+ - Install `syft-flwr`: `uv pip install syft-flwr`
@@ -1,7 +1,7 @@
1
1
  [project]
2
2
  name = "syft-flwr"
3
- version = "0.1.0"
4
- description = "Add your description here"
3
+ version = "0.1.2"
4
+ description = "syft_flwr is an open source framework that facilitate federated learning projects using Flower over the SyftBox protocol"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.9.2"
7
7
  dependencies = [
@@ -21,10 +21,6 @@ dependencies = [
21
21
  [project.scripts]
22
22
  syft_flwr = "syft_flwr.cli:main"
23
23
 
24
- [build-system]
25
- requires = ["hatchling"]
26
- build-backend = "hatchling.build"
27
-
28
24
  [tool.uv]
29
25
  dev-dependencies = [
30
26
  "ipykernel>=6.29.5",
@@ -33,6 +29,10 @@ dev-dependencies = [
33
29
  "pre-commit>=4.0.1",
34
30
  ]
35
31
 
32
+ [build-system]
33
+ requires = ["hatchling"]
34
+ build-backend = "hatchling.build"
35
+
36
36
  [tool.hatch.build.targets.wheel]
37
37
  packages = ["src/syft_flwr"]
38
38
  only-include = ["src", "pyproject.toml", "/README.md"]
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.0"
1
+ __version__ = "0.1.2"
2
2
 
3
3
  from syft_flwr.bootstrap import bootstrap
4
4
  from syft_flwr.run import run
@@ -1,3 +1,4 @@
1
+ import time
1
2
  from pathlib import Path
2
3
 
3
4
  from loguru import logger
@@ -50,8 +51,15 @@ def __update_pyproject_toml(
50
51
  deps.append(f"syft_flwr=={__version__}")
51
52
  pyproject_conf["project"]["dependencies"] = deps
52
53
 
53
- # always override the datasites and aggregator
54
54
  pyproject_conf["tool"]["syft_flwr"] = {}
55
+
56
+ # configure unique app name for each syft_flwr run
57
+ base_app_name = pyproject_conf["project"]["name"]
58
+ pyproject_conf["tool"]["syft_flwr"]["app_name"] = (
59
+ f"{aggregator}_{base_app_name}_{int(time.time())}"
60
+ )
61
+
62
+ # always override the datasites and aggregator
55
63
  pyproject_conf["tool"]["syft_flwr"]["datasites"] = datasites
56
64
  pyproject_conf["tool"]["syft_flwr"]["aggregator"] = aggregator
57
65
 
@@ -0,0 +1,87 @@
1
+ import sys
2
+ import traceback
3
+
4
+ from loguru import logger
5
+ from syft_event import SyftEvents
6
+ from syft_event.types import Request
7
+
8
+ from flwr.client import ClientApp
9
+ from flwr.common import Context
10
+ from flwr.common.constant import ErrorCode, MessageType
11
+ from flwr.common.message import Error, Message
12
+ from syft_flwr.flwr_compatibility import RecordDict, create_flwr_message
13
+ from syft_flwr.serde import bytes_to_flower_message, flower_message_to_bytes
14
+
15
+
16
+ def _handle_normal_message(
17
+ message: Message, client_app: ClientApp, context: Context
18
+ ) -> bytes:
19
+ # Normal message handling
20
+ logger.info(f"Receive message with metadata: {message.metadata}")
21
+ reply_message: Message = client_app(message=message, context=context)
22
+ res_bytes: bytes = flower_message_to_bytes(reply_message)
23
+ logger.info(f"Reply message size: {len(res_bytes)/2**20} MB")
24
+ return res_bytes
25
+
26
+
27
+ def _create_error_reply(message: Message, error: Error) -> bytes:
28
+ """Create and return error reply message in bytes."""
29
+ error_reply: Message = create_flwr_message(
30
+ content=RecordDict(),
31
+ reply_to=message,
32
+ message_type=message.metadata.message_type,
33
+ src_node_id=message.metadata.dst_node_id,
34
+ dst_node_id=message.metadata.src_node_id,
35
+ group_id=message.metadata.group_id,
36
+ run_id=message.metadata.run_id,
37
+ error=error,
38
+ )
39
+ error_bytes: bytes = flower_message_to_bytes(error_reply)
40
+ logger.info(f"Error reply message size: {len(error_bytes)/2**20} MB")
41
+ return error_bytes
42
+
43
+
44
+ def syftbox_flwr_client(client_app: ClientApp, context: Context, app_name: str):
45
+ """Run the Flower ClientApp with SyftBox."""
46
+ syft_flwr_app_name = f"flwr/{app_name}"
47
+ box = SyftEvents(syft_flwr_app_name)
48
+ client_email = box.client.email
49
+ logger.info(f"Started SyftBox Flower Client on: {client_email}")
50
+ logger.info(f"syft_flwr app name: {syft_flwr_app_name}")
51
+
52
+ @box.on_request("/messages")
53
+ def handle_messages(request: Request) -> None:
54
+ logger.info(
55
+ f"Received request id: {request.id}, size: {len(request.body) / 1024 / 1024} (MB)"
56
+ )
57
+ message: Message = bytes_to_flower_message(request.body)
58
+ try:
59
+ # Handle stop signal
60
+ if (
61
+ message.metadata.message_type == MessageType.SYSTEM
62
+ and message.content["config"]["action"] == "stop"
63
+ ):
64
+ logger.info(f"Received stop message: {message}")
65
+ box._stop_event.set()
66
+ return None
67
+
68
+ return _handle_normal_message(message, client_app, context)
69
+
70
+ except Exception as e:
71
+ error_traceback = traceback.format_exc()
72
+ error_message = f"Client: '{client_email}'. Error: {str(e)}. Traceback: {error_traceback}"
73
+ logger.error(error_message)
74
+
75
+ error = Error(
76
+ code=ErrorCode.CLIENT_APP_RAISED_EXCEPTION, reason=error_message
77
+ )
78
+ box._stop_event.set()
79
+ return _create_error_reply(message, error)
80
+
81
+ try:
82
+ box.run_forever()
83
+ except Exception as e:
84
+ logger.error(
85
+ f"Fatal error in syftbox_flwr_client: {str(e)}\n{traceback.format_exc()}"
86
+ )
87
+ sys.exit(1)
@@ -0,0 +1,42 @@
1
+ import traceback
2
+ from random import randint
3
+
4
+ from loguru import logger
5
+
6
+ from flwr.common import Context
7
+ from flwr.server import ServerApp
8
+ from flwr.server.run_serverapp import run as run_server
9
+ from syft_flwr.grid import SyftGrid
10
+
11
+
12
+ def syftbox_flwr_server(
13
+ server_app: ServerApp,
14
+ context: Context,
15
+ datasites: list[str],
16
+ app_name: str,
17
+ ) -> Context:
18
+ """Run the Flower ServerApp with SyftBox."""
19
+ syft_flwr_app_name = f"flwr/{app_name}"
20
+ syft_grid = SyftGrid(app_name=syft_flwr_app_name, datasites=datasites)
21
+ run_id = randint(0, 1000)
22
+ syft_grid.set_run(run_id)
23
+ logger.info(f"Started SyftBox Flower Server on: {syft_grid._client.email}")
24
+ logger.info(f"syft_flwr app name: {syft_flwr_app_name}")
25
+
26
+ try:
27
+ updated_context = run_server(
28
+ syft_grid,
29
+ context=context,
30
+ loaded_server_app=server_app,
31
+ server_app_dir="",
32
+ )
33
+ logger.info(f"Server completed with context: {updated_context}")
34
+ except Exception as e:
35
+ logger.error(f"Server encountered an error: {str(e)}")
36
+ logger.error(f"Traceback: {traceback.format_exc()}")
37
+ updated_context = context
38
+ finally:
39
+ syft_grid.send_stop_signal(group_id="final", reason="Server stopped")
40
+ logger.info("Sending stop signals to the clients")
41
+
42
+ return updated_context
@@ -1,6 +1,9 @@
1
+ import os
1
2
  import time
2
3
  from typing import Iterable, cast
3
4
 
5
+ from flwr.common import ConfigRecord
6
+ from flwr.common.constant import MessageType
4
7
  from flwr.common.message import Message
5
8
  from flwr.common.typing import Run
6
9
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
@@ -22,9 +25,14 @@ from syft_flwr.utils import str_to_int
22
25
  AGGREGATOR_NODE_ID = 1
23
26
 
24
27
 
28
+ # env vars
29
+ SYFT_FLWR_MSG_TIMEOUT = "SYFT_FLWR_MSG_TIMEOUT"
30
+
31
+
25
32
  class SyftGrid(Grid):
26
33
  def __init__(
27
34
  self,
35
+ app_name: str,
28
36
  datasites: list[str] = [],
29
37
  client: Client = None,
30
38
  ) -> None:
@@ -36,6 +44,7 @@ class SyftGrid(Grid):
36
44
  logger.debug(
37
45
  f"Initialize SyftGrid for '{self._client.email}' with datasites: {self.datasites}"
38
46
  )
47
+ self.app_name = app_name
39
48
 
40
49
  def set_run(self, run_id: int) -> None:
41
50
  # TODO: In Grpc Grid case, the superlink is the one which sets up the run id,
@@ -96,7 +105,9 @@ class SyftGrid(Grid):
96
105
  msg.metadata.__dict__["_src_node_id"] = self.node.node_id
97
106
  # RPC URL
98
107
  dest_datasite = self.client_map[msg.metadata.dst_node_id]
99
- url = rpc.make_url(dest_datasite, app_name="flwr", endpoint="messages")
108
+ url = rpc.make_url(
109
+ dest_datasite, app_name=self.app_name, endpoint="messages"
110
+ )
100
111
  # Check message
101
112
  self._check_message(msg)
102
113
  # Serialize message
@@ -107,7 +118,9 @@ class SyftGrid(Grid):
107
118
  f"Pushed message to {url} with metadata {msg.metadata}; size {len(msg_bytes) / 1024 / 1024} (Mb)"
108
119
  )
109
120
  # Save future
110
- rpc_db.save_future(future=future, namespace="flwr", client=self._client)
121
+ rpc_db.save_future(
122
+ future=future, namespace=self.app_name, client=self._client
123
+ )
111
124
  message_ids.append(future.id)
112
125
 
113
126
  return message_ids
@@ -153,8 +166,19 @@ class SyftGrid(Grid):
153
166
 
154
167
  This method sends a list of messages to their destination node IDs and then
155
168
  waits for the replies. It continues to pull replies until either all replies are
156
- received or the specified timeout duration is exceeded.
169
+ received or the specified timeout duration (in seconds) is exceeded.
157
170
  """
171
+ if os.environ.get(SYFT_FLWR_MSG_TIMEOUT) is not None:
172
+ timeout = float(os.environ.get(SYFT_FLWR_MSG_TIMEOUT))
173
+ if timeout is not None:
174
+ logger.debug(
175
+ f"syft_flwr messages timeout = {timeout}: Will move on after {timeout} (s) if no reply is received"
176
+ )
177
+ else:
178
+ logger.debug(
179
+ "syft_flwr messages timeout = None: Will wait indefinitely for replies"
180
+ )
181
+
158
182
  # Push messages
159
183
  msg_ids = set(self.push_messages(messages))
160
184
 
@@ -165,7 +189,33 @@ class SyftGrid(Grid):
165
189
  res_msgs = self.pull_messages(msg_ids)
166
190
  ret.update(res_msgs)
167
191
  msg_ids.difference_update(res_msgs.keys())
168
- if len(msg_ids) == 0:
192
+ if len(msg_ids) == 0: # All messages received
169
193
  break
170
- time.sleep(3)
194
+ time.sleep(3) # polling interval
195
+
196
+ if msg_ids:
197
+ logger.warning(
198
+ f"Timeout reached. {len(msg_ids)} message(s) sent out but not replied."
199
+ )
200
+
171
201
  return ret.values()
202
+
203
+ def send_stop_signal(
204
+ self, group_id: str, reason: str = "Training complete", ttl: float = 60.0
205
+ ) -> list[Message]:
206
+ """Send a stop signal to all datasites (clients)."""
207
+ stop_messages: list[Message] = [
208
+ self.create_message(
209
+ content=RecordDict(
210
+ {"config": ConfigRecord({"action": "stop", "reason": reason})}
211
+ ),
212
+ message_type=MessageType.SYSTEM,
213
+ dst_node_id=node_id,
214
+ group_id=group_id,
215
+ ttl=ttl,
216
+ )
217
+ for node_id in self.get_node_ids()
218
+ ]
219
+ self.push_messages(stop_messages)
220
+
221
+ return stop_messages
@@ -6,7 +6,6 @@ from flwr.client.client_app import LoadClientAppError
6
6
  from flwr.common import Context
7
7
  from flwr.common.object_ref import load_app
8
8
  from flwr.server.server_app import LoadServerAppError
9
-
10
9
  from syft_flwr.config import load_flwr_pyproject
11
10
  from syft_flwr.flower_client import syftbox_flwr_client
12
11
  from syft_flwr.flower_server import syftbox_flwr_server
@@ -23,6 +22,7 @@ warnings.filterwarnings("ignore", category=DeprecationWarning, module="pydantic"
23
22
  def syftbox_run_flwr_client(flower_project_dir: Path) -> None:
24
23
  pyproject_conf = load_flwr_pyproject(flower_project_dir)
25
24
  client_ref = pyproject_conf["tool"]["flwr"]["app"]["components"]["clientapp"]
25
+ app_name = pyproject_conf["tool"]["syft_flwr"]["app_name"]
26
26
 
27
27
  context = Context(
28
28
  run_id=uuid4().int,
@@ -37,13 +37,14 @@ def syftbox_run_flwr_client(flower_project_dir: Path) -> None:
37
37
  flower_project_dir,
38
38
  )
39
39
 
40
- syftbox_flwr_client(client_app, context)
40
+ syftbox_flwr_client(client_app, context, app_name)
41
41
 
42
42
 
43
43
  def syftbox_run_flwr_server(flower_project_dir: Path) -> None:
44
44
  pyproject_conf = load_flwr_pyproject(flower_project_dir)
45
45
  datasites = pyproject_conf["tool"]["syft_flwr"]["datasites"]
46
46
  server_ref = pyproject_conf["tool"]["flwr"]["app"]["components"]["serverapp"]
47
+ app_name = pyproject_conf["tool"]["syft_flwr"]["app_name"]
47
48
 
48
49
  context = Context(
49
50
  run_id=uuid4().int,
@@ -58,4 +59,4 @@ def syftbox_run_flwr_server(flower_project_dir: Path) -> None:
58
59
  flower_project_dir,
59
60
  )
60
61
 
61
- syftbox_flwr_server(server_app, context, datasites)
62
+ syftbox_flwr_server(server_app, context, datasites, app_name)
@@ -1,10 +1,11 @@
1
1
  from pathlib import Path
2
2
 
3
- from flwr.common import parameters_to_ndarrays
4
- from flwr.server.strategy import FedAvg
5
3
  from loguru import logger
6
4
  from safetensors.numpy import save_file
7
5
 
6
+ from flwr.common import parameters_to_ndarrays
7
+ from flwr.server.strategy import FedAvg
8
+
8
9
 
9
10
  class FedAvgWithModelSaving(FedAvg):
10
11
  """This is a custom strategy that behaves exactly like
@@ -23,9 +24,13 @@ class FedAvgWithModelSaving(FedAvg):
23
24
  ndarrays = parameters_to_ndarrays(parameters)
24
25
  tensor_dict = {f"layer_{i}": array for i, array in enumerate(ndarrays)}
25
26
  filename = self.save_path / f"parameters_round_{server_round}.safetensors"
26
- save_file(tensor_dict, str(filename))
27
-
28
- logger.info(f"Checkpoint saved to: {filename}")
27
+ if not self.save_path.exists():
28
+ logger.error(
29
+ f"Save directory {self.save_path} does NOT exist! Maybe it's deleted or moved."
30
+ )
31
+ else:
32
+ save_file(tensor_dict, str(filename))
33
+ logger.info(f"Checkpoint saved to: {filename}")
29
34
 
30
35
  def evaluate(self, server_round: int, parameters):
31
36
  """Evaluate model parameters using an evaluation function."""
@@ -1,65 +0,0 @@
1
- import sys
2
- import traceback
3
-
4
- from flwr.client import ClientApp
5
- from flwr.common import Context
6
- from flwr.common.constant import ErrorCode
7
- from flwr.common.message import Error, Message
8
- from loguru import logger
9
- from syft_event import SyftEvents
10
- from syft_event.types import Request
11
-
12
- from syft_flwr.flwr_compatibility import RecordDict, create_flwr_message
13
- from syft_flwr.serde import bytes_to_flower_message, flower_message_to_bytes
14
-
15
-
16
- def syftbox_flwr_client(client_app: ClientApp, context: Context):
17
- """Run the Flower ClientApp with SyftBox."""
18
-
19
- box = SyftEvents("flwr")
20
- client_email = box.client.email
21
- logger.info(f"Started SyftBox Flower Client on: {client_email}")
22
-
23
- @box.on_request("/messages")
24
- def handle_messages(request: Request) -> None:
25
- logger.info(
26
- f"Received request id: {request.id}, size: {len(request.body) / 1024 / 1024} (MB)"
27
- )
28
- message: Message = bytes_to_flower_message(request.body)
29
-
30
- try:
31
- reply_message: Message = client_app(message=message, context=context)
32
- res_bytes: bytes = flower_message_to_bytes(reply_message)
33
- logger.info(f"Reply message size: {len(res_bytes)/2**20} MB")
34
- return res_bytes
35
-
36
- except Exception as e:
37
- error_traceback = traceback.format_exc()
38
- error_message = f"Client: '{client_email}'. Error: {str(e)}. Traceback: {error_traceback}"
39
- logger.error(error_message)
40
-
41
- error = Error(
42
- code=ErrorCode.CLIENT_APP_RAISED_EXCEPTION, reason=f"{error_message}"
43
- )
44
-
45
- error_reply: Message = create_flwr_message(
46
- content=RecordDict(),
47
- reply_to=message,
48
- message_type=message.metadata.message_type,
49
- src_node_id=message.metadata.dst_node_id,
50
- dst_node_id=message.metadata.src_node_id,
51
- group_id=message.metadata.group_id,
52
- run_id=message.metadata.run_id,
53
- error=error,
54
- )
55
- error_bytes: bytes = flower_message_to_bytes(error_reply)
56
- logger.info(f"Error reply message size: {len(error_bytes)/2**20} MB")
57
- return error_bytes
58
-
59
- try:
60
- box.run_forever()
61
- except Exception as e:
62
- logger.error(
63
- f"Fatal error in syftbox_flwr_client: {str(e)}\n{traceback.format_exc()}"
64
- )
65
- sys.exit(1)
@@ -1,25 +0,0 @@
1
- from random import randint
2
-
3
- from flwr.common import Context
4
- from flwr.server import ServerApp
5
- from flwr.server.run_serverapp import run as run_server
6
- from loguru import logger
7
-
8
- from syft_flwr.grid import SyftGrid
9
-
10
-
11
- def syftbox_flwr_server(server_app: ServerApp, context: Context, datasites: list[str]):
12
- """Run the Flower ServerApp with SyftBox."""
13
- syft_grid = SyftGrid(datasites=datasites)
14
- run_id = randint(0, 1000)
15
- syft_grid.set_run(run_id)
16
- logger.info(f"Started SyftBox Flower Server on: {syft_grid._client.email}")
17
-
18
- updated_context = run_server(
19
- syft_grid,
20
- context=context,
21
- loaded_server_app=server_app,
22
- server_app_dir="",
23
- )
24
- logger.info(f"Server completed with context: {updated_context}")
25
- return updated_context
File without changes