flwr-nightly 1.11.0.dev20240813__py3-none-any.whl → 1.11.0.dev20240821__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flwr-nightly might be problematic. Click here for more details.

Files changed (58) hide show
  1. flwr/cli/config_utils.py +2 -2
  2. flwr/cli/install.py +3 -1
  3. flwr/cli/run/run.py +15 -11
  4. flwr/client/app.py +132 -14
  5. flwr/client/clientapp/__init__.py +22 -0
  6. flwr/client/clientapp/app.py +233 -0
  7. flwr/client/clientapp/clientappio_servicer.py +244 -0
  8. flwr/client/clientapp/utils.py +108 -0
  9. flwr/client/grpc_rere_client/connection.py +9 -1
  10. flwr/client/node_state.py +17 -4
  11. flwr/client/rest_client/connection.py +16 -3
  12. flwr/client/supernode/__init__.py +0 -2
  13. flwr/client/supernode/app.py +37 -122
  14. flwr/common/__init__.py +4 -0
  15. flwr/common/config.py +31 -10
  16. flwr/common/record/configsrecord.py +49 -15
  17. flwr/common/record/metricsrecord.py +54 -14
  18. flwr/common/record/parametersrecord.py +84 -17
  19. flwr/common/record/recordset.py +80 -8
  20. flwr/common/record/typeddict.py +20 -58
  21. flwr/common/recordset_compat.py +6 -6
  22. flwr/common/serde.py +24 -2
  23. flwr/common/typing.py +1 -0
  24. flwr/proto/clientappio_pb2.py +17 -13
  25. flwr/proto/clientappio_pb2.pyi +24 -2
  26. flwr/proto/clientappio_pb2_grpc.py +34 -0
  27. flwr/proto/clientappio_pb2_grpc.pyi +13 -0
  28. flwr/proto/exec_pb2.py +16 -15
  29. flwr/proto/exec_pb2.pyi +7 -4
  30. flwr/proto/message_pb2.py +2 -2
  31. flwr/proto/message_pb2.pyi +4 -1
  32. flwr/server/app.py +15 -0
  33. flwr/server/driver/grpc_driver.py +1 -0
  34. flwr/server/run_serverapp.py +18 -2
  35. flwr/server/server.py +3 -1
  36. flwr/server/superlink/driver/driver_grpc.py +3 -0
  37. flwr/server/superlink/driver/driver_servicer.py +32 -4
  38. flwr/server/superlink/ffs/disk_ffs.py +6 -3
  39. flwr/server/superlink/ffs/ffs.py +3 -3
  40. flwr/server/superlink/ffs/ffs_factory.py +47 -0
  41. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +12 -4
  42. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +8 -2
  43. flwr/server/superlink/fleet/message_handler/message_handler.py +16 -1
  44. flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -2
  45. flwr/server/superlink/fleet/vce/vce_api.py +2 -2
  46. flwr/server/superlink/state/in_memory_state.py +7 -5
  47. flwr/server/superlink/state/sqlite_state.py +17 -7
  48. flwr/server/superlink/state/state.py +4 -3
  49. flwr/server/workflow/default_workflows.py +3 -1
  50. flwr/simulation/run_simulation.py +5 -67
  51. flwr/superexec/app.py +3 -3
  52. flwr/superexec/deployment.py +8 -9
  53. flwr/superexec/exec_servicer.py +1 -1
  54. {flwr_nightly-1.11.0.dev20240813.dist-info → flwr_nightly-1.11.0.dev20240821.dist-info}/METADATA +2 -2
  55. {flwr_nightly-1.11.0.dev20240813.dist-info → flwr_nightly-1.11.0.dev20240821.dist-info}/RECORD +58 -53
  56. {flwr_nightly-1.11.0.dev20240813.dist-info → flwr_nightly-1.11.0.dev20240821.dist-info}/entry_points.txt +1 -1
  57. {flwr_nightly-1.11.0.dev20240813.dist-info → flwr_nightly-1.11.0.dev20240821.dist-info}/LICENSE +0 -0
  58. {flwr_nightly-1.11.0.dev20240813.dist-info → flwr_nightly-1.11.0.dev20240821.dist-info}/WHEEL +0 -0
flwr/proto/exec_pb2.pyi CHANGED
@@ -3,6 +3,7 @@
3
3
  isort:skip_file
4
4
  """
5
5
  import builtins
6
+ import flwr.proto.fab_pb2
6
7
  import flwr.proto.transport_pb2
7
8
  import google.protobuf.descriptor
8
9
  import google.protobuf.internal.containers
@@ -44,21 +45,23 @@ class StartRunRequest(google.protobuf.message.Message):
44
45
  def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
45
46
  def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...
46
47
 
47
- FAB_FILE_FIELD_NUMBER: builtins.int
48
+ FAB_FIELD_NUMBER: builtins.int
48
49
  OVERRIDE_CONFIG_FIELD_NUMBER: builtins.int
49
50
  FEDERATION_CONFIG_FIELD_NUMBER: builtins.int
50
- fab_file: builtins.bytes
51
+ @property
52
+ def fab(self) -> flwr.proto.fab_pb2.Fab: ...
51
53
  @property
52
54
  def override_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
53
55
  @property
54
56
  def federation_config(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, flwr.proto.transport_pb2.Scalar]: ...
55
57
  def __init__(self,
56
58
  *,
57
- fab_file: builtins.bytes = ...,
59
+ fab: typing.Optional[flwr.proto.fab_pb2.Fab] = ...,
58
60
  override_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
59
61
  federation_config: typing.Optional[typing.Mapping[typing.Text, flwr.proto.transport_pb2.Scalar]] = ...,
60
62
  ) -> None: ...
61
- def ClearField(self, field_name: typing_extensions.Literal["fab_file",b"fab_file","federation_config",b"federation_config","override_config",b"override_config"]) -> None: ...
63
+ def HasField(self, field_name: typing_extensions.Literal["fab",b"fab"]) -> builtins.bool: ...
64
+ def ClearField(self, field_name: typing_extensions.Literal["fab",b"fab","federation_config",b"federation_config","override_config",b"override_config"]) -> None: ...
62
65
  global___StartRunRequest = StartRunRequest
63
66
 
64
67
  class StartRunResponse(google.protobuf.message.Message):
flwr/proto/message_pb2.py CHANGED
@@ -17,7 +17,7 @@ from flwr.proto import recordset_pb2 as flwr_dot_proto_dot_recordset__pb2
17
17
  from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2
18
18
 
19
19
 
20
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/message.proto\x12\nflwr.proto\x1a\x16\x66lwr/proto/error.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"{\n\x07Message\x12&\n\x08metadata\x18\x01 \x01(\x0b\x32\x14.flwr.proto.Metadata\x12&\n\x07\x63ontent\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\x03 \x01(\x0b\x32\x11.flwr.proto.Error\"\xbf\x02\n\x07\x43ontext\x12\x0f\n\x07node_id\x18\x01 \x01(\x12\x12\x38\n\x0bnode_config\x18\x02 \x03(\x0b\x32#.flwr.proto.Context.NodeConfigEntry\x12$\n\x05state\x18\x03 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12\x36\n\nrun_config\x18\x04 \x03(\x0b\x32\".flwr.proto.Context.RunConfigEntry\x1a\x45\n\x0fNodeConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x44\n\x0eRunConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\xa7\x01\n\x08Metadata\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\x13\n\x0bsrc_node_id\x18\x03 \x01(\x12\x12\x13\n\x0b\x64st_node_id\x18\x04 \x01(\x12\x12\x18\n\x10reply_to_message\x18\x05 \x01(\t\x12\x10\n\x08group_id\x18\x06 \x01(\t\x12\x0b\n\x03ttl\x18\x07 \x01(\x01\x12\x14\n\x0cmessage_type\x18\x08 \x01(\tb\x06proto3')
20
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x66lwr/proto/message.proto\x12\nflwr.proto\x1a\x16\x66lwr/proto/error.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"{\n\x07Message\x12&\n\x08metadata\x18\x01 \x01(\x0b\x32\x14.flwr.proto.Metadata\x12&\n\x07\x63ontent\x18\x02 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12 \n\x05\x65rror\x18\x03 \x01(\x0b\x32\x11.flwr.proto.Error\"\xbf\x02\n\x07\x43ontext\x12\x0f\n\x07node_id\x18\x01 \x01(\x12\x12\x38\n\x0bnode_config\x18\x02 \x03(\x0b\x32#.flwr.proto.Context.NodeConfigEntry\x12$\n\x05state\x18\x03 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12\x36\n\nrun_config\x18\x04 \x03(\x0b\x32\".flwr.proto.Context.RunConfigEntry\x1a\x45\n\x0fNodeConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\x1a\x44\n\x0eRunConfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.flwr.proto.Scalar:\x02\x38\x01\"\xbb\x01\n\x08Metadata\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\x13\n\x0bsrc_node_id\x18\x03 \x01(\x12\x12\x13\n\x0b\x64st_node_id\x18\x04 \x01(\x12\x12\x18\n\x10reply_to_message\x18\x05 \x01(\t\x12\x10\n\x08group_id\x18\x06 \x01(\t\x12\x0b\n\x03ttl\x18\x07 \x01(\x01\x12\x14\n\x0cmessage_type\x18\x08 \x01(\t\x12\x12\n\ncreated_at\x18\t \x01(\x01\x62\x06proto3')
21
21
 
22
22
  _globals = globals()
23
23
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
@@ -37,5 +37,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
37
37
  _globals['_CONTEXT_RUNCONFIGENTRY']._serialized_start=497
38
38
  _globals['_CONTEXT_RUNCONFIGENTRY']._serialized_end=565
39
39
  _globals['_METADATA']._serialized_start=568
40
- _globals['_METADATA']._serialized_end=735
40
+ _globals['_METADATA']._serialized_end=755
41
41
  # @@protoc_insertion_point(module_scope)
@@ -99,6 +99,7 @@ class Metadata(google.protobuf.message.Message):
99
99
  GROUP_ID_FIELD_NUMBER: builtins.int
100
100
  TTL_FIELD_NUMBER: builtins.int
101
101
  MESSAGE_TYPE_FIELD_NUMBER: builtins.int
102
+ CREATED_AT_FIELD_NUMBER: builtins.int
102
103
  run_id: builtins.int
103
104
  message_id: typing.Text
104
105
  src_node_id: builtins.int
@@ -107,6 +108,7 @@ class Metadata(google.protobuf.message.Message):
107
108
  group_id: typing.Text
108
109
  ttl: builtins.float
109
110
  message_type: typing.Text
111
+ created_at: builtins.float
110
112
  def __init__(self,
111
113
  *,
112
114
  run_id: builtins.int = ...,
@@ -117,6 +119,7 @@ class Metadata(google.protobuf.message.Message):
117
119
  group_id: typing.Text = ...,
118
120
  ttl: builtins.float = ...,
119
121
  message_type: typing.Text = ...,
122
+ created_at: builtins.float = ...,
120
123
  ) -> None: ...
121
- def ClearField(self, field_name: typing_extensions.Literal["dst_node_id",b"dst_node_id","group_id",b"group_id","message_id",b"message_id","message_type",b"message_type","reply_to_message",b"reply_to_message","run_id",b"run_id","src_node_id",b"src_node_id","ttl",b"ttl"]) -> None: ...
124
+ def ClearField(self, field_name: typing_extensions.Literal["created_at",b"created_at","dst_node_id",b"dst_node_id","group_id",b"group_id","message_id",b"message_id","message_type",b"message_type","reply_to_message",b"reply_to_message","run_id",b"run_id","src_node_id",b"src_node_id","ttl",b"ttl"]) -> None: ...
122
125
  global___Metadata = Metadata
flwr/server/app.py CHANGED
@@ -34,6 +34,7 @@ from cryptography.hazmat.primitives.serialization import (
34
34
 
35
35
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
36
36
  from flwr.common.address import parse_address
37
+ from flwr.common.config import get_flwr_dir
37
38
  from flwr.common.constant import (
38
39
  MISSING_EXTRA_REST,
39
40
  TRANSPORT_TYPE_GRPC_ADAPTER,
@@ -57,6 +58,7 @@ from .server import Server, init_defaults, run_fl
57
58
  from .server_config import ServerConfig
58
59
  from .strategy import Strategy
59
60
  from .superlink.driver.driver_grpc import run_driver_api_grpc
61
+ from .superlink.ffs.ffs_factory import FfsFactory
60
62
  from .superlink.fleet.grpc_adapter.grpc_adapter_servicer import GrpcAdapterServicer
61
63
  from .superlink.fleet.grpc_bidi.grpc_server import (
62
64
  generic_create_grpc_server,
@@ -72,6 +74,7 @@ ADDRESS_FLEET_API_GRPC_BIDI = "[::]:8080" # IPv6 to keep start_server compatibl
72
74
  ADDRESS_FLEET_API_REST = "0.0.0.0:9093"
73
75
 
74
76
  DATABASE = ":flwr-in-memory-state:"
77
+ BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
75
78
 
76
79
 
77
80
  def start_server( # pylint: disable=too-many-arguments,too-many-locals
@@ -211,10 +214,14 @@ def run_superlink() -> None:
211
214
  # Initialize StateFactory
212
215
  state_factory = StateFactory(args.database)
213
216
 
217
+ # Initialize FfsFactory
218
+ ffs_factory = FfsFactory(args.storage_dir)
219
+
214
220
  # Start Driver API
215
221
  driver_server: grpc.Server = run_driver_api_grpc(
216
222
  address=driver_address,
217
223
  state_factory=state_factory,
224
+ ffs_factory=ffs_factory,
218
225
  certificates=certificates,
219
226
  )
220
227
 
@@ -294,6 +301,7 @@ def run_superlink() -> None:
294
301
  fleet_server = _run_fleet_api_grpc_rere(
295
302
  address=fleet_address,
296
303
  state_factory=state_factory,
304
+ ffs_factory=ffs_factory,
297
305
  certificates=certificates,
298
306
  interceptors=interceptors,
299
307
  )
@@ -480,6 +488,7 @@ def _try_obtain_certificates(
480
488
  def _run_fleet_api_grpc_rere(
481
489
  address: str,
482
490
  state_factory: StateFactory,
491
+ ffs_factory: FfsFactory,
483
492
  certificates: Optional[Tuple[bytes, bytes, bytes]],
484
493
  interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
485
494
  ) -> grpc.Server:
@@ -487,6 +496,7 @@ def _run_fleet_api_grpc_rere(
487
496
  # Create Fleet API gRPC server
488
497
  fleet_servicer = FleetServicer(
489
498
  state_factory=state_factory,
499
+ ffs_factory=ffs_factory,
490
500
  )
491
501
  fleet_add_servicer_to_server_fn = add_FleetServicer_to_server
492
502
  fleet_grpc_server = generic_create_grpc_server(
@@ -610,6 +620,11 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
610
620
  "Flower will just create a state in memory.",
611
621
  default=DATABASE,
612
622
  )
623
+ parser.add_argument(
624
+ "--storage-dir",
625
+ help="The base directory to store the objects for the Flower File System.",
626
+ default=BASE_DIR,
627
+ )
613
628
  parser.add_argument(
614
629
  "--auth-list-public-keys",
615
630
  type=str,
@@ -131,6 +131,7 @@ class GrpcDriver(Driver):
131
131
  run_id=res.run.run_id,
132
132
  fab_id=res.run.fab_id,
133
133
  fab_version=res.run.fab_version,
134
+ fab_hash=res.run.fab_hash,
134
135
  override_config=user_config_from_proto(res.run.override_config),
135
136
  )
136
137
 
@@ -21,6 +21,8 @@ from logging import DEBUG, INFO, WARN
21
21
  from pathlib import Path
22
22
  from typing import Optional
23
23
 
24
+ from flwr.cli.config_utils import get_fab_metadata
25
+ from flwr.cli.install import install_from_fab
24
26
  from flwr.common import Context, EventType, RecordSet, event
25
27
  from flwr.common.config import (
26
28
  get_flwr_dir,
@@ -36,6 +38,7 @@ from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
36
38
  CreateRunRequest,
37
39
  CreateRunResponse,
38
40
  )
41
+ from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
39
42
 
40
43
  from .driver import Driver
41
44
  from .driver.grpc_driver import GrpcDriver
@@ -87,7 +90,8 @@ def run(
87
90
  log(DEBUG, "ServerApp finished running.")
88
91
 
89
92
 
90
- def run_server_app() -> None: # pylint: disable=too-many-branches
93
+ # pylint: disable-next=too-many-branches,too-many-statements,too-many-locals
94
+ def run_server_app() -> None:
91
95
  """Run Flower server app."""
92
96
  event(EventType.RUN_SERVER_APP_ENTER)
93
97
 
@@ -164,7 +168,19 @@ def run_server_app() -> None: # pylint: disable=too-many-branches
164
168
  )
165
169
  flwr_dir = get_flwr_dir(args.flwr_dir)
166
170
  run_ = driver.run
167
- app_path = str(get_project_dir(run_.fab_id, run_.fab_version, flwr_dir))
171
+ if run_.fab_hash:
172
+ fab_req = GetFabRequest(hash_str=run_.fab_hash)
173
+ # pylint: disable-next=W0212
174
+ fab_res: GetFabResponse = driver._stub.GetFab(fab_req)
175
+ if fab_res.fab.hash_str != run_.fab_hash:
176
+ raise ValueError("FAB hashes don't match.")
177
+
178
+ install_from_fab(fab_res.fab.content, flwr_dir, True)
179
+ fab_id, fab_version = get_fab_metadata(fab_res.fab.content)
180
+ else:
181
+ fab_id, fab_version = run_.fab_id, run_.fab_version
182
+
183
+ app_path = str(get_project_dir(fab_id, fab_version, flwr_dir))
168
184
  config = get_project_config(app_path)
169
185
  else:
170
186
  # User provided `app_dir`, but not `--run-id`
flwr/server/server.py CHANGED
@@ -91,7 +91,7 @@ class Server:
91
91
  # Initialize parameters
92
92
  log(INFO, "[INIT]")
93
93
  self.parameters = self._get_initial_parameters(server_round=0, timeout=timeout)
94
- log(INFO, "Evaluating initial global parameters")
94
+ log(INFO, "Starting evaluation of initial global parameters")
95
95
  res = self.strategy.evaluate(0, parameters=self.parameters)
96
96
  if res is not None:
97
97
  log(
@@ -102,6 +102,8 @@ class Server:
102
102
  )
103
103
  history.add_loss_centralized(server_round=0, loss=res[0])
104
104
  history.add_metrics_centralized(server_round=0, metrics=res[1])
105
+ else:
106
+ log(INFO, "Evaluation returned no results (`None`)")
105
107
 
106
108
  # Run federated learning for num_rounds
107
109
  start_time = timeit.default_timer()
@@ -24,6 +24,7 @@ from flwr.common.logger import log
24
24
  from flwr.proto.driver_pb2_grpc import ( # pylint: disable=E0611
25
25
  add_DriverServicer_to_server,
26
26
  )
27
+ from flwr.server.superlink.ffs.ffs_factory import FfsFactory
27
28
  from flwr.server.superlink.state import StateFactory
28
29
 
29
30
  from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server
@@ -33,12 +34,14 @@ from .driver_servicer import DriverServicer
33
34
  def run_driver_api_grpc(
34
35
  address: str,
35
36
  state_factory: StateFactory,
37
+ ffs_factory: FfsFactory,
36
38
  certificates: Optional[Tuple[bytes, bytes, bytes]],
37
39
  ) -> grpc.Server:
38
40
  """Run Driver API (gRPC, request-response)."""
39
41
  # Create Driver API gRPC server
40
42
  driver_servicer: grpc.Server = DriverServicer(
41
43
  state_factory=state_factory,
44
+ ffs_factory=ffs_factory,
42
45
  )
43
46
  driver_add_servicer_to_server_fn = add_DriverServicer_to_server
44
47
  driver_grpc_server = generic_create_grpc_server(
@@ -23,7 +23,13 @@ from uuid import UUID
23
23
  import grpc
24
24
 
25
25
  from flwr.common.logger import log
26
- from flwr.common.serde import user_config_from_proto, user_config_to_proto
26
+ from flwr.common.serde import (
27
+ fab_from_proto,
28
+ fab_to_proto,
29
+ user_config_from_proto,
30
+ user_config_to_proto,
31
+ )
32
+ from flwr.common.typing import Fab
27
33
  from flwr.proto import driver_pb2_grpc # pylint: disable=E0611
28
34
  from flwr.proto.driver_pb2 import ( # pylint: disable=E0611
29
35
  CreateRunRequest,
@@ -43,6 +49,8 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
43
49
  Run,
44
50
  )
45
51
  from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
52
+ from flwr.server.superlink.ffs.ffs import Ffs
53
+ from flwr.server.superlink.ffs.ffs_factory import FfsFactory
46
54
  from flwr.server.superlink.state import State, StateFactory
47
55
  from flwr.server.utils.validator import validate_task_ins_or_res
48
56
 
@@ -50,8 +58,9 @@ from flwr.server.utils.validator import validate_task_ins_or_res
50
58
  class DriverServicer(driver_pb2_grpc.DriverServicer):
51
59
  """Driver API servicer."""
52
60
 
53
- def __init__(self, state_factory: StateFactory) -> None:
61
+ def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
54
62
  self.state_factory = state_factory
63
+ self.ffs_factory = ffs_factory
55
64
 
56
65
  def GetNodes(
57
66
  self, request: GetNodesRequest, context: grpc.ServicerContext
@@ -71,9 +80,20 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
71
80
  """Create run ID."""
72
81
  log(DEBUG, "DriverServicer.CreateRun")
73
82
  state: State = self.state_factory.state()
83
+ if request.HasField("fab"):
84
+ fab = fab_from_proto(request.fab)
85
+ ffs: Ffs = self.ffs_factory.ffs()
86
+ fab_hash = ffs.put(fab.content, {})
87
+ _raise_if(
88
+ fab_hash != fab.hash_str,
89
+ f"FAB ({fab.hash_str}) hash from request doesn't match contents",
90
+ )
91
+ else:
92
+ fab_hash = ""
74
93
  run_id = state.create_run(
75
94
  request.fab_id,
76
95
  request.fab_version,
96
+ fab_hash,
77
97
  user_config_from_proto(request.override_config),
78
98
  )
79
99
  return CreateRunResponse(run_id=run_id)
@@ -161,14 +181,22 @@ class DriverServicer(driver_pb2_grpc.DriverServicer):
161
181
  fab_id=run.fab_id,
162
182
  fab_version=run.fab_version,
163
183
  override_config=user_config_to_proto(run.override_config),
184
+ fab_hash=run.fab_hash,
164
185
  )
165
186
  )
166
187
 
167
188
  def GetFab(
168
189
  self, request: GetFabRequest, context: grpc.ServicerContext
169
190
  ) -> GetFabResponse:
170
- """Will be implemented later."""
171
- raise NotImplementedError
191
+ """Get FAB from Ffs."""
192
+ log(DEBUG, "DriverServicer.GetFab")
193
+
194
+ ffs: Ffs = self.ffs_factory.ffs()
195
+ if result := ffs.get(request.hash_str):
196
+ fab = Fab(request.hash_str, result[0])
197
+ return GetFabResponse(fab=fab_to_proto(fab))
198
+
199
+ raise ValueError(f"Found no FAB with hash: {request.hash_str}")
172
200
 
173
201
 
174
202
  def _raise_if(validation_error: bool, detail: str) -> None:
@@ -17,7 +17,7 @@
17
17
  import hashlib
18
18
  import json
19
19
  from pathlib import Path
20
- from typing import Dict, List, Tuple
20
+ from typing import Dict, List, Optional, Tuple
21
21
 
22
22
  from flwr.server.superlink.ffs.ffs import Ffs
23
23
 
@@ -58,7 +58,7 @@ class DiskFfs(Ffs): # pylint: disable=R0904
58
58
 
59
59
  return content_hash
60
60
 
61
- def get(self, key: str) -> Tuple[bytes, Dict[str, str]]:
61
+ def get(self, key: str) -> Optional[Tuple[bytes, Dict[str, str]]]:
62
62
  """Return tuple containing the object content and metadata.
63
63
 
64
64
  Parameters
@@ -68,9 +68,12 @@ class DiskFfs(Ffs): # pylint: disable=R0904
68
68
 
69
69
  Returns
70
70
  -------
71
- Tuple[bytes, Dict[str, str]]
71
+ Optional[Tuple[bytes, Dict[str, str]]]
72
72
  A tuple containing the object content and metadata.
73
73
  """
74
+ if not (self.base_dir / key).exists():
75
+ return None
76
+
74
77
  content = (self.base_dir / key).read_bytes()
75
78
  meta = json.loads((self.base_dir / f"{key}.META").read_text())
76
79
 
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  import abc
19
- from typing import Dict, List, Tuple
19
+ from typing import Dict, List, Optional, Tuple
20
20
 
21
21
 
22
22
  class Ffs(abc.ABC): # pylint: disable=R0904
@@ -40,7 +40,7 @@ class Ffs(abc.ABC): # pylint: disable=R0904
40
40
  """
41
41
 
42
42
  @abc.abstractmethod
43
- def get(self, key: str) -> Tuple[bytes, Dict[str, str]]:
43
+ def get(self, key: str) -> Optional[Tuple[bytes, Dict[str, str]]]:
44
44
  """Return tuple containing the object content and metadata.
45
45
 
46
46
  Parameters
@@ -50,7 +50,7 @@ class Ffs(abc.ABC): # pylint: disable=R0904
50
50
 
51
51
  Returns
52
52
  -------
53
- Tuple[bytes, Dict[str, str]]
53
+ Optional[Tuple[bytes, Dict[str, str]]]
54
54
  A tuple containing the object content and metadata.
55
55
  """
56
56
 
@@ -0,0 +1,47 @@
1
+ # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Factory class that creates Ffs instances."""
16
+
17
+
18
+ from logging import DEBUG
19
+ from typing import Optional
20
+
21
+ from flwr.common.logger import log
22
+
23
+ from .disk_ffs import DiskFfs
24
+ from .ffs import Ffs
25
+
26
+
27
+ class FfsFactory:
28
+ """Factory class that creates Ffs instances.
29
+
30
+ Parameters
31
+ ----------
32
+ base_dir : str
33
+ The base directory used by DiskFfs to store objects.
34
+ """
35
+
36
+ def __init__(self, base_dir: str) -> None:
37
+ self.base_dir = base_dir
38
+ self.ffs_instance: Optional[Ffs] = None
39
+
40
+ def ffs(self) -> Ffs:
41
+ """Return a Ffs instance and create it, if necessary."""
42
+ if not self.ffs_instance:
43
+ log(DEBUG, "Initializing DiskFfs")
44
+ self.ffs_instance = DiskFfs(self.base_dir)
45
+
46
+ log(DEBUG, "Using DiskFfs")
47
+ return self.ffs_instance
@@ -35,6 +35,7 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
35
35
  PushTaskResResponse,
36
36
  )
37
37
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
38
+ from flwr.server.superlink.ffs.ffs_factory import FfsFactory
38
39
  from flwr.server.superlink.fleet.message_handler import message_handler
39
40
  from flwr.server.superlink.state import StateFactory
40
41
 
@@ -42,18 +43,21 @@ from flwr.server.superlink.state import StateFactory
42
43
  class FleetServicer(fleet_pb2_grpc.FleetServicer):
43
44
  """Fleet API servicer."""
44
45
 
45
- def __init__(self, state_factory: StateFactory) -> None:
46
+ def __init__(self, state_factory: StateFactory, ffs_factory: FfsFactory) -> None:
46
47
  self.state_factory = state_factory
48
+ self.ffs_factory = ffs_factory
47
49
 
48
50
  def CreateNode(
49
51
  self, request: CreateNodeRequest, context: grpc.ServicerContext
50
52
  ) -> CreateNodeResponse:
51
53
  """."""
52
54
  log(INFO, "FleetServicer.CreateNode")
53
- return message_handler.create_node(
55
+ response = message_handler.create_node(
54
56
  request=request,
55
57
  state=self.state_factory.state(),
56
58
  )
59
+ log(INFO, "FleetServicer: Created node_id=%s", response.node.node_id)
60
+ return response
57
61
 
58
62
  def DeleteNode(
59
63
  self, request: DeleteNodeRequest, context: grpc.ServicerContext
@@ -106,5 +110,9 @@ class FleetServicer(fleet_pb2_grpc.FleetServicer):
106
110
  def GetFab(
107
111
  self, request: GetFabRequest, context: grpc.ServicerContext
108
112
  ) -> GetFabResponse:
109
- """Will be implemented later."""
110
- raise NotImplementedError
113
+ """Get FAB."""
114
+ log(DEBUG, "DriverServicer.GetFab")
115
+ return message_handler.get_fab(
116
+ request=request,
117
+ ffs=self.ffs_factory.ffs(),
118
+ )
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  import base64
19
- from logging import WARNING
19
+ from logging import INFO, WARNING
20
20
  from typing import Any, Callable, Optional, Sequence, Tuple, Union
21
21
 
22
22
  import grpc
@@ -128,9 +128,15 @@ class AuthenticateServerInterceptor(grpc.ServerInterceptor): # type: ignore
128
128
  context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
129
129
 
130
130
  if isinstance(request, CreateNodeRequest):
131
- return self._create_authenticated_node(
131
+ response = self._create_authenticated_node(
132
132
  client_public_key_bytes, request, context
133
133
  )
134
+ log(
135
+ INFO,
136
+ "AuthenticateServerInterceptor: Created node_id=%s",
137
+ response.node.node_id,
138
+ )
139
+ return response
134
140
 
135
141
  # Verify hmac value
136
142
  hmac_value = base64.urlsafe_b64decode(
@@ -19,7 +19,9 @@ import time
19
19
  from typing import List, Optional
20
20
  from uuid import UUID
21
21
 
22
- from flwr.common.serde import user_config_to_proto
22
+ from flwr.common.serde import fab_to_proto, user_config_to_proto
23
+ from flwr.common.typing import Fab
24
+ from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
23
25
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
24
26
  CreateNodeRequest,
25
27
  CreateNodeResponse,
@@ -40,6 +42,7 @@ from flwr.proto.run_pb2 import ( # pylint: disable=E0611
40
42
  Run,
41
43
  )
42
44
  from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
45
+ from flwr.server.superlink.ffs.ffs import Ffs
43
46
  from flwr.server.superlink.state import State
44
47
 
45
48
 
@@ -124,5 +127,17 @@ def get_run(
124
127
  fab_id=run.fab_id,
125
128
  fab_version=run.fab_version,
126
129
  override_config=user_config_to_proto(run.override_config),
130
+ fab_hash=run.fab_hash,
127
131
  )
128
132
  )
133
+
134
+
135
+ def get_fab(
136
+ request: GetFabRequest, ffs: Ffs # pylint: disable=W0613
137
+ ) -> GetFabResponse:
138
+ """Get FAB."""
139
+ if result := ffs.get(request.hash_str):
140
+ fab = Fab(request.hash_str, result[0])
141
+ return GetFabResponse(fab=fab_to_proto(fab))
142
+
143
+ raise ValueError(f"Found no FAB with hash: {request.hash_str}")
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Ray backend for the Fleet API using the Simulation Engine."""
16
16
 
17
+ import sys
17
18
  from logging import DEBUG, ERROR
18
19
  from typing import Callable, Dict, Tuple, Union
19
20
 
@@ -111,8 +112,10 @@ class RayBackend(Backend):
111
112
  if backend_config.get(self.init_args_key):
112
113
  for k, v in backend_config[self.init_args_key].items():
113
114
  ray_init_args[k] = v
114
-
115
- ray.init(**ray_init_args)
115
+ ray.init(
116
+ runtime_env={"env_vars": {"PYTHONPATH": ":".join(sys.path)}},
117
+ **ray_init_args,
118
+ )
116
119
 
117
120
  @property
118
121
  def num_workers(self) -> int:
@@ -27,8 +27,8 @@ from time import sleep
27
27
  from typing import Callable, Dict, Optional
28
28
 
29
29
  from flwr.client.client_app import ClientApp, ClientAppException, LoadClientAppError
30
+ from flwr.client.clientapp.utils import get_load_client_app_fn
30
31
  from flwr.client.node_state import NodeState
31
- from flwr.client.supernode.app import _get_load_client_app_fn
32
32
  from flwr.common.constant import (
33
33
  NUM_PARTITIONS_KEY,
34
34
  PARTITION_ID_KEY,
@@ -345,7 +345,7 @@ def start_vce(
345
345
  def _load() -> ClientApp:
346
346
 
347
347
  if client_app_attr:
348
- app = _get_load_client_app_fn(
348
+ app = get_load_client_app_fn(
349
349
  default_app_ref=client_app_attr,
350
350
  app_path=app_dir,
351
351
  flwr_dir=flwr_dir,
@@ -277,11 +277,12 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
277
277
 
278
278
  def create_run(
279
279
  self,
280
- fab_id: str,
281
- fab_version: str,
280
+ fab_id: Optional[str],
281
+ fab_version: Optional[str],
282
+ fab_hash: Optional[str],
282
283
  override_config: UserConfig,
283
284
  ) -> int:
284
- """Create a new run for the specified `fab_id` and `fab_version`."""
285
+ """Create a new run for the specified `fab_hash`."""
285
286
  # Sample a random int64 as run_id
286
287
  with self.lock:
287
288
  run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
@@ -289,8 +290,9 @@ class InMemoryState(State): # pylint: disable=R0902,R0904
289
290
  if run_id not in self.run_ids:
290
291
  self.run_ids[run_id] = Run(
291
292
  run_id=run_id,
292
- fab_id=fab_id,
293
- fab_version=fab_version,
293
+ fab_id=fab_id if fab_id else "",
294
+ fab_version=fab_version if fab_version else "",
295
+ fab_hash=fab_hash if fab_hash else "",
294
296
  override_config=override_config,
295
297
  )
296
298
  return run_id