flwr-nightly 1.13.0.dev20241023__py3-none-any.whl → 1.13.0.dev20241025__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

@@ -10,7 +10,7 @@ license = "Apache-2.0"
10
10
  dependencies = [
11
11
  "flwr[simulation]>=1.12.0",
12
12
  "flwr-datasets[vision]>=0.3.0",
13
- "tensorflow>=2.11.1",
13
+ "tensorflow>=2.11.1,<2.18.0",
14
14
  ]
15
15
 
16
16
  [tool.hatch.build.targets.wheel]
flwr/proto/driver_pb2.py CHANGED
@@ -13,30 +13,39 @@ _sym_db = _symbol_database.Default()
13
13
 
14
14
 
15
15
  from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
16
+ from flwr.proto import message_pb2 as flwr_dot_proto_dot_message__pb2
16
17
  from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2
17
18
  from flwr.proto import run_pb2 as flwr_dot_proto_dot_run__pb2
18
19
  from flwr.proto import fab_pb2 as flwr_dot_proto_dot_fab__pb2
19
20
 
20
21
 
21
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xc7\x03\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x62\x06proto3')
22
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x18\x66lwr/proto/message.proto\x1a\x15\x66lwr/proto/task.proto\x1a\x14\x66lwr/proto/run.proto\x1a\x14\x66lwr/proto/fab.proto\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes\",\n\x1aPullServerAppInputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\"\x7f\n\x1bPullServerAppInputsResponse\x12$\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x13.flwr.proto.Context\x12\x1c\n\x03run\x18\x02 \x01(\x0b\x32\x0f.flwr.proto.Run\x12\x1c\n\x03\x66\x61\x62\x18\x03 \x01(\x0b\x32\x0f.flwr.proto.Fab\"S\n\x1bPushServerAppOutputsRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x04\x12$\n\x07\x63ontext\x18\x02 \x01(\x0b\x32\x13.flwr.proto.Context\"\x1e\n\x1cPushServerAppOutputsResponse2\x9e\x05\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x12\x41\n\x06GetRun\x12\x19.flwr.proto.GetRunRequest\x1a\x1a.flwr.proto.GetRunResponse\"\x00\x12\x41\n\x06GetFab\x12\x19.flwr.proto.GetFabRequest\x1a\x1a.flwr.proto.GetFabResponse\"\x00\x12h\n\x13PullServerAppInputs\x12&.flwr.proto.PullServerAppInputsRequest\x1a\'.flwr.proto.PullServerAppInputsResponse\"\x00\x12k\n\x14PushServerAppOutputs\x12\'.flwr.proto.PushServerAppOutputsRequest\x1a(.flwr.proto.PushServerAppOutputsResponse\"\x00\x62\x06proto3')
22
23
 
23
24
  _globals = globals()
24
25
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
25
26
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.driver_pb2', _globals)
26
27
  if _descriptor._USE_C_DESCRIPTORS == False:
27
28
  DESCRIPTOR._options = None
28
- _globals['_GETNODESREQUEST']._serialized_start=129
29
- _globals['_GETNODESREQUEST']._serialized_end=162
30
- _globals['_GETNODESRESPONSE']._serialized_start=164
31
- _globals['_GETNODESRESPONSE']._serialized_end=215
32
- _globals['_PUSHTASKINSREQUEST']._serialized_start=217
33
- _globals['_PUSHTASKINSREQUEST']._serialized_end=281
34
- _globals['_PUSHTASKINSRESPONSE']._serialized_start=283
35
- _globals['_PUSHTASKINSRESPONSE']._serialized_end=322
36
- _globals['_PULLTASKRESREQUEST']._serialized_start=324
37
- _globals['_PULLTASKRESREQUEST']._serialized_end=394
38
- _globals['_PULLTASKRESRESPONSE']._serialized_start=396
39
- _globals['_PULLTASKRESRESPONSE']._serialized_end=461
40
- _globals['_DRIVER']._serialized_start=464
41
- _globals['_DRIVER']._serialized_end=919
29
+ _globals['_GETNODESREQUEST']._serialized_start=155
30
+ _globals['_GETNODESREQUEST']._serialized_end=188
31
+ _globals['_GETNODESRESPONSE']._serialized_start=190
32
+ _globals['_GETNODESRESPONSE']._serialized_end=241
33
+ _globals['_PUSHTASKINSREQUEST']._serialized_start=243
34
+ _globals['_PUSHTASKINSREQUEST']._serialized_end=307
35
+ _globals['_PUSHTASKINSRESPONSE']._serialized_start=309
36
+ _globals['_PUSHTASKINSRESPONSE']._serialized_end=348
37
+ _globals['_PULLTASKRESREQUEST']._serialized_start=350
38
+ _globals['_PULLTASKRESREQUEST']._serialized_end=420
39
+ _globals['_PULLTASKRESRESPONSE']._serialized_start=422
40
+ _globals['_PULLTASKRESRESPONSE']._serialized_end=487
41
+ _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_start=489
42
+ _globals['_PULLSERVERAPPINPUTSREQUEST']._serialized_end=533
43
+ _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_start=535
44
+ _globals['_PULLSERVERAPPINPUTSRESPONSE']._serialized_end=662
45
+ _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_start=664
46
+ _globals['_PUSHSERVERAPPOUTPUTSREQUEST']._serialized_end=747
47
+ _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_start=749
48
+ _globals['_PUSHSERVERAPPOUTPUTSRESPONSE']._serialized_end=779
49
+ _globals['_DRIVER']._serialized_start=782
50
+ _globals['_DRIVER']._serialized_end=1452
42
51
  # @@protoc_insertion_point(module_scope)
flwr/proto/driver_pb2.pyi CHANGED
@@ -3,7 +3,10 @@
3
3
  isort:skip_file
4
4
  """
5
5
  import builtins
6
+ import flwr.proto.fab_pb2
7
+ import flwr.proto.message_pb2
6
8
  import flwr.proto.node_pb2
9
+ import flwr.proto.run_pb2
7
10
  import flwr.proto.task_pb2
8
11
  import google.protobuf.descriptor
9
12
  import google.protobuf.internal.containers
@@ -91,3 +94,59 @@ class PullTaskResResponse(google.protobuf.message.Message):
91
94
  ) -> None: ...
92
95
  def ClearField(self, field_name: typing_extensions.Literal["task_res_list",b"task_res_list"]) -> None: ...
93
96
  global___PullTaskResResponse = PullTaskResResponse
97
+
98
+ class PullServerAppInputsRequest(google.protobuf.message.Message):
99
+ """PullServerAppInputs messages"""
100
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
101
+ RUN_ID_FIELD_NUMBER: builtins.int
102
+ run_id: builtins.int
103
+ def __init__(self,
104
+ *,
105
+ run_id: builtins.int = ...,
106
+ ) -> None: ...
107
+ def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ...
108
+ global___PullServerAppInputsRequest = PullServerAppInputsRequest
109
+
110
+ class PullServerAppInputsResponse(google.protobuf.message.Message):
111
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
112
+ CONTEXT_FIELD_NUMBER: builtins.int
113
+ RUN_FIELD_NUMBER: builtins.int
114
+ FAB_FIELD_NUMBER: builtins.int
115
+ @property
116
+ def context(self) -> flwr.proto.message_pb2.Context: ...
117
+ @property
118
+ def run(self) -> flwr.proto.run_pb2.Run: ...
119
+ @property
120
+ def fab(self) -> flwr.proto.fab_pb2.Fab: ...
121
+ def __init__(self,
122
+ *,
123
+ context: typing.Optional[flwr.proto.message_pb2.Context] = ...,
124
+ run: typing.Optional[flwr.proto.run_pb2.Run] = ...,
125
+ fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ...,
126
+ ) -> None: ...
127
+ def HasField(self, field_name: typing_extensions.Literal["context",b"context","fab",b"fab","run",b"run"]) -> builtins.bool: ...
128
+ def ClearField(self, field_name: typing_extensions.Literal["context",b"context","fab",b"fab","run",b"run"]) -> None: ...
129
+ global___PullServerAppInputsResponse = PullServerAppInputsResponse
130
+
131
+ class PushServerAppOutputsRequest(google.protobuf.message.Message):
132
+ """PushServerAppOutputs messages"""
133
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
134
+ RUN_ID_FIELD_NUMBER: builtins.int
135
+ CONTEXT_FIELD_NUMBER: builtins.int
136
+ run_id: builtins.int
137
+ @property
138
+ def context(self) -> flwr.proto.message_pb2.Context: ...
139
+ def __init__(self,
140
+ *,
141
+ run_id: builtins.int = ...,
142
+ context: typing.Optional[flwr.proto.message_pb2.Context] = ...,
143
+ ) -> None: ...
144
+ def HasField(self, field_name: typing_extensions.Literal["context",b"context"]) -> builtins.bool: ...
145
+ def ClearField(self, field_name: typing_extensions.Literal["context",b"context","run_id",b"run_id"]) -> None: ...
146
+ global___PushServerAppOutputsRequest = PushServerAppOutputsRequest
147
+
148
+ class PushServerAppOutputsResponse(google.protobuf.message.Message):
149
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
150
+ def __init__(self,
151
+ ) -> None: ...
152
+ global___PushServerAppOutputsResponse = PushServerAppOutputsResponse
@@ -46,6 +46,16 @@ class DriverStub(object):
46
46
  request_serializer=flwr_dot_proto_dot_fab__pb2.GetFabRequest.SerializeToString,
47
47
  response_deserializer=flwr_dot_proto_dot_fab__pb2.GetFabResponse.FromString,
48
48
  )
49
+ self.PullServerAppInputs = channel.unary_unary(
50
+ '/flwr.proto.Driver/PullServerAppInputs',
51
+ request_serializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsRequest.SerializeToString,
52
+ response_deserializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsResponse.FromString,
53
+ )
54
+ self.PushServerAppOutputs = channel.unary_unary(
55
+ '/flwr.proto.Driver/PushServerAppOutputs',
56
+ request_serializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.SerializeToString,
57
+ response_deserializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.FromString,
58
+ )
49
59
 
50
60
 
51
61
  class DriverServicer(object):
@@ -93,6 +103,20 @@ class DriverServicer(object):
93
103
  context.set_details('Method not implemented!')
94
104
  raise NotImplementedError('Method not implemented!')
95
105
 
106
+ def PullServerAppInputs(self, request, context):
107
+ """Pull ServerApp inputs
108
+ """
109
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
110
+ context.set_details('Method not implemented!')
111
+ raise NotImplementedError('Method not implemented!')
112
+
113
+ def PushServerAppOutputs(self, request, context):
114
+ """Push ServerApp outputs
115
+ """
116
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
117
+ context.set_details('Method not implemented!')
118
+ raise NotImplementedError('Method not implemented!')
119
+
96
120
 
97
121
  def add_DriverServicer_to_server(servicer, server):
98
122
  rpc_method_handlers = {
@@ -126,6 +150,16 @@ def add_DriverServicer_to_server(servicer, server):
126
150
  request_deserializer=flwr_dot_proto_dot_fab__pb2.GetFabRequest.FromString,
127
151
  response_serializer=flwr_dot_proto_dot_fab__pb2.GetFabResponse.SerializeToString,
128
152
  ),
153
+ 'PullServerAppInputs': grpc.unary_unary_rpc_method_handler(
154
+ servicer.PullServerAppInputs,
155
+ request_deserializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsRequest.FromString,
156
+ response_serializer=flwr_dot_proto_dot_driver__pb2.PullServerAppInputsResponse.SerializeToString,
157
+ ),
158
+ 'PushServerAppOutputs': grpc.unary_unary_rpc_method_handler(
159
+ servicer.PushServerAppOutputs,
160
+ request_deserializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.FromString,
161
+ response_serializer=flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.SerializeToString,
162
+ ),
129
163
  }
130
164
  generic_handler = grpc.method_handlers_generic_handler(
131
165
  'flwr.proto.Driver', rpc_method_handlers)
@@ -237,3 +271,37 @@ class Driver(object):
237
271
  flwr_dot_proto_dot_fab__pb2.GetFabResponse.FromString,
238
272
  options, channel_credentials,
239
273
  insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
274
+
275
+ @staticmethod
276
+ def PullServerAppInputs(request,
277
+ target,
278
+ options=(),
279
+ channel_credentials=None,
280
+ call_credentials=None,
281
+ insecure=False,
282
+ compression=None,
283
+ wait_for_ready=None,
284
+ timeout=None,
285
+ metadata=None):
286
+ return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/PullServerAppInputs',
287
+ flwr_dot_proto_dot_driver__pb2.PullServerAppInputsRequest.SerializeToString,
288
+ flwr_dot_proto_dot_driver__pb2.PullServerAppInputsResponse.FromString,
289
+ options, channel_credentials,
290
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
291
+
292
+ @staticmethod
293
+ def PushServerAppOutputs(request,
294
+ target,
295
+ options=(),
296
+ channel_credentials=None,
297
+ call_credentials=None,
298
+ insecure=False,
299
+ compression=None,
300
+ wait_for_ready=None,
301
+ timeout=None,
302
+ metadata=None):
303
+ return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/PushServerAppOutputs',
304
+ flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsRequest.SerializeToString,
305
+ flwr_dot_proto_dot_driver__pb2.PushServerAppOutputsResponse.FromString,
306
+ options, channel_credentials,
307
+ insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@@ -40,6 +40,16 @@ class DriverStub:
40
40
  flwr.proto.fab_pb2.GetFabResponse]
41
41
  """Get FAB"""
42
42
 
43
+ PullServerAppInputs: grpc.UnaryUnaryMultiCallable[
44
+ flwr.proto.driver_pb2.PullServerAppInputsRequest,
45
+ flwr.proto.driver_pb2.PullServerAppInputsResponse]
46
+ """Pull ServerApp inputs"""
47
+
48
+ PushServerAppOutputs: grpc.UnaryUnaryMultiCallable[
49
+ flwr.proto.driver_pb2.PushServerAppOutputsRequest,
50
+ flwr.proto.driver_pb2.PushServerAppOutputsResponse]
51
+ """Push ServerApp outputs"""
52
+
43
53
 
44
54
  class DriverServicer(metaclass=abc.ABCMeta):
45
55
  @abc.abstractmethod
@@ -90,5 +100,21 @@ class DriverServicer(metaclass=abc.ABCMeta):
90
100
  """Get FAB"""
91
101
  pass
92
102
 
103
+ @abc.abstractmethod
104
+ def PullServerAppInputs(self,
105
+ request: flwr.proto.driver_pb2.PullServerAppInputsRequest,
106
+ context: grpc.ServicerContext,
107
+ ) -> flwr.proto.driver_pb2.PullServerAppInputsResponse:
108
+ """Pull ServerApp inputs"""
109
+ pass
110
+
111
+ @abc.abstractmethod
112
+ def PushServerAppOutputs(self,
113
+ request: flwr.proto.driver_pb2.PushServerAppOutputsRequest,
114
+ context: grpc.ServicerContext,
115
+ ) -> flwr.proto.driver_pb2.PushServerAppOutputsResponse:
116
+ """Push ServerApp outputs"""
117
+ pass
118
+
93
119
 
94
120
  def add_DriverServicer_to_server(servicer: DriverServicer, server: grpc.Server) -> None: ...
flwr/server/app.py CHANGED
@@ -64,7 +64,7 @@ from flwr.proto.fleet_pb2_grpc import ( # pylint: disable=E0611
64
64
  )
65
65
  from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
66
66
  from flwr.superexec.app import load_executor
67
- from flwr.superexec.exec_grpc import run_superexec_api_grpc
67
+ from flwr.superexec.exec_grpc import run_exec_api_grpc
68
68
 
69
69
  from .client_manager import ClientManager
70
70
  from .history import History
@@ -329,8 +329,10 @@ def run_superlink() -> None:
329
329
  raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")
330
330
 
331
331
  # Start Exec API
332
- exec_server: grpc.Server = run_superexec_api_grpc(
332
+ exec_server: grpc.Server = run_exec_api_grpc(
333
333
  address=exec_address,
334
+ state_factory=state_factory,
335
+ ffs_factory=ffs_factory,
334
336
  executor=load_executor(args),
335
337
  certificates=certificates,
336
338
  config=parse_config_args(
@@ -34,7 +34,6 @@ from flwr.common.config import (
34
34
  from flwr.common.constant import DRIVER_API_DEFAULT_ADDRESS
35
35
  from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
36
36
  from flwr.common.object_ref import load_app
37
- from flwr.common.typing import UserConfig
38
37
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
39
38
  from flwr.proto.run_pb2 import ( # pylint: disable=E0611
40
39
  CreateRunRequest,
@@ -46,13 +45,14 @@ from .driver.grpc_driver import GrpcDriver
46
45
  from .server_app import LoadServerAppError, ServerApp
47
46
 
48
47
 
48
+ # pylint: disable-next=too-many-arguments,too-many-positional-arguments
49
49
  def run(
50
50
  driver: Driver,
51
+ context: Context,
51
52
  server_app_dir: str,
52
- server_app_run_config: UserConfig,
53
53
  server_app_attr: Optional[str] = None,
54
54
  loaded_server_app: Optional[ServerApp] = None,
55
- ) -> None:
55
+ ) -> Context:
56
56
  """Run ServerApp with a given Driver."""
57
57
  if not (server_app_attr is None) ^ (loaded_server_app is None):
58
58
  raise ValueError(
@@ -78,15 +78,11 @@ def run(
78
78
 
79
79
  server_app = _load()
80
80
 
81
- # Initialize Context
82
- context = Context(
83
- node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config
84
- )
85
-
86
81
  # Call ServerApp
87
82
  server_app(driver=driver, context=context)
88
83
 
89
84
  log(DEBUG, "ServerApp finished running.")
85
+ return context
90
86
 
91
87
 
92
88
  # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
@@ -225,11 +221,19 @@ def run_server_app() -> None:
225
221
  root_certificates,
226
222
  )
227
223
 
224
+ # Initialize Context
225
+ context = Context(
226
+ node_id=0,
227
+ node_config={},
228
+ state=RecordSet(),
229
+ run_config=server_app_run_config,
230
+ )
231
+
228
232
  # Run the ServerApp with the Driver
229
233
  run(
230
234
  driver=driver,
235
+ context=context,
231
236
  server_app_dir=app_path,
232
- server_app_run_config=server_app_run_config,
233
237
  server_app_attr=server_app_attr,
234
238
  )
235
239
 
@@ -15,27 +15,35 @@
15
15
  """Driver API servicer."""
16
16
 
17
17
 
18
+ import threading
18
19
  import time
19
- from logging import DEBUG
20
+ from logging import DEBUG, INFO
20
21
  from typing import Optional
21
22
  from uuid import UUID
22
23
 
23
24
  import grpc
24
25
 
26
+ from flwr.common.constant import Status
25
27
  from flwr.common.logger import log
26
28
  from flwr.common.serde import (
29
+ context_from_proto,
30
+ context_to_proto,
27
31
  fab_from_proto,
28
32
  fab_to_proto,
33
+ run_to_proto,
29
34
  user_config_from_proto,
30
- user_config_to_proto,
31
35
  )
32
- from flwr.common.typing import Fab
36
+ from flwr.common.typing import Fab, RunStatus
33
37
  from flwr.proto import driver_pb2_grpc # pylint: disable=E0611
34
38
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
35
39
  GetNodesRequest,
36
40
  GetNodesResponse,
41
+ PullServerAppInputsRequest,
42
+ PullServerAppInputsResponse,
37
43
  PullTaskResRequest,
38
44
  PullTaskResResponse,
45
+ PushServerAppOutputsRequest,
46
+ PushServerAppOutputsResponse,
39
47
  PushTaskInsRequest,
40
48
  PushTaskInsResponse,
41
49
  )
@@ -46,7 +54,6 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
46
54
  CreateRunResponse,
47
55
  GetRunRequest,
48
56
  GetRunResponse,
49
- Run,
50
57
  )
51
58
  from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
52
59
  from flwr.server.superlink.ffs.ffs import Ffs
@@ -63,6 +70,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
63
70
  ) -> None:
64
71
  self.state_factory = state_factory
65
72
  self.ffs_factory = ffs_factory
73
+ self.lock = threading.RLock()
66
74
 
67
75
  def GetNodes(
68
76
  self, request: GetNodesRequest, context: grpc.ServicerContext
@@ -177,15 +185,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
177
185
  if run is None:
178
186
  return GetRunResponse()
179
187
 
180
- return GetRunResponse(
181
- run=Run(
182
- run_id=run.run_id,
183
- fab_id=run.fab_id,
184
- fab_version=run.fab_version,
185
- override_config=user_config_to_proto(run.override_config),
186
- fab_hash=run.fab_hash,
187
- )
188
- )
188
+ return GetRunResponse(run=run_to_proto(run))
189
189
 
190
190
  def GetFab(
191
191
  self, request: GetFabRequest, context: grpc.ServicerContext
@@ -200,6 +200,58 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
200
200
 
201
201
  raise ValueError(f"Found no FAB with hash: {request.hash_str}")
202
202
 
203
+ def PullServerAppInputs(
204
+ self, request: PullServerAppInputsRequest, context: grpc.ServicerContext
205
+ ) -> PullServerAppInputsResponse:
206
+ """Pull ServerApp process inputs."""
207
+ log(DEBUG, "DriverServicer.PullServerAppInputs")
208
+ # Init access to LinkState and Ffs
209
+ state = self.state_factory.state()
210
+ ffs = self.ffs_factory.ffs()
211
+
212
+ # Lock access to LinkState, preventing obtaining the same pending run_id
213
+ with self.lock:
214
+ # If run_id is provided, use it, otherwise use the pending run_id
215
+ if request.HasField("run_id"):
216
+ run_id: Optional[int] = request.run_id
217
+ else:
218
+ run_id = state.get_pending_run_id()
219
+ # If there's no pending run, return an empty response
220
+ if run_id is None:
221
+ return PullServerAppInputsResponse()
222
+
223
+ # Retrieve Context, Run and Fab for the run_id
224
+ serverapp_ctxt = state.get_serverapp_context(run_id)
225
+ run = state.get_run(run_id)
226
+ fab = None
227
+ if run and run.fab_hash:
228
+ if result := ffs.get(run.fab_hash):
229
+ fab = Fab(run.fab_hash, result[0])
230
+ if run and fab:
231
+ # Update run status to STARTING
232
+ if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")):
233
+ log(INFO, "Starting run %d", run_id)
234
+ return PullServerAppInputsResponse(
235
+ context=(
236
+ context_to_proto(serverapp_ctxt) if serverapp_ctxt else None
237
+ ),
238
+ run=run_to_proto(run),
239
+ fab=fab_to_proto(fab),
240
+ )
241
+
242
+ # Raise an exception if the Run or Fab is not found,
243
+ # or if the status cannot be updated to STARTING
244
+ raise RuntimeError(f"Failed to start run {run_id}")
245
+
246
+ def PushServerAppOutputs(
247
+ self, request: PushServerAppOutputsRequest, context: grpc.ServicerContext
248
+ ) -> PushServerAppOutputsResponse:
249
+ """Push ServerApp process outputs."""
250
+ log(DEBUG, "DriverServicer.PushServerAppOutputs")
251
+ state = self.state_factory.state()
252
+ state.set_serverapp_context(request.run_id, context_from_proto(request.context))
253
+ return PushServerAppOutputsResponse()
254
+
203
255
 
204
256
  def _raise_if(validation_error: bool, detail: str) -> None:
205
257
  if validation_error:
@@ -22,7 +22,7 @@ from logging import ERROR, WARNING
22
22
  from typing import Optional
23
23
  from uuid import UUID, uuid4
24
24
 
25
- from flwr.common import log, now
25
+ from flwr.common import Context, log, now
26
26
  from flwr.common.constant import (
27
27
  MESSAGE_TTL_TOLERANCE,
28
28
  NODE_ID_NUM_BYTES,
@@ -65,6 +65,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
65
65
 
66
66
  # Map run_id to RunRecord
67
67
  self.run_ids: dict[int, RunRecord] = {}
68
+ self.contexts: dict[int, Context] = {}
68
69
  self.task_ins_store: dict[UUID, TaskIns] = {}
69
70
  self.task_res_store: dict[UUID, TaskRes] = {}
70
71
 
@@ -500,3 +501,13 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
500
501
  self.node_ids[node_id] = (time.time() + ping_interval, ping_interval)
501
502
  return True
502
503
  return False
504
+
505
+ def get_serverapp_context(self, run_id: int) -> Optional[Context]:
506
+ """Get the context for the specified `run_id`."""
507
+ return self.contexts.get(run_id)
508
+
509
+ def set_serverapp_context(self, run_id: int, context: Context) -> None:
510
+ """Set the context for the specified `run_id`."""
511
+ if run_id not in self.run_ids:
512
+ raise ValueError(f"Run {run_id} not found")
513
+ self.contexts[run_id] = context
@@ -19,6 +19,7 @@ import abc
19
19
  from typing import Optional
20
20
  from uuid import UUID
21
21
 
22
+ from flwr.common import Context
22
23
  from flwr.common.typing import Run, RunStatus, UserConfig
23
24
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
24
25
 
@@ -270,3 +271,31 @@ class LinkState(abc.ABC): # pylint: disable=R0904
270
271
  is_acknowledged : bool
271
272
  True if the ping is successfully acknowledged; otherwise, False.
272
273
  """
274
+
275
+ @abc.abstractmethod
276
+ def get_serverapp_context(self, run_id: int) -> Optional[Context]:
277
+ """Get the context for the specified `run_id`.
278
+
279
+ Parameters
280
+ ----------
281
+ run_id : int
282
+ The identifier of the run for which to retrieve the context.
283
+
284
+ Returns
285
+ -------
286
+ Optional[Context]
287
+ The context associated with the specified `run_id`, or `None` if no context
288
+ exists for the given `run_id`.
289
+ """
290
+
291
+ @abc.abstractmethod
292
+ def set_serverapp_context(self, run_id: int, context: Context) -> None:
293
+ """Set the context for the specified `run_id`.
294
+
295
+ Parameters
296
+ ----------
297
+ run_id : int
298
+ The identifier of the run for which to set the context.
299
+ context : Context
300
+ The context to be associated with the specified `run_id`.
301
+ """