flwr-nightly 1.19.0.dev20250526__py3-none-any.whl → 1.19.0.dev20250528__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. flwr/cli/log.py +3 -3
  2. flwr/cli/login/login.py +3 -7
  3. flwr/cli/ls.py +3 -3
  4. flwr/cli/run/run.py +2 -6
  5. flwr/cli/stop.py +2 -2
  6. flwr/cli/utils.py +5 -4
  7. flwr/client/grpc_rere_client/connection.py +2 -0
  8. flwr/client/message_handler/message_handler.py +1 -1
  9. flwr/common/auth_plugin/__init__.py +2 -0
  10. flwr/common/auth_plugin/auth_plugin.py +18 -0
  11. flwr/common/constant.py +3 -0
  12. flwr/common/inflatable.py +33 -2
  13. flwr/common/message.py +5 -1
  14. flwr/common/record/array.py +38 -1
  15. flwr/common/record/arrayrecord.py +34 -0
  16. flwr/common/serde.py +6 -1
  17. flwr/compat/client/app.py +9 -151
  18. flwr/proto/fleet_pb2.py +25 -13
  19. flwr/proto/fleet_pb2.pyi +60 -3
  20. flwr/proto/message_pb2.py +22 -19
  21. flwr/proto/message_pb2.pyi +25 -2
  22. flwr/proto/serverappio_pb2.py +31 -19
  23. flwr/proto/serverappio_pb2.pyi +60 -3
  24. flwr/server/app.py +44 -1
  25. flwr/server/grid/grpc_grid.py +2 -1
  26. flwr/server/grid/inmemory_grid.py +5 -4
  27. flwr/server/superlink/fleet/message_handler/message_handler.py +1 -2
  28. flwr/server/superlink/fleet/vce/vce_api.py +3 -0
  29. flwr/server/superlink/linkstate/in_memory_linkstate.py +14 -25
  30. flwr/server/superlink/linkstate/linkstate.py +9 -10
  31. flwr/server/superlink/linkstate/sqlite_linkstate.py +11 -21
  32. flwr/server/superlink/linkstate/utils.py +23 -23
  33. flwr/server/superlink/serverappio/serverappio_servicer.py +6 -10
  34. flwr/server/utils/validator.py +2 -2
  35. flwr/supercore/object_store/in_memory_object_store.py +30 -4
  36. flwr/supercore/object_store/object_store.py +48 -1
  37. flwr/superexec/exec_servicer.py +1 -2
  38. {flwr_nightly-1.19.0.dev20250526.dist-info → flwr_nightly-1.19.0.dev20250528.dist-info}/METADATA +1 -1
  39. {flwr_nightly-1.19.0.dev20250526.dist-info → flwr_nightly-1.19.0.dev20250528.dist-info}/RECORD +41 -41
  40. {flwr_nightly-1.19.0.dev20250526.dist-info → flwr_nightly-1.19.0.dev20250528.dist-info}/WHEEL +0 -0
  41. {flwr_nightly-1.19.0.dev20250526.dist-info → flwr_nightly-1.19.0.dev20250528.dist-info}/entry_points.txt +0 -0
flwr/cli/log.py CHANGED
@@ -35,7 +35,7 @@ from flwr.common.logger import log as logger
35
35
  from flwr.proto.exec_pb2 import StreamLogsRequest # pylint: disable=E0611
36
36
  from flwr.proto.exec_pb2_grpc import ExecStub
37
37
 
38
- from .utils import init_channel, try_obtain_cli_auth_plugin, unauthenticated_exc_handler
38
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
39
39
 
40
40
 
41
41
  class AllLogsRetrieved(BaseException):
@@ -95,7 +95,7 @@ def stream_logs(
95
95
  latest_timestamp = 0.0
96
96
  res = None
97
97
  try:
98
- with unauthenticated_exc_handler():
98
+ with flwr_cli_grpc_exc_handler():
99
99
  for res in stub.StreamLogs(req, timeout=duration):
100
100
  print(res.log_output, end="")
101
101
  raise AllLogsRetrieved()
@@ -116,7 +116,7 @@ def print_logs(run_id: int, channel: grpc.Channel, timeout: int) -> None:
116
116
  req = StreamLogsRequest(run_id=run_id, after_timestamp=0.0)
117
117
 
118
118
  try:
119
- with unauthenticated_exc_handler():
119
+ with flwr_cli_grpc_exc_handler():
120
120
  # Enforce timeout for graceful exit
121
121
  for res in stub.StreamLogs(req, timeout=timeout):
122
122
  print(res.log_output)
flwr/cli/login/login.py CHANGED
@@ -35,11 +35,7 @@ from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
35
35
  )
36
36
  from flwr.proto.exec_pb2_grpc import ExecStub
37
37
 
38
- from ..utils import (
39
- init_channel,
40
- try_obtain_cli_auth_plugin,
41
- unauthenticated_exc_handler,
42
- )
38
+ from ..utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
43
39
 
44
40
 
45
41
  def login( # pylint: disable=R0914
@@ -96,7 +92,7 @@ def login( # pylint: disable=R0914
96
92
  stub = ExecStub(channel)
97
93
 
98
94
  login_request = GetLoginDetailsRequest()
99
- with unauthenticated_exc_handler():
95
+ with flwr_cli_grpc_exc_handler():
100
96
  login_response: GetLoginDetailsResponse = stub.GetLoginDetails(login_request)
101
97
 
102
98
  # Get the auth plugin
@@ -120,7 +116,7 @@ def login( # pylint: disable=R0914
120
116
  expires_in=login_response.expires_in,
121
117
  interval=login_response.interval,
122
118
  )
123
- with unauthenticated_exc_handler():
119
+ with flwr_cli_grpc_exc_handler():
124
120
  credentials = auth_plugin.login(details, stub)
125
121
 
126
122
  # Store the tokens
flwr/cli/ls.py CHANGED
@@ -44,7 +44,7 @@ from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
44
44
  )
45
45
  from flwr.proto.exec_pb2_grpc import ExecStub
46
46
 
47
- from .utils import init_channel, try_obtain_cli_auth_plugin, unauthenticated_exc_handler
47
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
48
48
 
49
49
  _RunListType = tuple[int, str, str, str, str, str, str, str, str]
50
50
 
@@ -305,7 +305,7 @@ def _list_runs(
305
305
  output_format: str = CliOutputFormat.DEFAULT,
306
306
  ) -> None:
307
307
  """List all runs."""
308
- with unauthenticated_exc_handler():
308
+ with flwr_cli_grpc_exc_handler():
309
309
  res: ListRunsResponse = stub.ListRuns(ListRunsRequest())
310
310
  run_dict = {run_id: run_from_proto(proto) for run_id, proto in res.run_dict.items()}
311
311
 
@@ -322,7 +322,7 @@ def _display_one_run(
322
322
  output_format: str = CliOutputFormat.DEFAULT,
323
323
  ) -> None:
324
324
  """Display information about a specific run."""
325
- with unauthenticated_exc_handler():
325
+ with flwr_cli_grpc_exc_handler():
326
326
  res: ListRunsResponse = stub.ListRuns(ListRunsRequest(run_id=run_id))
327
327
  if not res.run_dict:
328
328
  raise ValueError(f"Run ID {run_id} not found")
flwr/cli/run/run.py CHANGED
@@ -45,11 +45,7 @@ from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
45
45
  from flwr.proto.exec_pb2_grpc import ExecStub
46
46
 
47
47
  from ..log import start_stream
48
- from ..utils import (
49
- init_channel,
50
- try_obtain_cli_auth_plugin,
51
- unauthenticated_exc_handler,
52
- )
48
+ from ..utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
53
49
 
54
50
  CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)
55
51
 
@@ -172,7 +168,7 @@ def _run_with_exec_api(
172
168
  override_config=user_config_to_proto(parse_config_args(config_overrides)),
173
169
  federation_options=config_record_to_proto(c_record),
174
170
  )
175
- with unauthenticated_exc_handler():
171
+ with flwr_cli_grpc_exc_handler():
176
172
  res = stub.StartRun(req)
177
173
 
178
174
  if res.HasField("run_id"):
flwr/cli/stop.py CHANGED
@@ -35,7 +35,7 @@ from flwr.common.logger import print_json_error, redirect_output, restore_output
35
35
  from flwr.proto.exec_pb2 import StopRunRequest, StopRunResponse # pylint: disable=E0611
36
36
  from flwr.proto.exec_pb2_grpc import ExecStub
37
37
 
38
- from .utils import init_channel, try_obtain_cli_auth_plugin, unauthenticated_exc_handler
38
+ from .utils import flwr_cli_grpc_exc_handler, init_channel, try_obtain_cli_auth_plugin
39
39
 
40
40
 
41
41
  def stop( # pylint: disable=R0914
@@ -122,7 +122,7 @@ def stop( # pylint: disable=R0914
122
122
 
123
123
  def _stop_run(stub: ExecStub, run_id: int, output_format: str) -> None:
124
124
  """Stop a run."""
125
- with unauthenticated_exc_handler():
125
+ with flwr_cli_grpc_exc_handler():
126
126
  response: StopRunResponse = stub.StopRun(request=StopRunRequest(run_id=run_id))
127
127
  if response.success:
128
128
  typer.secho(f"✅ Run {run_id} successfully stopped.", fg=typer.colors.GREEN)
flwr/cli/utils.py CHANGED
@@ -288,11 +288,12 @@ def init_channel(
288
288
 
289
289
 
290
290
  @contextmanager
291
- def unauthenticated_exc_handler() -> Iterator[None]:
292
- """Context manager to handle gRPC UNAUTHENTICATED errors.
291
+ def flwr_cli_grpc_exc_handler() -> Iterator[None]:
292
+ """Context manager to handle specific gRPC errors.
293
293
 
294
- It catches grpc.RpcError exceptions with UNAUTHENTICATED status, informs the user,
295
- and exits the application. All other exceptions will be allowed to escape.
294
+ It catches grpc.RpcError exceptions with UNAUTHENTICATED and UNIMPLEMENTED statuses,
295
+ informs the user, and exits the application. All other exceptions will be allowed to
296
+ escape.
296
297
  """
297
298
  try:
298
299
  yield
@@ -279,6 +279,8 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
279
279
  log(ERROR, "No current message")
280
280
  return
281
281
 
282
+ # Set message_id
283
+ message.metadata.__dict__["_message_id"] = message.object_id
282
284
  # Validate out message
283
285
  if not validate_out_message(message, metadata):
284
286
  log(ERROR, "Invalid out message")
@@ -164,7 +164,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
164
164
  in_meta = in_message_metadata
165
165
  if ( # pylint: disable-next=too-many-boolean-expressions
166
166
  out_meta.run_id == in_meta.run_id
167
- and out_meta.message_id == "" # This will be generated by the server
167
+ and out_meta.message_id == out_message.object_id # Should match the object id
168
168
  and out_meta.src_node_id == in_meta.dst_node_id
169
169
  and out_meta.dst_node_id == in_meta.src_node_id
170
170
  and out_meta.reply_to_message_id == in_meta.message_id
@@ -17,8 +17,10 @@
17
17
 
18
18
  from .auth_plugin import CliAuthPlugin as CliAuthPlugin
19
19
  from .auth_plugin import ExecAuthPlugin as ExecAuthPlugin
20
+ from .auth_plugin import ExecAuthzPlugin as ExecAuthzPlugin
20
21
 
21
22
  __all__ = [
22
23
  "CliAuthPlugin",
23
24
  "ExecAuthPlugin",
25
+ "ExecAuthzPlugin",
24
26
  ]
@@ -64,6 +64,24 @@ class ExecAuthPlugin(ABC):
64
64
  """Refresh authentication tokens in the provided metadata."""
65
65
 
66
66
 
67
+ class ExecAuthzPlugin(ABC): # pylint: disable=too-few-public-methods
68
+ """Abstract Flower Authorization Plugin class for ExecServicer.
69
+
70
+ Parameters
71
+ ----------
72
+ user_authz_config_path : Path
73
+ Path to the YAML file containing the authorization configuration.
74
+ """
75
+
76
+ @abstractmethod
77
+ def __init__(self, user_authz_config_path: Path, verify_tls_cert: bool):
78
+ """Abstract constructor."""
79
+
80
+ @abstractmethod
81
+ def verify_user_authorization(self, user_info: UserInfo) -> bool:
82
+ """Verify user authorization request."""
83
+
84
+
67
85
  class CliAuthPlugin(ABC):
68
86
  """Abstract Flower Auth Plugin class for CLI.
69
87
 
flwr/common/constant.py CHANGED
@@ -115,6 +115,9 @@ AUTH_TYPE_YAML_KEY = "auth_type" # For key name in YAML file
115
115
  ACCESS_TOKEN_KEY = "flwr-oidc-access-token"
116
116
  REFRESH_TOKEN_KEY = "flwr-oidc-refresh-token"
117
117
 
118
+ # Constants for user authorization
119
+ AUTHZ_TYPE_YAML_KEY = "authz_type" # For key name in YAML file
120
+
118
121
  # Constants for node authentication
119
122
  PUBLIC_KEY_HEADER = "flwr-public-key-bin" # Must end with "-bin" for binary data
120
123
  SIGNATURE_HEADER = "flwr-signature-bin" # Must end with "-bin" for binary data
flwr/common/inflatable.py CHANGED
@@ -18,7 +18,7 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import hashlib
21
- from typing import TypeVar
21
+ from typing import TypeVar, cast
22
22
 
23
23
  from .constant import HEAD_BODY_DIVIDER, HEAD_VALUE_DIVIDER
24
24
 
@@ -55,13 +55,24 @@ class InflatableObject:
55
55
  @property
56
56
  def object_id(self) -> str:
57
57
  """Get object_id."""
58
- return get_object_id(self.deflate())
58
+ if self.is_dirty or "_object_id" not in self.__dict__:
59
+ self.__dict__["_object_id"] = get_object_id(self.deflate())
60
+ return cast(str, self.__dict__["_object_id"])
59
61
 
60
62
  @property
61
63
  def children(self) -> dict[str, InflatableObject] | None:
62
64
  """Get all child objects as a dictionary or None if there are no children."""
63
65
  return None
64
66
 
67
+ @property
68
+ def is_dirty(self) -> bool:
69
+ """Check if the object is dirty after the last deflation.
70
+
71
+ An object is considered dirty if its content has changed since the last its
72
+ object ID was computed.
73
+ """
74
+ return True
75
+
65
76
 
66
77
  T = TypeVar("T", bound=InflatableObject)
67
78
 
@@ -178,3 +189,23 @@ def get_object_head_values_from_object_content(
178
189
  obj_type, children_str, body_len = head.split(HEAD_VALUE_DIVIDER)
179
190
  children_ids = children_str.split(",") if children_str else []
180
191
  return obj_type, children_ids, int(body_len)
192
+
193
+
194
+ def _get_descendants_object_ids_recursively(obj: InflatableObject) -> set[str]:
195
+
196
+ descendants: set[str] = set()
197
+ if children := obj.children:
198
+ for child in children.values():
199
+ descendants |= _get_descendants_object_ids_recursively(child)
200
+
201
+ descendants.add(obj.object_id)
202
+
203
+ return descendants
204
+
205
+
206
+ def get_desdendant_object_ids(obj: InflatableObject) -> set[str]:
207
+ """Get a set of object IDs of all descendants."""
208
+ descendants = _get_descendants_object_ids_recursively(obj)
209
+ # Exclude Object ID of parent object
210
+ descendants.discard(obj.object_id)
211
+ return descendants
flwr/common/message.py CHANGED
@@ -23,6 +23,7 @@ from typing import Any, cast, overload
23
23
  from flwr.common.date import now
24
24
  from flwr.common.logger import warn_deprecated_feature
25
25
  from flwr.proto.message_pb2 import Message as ProtoMessage # pylint: disable=E0611
26
+ from flwr.proto.message_pb2 import Metadata as ProtoMetadata # pylint: disable=E0611
26
27
 
27
28
  from ..app.error import Error
28
29
  from ..app.metadata import Metadata
@@ -351,9 +352,12 @@ class Message(InflatableObject):
351
352
 
352
353
  def deflate(self) -> bytes:
353
354
  """Deflate message."""
355
+ # Exclude message_id from serialization
356
+ proto_metadata: ProtoMetadata = metadata_to_proto(self.metadata)
357
+ proto_metadata.message_id = ""
354
358
  # Store message metadata and error in object body
355
359
  obj_body = ProtoMessage(
356
- metadata=metadata_to_proto(self.metadata),
360
+ metadata=proto_metadata,
357
361
  content=None,
358
362
  error=error_to_proto(self.error) if self.has_error() else None,
359
363
  ).SerializeToString(deterministic=True)
@@ -107,10 +107,21 @@ class Array(InflatableObject):
107
107
  """
108
108
 
109
109
  dtype: str
110
- shape: list[int]
111
110
  stype: str
112
111
  data: bytes
113
112
 
113
+ @property
114
+ def shape(self) -> list[int]:
115
+ """Get the shape of the array."""
116
+ self.is_dirty = True # Mark as dirty when shape is accessed
117
+ return cast(list[int], self.__dict__["_shape"])
118
+
119
+ @shape.setter
120
+ def shape(self, value: list[int]) -> None:
121
+ """Set the shape of the array."""
122
+ self.is_dirty = True # Mark as dirty when shape is set
123
+ self.__dict__["_shape"] = value
124
+
114
125
  @overload
115
126
  def __init__( # noqa: E704
116
127
  self, dtype: str, shape: list[int], stype: str, data: bytes
@@ -295,3 +306,29 @@ class Array(InflatableObject):
295
306
  stype=proto_array.stype,
296
307
  data=proto_array.data,
297
308
  )
309
+
310
+ @property
311
+ def object_id(self) -> str:
312
+ """Get object ID."""
313
+ ret = super().object_id
314
+ self.is_dirty = False # Reset dirty flag
315
+ return ret
316
+
317
+ @property
318
+ def is_dirty(self) -> bool:
319
+ """Check if the object is dirty after the last deflation."""
320
+ if "_is_dirty" not in self.__dict__:
321
+ self.__dict__["_is_dirty"] = True
322
+ return cast(bool, self.__dict__["_is_dirty"])
323
+
324
+ @is_dirty.setter
325
+ def is_dirty(self, value: bool) -> None:
326
+ """Set the dirty flag."""
327
+ self.__dict__["_is_dirty"] = value
328
+
329
+ def __setattr__(self, name: str, value: Any) -> None:
330
+ """Set attribute with special handling for dirty state."""
331
+ if name in ("dtype", "stype", "data"):
332
+ # Mark as dirty if any of the main attributes are set
333
+ self.is_dirty = True
334
+ super().__setattr__(name, value)
@@ -429,6 +429,40 @@ class ArrayRecord(TypedDict[str, Array], InflatableObject):
429
429
  )
430
430
  )
431
431
 
432
+ @property
433
+ def object_id(self) -> str:
434
+ """Get object ID."""
435
+ ret = super().object_id
436
+ self.is_dirty = False # Reset dirty flag
437
+ return ret
438
+
439
+ @property
440
+ def is_dirty(self) -> bool:
441
+ """Check if the object is dirty after the last deflation."""
442
+ if "_is_dirty" not in self.__dict__:
443
+ self.__dict__["_is_dirty"] = True
444
+
445
+ if not self.__dict__["_is_dirty"]:
446
+ if any(v.is_dirty for v in self.values()):
447
+ # If any Array is dirty, mark the record as dirty
448
+ self.__dict__["_is_dirty"] = True
449
+ return cast(bool, self.__dict__["_is_dirty"])
450
+
451
+ @is_dirty.setter
452
+ def is_dirty(self, value: bool) -> None:
453
+ """Set the dirty flag."""
454
+ self.__dict__["_is_dirty"] = value
455
+
456
+ def __setitem__(self, key: str, value: Array) -> None:
457
+ """Set item and mark the record as dirty."""
458
+ self.is_dirty = True # Mark as dirty when setting an item
459
+ super().__setitem__(key, value)
460
+
461
+ def __delitem__(self, key: str) -> None:
462
+ """Delete item and mark the record as dirty."""
463
+ self.is_dirty = True # Mark as dirty when deleting an item
464
+ super().__delitem__(key)
465
+
432
466
 
433
467
  class ParametersRecord(ArrayRecord):
434
468
  """Deprecated class ``ParametersRecord``, use ``ArrayRecord`` instead.
flwr/common/serde.py CHANGED
@@ -378,7 +378,12 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
378
378
 
379
379
  def array_to_proto(array: Array) -> ProtoArray:
380
380
  """Serialize Array to ProtoBuf."""
381
- return ProtoArray(**vars(array))
381
+ return ProtoArray(
382
+ dtype=array.dtype,
383
+ shape=array.shape,
384
+ stype=array.stype,
385
+ data=array.data,
386
+ )
382
387
 
383
388
 
384
389
  def array_from_proto(array_proto: ProtoArray) -> Array:
flwr/compat/client/app.py CHANGED
@@ -15,18 +15,12 @@
15
15
  """Flower client app."""
16
16
 
17
17
 
18
- import multiprocessing
19
- import os
20
- import sys
21
- import threading
22
18
  import time
23
19
  from contextlib import AbstractContextManager
24
20
  from logging import ERROR, INFO, WARN
25
- from os import urandom
26
21
  from pathlib import Path
27
- from typing import Callable, Optional, Union, cast
22
+ from typing import Callable, Optional, Union
28
23
 
29
- import grpc
30
24
  from cryptography.hazmat.primitives.asymmetric import ec
31
25
  from grpc import RpcError
32
26
 
@@ -35,11 +29,6 @@ from flwr.cli.config_utils import get_fab_metadata
35
29
  from flwr.cli.install import install_from_fab
36
30
  from flwr.client.client import Client
37
31
  from flwr.client.client_app import ClientApp, LoadClientAppError
38
- from flwr.client.clientapp.app import flwr_clientapp
39
- from flwr.client.clientapp.clientappio_servicer import (
40
- ClientAppInputs,
41
- ClientAppIoServicer,
42
- )
43
32
  from flwr.client.grpc_adapter_client.connection import grpc_adapter
44
33
  from flwr.client.grpc_rere_client.connection import grpc_request_response
45
34
  from flwr.client.message_handler.message_handler import handle_control_message
@@ -49,13 +38,7 @@ from flwr.client.typing import ClientFnExt
49
38
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, event
50
39
  from flwr.common.address import parse_address
51
40
  from flwr.common.constant import (
52
- CLIENT_OCTET,
53
- CLIENTAPPIO_API_DEFAULT_SERVER_ADDRESS,
54
- ISOLATION_MODE_PROCESS,
55
- ISOLATION_MODE_SUBPROCESS,
56
41
  MAX_RETRY_DELAY,
57
- RUN_ID_NUM_BYTES,
58
- SERVER_OCTET,
59
42
  TRANSPORT_TYPE_GRPC_ADAPTER,
60
43
  TRANSPORT_TYPE_GRPC_BIDI,
61
44
  TRANSPORT_TYPE_GRPC_RERE,
@@ -64,12 +47,10 @@ from flwr.common.constant import (
64
47
  ErrorCode,
65
48
  )
66
49
  from flwr.common.exit import ExitCode, flwr_exit
67
- from flwr.common.grpc import generic_create_grpc_server
68
50
  from flwr.common.logger import log, warn_deprecated_feature
69
51
  from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
70
52
  from flwr.common.typing import Fab, Run, RunNotRunningException, UserConfig
71
53
  from flwr.compat.client.grpc_client.connection import grpc_connection
72
- from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server
73
54
  from flwr.supernode.nodestate import NodeStateFactory
74
55
 
75
56
 
@@ -238,8 +219,6 @@ def start_client_internal(
238
219
  max_retries: Optional[int] = None,
239
220
  max_wait_time: Optional[float] = None,
240
221
  flwr_path: Optional[Path] = None,
241
- isolation: Optional[str] = None,
242
- clientappio_api_address: Optional[str] = CLIENTAPPIO_API_DEFAULT_SERVER_ADDRESS,
243
222
  ) -> None:
244
223
  """Start a Flower client node which connects to a Flower server.
245
224
 
@@ -292,17 +271,6 @@ def start_client_internal(
292
271
  If set to None, there is no limit to the total time.
293
272
  flwr_path: Optional[Path] (default: None)
294
273
  The fully resolved path containing installed Flower Apps.
295
- isolation : Optional[str] (default: None)
296
- Isolation mode for `ClientApp`. Possible values are `subprocess` and
297
- `process`. Defaults to `None`, which runs the `ClientApp` in the same process
298
- as the SuperNode. If `subprocess`, the `ClientApp` runs in a subprocess started
299
- by the SueprNode and communicates using gRPC at the address
300
- `clientappio_api_address`. If `process`, the `ClientApp` runs in a separate
301
- isolated process and communicates using gRPC at the address
302
- `clientappio_api_address`.
303
- clientappio_api_address : Optional[str]
304
- (default: `CLIENTAPPIO_API_DEFAULT_SERVER_ADDRESS`)
305
- The SuperNode gRPC server address.
306
274
  """
307
275
  if insecure is None:
308
276
  insecure = root_certificates is None
@@ -328,18 +296,6 @@ def start_client_internal(
328
296
 
329
297
  load_client_app_fn = _load_client_app
330
298
 
331
- if isolation:
332
- if clientappio_api_address is None:
333
- raise ValueError(
334
- f"`clientappio_api_address` required when `isolation` is "
335
- f"{ISOLATION_MODE_SUBPROCESS} or {ISOLATION_MODE_PROCESS}",
336
- )
337
- _clientappio_grpc_server, clientappio_servicer = run_clientappio_api_grpc(
338
- address=clientappio_api_address,
339
- certificates=None,
340
- )
341
- clientappio_api_address = cast(str, clientappio_api_address)
342
-
343
299
  # At this point, only `load_client_app_fn` should be used
344
300
  # Both `client` and `client_fn` must not be used directly
345
301
 
@@ -390,7 +346,6 @@ def start_client_internal(
390
346
  run_info_store: Optional[DeprecatedRunInfoStore] = None
391
347
  state_factory = NodeStateFactory()
392
348
  state = state_factory.state()
393
- mp_spawn_context = multiprocessing.get_context("spawn")
394
349
 
395
350
  runs: dict[int, Run] = {}
396
351
 
@@ -475,9 +430,8 @@ def start_client_internal(
475
430
  run: Run = runs[run_id]
476
431
  if get_fab is not None and run.fab_hash:
477
432
  fab = get_fab(run.fab_hash, run_id)
478
- if not isolation:
479
- # If `ClientApp` runs in the same process, install the FAB
480
- install_from_fab(fab.content, flwr_path, True)
433
+ # If `ClientApp` runs in the same process, install the FAB
434
+ install_from_fab(fab.content, flwr_path, True)
481
435
  fab_id, fab_version = get_fab_metadata(fab.content)
482
436
  else:
483
437
  fab = None
@@ -504,73 +458,13 @@ def start_client_internal(
504
458
 
505
459
  # Handle app loading and task message
506
460
  try:
507
- if isolation:
508
- # Two isolation modes:
509
- # 1. `subprocess`: SuperNode is starting the ClientApp
510
- # process as a subprocess.
511
- # 2. `process`: ClientApp process gets started separately
512
- # (via `flwr-clientapp`), for example, in a separate
513
- # Docker container.
514
-
515
- # Generate SuperNode token
516
- token = int.from_bytes(urandom(RUN_ID_NUM_BYTES), "little")
517
-
518
- # Mode 1: SuperNode starts ClientApp as subprocess
519
- start_subprocess = isolation == ISOLATION_MODE_SUBPROCESS
520
-
521
- # Share Message and Context with servicer
522
- clientappio_servicer.set_inputs(
523
- clientapp_input=ClientAppInputs(
524
- message=message,
525
- context=context,
526
- run=run,
527
- fab=fab,
528
- token=token,
529
- ),
530
- token_returned=start_subprocess,
531
- )
532
-
533
- if start_subprocess:
534
- _octet, _colon, _port = (
535
- clientappio_api_address.rpartition(":")
536
- )
537
- io_address = (
538
- f"{CLIENT_OCTET}:{_port}"
539
- if _octet == SERVER_OCTET
540
- else clientappio_api_address
541
- )
542
- # Start ClientApp subprocess
543
- command = [
544
- "flwr-clientapp",
545
- "--clientappio-api-address",
546
- io_address,
547
- "--token",
548
- str(token),
549
- ]
550
- command.append("--insecure")
551
-
552
- proc = mp_spawn_context.Process(
553
- target=_run_flwr_clientapp,
554
- args=(command, os.getpid()),
555
- daemon=True,
556
- )
557
- proc.start()
558
- proc.join()
559
- else:
560
- # Wait for output to become available
561
- while not clientappio_servicer.has_outputs():
562
- time.sleep(0.1)
563
-
564
- outputs = clientappio_servicer.get_outputs()
565
- reply_message, context = outputs.message, outputs.context
566
- else:
567
- # Load ClientApp instance
568
- client_app: ClientApp = load_client_app_fn(
569
- fab_id, fab_version, run.fab_hash
570
- )
461
+ # Load ClientApp instance
462
+ client_app: ClientApp = load_client_app_fn(
463
+ fab_id, fab_version, run.fab_hash
464
+ )
571
465
 
572
- # Execute ClientApp
573
- reply_message = client_app(message=message, context=context)
466
+ # Execute ClientApp
467
+ reply_message = client_app(message=message, context=context)
574
468
  except Exception as ex: # pylint: disable=broad-exception-caught
575
469
 
576
470
  # Legacy grpc-bidi
@@ -801,39 +695,3 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
801
695
  )
802
696
 
803
697
  return connection, address, error_type
804
-
805
-
806
- def _run_flwr_clientapp(args: list[str], main_pid: int) -> None:
807
- # Monitor the main process in case of SIGKILL
808
- def main_process_monitor() -> None:
809
- while True:
810
- time.sleep(1)
811
- if os.getppid() != main_pid:
812
- os.kill(os.getpid(), 9)
813
-
814
- threading.Thread(target=main_process_monitor, daemon=True).start()
815
-
816
- # Run the command
817
- sys.argv = args
818
- flwr_clientapp()
819
-
820
-
821
- def run_clientappio_api_grpc(
822
- address: str,
823
- certificates: Optional[tuple[bytes, bytes, bytes]],
824
- ) -> tuple[grpc.Server, ClientAppIoServicer]:
825
- """Run ClientAppIo API gRPC server."""
826
- clientappio_servicer: grpc.Server = ClientAppIoServicer()
827
- clientappio_add_servicer_to_server_fn = add_ClientAppIoServicer_to_server
828
- clientappio_grpc_server = generic_create_grpc_server(
829
- servicer_and_add_fn=(
830
- clientappio_servicer,
831
- clientappio_add_servicer_to_server_fn,
832
- ),
833
- server_address=address,
834
- max_message_length=GRPC_MAX_MESSAGE_LENGTH,
835
- certificates=certificates,
836
- )
837
- log(INFO, "Starting Flower ClientAppIo gRPC server on %s", address)
838
- clientappio_grpc_server.start()
839
- return clientappio_grpc_server, clientappio_servicer