syft-flwr 0.1.7__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of syft-flwr might be problematic. Click here for more details.
- syft_flwr/__init__.py +1 -1
- syft_flwr/consts.py +2 -0
- syft_flwr/flower_client.py +150 -61
- syft_flwr/flower_server.py +12 -4
- syft_flwr/grid.py +446 -109
- syft_flwr/mounts.py +1 -1
- syft_flwr/run.py +2 -1
- syft_flwr/run_simulation.py +124 -24
- syft_flwr/utils.py +79 -0
- {syft_flwr-0.1.7.dist-info → syft_flwr-0.2.1.dist-info}/METADATA +3 -3
- syft_flwr-0.2.1.dist-info/RECORD +21 -0
- syft_flwr/flwr_compatibility.py +0 -121
- syft_flwr-0.1.7.dist-info/RECORD +0 -21
- {syft_flwr-0.1.7.dist-info → syft_flwr-0.2.1.dist-info}/WHEEL +0 -0
- {syft_flwr-0.1.7.dist-info → syft_flwr-0.2.1.dist-info}/entry_points.txt +0 -0
- {syft_flwr-0.1.7.dist-info → syft_flwr-0.2.1.dist-info}/licenses/LICENSE +0 -0
syft_flwr/grid.py
CHANGED
|
@@ -1,32 +1,31 @@
|
|
|
1
|
+
import base64
|
|
1
2
|
import os
|
|
3
|
+
import random
|
|
2
4
|
import time
|
|
3
|
-
from typing import Iterable, cast
|
|
4
5
|
|
|
5
6
|
from flwr.common import ConfigRecord
|
|
6
7
|
from flwr.common.constant import MessageType
|
|
7
8
|
from flwr.common.message import Message
|
|
9
|
+
from flwr.common.record import RecordDict
|
|
8
10
|
from flwr.common.typing import Run
|
|
9
11
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
12
|
+
from flwr.server.grid import Grid
|
|
10
13
|
from loguru import logger
|
|
11
14
|
from syft_core import Client
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
RecordDict,
|
|
18
|
-
check_reply_to_field,
|
|
19
|
-
create_flwr_message,
|
|
20
|
-
)
|
|
15
|
+
from syft_crypto import EncryptedPayload, decrypt_message
|
|
16
|
+
from syft_rpc import SyftResponse, rpc, rpc_db
|
|
17
|
+
from typing_extensions import Dict, Iterable, List, Optional, Tuple, cast
|
|
18
|
+
|
|
19
|
+
from syft_flwr.consts import SYFT_FLWR_ENCRYPTION_ENABLED
|
|
21
20
|
from syft_flwr.serde import bytes_to_flower_message, flower_message_to_bytes
|
|
22
|
-
from syft_flwr.utils import str_to_int
|
|
21
|
+
from syft_flwr.utils import check_reply_to_field, create_flwr_message, str_to_int
|
|
23
22
|
|
|
24
23
|
# this is what superlink super node do
|
|
25
24
|
AGGREGATOR_NODE_ID = 1
|
|
26
25
|
|
|
27
|
-
|
|
28
26
|
# env vars
|
|
29
27
|
SYFT_FLWR_MSG_TIMEOUT = "SYFT_FLWR_MSG_TIMEOUT"
|
|
28
|
+
SYFT_FLWR_POLL_INTERVAL = "SYFT_FLWR_POLL_INTERVAL"
|
|
30
29
|
|
|
31
30
|
|
|
32
31
|
class SyftGrid(Grid):
|
|
@@ -36,39 +35,66 @@ class SyftGrid(Grid):
|
|
|
36
35
|
datasites: list[str] = [],
|
|
37
36
|
client: Client = None,
|
|
38
37
|
) -> None:
|
|
38
|
+
"""
|
|
39
|
+
SyftGrid is the server-side message orchestrator for federated learning in syft_flwr.
|
|
40
|
+
It acts as a bridge between Flower's server logic and SyftBox's communication layer:
|
|
41
|
+
|
|
42
|
+
Flower Server → SyftGrid → syft_rpc → SyftBox network → FL Clients
|
|
43
|
+
↑ ↓
|
|
44
|
+
└──────────── responses ←─────────────────┘
|
|
45
|
+
|
|
46
|
+
SyftGrid enables Flower's centralized server to communicate with distributed SyftBox
|
|
47
|
+
clients without knowing the underlying transport details.
|
|
48
|
+
|
|
49
|
+
Core functionalities:
|
|
50
|
+
- push_messages(): Sends messages to clients via syft_rpc, returns future IDs
|
|
51
|
+
- pull_messages(): Retrieves responses using futures
|
|
52
|
+
- send_and_receive(): Combines push/pull with timeout handling
|
|
53
|
+
"""
|
|
39
54
|
self._client = Client.load() if client is None else client
|
|
40
55
|
self._run: Optional[Run] = None
|
|
41
56
|
self.node = Node(node_id=AGGREGATOR_NODE_ID)
|
|
42
57
|
self.datasites = datasites
|
|
43
58
|
self.client_map = {str_to_int(ds): ds for ds in self.datasites}
|
|
59
|
+
|
|
60
|
+
# Check if encryption is enabled (default: True for production)
|
|
61
|
+
self._encryption_enabled = (
|
|
62
|
+
os.environ.get(SYFT_FLWR_ENCRYPTION_ENABLED, "true").lower() != "false"
|
|
63
|
+
)
|
|
64
|
+
|
|
44
65
|
logger.debug(
|
|
45
66
|
f"Initialize SyftGrid for '{self._client.email}' with datasites: {self.datasites}"
|
|
46
67
|
)
|
|
68
|
+
if self._encryption_enabled:
|
|
69
|
+
logger.info("🔐 End-to-end encryption is ENABLED for FL messages")
|
|
70
|
+
else:
|
|
71
|
+
logger.warning(
|
|
72
|
+
"⚠️ End-to-end encryption is DISABLED for FL messages (development mode / insecure)"
|
|
73
|
+
)
|
|
74
|
+
|
|
47
75
|
self.app_name = app_name
|
|
48
76
|
|
|
49
77
|
def set_run(self, run_id: int) -> None:
|
|
50
|
-
|
|
51
|
-
# do we need to do the same here, where the run id is set from an external context.
|
|
78
|
+
"""Set the run ID for this federated learning session.
|
|
52
79
|
|
|
80
|
+
Args:
|
|
81
|
+
run_id: Unique identifier for the FL run/session
|
|
82
|
+
|
|
83
|
+
Note:
|
|
84
|
+
In Grpc Grid case, the superlink sets up the run id.
|
|
85
|
+
Here, the run id is set from an external context.
|
|
86
|
+
"""
|
|
53
87
|
# Convert to Flower Run object
|
|
54
88
|
self._run = Run.create_empty(run_id)
|
|
55
89
|
|
|
56
90
|
@property
|
|
57
91
|
def run(self) -> Run:
|
|
58
|
-
"""Run
|
|
59
|
-
return Run(**vars(cast(Run, self._run)))
|
|
92
|
+
"""Get the current Flower Run object.
|
|
60
93
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
and message.metadata.src_node_id == self.node.node_id
|
|
66
|
-
and message.metadata.message_id == ""
|
|
67
|
-
and check_reply_to_field(message.metadata)
|
|
68
|
-
and message.metadata.ttl > 0
|
|
69
|
-
):
|
|
70
|
-
logger.debug(f"Invalid message with metadata: {message.metadata}")
|
|
71
|
-
raise ValueError(f"Invalid message: {message}")
|
|
94
|
+
Returns:
|
|
95
|
+
A copy of the current Run object with run metadata
|
|
96
|
+
"""
|
|
97
|
+
return Run(**vars(cast(Run, self._run)))
|
|
72
98
|
|
|
73
99
|
def create_message(
|
|
74
100
|
self,
|
|
@@ -78,81 +104,106 @@ class SyftGrid(Grid):
|
|
|
78
104
|
group_id: str,
|
|
79
105
|
ttl: Optional[float] = None,
|
|
80
106
|
) -> Message:
|
|
81
|
-
"""Create a new message with
|
|
107
|
+
"""Create a new Flower message with proper metadata.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
content: Message payload as RecordDict (e.g., model parameters, metrics)
|
|
111
|
+
message_type: Type of FL message (e.g., MessageType.TRAIN, MessageType.EVALUATE)
|
|
112
|
+
dst_node_id: Destination node ID (client identifier)
|
|
113
|
+
group_id: Message group identifier for related messages
|
|
114
|
+
ttl: Time-to-live in seconds (optional, for message expiration)
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
A Flower Message object ready to be sent to a client
|
|
118
|
+
|
|
119
|
+
Note:
|
|
120
|
+
Automatically adds current run_id and server's node_id to metadata.
|
|
121
|
+
"""
|
|
82
122
|
return create_flwr_message(
|
|
83
123
|
content=content,
|
|
84
124
|
message_type=message_type,
|
|
85
125
|
dst_node_id=dst_node_id,
|
|
86
126
|
group_id=group_id,
|
|
87
127
|
ttl=ttl,
|
|
88
|
-
run_id=cast(Run, self._run).run_id,
|
|
89
|
-
src_node_id=self.node.node_id,
|
|
90
128
|
)
|
|
91
129
|
|
|
92
130
|
def get_node_ids(self) -> list[int]:
|
|
93
|
-
"""Get node IDs of all connected
|
|
94
|
-
|
|
131
|
+
"""Get node IDs of all connected FL clients.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
List of integer node IDs representing connected datasites/clients
|
|
135
|
+
|
|
136
|
+
Note:
|
|
137
|
+
Node IDs are deterministically generated from datasite email addresses
|
|
138
|
+
using str_to_int() for consistent client identification.
|
|
139
|
+
"""
|
|
95
140
|
return list(self.client_map.keys())
|
|
96
141
|
|
|
97
142
|
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
98
|
-
"""Push messages to specified
|
|
99
|
-
|
|
100
|
-
|
|
143
|
+
"""Push FL messages to specified clients asynchronously.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
messages: Iterable of Flower Messages to send to clients
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
List of future IDs that can be used to retrieve responses
|
|
150
|
+
"""
|
|
101
151
|
message_ids = []
|
|
152
|
+
|
|
102
153
|
for msg in messages:
|
|
103
|
-
#
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
# RPC URL
|
|
107
|
-
dest_datasite = self.client_map[msg.metadata.dst_node_id]
|
|
108
|
-
url = rpc.make_url(
|
|
109
|
-
dest_datasite, app_name=self.app_name, endpoint="messages"
|
|
110
|
-
)
|
|
111
|
-
# Check message
|
|
112
|
-
self._check_message(msg)
|
|
113
|
-
# Serialize message
|
|
114
|
-
msg_bytes = flower_message_to_bytes(msg)
|
|
154
|
+
# Prepare message
|
|
155
|
+
dest_datasite, url, msg_bytes = self._prepare_message(msg)
|
|
156
|
+
|
|
115
157
|
# Send message
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
158
|
+
if self._encryption_enabled:
|
|
159
|
+
future_id = self._send_encrypted_message(
|
|
160
|
+
url, msg_bytes, dest_datasite, msg
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
future_id = self._send_unencrypted_message(
|
|
164
|
+
url, msg_bytes, dest_datasite, msg
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if future_id:
|
|
168
|
+
message_ids.append(future_id)
|
|
125
169
|
|
|
126
170
|
return message_ids
|
|
127
171
|
|
|
128
|
-
def pull_messages(self, message_ids):
|
|
129
|
-
"""Pull messages
|
|
172
|
+
def pull_messages(self, message_ids: List[str]) -> Dict[str, Message]:
|
|
173
|
+
"""Pull response messages from clients using future IDs.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
message_ids: List of future IDs from push_messages()
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
Dict mapping message_id to Flower Message response
|
|
180
|
+
"""
|
|
130
181
|
messages = {}
|
|
131
182
|
|
|
132
183
|
for msg_id in message_ids:
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
184
|
+
try:
|
|
185
|
+
# Get and resolve future
|
|
186
|
+
future = rpc_db.get_future(future_id=msg_id, client=self._client)
|
|
187
|
+
response = future.resolve()
|
|
137
188
|
|
|
138
|
-
|
|
189
|
+
if response is None:
|
|
190
|
+
continue # Message not ready yet
|
|
139
191
|
|
|
140
|
-
|
|
141
|
-
raise ValueError(f"Empty response: {response}")
|
|
192
|
+
response.raise_for_status()
|
|
142
193
|
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
194
|
+
# Process the response
|
|
195
|
+
message = self._process_response(response, msg_id)
|
|
196
|
+
|
|
197
|
+
if message:
|
|
198
|
+
messages[msg_id] = message
|
|
199
|
+
rpc_db.delete_future(future_id=msg_id, client=self._client)
|
|
200
|
+
|
|
201
|
+
except Exception as e:
|
|
202
|
+
logger.error(f"❌ Unexpected error pulling message {msg_id}: {e}")
|
|
149
203
|
continue
|
|
150
204
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
)
|
|
154
|
-
messages[msg_id] = message
|
|
155
|
-
rpc_db.delete_future(future_id=msg_id, client=self._client)
|
|
205
|
+
# Log summary
|
|
206
|
+
self._log_pull_summary(messages, message_ids)
|
|
156
207
|
|
|
157
208
|
return messages
|
|
158
209
|
|
|
@@ -164,47 +215,49 @@ class SyftGrid(Grid):
|
|
|
164
215
|
) -> Iterable[Message]:
|
|
165
216
|
"""Push messages to specified node IDs and pull the reply messages.
|
|
166
217
|
|
|
167
|
-
This method sends
|
|
168
|
-
|
|
169
|
-
|
|
218
|
+
This method sends messages to their destination nodes and waits for replies.
|
|
219
|
+
It continues polling until all replies are received or timeout is reached.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
messages: Messages to send
|
|
223
|
+
timeout: Maximum time to wait for replies (seconds).
|
|
224
|
+
Can be overridden by SYFT_FLWR_MSG_TIMEOUT env var.
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
Collection of reply messages received
|
|
170
228
|
"""
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
if timeout is not None:
|
|
174
|
-
logger.debug(
|
|
175
|
-
f"syft_flwr messages timeout = {timeout}: Will move on after {timeout} (s) if no reply is received"
|
|
176
|
-
)
|
|
177
|
-
else:
|
|
178
|
-
logger.debug(
|
|
179
|
-
"syft_flwr messages timeout = None: Will wait indefinitely for replies"
|
|
180
|
-
)
|
|
229
|
+
# Get timeout from environment or parameter
|
|
230
|
+
timeout = self._get_timeout(timeout)
|
|
181
231
|
|
|
182
|
-
# Push messages
|
|
232
|
+
# Push messages and get IDs
|
|
183
233
|
msg_ids = set(self.push_messages(messages))
|
|
234
|
+
if not msg_ids:
|
|
235
|
+
return []
|
|
184
236
|
|
|
185
|
-
#
|
|
186
|
-
|
|
187
|
-
ret = {}
|
|
188
|
-
while timeout is None or time.time() < end_time:
|
|
189
|
-
res_msgs = self.pull_messages(msg_ids)
|
|
190
|
-
ret.update(res_msgs)
|
|
191
|
-
msg_ids.difference_update(res_msgs.keys())
|
|
192
|
-
if len(msg_ids) == 0: # All messages received
|
|
193
|
-
break
|
|
194
|
-
time.sleep(3) # polling interval
|
|
195
|
-
|
|
196
|
-
if msg_ids:
|
|
197
|
-
logger.warning(
|
|
198
|
-
f"Timeout reached. {len(msg_ids)} message(s) sent out but not replied."
|
|
199
|
-
)
|
|
237
|
+
# Poll for responses
|
|
238
|
+
responses = self._poll_for_responses(msg_ids, timeout)
|
|
200
239
|
|
|
201
|
-
return
|
|
240
|
+
return responses.values()
|
|
202
241
|
|
|
203
242
|
def send_stop_signal(
|
|
204
243
|
self, group_id: str, reason: str = "Training complete", ttl: float = 60.0
|
|
205
|
-
) ->
|
|
206
|
-
"""Send a stop signal to all
|
|
207
|
-
|
|
244
|
+
) -> List[Message]:
|
|
245
|
+
"""Send a stop signal to all connected FL clients.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
group_id: Identifier for this group of stop messages
|
|
249
|
+
reason: Human-readable reason for stopping (default: "Training complete")
|
|
250
|
+
ttl: Time-to-live for stop messages in seconds (default: 60.0)
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
List of stop Messages that were sent
|
|
254
|
+
|
|
255
|
+
Note:
|
|
256
|
+
Used to gracefully terminate FL clients when training completes or
|
|
257
|
+
when the server encounters an error. Clients will shut down upon
|
|
258
|
+
receiving this SYSTEM message with action="stop".
|
|
259
|
+
"""
|
|
260
|
+
stop_messages: List[Message] = [
|
|
208
261
|
self.create_message(
|
|
209
262
|
content=RecordDict(
|
|
210
263
|
{"config": ConfigRecord({"action": "stop", "reason": reason})}
|
|
@@ -219,3 +272,287 @@ class SyftGrid(Grid):
|
|
|
219
272
|
self.push_messages(stop_messages)
|
|
220
273
|
|
|
221
274
|
return stop_messages
|
|
275
|
+
|
|
276
|
+
def _check_message(self, message: Message) -> None:
|
|
277
|
+
"""Validate a Flower message before sending.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
message: The Flower Message to validate
|
|
281
|
+
|
|
282
|
+
Raises:
|
|
283
|
+
ValueError: If message metadata is invalid (wrong run_id, src_node_id,
|
|
284
|
+
missing ttl, or invalid reply_to field)
|
|
285
|
+
|
|
286
|
+
Note:
|
|
287
|
+
Ensures message belongs to current run and originates from this server node.
|
|
288
|
+
"""
|
|
289
|
+
if not (
|
|
290
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
291
|
+
and message.metadata.src_node_id == self.node.node_id
|
|
292
|
+
and message.metadata.message_id == ""
|
|
293
|
+
and check_reply_to_field(message.metadata)
|
|
294
|
+
and message.metadata.ttl > 0
|
|
295
|
+
):
|
|
296
|
+
logger.debug(f"Invalid message with metadata: {message.metadata}")
|
|
297
|
+
raise ValueError(f"Invalid message: {message}")
|
|
298
|
+
|
|
299
|
+
def _prepare_message(self, msg: Message) -> Tuple[str, str, bytes]:
|
|
300
|
+
"""Prepare a message for sending.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
Tuple of (destination_datasite, url, message_bytes)
|
|
304
|
+
"""
|
|
305
|
+
run_id = cast(Run, self._run).run_id
|
|
306
|
+
msg.metadata.__dict__["_run_id"] = run_id
|
|
307
|
+
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
|
308
|
+
|
|
309
|
+
dest_datasite = self.client_map[msg.metadata.dst_node_id]
|
|
310
|
+
url = rpc.make_url(dest_datasite, app_name=self.app_name, endpoint="messages")
|
|
311
|
+
|
|
312
|
+
self._check_message(msg)
|
|
313
|
+
msg_bytes = flower_message_to_bytes(msg)
|
|
314
|
+
|
|
315
|
+
return dest_datasite, url, msg_bytes
|
|
316
|
+
|
|
317
|
+
def _retry_with_backoff(
|
|
318
|
+
self,
|
|
319
|
+
func,
|
|
320
|
+
max_retries: int = 3,
|
|
321
|
+
initial_delay: float = 0.1,
|
|
322
|
+
context: str = "",
|
|
323
|
+
check_error=None,
|
|
324
|
+
):
|
|
325
|
+
"""Generic retry logic with exponential backoff and jitter.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
func: Function to retry
|
|
329
|
+
max_retries: Maximum number of retry attempts
|
|
330
|
+
initial_delay: Initial delay in seconds
|
|
331
|
+
context: Context string for logging
|
|
332
|
+
check_error: Optional function to check if error is retryable
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
Result of func if successful
|
|
336
|
+
|
|
337
|
+
Raises:
|
|
338
|
+
Last exception if all retries fail
|
|
339
|
+
"""
|
|
340
|
+
for attempt in range(max_retries):
|
|
341
|
+
try:
|
|
342
|
+
return func()
|
|
343
|
+
except Exception as e:
|
|
344
|
+
is_retryable = check_error(e) if check_error else True
|
|
345
|
+
if is_retryable and attempt < max_retries - 1:
|
|
346
|
+
jitter = random.uniform(0, 0.05)
|
|
347
|
+
delay = initial_delay * (2**attempt) + jitter
|
|
348
|
+
logger.debug(
|
|
349
|
+
f"{context} failed (attempt {attempt + 1}/{max_retries}): {e}. "
|
|
350
|
+
f"Retrying in {delay:.3f}s"
|
|
351
|
+
)
|
|
352
|
+
time.sleep(delay)
|
|
353
|
+
else:
|
|
354
|
+
raise
|
|
355
|
+
|
|
356
|
+
def _save_future_with_retry(self, future, dest_datasite: str) -> bool:
|
|
357
|
+
"""Save future to database with retry logic for database locks.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
True if saved successfully, False if failed after retries
|
|
361
|
+
"""
|
|
362
|
+
try:
|
|
363
|
+
self._retry_with_backoff(
|
|
364
|
+
func=lambda: rpc_db.save_future(
|
|
365
|
+
future=future, namespace=self.app_name, client=self._client
|
|
366
|
+
),
|
|
367
|
+
context=f"Database save for {dest_datasite}",
|
|
368
|
+
check_error=lambda e: "database is locked" in str(e).lower(),
|
|
369
|
+
)
|
|
370
|
+
return True
|
|
371
|
+
except Exception as e:
|
|
372
|
+
logger.warning(
|
|
373
|
+
f"⚠️ Failed to save future to database for {dest_datasite}: {e}. "
|
|
374
|
+
f"Message sent but future not persisted."
|
|
375
|
+
)
|
|
376
|
+
return False
|
|
377
|
+
|
|
378
|
+
def _send_encrypted_message(
|
|
379
|
+
self, url: str, msg_bytes: bytes, dest_datasite: str, msg: Message
|
|
380
|
+
) -> Optional[str]:
|
|
381
|
+
"""Send an encrypted message and return future ID if successful."""
|
|
382
|
+
try:
|
|
383
|
+
# Send encrypted message
|
|
384
|
+
future = rpc.send(
|
|
385
|
+
url=url,
|
|
386
|
+
body=base64.b64encode(msg_bytes).decode("utf-8"),
|
|
387
|
+
client=self._client,
|
|
388
|
+
encrypt=True,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
logger.debug(
|
|
392
|
+
f"🔐 Pushed ENCRYPTED message to {dest_datasite} at {url} "
|
|
393
|
+
f"with metadata {msg.metadata}; size {len(msg_bytes) / 1024 / 1024:.2f} MB"
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
# Save future to database (non-critical - log warning if fails)
|
|
397
|
+
self._save_future_with_retry(future, dest_datasite)
|
|
398
|
+
return future.id
|
|
399
|
+
|
|
400
|
+
except (KeyError, ValueError) as e:
|
|
401
|
+
# Encryption setup errors - don't retry or fallback
|
|
402
|
+
error_type = (
|
|
403
|
+
"Encryption key" if isinstance(e, KeyError) else "Encryption parameter"
|
|
404
|
+
)
|
|
405
|
+
logger.error(
|
|
406
|
+
f"❌ {error_type} error for {dest_datasite}: {e}. "
|
|
407
|
+
f"Skipping message to node {msg.metadata.dst_node_id}"
|
|
408
|
+
)
|
|
409
|
+
return None
|
|
410
|
+
|
|
411
|
+
except Exception as e:
|
|
412
|
+
# Other errors - fallback to unencrypted
|
|
413
|
+
logger.warning(
|
|
414
|
+
f"⚠️ Encryption failed for {dest_datasite}: {e}. "
|
|
415
|
+
f"Falling back to unencrypted transmission"
|
|
416
|
+
)
|
|
417
|
+
return self._send_unencrypted_message(url, msg_bytes, dest_datasite, msg)
|
|
418
|
+
|
|
419
|
+
def _send_unencrypted_message(
|
|
420
|
+
self, url: str, msg_bytes: bytes, dest_datasite: str, msg: Message
|
|
421
|
+
) -> Optional[str]:
|
|
422
|
+
"""Send an unencrypted message and return future ID if successful."""
|
|
423
|
+
try:
|
|
424
|
+
future = rpc.send(url=url, body=msg_bytes, client=self._client)
|
|
425
|
+
logger.debug(
|
|
426
|
+
f"📤 Pushed PLAINTEXT message to {dest_datasite} at {url} "
|
|
427
|
+
f"with metadata {msg.metadata}; size {len(msg_bytes) / 1024 / 1024:.2f} MB"
|
|
428
|
+
)
|
|
429
|
+
rpc_db.save_future(
|
|
430
|
+
future=future, namespace=self.app_name, client=self._client
|
|
431
|
+
)
|
|
432
|
+
return future.id
|
|
433
|
+
|
|
434
|
+
except Exception as e:
|
|
435
|
+
logger.error(f"❌ Failed to send message to {dest_datasite}: {e}")
|
|
436
|
+
return None
|
|
437
|
+
|
|
438
|
+
def _poll_for_responses(
|
|
439
|
+
self, msg_ids: set, timeout: Optional[float]
|
|
440
|
+
) -> Dict[str, Message]:
|
|
441
|
+
"""Poll for responses until all received or timeout."""
|
|
442
|
+
end_time = time.time() + (timeout if timeout is not None else float("inf"))
|
|
443
|
+
responses = {}
|
|
444
|
+
pending_ids = msg_ids.copy()
|
|
445
|
+
|
|
446
|
+
# Get polling interval from environment or use default
|
|
447
|
+
poll_interval = float(os.environ.get(SYFT_FLWR_POLL_INTERVAL, "3"))
|
|
448
|
+
|
|
449
|
+
while pending_ids and (timeout is None or time.time() < end_time):
|
|
450
|
+
# Pull available messages
|
|
451
|
+
batch = self.pull_messages(pending_ids)
|
|
452
|
+
responses.update(batch)
|
|
453
|
+
pending_ids.difference_update(batch.keys())
|
|
454
|
+
|
|
455
|
+
if pending_ids:
|
|
456
|
+
time.sleep(poll_interval) # Configurable polling interval
|
|
457
|
+
|
|
458
|
+
# Log any missing responses
|
|
459
|
+
if pending_ids:
|
|
460
|
+
logger.warning(
|
|
461
|
+
f"Timeout reached. {len(pending_ids)} message(s) not received."
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
return responses
|
|
465
|
+
|
|
466
|
+
def _process_response(
|
|
467
|
+
self, response: SyftResponse, msg_id: str
|
|
468
|
+
) -> Optional[Message]:
|
|
469
|
+
"""Process a single response and return the deserialized message."""
|
|
470
|
+
if not response.body:
|
|
471
|
+
logger.warning(f"⚠️ Empty response for message {msg_id}, skipping")
|
|
472
|
+
return None
|
|
473
|
+
|
|
474
|
+
response_body = response.body
|
|
475
|
+
|
|
476
|
+
# Try to decrypt if encryption is enabled
|
|
477
|
+
if self._encryption_enabled:
|
|
478
|
+
response_body = self._try_decrypt_response(response.body, msg_id)
|
|
479
|
+
|
|
480
|
+
# Deserialize message
|
|
481
|
+
try:
|
|
482
|
+
message = bytes_to_flower_message(response_body)
|
|
483
|
+
except Exception as e:
|
|
484
|
+
logger.error(
|
|
485
|
+
f"❌ Failed to deserialize message {msg_id}: {e}. "
|
|
486
|
+
f"Message may be corrupted or in incompatible format."
|
|
487
|
+
)
|
|
488
|
+
return None
|
|
489
|
+
|
|
490
|
+
# Check for errors in message
|
|
491
|
+
if message.has_error():
|
|
492
|
+
error = message.error
|
|
493
|
+
logger.error(
|
|
494
|
+
f"❌ Message {msg_id} returned error with code={error.code}, "
|
|
495
|
+
f"reason={error.reason}"
|
|
496
|
+
)
|
|
497
|
+
return None
|
|
498
|
+
|
|
499
|
+
# Log successful pull
|
|
500
|
+
encryption_status = (
|
|
501
|
+
"🔐 ENCRYPTED" if self._encryption_enabled else "📥 PLAINTEXT"
|
|
502
|
+
)
|
|
503
|
+
logger.debug(
|
|
504
|
+
f"{encryption_status} Pulled message from {response.url} "
|
|
505
|
+
f"with metadata: {message.metadata}, "
|
|
506
|
+
f"size: {len(response_body) / 1024 / 1024:.2f} MB"
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
return message
|
|
510
|
+
|
|
511
|
+
def _try_decrypt_response(self, body: bytes, msg_id: str) -> bytes:
|
|
512
|
+
"""Try to decrypt response body if it's encrypted."""
|
|
513
|
+
try:
|
|
514
|
+
# Try to parse as encrypted payload
|
|
515
|
+
encrypted_payload = EncryptedPayload.model_validate_json(body.decode())
|
|
516
|
+
# Decrypt the message
|
|
517
|
+
decrypted_body = decrypt_message(encrypted_payload, client=self._client)
|
|
518
|
+
# The decrypted body should be a base64-encoded string
|
|
519
|
+
response_body = base64.b64decode(decrypted_body)
|
|
520
|
+
logger.debug(f"🔓 Successfully decrypted response for message {msg_id}")
|
|
521
|
+
return response_body
|
|
522
|
+
except Exception as e:
|
|
523
|
+
# If decryption fails, assume plaintext
|
|
524
|
+
logger.debug(
|
|
525
|
+
f"📥 Response appears to be plaintext or decryption not needed "
|
|
526
|
+
f"for message {msg_id}: {e}"
|
|
527
|
+
)
|
|
528
|
+
return body
|
|
529
|
+
|
|
530
|
+
def _log_pull_summary(
|
|
531
|
+
self, messages: Dict[str, Message], message_ids: List[str]
|
|
532
|
+
) -> None:
|
|
533
|
+
"""Log summary of pulled messages."""
|
|
534
|
+
if messages:
|
|
535
|
+
if self._encryption_enabled:
|
|
536
|
+
logger.info(
|
|
537
|
+
f"🔐 Successfully pulled {len(messages)} messages (encryption enabled)"
|
|
538
|
+
)
|
|
539
|
+
else:
|
|
540
|
+
logger.info(f"📥 Successfully pulled {len(messages)} messages")
|
|
541
|
+
elif message_ids:
|
|
542
|
+
logger.debug(
|
|
543
|
+
f"No messages pulled yet from {len(message_ids)} attempts "
|
|
544
|
+
f"(clients may still be processing)"
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
def _get_timeout(self, timeout: Optional[float]) -> Optional[float]:
|
|
548
|
+
"""Get timeout value from environment or parameter."""
|
|
549
|
+
env_timeout = os.environ.get(SYFT_FLWR_MSG_TIMEOUT)
|
|
550
|
+
if env_timeout is not None:
|
|
551
|
+
timeout = float(env_timeout)
|
|
552
|
+
|
|
553
|
+
if timeout is not None:
|
|
554
|
+
logger.debug(f"Message timeout: {timeout}s")
|
|
555
|
+
else:
|
|
556
|
+
logger.debug("No timeout - will wait indefinitely for replies")
|
|
557
|
+
|
|
558
|
+
return timeout
|
syft_flwr/mounts.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import os
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import List
|
|
5
4
|
|
|
6
5
|
import tomli
|
|
7
6
|
from loguru import logger
|
|
8
7
|
from syft_core import Client
|
|
9
8
|
from syft_rds.models import DockerMount, JobConfig
|
|
10
9
|
from syft_rds.syft_runtime.mounts import MountProvider
|
|
10
|
+
from typing_extensions import List
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class SyftFlwrMountProvider(MountProvider):
|
syft_flwr/run.py
CHANGED
|
@@ -5,11 +5,12 @@ from uuid import uuid4
|
|
|
5
5
|
from flwr.client.client_app import LoadClientAppError
|
|
6
6
|
from flwr.common import Context
|
|
7
7
|
from flwr.common.object_ref import load_app
|
|
8
|
+
from flwr.common.record import RecordDict
|
|
8
9
|
from flwr.server.server_app import LoadServerAppError
|
|
10
|
+
|
|
9
11
|
from syft_flwr.config import load_flwr_pyproject
|
|
10
12
|
from syft_flwr.flower_client import syftbox_flwr_client
|
|
11
13
|
from syft_flwr.flower_server import syftbox_flwr_server
|
|
12
|
-
from syft_flwr.flwr_compatibility import RecordDict
|
|
13
14
|
from syft_flwr.run_simulation import run
|
|
14
15
|
|
|
15
16
|
__all__ = ["syftbox_run_flwr_client", "syftbox_run_flwr_server", "run"]
|