flwr-nightly 1.10.0.dev20240707__py3-none-any.whl → 1.11.0.dev20240724__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (99) hide show
  1. flwr/cli/build.py +16 -2
  2. flwr/cli/config_utils.py +47 -27
  3. flwr/cli/install.py +17 -1
  4. flwr/cli/new/new.py +32 -21
  5. flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +15 -5
  6. flwr/cli/new/templates/app/code/client.jax.py.tpl +2 -1
  7. flwr/cli/new/templates/app/code/client.mlx.py.tpl +36 -13
  8. flwr/cli/new/templates/app/code/client.numpy.py.tpl +2 -1
  9. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +16 -5
  10. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +6 -3
  11. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +25 -5
  12. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +22 -19
  13. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
  14. flwr/cli/new/templates/app/code/flwr_tune/server.py.tpl +1 -1
  15. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -0
  16. flwr/cli/new/templates/app/code/server.jax.py.tpl +16 -8
  17. flwr/cli/new/templates/app/code/server.mlx.py.tpl +12 -7
  18. flwr/cli/new/templates/app/code/server.numpy.py.tpl +16 -8
  19. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +15 -13
  20. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +17 -10
  21. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +16 -13
  22. flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +14 -2
  23. flwr/cli/new/templates/app/code/task.mlx.py.tpl +14 -2
  24. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -3
  25. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +13 -1
  26. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +9 -12
  27. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +38 -0
  28. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +17 -11
  29. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +17 -12
  30. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +12 -12
  31. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +13 -12
  32. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +12 -12
  33. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +15 -12
  34. flwr/cli/run/run.py +133 -54
  35. flwr/client/app.py +56 -24
  36. flwr/client/client_app.py +28 -8
  37. flwr/client/grpc_adapter_client/connection.py +3 -2
  38. flwr/client/grpc_client/connection.py +3 -2
  39. flwr/client/grpc_rere_client/connection.py +17 -6
  40. flwr/client/message_handler/message_handler.py +1 -1
  41. flwr/client/node_state.py +59 -12
  42. flwr/client/node_state_tests.py +4 -3
  43. flwr/client/rest_client/connection.py +19 -8
  44. flwr/client/supernode/app.py +39 -39
  45. flwr/client/typing.py +2 -2
  46. flwr/common/config.py +92 -2
  47. flwr/common/constant.py +3 -0
  48. flwr/common/context.py +24 -9
  49. flwr/common/logger.py +25 -0
  50. flwr/common/object_ref.py +84 -21
  51. flwr/common/serde.py +45 -0
  52. flwr/common/telemetry.py +17 -0
  53. flwr/common/typing.py +5 -0
  54. flwr/proto/common_pb2.py +36 -0
  55. flwr/proto/common_pb2.pyi +121 -0
  56. flwr/proto/common_pb2_grpc.py +4 -0
  57. flwr/proto/common_pb2_grpc.pyi +4 -0
  58. flwr/proto/driver_pb2.py +24 -19
  59. flwr/proto/driver_pb2.pyi +21 -1
  60. flwr/proto/exec_pb2.py +20 -11
  61. flwr/proto/exec_pb2.pyi +41 -1
  62. flwr/proto/run_pb2.py +12 -7
  63. flwr/proto/run_pb2.pyi +22 -1
  64. flwr/proto/task_pb2.py +7 -8
  65. flwr/server/__init__.py +2 -0
  66. flwr/server/compat/legacy_context.py +5 -4
  67. flwr/server/driver/grpc_driver.py +82 -140
  68. flwr/server/run_serverapp.py +40 -18
  69. flwr/server/server_app.py +56 -10
  70. flwr/server/serverapp_components.py +52 -0
  71. flwr/server/superlink/driver/driver_servicer.py +18 -3
  72. flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
  73. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  74. flwr/server/superlink/fleet/vce/backend/backend.py +4 -4
  75. flwr/server/superlink/fleet/vce/backend/raybackend.py +10 -10
  76. flwr/server/superlink/fleet/vce/vce_api.py +149 -117
  77. flwr/server/superlink/state/in_memory_state.py +11 -3
  78. flwr/server/superlink/state/sqlite_state.py +23 -8
  79. flwr/server/superlink/state/state.py +7 -2
  80. flwr/server/typing.py +2 -0
  81. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
  82. flwr/simulation/__init__.py +1 -1
  83. flwr/simulation/app.py +4 -3
  84. flwr/simulation/ray_transport/ray_actor.py +15 -19
  85. flwr/simulation/ray_transport/ray_client_proxy.py +22 -9
  86. flwr/simulation/run_simulation.py +269 -70
  87. flwr/superexec/app.py +17 -11
  88. flwr/superexec/deployment.py +111 -35
  89. flwr/superexec/exec_grpc.py +5 -1
  90. flwr/superexec/exec_servicer.py +6 -1
  91. flwr/superexec/executor.py +21 -0
  92. flwr/superexec/simulation.py +181 -0
  93. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/METADATA +3 -2
  94. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/RECORD +97 -91
  95. flwr/cli/new/templates/app/code/server.hf.py.tpl +0 -17
  96. flwr/cli/new/templates/app/pyproject.hf.toml.tpl +0 -37
  97. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/LICENSE +0 -0
  98. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/WHEEL +0 -0
  99. {flwr_nightly-1.10.0.dev20240707.dist-info → flwr_nightly-1.11.0.dev20240724.dist-info}/entry_points.txt +0 -0
flwr/proto/exec_pb2.pyi CHANGED
@@ -3,7 +3,9 @@
3
3
  isort:skip_file
4
4
  """
5
5
  import builtins
6
+ import flwr.proto.transport_pb2
6
7
  import google.protobuf.descriptor
8
+ import google.protobuf.internal.containers
7
9
  import google.protobuf.message
8
10
  import typing
9
11
  import typing_extensions
@@ -12,13 +14,51 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
12
14
 
13
15
  class StartRunRequest(google.protobuf.message.Message):
14
16
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
17
+ class OverrideConfigEntry(google.protobuf.message.Message):
18
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
19
+ KEY_FIELD_NUMBER: builtins.int
20
+ VALUE_FIELD_NUMBER: builtins.int
21
+ key: typing.Text
22
+ @property
23
+ def value(self) -> flwr.proto.transport_pb2.Scalar: ...
24
+ def __init__(self,
25
+ *,
26
+ key: typing.Text = ...,
27
+ value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ...,
28
+ ) -> None: ...
29
+ def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
30
+ def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
31
+
32
+ class FederationConfigEntry(google.protobuf.message.Message):
33
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
34
+ KEY_FIELD_NUMBER: builtins.int
35
+ VALUE_FIELD_NUMBER: builtins.int
36
+ key: typing.Text
37
+ @property
38
+ def value(self) -> flwr.proto.transport_pb2.Scalar: ...
39
+ def __init__(self,
40
+ *,
41
+ key: typing.Text = ...,
42
+ value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ...,
43
+ ) -> None: ...
44
+ def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
45
+ def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
46
+
15
47
  FAB_FILE_FIELD_NUMBER: builtins.int
48
+ OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int
49
+ FEDERATION_CONFIG_FIELD_NUMBER: builtins.int
16
50
  fab_file: builtins.bytes
51
+ @property
52
+ def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
53
+ @property
54
+ def federation_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
17
55
  def __init__(self,
18
56
  *,
19
57
  fab_file: builtins.bytes = ...,
58
+ override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
59
+ federation_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
20
60
  ) -> None: ...
21
- def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file"]) -> None: ...
61
+ def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file","federation_config",b"federation_config","override_config",b"override_config"]) -> None: ...
22
62
  global___StartRunRequest = StartRunRequest
23
63
 
24
64
  class StartRunResponse(google.protobuf.message.Message):
flwr/proto/run_pb2.py CHANGED
@@ -12,19 +12,24 @@ from google.protobuf.internal import builder as _builder
12
12
  _sym_db = _symbol_database.Default()
13
13
 
14
14
 
15
+ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
15
16
 
16
17
 
17
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\":\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3')
18
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xc3\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3')
18
19
 
19
20
  _globals = globals()
20
21
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
21
22
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.run_pb2', _globals)
22
23
  if _descriptor._USE_C_DESCRIPTORS == False:
23
24
  DESCRIPTOR._options = None
24
- _globals['_RUN']._serialized_start=36
25
- _globals['_RUN']._serialized_end=94
26
- _globals['_GETRUNREQUEST']._serialized_start=96
27
- _globals['_GETRUNREQUEST']._serialized_end=127
28
- _globals['_GETRUNRESPONSE']._serialized_start=129
29
- _globals['_GETRUNRESPONSE']._serialized_end=175
25
+ _globals['_RUN_OVERRIDECONFIGENTRY']._options = None
26
+ _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001'
27
+ _globals['_RUN']._serialized_start=65
28
+ _globals['_RUN']._serialized_end=260
29
+ _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=187
30
+ _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=260
31
+ _globals['_GETRUNREQUEST']._serialized_start=262
32
+ _globals['_GETRUNREQUEST']._serialized_end=293
33
+ _globals['_GETRUNRESPONSE']._serialized_start=295
34
+ _globals['_GETRUNRESPONSE']._serialized_end=341
30
35
  # @@protoc_insertion_point(module_scope)
flwr/proto/run_pb2.pyi CHANGED
@@ -3,7 +3,9 @@
3
3
  isort:skip_file
4
4
  """
5
5
  import builtins
6
+ import flwr.proto.transport_pb2
6
7
  import google.protobuf.descriptor
8
+ import google.protobuf.internal.containers
7
9
  import google.protobuf.message
8
10
  import typing
9
11
  import typing_extensions
@@ -12,19 +14,38 @@ DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
12
14
 
13
15
  class Run(google.protobuf.message.Message):
14
16
  DESCRIPTOR: google.protobuf.descriptor.Descriptor
17
+ class OverrideConfigEntry(google.protobuf.message.Message):
18
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
19
+ KEY_FIELD_NUMBER: builtins.int
20
+ VALUE_FIELD_NUMBER: builtins.int
21
+ key: typing.Text
22
+ @property
23
+ def value(self) -> flwr.proto.transport_pb2.Scalar: ...
24
+ def __init__(self,
25
+ *,
26
+ key: typing.Text = ...,
27
+ value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ...,
28
+ ) -> None: ...
29
+ def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
30
+ def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
31
+
15
32
  RUN_ID_FIELD_NUMBER: builtins.int
16
33
  FAB_ID_FIELD_NUMBER: builtins.int
17
34
  FAB_VERSION_FIELD_NUMBER: builtins.int
35
+ OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int
18
36
  run_id: builtins.int
19
37
  fab_id: typing.Text
20
38
  fab_version: typing.Text
39
+ @property
40
+ def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
21
41
  def __init__(self,
22
42
  *,
23
43
  run_id: builtins.int = ...,
24
44
  fab_id: typing.Text = ...,
25
45
  fab_version: typing.Text = ...,
46
+ override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
26
47
  ) -> None: ...
27
- def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","run_id",b"run_id"]) -> None: ...
48
+ def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","run_id",b"run_id"]) -> None: ...
28
49
  global___Run = Run
29
50
 
30
51
  class GetRunRequest(google.protobuf.message.Message):
flwr/proto/task_pb2.py CHANGED
@@ -14,21 +14,20 @@ _sym_db = _symbol_database.Default()
14
14
 
15
15
  from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
16
16
  from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
17
- from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
18
17
  from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2
19
18
 
20
19
 
21
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\x89\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\x01\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x11\n\tpushed_at\x18\x05 \x01(\x01\x12\x0b\n\x03ttl\x18\x06 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x07 \x03(\t\x12\x11\n\ttask_type\x18\x08 \x01(\t\x12(\n\trecordset\x18\t \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\n \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
20
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x16\x66lwr/proto/error.proto\"\x89\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\x01\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x11\n\tpushed_at\x18\x05 \x01(\x01\x12\x0b\n\x03ttl\x18\x06 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x07 \x03(\t\x12\x11\n\ttask_type\x18\x08 \x01(\t\x12(\n\trecordset\x18\t \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\n \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
22
21
 
23
22
  _globals = globals()
24
23
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
25
24
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals)
26
25
  if _descriptor._USE_C_DESCRIPTORS == False:
27
26
  DESCRIPTOR._options = None
28
- _globals['_TASK']._serialized_start=141
29
- _globals['_TASK']._serialized_end=406
30
- _globals['_TASKINS']._serialized_start=408
31
- _globals['_TASKINS']._serialized_end=500
32
- _globals['_TASKRES']._serialized_start=502
33
- _globals['_TASKRES']._serialized_end=594
27
+ _globals['_TASK']._serialized_start=113
28
+ _globals['_TASK']._serialized_end=378
29
+ _globals['_TASKINS']._serialized_start=380
30
+ _globals['_TASKINS']._serialized_end=472
31
+ _globals['_TASKRES']._serialized_start=474
32
+ _globals['_TASKRES']._serialized_end=566
34
33
  # @@protoc_insertion_point(module_scope)
flwr/server/__init__.py CHANGED
@@ -28,6 +28,7 @@ from .run_serverapp import run_server_app as run_server_app
28
28
  from .server import Server as Server
29
29
  from .server_app import ServerApp as ServerApp
30
30
  from .server_config import ServerConfig as ServerConfig
31
+ from .serverapp_components import ServerAppComponents as ServerAppComponents
31
32
 
32
33
  __all__ = [
33
34
  "ClientManager",
@@ -36,6 +37,7 @@ __all__ = [
36
37
  "LegacyContext",
37
38
  "Server",
38
39
  "ServerApp",
40
+ "ServerAppComponents",
39
41
  "ServerConfig",
40
42
  "SimpleClientManager",
41
43
  "run_server_app",
@@ -18,7 +18,7 @@
18
18
  from dataclasses import dataclass
19
19
  from typing import Optional
20
20
 
21
- from flwr.common import Context, RecordSet
21
+ from flwr.common import Context
22
22
 
23
23
  from ..client_manager import ClientManager, SimpleClientManager
24
24
  from ..history import History
@@ -35,9 +35,9 @@ class LegacyContext(Context):
35
35
  client_manager: ClientManager
36
36
  history: History
37
37
 
38
- def __init__(
38
+ def __init__( # pylint: disable=too-many-arguments
39
39
  self,
40
- state: RecordSet,
40
+ context: Context,
41
41
  config: Optional[ServerConfig] = None,
42
42
  strategy: Optional[Strategy] = None,
43
43
  client_manager: Optional[ClientManager] = None,
@@ -52,4 +52,5 @@ class LegacyContext(Context):
52
52
  self.strategy = strategy
53
53
  self.client_manager = client_manager
54
54
  self.history = History()
55
- super().__init__(state)
55
+
56
+ super().__init__(**vars(context))
@@ -16,19 +16,21 @@
16
16
 
17
17
  import time
18
18
  import warnings
19
- from logging import DEBUG, ERROR, WARNING
20
- from typing import Iterable, List, Optional, Tuple, cast
19
+ from logging import DEBUG, WARNING
20
+ from typing import Iterable, List, Optional, cast
21
21
 
22
22
  import grpc
23
23
 
24
24
  from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
25
25
  from flwr.common.grpc import create_channel
26
26
  from flwr.common.logger import log
27
- from flwr.common.serde import message_from_taskres, message_to_taskins
27
+ from flwr.common.serde import (
28
+ message_from_taskres,
29
+ message_to_taskins,
30
+ user_config_from_proto,
31
+ )
28
32
  from flwr.common.typing import Run
29
33
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
30
- CreateRunRequest,
31
- CreateRunResponse,
32
34
  GetNodesRequest,
33
35
  GetNodesResponse,
34
36
  PullTaskResRequest,
@@ -53,167 +55,103 @@ Call `connect()` on the `GrpcDriverStub` instance before calling any of the othe
53
55
  """
54
56
 
55
57
 
56
- class GrpcDriverStub:
57
- """`GrpcDriverStub` provides access to the gRPC Driver API/service.
58
+ class GrpcDriver(Driver):
59
+ """`GrpcDriver` provides an interface to the Driver API.
58
60
 
59
61
  Parameters
60
62
  ----------
61
- driver_service_address : Optional[str]
62
- The IPv4 or IPv6 address of the Driver API server.
63
- Defaults to `"[::]:9091"`.
63
+ run_id : int
64
+ The identifier of the run.
65
+ driver_service_address : str (default: "[::]:9091")
66
+ The address (URL, IPv6, IPv4) of the SuperLink Driver API service.
64
67
  root_certificates : Optional[bytes] (default: None)
65
68
  The PEM-encoded root certificates as a byte string.
66
69
  If provided, a secure connection using the certificates will be
67
70
  established to an SSL-enabled Flower server.
68
71
  """
69
72
 
70
- def __init__(
73
+ def __init__( # pylint: disable=too-many-arguments
71
74
  self,
75
+ run_id: int,
72
76
  driver_service_address: str = DEFAULT_SERVER_ADDRESS_DRIVER,
73
77
  root_certificates: Optional[bytes] = None,
74
78
  ) -> None:
75
- self.driver_service_address = driver_service_address
76
- self.root_certificates = root_certificates
77
- self.channel: Optional[grpc.Channel] = None
78
- self.stub: Optional[DriverStub] = None
79
+ self._run_id = run_id
80
+ self._addr = driver_service_address
81
+ self._cert = root_certificates
82
+ self._run: Optional[Run] = None
83
+ self._grpc_stub: Optional[DriverStub] = None
84
+ self._channel: Optional[grpc.Channel] = None
85
+ self.node = Node(node_id=0, anonymous=True)
79
86
 
80
- def is_connected(self) -> bool:
81
- """Return True if connected to the Driver API server, otherwise False."""
82
- return self.channel is not None
87
+ @property
88
+ def _is_connected(self) -> bool:
89
+ """Check if connected to the Driver API server."""
90
+ return self._channel is not None
91
+
92
+ def _connect(self) -> None:
93
+ """Connect to the Driver API.
83
94
 
84
- def connect(self) -> None:
85
- """Connect to the Driver API."""
95
+ This will not call GetRun.
96
+ """
86
97
  event(EventType.DRIVER_CONNECT)
87
- if self.channel is not None or self.stub is not None:
98
+ if self._is_connected:
88
99
  log(WARNING, "Already connected")
89
100
  return
90
- self.channel = create_channel(
91
- server_address=self.driver_service_address,
92
- insecure=(self.root_certificates is None),
93
- root_certificates=self.root_certificates,
101
+ self._channel = create_channel(
102
+ server_address=self._addr,
103
+ insecure=(self._cert is None),
104
+ root_certificates=self._cert,
94
105
  )
95
- self.stub = DriverStub(self.channel)
96
- log(DEBUG, "[Driver] Connected to %s", self.driver_service_address)
106
+ self._grpc_stub = DriverStub(self._channel)
107
+ log(DEBUG, "[Driver] Connected to %s", self._addr)
97
108
 
98
- def disconnect(self) -> None:
109
+ def _disconnect(self) -> None:
99
110
  """Disconnect from the Driver API."""
100
111
  event(EventType.DRIVER_DISCONNECT)
101
- if self.channel is None or self.stub is None:
112
+ if not self._is_connected:
102
113
  log(DEBUG, "Already disconnected")
103
114
  return
104
- channel = self.channel
105
- self.channel = None
106
- self.stub = None
115
+ channel: grpc.Channel = self._channel
116
+ self._channel = None
117
+ self._grpc_stub = None
107
118
  channel.close()
108
119
  log(DEBUG, "[Driver] Disconnected")
109
120
 
110
- def create_run(self, req: CreateRunRequest) -> CreateRunResponse:
111
- """Request for run ID."""
112
- # Check if channel is open
113
- if self.stub is None:
114
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
115
- raise ConnectionError("`GrpcDriverStub` instance not connected")
116
-
117
- # Call Driver API
118
- res: CreateRunResponse = self.stub.CreateRun(request=req)
119
- return res
120
-
121
- def get_run(self, req: GetRunRequest) -> GetRunResponse:
122
- """Get run information."""
123
- # Check if channel is open
124
- if self.stub is None:
125
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
126
- raise ConnectionError("`GrpcDriverStub` instance not connected")
127
-
128
- # Call gRPC Driver API
129
- res: GetRunResponse = self.stub.GetRun(request=req)
130
- return res
131
-
132
- def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
133
- """Get client IDs."""
134
- # Check if channel is open
135
- if self.stub is None:
136
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
137
- raise ConnectionError("`GrpcDriverStub` instance not connected")
138
-
139
- # Call gRPC Driver API
140
- res: GetNodesResponse = self.stub.GetNodes(request=req)
141
- return res
142
-
143
- def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse:
144
- """Schedule tasks."""
145
- # Check if channel is open
146
- if self.stub is None:
147
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
148
- raise ConnectionError("`GrpcDriverStub` instance not connected")
149
-
150
- # Call gRPC Driver API
151
- res: PushTaskInsResponse = self.stub.PushTaskIns(request=req)
152
- return res
153
-
154
- def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse:
155
- """Get task results."""
156
- # Check if channel is open
157
- if self.stub is None:
158
- log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
159
- raise ConnectionError("`GrpcDriverStub` instance not connected")
160
-
161
- # Call Driver API
162
- res: PullTaskResResponse = self.stub.PullTaskRes(request=req)
163
- return res
164
-
165
-
166
- class GrpcDriver(Driver):
167
- """`Driver` class provides an interface to the Driver API.
168
-
169
- Parameters
170
- ----------
171
- run_id : int
172
- The identifier of the run.
173
- stub : Optional[GrpcDriverStub] (default: None)
174
- The ``GrpcDriverStub`` instance used to communicate with the SuperLink.
175
- If None, an instance connected to "[::]:9091" will be created.
176
- """
177
-
178
- def __init__( # pylint: disable=too-many-arguments
179
- self,
180
- run_id: int,
181
- stub: Optional[GrpcDriverStub] = None,
182
- ) -> None:
183
- self._run_id = run_id
184
- self._run: Optional[Run] = None
185
- self.stub = stub if stub is not None else GrpcDriverStub()
186
- self.node = Node(node_id=0, anonymous=True)
121
+ def _init_run(self) -> None:
122
+ # Check if is initialized
123
+ if self._run is not None:
124
+ return
125
+ # Get the run info
126
+ req = GetRunRequest(run_id=self._run_id)
127
+ res: GetRunResponse = self._stub.GetRun(req)
128
+ if not res.HasField("run"):
129
+ raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
130
+ self._run = Run(
131
+ run_id=res.run.run_id,
132
+ fab_id=res.run.fab_id,
133
+ fab_version=res.run.fab_version,
134
+ override_config=user_config_from_proto(res.run.override_config),
135
+ )
187
136
 
188
137
  @property
189
138
  def run(self) -> Run:
190
139
  """Run information."""
191
- self._get_stub_and_run_id()
192
- return Run(**vars(cast(Run, self._run)))
140
+ self._init_run()
141
+ return Run(**vars(self._run))
193
142
 
194
- def _get_stub_and_run_id(self) -> Tuple[GrpcDriverStub, int]:
195
- # Check if is initialized
196
- if self._run is None:
197
- # Connect
198
- if not self.stub.is_connected():
199
- self.stub.connect()
200
- # Get the run info
201
- req = GetRunRequest(run_id=self._run_id)
202
- res = self.stub.get_run(req)
203
- if not res.HasField("run"):
204
- raise RuntimeError(f"Cannot find the run with ID: {self._run_id}")
205
- self._run = Run(
206
- run_id=res.run.run_id,
207
- fab_id=res.run.fab_id,
208
- fab_version=res.run.fab_version,
209
- )
210
-
211
- return self.stub, self._run.run_id
143
+ @property
144
+ def _stub(self) -> DriverStub:
145
+ """Driver stub."""
146
+ if not self._is_connected:
147
+ self._connect()
148
+ return cast(DriverStub, self._grpc_stub)
212
149
 
213
150
  def _check_message(self, message: Message) -> None:
214
151
  # Check if the message is valid
215
152
  if not (
216
- message.metadata.run_id == cast(Run, self._run).run_id
153
+ # Assume self._run being initialized
154
+ message.metadata.run_id == self._run_id
217
155
  and message.metadata.src_node_id == self.node.node_id
218
156
  and message.metadata.message_id == ""
219
157
  and message.metadata.reply_to_message == ""
@@ -234,7 +172,7 @@ class GrpcDriver(Driver):
234
172
  This method constructs a new `Message` with given content and metadata.
235
173
  The `run_id` and `src_node_id` will be set automatically.
236
174
  """
237
- _, run_id = self._get_stub_and_run_id()
175
+ self._init_run()
238
176
  if ttl:
239
177
  warnings.warn(
240
178
  "A custom TTL was set, but note that the SuperLink does not enforce "
@@ -245,7 +183,7 @@ class GrpcDriver(Driver):
245
183
 
246
184
  ttl_ = DEFAULT_TTL if ttl is None else ttl
247
185
  metadata = Metadata(
248
- run_id=run_id,
186
+ run_id=self._run_id,
249
187
  message_id="", # Will be set by the server
250
188
  src_node_id=self.node.node_id,
251
189
  dst_node_id=dst_node_id,
@@ -258,9 +196,11 @@ class GrpcDriver(Driver):
258
196
 
259
197
  def get_node_ids(self) -> List[int]:
260
198
  """Get node IDs."""
261
- stub, run_id = self._get_stub_and_run_id()
199
+ self._init_run()
262
200
  # Call GrpcDriverStub method
263
- res = stub.get_nodes(GetNodesRequest(run_id=run_id))
201
+ res: GetNodesResponse = self._stub.GetNodes(
202
+ GetNodesRequest(run_id=self._run_id)
203
+ )
264
204
  return [node.node_id for node in res.nodes]
265
205
 
266
206
  def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
@@ -269,7 +209,7 @@ class GrpcDriver(Driver):
269
209
  This method takes an iterable of messages and sends each message
270
210
  to the node specified in `dst_node_id`.
271
211
  """
272
- stub, _ = self._get_stub_and_run_id()
212
+ self._init_run()
273
213
  # Construct TaskIns
274
214
  task_ins_list: List[TaskIns] = []
275
215
  for msg in messages:
@@ -280,7 +220,9 @@ class GrpcDriver(Driver):
280
220
  # Add to list
281
221
  task_ins_list.append(taskins)
282
222
  # Call GrpcDriverStub method
283
- res = stub.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list))
223
+ res: PushTaskInsResponse = self._stub.PushTaskIns(
224
+ PushTaskInsRequest(task_ins_list=task_ins_list)
225
+ )
284
226
  return list(res.task_ids)
285
227
 
286
228
  def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
@@ -289,9 +231,9 @@ class GrpcDriver(Driver):
289
231
  This method is used to collect messages from the SuperLink that correspond to a
290
232
  set of given message IDs.
291
233
  """
292
- stub, _ = self._get_stub_and_run_id()
234
+ self._init_run()
293
235
  # Pull TaskRes
294
- res = stub.pull_task_res(
236
+ res: PullTaskResResponse = self._stub.PullTaskRes(
295
237
  PullTaskResRequest(node=self.node, task_ids=message_ids)
296
238
  )
297
239
  # Convert TaskRes to Message
@@ -331,7 +273,7 @@ class GrpcDriver(Driver):
331
273
  def close(self) -> None:
332
274
  """Disconnect from the SuperLink if connected."""
333
275
  # Check if `connect` was called before
334
- if not self.stub.is_connected():
276
+ if not self._is_connected:
335
277
  return
336
278
  # Disconnect
337
- self.stub.disconnect()
279
+ self._disconnect()
@@ -22,13 +22,22 @@ from pathlib import Path
22
22
  from typing import Optional
23
23
 
24
24
  from flwr.common import Context, EventType, RecordSet, event
25
- from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
25
+ from flwr.common.config import (
26
+ get_flwr_dir,
27
+ get_fused_config,
28
+ get_project_config,
29
+ get_project_dir,
30
+ )
26
31
  from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
27
32
  from flwr.common.object_ref import load_app
28
- from flwr.proto.driver_pb2 import CreateRunRequest # pylint: disable=E0611
33
+ from flwr.common.typing import UserConfig
34
+ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
35
+ CreateRunRequest,
36
+ CreateRunResponse,
37
+ )
29
38
 
30
39
  from .driver import Driver
31
- from .driver.grpc_driver import GrpcDriver, GrpcDriverStub
40
+ from .driver.grpc_driver import GrpcDriver
32
41
  from .server_app import LoadServerAppError, ServerApp
33
42
 
34
43
  ADDRESS_DRIVER_API = "0.0.0.0:9091"
@@ -37,6 +46,7 @@ ADDRESS_DRIVER_API = "0.0.0.0:9091"
37
46
  def run(
38
47
  driver: Driver,
39
48
  server_app_dir: str,
49
+ server_app_run_config: UserConfig,
40
50
  server_app_attr: Optional[str] = None,
41
51
  loaded_server_app: Optional[ServerApp] = None,
42
52
  ) -> None:
@@ -47,9 +57,6 @@ def run(
47
57
  "but not both."
48
58
  )
49
59
 
50
- if server_app_dir is not None:
51
- sys.path.insert(0, str(Path(server_app_dir).absolute()))
52
-
53
60
  # Load ServerApp if needed
54
61
  def _load() -> ServerApp:
55
62
  if server_app_attr:
@@ -69,7 +76,9 @@ def run(
69
76
  server_app = _load()
70
77
 
71
78
  # Initialize Context
72
- context = Context(state=RecordSet())
79
+ context = Context(
80
+ node_id=0, node_config={}, state=RecordSet(), run_config=server_app_run_config
81
+ )
73
82
 
74
83
  # Call ServerApp
75
84
  server_app(driver=driver, context=context)
@@ -144,22 +153,29 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
144
153
  "For more details, use: ``flower-server-app -h``"
145
154
  )
146
155
 
147
- stub = GrpcDriverStub(
148
- driver_service_address=args.superlink, root_certificates=root_certificates
149
- )
156
+ # Initialize GrpcDriver
150
157
  if args.run_id is not None:
151
158
  # User provided `--run-id`, but not `server-app`
152
- run_id = args.run_id
159
+ driver = GrpcDriver(
160
+ run_id=args.run_id,
161
+ driver_service_address=args.superlink,
162
+ root_certificates=root_certificates,
163
+ )
153
164
  else:
154
165
  # User provided `server-app`, but not `--run-id`
155
166
  # Create run if run_id is not provided
156
- stub.connect()
167
+ driver = GrpcDriver(
168
+ run_id=0, # Will be overwritten
169
+ driver_service_address=args.superlink,
170
+ root_certificates=root_certificates,
171
+ )
172
+ # Create run
157
173
  req = CreateRunRequest(fab_id=args.fab_id, fab_version=args.fab_version)
158
- res = stub.create_run(req)
159
- run_id = res.run_id
174
+ res: CreateRunResponse = driver._stub.CreateRun(req) # pylint: disable=W0212
175
+ # Overwrite driver._run_id
176
+ driver._run_id = res.run_id # pylint: disable=W0212
160
177
 
161
- # Initialize GrpcDriver
162
- driver = GrpcDriver(run_id=run_id, stub=stub)
178
+ server_app_run_config = {}
163
179
 
164
180
  # Dynamically obtain ServerApp path based on run_id
165
181
  if args.run_id is not None:
@@ -168,7 +184,8 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
168
184
  run_ = driver.run
169
185
  server_app_dir = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir))
170
186
  config = get_project_config(server_app_dir)
171
- server_app_attr = config["flower"]["components"]["serverapp"]
187
+ server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
188
+ server_app_run_config = get_fused_config(run_, flwr_dir)
172
189
  else:
173
190
  # User provided `server-app`, but not `--run-id`
174
191
  server_app_dir = str(Path(args.dir).absolute())
@@ -182,7 +199,12 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
182
199
  )
183
200
 
184
201
  # Run the ServerApp with the Driver
185
- run(driver=driver, server_app_dir=server_app_dir, server_app_attr=server_app_attr)
202
+ run(
203
+ driver=driver,
204
+ server_app_dir=server_app_dir,
205
+ server_app_run_config=server_app_run_config,
206
+ server_app_attr=server_app_attr,
207
+ )
186
208
 
187
209
  # Clean up
188
210
  driver.close()