flwr-nightly 1.10.0.dev20240721__py3-none-any.whl → 1.10.0.dev20240723__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (74) hide show
  1. flwr/cli/config_utils.py +20 -18
  2. flwr/cli/new/new.py +1 -1
  3. flwr/cli/new/templates/app/code/{client.hf.py.tpl → client.huggingface.py.tpl} +7 -5
  4. flwr/cli/new/templates/app/code/client.mlx.py.tpl +28 -10
  5. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +7 -5
  6. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +2 -2
  7. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +17 -7
  8. flwr/cli/new/templates/app/code/flwr_tune/app.py.tpl +20 -17
  9. flwr/cli/new/templates/app/code/flwr_tune/client.py.tpl +5 -3
  10. flwr/cli/new/templates/app/code/{server.hf.py.tpl → server.huggingface.py.tpl} +2 -1
  11. flwr/cli/new/templates/app/code/server.jax.py.tpl +2 -1
  12. flwr/cli/new/templates/app/code/server.mlx.py.tpl +2 -1
  13. flwr/cli/new/templates/app/code/server.numpy.py.tpl +2 -1
  14. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +1 -1
  15. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +2 -1
  16. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +1 -1
  17. flwr/cli/new/templates/app/code/{task.hf.py.tpl → task.huggingface.py.tpl} +13 -1
  18. flwr/cli/new/templates/app/code/task.mlx.py.tpl +14 -1
  19. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +13 -2
  20. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +13 -1
  21. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  22. flwr/cli/new/templates/app/{pyproject.hf.toml.tpl → pyproject.huggingface.toml.tpl} +2 -2
  23. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  24. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +6 -6
  25. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  26. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +2 -2
  27. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  28. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +4 -4
  29. flwr/cli/run/run.py +35 -28
  30. flwr/client/app.py +3 -3
  31. flwr/client/grpc_rere_client/connection.py +6 -2
  32. flwr/client/node_state.py +3 -3
  33. flwr/client/rest_client/connection.py +6 -2
  34. flwr/client/supernode/app.py +12 -43
  35. flwr/common/config.py +23 -17
  36. flwr/common/context.py +7 -7
  37. flwr/common/object_ref.py +84 -21
  38. flwr/common/serde.py +45 -0
  39. flwr/common/telemetry.py +17 -0
  40. flwr/common/typing.py +5 -1
  41. flwr/proto/common_pb2.py +13 -1
  42. flwr/proto/common_pb2.pyi +114 -0
  43. flwr/proto/driver_pb2.py +22 -21
  44. flwr/proto/driver_pb2.pyi +7 -4
  45. flwr/proto/exec_pb2.py +18 -13
  46. flwr/proto/exec_pb2.pyi +27 -5
  47. flwr/proto/run_pb2.py +10 -9
  48. flwr/proto/run_pb2.pyi +7 -4
  49. flwr/proto/task_pb2.py +7 -8
  50. flwr/server/compat/legacy_context.py +5 -4
  51. flwr/server/driver/grpc_driver.py +6 -2
  52. flwr/server/run_serverapp.py +3 -5
  53. flwr/server/superlink/driver/driver_servicer.py +14 -3
  54. flwr/server/superlink/fleet/message_handler/message_handler.py +13 -2
  55. flwr/server/superlink/fleet/vce/backend/__init__.py +1 -1
  56. flwr/server/superlink/fleet/vce/vce_api.py +4 -4
  57. flwr/server/superlink/state/in_memory_state.py +2 -2
  58. flwr/server/superlink/state/sqlite_state.py +2 -2
  59. flwr/server/superlink/state/state.py +3 -3
  60. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +18 -2
  61. flwr/simulation/__init__.py +1 -1
  62. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  63. flwr/simulation/run_simulation.py +39 -11
  64. flwr/superexec/app.py +4 -5
  65. flwr/superexec/deployment.py +19 -8
  66. flwr/superexec/exec_grpc.py +3 -2
  67. flwr/superexec/exec_servicer.py +3 -1
  68. flwr/superexec/executor.py +10 -5
  69. flwr/superexec/simulation.py +41 -15
  70. {flwr_nightly-1.10.0.dev20240721.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/METADATA +1 -1
  71. {flwr_nightly-1.10.0.dev20240721.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/RECORD +74 -74
  72. {flwr_nightly-1.10.0.dev20240721.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/LICENSE +0 -0
  73. {flwr_nightly-1.10.0.dev20240721.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/WHEEL +0 -0
  74. {flwr_nightly-1.10.0.dev20240721.dist-info → flwr_nightly-1.10.0.dev20240723.dist-info}/entry_points.txt +0 -0
flwr/proto/run_pb2.py CHANGED
@@ -12,9 +12,10 @@ from google.protobuf.internal import builder as _builder
12
12
  _sym_db = _symbol_database.Default()
13
13
 
14
14
 
15
+ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
15
16
 
16
17
 
17
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\"\xaf\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x1a\x35\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3')
18
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14\x66lwr/proto/run.proto\x12\nflwr.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xc3\x01\n\x03Run\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x0e\n\x06\x66\x61\x62_id\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61\x62_version\x18\x03 \x01(\t\x12<\n\x0foverride_config\x18\x04 \x03(\x0b\x32#.flwr.proto.Run.OverrideConfigEntry\x1aI\n\x13OverrideConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\".\n\x0eGetRunResponse\x12\x1c\n\x03run\x18\x01 \x01(\x0b\x32\x0f.flwr.proto.Runb\x06proto3')
18
19
 
19
20
  _globals = globals()
20
21
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -23,12 +24,12 @@ if _descriptor._USE_C_DESCRIPTORS == False:
23
24
  DESCRIPTOR._options = None
24
25
  _globals['_RUN_OVERRIDECONFIGENTRY']._options = None
25
26
  _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_options = b'8\001'
26
- _globals['_RUN']._serialized_start=37
27
- _globals['_RUN']._serialized_end=212
28
- _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=159
29
- _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=212
30
- _globals['_GETRUNREQUEST']._serialized_start=214
31
- _globals['_GETRUNREQUEST']._serialized_end=245
32
- _globals['_GETRUNRESPONSE']._serialized_start=247
33
- _globals['_GETRUNRESPONSE']._serialized_end=293
27
+ _globals['_RUN']._serialized_start=65
28
+ _globals['_RUN']._serialized_end=260
29
+ _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_start=187
30
+ _globals['_RUN_OVERRIDECONFIGENTRY']._serialized_end=260
31
+ _globals['_GETRUNREQUEST']._serialized_start=262
32
+ _globals['_GETRUNREQUEST']._serialized_end=293
33
+ _globals['_GETRUNRESPONSE']._serialized_start=295
34
+ _globals['_GETRUNRESPONSE']._serialized_end=341
34
35
  # @@protoc_insertion_point(module_scope)
flwr/proto/run_pb2.pyi CHANGED
@@ -3,6 +3,7 @@
3
3
  isort:skip_file
4
4
  """
5
5
  import builtins
6
+ import flwr.proto.transport_pb2
6
7
  import google.protobuf.descriptor
7
8
  import google.protobuf.internal.containers
8
9
  import google.protobuf.message
@@ -18,12 +19,14 @@ class Run(google.protobuf.message.Message):
18
19
  KEY_FIELD_NUMBER: builtins.int
19
20
  VALUE_FIELD_NUMBER: builtins.int
20
21
  key: typing.Text
21
- value: typing.Text
22
+ @property
23
+ def value(self) -> flwr.proto.transport_pb2.Scalar: ...
22
24
  def __init__(self,
23
25
  *,
24
26
  key: typing.Text = ...,
25
- value: typing.Text = ...,
27
+ value: typing.Optional[flwr.proto.transport_pb2.Scalar] = ...,
26
28
  ) -> None: ...
29
+ def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
27
30
  def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
28
31
 
29
32
  RUN_ID_FIELD_NUMBER: builtins.int
@@ -34,13 +37,13 @@ class Run(google.protobuf.message.Message):
34
37
  fab_id: typing.Text
35
38
  fab_version: typing.Text
36
39
  @property
37
- def override_config(self) -> google.protobuf.internal.containers.ScalarMap[typing.Text, typing.Text]: ...
40
+ def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
38
41
  def __init__(self,
39
42
  *,
40
43
  run_id: builtins.int = ...,
41
44
  fab_id: typing.Text = ...,
42
45
  fab_version: typing.Text = ...,
43
- override_config: typing.Optional[typing.Mapping[typing.Text, typing.Text]] = ...,
46
+ override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
44
47
  ) -> None: ...
45
48
  def ClearField(self, field_name: typing_extensions.Literal["fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","run_id",b"run_id"]) -> None: ...
46
49
  global___Run = Run
flwr/proto/task_pb2.py CHANGED
@@ -14,21 +14,20 @@ _sym_db = _symbol_database.Default()
14
14
 
15
15
  from flwr.proto import node_pb2 as flwr_dot_proto_dot_node__pb2
16
16
  from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
17
- from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
18
17
  from flwr.proto import error_pb2 as flwr_dot_proto_dot_error__pb2
19
18
 
20
19
 
21
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\x1a\x16\x66lwr/proto/error.proto\"\x89\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\x01\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x11\n\tpushed_at\x18\x05 \x01(\x01\x12\x0b\n\x03ttl\x18\x06 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x07 \x03(\t\x12\x11\n\ttask_type\x18\x08 \x01(\t\x12(\n\trecordset\x18\t \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\n \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
20
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x16\x66lwr/proto/error.proto\"\x89\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\x01\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x11\n\tpushed_at\x18\x05 \x01(\x01\x12\x0b\n\x03ttl\x18\x06 \x01(\x01\x12\x10\n\x08\x61ncestry\x18\x07 \x03(\t\x12\x11\n\ttask_type\x18\x08 \x01(\t\x12(\n\trecordset\x18\t \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\n \x01(\x0b\x32\x11.flwr.proto.Error\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Taskb\x06proto3')
22
21
 
23
22
  _globals = globals()
24
23
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
25
24
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'flwr.proto.task_pb2', _globals)
26
25
  if _descriptor._USE_C_DESCRIPTORS == False:
27
26
  DESCRIPTOR._options = None
28
- _globals['_TASK']._serialized_start=141
29
- _globals['_TASK']._serialized_end=406
30
- _globals['_TASKINS']._serialized_start=408
31
- _globals['_TASKINS']._serialized_end=500
32
- _globals['_TASKRES']._serialized_start=502
33
- _globals['_TASKRES']._serialized_end=594
27
+ _globals['_TASK']._serialized_start=113
28
+ _globals['_TASK']._serialized_end=378
29
+ _globals['_TASKINS']._serialized_start=380
30
+ _globals['_TASKINS']._serialized_end=472
31
+ _globals['_TASKRES']._serialized_start=474
32
+ _globals['_TASKRES']._serialized_end=566
34
33
  # @@protoc_insertion_point(module_scope)
@@ -18,7 +18,7 @@
18
18
  from dataclasses import dataclass
19
19
  from typing import Optional
20
20
 
21
- from flwr.common import Context, RecordSet
21
+ from flwr.common import Context
22
22
 
23
23
  from ..client_manager import ClientManager, SimpleClientManager
24
24
  from ..history import History
@@ -35,9 +35,9 @@ class LegacyContext(Context):
35
35
  client_manager: ClientManager
36
36
  history: History
37
37
 
38
- def __init__(
38
+ def __init__( # pylint: disable=too-many-arguments
39
39
  self,
40
- state: RecordSet,
40
+ context: Context,
41
41
  config: Optional[ServerConfig] = None,
42
42
  strategy: Optional[Strategy] = None,
43
43
  client_manager: Optional[ClientManager] = None,
@@ -52,4 +52,5 @@ class LegacyContext(Context):
52
52
  self.strategy = strategy
53
53
  self.client_manager = client_manager
54
54
  self.history = History()
55
- super().__init__(node_id=0, node_config={}, state=state, run_config={})
55
+
56
+ super().__init__(**vars(context))
@@ -24,7 +24,11 @@ import grpc
24
24
  from flwr.common import DEFAULT_TTL, EventType, Message, Metadata, RecordSet, event
25
25
  from flwr.common.grpc import create_channel
26
26
  from flwr.common.logger import log
27
- from flwr.common.serde import message_from_taskres, message_to_taskins
27
+ from flwr.common.serde import (
28
+ message_from_taskres,
29
+ message_to_taskins,
30
+ user_config_from_proto,
31
+ )
28
32
  from flwr.common.typing import Run
29
33
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
30
34
  GetNodesRequest,
@@ -127,7 +131,7 @@ class GrpcDriver(Driver):
127
131
  run_id=res.run.run_id,
128
132
  fab_id=res.run.fab_id,
129
133
  fab_version=res.run.fab_version,
130
- override_config=dict(res.run.override_config.items()),
134
+ override_config=user_config_from_proto(res.run.override_config),
131
135
  )
132
136
 
133
137
  @property
@@ -19,7 +19,7 @@ import argparse
19
19
  import sys
20
20
  from logging import DEBUG, INFO, WARN
21
21
  from pathlib import Path
22
- from typing import Dict, Optional
22
+ from typing import Optional
23
23
 
24
24
  from flwr.common import Context, EventType, RecordSet, event
25
25
  from flwr.common.config import (
@@ -30,6 +30,7 @@ from flwr.common.config import (
30
30
  )
31
31
  from flwr.common.logger import log, update_console_handler, warn_deprecated_feature
32
32
  from flwr.common.object_ref import load_app
33
+ from flwr.common.typing import UserConfig
33
34
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
34
35
  CreateRunRequest,
35
36
  CreateRunResponse,
@@ -45,7 +46,7 @@ ADDRESS_DRIVER_API = "0.0.0.0:9091"
45
46
  def run(
46
47
  driver: Driver,
47
48
  server_app_dir: str,
48
- server_app_run_config: Dict[str, str],
49
+ server_app_run_config: UserConfig,
49
50
  server_app_attr: Optional[str] = None,
50
51
  loaded_server_app: Optional[ServerApp] = None,
51
52
  ) -> None:
@@ -56,9 +57,6 @@ def run(
56
57
  "but not both."
57
58
  )
58
59
 
59
- if server_app_dir is not None:
60
- sys.path.insert(0, str(Path(server_app_dir).absolute()))
61
-
62
60
  # Load ServerApp if needed
63
61
  def _load() -> ServerApp:
64
62
  if server_app_attr:
@@ -23,6 +23,7 @@ from uuid import UUID
23
23
  import grpc
24
24
 
25
25
  from flwr.common.logger import log
26
+ from flwr.common.serde import user_config_from_proto, user_config_to_proto
26
27
  from flwr.proto import driver_pb2_grpc # pylint: disable=E0611
27
28
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
28
29
  CreateRunRequest,
@@ -72,7 +73,7 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
72
73
  run_id = state.create_run(
73
74
  request.fab_id,
74
75
  request.fab_version,
75
- dict(request.override_config.items()),
76
+ user_config_from_proto(request.override_config),
76
77
  )
77
78
  return CreateRunResponse(run_id=run_id)
78
79
 
@@ -149,8 +150,18 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
149
150
 
150
151
  # Retrieve run information
151
152
  run = state.get_run(request.run_id)
152
- run_proto = None if run is None else Run(**vars(run))
153
- return GetRunResponse(run=run_proto)
153
+
154
+ if run is None:
155
+ return GetRunResponse()
156
+
157
+ return GetRunResponse(
158
+ run=Run(
159
+ run_id=run.run_id,
160
+ fab_id=run.fab_id,
161
+ fab_version=run.fab_version,
162
+ override_config=user_config_to_proto(run.override_config),
163
+ )
164
+ )
154
165
 
155
166
 
156
167
  def _raise_if(validation_error: bool, detail: str) -> None:
@@ -19,6 +19,7 @@ import time
19
19
  from typing import List, Optional
20
20
  from uuid import UUID
21
21
 
22
+ from flwr.common.serde import user_config_to_proto
22
23
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
23
24
  CreateNodeRequest,
24
25
  CreateNodeResponse,
@@ -113,5 +114,15 @@ def get_run(
113
114
  ) -> GetRunResponse:
114
115
  """Get run information."""
115
116
  run = state.get_run(request.run_id)
116
- run_proto = None if run is None else Run(**vars(run))
117
- return GetRunResponse(run=run_proto)
117
+
118
+ if run is None:
119
+ return GetRunResponse()
120
+
121
+ return GetRunResponse(
122
+ run=Run(
123
+ run_id=run.run_id,
124
+ fab_id=run.fab_id,
125
+ fab_version=run.fab_version,
126
+ override_config=user_config_to_proto(run.override_config),
127
+ )
128
+ )
@@ -38,7 +38,7 @@ else:
38
38
 
39
39
  To install the necessary dependencies, install `flwr` with the `simulation` extra:
40
40
 
41
- pip install -U flwr["simulation"]
41
+ pip install -U "flwr[simulation]"
42
42
  """
43
43
 
44
44
 
@@ -72,8 +72,8 @@ def _register_node_states(
72
72
  node_states[node_id] = NodeState(
73
73
  node_id=node_id,
74
74
  node_config={
75
- PARTITION_ID_KEY: str(partition_id),
76
- NUM_PARTITIONS_KEY: str(num_partitions),
75
+ PARTITION_ID_KEY: partition_id,
76
+ NUM_PARTITIONS_KEY: num_partitions,
77
77
  },
78
78
  )
79
79
 
@@ -347,8 +347,8 @@ def start_vce(
347
347
  if client_app_attr:
348
348
  app = _get_load_client_app_fn(
349
349
  default_app_ref=client_app_attr,
350
- dir_arg=app_dir,
351
- flwr_dir_arg=flwr_dir,
350
+ project_dir=app_dir,
351
+ flwr_dir=flwr_dir,
352
352
  multi_app=True,
353
353
  )(run.fab_id, run.fab_version)
354
354
 
@@ -23,7 +23,7 @@ from uuid import UUID, uuid4
23
23
 
24
24
  from flwr.common import log, now
25
25
  from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
26
- from flwr.common.typing import Run
26
+ from flwr.common.typing import Run, UserConfig
27
27
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
28
28
  from flwr.server.superlink.state.state import State
29
29
  from flwr.server.utils import validate_task_ins_or_res
@@ -279,7 +279,7 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
279
279
  self,
280
280
  fab_id: str,
281
281
  fab_version: str,
282
- override_config: Dict[str, str],
282
+ override_config: UserConfig,
283
283
  ) -> int:
284
284
  """Create a new run for the specified `fab_id` and `fab_version`."""
285
285
  # Sample a random int64 as run_id
@@ -25,7 +25,7 @@ from uuid import UUID, uuid4
25
25
 
26
26
  from flwr.common import log, now
27
27
  from flwr.common.constant import NODE_ID_NUM_BYTES, RUN_ID_NUM_BYTES
28
- from flwr.common.typing import Run
28
+ from flwr.common.typing import Run, UserConfig
29
29
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
30
30
  from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
31
31
  from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
@@ -619,7 +619,7 @@ class SqliteState(State): # pylint: disable=R0904
619
619
  self,
620
620
  fab_id: str,
621
621
  fab_version: str,
622
- override_config: Dict[str, str],
622
+ override_config: UserConfig,
623
623
  ) -> int:
624
624
  """Create a new run for the specified `fab_id` and `fab_version`."""
625
625
  # Sample a random int64 as run_id
@@ -16,10 +16,10 @@
16
16
 
17
17
 
18
18
  import abc
19
- from typing import Dict, List, Optional, Set
19
+ from typing import List, Optional, Set
20
20
  from uuid import UUID
21
21
 
22
- from flwr.common.typing import Run
22
+ from flwr.common.typing import Run, UserConfig
23
23
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
24
24
 
25
25
 
@@ -161,7 +161,7 @@ class State(abc.ABC): # pylint: disable=R0904
161
161
  self,
162
162
  fab_id: str,
163
163
  fab_version: str,
164
- override_config: Dict[str, str],
164
+ override_config: UserConfig,
165
165
  ) -> int:
166
166
  """Create a new run for the specified `fab_id` and `fab_version`."""
167
167
 
@@ -81,6 +81,7 @@ class WorkflowState: # pylint: disable=R0902
81
81
  forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
82
82
  aggregate_ndarrays: NDArrays = field(default_factory=list)
83
83
  legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)
84
+ failures: List[Exception] = field(default_factory=list)
84
85
 
85
86
 
86
87
  class SecAggPlusWorkflow:
@@ -394,6 +395,7 @@ class SecAggPlusWorkflow:
394
395
 
395
396
  for msg in msgs:
396
397
  if msg.has_error():
398
+ state.failures.append(Exception(msg.error))
397
399
  continue
398
400
  key_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
399
401
  node_id = msg.metadata.src_node_id
@@ -451,6 +453,9 @@ class SecAggPlusWorkflow:
451
453
  nid: [] for nid in state.active_node_ids
452
454
  } # dest node ID -> list of src node IDs
453
455
  for msg in msgs:
456
+ if msg.has_error():
457
+ state.failures.append(Exception(msg.error))
458
+ continue
454
459
  node_id = msg.metadata.src_node_id
455
460
  res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
456
461
  dst_lst = cast(List[int], res_dict[Key.DESTINATION_LIST])
@@ -515,6 +520,9 @@ class SecAggPlusWorkflow:
515
520
  # Sum collected masked vectors and compute active/dead node IDs
516
521
  masked_vector = None
517
522
  for msg in msgs:
523
+ if msg.has_error():
524
+ state.failures.append(Exception(msg.error))
525
+ continue
518
526
  res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
519
527
  bytes_list = cast(List[bytes], res_dict[Key.MASKED_PARAMETERS])
520
528
  client_masked_vec = [bytes_to_ndarray(b) for b in bytes_list]
@@ -528,6 +536,9 @@ class SecAggPlusWorkflow:
528
536
 
529
537
  # Backward compatibility with Strategy
530
538
  for msg in msgs:
539
+ if msg.has_error():
540
+ state.failures.append(Exception(msg.error))
541
+ continue
531
542
  fitres = compat.recordset_to_fitres(msg.content, True)
532
543
  proxy = state.nid_to_proxies[msg.metadata.src_node_id]
533
544
  state.legacy_results.append((proxy, fitres))
@@ -584,6 +595,9 @@ class SecAggPlusWorkflow:
584
595
  for nid in state.sampled_node_ids:
585
596
  collected_shares_dict[nid] = []
586
597
  for msg in msgs:
598
+ if msg.has_error():
599
+ state.failures.append(Exception(msg.error))
600
+ continue
587
601
  res_dict = msg.content.configs_records[RECORD_KEY_CONFIGS]
588
602
  nids = cast(List[int], res_dict[Key.NODE_ID_LIST])
589
603
  shares = cast(List[bytes], res_dict[Key.SHARE_LIST])
@@ -652,9 +666,11 @@ class SecAggPlusWorkflow:
652
666
  INFO,
653
667
  "aggregate_fit: received %s results and %s failures",
654
668
  len(results),
655
- 0,
669
+ len(state.failures),
670
+ )
671
+ aggregated_result = context.strategy.aggregate_fit(
672
+ current_round, results, state.failures # type: ignore
656
673
  )
657
- aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
658
674
  parameters_aggregated, metrics_aggregated = aggregated_result
659
675
 
660
676
  # Update the parameters and write history
@@ -28,7 +28,7 @@ else:
28
28
 
29
29
  To install the necessary dependencies, install `flwr` with the `simulation` extra:
30
30
 
31
- pip install -U flwr["simulation"]
31
+ pip install -U "flwr[simulation]"
32
32
  """
33
33
 
34
34
  def start_simulation(*args, **kwargs): # type: ignore
@@ -82,7 +82,7 @@ class RayActorClientProxy(ClientProxy):
82
82
 
83
83
  # Retrieve context
84
84
  context = self.proxy_state.retrieve_context(run_id=run_id)
85
- partition_id_str = context.node_config[PARTITION_ID_KEY]
85
+ partition_id_str = str(context.node_config[PARTITION_ID_KEY])
86
86
 
87
87
  try:
88
88
  self.actor_pool.submit_client_job(
@@ -25,15 +25,19 @@ from argparse import Namespace
25
25
  from logging import DEBUG, ERROR, INFO, WARNING
26
26
  from pathlib import Path
27
27
  from time import sleep
28
- from typing import Dict, List, Optional
28
+ from typing import List, Optional
29
29
 
30
30
  from flwr.cli.config_utils import load_and_validate
31
31
  from flwr.client import ClientApp
32
32
  from flwr.common import EventType, event, log
33
33
  from flwr.common.config import get_fused_config_from_dir, parse_config_args
34
34
  from flwr.common.constant import RUN_ID_NUM_BYTES
35
- from flwr.common.logger import set_logger_propagation, update_console_handler
36
- from flwr.common.typing import Run
35
+ from flwr.common.logger import (
36
+ set_logger_propagation,
37
+ update_console_handler,
38
+ warn_deprecated_feature_with_example,
39
+ )
40
+ from flwr.common.typing import Run, UserConfig
37
41
  from flwr.server.driver import Driver, InMemoryDriver
38
42
  from flwr.server.run_serverapp import run as run_server_app
39
43
  from flwr.server.server_app import ServerApp
@@ -93,6 +97,14 @@ def run_simulation_from_cli() -> None:
93
97
  """Run Simulation Engine from the CLI."""
94
98
  args = _parse_args_run_simulation().parse_args()
95
99
 
100
+ if args.enable_tf_gpu_growth:
101
+ warn_deprecated_feature_with_example(
102
+ "Passing `--enable-tf-gpu-growth` is deprecated.",
103
+ example_message="Instead, set the `TF_FORCE_GPU_ALLOW_GROWTH` environmnet "
104
+ "variable to true.",
105
+ code_example='TF_FORCE_GPU_ALLOW_GROWTH="true" flower-simulation <...>',
106
+ )
107
+
96
108
  # We are supporting two modes for the CLI entrypoint:
97
109
  # 1) Running an app dir containing a `pyproject.toml`
98
110
  # 2) Running any ClientApp and SeverApp w/o pyproject.toml being present
@@ -223,6 +235,15 @@ def run_simulation(
223
235
  When disabled, only INFO, WARNING and ERROR log messages will be shown. If
224
236
  enabled, DEBUG-level logs will be displayed.
225
237
  """
238
+ if enable_tf_gpu_growth:
239
+ warn_deprecated_feature_with_example(
240
+ "Passing `enable_tf_gpu_growth=True` is deprecated.",
241
+ example_message="Instead, set the `TF_FORCE_GPU_ALLOW_GROWTH` environmnet "
242
+ "variable to true.",
243
+ code_example='import os;os.environ["TF_FORCE_GPU_ALLOW_GROWTH"]="true"'
244
+ "\n\tflwr.simulation.run_simulationt(...)",
245
+ )
246
+
226
247
  _run_simulation(
227
248
  num_supernodes=num_supernodes,
228
249
  client_app=client_app,
@@ -238,7 +259,7 @@ def run_simulation(
238
259
  def run_serverapp_th(
239
260
  server_app_attr: Optional[str],
240
261
  server_app: Optional[ServerApp],
241
- server_app_run_config: Dict[str, str],
262
+ server_app_run_config: UserConfig,
242
263
  driver: Driver,
243
264
  app_dir: str,
244
265
  f_stop: threading.Event,
@@ -254,7 +275,7 @@ def run_serverapp_th(
254
275
  exception_event: threading.Event,
255
276
  _driver: Driver,
256
277
  _server_app_dir: str,
257
- _server_app_run_config: Dict[str, str],
278
+ _server_app_run_config: UserConfig,
258
279
  _server_app_attr: Optional[str],
259
280
  _server_app: Optional[ServerApp],
260
281
  ) -> None:
@@ -264,7 +285,7 @@ def run_serverapp_th(
264
285
  """
265
286
  try:
266
287
  if tf_gpu_growth:
267
- log(INFO, "Enabling GPU growth for Tensorflow on the main thread.")
288
+ log(INFO, "Enabling GPU growth for Tensorflow on the server thread.")
268
289
  enable_gpu_growth()
269
290
 
270
291
  # Run ServerApp
@@ -319,7 +340,7 @@ def _main_loop(
319
340
  client_app_attr: Optional[str] = None,
320
341
  server_app: Optional[ServerApp] = None,
321
342
  server_app_attr: Optional[str] = None,
322
- server_app_run_config: Optional[Dict[str, str]] = None,
343
+ server_app_run_config: Optional[UserConfig] = None,
323
344
  ) -> None:
324
345
  """Launch SuperLink with Simulation Engine, then ServerApp on a separate thread."""
325
346
  # Initialize StateFactory
@@ -395,7 +416,7 @@ def _run_simulation(
395
416
  backend_config: Optional[BackendConfig] = None,
396
417
  client_app_attr: Optional[str] = None,
397
418
  server_app_attr: Optional[str] = None,
398
- server_app_run_config: Optional[Dict[str, str]] = None,
419
+ server_app_run_config: Optional[UserConfig] = None,
399
420
  app_dir: str = "",
400
421
  flwr_dir: Optional[str] = None,
401
422
  run: Optional[Run] = None,
@@ -438,7 +459,7 @@ def _run_simulation(
438
459
  A path to a `ServerApp` module to be loaded: For example: `server:app` or
439
460
  `project.package.module:wrapper.app`."
440
461
 
441
- server_app_run_config : Optional[Dict[str, str]]
462
+ server_app_run_config : Optional[UserConfig]
442
463
  Config dictionary that parameterizes the run config. It will be made accesible
443
464
  to the ServerApp.
444
465
 
@@ -475,6 +496,14 @@ def _run_simulation(
475
496
  if "init_args" not in backend_config:
476
497
  backend_config["init_args"] = {}
477
498
 
499
+ # Set default client_resources if not passed
500
+ if "client_resources" not in backend_config:
501
+ backend_config["client_resources"] = {"num_cpus": 2, "num_gpus": 0}
502
+
503
+ # Initialization of backend config to enable GPU growth globally when set
504
+ if "actor" not in backend_config:
505
+ backend_config["actor"] = {"tensorflow": 0}
506
+
478
507
  # Set logging level
479
508
  logger = logging.getLogger("flwr")
480
509
  if verbose_logging:
@@ -580,8 +609,7 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
580
609
  parser.add_argument(
581
610
  "--backend-config",
582
611
  type=str,
583
- default='{"client_resources": {"num_cpus":2, "num_gpus":0.0},'
584
- '"actor": {"tensorflow": 0}}',
612
+ default="{}",
585
613
  help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
586
614
  "configure a backend. Values supported in <value> are those included by "
587
615
  "`flwr.common.typing.ConfigsRecordValues`. ",
flwr/superexec/app.py CHANGED
@@ -93,7 +93,9 @@ def _parse_args_run_superexec() -> argparse.ArgumentParser:
93
93
  )
94
94
  parser.add_argument(
95
95
  "--executor-config",
96
- help="Key-value pairs for the executor config, separated by commas.",
96
+ help="Key-value pairs for the executor config, separated by commas. "
97
+ 'For example:\n\n`--executor-config superlink="superlink:9091",'
98
+ 'root-certificates="certificates/superlink-ca.crt"`',
97
99
  )
98
100
  parser.add_argument(
99
101
  "--insecure",
@@ -163,11 +165,8 @@ def _load_executor(
163
165
  args: argparse.Namespace,
164
166
  ) -> Executor:
165
167
  """Get the executor plugin."""
166
- if args.executor_dir is not None:
167
- sys.path.insert(0, args.executor_dir)
168
-
169
168
  executor_ref: str = args.executor
170
- valid, error_msg = validate(executor_ref)
169
+ valid, error_msg = validate(executor_ref, project_dir=args.executor_dir)
171
170
  if not valid and error_msg:
172
171
  raise LoadExecutorError(error_msg) from None
173
172