syft-flwr 0.1.7__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of syft-flwr might be problematic. Click here for more details.

syft_flwr/grid.py CHANGED
@@ -1,32 +1,31 @@
1
+ import base64
1
2
  import os
3
+ import random
2
4
  import time
3
- from typing import Iterable, cast
4
5
 
5
6
  from flwr.common import ConfigRecord
6
7
  from flwr.common.constant import MessageType
7
8
  from flwr.common.message import Message
9
+ from flwr.common.record import RecordDict
8
10
  from flwr.common.typing import Run
9
11
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
12
+ from flwr.server.grid import Grid
10
13
  from loguru import logger
11
14
  from syft_core import Client
12
- from syft_rpc import rpc, rpc_db
13
- from typing_extensions import Optional
14
-
15
- from syft_flwr.flwr_compatibility import (
16
- Grid,
17
- RecordDict,
18
- check_reply_to_field,
19
- create_flwr_message,
20
- )
15
+ from syft_crypto import EncryptedPayload, decrypt_message
16
+ from syft_rpc import SyftResponse, rpc, rpc_db
17
+ from typing_extensions import Dict, Iterable, List, Optional, Tuple, cast
18
+
19
+ from syft_flwr.consts import SYFT_FLWR_ENCRYPTION_ENABLED
21
20
  from syft_flwr.serde import bytes_to_flower_message, flower_message_to_bytes
22
- from syft_flwr.utils import str_to_int
21
+ from syft_flwr.utils import check_reply_to_field, create_flwr_message, str_to_int
23
22
 
24
23
  # this is what superlink super node do
25
24
  AGGREGATOR_NODE_ID = 1
26
25
 
27
-
28
26
  # env vars
29
27
  SYFT_FLWR_MSG_TIMEOUT = "SYFT_FLWR_MSG_TIMEOUT"
28
+ SYFT_FLWR_POLL_INTERVAL = "SYFT_FLWR_POLL_INTERVAL"
30
29
 
31
30
 
32
31
  class SyftGrid(Grid):
@@ -36,39 +35,66 @@ class SyftGrid(Grid):
36
35
  datasites: list[str] = [],
37
36
  client: Client = None,
38
37
  ) -> None:
38
+ """
39
+ SyftGrid is the server-side message orchestrator for federated learning in syft_flwr.
40
+ It acts as a bridge between Flower's server logic and SyftBox's communication layer:
41
+
42
+ Flower Server → SyftGrid → syft_rpc → SyftBox network → FL Clients
43
+ ↑ ↓
44
+ └──────────── responses ←─────────────────┘
45
+
46
+ SyftGrid enables Flower's centralized server to communicate with distributed SyftBox
47
+ clients without knowing the underlying transport details.
48
+
49
+ Core functionalities:
50
+ - push_messages(): Sends messages to clients via syft_rpc, returns future IDs
51
+ - pull_messages(): Retrieves responses using futures
52
+ - send_and_receive(): Combines push/pull with timeout handling
53
+ """
39
54
  self._client = Client.load() if client is None else client
40
55
  self._run: Optional[Run] = None
41
56
  self.node = Node(node_id=AGGREGATOR_NODE_ID)
42
57
  self.datasites = datasites
43
58
  self.client_map = {str_to_int(ds): ds for ds in self.datasites}
59
+
60
+ # Check if encryption is enabled (default: True for production)
61
+ self._encryption_enabled = (
62
+ os.environ.get(SYFT_FLWR_ENCRYPTION_ENABLED, "true").lower() != "false"
63
+ )
64
+
44
65
  logger.debug(
45
66
  f"Initialize SyftGrid for '{self._client.email}' with datasites: {self.datasites}"
46
67
  )
68
+ if self._encryption_enabled:
69
+ logger.info("🔐 End-to-end encryption is ENABLED for FL messages")
70
+ else:
71
+ logger.warning(
72
+ "⚠️ End-to-end encryption is DISABLED for FL messages (development mode / insecure)"
73
+ )
74
+
47
75
  self.app_name = app_name
48
76
 
49
77
  def set_run(self, run_id: int) -> None:
50
- # TODO: In Grpc Grid case, the superlink is the one which sets up the run id,
51
- # do we need to do the same here, where the run id is set from an external context.
78
+ """Set the run ID for this federated learning session.
52
79
 
80
+ Args:
81
+ run_id: Unique identifier for the FL run/session
82
+
83
+ Note:
84
+ In Grpc Grid case, the superlink sets up the run id.
85
+ Here, the run id is set from an external context.
86
+ """
53
87
  # Convert to Flower Run object
54
88
  self._run = Run.create_empty(run_id)
55
89
 
56
90
  @property
57
91
  def run(self) -> Run:
58
- """Run ID."""
59
- return Run(**vars(cast(Run, self._run)))
92
+ """Get the current Flower Run object.
60
93
 
61
- def _check_message(self, message: Message) -> None:
62
- # Check if the message is valid
63
- if not (
64
- message.metadata.run_id == cast(Run, self._run).run_id
65
- and message.metadata.src_node_id == self.node.node_id
66
- and message.metadata.message_id == ""
67
- and check_reply_to_field(message.metadata)
68
- and message.metadata.ttl > 0
69
- ):
70
- logger.debug(f"Invalid message with metadata: {message.metadata}")
71
- raise ValueError(f"Invalid message: {message}")
94
+ Returns:
95
+ A copy of the current Run object with run metadata
96
+ """
97
+ return Run(**vars(cast(Run, self._run)))
72
98
 
73
99
  def create_message(
74
100
  self,
@@ -78,81 +104,106 @@ class SyftGrid(Grid):
78
104
  group_id: str,
79
105
  ttl: Optional[float] = None,
80
106
  ) -> Message:
81
- """Create a new message with specified parameters."""
107
+ """Create a new Flower message with proper metadata.
108
+
109
+ Args:
110
+ content: Message payload as RecordDict (e.g., model parameters, metrics)
111
+ message_type: Type of FL message (e.g., MessageType.TRAIN, MessageType.EVALUATE)
112
+ dst_node_id: Destination node ID (client identifier)
113
+ group_id: Message group identifier for related messages
114
+ ttl: Time-to-live in seconds (optional, for message expiration)
115
+
116
+ Returns:
117
+ A Flower Message object ready to be sent to a client
118
+
119
+ Note:
120
+ Automatically adds current run_id and server's node_id to metadata.
121
+ """
82
122
  return create_flwr_message(
83
123
  content=content,
84
124
  message_type=message_type,
85
125
  dst_node_id=dst_node_id,
86
126
  group_id=group_id,
87
127
  ttl=ttl,
88
- run_id=cast(Run, self._run).run_id,
89
- src_node_id=self.node.node_id,
90
128
  )
91
129
 
92
130
  def get_node_ids(self) -> list[int]:
93
- """Get node IDs of all connected nodes."""
94
- # it is map from datasites to node id
131
+ """Get node IDs of all connected FL clients.
132
+
133
+ Returns:
134
+ List of integer node IDs representing connected datasites/clients
135
+
136
+ Note:
137
+ Node IDs are deterministically generated from datasite email addresses
138
+ using str_to_int() for consistent client identification.
139
+ """
95
140
  return list(self.client_map.keys())
96
141
 
97
142
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
98
- """Push messages to specified node IDs."""
99
- # Construct Messages
100
- run_id = cast(Run, self._run).run_id
143
+ """Push FL messages to specified clients asynchronously.
144
+
145
+ Args:
146
+ messages: Iterable of Flower Messages to send to clients
147
+
148
+ Returns:
149
+ List of future IDs that can be used to retrieve responses
150
+ """
101
151
  message_ids = []
152
+
102
153
  for msg in messages:
103
- # Set metadata
104
- msg.metadata.__dict__["_run_id"] = run_id
105
- msg.metadata.__dict__["_src_node_id"] = self.node.node_id
106
- # RPC URL
107
- dest_datasite = self.client_map[msg.metadata.dst_node_id]
108
- url = rpc.make_url(
109
- dest_datasite, app_name=self.app_name, endpoint="messages"
110
- )
111
- # Check message
112
- self._check_message(msg)
113
- # Serialize message
114
- msg_bytes = flower_message_to_bytes(msg)
154
+ # Prepare message
155
+ dest_datasite, url, msg_bytes = self._prepare_message(msg)
156
+
115
157
  # Send message
116
- future = rpc.send(url=url, body=msg_bytes, client=self._client)
117
- logger.debug(
118
- f"Pushed message to {url} with metadata {msg.metadata}; size {len(msg_bytes) / 1024 / 1024} (Mb)"
119
- )
120
- # Save future
121
- rpc_db.save_future(
122
- future=future, namespace=self.app_name, client=self._client
123
- )
124
- message_ids.append(future.id)
158
+ if self._encryption_enabled:
159
+ future_id = self._send_encrypted_message(
160
+ url, msg_bytes, dest_datasite, msg
161
+ )
162
+ else:
163
+ future_id = self._send_unencrypted_message(
164
+ url, msg_bytes, dest_datasite, msg
165
+ )
166
+
167
+ if future_id:
168
+ message_ids.append(future_id)
125
169
 
126
170
  return message_ids
127
171
 
128
- def pull_messages(self, message_ids):
129
- """Pull messages based on message IDs."""
172
+ def pull_messages(self, message_ids: List[str]) -> Dict[str, Message]:
173
+ """Pull response messages from clients using future IDs.
174
+
175
+ Args:
176
+ message_ids: List of future IDs from push_messages()
177
+
178
+ Returns:
179
+ Dict mapping message_id to Flower Message response
180
+ """
130
181
  messages = {}
131
182
 
132
183
  for msg_id in message_ids:
133
- future = rpc_db.get_future(future_id=msg_id, client=self._client)
134
- response = future.resolve()
135
- if response is None:
136
- continue
184
+ try:
185
+ # Get and resolve future
186
+ future = rpc_db.get_future(future_id=msg_id, client=self._client)
187
+ response = future.resolve()
137
188
 
138
- response.raise_for_status()
189
+ if response is None:
190
+ continue # Message not ready yet
139
191
 
140
- if not response.body:
141
- raise ValueError(f"Empty response: {response}")
192
+ response.raise_for_status()
142
193
 
143
- message: Message = bytes_to_flower_message(response.body)
144
- if message.has_error():
145
- error = message.error
146
- logger.error(
147
- f"Message {msg_id} error with code={error.code}, reason={error.reason}"
148
- )
194
+ # Process the response
195
+ message = self._process_response(response, msg_id)
196
+
197
+ if message:
198
+ messages[msg_id] = message
199
+ rpc_db.delete_future(future_id=msg_id, client=self._client)
200
+
201
+ except Exception as e:
202
+ logger.error(f"❌ Unexpected error pulling message {msg_id}: {e}")
149
203
  continue
150
204
 
151
- logger.debug(
152
- f"Pulled message from {response.url} with metadata: {message.metadata}, size: {len(response.body) / 1024 / 1024} (Mb)"
153
- )
154
- messages[msg_id] = message
155
- rpc_db.delete_future(future_id=msg_id, client=self._client)
205
+ # Log summary
206
+ self._log_pull_summary(messages, message_ids)
156
207
 
157
208
  return messages
158
209
 
@@ -164,47 +215,49 @@ class SyftGrid(Grid):
164
215
  ) -> Iterable[Message]:
165
216
  """Push messages to specified node IDs and pull the reply messages.
166
217
 
167
- This method sends a list of messages to their destination node IDs and then
168
- waits for the replies. It continues to pull replies until either all replies are
169
- received or the specified timeout duration (in seconds) is exceeded.
218
+ This method sends messages to their destination nodes and waits for replies.
219
+ It continues polling until all replies are received or timeout is reached.
220
+
221
+ Args:
222
+ messages: Messages to send
223
+ timeout: Maximum time to wait for replies (seconds).
224
+ Can be overridden by SYFT_FLWR_MSG_TIMEOUT env var.
225
+
226
+ Returns:
227
+ Collection of reply messages received
170
228
  """
171
- if os.environ.get(SYFT_FLWR_MSG_TIMEOUT) is not None:
172
- timeout = float(os.environ.get(SYFT_FLWR_MSG_TIMEOUT))
173
- if timeout is not None:
174
- logger.debug(
175
- f"syft_flwr messages timeout = {timeout}: Will move on after {timeout} (s) if no reply is received"
176
- )
177
- else:
178
- logger.debug(
179
- "syft_flwr messages timeout = None: Will wait indefinitely for replies"
180
- )
229
+ # Get timeout from environment or parameter
230
+ timeout = self._get_timeout(timeout)
181
231
 
182
- # Push messages
232
+ # Push messages and get IDs
183
233
  msg_ids = set(self.push_messages(messages))
234
+ if not msg_ids:
235
+ return []
184
236
 
185
- # Pull messages
186
- end_time = time.time() + (timeout if timeout is not None else 0.0)
187
- ret = {}
188
- while timeout is None or time.time() < end_time:
189
- res_msgs = self.pull_messages(msg_ids)
190
- ret.update(res_msgs)
191
- msg_ids.difference_update(res_msgs.keys())
192
- if len(msg_ids) == 0: # All messages received
193
- break
194
- time.sleep(3) # polling interval
195
-
196
- if msg_ids:
197
- logger.warning(
198
- f"Timeout reached. {len(msg_ids)} message(s) sent out but not replied."
199
- )
237
+ # Poll for responses
238
+ responses = self._poll_for_responses(msg_ids, timeout)
200
239
 
201
- return ret.values()
240
+ return responses.values()
202
241
 
203
242
  def send_stop_signal(
204
243
  self, group_id: str, reason: str = "Training complete", ttl: float = 60.0
205
- ) -> list[Message]:
206
- """Send a stop signal to all datasites (clients)."""
207
- stop_messages: list[Message] = [
244
+ ) -> List[Message]:
245
+ """Send a stop signal to all connected FL clients.
246
+
247
+ Args:
248
+ group_id: Identifier for this group of stop messages
249
+ reason: Human-readable reason for stopping (default: "Training complete")
250
+ ttl: Time-to-live for stop messages in seconds (default: 60.0)
251
+
252
+ Returns:
253
+ List of stop Messages that were sent
254
+
255
+ Note:
256
+ Used to gracefully terminate FL clients when training completes or
257
+ when the server encounters an error. Clients will shut down upon
258
+ receiving this SYSTEM message with action="stop".
259
+ """
260
+ stop_messages: List[Message] = [
208
261
  self.create_message(
209
262
  content=RecordDict(
210
263
  {"config": ConfigRecord({"action": "stop", "reason": reason})}
@@ -219,3 +272,287 @@ class SyftGrid(Grid):
219
272
  self.push_messages(stop_messages)
220
273
 
221
274
  return stop_messages
275
+
276
+ def _check_message(self, message: Message) -> None:
277
+ """Validate a Flower message before sending.
278
+
279
+ Args:
280
+ message: The Flower Message to validate
281
+
282
+ Raises:
283
+ ValueError: If message metadata is invalid (wrong run_id, src_node_id,
284
+ missing ttl, or invalid reply_to field)
285
+
286
+ Note:
287
+ Ensures message belongs to current run and originates from this server node.
288
+ """
289
+ if not (
290
+ message.metadata.run_id == cast(Run, self._run).run_id
291
+ and message.metadata.src_node_id == self.node.node_id
292
+ and message.metadata.message_id == ""
293
+ and check_reply_to_field(message.metadata)
294
+ and message.metadata.ttl > 0
295
+ ):
296
+ logger.debug(f"Invalid message with metadata: {message.metadata}")
297
+ raise ValueError(f"Invalid message: {message}")
298
+
299
+ def _prepare_message(self, msg: Message) -> Tuple[str, str, bytes]:
300
+ """Prepare a message for sending.
301
+
302
+ Returns:
303
+ Tuple of (destination_datasite, url, message_bytes)
304
+ """
305
+ run_id = cast(Run, self._run).run_id
306
+ msg.metadata.__dict__["_run_id"] = run_id
307
+ msg.metadata.__dict__["_src_node_id"] = self.node.node_id
308
+
309
+ dest_datasite = self.client_map[msg.metadata.dst_node_id]
310
+ url = rpc.make_url(dest_datasite, app_name=self.app_name, endpoint="messages")
311
+
312
+ self._check_message(msg)
313
+ msg_bytes = flower_message_to_bytes(msg)
314
+
315
+ return dest_datasite, url, msg_bytes
316
+
317
+ def _retry_with_backoff(
318
+ self,
319
+ func,
320
+ max_retries: int = 3,
321
+ initial_delay: float = 0.1,
322
+ context: str = "",
323
+ check_error=None,
324
+ ):
325
+ """Generic retry logic with exponential backoff and jitter.
326
+
327
+ Args:
328
+ func: Function to retry
329
+ max_retries: Maximum number of retry attempts
330
+ initial_delay: Initial delay in seconds
331
+ context: Context string for logging
332
+ check_error: Optional function to check if error is retryable
333
+
334
+ Returns:
335
+ Result of func if successful
336
+
337
+ Raises:
338
+ Last exception if all retries fail
339
+ """
340
+ for attempt in range(max_retries):
341
+ try:
342
+ return func()
343
+ except Exception as e:
344
+ is_retryable = check_error(e) if check_error else True
345
+ if is_retryable and attempt < max_retries - 1:
346
+ jitter = random.uniform(0, 0.05)
347
+ delay = initial_delay * (2**attempt) + jitter
348
+ logger.debug(
349
+ f"{context} failed (attempt {attempt + 1}/{max_retries}): {e}. "
350
+ f"Retrying in {delay:.3f}s"
351
+ )
352
+ time.sleep(delay)
353
+ else:
354
+ raise
355
+
356
+ def _save_future_with_retry(self, future, dest_datasite: str) -> bool:
357
+ """Save future to database with retry logic for database locks.
358
+
359
+ Returns:
360
+ True if saved successfully, False if failed after retries
361
+ """
362
+ try:
363
+ self._retry_with_backoff(
364
+ func=lambda: rpc_db.save_future(
365
+ future=future, namespace=self.app_name, client=self._client
366
+ ),
367
+ context=f"Database save for {dest_datasite}",
368
+ check_error=lambda e: "database is locked" in str(e).lower(),
369
+ )
370
+ return True
371
+ except Exception as e:
372
+ logger.warning(
373
+ f"⚠️ Failed to save future to database for {dest_datasite}: {e}. "
374
+ f"Message sent but future not persisted."
375
+ )
376
+ return False
377
+
378
+ def _send_encrypted_message(
379
+ self, url: str, msg_bytes: bytes, dest_datasite: str, msg: Message
380
+ ) -> Optional[str]:
381
+ """Send an encrypted message and return future ID if successful."""
382
+ try:
383
+ # Send encrypted message
384
+ future = rpc.send(
385
+ url=url,
386
+ body=base64.b64encode(msg_bytes).decode("utf-8"),
387
+ client=self._client,
388
+ encrypt=True,
389
+ )
390
+
391
+ logger.debug(
392
+ f"🔐 Pushed ENCRYPTED message to {dest_datasite} at {url} "
393
+ f"with metadata {msg.metadata}; size {len(msg_bytes) / 1024 / 1024:.2f} MB"
394
+ )
395
+
396
+ # Save future to database (non-critical - log warning if fails)
397
+ self._save_future_with_retry(future, dest_datasite)
398
+ return future.id
399
+
400
+ except (KeyError, ValueError) as e:
401
+ # Encryption setup errors - don't retry or fallback
402
+ error_type = (
403
+ "Encryption key" if isinstance(e, KeyError) else "Encryption parameter"
404
+ )
405
+ logger.error(
406
+ f"❌ {error_type} error for {dest_datasite}: {e}. "
407
+ f"Skipping message to node {msg.metadata.dst_node_id}"
408
+ )
409
+ return None
410
+
411
+ except Exception as e:
412
+ # Other errors - fallback to unencrypted
413
+ logger.warning(
414
+ f"⚠️ Encryption failed for {dest_datasite}: {e}. "
415
+ f"Falling back to unencrypted transmission"
416
+ )
417
+ return self._send_unencrypted_message(url, msg_bytes, dest_datasite, msg)
418
+
419
+ def _send_unencrypted_message(
420
+ self, url: str, msg_bytes: bytes, dest_datasite: str, msg: Message
421
+ ) -> Optional[str]:
422
+ """Send an unencrypted message and return future ID if successful."""
423
+ try:
424
+ future = rpc.send(url=url, body=msg_bytes, client=self._client)
425
+ logger.debug(
426
+ f"📤 Pushed PLAINTEXT message to {dest_datasite} at {url} "
427
+ f"with metadata {msg.metadata}; size {len(msg_bytes) / 1024 / 1024:.2f} MB"
428
+ )
429
+ rpc_db.save_future(
430
+ future=future, namespace=self.app_name, client=self._client
431
+ )
432
+ return future.id
433
+
434
+ except Exception as e:
435
+ logger.error(f"❌ Failed to send message to {dest_datasite}: {e}")
436
+ return None
437
+
438
+ def _poll_for_responses(
439
+ self, msg_ids: set, timeout: Optional[float]
440
+ ) -> Dict[str, Message]:
441
+ """Poll for responses until all received or timeout."""
442
+ end_time = time.time() + (timeout if timeout is not None else float("inf"))
443
+ responses = {}
444
+ pending_ids = msg_ids.copy()
445
+
446
+ # Get polling interval from environment or use default
447
+ poll_interval = float(os.environ.get(SYFT_FLWR_POLL_INTERVAL, "3"))
448
+
449
+ while pending_ids and (timeout is None or time.time() < end_time):
450
+ # Pull available messages
451
+ batch = self.pull_messages(pending_ids)
452
+ responses.update(batch)
453
+ pending_ids.difference_update(batch.keys())
454
+
455
+ if pending_ids:
456
+ time.sleep(poll_interval) # Configurable polling interval
457
+
458
+ # Log any missing responses
459
+ if pending_ids:
460
+ logger.warning(
461
+ f"Timeout reached. {len(pending_ids)} message(s) not received."
462
+ )
463
+
464
+ return responses
465
+
466
+ def _process_response(
467
+ self, response: SyftResponse, msg_id: str
468
+ ) -> Optional[Message]:
469
+ """Process a single response and return the deserialized message."""
470
+ if not response.body:
471
+ logger.warning(f"⚠️ Empty response for message {msg_id}, skipping")
472
+ return None
473
+
474
+ response_body = response.body
475
+
476
+ # Try to decrypt if encryption is enabled
477
+ if self._encryption_enabled:
478
+ response_body = self._try_decrypt_response(response.body, msg_id)
479
+
480
+ # Deserialize message
481
+ try:
482
+ message = bytes_to_flower_message(response_body)
483
+ except Exception as e:
484
+ logger.error(
485
+ f"❌ Failed to deserialize message {msg_id}: {e}. "
486
+ f"Message may be corrupted or in incompatible format."
487
+ )
488
+ return None
489
+
490
+ # Check for errors in message
491
+ if message.has_error():
492
+ error = message.error
493
+ logger.error(
494
+ f"❌ Message {msg_id} returned error with code={error.code}, "
495
+ f"reason={error.reason}"
496
+ )
497
+ return None
498
+
499
+ # Log successful pull
500
+ encryption_status = (
501
+ "🔐 ENCRYPTED" if self._encryption_enabled else "📥 PLAINTEXT"
502
+ )
503
+ logger.debug(
504
+ f"{encryption_status} Pulled message from {response.url} "
505
+ f"with metadata: {message.metadata}, "
506
+ f"size: {len(response_body) / 1024 / 1024:.2f} MB"
507
+ )
508
+
509
+ return message
510
+
511
+ def _try_decrypt_response(self, body: bytes, msg_id: str) -> bytes:
512
+ """Try to decrypt response body if it's encrypted."""
513
+ try:
514
+ # Try to parse as encrypted payload
515
+ encrypted_payload = EncryptedPayload.model_validate_json(body.decode())
516
+ # Decrypt the message
517
+ decrypted_body = decrypt_message(encrypted_payload, client=self._client)
518
+ # The decrypted body should be a base64-encoded string
519
+ response_body = base64.b64decode(decrypted_body)
520
+ logger.debug(f"🔓 Successfully decrypted response for message {msg_id}")
521
+ return response_body
522
+ except Exception as e:
523
+ # If decryption fails, assume plaintext
524
+ logger.debug(
525
+ f"📥 Response appears to be plaintext or decryption not needed "
526
+ f"for message {msg_id}: {e}"
527
+ )
528
+ return body
529
+
530
+ def _log_pull_summary(
531
+ self, messages: Dict[str, Message], message_ids: List[str]
532
+ ) -> None:
533
+ """Log summary of pulled messages."""
534
+ if messages:
535
+ if self._encryption_enabled:
536
+ logger.info(
537
+ f"🔐 Successfully pulled {len(messages)} messages (encryption enabled)"
538
+ )
539
+ else:
540
+ logger.info(f"📥 Successfully pulled {len(messages)} messages")
541
+ elif message_ids:
542
+ logger.debug(
543
+ f"No messages pulled yet from {len(message_ids)} attempts "
544
+ f"(clients may still be processing)"
545
+ )
546
+
547
+ def _get_timeout(self, timeout: Optional[float]) -> Optional[float]:
548
+ """Get timeout value from environment or parameter."""
549
+ env_timeout = os.environ.get(SYFT_FLWR_MSG_TIMEOUT)
550
+ if env_timeout is not None:
551
+ timeout = float(env_timeout)
552
+
553
+ if timeout is not None:
554
+ logger.debug(f"Message timeout: {timeout}s")
555
+ else:
556
+ logger.debug("No timeout - will wait indefinitely for replies")
557
+
558
+ return timeout
syft_flwr/mounts.py CHANGED
@@ -1,13 +1,13 @@
1
1
  import json
2
2
  import os
3
3
  from pathlib import Path
4
- from typing import List
5
4
 
6
5
  import tomli
7
6
  from loguru import logger
8
7
  from syft_core import Client
9
8
  from syft_rds.models import DockerMount, JobConfig
10
9
  from syft_rds.syft_runtime.mounts import MountProvider
10
+ from typing_extensions import List
11
11
 
12
12
 
13
13
  class SyftFlwrMountProvider(MountProvider):
syft_flwr/run.py CHANGED
@@ -5,11 +5,12 @@ from uuid import uuid4
5
5
  from flwr.client.client_app import LoadClientAppError
6
6
  from flwr.common import Context
7
7
  from flwr.common.object_ref import load_app
8
+ from flwr.common.record import RecordDict
8
9
  from flwr.server.server_app import LoadServerAppError
10
+
9
11
  from syft_flwr.config import load_flwr_pyproject
10
12
  from syft_flwr.flower_client import syftbox_flwr_client
11
13
  from syft_flwr.flower_server import syftbox_flwr_server
12
- from syft_flwr.flwr_compatibility import RecordDict
13
14
  from syft_flwr.run_simulation import run
14
15
 
15
16
  __all__ = ["syftbox_run_flwr_client", "syftbox_run_flwr_server", "run"]