syft-flwr 0.1.7__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of syft-flwr might be problematic. Click here for more details.
- syft_flwr/__init__.py +1 -1
- syft_flwr/consts.py +2 -0
- syft_flwr/flower_client.py +151 -60
- syft_flwr/flower_server.py +12 -4
- syft_flwr/grid.py +372 -99
- syft_flwr/mounts.py +1 -1
- syft_flwr/run_simulation.py +124 -24
- syft_flwr/utils.py +45 -0
- {syft_flwr-0.1.7.dist-info → syft_flwr-0.2.0.dist-info}/METADATA +3 -3
- syft_flwr-0.2.0.dist-info/RECORD +22 -0
- syft_flwr-0.1.7.dist-info/RECORD +0 -21
- {syft_flwr-0.1.7.dist-info → syft_flwr-0.2.0.dist-info}/WHEEL +0 -0
- {syft_flwr-0.1.7.dist-info → syft_flwr-0.2.0.dist-info}/entry_points.txt +0 -0
- {syft_flwr-0.1.7.dist-info → syft_flwr-0.2.0.dist-info}/licenses/LICENSE +0 -0
syft_flwr/__init__.py
CHANGED
syft_flwr/consts.py
ADDED
syft_flwr/flower_client.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import base64
|
|
1
2
|
import sys
|
|
2
3
|
import traceback
|
|
3
4
|
|
|
@@ -8,85 +9,175 @@ from flwr.common.message import Error, Message
|
|
|
8
9
|
from loguru import logger
|
|
9
10
|
from syft_event import SyftEvents
|
|
10
11
|
from syft_event.types import Request
|
|
12
|
+
from typing_extensions import Optional, Union
|
|
11
13
|
|
|
12
14
|
from syft_flwr.flwr_compatibility import RecordDict, create_flwr_message
|
|
13
15
|
from syft_flwr.serde import bytes_to_flower_message, flower_message_to_bytes
|
|
16
|
+
from syft_flwr.utils import setup_client
|
|
14
17
|
|
|
15
18
|
|
|
16
|
-
|
|
17
|
-
message
|
|
18
|
-
) -> bytes:
|
|
19
|
-
# Normal message handling
|
|
20
|
-
logger.info(f"Receive message with metadata: {message.metadata}")
|
|
21
|
-
reply_message: Message = client_app(message=message, context=context)
|
|
22
|
-
res_bytes: bytes = flower_message_to_bytes(reply_message)
|
|
23
|
-
logger.info(f"Reply message size: {len(res_bytes)/2**20} MB")
|
|
24
|
-
return res_bytes
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def _create_error_reply(message: Message, error: Error) -> bytes:
|
|
28
|
-
"""Create and return error reply message in bytes."""
|
|
29
|
-
error_reply: Message = create_flwr_message(
|
|
30
|
-
content=RecordDict(),
|
|
31
|
-
reply_to=message,
|
|
32
|
-
message_type=message.metadata.message_type,
|
|
33
|
-
src_node_id=message.metadata.dst_node_id,
|
|
34
|
-
dst_node_id=message.metadata.src_node_id,
|
|
35
|
-
group_id=message.metadata.group_id,
|
|
36
|
-
run_id=message.metadata.run_id,
|
|
37
|
-
error=error,
|
|
38
|
-
)
|
|
39
|
-
error_bytes: bytes = flower_message_to_bytes(error_reply)
|
|
40
|
-
logger.info(f"Error reply message size: {len(error_bytes)/2**20} MB")
|
|
41
|
-
return error_bytes
|
|
19
|
+
class MessageHandler:
|
|
20
|
+
"""Handles message processing for Flower client."""
|
|
42
21
|
|
|
22
|
+
def __init__(
|
|
23
|
+
self, client_app: ClientApp, context: Context, encryption_enabled: bool
|
|
24
|
+
):
|
|
25
|
+
self.client_app = client_app
|
|
26
|
+
self.context = context
|
|
27
|
+
self.encryption_enabled = encryption_enabled
|
|
43
28
|
|
|
44
|
-
def
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
29
|
+
def prepare_reply(self, data: bytes) -> Union[str, bytes]:
|
|
30
|
+
"""Prepare reply data based on encryption setting."""
|
|
31
|
+
if self.encryption_enabled:
|
|
32
|
+
logger.info(f"🔒 Preparing ENCRYPTED reply, size: {len(data)/2**20:.2f} MB")
|
|
33
|
+
return base64.b64encode(data).decode("utf-8")
|
|
34
|
+
else:
|
|
35
|
+
logger.info(f"📤 Preparing PLAINTEXT reply, size: {len(data)/2**20:.2f} MB")
|
|
36
|
+
return data
|
|
37
|
+
|
|
38
|
+
def process_message(self, message: Message) -> Union[str, bytes]:
|
|
39
|
+
"""Process normal Flower message and return reply."""
|
|
40
|
+
logger.info(f"Processing message with metadata: {message.metadata}")
|
|
41
|
+
reply_message = self.client_app(message=message, context=self.context)
|
|
42
|
+
reply_bytes = flower_message_to_bytes(reply_message)
|
|
43
|
+
return self.prepare_reply(reply_bytes)
|
|
44
|
+
|
|
45
|
+
def create_error_reply(
|
|
46
|
+
self, message: Optional[Message], error: Error
|
|
47
|
+
) -> Union[str, bytes]:
|
|
48
|
+
"""Create error reply message."""
|
|
49
|
+
error_reply = create_flwr_message(
|
|
50
|
+
content=RecordDict(),
|
|
51
|
+
reply_to=message,
|
|
52
|
+
message_type=message.metadata.message_type if message else MessageType.TASK,
|
|
53
|
+
src_node_id=message.metadata.dst_node_id if message else 0,
|
|
54
|
+
dst_node_id=message.metadata.src_node_id if message else 0,
|
|
55
|
+
group_id=message.metadata.group_id if message else "",
|
|
56
|
+
run_id=message.metadata.run_id if message else 0,
|
|
57
|
+
error=error,
|
|
58
|
+
)
|
|
59
|
+
error_bytes = flower_message_to_bytes(error_reply)
|
|
60
|
+
logger.info(f"Error reply size: {len(error_bytes)/2**20:.2f} MB")
|
|
61
|
+
return self.prepare_reply(error_bytes)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class RequestProcessor:
|
|
65
|
+
"""Processes incoming requests and handles encryption/decryption."""
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self, message_handler: MessageHandler, box: SyftEvents, client_email: str
|
|
69
|
+
):
|
|
70
|
+
self.message_handler = message_handler
|
|
71
|
+
self.box = box
|
|
72
|
+
self.client_email = client_email
|
|
73
|
+
|
|
74
|
+
def decode_request_body(self, request_body: Union[bytes, str]) -> bytes:
|
|
75
|
+
"""Decode request body, handling base64 if encrypted."""
|
|
76
|
+
if not self.message_handler.encryption_enabled:
|
|
77
|
+
return request_body
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
# Convert to string if bytes
|
|
81
|
+
if isinstance(request_body, bytes):
|
|
82
|
+
request_body_str = request_body.decode("utf-8")
|
|
83
|
+
else:
|
|
84
|
+
request_body_str = request_body
|
|
85
|
+
# Decode base64
|
|
86
|
+
decoded = base64.b64decode(request_body_str)
|
|
87
|
+
logger.debug("🔓 Decoded base64 message")
|
|
88
|
+
return decoded
|
|
89
|
+
except Exception:
|
|
90
|
+
# Not base64 or decoding failed, use as-is
|
|
91
|
+
return request_body
|
|
92
|
+
|
|
93
|
+
def is_stop_signal(self, message: Message) -> bool:
|
|
94
|
+
"""Check if message is a stop signal."""
|
|
95
|
+
if message.metadata.message_type != MessageType.SYSTEM:
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
# Check for stop action in config
|
|
99
|
+
if "config" in message.content and "action" in message.content["config"]:
|
|
100
|
+
return message.content["config"]["action"] == "stop"
|
|
101
|
+
|
|
102
|
+
# Alternative stop signal format
|
|
103
|
+
return message.metadata.group_id == "final"
|
|
104
|
+
|
|
105
|
+
def process(self, request: Request) -> Optional[Union[str, bytes]]:
|
|
106
|
+
"""Process incoming request and return response."""
|
|
107
|
+
original_sender = request.headers.get("X-Syft-Original-Sender", "unknown")
|
|
108
|
+
encryption_status = (
|
|
109
|
+
"🔐 ENCRYPTED"
|
|
110
|
+
if self.message_handler.encryption_enabled
|
|
111
|
+
else "📥 PLAINTEXT"
|
|
112
|
+
)
|
|
51
113
|
|
|
52
|
-
@box.on_request("/messages")
|
|
53
|
-
def handle_messages(request: Request) -> None:
|
|
54
114
|
logger.info(
|
|
55
|
-
f"Received request
|
|
115
|
+
f"{encryption_status} Received request from {original_sender}, "
|
|
116
|
+
f"id: {request.id}, size: {len(request.body) / 1024 / 1024:.2f} MB"
|
|
56
117
|
)
|
|
57
|
-
|
|
118
|
+
|
|
119
|
+
# Parse message
|
|
58
120
|
try:
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
"
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
121
|
+
request_body = self.decode_request_body(request.body)
|
|
122
|
+
message = bytes_to_flower_message(request_body)
|
|
123
|
+
|
|
124
|
+
if self.message_handler.encryption_enabled:
|
|
125
|
+
logger.debug(
|
|
126
|
+
f"🔓 Successfully decrypted message from {original_sender}"
|
|
127
|
+
)
|
|
128
|
+
except Exception as e:
|
|
129
|
+
logger.error(
|
|
130
|
+
f"❌ Failed to deserialize message from {original_sender}: {e}"
|
|
131
|
+
)
|
|
132
|
+
error = Error(
|
|
133
|
+
code=ErrorCode.CLIENT_APP_RAISED_EXCEPTION,
|
|
134
|
+
reason=f"Message deserialization failed: {e}",
|
|
135
|
+
)
|
|
136
|
+
return self.message_handler.create_error_reply(None, error)
|
|
137
|
+
|
|
138
|
+
# Handle message
|
|
139
|
+
try:
|
|
140
|
+
# Check for stop signal
|
|
141
|
+
if self.is_stop_signal(message):
|
|
142
|
+
logger.info("Received stop signal")
|
|
143
|
+
self.box._stop_event.set()
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
# Process normal message
|
|
147
|
+
return self.message_handler.process_message(message)
|
|
78
148
|
|
|
79
149
|
except Exception as e:
|
|
80
|
-
|
|
81
|
-
error_message = f"Client: '{client_email}'. Error: {str(e)}. Traceback: {error_traceback}"
|
|
150
|
+
error_message = f"Client: '{self.client_email}'. Error: {str(e)}. Traceback: {traceback.format_exc()}"
|
|
82
151
|
logger.error(error_message)
|
|
83
152
|
|
|
84
153
|
error = Error(
|
|
85
154
|
code=ErrorCode.CLIENT_APP_RAISED_EXCEPTION, reason=error_message
|
|
86
155
|
)
|
|
87
|
-
box._stop_event.set()
|
|
88
|
-
return
|
|
156
|
+
self.box._stop_event.set()
|
|
157
|
+
return self.message_handler.create_error_reply(message, error)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def syftbox_flwr_client(client_app: ClientApp, context: Context, app_name: str):
|
|
161
|
+
"""Run the Flower ClientApp with SyftBox."""
|
|
162
|
+
# Setup
|
|
163
|
+
client, encryption_enabled, syft_flwr_app_name = setup_client(app_name)
|
|
164
|
+
box = SyftEvents(app_name=syft_flwr_app_name, client=client)
|
|
165
|
+
|
|
166
|
+
logger.info(f"Started SyftBox Flower Client on: {box.client.email}")
|
|
167
|
+
logger.info(f"syft_flwr app name: {syft_flwr_app_name}")
|
|
168
|
+
|
|
169
|
+
# Create handlers
|
|
170
|
+
message_handler = MessageHandler(client_app, context, encryption_enabled)
|
|
171
|
+
processor = RequestProcessor(message_handler, box, box.client.email)
|
|
172
|
+
|
|
173
|
+
# Register message handler
|
|
174
|
+
@box.on_request(
|
|
175
|
+
"/messages", auto_decrypt=encryption_enabled, encrypt_reply=encryption_enabled
|
|
176
|
+
)
|
|
177
|
+
def handle_messages(request: Request) -> Optional[Union[str, bytes]]:
|
|
178
|
+
return processor.process(request)
|
|
89
179
|
|
|
180
|
+
# Run
|
|
90
181
|
try:
|
|
91
182
|
box.run_forever()
|
|
92
183
|
except Exception as e:
|
syft_flwr/flower_server.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
import traceback
|
|
2
2
|
from random import randint
|
|
3
3
|
|
|
4
|
-
from loguru import logger
|
|
5
|
-
|
|
6
4
|
from flwr.common import Context
|
|
7
5
|
from flwr.server import ServerApp
|
|
8
6
|
from flwr.server.run_serverapp import run as run_server
|
|
7
|
+
from loguru import logger
|
|
8
|
+
|
|
9
9
|
from syft_flwr.grid import SyftGrid
|
|
10
|
+
from syft_flwr.utils import setup_client
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
def syftbox_flwr_server(
|
|
@@ -16,10 +17,17 @@ def syftbox_flwr_server(
|
|
|
16
17
|
app_name: str,
|
|
17
18
|
) -> Context:
|
|
18
19
|
"""Run the Flower ServerApp with SyftBox."""
|
|
19
|
-
syft_flwr_app_name =
|
|
20
|
-
|
|
20
|
+
client, _, syft_flwr_app_name = setup_client(app_name)
|
|
21
|
+
|
|
22
|
+
# Construct the SyftGrid
|
|
23
|
+
syft_grid = SyftGrid(
|
|
24
|
+
app_name=syft_flwr_app_name, datasites=datasites, client=client
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# Set the run id (random for now)
|
|
21
28
|
run_id = randint(0, 1000)
|
|
22
29
|
syft_grid.set_run(run_id)
|
|
30
|
+
|
|
23
31
|
logger.info(f"Started SyftBox Flower Server on: {syft_grid._client.email}")
|
|
24
32
|
logger.info(f"syft_flwr app name: {syft_flwr_app_name}")
|
|
25
33
|
|
syft_flwr/grid.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
+
import base64
|
|
1
2
|
import os
|
|
2
3
|
import time
|
|
3
|
-
from typing import Iterable, cast
|
|
4
4
|
|
|
5
5
|
from flwr.common import ConfigRecord
|
|
6
6
|
from flwr.common.constant import MessageType
|
|
@@ -9,9 +9,11 @@ from flwr.common.typing import Run
|
|
|
9
9
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
10
10
|
from loguru import logger
|
|
11
11
|
from syft_core import Client
|
|
12
|
-
from
|
|
13
|
-
from
|
|
12
|
+
from syft_crypto import EncryptedPayload, decrypt_message
|
|
13
|
+
from syft_rpc import SyftResponse, rpc, rpc_db
|
|
14
|
+
from typing_extensions import Dict, Iterable, List, Optional, Tuple, cast
|
|
14
15
|
|
|
16
|
+
from syft_flwr.consts import SYFT_FLWR_ENCRYPTION_ENABLED
|
|
15
17
|
from syft_flwr.flwr_compatibility import (
|
|
16
18
|
Grid,
|
|
17
19
|
RecordDict,
|
|
@@ -24,7 +26,6 @@ from syft_flwr.utils import str_to_int
|
|
|
24
26
|
# this is what superlink super node do
|
|
25
27
|
AGGREGATOR_NODE_ID = 1
|
|
26
28
|
|
|
27
|
-
|
|
28
29
|
# env vars
|
|
29
30
|
SYFT_FLWR_MSG_TIMEOUT = "SYFT_FLWR_MSG_TIMEOUT"
|
|
30
31
|
|
|
@@ -36,39 +37,66 @@ class SyftGrid(Grid):
|
|
|
36
37
|
datasites: list[str] = [],
|
|
37
38
|
client: Client = None,
|
|
38
39
|
) -> None:
|
|
40
|
+
"""
|
|
41
|
+
SyftGrid is the server-side message orchestrator for federated learning in syft_flwr.
|
|
42
|
+
It acts as a bridge between Flower's server logic and SyftBox's communication layer:
|
|
43
|
+
|
|
44
|
+
Flower Server → SyftGrid → syft_rpc → SyftBox network → FL Clients
|
|
45
|
+
↑ ↓
|
|
46
|
+
└──────────── responses ←─────────────────┘
|
|
47
|
+
|
|
48
|
+
SyftGrid enables Flower's centralized server to communicate with distributed SyftBox
|
|
49
|
+
clients without knowing the underlying transport details.
|
|
50
|
+
|
|
51
|
+
Core functionalities:
|
|
52
|
+
- push_messages(): Sends messages to clients via syft_rpc, returns future IDs
|
|
53
|
+
- pull_messages(): Retrieves responses using futures
|
|
54
|
+
- send_and_receive(): Combines push/pull with timeout handling
|
|
55
|
+
"""
|
|
39
56
|
self._client = Client.load() if client is None else client
|
|
40
57
|
self._run: Optional[Run] = None
|
|
41
58
|
self.node = Node(node_id=AGGREGATOR_NODE_ID)
|
|
42
59
|
self.datasites = datasites
|
|
43
60
|
self.client_map = {str_to_int(ds): ds for ds in self.datasites}
|
|
61
|
+
|
|
62
|
+
# Check if encryption is enabled (default: True for production)
|
|
63
|
+
self._encryption_enabled = (
|
|
64
|
+
os.environ.get(SYFT_FLWR_ENCRYPTION_ENABLED, "true").lower() != "false"
|
|
65
|
+
)
|
|
66
|
+
|
|
44
67
|
logger.debug(
|
|
45
68
|
f"Initialize SyftGrid for '{self._client.email}' with datasites: {self.datasites}"
|
|
46
69
|
)
|
|
70
|
+
if self._encryption_enabled:
|
|
71
|
+
logger.info("🔐 End-to-end encryption is ENABLED for FL messages")
|
|
72
|
+
else:
|
|
73
|
+
logger.warning(
|
|
74
|
+
"⚠️ End-to-end encryption is DISABLED for FL messages (development mode / insecure)"
|
|
75
|
+
)
|
|
76
|
+
|
|
47
77
|
self.app_name = app_name
|
|
48
78
|
|
|
49
79
|
def set_run(self, run_id: int) -> None:
|
|
50
|
-
|
|
51
|
-
# do we need to do the same here, where the run id is set from an external context.
|
|
80
|
+
"""Set the run ID for this federated learning session.
|
|
52
81
|
|
|
82
|
+
Args:
|
|
83
|
+
run_id: Unique identifier for the FL run/session
|
|
84
|
+
|
|
85
|
+
Note:
|
|
86
|
+
In Grpc Grid case, the superlink sets up the run id.
|
|
87
|
+
Here, the run id is set from an external context.
|
|
88
|
+
"""
|
|
53
89
|
# Convert to Flower Run object
|
|
54
90
|
self._run = Run.create_empty(run_id)
|
|
55
91
|
|
|
56
92
|
@property
|
|
57
93
|
def run(self) -> Run:
|
|
58
|
-
"""Run
|
|
59
|
-
return Run(**vars(cast(Run, self._run)))
|
|
94
|
+
"""Get the current Flower Run object.
|
|
60
95
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
and message.metadata.src_node_id == self.node.node_id
|
|
66
|
-
and message.metadata.message_id == ""
|
|
67
|
-
and check_reply_to_field(message.metadata)
|
|
68
|
-
and message.metadata.ttl > 0
|
|
69
|
-
):
|
|
70
|
-
logger.debug(f"Invalid message with metadata: {message.metadata}")
|
|
71
|
-
raise ValueError(f"Invalid message: {message}")
|
|
96
|
+
Returns:
|
|
97
|
+
A copy of the current Run object with run metadata
|
|
98
|
+
"""
|
|
99
|
+
return Run(**vars(cast(Run, self._run)))
|
|
72
100
|
|
|
73
101
|
def create_message(
|
|
74
102
|
self,
|
|
@@ -78,7 +106,21 @@ class SyftGrid(Grid):
|
|
|
78
106
|
group_id: str,
|
|
79
107
|
ttl: Optional[float] = None,
|
|
80
108
|
) -> Message:
|
|
81
|
-
"""Create a new message with
|
|
109
|
+
"""Create a new Flower message with proper metadata.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
content: Message payload as RecordDict (e.g., model parameters, metrics)
|
|
113
|
+
message_type: Type of FL message (e.g., MessageType.TRAIN, MessageType.EVALUATE)
|
|
114
|
+
dst_node_id: Destination node ID (client identifier)
|
|
115
|
+
group_id: Message group identifier for related messages
|
|
116
|
+
ttl: Time-to-live in seconds (optional, for message expiration)
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
A Flower Message object ready to be sent to a client
|
|
120
|
+
|
|
121
|
+
Note:
|
|
122
|
+
Automatically adds current run_id and server's node_id to metadata.
|
|
123
|
+
"""
|
|
82
124
|
return create_flwr_message(
|
|
83
125
|
content=content,
|
|
84
126
|
message_type=message_type,
|
|
@@ -90,69 +132,82 @@ class SyftGrid(Grid):
|
|
|
90
132
|
)
|
|
91
133
|
|
|
92
134
|
def get_node_ids(self) -> list[int]:
|
|
93
|
-
"""Get node IDs of all connected
|
|
94
|
-
|
|
135
|
+
"""Get node IDs of all connected FL clients.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
List of integer node IDs representing connected datasites/clients
|
|
139
|
+
|
|
140
|
+
Note:
|
|
141
|
+
Node IDs are deterministically generated from datasite email addresses
|
|
142
|
+
using str_to_int() for consistent client identification.
|
|
143
|
+
"""
|
|
95
144
|
return list(self.client_map.keys())
|
|
96
145
|
|
|
97
146
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
98
|
-
"""Push messages to specified
|
|
99
|
-
|
|
100
|
-
|
|
147
|
+
"""Push FL messages to specified clients asynchronously.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
messages: Iterable of Flower Messages to send to clients
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
List of future IDs that can be used to retrieve responses
|
|
154
|
+
"""
|
|
101
155
|
message_ids = []
|
|
156
|
+
|
|
102
157
|
for msg in messages:
|
|
103
|
-
#
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
# RPC URL
|
|
107
|
-
dest_datasite = self.client_map[msg.metadata.dst_node_id]
|
|
108
|
-
url = rpc.make_url(
|
|
109
|
-
dest_datasite, app_name=self.app_name, endpoint="messages"
|
|
110
|
-
)
|
|
111
|
-
# Check message
|
|
112
|
-
self._check_message(msg)
|
|
113
|
-
# Serialize message
|
|
114
|
-
msg_bytes = flower_message_to_bytes(msg)
|
|
158
|
+
# Prepare message
|
|
159
|
+
dest_datasite, url, msg_bytes = self._prepare_message(msg)
|
|
160
|
+
|
|
115
161
|
# Send message
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
162
|
+
if self._encryption_enabled:
|
|
163
|
+
future_id = self._send_encrypted_message(
|
|
164
|
+
url, msg_bytes, dest_datasite, msg
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
future_id = self._send_unencrypted_message(
|
|
168
|
+
url, msg_bytes, dest_datasite, msg
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
if future_id:
|
|
172
|
+
message_ids.append(future_id)
|
|
125
173
|
|
|
126
174
|
return message_ids
|
|
127
175
|
|
|
128
|
-
def pull_messages(self, message_ids):
|
|
129
|
-
"""Pull messages
|
|
176
|
+
def pull_messages(self, message_ids: List[str]) -> Dict[str, Message]:
|
|
177
|
+
"""Pull response messages from clients using future IDs.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
message_ids: List of future IDs from push_messages()
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Dict mapping message_id to Flower Message response
|
|
184
|
+
"""
|
|
130
185
|
messages = {}
|
|
131
186
|
|
|
132
187
|
for msg_id in message_ids:
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
188
|
+
try:
|
|
189
|
+
# Get and resolve future
|
|
190
|
+
future = rpc_db.get_future(future_id=msg_id, client=self._client)
|
|
191
|
+
response = future.resolve()
|
|
137
192
|
|
|
138
|
-
|
|
193
|
+
if response is None:
|
|
194
|
+
continue # Message not ready yet
|
|
139
195
|
|
|
140
|
-
|
|
141
|
-
raise ValueError(f"Empty response: {response}")
|
|
196
|
+
response.raise_for_status()
|
|
142
197
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
198
|
+
# Process the response
|
|
199
|
+
message = self._process_response(response, msg_id)
|
|
200
|
+
|
|
201
|
+
if message:
|
|
202
|
+
messages[msg_id] = message
|
|
203
|
+
rpc_db.delete_future(future_id=msg_id, client=self._client)
|
|
204
|
+
|
|
205
|
+
except Exception as e:
|
|
206
|
+
logger.error(f"❌ Unexpected error pulling message {msg_id}: {e}")
|
|
149
207
|
continue
|
|
150
208
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
)
|
|
154
|
-
messages[msg_id] = message
|
|
155
|
-
rpc_db.delete_future(future_id=msg_id, client=self._client)
|
|
209
|
+
# Log summary
|
|
210
|
+
self._log_pull_summary(messages, message_ids)
|
|
156
211
|
|
|
157
212
|
return messages
|
|
158
213
|
|
|
@@ -164,47 +219,49 @@ class SyftGrid(Grid):
|
|
|
164
219
|
) -> Iterable[Message]:
|
|
165
220
|
"""Push messages to specified node IDs and pull the reply messages.
|
|
166
221
|
|
|
167
|
-
This method sends
|
|
168
|
-
|
|
169
|
-
|
|
222
|
+
This method sends messages to their destination nodes and waits for replies.
|
|
223
|
+
It continues polling until all replies are received or timeout is reached.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
messages: Messages to send
|
|
227
|
+
timeout: Maximum time to wait for replies (seconds).
|
|
228
|
+
Can be overridden by SYFT_FLWR_MSG_TIMEOUT env var.
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
Collection of reply messages received
|
|
170
232
|
"""
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
if timeout is not None:
|
|
174
|
-
logger.debug(
|
|
175
|
-
f"syft_flwr messages timeout = {timeout}: Will move on after {timeout} (s) if no reply is received"
|
|
176
|
-
)
|
|
177
|
-
else:
|
|
178
|
-
logger.debug(
|
|
179
|
-
"syft_flwr messages timeout = None: Will wait indefinitely for replies"
|
|
180
|
-
)
|
|
233
|
+
# Get timeout from environment or parameter
|
|
234
|
+
timeout = self._get_timeout(timeout)
|
|
181
235
|
|
|
182
|
-
# Push messages
|
|
236
|
+
# Push messages and get IDs
|
|
183
237
|
msg_ids = set(self.push_messages(messages))
|
|
238
|
+
if not msg_ids:
|
|
239
|
+
return []
|
|
184
240
|
|
|
185
|
-
#
|
|
186
|
-
|
|
187
|
-
ret = {}
|
|
188
|
-
while timeout is None or time.time() < end_time:
|
|
189
|
-
res_msgs = self.pull_messages(msg_ids)
|
|
190
|
-
ret.update(res_msgs)
|
|
191
|
-
msg_ids.difference_update(res_msgs.keys())
|
|
192
|
-
if len(msg_ids) == 0: # All messages received
|
|
193
|
-
break
|
|
194
|
-
time.sleep(3) # polling interval
|
|
195
|
-
|
|
196
|
-
if msg_ids:
|
|
197
|
-
logger.warning(
|
|
198
|
-
f"Timeout reached. {len(msg_ids)} message(s) sent out but not replied."
|
|
199
|
-
)
|
|
241
|
+
# Poll for responses
|
|
242
|
+
responses = self._poll_for_responses(msg_ids, timeout)
|
|
200
243
|
|
|
201
|
-
return
|
|
244
|
+
return responses.values()
|
|
202
245
|
|
|
203
246
|
def send_stop_signal(
|
|
204
247
|
self, group_id: str, reason: str = "Training complete", ttl: float = 60.0
|
|
205
|
-
) ->
|
|
206
|
-
"""Send a stop signal to all
|
|
207
|
-
|
|
248
|
+
) -> List[Message]:
|
|
249
|
+
"""Send a stop signal to all connected FL clients.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
group_id: Identifier for this group of stop messages
|
|
253
|
+
reason: Human-readable reason for stopping (default: "Training complete")
|
|
254
|
+
ttl: Time-to-live for stop messages in seconds (default: 60.0)
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
List of stop Messages that were sent
|
|
258
|
+
|
|
259
|
+
Note:
|
|
260
|
+
Used to gracefully terminate FL clients when training completes or
|
|
261
|
+
when the server encounters an error. Clients will shut down upon
|
|
262
|
+
receiving this SYSTEM message with action="stop".
|
|
263
|
+
"""
|
|
264
|
+
stop_messages: List[Message] = [
|
|
208
265
|
self.create_message(
|
|
209
266
|
content=RecordDict(
|
|
210
267
|
{"config": ConfigRecord({"action": "stop", "reason": reason})}
|
|
@@ -219,3 +276,219 @@ class SyftGrid(Grid):
|
|
|
219
276
|
self.push_messages(stop_messages)
|
|
220
277
|
|
|
221
278
|
return stop_messages
|
|
279
|
+
|
|
280
|
+
def _check_message(self, message: Message) -> None:
|
|
281
|
+
"""Validate a Flower message before sending.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
message: The Flower Message to validate
|
|
285
|
+
|
|
286
|
+
Raises:
|
|
287
|
+
ValueError: If message metadata is invalid (wrong run_id, src_node_id,
|
|
288
|
+
missing ttl, or invalid reply_to field)
|
|
289
|
+
|
|
290
|
+
Note:
|
|
291
|
+
Ensures message belongs to current run and originates from this server node.
|
|
292
|
+
"""
|
|
293
|
+
if not (
|
|
294
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
295
|
+
and message.metadata.src_node_id == self.node.node_id
|
|
296
|
+
and message.metadata.message_id == ""
|
|
297
|
+
and check_reply_to_field(message.metadata)
|
|
298
|
+
and message.metadata.ttl > 0
|
|
299
|
+
):
|
|
300
|
+
logger.debug(f"Invalid message with metadata: {message.metadata}")
|
|
301
|
+
raise ValueError(f"Invalid message: {message}")
|
|
302
|
+
|
|
303
|
+
def _prepare_message(self, msg: Message) -> Tuple[str, str, bytes]:
|
|
304
|
+
"""Prepare a message for sending.
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
Tuple of (destination_datasite, url, message_bytes)
|
|
308
|
+
"""
|
|
309
|
+
run_id = cast(Run, self._run).run_id
|
|
310
|
+
msg.metadata.__dict__["_run_id"] = run_id
|
|
311
|
+
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
|
312
|
+
|
|
313
|
+
dest_datasite = self.client_map[msg.metadata.dst_node_id]
|
|
314
|
+
url = rpc.make_url(dest_datasite, app_name=self.app_name, endpoint="messages")
|
|
315
|
+
|
|
316
|
+
self._check_message(msg)
|
|
317
|
+
msg_bytes = flower_message_to_bytes(msg)
|
|
318
|
+
|
|
319
|
+
return dest_datasite, url, msg_bytes
|
|
320
|
+
|
|
321
|
+
def _send_encrypted_message(
|
|
322
|
+
self, url: str, msg_bytes: bytes, dest_datasite: str, msg: Message
|
|
323
|
+
) -> Optional[str]:
|
|
324
|
+
"""Send an encrypted message and return future ID if successful."""
|
|
325
|
+
try:
|
|
326
|
+
future = rpc.send(
|
|
327
|
+
url=url,
|
|
328
|
+
body=base64.b64encode(msg_bytes).decode("utf-8"),
|
|
329
|
+
client=self._client,
|
|
330
|
+
encrypt=True,
|
|
331
|
+
)
|
|
332
|
+
logger.debug(
|
|
333
|
+
f"🔐 Pushed ENCRYPTED message to {dest_datasite} at {url} "
|
|
334
|
+
f"with metadata {msg.metadata}; size {len(msg_bytes) / 1024 / 1024:.2f} MB"
|
|
335
|
+
)
|
|
336
|
+
rpc_db.save_future(
|
|
337
|
+
future=future, namespace=self.app_name, client=self._client
|
|
338
|
+
)
|
|
339
|
+
return future.id
|
|
340
|
+
|
|
341
|
+
except (KeyError, ValueError) as e:
|
|
342
|
+
error_type = (
|
|
343
|
+
"Encryption key" if isinstance(e, KeyError) else "Encryption parameter"
|
|
344
|
+
)
|
|
345
|
+
logger.error(
|
|
346
|
+
f"❌ {error_type} error for {dest_datasite}: {e}. "
|
|
347
|
+
f"Skipping message to node {msg.metadata.dst_node_id}"
|
|
348
|
+
)
|
|
349
|
+
return None
|
|
350
|
+
|
|
351
|
+
except Exception as e:
|
|
352
|
+
logger.warning(
|
|
353
|
+
f"⚠️ Encryption failed for {dest_datasite}: {e}. "
|
|
354
|
+
f"Falling back to unencrypted transmission"
|
|
355
|
+
)
|
|
356
|
+
return self._send_unencrypted_message(url, msg_bytes, dest_datasite, msg)
|
|
357
|
+
|
|
358
|
+
def _send_unencrypted_message(
|
|
359
|
+
self, url: str, msg_bytes: bytes, dest_datasite: str, msg: Message
|
|
360
|
+
) -> Optional[str]:
|
|
361
|
+
"""Send an unencrypted message and return future ID if successful."""
|
|
362
|
+
try:
|
|
363
|
+
future = rpc.send(url=url, body=msg_bytes, client=self._client)
|
|
364
|
+
logger.debug(
|
|
365
|
+
f"📤 Pushed PLAINTEXT message to {dest_datasite} at {url} "
|
|
366
|
+
f"with metadata {msg.metadata}; size {len(msg_bytes) / 1024 / 1024:.2f} MB"
|
|
367
|
+
)
|
|
368
|
+
rpc_db.save_future(
|
|
369
|
+
future=future, namespace=self.app_name, client=self._client
|
|
370
|
+
)
|
|
371
|
+
return future.id
|
|
372
|
+
|
|
373
|
+
except Exception as e:
|
|
374
|
+
logger.error(f"❌ Failed to send message to {dest_datasite}: {e}")
|
|
375
|
+
return None
|
|
376
|
+
|
|
377
|
+
def _poll_for_responses(
|
|
378
|
+
self, msg_ids: set, timeout: Optional[float]
|
|
379
|
+
) -> Dict[str, Message]:
|
|
380
|
+
"""Poll for responses until all received or timeout."""
|
|
381
|
+
end_time = time.time() + (timeout if timeout is not None else float("inf"))
|
|
382
|
+
responses = {}
|
|
383
|
+
pending_ids = msg_ids.copy()
|
|
384
|
+
|
|
385
|
+
while pending_ids and (timeout is None or time.time() < end_time):
|
|
386
|
+
# Pull available messages
|
|
387
|
+
batch = self.pull_messages(pending_ids)
|
|
388
|
+
responses.update(batch)
|
|
389
|
+
pending_ids.difference_update(batch.keys())
|
|
390
|
+
|
|
391
|
+
if pending_ids:
|
|
392
|
+
time.sleep(3) # Polling interval
|
|
393
|
+
|
|
394
|
+
# Log any missing responses
|
|
395
|
+
if pending_ids:
|
|
396
|
+
logger.warning(
|
|
397
|
+
f"Timeout reached. {len(pending_ids)} message(s) not received."
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
return responses
|
|
401
|
+
|
|
402
|
+
def _process_response(
|
|
403
|
+
self, response: SyftResponse, msg_id: str
|
|
404
|
+
) -> Optional[Message]:
|
|
405
|
+
"""Process a single response and return the deserialized message."""
|
|
406
|
+
if not response.body:
|
|
407
|
+
logger.warning(f"⚠️ Empty response for message {msg_id}, skipping")
|
|
408
|
+
return None
|
|
409
|
+
|
|
410
|
+
response_body = response.body
|
|
411
|
+
|
|
412
|
+
# Try to decrypt if encryption is enabled
|
|
413
|
+
if self._encryption_enabled:
|
|
414
|
+
response_body = self._try_decrypt_response(response.body, msg_id)
|
|
415
|
+
|
|
416
|
+
# Deserialize message
|
|
417
|
+
try:
|
|
418
|
+
message = bytes_to_flower_message(response_body)
|
|
419
|
+
except Exception as e:
|
|
420
|
+
logger.error(
|
|
421
|
+
f"❌ Failed to deserialize message {msg_id}: {e}. "
|
|
422
|
+
f"Message may be corrupted or in incompatible format."
|
|
423
|
+
)
|
|
424
|
+
return None
|
|
425
|
+
|
|
426
|
+
# Check for errors in message
|
|
427
|
+
if message.has_error():
|
|
428
|
+
error = message.error
|
|
429
|
+
logger.error(
|
|
430
|
+
f"❌ Message {msg_id} returned error with code={error.code}, "
|
|
431
|
+
f"reason={error.reason}"
|
|
432
|
+
)
|
|
433
|
+
return None
|
|
434
|
+
|
|
435
|
+
# Log successful pull
|
|
436
|
+
encryption_status = (
|
|
437
|
+
"🔐 ENCRYPTED" if self._encryption_enabled else "📥 PLAINTEXT"
|
|
438
|
+
)
|
|
439
|
+
logger.debug(
|
|
440
|
+
f"{encryption_status} Pulled message from {response.url} "
|
|
441
|
+
f"with metadata: {message.metadata}, "
|
|
442
|
+
f"size: {len(response_body) / 1024 / 1024:.2f} MB"
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
return message
|
|
446
|
+
|
|
447
|
+
def _try_decrypt_response(self, body: bytes, msg_id: str) -> bytes:
|
|
448
|
+
"""Try to decrypt response body if it's encrypted."""
|
|
449
|
+
try:
|
|
450
|
+
# Try to parse as encrypted payload
|
|
451
|
+
encrypted_payload = EncryptedPayload.model_validate_json(body.decode())
|
|
452
|
+
# Decrypt the message
|
|
453
|
+
decrypted_body = decrypt_message(encrypted_payload, client=self._client)
|
|
454
|
+
# The decrypted body should be a base64-encoded string
|
|
455
|
+
response_body = base64.b64decode(decrypted_body)
|
|
456
|
+
logger.debug(f"🔓 Successfully decrypted response for message {msg_id}")
|
|
457
|
+
return response_body
|
|
458
|
+
except Exception as e:
|
|
459
|
+
# If decryption fails, assume plaintext
|
|
460
|
+
logger.debug(
|
|
461
|
+
f"📥 Response appears to be plaintext or decryption not needed "
|
|
462
|
+
f"for message {msg_id}: {e}"
|
|
463
|
+
)
|
|
464
|
+
return body
|
|
465
|
+
|
|
466
|
+
def _log_pull_summary(
|
|
467
|
+
self, messages: Dict[str, Message], message_ids: List[str]
|
|
468
|
+
) -> None:
|
|
469
|
+
"""Log summary of pulled messages."""
|
|
470
|
+
if messages:
|
|
471
|
+
if self._encryption_enabled:
|
|
472
|
+
logger.info(
|
|
473
|
+
f"🔐 Successfully pulled {len(messages)} messages (encryption enabled)"
|
|
474
|
+
)
|
|
475
|
+
else:
|
|
476
|
+
logger.info(f"📥 Successfully pulled {len(messages)} messages")
|
|
477
|
+
elif message_ids:
|
|
478
|
+
logger.debug(
|
|
479
|
+
f"No messages pulled yet from {len(message_ids)} attempts "
|
|
480
|
+
f"(clients may still be processing)"
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
def _get_timeout(self, timeout: Optional[float]) -> Optional[float]:
|
|
484
|
+
"""Get timeout value from environment or parameter."""
|
|
485
|
+
env_timeout = os.environ.get(SYFT_FLWR_MSG_TIMEOUT)
|
|
486
|
+
if env_timeout is not None:
|
|
487
|
+
timeout = float(env_timeout)
|
|
488
|
+
|
|
489
|
+
if timeout is not None:
|
|
490
|
+
logger.debug(f"Message timeout: {timeout}s")
|
|
491
|
+
else:
|
|
492
|
+
logger.debug("No timeout - will wait indefinitely for replies")
|
|
493
|
+
|
|
494
|
+
return timeout
|
syft_flwr/mounts.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import List
|
|
5
4
|
|
|
6
5
|
import tomli
|
|
7
6
|
from loguru import logger
|
|
8
7
|
from syft_core import Client
|
|
9
8
|
from syft_rds.models import DockerMount, JobConfig
|
|
10
9
|
from syft_rds.syft_runtime.mounts import MountProvider
|
|
10
|
+
from typing_extensions import List
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class SyftFlwrMountProvider(MountProvider):
|
syft_flwr/run_simulation.py
CHANGED
|
@@ -1,33 +1,109 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import os
|
|
3
|
-
import
|
|
3
|
+
import sys
|
|
4
|
+
import tempfile
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
|
|
6
7
|
from loguru import logger
|
|
8
|
+
from syft_core import Client
|
|
9
|
+
from syft_crypto import did_path, ensure_bootstrap, get_did_document, private_key_path
|
|
7
10
|
from syft_rds.client.rds_client import RDSClient
|
|
8
|
-
from syft_rds.orchestra import
|
|
9
|
-
from typing_extensions import Union
|
|
11
|
+
from syft_rds.orchestra import SingleRDSStack, remove_rds_stack_dir
|
|
12
|
+
from typing_extensions import Optional, Union
|
|
10
13
|
|
|
11
14
|
from syft_flwr.config import load_flwr_pyproject
|
|
15
|
+
from syft_flwr.consts import SYFT_FLWR_ENCRYPTION_ENABLED
|
|
16
|
+
from syft_flwr.utils import create_temp_client
|
|
12
17
|
|
|
13
18
|
|
|
14
19
|
def _setup_mock_rds_clients(
|
|
15
20
|
project_dir: Path, aggregator: str, datasites: list[str]
|
|
16
|
-
) -> tuple[
|
|
21
|
+
) -> tuple[Path, list[RDSClient], RDSClient]:
|
|
17
22
|
"""Setup mock RDS clients for the given project directory"""
|
|
18
|
-
|
|
19
|
-
remove_rds_stack_dir(
|
|
23
|
+
simulated_syftbox_network_dir = Path(tempfile.gettempdir(), project_dir.name)
|
|
24
|
+
remove_rds_stack_dir(root_dir=simulated_syftbox_network_dir)
|
|
20
25
|
|
|
21
|
-
|
|
22
|
-
|
|
26
|
+
ds_syftbox_client = create_temp_client(
|
|
27
|
+
email=aggregator, workspace_dir=simulated_syftbox_network_dir
|
|
28
|
+
)
|
|
29
|
+
ds_stack = SingleRDSStack(client=ds_syftbox_client)
|
|
30
|
+
ds_rds_client = ds_stack.init_session(host=aggregator)
|
|
23
31
|
|
|
24
|
-
|
|
32
|
+
do_rds_clients = []
|
|
25
33
|
for datasite in datasites:
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
34
|
+
do_syftbox_client = create_temp_client(
|
|
35
|
+
email=datasite, workspace_dir=simulated_syftbox_network_dir
|
|
36
|
+
)
|
|
37
|
+
do_stack = SingleRDSStack(client=do_syftbox_client)
|
|
38
|
+
do_rds_client = do_stack.init_session(host=datasite)
|
|
39
|
+
do_rds_clients.append(do_rds_client)
|
|
40
|
+
|
|
41
|
+
return simulated_syftbox_network_dir, do_rds_clients, ds_rds_client
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _bootstrap_encryption_keys(
|
|
45
|
+
do_clients: list[RDSClient], ds_client: RDSClient
|
|
46
|
+
) -> None:
|
|
47
|
+
"""Bootstrap the encryption keys for all clients if encryption is enabled."""
|
|
48
|
+
# Check if encryption is enabled
|
|
49
|
+
encryption_enabled = (
|
|
50
|
+
os.environ.get(SYFT_FLWR_ENCRYPTION_ENABLED, "true").lower() != "false"
|
|
51
|
+
)
|
|
29
52
|
|
|
30
|
-
|
|
53
|
+
if not encryption_enabled:
|
|
54
|
+
logger.warning("⚠️ Encryption disabled - skipping key bootstrap")
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
logger.info("🔐 Bootstrapping encryption keys for all participants...")
|
|
58
|
+
|
|
59
|
+
all_syftbox_clients: list[Client] = []
|
|
60
|
+
|
|
61
|
+
# Bootstrap server
|
|
62
|
+
try:
|
|
63
|
+
server_client: Client = ds_client._syftbox_client
|
|
64
|
+
ensure_bootstrap(server_client)
|
|
65
|
+
server_client_did_path = did_path(server_client, server_client.email)
|
|
66
|
+
server_client_private_key_path = private_key_path(server_client)
|
|
67
|
+
logger.debug(
|
|
68
|
+
f"✅ Server {ds_client.email} bootstrapped with private encryption keys at {server_client_private_key_path} and did path at {server_client_did_path}"
|
|
69
|
+
)
|
|
70
|
+
all_syftbox_clients.append(server_client)
|
|
71
|
+
except Exception as e:
|
|
72
|
+
logger.error(f"❌ Failed to bootstrap server {ds_client.email}: {e}")
|
|
73
|
+
raise
|
|
74
|
+
|
|
75
|
+
# Bootstrap each client
|
|
76
|
+
for do_client in do_clients:
|
|
77
|
+
try:
|
|
78
|
+
client: Client = do_client._syftbox_client
|
|
79
|
+
ensure_bootstrap(client)
|
|
80
|
+
client_did_path = did_path(client, client.email)
|
|
81
|
+
client_did_doc = get_did_document(client, client.email)
|
|
82
|
+
client_private_key_path = private_key_path(client)
|
|
83
|
+
logger.debug(
|
|
84
|
+
f"✅ Client {do_client.email} bootstrapped with private encryption keys at {client_private_key_path} and did path at {client_did_path} with content: {client_did_doc}"
|
|
85
|
+
)
|
|
86
|
+
all_syftbox_clients.append(client)
|
|
87
|
+
except Exception as e:
|
|
88
|
+
logger.error(f"❌ Failed to bootstrap client {do_client.email}: {e}")
|
|
89
|
+
raise
|
|
90
|
+
|
|
91
|
+
# Verify all DID documents are accessible
|
|
92
|
+
for checking_client in all_syftbox_clients:
|
|
93
|
+
for target_client in all_syftbox_clients:
|
|
94
|
+
if checking_client.email != target_client.email:
|
|
95
|
+
# Verify that checking_client can see target_client's DID document
|
|
96
|
+
did_file_path = did_path(checking_client, target_client.email)
|
|
97
|
+
if did_file_path.exists():
|
|
98
|
+
logger.debug(
|
|
99
|
+
f"✅ {checking_client.email} can see {target_client.email}'s DID at {did_file_path}"
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
logger.warning(
|
|
103
|
+
f"⚠️ {checking_client.email} cannot find {target_client.email}'s DID at {did_file_path}"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
logger.info("🔐 All participants bootstrapped for E2E encryption ✅✅✅")
|
|
31
107
|
|
|
32
108
|
|
|
33
109
|
async def _run_main_py(
|
|
@@ -35,7 +111,7 @@ async def _run_main_py(
|
|
|
35
111
|
config_path: Path,
|
|
36
112
|
client_email: str,
|
|
37
113
|
log_dir: Path,
|
|
38
|
-
dataset_path: Union[str, Path]
|
|
114
|
+
dataset_path: Optional[Union[str, Path]] = None,
|
|
39
115
|
) -> int:
|
|
40
116
|
"""Run the `main.py` file for a given client"""
|
|
41
117
|
log_file_path = log_dir / f"{client_email}.log"
|
|
@@ -49,7 +125,7 @@ async def _run_main_py(
|
|
|
49
125
|
try:
|
|
50
126
|
with open(log_file_path, "w") as f:
|
|
51
127
|
process = await asyncio.create_subprocess_exec(
|
|
52
|
-
|
|
128
|
+
sys.executable, # Use the current Python executable
|
|
53
129
|
str(main_py_path),
|
|
54
130
|
"-s",
|
|
55
131
|
stdout=f,
|
|
@@ -131,7 +207,7 @@ async def _run_simulated_flwr_project(
|
|
|
131
207
|
return run_success
|
|
132
208
|
|
|
133
209
|
|
|
134
|
-
def
|
|
210
|
+
def validate_bootstraped_project(project_dir: Path) -> None:
|
|
135
211
|
"""Validate a bootstraped `syft_flwr` project directory"""
|
|
136
212
|
if not project_dir.exists():
|
|
137
213
|
raise FileNotFoundError(f"Project directory {project_dir} does not exist")
|
|
@@ -159,42 +235,65 @@ def _validate_mock_dataset_paths(mock_dataset_paths: list[str]) -> list[Path]:
|
|
|
159
235
|
|
|
160
236
|
def run(
|
|
161
237
|
project_dir: Union[str, Path], mock_dataset_paths: list[Union[str, Path]]
|
|
162
|
-
) ->
|
|
163
|
-
"""Run a syft_flwr project in simulation mode over mock data
|
|
238
|
+
) -> Union[bool, asyncio.Task]:
|
|
239
|
+
"""Run a syft_flwr project in simulation mode over mock data.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
bool: True if simulation succeeded, False otherwise (synchronous execution)
|
|
243
|
+
asyncio.Task: Task handle if running in async environment (e.g., Jupyter)
|
|
244
|
+
"""
|
|
164
245
|
|
|
165
246
|
project_dir = Path(project_dir).expanduser().resolve()
|
|
166
|
-
|
|
247
|
+
validate_bootstraped_project(project_dir)
|
|
167
248
|
mock_dataset_paths = _validate_mock_dataset_paths(mock_dataset_paths)
|
|
168
249
|
|
|
169
|
-
|
|
250
|
+
# Skip module validation during testing to avoid parallel test issues
|
|
251
|
+
skip_module_check = (
|
|
252
|
+
os.environ.get("SYFT_FLWR_SKIP_MODULE_CHECK", "false").lower() == "true"
|
|
253
|
+
)
|
|
254
|
+
pyproject_conf = load_flwr_pyproject(
|
|
255
|
+
project_dir, check_module=not skip_module_check
|
|
256
|
+
)
|
|
170
257
|
datasites = pyproject_conf["tool"]["syft_flwr"]["datasites"]
|
|
171
258
|
aggregator = pyproject_conf["tool"]["syft_flwr"]["aggregator"]
|
|
172
259
|
|
|
173
|
-
|
|
260
|
+
simulated_syftbox_network_dir, do_clients, ds_client = _setup_mock_rds_clients(
|
|
174
261
|
project_dir, aggregator, datasites
|
|
175
262
|
)
|
|
176
263
|
|
|
264
|
+
_bootstrap_encryption_keys(do_clients, ds_client)
|
|
265
|
+
|
|
266
|
+
simulation_success = False # Track success status
|
|
267
|
+
|
|
177
268
|
async def main():
|
|
269
|
+
nonlocal simulation_success
|
|
178
270
|
try:
|
|
179
271
|
run_success = await _run_simulated_flwr_project(
|
|
180
272
|
project_dir, do_clients, ds_client, mock_dataset_paths
|
|
181
273
|
)
|
|
274
|
+
simulation_success = run_success
|
|
182
275
|
if run_success:
|
|
183
276
|
logger.success("Simulation completed successfully ✅")
|
|
184
277
|
else:
|
|
185
278
|
logger.error("Simulation failed ❌")
|
|
186
279
|
except Exception as e:
|
|
187
280
|
logger.error(f"Simulation failed ❌: {e}")
|
|
281
|
+
simulation_success = False
|
|
188
282
|
finally:
|
|
189
283
|
# Clean up the RDS stack
|
|
190
|
-
remove_rds_stack_dir(
|
|
191
|
-
logger.debug(f"Removed RDS stack: {
|
|
284
|
+
remove_rds_stack_dir(simulated_syftbox_network_dir)
|
|
285
|
+
logger.debug(f"Removed RDS stack: {simulated_syftbox_network_dir}")
|
|
286
|
+
# Also remove the .syftbox folder that saves the config files and private keys
|
|
287
|
+
remove_rds_stack_dir(simulated_syftbox_network_dir.parent / ".syftbox")
|
|
288
|
+
|
|
289
|
+
return simulation_success
|
|
192
290
|
|
|
193
291
|
try:
|
|
194
292
|
loop = asyncio.get_running_loop()
|
|
195
293
|
logger.debug(f"Running in an environment with an existing event loop {loop}")
|
|
196
294
|
# We are in an environment with an existing event loop (like Jupyter)
|
|
197
|
-
asyncio.create_task(main())
|
|
295
|
+
task = asyncio.create_task(main())
|
|
296
|
+
return task # Return the task so callers can await it
|
|
198
297
|
except RuntimeError:
|
|
199
298
|
logger.debug("No existing event loop, creating and running one")
|
|
200
299
|
# No existing event loop, create and run one (for scripts)
|
|
@@ -202,3 +301,4 @@ def run(
|
|
|
202
301
|
asyncio.set_event_loop(loop)
|
|
203
302
|
loop.run_until_complete(main())
|
|
204
303
|
loop.close()
|
|
304
|
+
return simulation_success # Return success status for synchronous execution
|
syft_flwr/utils.py
CHANGED
|
@@ -3,6 +3,13 @@ import re
|
|
|
3
3
|
import zlib
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
|
|
6
|
+
from loguru import logger
|
|
7
|
+
from syft_core import Client, SyftClientConfig
|
|
8
|
+
from syft_crypto.x3dh_bootstrap import ensure_bootstrap
|
|
9
|
+
from typing_extensions import Tuple
|
|
10
|
+
|
|
11
|
+
from syft_flwr.consts import SYFT_FLWR_ENCRYPTION_ENABLED
|
|
12
|
+
|
|
6
13
|
EMAIL_REGEX = r"^[^@]+@[^@]+\.[^@]+$"
|
|
7
14
|
|
|
8
15
|
|
|
@@ -34,3 +41,41 @@ def run_syft_flwr() -> bool:
|
|
|
34
41
|
return True
|
|
35
42
|
except FileNotFoundError:
|
|
36
43
|
return False
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def create_temp_client(email: str, workspace_dir: Path) -> Client:
|
|
47
|
+
"""Create a temporary Client instance for testing"""
|
|
48
|
+
workspace_hash = hash(str(workspace_dir)) % 10000
|
|
49
|
+
server_port = 8080 + workspace_hash
|
|
50
|
+
client_port = 8082 + workspace_hash
|
|
51
|
+
config: SyftClientConfig = SyftClientConfig(
|
|
52
|
+
email=email,
|
|
53
|
+
data_dir=workspace_dir,
|
|
54
|
+
server_url=f"http://localhost:{server_port}",
|
|
55
|
+
client_url=f"http://localhost:{client_port}",
|
|
56
|
+
path=workspace_dir / ".syftbox" / f"{email.split('@')[0]}_config.json",
|
|
57
|
+
).save()
|
|
58
|
+
logger.debug(f"Created temp client {email} with config {config}")
|
|
59
|
+
return Client(config)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def setup_client(app_name: str) -> Tuple[Client, bool, str]:
|
|
63
|
+
"""Setup SyftBox client and encryption."""
|
|
64
|
+
client = Client.load()
|
|
65
|
+
|
|
66
|
+
# Check encryption setting
|
|
67
|
+
encryption_enabled = (
|
|
68
|
+
os.environ.get(SYFT_FLWR_ENCRYPTION_ENABLED, "true").lower() != "false"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Bootstrap encryption if needed
|
|
72
|
+
if encryption_enabled:
|
|
73
|
+
client = ensure_bootstrap(client)
|
|
74
|
+
logger.info("🔐 End-to-end encryption is ENABLED for FL messages")
|
|
75
|
+
else:
|
|
76
|
+
logger.warning("⚠️ Encryption disabled - skipping client key bootstrap")
|
|
77
|
+
logger.warning(
|
|
78
|
+
"⚠️ End-to-end encryption is DISABLED for FL messages (development mode / insecure)"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
return client, encryption_enabled, f"flwr/{app_name}"
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: syft-flwr
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: syft_flwr is an open source framework that facilitate federated learning projects using Flower over the SyftBox protocol
|
|
5
5
|
License-File: LICENSE
|
|
6
|
-
Requires-Python: >=3.
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
7
|
Requires-Dist: flwr-datasets[vision]>=0.5.0
|
|
8
8
|
Requires-Dist: flwr[simulation]>=1.20.0
|
|
9
9
|
Requires-Dist: loguru>=0.7.3
|
|
10
10
|
Requires-Dist: safetensors>=0.6.2
|
|
11
|
-
Requires-Dist: syft-rds
|
|
11
|
+
Requires-Dist: syft-rds>=0.2.1
|
|
12
12
|
Requires-Dist: tomli-w>=1.2.0
|
|
13
13
|
Requires-Dist: tomli>=2.2.1
|
|
14
14
|
Requires-Dist: typing-extensions>=4.13.0
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
syft_flwr/__init__.py,sha256=ayTCluE-u-ttvrPNq2REmFOfBOd6Zmr_QBF9oh2BpJc,426
|
|
2
|
+
syft_flwr/bootstrap.py,sha256=-T6SRh_p6u6uWpbTPZ6-URsAfMQAI2jakpjZAh0UUlw,3690
|
|
3
|
+
syft_flwr/cli.py,sha256=imctwdQMxQeGQZaiKSX1Mo2nU_-RmA-cGB3H4huuUeA,3274
|
|
4
|
+
syft_flwr/config.py,sha256=4hwkovGtFOLNULjJwoGYcA0uT4y3vZSrxndXqYXquMY,821
|
|
5
|
+
syft_flwr/consts.py,sha256=u3QK-Wp8D2Va7iZcp5z4ormVm_FAUDeK4u-w81UL_eY,107
|
|
6
|
+
syft_flwr/flower_client.py,sha256=-HsPp2Uw0RlILthezQHzZdwyVAzklbh_VHnl_nANgx8,7210
|
|
7
|
+
syft_flwr/flower_server.py,sha256=ZNDUR1U79M0BaG7n39TGUkVHV2NYi-LDsN8FqKJFfFQ,1508
|
|
8
|
+
syft_flwr/flwr_compatibility.py,sha256=vURf9rfsZ1uPm04szw6RpGYxtlG3BE4tW3YijptiGyk,3197
|
|
9
|
+
syft_flwr/grid.py,sha256=zjhkKHIYxRoCmda75Bw1L1Qra7b5DXhMTpY7L3Ujy_4,17799
|
|
10
|
+
syft_flwr/mounts.py,sha256=hp0TKVot16SaPYO10Y_mSJGei7aiNteJfK4U4vynWmU,2330
|
|
11
|
+
syft_flwr/run.py,sha256=OPW9bVt366DT-U-SxMpMLNXASwTZjp7XNNXfDP767f4,2153
|
|
12
|
+
syft_flwr/run_simulation.py,sha256=frHytbsxYLjiCM4r4m1NVQOc1j98hm4sQQoBLeagJi8,11539
|
|
13
|
+
syft_flwr/serde.py,sha256=5fCI-cRUOh5wE7cXQd4J6jex1grRGnyD1Jx-VlEDOXM,495
|
|
14
|
+
syft_flwr/utils.py,sha256=SC-lnCydP9t2_FNlUZEUFDcb6wtIE9v0soiW8nH7G0w,2594
|
|
15
|
+
syft_flwr/strategy/__init__.py,sha256=mpUmExjjFkqU8gg41XsOBKfO3aqCBe7XPJSU-_P7smU,97
|
|
16
|
+
syft_flwr/strategy/fedavg.py,sha256=N8jULUkjvuaBIEVINowyQln8W8yFhkO-J8k0-iPcGMA,1562
|
|
17
|
+
syft_flwr/templates/main.py.tpl,sha256=p0uK97jvLGk3LJdy1_HF1R5BQgIjaTGkYnr-csfh39M,791
|
|
18
|
+
syft_flwr-0.2.0.dist-info/METADATA,sha256=R6CudwzWKXVdL-wKi-iJd_Fm7yEGc9eNXgiTSu5sPsM,1254
|
|
19
|
+
syft_flwr-0.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
20
|
+
syft_flwr-0.2.0.dist-info/entry_points.txt,sha256=o7oT0dCoHn-3WyIwdDw1lBh2q-GvhB_8s0hWeJU4myc,49
|
|
21
|
+
syft_flwr-0.2.0.dist-info/licenses/LICENSE,sha256=0msOUar8uPZTqkAOTBp4rCzd7Jl9eRhfKiNufwrsg7k,11361
|
|
22
|
+
syft_flwr-0.2.0.dist-info/RECORD,,
|
syft_flwr-0.1.7.dist-info/RECORD
DELETED
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
syft_flwr/__init__.py,sha256=oGfKpHbO65Irb9RHFJHb7-RyS-rlcszl_tIxbssgdXU,426
|
|
2
|
-
syft_flwr/bootstrap.py,sha256=-T6SRh_p6u6uWpbTPZ6-URsAfMQAI2jakpjZAh0UUlw,3690
|
|
3
|
-
syft_flwr/cli.py,sha256=imctwdQMxQeGQZaiKSX1Mo2nU_-RmA-cGB3H4huuUeA,3274
|
|
4
|
-
syft_flwr/config.py,sha256=4hwkovGtFOLNULjJwoGYcA0uT4y3vZSrxndXqYXquMY,821
|
|
5
|
-
syft_flwr/flower_client.py,sha256=5UrtfwSTUDwPQlir7mQWLLVxulGcT4Mcy17uz1-UAlk,3685
|
|
6
|
-
syft_flwr/flower_server.py,sha256=sJwSEqePmkmWKGFXm2E44Ugoc6aaz-6tM7UaeWM2-co,1353
|
|
7
|
-
syft_flwr/flwr_compatibility.py,sha256=vURf9rfsZ1uPm04szw6RpGYxtlG3BE4tW3YijptiGyk,3197
|
|
8
|
-
syft_flwr/grid.py,sha256=Me2tivW0v1ApTjdjQffUc9f1UCHh1LtkYcKUjE82iZ8,7735
|
|
9
|
-
syft_flwr/mounts.py,sha256=ry3_3eD4aPkRahk9eibfu88TpQjgp_KQ96G7yj692x4,2319
|
|
10
|
-
syft_flwr/run.py,sha256=OPW9bVt366DT-U-SxMpMLNXASwTZjp7XNNXfDP767f4,2153
|
|
11
|
-
syft_flwr/run_simulation.py,sha256=t3shhfzAWDUlf6iQmwf5sS9APZQE9mkaZ9MLCYs9Ng0,6922
|
|
12
|
-
syft_flwr/serde.py,sha256=5fCI-cRUOh5wE7cXQd4J6jex1grRGnyD1Jx-VlEDOXM,495
|
|
13
|
-
syft_flwr/utils.py,sha256=3dDYEB7btq9hxZ9UsfQWh3i44OerAhGXc5XaX5wO3-o,955
|
|
14
|
-
syft_flwr/strategy/__init__.py,sha256=mpUmExjjFkqU8gg41XsOBKfO3aqCBe7XPJSU-_P7smU,97
|
|
15
|
-
syft_flwr/strategy/fedavg.py,sha256=N8jULUkjvuaBIEVINowyQln8W8yFhkO-J8k0-iPcGMA,1562
|
|
16
|
-
syft_flwr/templates/main.py.tpl,sha256=p0uK97jvLGk3LJdy1_HF1R5BQgIjaTGkYnr-csfh39M,791
|
|
17
|
-
syft_flwr-0.1.7.dist-info/METADATA,sha256=-FVHhD66zCCZ3FUZDGwzISjNFJuKhzWp3gHlKKRDqb4,1255
|
|
18
|
-
syft_flwr-0.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
19
|
-
syft_flwr-0.1.7.dist-info/entry_points.txt,sha256=o7oT0dCoHn-3WyIwdDw1lBh2q-GvhB_8s0hWeJU4myc,49
|
|
20
|
-
syft_flwr-0.1.7.dist-info/licenses/LICENSE,sha256=0msOUar8uPZTqkAOTBp4rCzd7Jl9eRhfKiNufwrsg7k,11361
|
|
21
|
-
syft_flwr-0.1.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|