syft-flwr 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- syft_flwr/__init__.py +15 -0
- syft_flwr/bootstrap.py +102 -0
- syft_flwr/cli.py +116 -0
- syft_flwr/config.py +36 -0
- syft_flwr/consts.py +2 -0
- syft_flwr/flower_client.py +199 -0
- syft_flwr/flower_server.py +50 -0
- syft_flwr/grid.py +580 -0
- syft_flwr/mounts.py +62 -0
- syft_flwr/run.py +63 -0
- syft_flwr/run_simulation.py +328 -0
- syft_flwr/serde.py +15 -0
- syft_flwr/strategy/__init__.py +3 -0
- syft_flwr/strategy/fedavg.py +38 -0
- syft_flwr/templates/main.py.tpl +31 -0
- syft_flwr/utils.py +126 -0
- syft_flwr-0.4.3.dist-info/METADATA +31 -0
- syft_flwr-0.4.3.dist-info/RECORD +21 -0
- syft_flwr-0.4.3.dist-info/WHEEL +4 -0
- syft_flwr-0.4.3.dist-info/entry_points.txt +2 -0
- syft_flwr-0.4.3.dist-info/licenses/LICENSE +201 -0
syft_flwr/grid.py
ADDED
|
@@ -0,0 +1,580 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import os
|
|
3
|
+
import random
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
from flwr.common import ConfigRecord
|
|
7
|
+
from flwr.common.constant import MessageType
|
|
8
|
+
from flwr.common.message import Message
|
|
9
|
+
from flwr.common.record import RecordDict
|
|
10
|
+
from flwr.common.typing import Run
|
|
11
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
12
|
+
from flwr.server.grid import Grid
|
|
13
|
+
from loguru import logger
|
|
14
|
+
from syft_core import Client
|
|
15
|
+
from syft_crypto import EncryptedPayload, decrypt_message
|
|
16
|
+
from syft_rpc import SyftResponse, rpc, rpc_db
|
|
17
|
+
from typing_extensions import Dict, Iterable, List, Optional, Tuple, cast
|
|
18
|
+
|
|
19
|
+
from syft_flwr.consts import SYFT_FLWR_ENCRYPTION_ENABLED
|
|
20
|
+
from syft_flwr.serde import bytes_to_flower_message, flower_message_to_bytes
|
|
21
|
+
from syft_flwr.utils import check_reply_to_field, create_flwr_message, str_to_int
|
|
22
|
+
|
|
23
|
+
# this is what superlink super node do
|
|
24
|
+
AGGREGATOR_NODE_ID = 1
|
|
25
|
+
|
|
26
|
+
# env vars
|
|
27
|
+
SYFT_FLWR_MSG_TIMEOUT = "SYFT_FLWR_MSG_TIMEOUT"
|
|
28
|
+
SYFT_FLWR_POLL_INTERVAL = "SYFT_FLWR_POLL_INTERVAL"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SyftGrid(Grid):
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
app_name: str,
|
|
35
|
+
datasites: list[str] = [],
|
|
36
|
+
client: Client = None,
|
|
37
|
+
) -> None:
|
|
38
|
+
"""
|
|
39
|
+
SyftGrid is the server-side message orchestrator for federated learning in syft_flwr.
|
|
40
|
+
It acts as a bridge between Flower's server logic and SyftBox's communication layer:
|
|
41
|
+
|
|
42
|
+
Flower Server → SyftGrid → syft_rpc → SyftBox network → FL Clients
|
|
43
|
+
↑ ↓
|
|
44
|
+
└──────────── responses ←─────────────────┘
|
|
45
|
+
|
|
46
|
+
SyftGrid enables Flower's centralized server to communicate with distributed SyftBox
|
|
47
|
+
clients without knowing the underlying transport details.
|
|
48
|
+
|
|
49
|
+
Core functionalities:
|
|
50
|
+
- push_messages(): Sends messages to clients via syft_rpc, returns future IDs
|
|
51
|
+
- pull_messages(): Retrieves responses using futures
|
|
52
|
+
- send_and_receive(): Combines push/pull with timeout handling
|
|
53
|
+
"""
|
|
54
|
+
self._client = Client.load() if client is None else client
|
|
55
|
+
self._run: Optional[Run] = None
|
|
56
|
+
self.node = Node(node_id=AGGREGATOR_NODE_ID)
|
|
57
|
+
self.datasites = datasites
|
|
58
|
+
self.client_map = {str_to_int(ds): ds for ds in self.datasites}
|
|
59
|
+
|
|
60
|
+
# Check if encryption is enabled (default: True for production)
|
|
61
|
+
self._encryption_enabled = (
|
|
62
|
+
os.environ.get(SYFT_FLWR_ENCRYPTION_ENABLED, "true").lower() != "false"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
logger.debug(
|
|
66
|
+
f"Initialize SyftGrid for '{self._client.email}' with datasites: {self.datasites}"
|
|
67
|
+
)
|
|
68
|
+
if self._encryption_enabled:
|
|
69
|
+
logger.info("🔐 End-to-end encryption is ENABLED for FL messages")
|
|
70
|
+
else:
|
|
71
|
+
logger.warning(
|
|
72
|
+
"⚠️ End-to-end encryption is DISABLED for FL messages (development mode / insecure)"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
self.app_name = app_name
|
|
76
|
+
|
|
77
|
+
def set_run(self, run_id: int) -> None:
|
|
78
|
+
"""Set the run ID for this federated learning session.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
run_id: Unique identifier for the FL run/session
|
|
82
|
+
|
|
83
|
+
Note:
|
|
84
|
+
In Grpc Grid case, the superlink sets up the run id.
|
|
85
|
+
Here, the run id is set from an external context.
|
|
86
|
+
"""
|
|
87
|
+
# Convert to Flower Run object
|
|
88
|
+
self._run = Run.create_empty(run_id)
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def run(self) -> Run:
|
|
92
|
+
"""Get the current Flower Run object.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
A copy of the current Run object with run metadata
|
|
96
|
+
"""
|
|
97
|
+
return Run(**vars(cast(Run, self._run)))
|
|
98
|
+
|
|
99
|
+
def create_message(
|
|
100
|
+
self,
|
|
101
|
+
content: RecordDict,
|
|
102
|
+
message_type: str,
|
|
103
|
+
dst_node_id: int,
|
|
104
|
+
group_id: str,
|
|
105
|
+
ttl: Optional[float] = None,
|
|
106
|
+
) -> Message:
|
|
107
|
+
"""Create a new Flower message with proper metadata.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
content: Message payload as RecordDict (e.g., model parameters, metrics)
|
|
111
|
+
message_type: Type of FL message (e.g., MessageType.TRAIN, MessageType.EVALUATE)
|
|
112
|
+
dst_node_id: Destination node ID (client identifier)
|
|
113
|
+
group_id: Message group identifier for related messages
|
|
114
|
+
ttl: Time-to-live in seconds (optional, for message expiration)
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
A Flower Message object ready to be sent to a client
|
|
118
|
+
|
|
119
|
+
Note:
|
|
120
|
+
Automatically adds current run_id and server's node_id to metadata.
|
|
121
|
+
"""
|
|
122
|
+
return create_flwr_message(
|
|
123
|
+
content=content,
|
|
124
|
+
message_type=message_type,
|
|
125
|
+
dst_node_id=dst_node_id,
|
|
126
|
+
group_id=group_id,
|
|
127
|
+
ttl=ttl,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def get_node_ids(self) -> list[int]:
|
|
131
|
+
"""Get node IDs of all connected FL clients.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
List of integer node IDs representing connected datasites/clients
|
|
135
|
+
|
|
136
|
+
Note:
|
|
137
|
+
Node IDs are deterministically generated from datasite email addresses
|
|
138
|
+
using str_to_int() for consistent client identification.
|
|
139
|
+
"""
|
|
140
|
+
return list(self.client_map.keys())
|
|
141
|
+
|
|
142
|
+
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
143
|
+
"""Push FL messages to specified clients asynchronously.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
messages: Iterable of Flower Messages to send to clients
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
List of future IDs that can be used to retrieve responses
|
|
150
|
+
"""
|
|
151
|
+
message_ids = []
|
|
152
|
+
|
|
153
|
+
for msg in messages:
|
|
154
|
+
# Prepare message
|
|
155
|
+
dest_datasite, url, msg_bytes = self._prepare_message(msg)
|
|
156
|
+
|
|
157
|
+
# Send message
|
|
158
|
+
if self._encryption_enabled:
|
|
159
|
+
future_id = self._send_encrypted_message(
|
|
160
|
+
url, msg_bytes, dest_datasite, msg
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
future_id = self._send_unencrypted_message(
|
|
164
|
+
url, msg_bytes, dest_datasite, msg
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if future_id:
|
|
168
|
+
message_ids.append(future_id)
|
|
169
|
+
|
|
170
|
+
return message_ids
|
|
171
|
+
|
|
172
|
+
def pull_messages(self, message_ids: List[str]) -> Tuple[Dict[str, Message], set]:
|
|
173
|
+
"""Pull response messages from clients using future IDs.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
message_ids: List of future IDs from push_messages()
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
Tuple of:
|
|
180
|
+
- Dict mapping message_id to Flower Message response (includes both successes and client errors)
|
|
181
|
+
- Set of message_ids that are completed (got response, deserialized successfully, or permanently failed)
|
|
182
|
+
"""
|
|
183
|
+
messages = {}
|
|
184
|
+
completed_ids = set()
|
|
185
|
+
|
|
186
|
+
for msg_id in message_ids:
|
|
187
|
+
try:
|
|
188
|
+
# Get and resolve future
|
|
189
|
+
future = rpc_db.get_future(future_id=msg_id, client=self._client)
|
|
190
|
+
response = future.resolve()
|
|
191
|
+
|
|
192
|
+
if response is None:
|
|
193
|
+
continue # Message not ready yet
|
|
194
|
+
|
|
195
|
+
response.raise_for_status()
|
|
196
|
+
|
|
197
|
+
# Process the response
|
|
198
|
+
message = self._process_response(response, msg_id)
|
|
199
|
+
|
|
200
|
+
# Always delete the future once we get a response (success or error)
|
|
201
|
+
# This prevents retrying failed messages indefinitely
|
|
202
|
+
rpc_db.delete_future(future_id=msg_id, client=self._client)
|
|
203
|
+
|
|
204
|
+
# Mark as completed regardless of success/failure
|
|
205
|
+
completed_ids.add(msg_id)
|
|
206
|
+
|
|
207
|
+
if message:
|
|
208
|
+
messages[msg_id] = message
|
|
209
|
+
|
|
210
|
+
except Exception as e:
|
|
211
|
+
logger.error(f"❌ Unexpected error pulling message {msg_id}: {e}")
|
|
212
|
+
continue
|
|
213
|
+
|
|
214
|
+
# Log summary
|
|
215
|
+
self._log_pull_summary(messages, message_ids)
|
|
216
|
+
|
|
217
|
+
return messages, completed_ids
|
|
218
|
+
|
|
219
|
+
def send_and_receive(
|
|
220
|
+
self,
|
|
221
|
+
messages: Iterable[Message],
|
|
222
|
+
*,
|
|
223
|
+
timeout: Optional[float] = None,
|
|
224
|
+
) -> Iterable[Message]:
|
|
225
|
+
"""Push messages to specified node IDs and pull the reply messages.
|
|
226
|
+
|
|
227
|
+
This method sends messages to their destination nodes and waits for replies.
|
|
228
|
+
It continues polling until all replies are received or timeout is reached.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
messages: Messages to send
|
|
232
|
+
timeout: Maximum time to wait for replies (seconds).
|
|
233
|
+
Can be overridden by SYFT_FLWR_MSG_TIMEOUT env var.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Collection of reply messages received
|
|
237
|
+
"""
|
|
238
|
+
# Get timeout from environment or parameter
|
|
239
|
+
timeout = self._get_timeout(timeout)
|
|
240
|
+
|
|
241
|
+
# Push messages and get IDs
|
|
242
|
+
msg_ids = set(self.push_messages(messages))
|
|
243
|
+
if not msg_ids:
|
|
244
|
+
return []
|
|
245
|
+
|
|
246
|
+
# Poll for responses
|
|
247
|
+
responses = self._poll_for_responses(msg_ids, timeout)
|
|
248
|
+
|
|
249
|
+
return responses.values()
|
|
250
|
+
|
|
251
|
+
def send_stop_signal(
|
|
252
|
+
self, group_id: str, reason: str = "Training complete", ttl: float = 60.0
|
|
253
|
+
) -> List[Message]:
|
|
254
|
+
"""Send a stop signal to all connected FL clients.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
group_id: Identifier for this group of stop messages
|
|
258
|
+
reason: Human-readable reason for stopping (default: "Training complete")
|
|
259
|
+
ttl: Time-to-live for stop messages in seconds (default: 60.0)
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
List of stop Messages that were sent
|
|
263
|
+
|
|
264
|
+
Note:
|
|
265
|
+
Used to gracefully terminate FL clients when training completes or
|
|
266
|
+
when the server encounters an error. Clients will shut down upon
|
|
267
|
+
receiving this SYSTEM message with action="stop".
|
|
268
|
+
"""
|
|
269
|
+
stop_messages: List[Message] = [
|
|
270
|
+
self.create_message(
|
|
271
|
+
content=RecordDict(
|
|
272
|
+
{"config": ConfigRecord({"action": "stop", "reason": reason})}
|
|
273
|
+
),
|
|
274
|
+
message_type=MessageType.SYSTEM,
|
|
275
|
+
dst_node_id=node_id,
|
|
276
|
+
group_id=group_id,
|
|
277
|
+
ttl=ttl,
|
|
278
|
+
)
|
|
279
|
+
for node_id in self.get_node_ids()
|
|
280
|
+
]
|
|
281
|
+
self.push_messages(stop_messages)
|
|
282
|
+
|
|
283
|
+
return stop_messages
|
|
284
|
+
|
|
285
|
+
def _check_message(self, message: Message) -> None:
|
|
286
|
+
"""Validate a Flower message before sending.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
message: The Flower Message to validate
|
|
290
|
+
|
|
291
|
+
Raises:
|
|
292
|
+
ValueError: If message metadata is invalid (wrong run_id, src_node_id,
|
|
293
|
+
missing ttl, or invalid reply_to field)
|
|
294
|
+
|
|
295
|
+
Note:
|
|
296
|
+
Ensures message belongs to current run and originates from this server node.
|
|
297
|
+
"""
|
|
298
|
+
if not (
|
|
299
|
+
message.metadata.run_id == cast(Run, self._run).run_id
|
|
300
|
+
and message.metadata.src_node_id == self.node.node_id
|
|
301
|
+
and message.metadata.message_id == ""
|
|
302
|
+
and check_reply_to_field(message.metadata)
|
|
303
|
+
and message.metadata.ttl > 0
|
|
304
|
+
):
|
|
305
|
+
logger.debug(f"Invalid message with metadata: {message.metadata}")
|
|
306
|
+
raise ValueError(f"Invalid message: {message}")
|
|
307
|
+
|
|
308
|
+
def _prepare_message(self, msg: Message) -> Tuple[str, str, bytes]:
|
|
309
|
+
"""Prepare a message for sending.
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
Tuple of (destination_datasite, url, message_bytes)
|
|
313
|
+
"""
|
|
314
|
+
run_id = cast(Run, self._run).run_id
|
|
315
|
+
msg.metadata.__dict__["_run_id"] = run_id
|
|
316
|
+
msg.metadata.__dict__["_src_node_id"] = self.node.node_id
|
|
317
|
+
|
|
318
|
+
dest_datasite = self.client_map[msg.metadata.dst_node_id]
|
|
319
|
+
url = rpc.make_url(dest_datasite, app_name=self.app_name, endpoint="messages")
|
|
320
|
+
|
|
321
|
+
self._check_message(msg)
|
|
322
|
+
msg_bytes = flower_message_to_bytes(msg)
|
|
323
|
+
|
|
324
|
+
return dest_datasite, url, msg_bytes
|
|
325
|
+
|
|
326
|
+
def _retry_with_backoff(
|
|
327
|
+
self,
|
|
328
|
+
func,
|
|
329
|
+
max_retries: int = 3,
|
|
330
|
+
initial_delay: float = 0.1,
|
|
331
|
+
context: str = "",
|
|
332
|
+
check_error=None,
|
|
333
|
+
):
|
|
334
|
+
"""Generic retry logic with exponential backoff and jitter.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
func: Function to retry
|
|
338
|
+
max_retries: Maximum number of retry attempts
|
|
339
|
+
initial_delay: Initial delay in seconds
|
|
340
|
+
context: Context string for logging
|
|
341
|
+
check_error: Optional function to check if error is retryable
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
Result of func if successful
|
|
345
|
+
|
|
346
|
+
Raises:
|
|
347
|
+
Last exception if all retries fail
|
|
348
|
+
"""
|
|
349
|
+
for attempt in range(max_retries):
|
|
350
|
+
try:
|
|
351
|
+
return func()
|
|
352
|
+
except Exception as e:
|
|
353
|
+
is_retryable = check_error(e) if check_error else True
|
|
354
|
+
if is_retryable and attempt < max_retries - 1:
|
|
355
|
+
jitter = random.uniform(0, 0.05)
|
|
356
|
+
delay = initial_delay * (2**attempt) + jitter
|
|
357
|
+
logger.debug(
|
|
358
|
+
f"{context} failed (attempt {attempt + 1}/{max_retries}): {e}. "
|
|
359
|
+
f"Retrying in {delay:.3f}s"
|
|
360
|
+
)
|
|
361
|
+
time.sleep(delay)
|
|
362
|
+
else:
|
|
363
|
+
raise
|
|
364
|
+
|
|
365
|
+
def _save_future_with_retry(self, future, dest_datasite: str) -> bool:
|
|
366
|
+
"""Save future to database with retry logic for database locks.
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
True if saved successfully, False if failed after retries
|
|
370
|
+
"""
|
|
371
|
+
try:
|
|
372
|
+
self._retry_with_backoff(
|
|
373
|
+
func=lambda: rpc_db.save_future(
|
|
374
|
+
future=future, namespace=self.app_name, client=self._client
|
|
375
|
+
),
|
|
376
|
+
context=f"Database save for {dest_datasite}",
|
|
377
|
+
check_error=lambda e: "database is locked" in str(e).lower(),
|
|
378
|
+
)
|
|
379
|
+
return True
|
|
380
|
+
except Exception as e:
|
|
381
|
+
logger.warning(
|
|
382
|
+
f"⚠️ Failed to save future to database for {dest_datasite}: {e}. "
|
|
383
|
+
f"Message sent but future not persisted."
|
|
384
|
+
)
|
|
385
|
+
return False
|
|
386
|
+
|
|
387
|
+
def _send_encrypted_message(
|
|
388
|
+
self, url: str, msg_bytes: bytes, dest_datasite: str, msg: Message
|
|
389
|
+
) -> Optional[str]:
|
|
390
|
+
"""Send an encrypted message and return future ID if successful."""
|
|
391
|
+
try:
|
|
392
|
+
# Send encrypted message
|
|
393
|
+
future = rpc.send(
|
|
394
|
+
url=url,
|
|
395
|
+
body=base64.b64encode(msg_bytes).decode("utf-8"),
|
|
396
|
+
client=self._client,
|
|
397
|
+
encrypt=True,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
logger.debug(
|
|
401
|
+
f"🔐 Pushed ENCRYPTED message to {dest_datasite} at {url} "
|
|
402
|
+
f"with metadata {msg.metadata}; size {len(msg_bytes) / 1024 / 1024:.2f} MB"
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# Save future to database (non-critical - log warning if fails)
|
|
406
|
+
self._save_future_with_retry(future, dest_datasite)
|
|
407
|
+
return future.id
|
|
408
|
+
|
|
409
|
+
except (KeyError, ValueError) as e:
|
|
410
|
+
# Encryption setup errors - don't retry or fallback
|
|
411
|
+
error_type = (
|
|
412
|
+
"Encryption key" if isinstance(e, KeyError) else "Encryption parameter"
|
|
413
|
+
)
|
|
414
|
+
logger.error(
|
|
415
|
+
f"❌ {error_type} error for {dest_datasite}: {e}. "
|
|
416
|
+
f"Skipping message to node {msg.metadata.dst_node_id}"
|
|
417
|
+
)
|
|
418
|
+
return None
|
|
419
|
+
|
|
420
|
+
except Exception as e:
|
|
421
|
+
# Other errors - fallback to unencrypted
|
|
422
|
+
logger.warning(
|
|
423
|
+
f"⚠️ Encryption failed for {dest_datasite}: {e}. "
|
|
424
|
+
f"Falling back to unencrypted transmission"
|
|
425
|
+
)
|
|
426
|
+
return self._send_unencrypted_message(url, msg_bytes, dest_datasite, msg)
|
|
427
|
+
|
|
428
|
+
def _send_unencrypted_message(
|
|
429
|
+
self, url: str, msg_bytes: bytes, dest_datasite: str, msg: Message
|
|
430
|
+
) -> Optional[str]:
|
|
431
|
+
"""Send an unencrypted message and return future ID if successful."""
|
|
432
|
+
try:
|
|
433
|
+
future = rpc.send(url=url, body=msg_bytes, client=self._client)
|
|
434
|
+
logger.debug(
|
|
435
|
+
f"📤 Pushed PLAINTEXT message to {dest_datasite} at {url} "
|
|
436
|
+
f"with metadata {msg.metadata}; size {len(msg_bytes) / 1024 / 1024:.2f} MB"
|
|
437
|
+
)
|
|
438
|
+
rpc_db.save_future(
|
|
439
|
+
future=future, namespace=self.app_name, client=self._client
|
|
440
|
+
)
|
|
441
|
+
return future.id
|
|
442
|
+
|
|
443
|
+
except Exception as e:
|
|
444
|
+
logger.error(f"❌ Failed to send message to {dest_datasite}: {e}")
|
|
445
|
+
return None
|
|
446
|
+
|
|
447
|
+
def _poll_for_responses(
|
|
448
|
+
self, msg_ids: set, timeout: Optional[float]
|
|
449
|
+
) -> Dict[str, Message]:
|
|
450
|
+
"""Poll for responses until all received or timeout."""
|
|
451
|
+
end_time = time.time() + (timeout if timeout is not None else float("inf"))
|
|
452
|
+
responses = {}
|
|
453
|
+
pending_ids = msg_ids.copy()
|
|
454
|
+
|
|
455
|
+
# Get polling interval from environment or use default
|
|
456
|
+
poll_interval = float(os.environ.get(SYFT_FLWR_POLL_INTERVAL, "3"))
|
|
457
|
+
|
|
458
|
+
while pending_ids and (timeout is None or time.time() < end_time):
|
|
459
|
+
# Pull available messages
|
|
460
|
+
batch, completed = self.pull_messages(list(pending_ids))
|
|
461
|
+
responses.update(batch)
|
|
462
|
+
# Remove all completed IDs (both successes and failures)
|
|
463
|
+
pending_ids.difference_update(completed)
|
|
464
|
+
|
|
465
|
+
if pending_ids:
|
|
466
|
+
time.sleep(poll_interval) # Configurable polling interval
|
|
467
|
+
|
|
468
|
+
# Log any missing responses
|
|
469
|
+
if pending_ids:
|
|
470
|
+
logger.warning(
|
|
471
|
+
f"Timeout reached. {len(pending_ids)} message(s) not received."
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
return responses
|
|
475
|
+
|
|
476
|
+
def _process_response(
|
|
477
|
+
self, response: SyftResponse, msg_id: str
|
|
478
|
+
) -> Optional[Message]:
|
|
479
|
+
"""Process a single response and return the deserialized message."""
|
|
480
|
+
if not response.body:
|
|
481
|
+
logger.warning(f"⚠️ Empty response for message {msg_id}, skipping")
|
|
482
|
+
return None
|
|
483
|
+
|
|
484
|
+
response_body = response.body
|
|
485
|
+
|
|
486
|
+
# Try to decrypt if encryption is enabled
|
|
487
|
+
if self._encryption_enabled:
|
|
488
|
+
response_body = self._try_decrypt_response(response.body, msg_id)
|
|
489
|
+
|
|
490
|
+
# Deserialize message
|
|
491
|
+
try:
|
|
492
|
+
message = bytes_to_flower_message(response_body)
|
|
493
|
+
except Exception as e:
|
|
494
|
+
logger.error(
|
|
495
|
+
f"❌ Failed to deserialize message {msg_id}: {e}. "
|
|
496
|
+
f"Message may be corrupted or in incompatible format."
|
|
497
|
+
)
|
|
498
|
+
return None
|
|
499
|
+
|
|
500
|
+
# Check for errors in message (but still return it so Flower can handle the failure)
|
|
501
|
+
if message.has_error():
|
|
502
|
+
error = message.error
|
|
503
|
+
logger.error(
|
|
504
|
+
f"❌ Message {msg_id} returned error with code={error.code}, "
|
|
505
|
+
f"reason={error.reason}. Returning error message to Flower for proper failure handling."
|
|
506
|
+
)
|
|
507
|
+
else:
|
|
508
|
+
# Log successful pull only if no error
|
|
509
|
+
encryption_status = (
|
|
510
|
+
"🔐 ENCRYPTED" if self._encryption_enabled else "📥 PLAINTEXT"
|
|
511
|
+
)
|
|
512
|
+
logger.debug(
|
|
513
|
+
f"{encryption_status} Pulled message from {response.url} "
|
|
514
|
+
f"with metadata: {message.metadata}, "
|
|
515
|
+
f"size: {len(response_body) / 1024 / 1024:.2f} MB"
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
# Always return the message (even with errors) so Flower's strategy can handle failures
|
|
519
|
+
return message
|
|
520
|
+
|
|
521
|
+
def _try_decrypt_response(self, body: bytes, msg_id: str) -> bytes:
|
|
522
|
+
"""Try to decrypt response body if it's encrypted."""
|
|
523
|
+
try:
|
|
524
|
+
# Try to parse as encrypted payload
|
|
525
|
+
encrypted_payload = EncryptedPayload.model_validate_json(body.decode())
|
|
526
|
+
# Decrypt the message
|
|
527
|
+
decrypted_body = decrypt_message(encrypted_payload, client=self._client)
|
|
528
|
+
# The decrypted body should be a base64-encoded string
|
|
529
|
+
response_body = base64.b64decode(decrypted_body)
|
|
530
|
+
logger.debug(f"🔓 Successfully decrypted response for message {msg_id}")
|
|
531
|
+
return response_body
|
|
532
|
+
except Exception as e:
|
|
533
|
+
# If decryption fails, assume plaintext
|
|
534
|
+
logger.debug(
|
|
535
|
+
f"📥 Response appears to be plaintext or decryption not needed "
|
|
536
|
+
f"for message {msg_id}: {e}"
|
|
537
|
+
)
|
|
538
|
+
return body
|
|
539
|
+
|
|
540
|
+
def _log_pull_summary(
|
|
541
|
+
self, messages: Dict[str, Message], message_ids: List[str]
|
|
542
|
+
) -> None:
|
|
543
|
+
"""Log summary of pulled messages."""
|
|
544
|
+
if messages:
|
|
545
|
+
if self._encryption_enabled:
|
|
546
|
+
logger.info(
|
|
547
|
+
f"🔐 Successfully pulled {len(messages)} messages (encryption enabled)"
|
|
548
|
+
)
|
|
549
|
+
else:
|
|
550
|
+
logger.info(f"📥 Successfully pulled {len(messages)} messages")
|
|
551
|
+
elif message_ids:
|
|
552
|
+
logger.debug(
|
|
553
|
+
f"No messages pulled yet from {len(message_ids)} attempts "
|
|
554
|
+
f"(clients may still be processing)"
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
def _get_timeout(self, timeout: Optional[float]) -> Optional[float]:
|
|
558
|
+
"""Get timeout value from environment or parameter.
|
|
559
|
+
|
|
560
|
+
Priority:
|
|
561
|
+
1. Explicit timeout parameter
|
|
562
|
+
2. SYFT_FLWR_MSG_TIMEOUT environment variable
|
|
563
|
+
3. Default: 120 seconds (to prevent indefinite waiting)
|
|
564
|
+
"""
|
|
565
|
+
# First check explicit parameter
|
|
566
|
+
if timeout is not None:
|
|
567
|
+
logger.debug(f"Message timeout: {timeout}s (from parameter)")
|
|
568
|
+
return timeout
|
|
569
|
+
|
|
570
|
+
# Then check environment variable
|
|
571
|
+
env_timeout = os.environ.get(SYFT_FLWR_MSG_TIMEOUT)
|
|
572
|
+
if env_timeout is not None:
|
|
573
|
+
timeout = float(env_timeout)
|
|
574
|
+
logger.debug(f"Message timeout: {timeout}s (from env var)")
|
|
575
|
+
return timeout
|
|
576
|
+
|
|
577
|
+
# Default to 120 seconds to prevent indefinite waiting
|
|
578
|
+
default_timeout = 120.0
|
|
579
|
+
logger.debug(f"Message timeout: {default_timeout}s (default)")
|
|
580
|
+
return default_timeout
|
syft_flwr/mounts.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import tomli
|
|
6
|
+
from loguru import logger
|
|
7
|
+
from syft_core import Client
|
|
8
|
+
from syft_rds.models import DockerMount, JobConfig
|
|
9
|
+
from syft_rds.syft_runtime.mounts import MountProvider
|
|
10
|
+
from typing_extensions import List
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SyftFlwrMountProvider(MountProvider):
|
|
14
|
+
def _simplify_config(self, config_path: Path, simplified_config_path: Path) -> None:
|
|
15
|
+
"""
|
|
16
|
+
Simplify the config by removing the refresh_token and setting the data_dir to /app/SyftBox
|
|
17
|
+
in order to mount the config to the container.
|
|
18
|
+
"""
|
|
19
|
+
with open(config_path, "r") as fp:
|
|
20
|
+
config = json.load(fp)
|
|
21
|
+
modified_config = config.copy()
|
|
22
|
+
modified_config["data_dir"] = "/app/SyftBox"
|
|
23
|
+
modified_config.pop("refresh_token", None)
|
|
24
|
+
with open(simplified_config_path, "w") as fp:
|
|
25
|
+
json.dump(modified_config, fp)
|
|
26
|
+
|
|
27
|
+
def get_mounts(self, job_config: JobConfig) -> List[DockerMount]:
|
|
28
|
+
client = Client.load()
|
|
29
|
+
client_email = client.email
|
|
30
|
+
flwr_app_data = client.app_data("flwr")
|
|
31
|
+
|
|
32
|
+
config_path = client.config_path
|
|
33
|
+
simplified_dir = client.config_path.parent / ".simplified_configs"
|
|
34
|
+
simplified_dir.mkdir(parents=True, exist_ok=True)
|
|
35
|
+
simplified_config_path = simplified_dir / f"{client_email}.config.json"
|
|
36
|
+
self._simplify_config(config_path, simplified_config_path)
|
|
37
|
+
|
|
38
|
+
# Read app name from pyproject.toml
|
|
39
|
+
with open(job_config.function_folder / "pyproject.toml", "rb") as fp:
|
|
40
|
+
toml_dict = tomli.load(fp)
|
|
41
|
+
syft_flwr_app_name = toml_dict["tool"]["syft_flwr"]["app_name"]
|
|
42
|
+
|
|
43
|
+
rpc_messages_source = Path(f"{flwr_app_data}/{syft_flwr_app_name}/rpc/messages")
|
|
44
|
+
rpc_messages_source.mkdir(parents=True, exist_ok=True)
|
|
45
|
+
os.chmod(rpc_messages_source, 0o777)
|
|
46
|
+
|
|
47
|
+
mounts = [
|
|
48
|
+
DockerMount(
|
|
49
|
+
source=simplified_config_path,
|
|
50
|
+
target="/app/config.json",
|
|
51
|
+
mode="ro",
|
|
52
|
+
),
|
|
53
|
+
DockerMount(
|
|
54
|
+
source=rpc_messages_source,
|
|
55
|
+
target=f"/app/SyftBox/datasites/{client_email}/app_data/flwr/{syft_flwr_app_name}/rpc/messages",
|
|
56
|
+
mode="rw",
|
|
57
|
+
),
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
logger.debug(f"Mounts: {mounts}")
|
|
61
|
+
|
|
62
|
+
return mounts
|