flwr-nightly 1.13.0.dev20241025__py3-none-any.whl → 1.13.0.dev20241029__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. flwr/common/date.py +3 -3
  2. flwr/common/logger.py +31 -0
  3. flwr/common/serde.py +22 -0
  4. flwr/proto/driver_pb2.py +24 -23
  5. flwr/proto/driver_pb2_grpc.py +69 -0
  6. flwr/proto/driver_pb2_grpc.pyi +27 -0
  7. flwr/proto/log_pb2.py +29 -0
  8. flwr/proto/log_pb2.pyi +39 -0
  9. flwr/proto/log_pb2_grpc.py +4 -0
  10. flwr/proto/log_pb2_grpc.pyi +4 -0
  11. flwr/server/app.py +10 -8
  12. flwr/server/driver/driver.py +14 -0
  13. flwr/server/driver/grpc_driver.py +8 -15
  14. flwr/server/driver/inmemory_driver.py +3 -11
  15. flwr/server/run_serverapp.py +3 -4
  16. flwr/server/serverapp/app.py +193 -18
  17. flwr/server/superlink/driver/driver_servicer.py +34 -1
  18. flwr/server/superlink/linkstate/in_memory_linkstate.py +28 -2
  19. flwr/server/superlink/linkstate/linkstate.py +35 -0
  20. flwr/server/superlink/linkstate/sqlite_linkstate.py +50 -0
  21. flwr/simulation/run_simulation.py +2 -1
  22. flwr/superexec/deployment.py +3 -37
  23. flwr/superexec/exec_servicer.py +5 -72
  24. flwr/superexec/executor.py +3 -4
  25. flwr/superexec/simulation.py +4 -7
  26. {flwr_nightly-1.13.0.dev20241025.dist-info → flwr_nightly-1.13.0.dev20241029.dist-info}/METADATA +1 -1
  27. {flwr_nightly-1.13.0.dev20241025.dist-info → flwr_nightly-1.13.0.dev20241029.dist-info}/RECORD +30 -26
  28. {flwr_nightly-1.13.0.dev20241025.dist-info → flwr_nightly-1.13.0.dev20241029.dist-info}/LICENSE +0 -0
  29. {flwr_nightly-1.13.0.dev20241025.dist-info → flwr_nightly-1.13.0.dev20241029.dist-info}/WHEEL +0 -0
  30. {flwr_nightly-1.13.0.dev20241025.dist-info → flwr_nightly-1.13.0.dev20241029.dist-info}/entry_points.txt +0 -0
flwr/common/date.py CHANGED
@@ -15,9 +15,9 @@
15
15
  """Flower date utils."""
16
16
 
17
17
 
18
- from datetime import datetime, timezone
18
+ import datetime
19
19
 
20
20
 
21
- def now() -> datetime:
21
+ def now() -> datetime.datetime:
22
22
  """Construct a datetime from time.time() with time zone set to UTC."""
23
- return datetime.now(tz=timezone.utc)
23
+ return datetime.datetime.now(tz=datetime.timezone.utc)
flwr/common/logger.py CHANGED
@@ -16,8 +16,10 @@
16
16
 
17
17
 
18
18
  import logging
19
+ import sys
19
20
  from logging import WARN, LogRecord
20
21
  from logging.handlers import HTTPHandler
22
+ from queue import Queue
21
23
  from typing import TYPE_CHECKING, Any, Optional, TextIO
22
24
 
23
25
  # Create logger
@@ -259,3 +261,32 @@ def set_logger_propagation(
259
261
  if not child_logger.propagate:
260
262
  child_logger.log(logging.DEBUG, "Logger propagate set to False")
261
263
  return child_logger
264
+
265
+
266
+ def mirror_output_to_queue(log_queue: Queue[str]) -> None:
267
+ """Mirror stdout and stderr output to the provided queue."""
268
+
269
+ def get_write_fn(stream: TextIO) -> Any:
270
+ original_write = stream.write
271
+
272
+ def fn(s: str) -> int:
273
+ ret = original_write(s)
274
+ stream.flush()
275
+ log_queue.put(s)
276
+ return ret
277
+
278
+ return fn
279
+
280
+ sys.stdout.write = get_write_fn(sys.stdout) # type: ignore[method-assign]
281
+ sys.stderr.write = get_write_fn(sys.stderr) # type: ignore[method-assign]
282
+ console_handler.stream = sys.stdout
283
+
284
+
285
+ def restore_output() -> None:
286
+ """Restore stdout and stderr.
287
+
288
+ This will stop mirroring output to queues.
289
+ """
290
+ sys.stdout = sys.__stdout__
291
+ sys.stderr = sys.__stderr__
292
+ console_handler.stream = sys.stdout
flwr/common/serde.py CHANGED
@@ -40,6 +40,7 @@ from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
40
40
  from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
41
41
  from flwr.proto.recordset_pb2 import SintList, StringList, UintList
42
42
  from flwr.proto.run_pb2 import Run as ProtoRun
43
+ from flwr.proto.run_pb2 import RunStatus as ProtoRunStatus
43
44
  from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
44
45
  from flwr.proto.transport_pb2 import (
45
46
  ClientMessage,
@@ -910,3 +911,24 @@ def clientappstatus_from_proto(
910
911
  if msg.code == ClientAppOutputCode.UNKNOWN_ERROR:
911
912
  code = typing.ClientAppOutputCode.UNKNOWN_ERROR
912
913
  return typing.ClientAppOutputStatus(code=code, message=msg.message)
914
+
915
+
916
+ # === Run status ===
917
+
918
+
919
+ def run_status_to_proto(run_status: typing.RunStatus) -> ProtoRunStatus:
920
+ """Serialize `RunStatus` to ProtoBuf."""
921
+ return ProtoRunStatus(
922
+ status=run_status.status,
923
+ sub_status=run_status.sub_status,
924
+ details=run_status.details,
925
+ )
926
+
927
+
928
+ def run_status_from_proto(run_status_proto: ProtoRunStatus) -> typing.RunStatus:
929
+ """Deserialize `RunStatus` from ProtoBuf."""
930
+ return typing.RunStatus(
931
+ status=run_status_proto.status,
932
+ sub_status=run_status_proto.sub_status,
933
+ details=run_status_proto.details,
934
+ )
flwr/proto/driver_pb2.py CHANGED
@@ -12,6 +12,7 @@ from google.protobuf.internal import builder as _builder
12
12
  _sym_db = _symbol_database.Default()
13
13
 
14
14
 
15
+ from flwr.proto import log_pb2 as flwr_dot_proto_dot_log__pb2
15
16
  from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
16
17
  from flwr.proto import message_pb2 as flwr_dot_proto_dot_message__pb2
17
18
  from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2
@@ -19,33 +20,33 @@ from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
19
20
  from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
20
21
 
21
22
 
22
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\",\n\x1aPullServerAppInputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\x9e\x05\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x62\x06proto3')
23
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x14\x66lwr/proto/log.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\",\n\x1aPullServerAppInputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\xc5\x06\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x12\\\n\x0fUpdateRunStatus\x12\".flwr.proto.UpdateRunStatusRequest\x1a#.flwr.proto.UpdateRunStatusResponse\"\x00\x12G\n\x08PushLogs\x12\x1b.flwr.proto.PushLogsRequest\x1a\x1c.flwr.proto.PushLogsResponse\"\x00\x62\x06proto3')
23
24
 
24
25
  _globals = globals()
25
26
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
26
27
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.driver_pb2', _globals)
27
28
  if _descriptor._USE_C_DESCRIPTORS == False:
28
29
  DESCRIPTOR._options = None
29
- _globals['_GETNODESREQUEST']._serialized_start=155
30
- _globals['_GETNODESREQUEST']._serialized_end=188
31
- _globals['_GETNODESRESPONSE']._serialized_start=190
32
- _globals['_GETNODESRESPONSE']._serialized_end=241
33
- _globals['_PUSHTASKINSREQUEST']._serialized_start=243
34
- _globals['_PUSHTASKINSREQUEST']._serialized_end=307
35
- _globals['_PUSHTASKINSRESPONSE']._serialized_start=309
36
- _globals['_PUSHTASKINSRESPONSE']._serialized_end=348
37
- _globals['_PULLTASKRESREQUEST']._serialized_start=350
38
- _globals['_PULLTASKRESREQUEST']._serialized_end=420
39
- _globals['_PULLTASKRESRESPONSE']._serialized_start=422
40
- _globals['_PULLTASKRESRESPONSE']._serialized_end=487
41
- _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=489
42
- _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=533
43
- _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=535
44
- _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=662
45
- _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=664
46
- _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=747
47
- _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=749
48
- _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=779
49
- _globals['_DRIVER']._serialized_start=782
50
- _globals['_DRIVER']._serialized_end=1452
30
+ _globals['_GETNODESREQUEST']._serialized_start=177
31
+ _globals['_GETNODESREQUEST']._serialized_end=210
32
+ _globals['_GETNODESRESPONSE']._serialized_start=212
33
+ _globals['_GETNODESRESPONSE']._serialized_end=263
34
+ _globals['_PUSHTASKINSREQUEST']._serialized_start=265
35
+ _globals['_PUSHTASKINSREQUEST']._serialized_end=329
36
+ _globals['_PUSHTASKINSRESPONSE']._serialized_start=331
37
+ _globals['_PUSHTASKINSRESPONSE']._serialized_end=370
38
+ _globals['_PULLTASKRESREQUEST']._serialized_start=372
39
+ _globals['_PULLTASKRESREQUEST']._serialized_end=442
40
+ _globals['_PULLTASKRESRESPONSE']._serialized_start=444
41
+ _globals['_PULLTASKRESRESPONSE']._serialized_end=509
42
+ _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=511
43
+ _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=555
44
+ _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=557
45
+ _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=684
46
+ _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=686
47
+ _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=769
48
+ _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=771
49
+ _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=801
50
+ _globals['_DRIVER']._serialized_start=804
51
+ _globals['_DRIVER']._serialized_end=1641
51
52
  # @@protoc_insertion_point(module_scope)
@@ -4,6 +4,7 @@ import grpc
4
4
 
5
5
  from flwr.proto import driver_pb2 as flwr_dot_proto_dot_driver__pb2
6
6
  from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
7
+ from flwr.proto import log_pb2 as flwr_dot_proto_dot_log__pb2
7
8
  from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
8
9
 
9
10
 
@@ -56,6 +57,16 @@ class DriverStub(object):
56
57
  request_serializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.SerializeToString,
57
58
  response_deserializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.FromString,
58
59
  )
60
+ self.UpdateRunStatus = channel.unary_unary(
61
+ '/flwr.proto.Driver/UpdateRunStatus',
62
+ request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString,
63
+ response_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString,
64
+ )
65
+ self.PushLogs = channel.unary_unary(
66
+ '/flwr.proto.Driver/PushLogs',
67
+ request_serializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.SerializeToString,
68
+ response_deserializer=flwr_dot_proto_dot_log__pb2.PushLogsResponse.FromString,
69
+ )
59
70
 
60
71
 
61
72
  class DriverServicer(object):
@@ -117,6 +128,20 @@ class DriverServicer(object):
117
128
  context.set_details('Method not implemented!')
118
129
  raise NotImplementedError('Method not implemented!')
119
130
 
131
+ def UpdateRunStatus(self, request, context):
132
+ """Update the status of a given run
133
+ """
134
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
135
+ context.set_details('Method not implemented!')
136
+ raise NotImplementedError('Method not implemented!')
137
+
138
+ def PushLogs(self, request, context):
139
+ """Push ServerApp logs
140
+ """
141
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
142
+ context.set_details('Method not implemented!')
143
+ raise NotImplementedError('Method not implemented!')
144
+
120
145
 
121
146
  def add_DriverServicer_to_server(servicer, server):
122
147
  rpc_method_handlers = {
@@ -160,6 +185,16 @@ def add_DriverServicer_to_server(servicer, server):
160
185
  request_deserializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.FromString,
161
186
  response_serializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.SerializeToString,
162
187
  ),
188
+ 'UpdateRunStatus': grpc.unary_unary_rpc_method_handler(
189
+ servicer.UpdateRunStatus,
190
+ request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString,
191
+ response_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.SerializeToString,
192
+ ),
193
+ 'PushLogs': grpc.unary_unary_rpc_method_handler(
194
+ servicer.PushLogs,
195
+ request_deserializer=flwr_dot_proto_dot_log__pb2.PushLogsRequest.FromString,
196
+ response_serializer=flwr_dot_proto_dot_log__pb2.PushLogsResponse.SerializeToString,
197
+ ),
163
198
  }
164
199
  generic_handler = grpc.method_handlers_generic_handler(
165
200
  'flwr.proto.Driver', rpc_method_handlers)
@@ -305,3 +340,37 @@ class Driver(object):
305
340
  flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.FromString,
306
341
  options, channel_credentials,
307
342
  insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
343
+
344
+ @staticmethod
345
+ def UpdateRunStatus(request,
346
+ target,
347
+ options=(),
348
+ channel_credentials=None,
349
+ call_credentials=None,
350
+ insecure=False,
351
+ compression=None,
352
+ wait_for_ready=None,
353
+ timeout=None,
354
+ metadata=None):
355
+ return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/UpdateRunStatus',
356
+ flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString,
357
+ flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString,
358
+ options, channel_credentials,
359
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
360
+
361
+ @staticmethod
362
+ def PushLogs(request,
363
+ target,
364
+ options=(),
365
+ channel_credentials=None,
366
+ call_credentials=None,
367
+ insecure=False,
368
+ compression=None,
369
+ wait_for_ready=None,
370
+ timeout=None,
371
+ metadata=None):
372
+ return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/PushLogs',
373
+ flwr_dot_proto_dot_log__pb2.PushLogsRequest.SerializeToString,
374
+ flwr_dot_proto_dot_log__pb2.PushLogsResponse.FromString,
375
+ options, channel_credentials,
376
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@@ -5,6 +5,7 @@ isort:skip_file
5
5
  import abc
6
6
  import flwr.proto.driver_pb2
7
7
  import flwr.proto.fab_pb2
8
+ import flwr.proto.log_pb2
8
9
  import flwr.proto.run_pb2
9
10
  import grpc
10
11
 
@@ -50,6 +51,16 @@ class DriverStub:
50
51
  flwr.proto.driver_pb2.PushServerAppOutputsResponse]
51
52
  """Push ServerApp outputs"""
52
53
 
54
+ UpdateRunStatus: grpc.UnaryUnaryMultiCallable[
55
+ flwr.proto.run_pb2.UpdateRunStatusRequest,
56
+ flwr.proto.run_pb2.UpdateRunStatusResponse]
57
+ """Update the status of a given run"""
58
+
59
+ PushLogs: grpc.UnaryUnaryMultiCallable[
60
+ flwr.proto.log_pb2.PushLogsRequest,
61
+ flwr.proto.log_pb2.PushLogsResponse]
62
+ """Push ServerApp logs"""
63
+
53
64
 
54
65
  class DriverServicer(metaclass=abc.ABCMeta):
55
66
  @abc.abstractmethod
@@ -116,5 +127,21 @@ class DriverServicer(metaclass=abc.ABCMeta):
116
127
  """Push ServerApp outputs"""
117
128
  pass
118
129
 
130
+ @abc.abstractmethod
131
+ def UpdateRunStatus(self,
132
+ request: flwr.proto.run_pb2.UpdateRunStatusRequest,
133
+ context: grpc.ServicerContext,
134
+ ) -> flwr.proto.run_pb2.UpdateRunStatusResponse:
135
+ """Update the status of a given run"""
136
+ pass
137
+
138
+ @abc.abstractmethod
139
+ def PushLogs(self,
140
+ request: flwr.proto.log_pb2.PushLogsRequest,
141
+ context: grpc.ServicerContext,
142
+ ) -> flwr.proto.log_pb2.PushLogsResponse:
143
+ """Push ServerApp logs"""
144
+ pass
145
+
119
146
 
120
147
  def add_DriverServicer_to_server(servicer: DriverServicer, server: grpc.Server) -> None: ...
flwr/proto/log_pb2.py ADDED
@@ -0,0 +1,29 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: flwr/proto/log.proto
4
+ # Protobuf Python Version: 4.25.0
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+ # @@protoc_insertion_point(imports)
11
+
12
+ _sym_db = _symbol_database.Default()
13
+
14
+
15
+ from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
16
+
17
+
18
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/log.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\"O\n\x0fPushLogsRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x0e\n\x06run_id\x18\x02 \x01(\x04\x12\x0c\n\x04logs\x18\x03 \x03(\t\"\x12\n\x10PushLogsResponseb\x06proto3')
19
+
20
+ _globals = globals()
21
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.log_pb2', _globals)
23
+ if _descriptor._USE_C_DESCRIPTORS == False:
24
+ DESCRIPTOR._options = None
25
+ _globals['_PUSHLOGSREQUEST']._serialized_start=59
26
+ _globals['_PUSHLOGSREQUEST']._serialized_end=138
27
+ _globals['_PUSHLOGSRESPONSE']._serialized_start=140
28
+ _globals['_PUSHLOGSRESPONSE']._serialized_end=158
29
+ # @@protoc_insertion_point(module_scope)
flwr/proto/log_pb2.pyi ADDED
@@ -0,0 +1,39 @@
1
+ """
2
+ @generated by mypy-protobuf. Do not edit manually!
3
+ isort:skip_file
4
+ """
5
+ import builtins
6
+ import flwr.proto.node_pb2
7
+ import google.protobuf.descriptor
8
+ import google.protobuf.internal.containers
9
+ import google.protobuf.message
10
+ import typing
11
+ import typing_extensions
12
+
13
+ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
14
+
15
+ class PushLogsRequest(google.protobuf.message.Message):
16
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
17
+ NODE_FIELD_NUMBER: builtins.int
18
+ RUN_ID_FIELD_NUMBER: builtins.int
19
+ LOGS_FIELD_NUMBER: builtins.int
20
+ @property
21
+ def node(self) -> flwr.proto.node_pb2.Node: ...
22
+ run_id: builtins.int
23
+ @property
24
+ def logs(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ...
25
+ def __init__(self,
26
+ *,
27
+ node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
28
+ run_id: builtins.int = ...,
29
+ logs: typing.Optional[typing.Iterable[typing.Text]] = ...,
30
+ ) -> None: ...
31
+ def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
32
+ def ClearField(self, field_name: typing_extensions.Literal["logs",b"logs","node",b"node","run_id",b"run_id"]) -> None: ...
33
+ global___PushLogsRequest = PushLogsRequest
34
+
35
+ class PushLogsResponse(google.protobuf.message.Message):
36
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
37
+ def __init__(self,
38
+ ) -> None: ...
39
+ global___PushLogsResponse = PushLogsResponse
@@ -0,0 +1,4 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+
@@ -0,0 +1,4 @@
1
+ """
2
+ @generated by mypy-protobuf. Do not edit manually!
3
+ isort:skip_file
4
+ """
flwr/server/app.py CHANGED
@@ -50,7 +50,6 @@ from flwr.common.constant import (
50
50
  TRANSPORT_TYPE_GRPC_ADAPTER,
51
51
  TRANSPORT_TYPE_GRPC_RERE,
52
52
  TRANSPORT_TYPE_REST,
53
- Status,
54
53
  )
55
54
  from flwr.common.exit_handlers import register_exit_handlers
56
55
  from flwr.common.logger import log
@@ -58,7 +57,6 @@ from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
58
57
  private_key_to_bytes,
59
58
  public_key_to_bytes,
60
59
  )
61
- from flwr.common.typing import RunStatus
62
60
  from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
63
61
  add_FleetServicer_to_server,
64
62
  )
@@ -345,7 +343,7 @@ def run_superlink() -> None:
345
343
  # Scheduler thread
346
344
  scheduler_th = threading.Thread(
347
345
  target=_flwr_serverapp_scheduler,
348
- args=(state_factory, args.driver_api_address),
346
+ args=(state_factory, args.driver_api_address, args.ssl_ca_certfile),
349
347
  )
350
348
  scheduler_th.start()
351
349
  bckg_threads.append(scheduler_th)
@@ -367,7 +365,9 @@ def run_superlink() -> None:
367
365
 
368
366
 
369
367
  def _flwr_serverapp_scheduler(
370
- state_factory: LinkStateFactory, driver_api_address: str
368
+ state_factory: LinkStateFactory,
369
+ driver_api_address: str,
370
+ ssl_ca_certfile: Optional[str],
371
371
  ) -> None:
372
372
  log(DEBUG, "Started flwr-serverapp scheduler thread.")
373
373
 
@@ -380,10 +380,6 @@ def _flwr_serverapp_scheduler(
380
380
 
381
381
  if pending_run_id:
382
382
 
383
- # Set run as starting
384
- state.update_run_status(
385
- run_id=pending_run_id, new_status=RunStatus(Status.STARTING, "", "")
386
- )
387
383
  log(
388
384
  INFO,
389
385
  "Launching `flwr-serverapp` subprocess with run-id %d. "
@@ -399,6 +395,12 @@ def _flwr_serverapp_scheduler(
399
395
  "--run-id",
400
396
  str(pending_run_id),
401
397
  ]
398
+ if ssl_ca_certfile:
399
+ command.append("--root-certificates")
400
+ command.append(ssl_ca_certfile)
401
+ else:
402
+ command.append("--insecure")
403
+
402
404
  subprocess.run(
403
405
  command,
404
406
  stdout=None,
@@ -26,6 +26,20 @@ from flwr.common.typing import Run
26
26
  class Driver(ABC):
27
27
  """Abstract base Driver class for the Driver API."""
28
28
 
29
+ @abstractmethod
30
+ def init_run(self, run_id: int) -> None:
31
+ """Request a run to the SuperLink with a given `run_id`.
32
+
33
+ If a Run with the specified `run_id` exists, a local Run
34
+ object will be created. It enables further functionality
35
+ in the driver, such as sending `Messages`.
36
+
37
+ Parameters
38
+ ----------
39
+ run_id : int
40
+ The `run_id` of the Run this Driver object operates in.
41
+ """
42
+
29
43
  @property
30
44
  @abstractmethod
31
45
  def run(self) -> Run:
@@ -60,8 +60,6 @@ class GrpcDriver(Driver):
60
60
 
61
61
  Parameters
62
62
  ----------
63
- run_id : int
64
- The identifier of the run.
65
63
  driver_service_address : str (default: "[::]:9091")
66
64
  The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
67
65
  root_certificates : Optional[bytes] (default: None)
@@ -72,11 +70,9 @@ class GrpcDriver(Driver):
72
70
 
73
71
  def __init__( # pylint: disable=too-many-arguments
74
72
  self,
75
- run_id: int,
76
73
  driver_service_address: str = DRIVER_API_DEFAULT_ADDRESS,
77
74
  root_certificates: Optional[bytes] = None,
78
75
  ) -> None:
79
- self._run_id = run_id
80
76
  self._addr = driver_service_address
81
77
  self._cert = root_certificates
82
78
  self._run: Optional[Run] = None
@@ -116,15 +112,17 @@ class GrpcDriver(Driver):
116
112
  channel.close()
117
113
  log(DEBUG, "[Driver] Disconnected")
118
114
 
119
- def _init_run(self) -> None:
115
+ def init_run(self, run_id: int) -> None:
116
+ """Initialize the run."""
120
117
  # Check if is initialized
121
118
  if self._run is not None:
122
119
  return
120
+
123
121
  # Get the run info
124
- req = GetRunRequest(run_id=self._run_id)
122
+ req = GetRunRequest(run_id=run_id)
125
123
  res: GetRunResponse = self._stub.GetRun(req)
126
124
  if not res.HasField("run"):
127
- raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
125
+ raise RuntimeError(f"Cannot find the run with ID: {run_id}")
128
126
  self._run = Run(
129
127
  run_id=res.run.run_id,
130
128
  fab_id=res.run.fab_id,
@@ -136,7 +134,6 @@ class GrpcDriver(Driver):
136
134
  @property
137
135
  def run(self) -> Run:
138
136
  """Run information."""
139
- self._init_run()
140
137
  return Run(**vars(self._run))
141
138
 
142
139
  @property
@@ -150,7 +147,7 @@ class GrpcDriver(Driver):
150
147
  # Check if the message is valid
151
148
  if not (
152
149
  # Assume self._run being initialized
153
- message.metadata.run_id == self._run_id
150
+ message.metadata.run_id == cast(Run, self._run).run_id
154
151
  and message.metadata.src_node_id == self.node.node_id
155
152
  and message.metadata.message_id == ""
156
153
  and message.metadata.reply_to_message == ""
@@ -171,7 +168,6 @@ class GrpcDriver(Driver):
171
168
  This method constructs a new `Message` with given content and metadata.
172
169
  The `run_id` and `src_node_id` will be set automatically.
173
170
  """
174
- self._init_run()
175
171
  if ttl:
176
172
  warnings.warn(
177
173
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -182,7 +178,7 @@ class GrpcDriver(Driver):
182
178
 
183
179
  ttl_ = DEFAULT_TTL if ttl is None else ttl
184
180
  metadata = Metadata(
185
- run_id=self._run_id,
181
+ run_id=cast(Run, self._run).run_id,
186
182
  message_id="", # Will be set by the server
187
183
  src_node_id=self.node.node_id,
188
184
  dst_node_id=dst_node_id,
@@ -195,10 +191,9 @@ class GrpcDriver(Driver):
195
191
 
196
192
  def get_node_ids(self) -> list[int]:
197
193
  """Get node IDs."""
198
- self._init_run()
199
194
  # Call GrpcDriverStub method
200
195
  res: GetNodesResponse = self._stub.GetNodes(
201
- GetNodesRequest(run_id=self._run_id)
196
+ GetNodesRequest(run_id=cast(Run, self._run).run_id)
202
197
  )
203
198
  return [node.node_id for node in res.nodes]
204
199
 
@@ -208,7 +203,6 @@ class GrpcDriver(Driver):
208
203
  This method takes an iterable of messages and sends each message
209
204
  to the node specified in `dst_node_id`.
210
205
  """
211
- self._init_run()
212
206
  # Construct TaskIns
213
207
  task_ins_list: list[TaskIns] = []
214
208
  for msg in messages:
@@ -230,7 +224,6 @@ class GrpcDriver(Driver):
230
224
  This method is used to collect messages from the SuperLink that correspond to a
231
225
  set of given message IDs.
232
226
  """
233
- self._init_run()
234
227
  # Pull TaskRes
235
228
  res: PullTaskResResponse = self._stub.PullTaskRes(
236
229
  PullTaskResRequest(node=self.node, task_ids=message_ids)
@@ -35,8 +35,6 @@ class InMemoryDriver(Driver):
35
35
 
36
36
  Parameters
37
37
  ----------
38
- run_id : int
39
- The identifier of the run.
40
38
  state_factory : StateFactory
41
39
  A StateFactory embedding a state that this driver can interface with.
42
40
  pull_interval : float (default=0.1)
@@ -45,18 +43,15 @@ class InMemoryDriver(Driver):
45
43
 
46
44
  def __init__(
47
45
  self,
48
- run_id: int,
49
46
  state_factory: LinkStateFactory,
50
47
  pull_interval: float = 0.1,
51
48
  ) -> None:
52
- self._run_id = run_id
53
49
  self._run: Optional[Run] = None
54
50
  self.state = state_factory.state()
55
51
  self.pull_interval = pull_interval
56
52
  self.node = Node(node_id=0, anonymous=True)
57
53
 
58
54
  def _check_message(self, message: Message) -> None:
59
- self._init_run()
60
55
  # Check if the message is valid
61
56
  if not (
62
57
  message.metadata.run_id == cast(Run, self._run).run_id
@@ -67,19 +62,18 @@ class InMemoryDriver(Driver):
67
62
  ):
68
63
  raise ValueError(f"Invalid message: {message}")
69
64
 
70
- def _init_run(self) -> None:
65
+ def init_run(self, run_id: int) -> None:
71
66
  """Initialize the run."""
72
67
  if self._run is not None:
73
68
  return
74
- run = self.state.get_run(self._run_id)
69
+ run = self.state.get_run(run_id)
75
70
  if run is None:
76
- raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
71
+ raise RuntimeError(f"Cannot find the run with ID: {run_id}")
77
72
  self._run = run
78
73
 
79
74
  @property
80
75
  def run(self) -> Run:
81
76
  """Run ID."""
82
- self._init_run()
83
77
  return Run(**vars(cast(Run, self._run)))
84
78
 
85
79
  def create_message( # pylint: disable=too-many-arguments,R0917
@@ -95,7 +89,6 @@ class InMemoryDriver(Driver):
95
89
  This method constructs a new `Message` with given content and metadata.
96
90
  The `run_id` and `src_node_id` will be set automatically.
97
91
  """
98
- self._init_run()
99
92
  if ttl:
100
93
  warnings.warn(
101
94
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -119,7 +112,6 @@ class InMemoryDriver(Driver):
119
112
 
120
113
  def get_node_ids(self) -> list[int]:
121
114
  """Get node IDs."""
122
- self._init_run()
123
115
  return list(self.state.get_nodes(cast(Run, self._run).run_id))
124
116
 
125
117
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]: