flwr 1.19.0__py3-none-any.whl → 1.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. flwr/cli/build.py +15 -5
  2. flwr/cli/new/new.py +12 -4
  3. flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
  4. flwr/cli/new/templates/app/README.md.tpl +5 -0
  5. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +14 -3
  6. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
  7. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
  8. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
  9. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
  10. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
  11. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
  12. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
  13. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
  14. flwr/cli/run/run.py +45 -38
  15. flwr/cli/utils.py +12 -5
  16. flwr/client/grpc_adapter_client/connection.py +11 -4
  17. flwr/client/grpc_rere_client/connection.py +92 -117
  18. flwr/client/rest_client/connection.py +131 -164
  19. flwr/common/constant.py +3 -1
  20. flwr/common/exit/exit_code.py +16 -1
  21. flwr/common/grpc.py +12 -1
  22. flwr/common/{inflatable_grpc_utils.py → inflatable_protobuf_utils.py} +52 -10
  23. flwr/common/inflatable_utils.py +191 -24
  24. flwr/common/record/array.py +101 -22
  25. flwr/common/record/arraychunk.py +59 -0
  26. flwr/common/serde.py +0 -28
  27. flwr/compat/client/app.py +14 -31
  28. flwr/proto/appio_pb2.py +43 -0
  29. flwr/proto/appio_pb2.pyi +151 -0
  30. flwr/proto/appio_pb2_grpc.py +4 -0
  31. flwr/proto/appio_pb2_grpc.pyi +4 -0
  32. flwr/proto/clientappio_pb2.py +12 -19
  33. flwr/proto/clientappio_pb2.pyi +23 -101
  34. flwr/proto/clientappio_pb2_grpc.py +269 -28
  35. flwr/proto/clientappio_pb2_grpc.pyi +114 -20
  36. flwr/proto/fleet_pb2.py +12 -20
  37. flwr/proto/fleet_pb2.pyi +6 -36
  38. flwr/proto/serverappio_pb2.py +8 -31
  39. flwr/proto/serverappio_pb2.pyi +0 -152
  40. flwr/proto/serverappio_pb2_grpc.py +39 -38
  41. flwr/proto/serverappio_pb2_grpc.pyi +21 -20
  42. flwr/server/app.py +1 -1
  43. flwr/server/fleet_event_log_interceptor.py +4 -0
  44. flwr/server/grid/grpc_grid.py +91 -54
  45. flwr/server/serverapp/app.py +27 -17
  46. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +8 -0
  47. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +1 -1
  48. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
  49. flwr/server/superlink/fleet/message_handler/message_handler.py +10 -16
  50. flwr/server/superlink/fleet/rest_rere/rest_api.py +1 -2
  51. flwr/server/superlink/serverappio/serverappio_grpc.py +1 -1
  52. flwr/server/superlink/serverappio/serverappio_servicer.py +35 -43
  53. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  54. flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
  55. flwr/server/superlink/utils.py +0 -35
  56. flwr/simulation/app.py +8 -0
  57. flwr/simulation/run_simulation.py +17 -0
  58. flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
  59. flwr/supercore/grpc_health/__init__.py +22 -0
  60. flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
  61. flwr/supercore/license_plugin/__init__.py +22 -0
  62. flwr/supercore/license_plugin/license_plugin.py +26 -0
  63. flwr/supercore/object_store/in_memory_object_store.py +31 -31
  64. flwr/supercore/object_store/object_store.py +20 -42
  65. flwr/supercore/object_store/utils.py +43 -0
  66. flwr/supercore/scheduler/__init__.py +22 -0
  67. flwr/supercore/scheduler/plugin.py +71 -0
  68. flwr/supercore/utils.py +32 -0
  69. flwr/superexec/deployment.py +1 -2
  70. flwr/superexec/exec_event_log_interceptor.py +4 -0
  71. flwr/superexec/exec_grpc.py +18 -2
  72. flwr/superexec/exec_license_interceptor.py +82 -0
  73. flwr/superexec/exec_servicer.py +10 -1
  74. flwr/superexec/exec_user_auth_interceptor.py +10 -2
  75. flwr/superexec/executor.py +1 -1
  76. flwr/superexec/simulation.py +1 -2
  77. flwr/supernode/cli/flower_supernode.py +0 -7
  78. flwr/supernode/cli/flwr_clientapp.py +10 -3
  79. flwr/supernode/nodestate/in_memory_nodestate.py +11 -2
  80. flwr/supernode/nodestate/nodestate.py +15 -0
  81. flwr/supernode/runtime/run_clientapp.py +110 -33
  82. flwr/supernode/scheduler/__init__.py +22 -0
  83. flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
  84. flwr/supernode/servicer/clientappio/__init__.py +1 -3
  85. flwr/supernode/servicer/clientappio/clientappio_servicer.py +223 -164
  86. flwr/supernode/start_client_internal.py +202 -104
  87. {flwr-1.19.0.dist-info → flwr-1.20.0.dist-info}/METADATA +2 -1
  88. {flwr-1.19.0.dist-info → flwr-1.20.0.dist-info}/RECORD +93 -78
  89. flwr/common/inflatable_rest_utils.py +0 -99
  90. /flwr/{server/superlink → supercore}/ffs/__init__.py +0 -0
  91. /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
  92. /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
  93. {flwr-1.19.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +0 -0
  94. {flwr-1.19.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +0 -0
@@ -15,10 +15,14 @@
15
15
  """InflatableObject utilities."""
16
16
 
17
17
  import concurrent.futures
18
+ import os
18
19
  import random
19
20
  import threading
20
21
  import time
21
- from typing import Callable, Optional
22
+ from collections.abc import Iterable, Iterator
23
+ from typing import Callable, Optional, TypeVar
24
+
25
+ from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
22
26
 
23
27
  from .constant import (
24
28
  HEAD_BODY_DIVIDER,
@@ -30,6 +34,7 @@ from .constant import (
30
34
  PULL_MAX_TIME,
31
35
  PULL_MAX_TRIES_PER_OBJECT,
32
36
  )
37
+ from .exit_handlers import add_exit_handler
33
38
  from .inflatable import (
34
39
  InflatableObject,
35
40
  UnexpectedObjectContentError,
@@ -37,12 +42,15 @@ from .inflatable import (
37
42
  get_object_head_values_from_object_content,
38
43
  get_object_id,
39
44
  is_valid_sha256_hash,
45
+ iterate_object_tree,
40
46
  )
41
47
  from .message import Message
42
48
  from .record import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
49
+ from .record.arraychunk import ArrayChunk
43
50
 
44
51
  # Helper registry that maps names of classes to their type
45
52
  inflatable_class_registry: dict[str, type[InflatableObject]] = {
53
+ ArrayChunk.__qualname__: ArrayChunk,
46
54
  Array.__qualname__: Array,
47
55
  ArrayRecord.__qualname__: ArrayRecord,
48
56
  ConfigRecord.__qualname__: ConfigRecord,
@@ -51,6 +59,36 @@ inflatable_class_registry: dict[str, type[InflatableObject]] = {
51
59
  RecordDict.__qualname__: RecordDict,
52
60
  }
53
61
 
62
+ T = TypeVar("T", bound=InflatableObject)
63
+
64
+
65
+ # Allow thread pool executors to be shut down gracefully
66
+ _thread_pool_executors: set[concurrent.futures.ThreadPoolExecutor] = set()
67
+ _lock = threading.Lock()
68
+
69
+
70
+ def _shutdown_thread_pool_executors() -> None:
71
+ """Shutdown all thread pool executors gracefully."""
72
+ with _lock:
73
+ for executor in _thread_pool_executors:
74
+ executor.shutdown(wait=False, cancel_futures=True)
75
+ _thread_pool_executors.clear()
76
+
77
+
78
+ def _track_executor(executor: concurrent.futures.ThreadPoolExecutor) -> None:
79
+ """Track a thread pool executor for graceful shutdown."""
80
+ with _lock:
81
+ _thread_pool_executors.add(executor)
82
+
83
+
84
+ def _untrack_executor(executor: concurrent.futures.ThreadPoolExecutor) -> None:
85
+ """Untrack a thread pool executor."""
86
+ with _lock:
87
+ _thread_pool_executors.discard(executor)
88
+
89
+
90
+ add_exit_handler(_shutdown_thread_pool_executors)
91
+
54
92
 
55
93
  class ObjectUnavailableError(Exception):
56
94
  """Exception raised when an object has been pre-registered but is not yet
@@ -67,6 +105,13 @@ class ObjectIdNotPreregisteredError(Exception):
67
105
  super().__init__(f"Object with ID '{object_id}' could not be found.")
68
106
 
69
107
 
108
+ def get_num_workers(max_concurrent: int) -> int:
109
+ """Get number of workers based on the number of CPU cores and the maximum
110
+ allowed."""
111
+ num_cores = os.cpu_count() or 1
112
+ return min(max_concurrent, num_cores)
113
+
114
+
70
115
  def push_objects(
71
116
  objects: dict[str, InflatableObject],
72
117
  push_object_fn: Callable[[str, bytes], None],
@@ -95,27 +140,73 @@ def push_objects(
95
140
  max_concurrent_pushes : int (default: MAX_CONCURRENT_PUSHES)
96
141
  The maximum number of concurrent pushes to perform.
97
142
  """
98
- if object_ids_to_push is not None:
99
- # Filter objects to push only those with IDs in the set
100
- objects = {k: v for k, v in objects.items() if k in object_ids_to_push}
101
-
102
143
  lock = threading.Lock()
103
144
 
104
- def push(obj_id: str) -> None:
145
+ def iter_dict_items() -> Iterator[tuple[str, bytes]]:
146
+ """Iterate over the dictionary items."""
147
+ for obj_id in list(objects.keys()):
148
+ # Skip the object if no need to push it
149
+ if object_ids_to_push is not None and obj_id not in object_ids_to_push:
150
+ continue
151
+
152
+ # Deflate the object content
153
+ object_content = objects[obj_id].deflate()
154
+ if not keep_objects:
155
+ with lock:
156
+ del objects[obj_id]
157
+
158
+ yield obj_id, object_content
159
+
160
+ push_object_contents_from_iterable(
161
+ iter_dict_items(),
162
+ push_object_fn,
163
+ max_concurrent_pushes=max_concurrent_pushes,
164
+ )
165
+
166
+
167
+ def push_object_contents_from_iterable(
168
+ object_contents: Iterable[tuple[str, bytes]],
169
+ push_object_fn: Callable[[str, bytes], None],
170
+ *,
171
+ max_concurrent_pushes: int = MAX_CONCURRENT_PUSHES,
172
+ ) -> None:
173
+ """Push multiple object contents to the servicer.
174
+
175
+ Parameters
176
+ ----------
177
+ object_contents : Iterable[tuple[str, bytes]]
178
+ An iterable of `(object_id, object_content)` pairs.
179
+ `object_id` is the object ID, and `object_content` is the object content.
180
+ push_object_fn : Callable[[str, bytes], None]
181
+ A function that takes an object ID and its content as bytes, and pushes
182
+ it to the servicer. This function should raise `ObjectIdNotPreregisteredError`
183
+ if the object ID is not pre-registered.
184
+ max_concurrent_pushes : int (default: MAX_CONCURRENT_PUSHES)
185
+ The maximum number of concurrent pushes to perform.
186
+ """
187
+
188
+ def push(args: tuple[str, bytes]) -> None:
105
189
  """Push a single object."""
106
- object_content = objects[obj_id].deflate()
107
- if not keep_objects:
108
- with lock:
109
- del objects[obj_id]
110
- push_object_fn(obj_id, object_content)
190
+ obj_id, obj_content = args
191
+ # Push the object using the provided function
192
+ push_object_fn(obj_id, obj_content)
193
+
194
+ # Push all object contents concurrently
195
+ num_workers = get_num_workers(max_concurrent_pushes)
196
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
197
+ # Ensure that the thread pool executors are tracked for graceful shutdown
198
+ _track_executor(executor)
111
199
 
112
- with concurrent.futures.ThreadPoolExecutor(
113
- max_workers=max_concurrent_pushes
114
- ) as executor:
115
- list(executor.map(push, list(objects.keys())))
200
+ # Submit push tasks for each object content
201
+ executor.map(push, object_contents) # Non-blocking map
116
202
 
203
+ # The context manager will block until all submitted tasks have completed
117
204
 
118
- def pull_objects( # pylint: disable=too-many-arguments
205
+ # Remove the executor from the list of tracked executors
206
+ _untrack_executor(executor)
207
+
208
+
209
+ def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
119
210
  object_ids: list[str],
120
211
  pull_object_fn: Callable[[str], bytes],
121
212
  *,
@@ -207,16 +298,20 @@ def pull_objects( # pylint: disable=too-many-arguments
207
298
  return
208
299
 
209
300
  # Submit all pull tasks concurrently
210
- with concurrent.futures.ThreadPoolExecutor(
211
- max_workers=max_concurrent_pulls
212
- ) as executor:
213
- futures = {
214
- executor.submit(pull_with_retries, obj_id): obj_id for obj_id in object_ids
215
- }
301
+ num_workers = get_num_workers(max_concurrent_pulls)
302
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
303
+ # Ensure that the thread pool executors are tracked for graceful shutdown
304
+ _track_executor(executor)
305
+
306
+ # Submit pull tasks for each object ID
307
+ executor.map(pull_with_retries, object_ids) # Non-blocking map
308
+
309
+ # The context manager will block until all submitted tasks have completed
216
310
 
217
- # Wait for completion
218
- concurrent.futures.wait(futures)
311
+ # Remove the executor from the list of tracked executors
312
+ _untrack_executor(executor)
219
313
 
314
+ # If an error occurred during pulling, raise it
220
315
  if err_to_raise is not None:
221
316
  raise err_to_raise
222
317
 
@@ -339,3 +434,75 @@ def validate_object_content(content: bytes) -> None:
339
434
  raise UnexpectedObjectContentError(
340
435
  object_id=get_object_id(content), reason=str(err)
341
436
  ) from err
437
+
438
+
439
+ def pull_and_inflate_object_from_tree( # pylint: disable=R0913
440
+ object_tree: ObjectTree,
441
+ pull_object_fn: Callable[[str], bytes],
442
+ confirm_object_received_fn: Callable[[str], None],
443
+ *,
444
+ return_type: type[T] = InflatableObject, # type: ignore
445
+ max_concurrent_pulls: int = MAX_CONCURRENT_PULLS,
446
+ max_time: Optional[float] = PULL_MAX_TIME,
447
+ max_tries_per_object: Optional[int] = PULL_MAX_TRIES_PER_OBJECT,
448
+ initial_backoff: float = PULL_INITIAL_BACKOFF,
449
+ backoff_cap: float = PULL_BACKOFF_CAP,
450
+ ) -> T:
451
+ """Pull and inflate the head object from the provided object tree.
452
+
453
+ Parameters
454
+ ----------
455
+ object_tree : ObjectTree
456
+ The object tree containing the object ID and its descendants.
457
+ pull_object_fn : Callable[[str], bytes]
458
+ A function that takes an object ID and returns the object content as bytes.
459
+ confirm_object_received_fn : Callable[[str], None]
460
+ A function to confirm that the object has been received.
461
+ return_type : type[T] (default: InflatableObject)
462
+ The type of the object to return. Must be a subclass of `InflatableObject`.
463
+ max_concurrent_pulls : int (default: MAX_CONCURRENT_PULLS)
464
+ The maximum number of concurrent pulls to perform.
465
+ max_time : Optional[float] (default: PULL_MAX_TIME)
466
+ The maximum time to wait for all pulls to complete. If `None`, waits
467
+ indefinitely.
468
+ max_tries_per_object : Optional[int] (default: PULL_MAX_TRIES_PER_OBJECT)
469
+ The maximum number of attempts to pull each object. If `None`, pulls
470
+ indefinitely until the object is available.
471
+ initial_backoff : float (default: PULL_INITIAL_BACKOFF)
472
+ The initial backoff time in seconds for retrying pulls after an
473
+ `ObjectUnavailableError`.
474
+ backoff_cap : float (default: PULL_BACKOFF_CAP)
475
+ The maximum backoff time in seconds. Backoff times will not exceed this value.
476
+
477
+ Returns
478
+ -------
479
+ T
480
+ An instance of the specified return type containing the inflated object.
481
+ """
482
+ # Pull the main object and all its descendants
483
+ pulled_object_contents = pull_objects(
484
+ [tree.object_id for tree in iterate_object_tree(object_tree)],
485
+ pull_object_fn,
486
+ max_concurrent_pulls=max_concurrent_pulls,
487
+ max_time=max_time,
488
+ max_tries_per_object=max_tries_per_object,
489
+ initial_backoff=initial_backoff,
490
+ backoff_cap=backoff_cap,
491
+ )
492
+
493
+ # Confirm that all objects were pulled
494
+ confirm_object_received_fn(object_tree.object_id)
495
+
496
+ # Inflate the main object
497
+ inflated_object = inflate_object_from_contents(
498
+ object_tree.object_id, pulled_object_contents, keep_object_contents=False
499
+ )
500
+
501
+ # Check if the inflated object is of the expected type
502
+ if not isinstance(inflated_object, return_type):
503
+ raise TypeError(
504
+ f"Expected object of type {return_type.__name__}, "
505
+ f"but got {type(inflated_object).__name__}."
506
+ )
507
+
508
+ return inflated_object
@@ -17,6 +17,7 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
+ import json
20
21
  import sys
21
22
  from dataclasses import dataclass
22
23
  from io import BytesIO
@@ -24,11 +25,15 @@ from typing import TYPE_CHECKING, Any, cast, overload
24
25
 
25
26
  import numpy as np
26
27
 
27
- from flwr.proto.recorddict_pb2 import Array as ArrayProto # pylint: disable=E0611
28
-
29
- from ..constant import SType
30
- from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
28
+ from ..constant import MAX_ARRAY_CHUNK_SIZE, SType
29
+ from ..inflatable import (
30
+ InflatableObject,
31
+ add_header_to_object_body,
32
+ get_object_body,
33
+ get_object_children_ids_from_object_content,
34
+ )
31
35
  from ..typing import NDArray
36
+ from .arraychunk import ArrayChunk
32
37
 
33
38
  if TYPE_CHECKING:
34
39
  import torch
@@ -252,16 +257,56 @@ class Array(InflatableObject):
252
257
  ndarray_deserialized = np.load(bytes_io, allow_pickle=False)
253
258
  return cast(NDArray, ndarray_deserialized)
254
259
 
260
+ @property
261
+ def children(self) -> dict[str, InflatableObject]:
262
+ """Return a dictionary of ArrayChunks with their Object IDs as keys."""
263
+ return dict(self.slice_array())
264
+
265
+ def slice_array(self) -> list[tuple[str, InflatableObject]]:
266
+ """Slice Array data and construct a list of ArrayChunks."""
267
+ # Return cached chunks if they exist
268
+ if "_chunks" in self.__dict__:
269
+ return cast(list[tuple[str, InflatableObject]], self.__dict__["_chunks"])
270
+
271
+ # Chunks are not children as some of them may be identical
272
+ chunks: list[tuple[str, InflatableObject]] = []
273
+ # memoryview allows for zero-copy slicing
274
+ data_view = memoryview(self.data)
275
+ for start in range(0, len(data_view), MAX_ARRAY_CHUNK_SIZE):
276
+ end = min(start + MAX_ARRAY_CHUNK_SIZE, len(data_view))
277
+ ac = ArrayChunk(data_view[start:end])
278
+ chunks.append((ac.object_id, ac))
279
+
280
+ # Cache the chunks for future use
281
+ self.__dict__["_chunks"] = chunks
282
+ return chunks
283
+
255
284
  def deflate(self) -> bytes:
256
285
  """Deflate the Array."""
257
- array_proto = ArrayProto(
258
- dtype=self.dtype,
259
- shape=self.shape,
260
- stype=self.stype,
261
- data=self.data,
262
- )
263
-
264
- obj_body = array_proto.SerializeToString(deterministic=True)
286
+ array_metadata: dict[str, str | tuple[int, ...] | list[int]] = {}
287
+
288
+ # We want to record all object_id even if repeated
289
+ # it can happend that chunks carry the exact same data
290
+ # for example when the array has only zeros
291
+ children_list = self.slice_array()
292
+ # Let's not save the entire object_id but a mapping to those
293
+ # that will be carried in the object head
294
+ # (replace a long object_id with a single scalar)
295
+ unique_children = list(self.children.keys())
296
+ arraychunk_ids = [unique_children.index(ch_id) for ch_id, _ in children_list]
297
+
298
+ # The deflated Array carries everything but the data
299
+ # The `arraychunk_ids` will be used during Array inflation
300
+ # to rematerialize the data from ArrayChunk objects.
301
+ array_metadata = {
302
+ "dtype": self.dtype,
303
+ "shape": self.shape,
304
+ "stype": self.stype,
305
+ "arraychunk_ids": arraychunk_ids,
306
+ }
307
+
308
+ # Serialize metadata dict
309
+ obj_body = json.dumps(array_metadata).encode("utf-8")
265
310
  return add_header_to_object_body(object_body=obj_body, obj=self)
266
311
 
267
312
  @classmethod
@@ -276,26 +321,55 @@ class Array(InflatableObject):
276
321
  The deflated object content of the Array.
277
322
 
278
323
  children : Optional[dict[str, InflatableObject]] (default: None)
279
- Must be ``None``. ``Array`` does not support child objects.
280
- Providing any children will raise a ``ValueError``.
324
+ Must be ``None``. ``Array`` must have child objects.
325
+ Providing no children will raise a ``ValueError``.
281
326
 
282
327
  Returns
283
328
  -------
284
329
  Array
285
330
  The inflated Array.
286
331
  """
287
- if children:
288
- raise ValueError("`Array` objects do not have children.")
332
+ if children is None:
333
+ children = {}
289
334
 
290
335
  obj_body = get_object_body(object_content, cls)
291
- proto_array = ArrayProto.FromString(obj_body)
292
- return cls(
293
- dtype=proto_array.dtype,
294
- shape=tuple(proto_array.shape),
295
- stype=proto_array.stype,
296
- data=proto_array.data,
336
+
337
+ # Extract children IDs from head
338
+ children_ids = get_object_children_ids_from_object_content(object_content)
339
+ # Decode the Array body
340
+ array_metadata: dict[str, str | tuple[int, ...] | list[int]] = json.loads(
341
+ obj_body.decode(encoding="utf-8")
297
342
  )
298
343
 
344
+ # Verify children ids in body match those passed for inflation
345
+ chunk_ids_indices = cast(list[int], array_metadata["arraychunk_ids"])
346
+ # Convert indices back to IDs
347
+ chunk_ids = [children_ids[i] for i in chunk_ids_indices]
348
+ # Check consistency
349
+ unique_arrayschunks = set(chunk_ids)
350
+ children_obj_ids = set(children.keys())
351
+ if unique_arrayschunks != children_obj_ids:
352
+ raise ValueError(
353
+ "Unexpected set of `children`. "
354
+ f"Expected {unique_arrayschunks} but got {children_obj_ids}."
355
+ )
356
+
357
+ # Materialize Array with empty data
358
+ array = cls(
359
+ dtype=cast(str, array_metadata["dtype"]),
360
+ shape=cast(tuple[int], tuple(array_metadata["shape"])),
361
+ stype=cast(str, array_metadata["stype"]),
362
+ data=b"",
363
+ )
364
+
365
+ # Now inject data from chunks
366
+ buff = bytearray()
367
+ for ch_id in chunk_ids:
368
+ buff += cast(ArrayChunk, children[ch_id]).data
369
+
370
+ array.data = bytes(buff)
371
+ return array
372
+
299
373
  @property
300
374
  def object_id(self) -> str:
301
375
  """Get object ID."""
@@ -320,4 +394,9 @@ class Array(InflatableObject):
320
394
  if name in ("dtype", "shape", "stype", "data"):
321
395
  # Mark as dirty if any of the main attributes are set
322
396
  self.is_dirty = True
397
+ # Clear cached object ID
398
+ self.__dict__.pop("_object_id", None)
399
+ # Clear cached chunks if data is set
400
+ if name == "data":
401
+ self.__dict__.pop("_chunks", None)
323
402
  super().__setattr__(name, value)
@@ -0,0 +1,59 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """ArrayChunk."""
16
+
17
+
18
+ from __future__ import annotations
19
+
20
+ from dataclasses import dataclass
21
+
22
+ from ..inflatable import InflatableObject, add_header_to_object_body, get_object_body
23
+
24
+
25
+ @dataclass
26
+ class ArrayChunk(InflatableObject):
27
+ """ArrayChunk type."""
28
+
29
+ data: memoryview
30
+
31
+ def deflate(self) -> bytes:
32
+ """Deflate the ArrayChunk."""
33
+ return add_header_to_object_body(object_body=self.data, obj=self)
34
+
35
+ @classmethod
36
+ def inflate(
37
+ cls, object_content: bytes, children: dict[str, InflatableObject] | None = None
38
+ ) -> ArrayChunk:
39
+ """Inflate an ArrayChunk from bytes.
40
+
41
+ Parameters
42
+ ----------
43
+ object_content : bytes
44
+ The deflated object content of the ArrayChunk.
45
+
46
+ children : Optional[dict[str, InflatableObject]] (default: None)
47
+ Must be ``None``. ``ArrayChunk`` does not support child objects.
48
+ Providing any children will raise a ``ValueError``.
49
+
50
+ Returns
51
+ -------
52
+ ArrayChunk
53
+ The inflated ArrayChunk.
54
+ """
55
+ if children:
56
+ raise ValueError("`ArrayChunk` objects do not have children.")
57
+
58
+ obj_body = get_object_body(object_content, cls)
59
+ return cls(data=memoryview(obj_body))
flwr/common/serde.py CHANGED
@@ -19,7 +19,6 @@ from collections import OrderedDict
19
19
  from typing import Any, cast
20
20
 
21
21
  # pylint: disable=E0611
22
- from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
23
22
  from flwr.proto.fab_pb2 import Fab as ProtoFab
24
23
  from flwr.proto.message_pb2 import Context as ProtoContext
25
24
  from flwr.proto.message_pb2 import Message as ProtoMessage
@@ -653,33 +652,6 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
653
652
  return run
654
653
 
655
654
 
656
- # === ClientApp status messages ===
657
-
658
-
659
- def clientappstatus_to_proto(
660
- status: typing.ClientAppOutputStatus,
661
- ) -> ClientAppOutputStatus:
662
- """Serialize `ClientAppOutputStatus` to ProtoBuf."""
663
- code = ClientAppOutputCode.SUCCESS
664
- if status.code == typing.ClientAppOutputCode.DEADLINE_EXCEEDED:
665
- code = ClientAppOutputCode.DEADLINE_EXCEEDED
666
- if status.code == typing.ClientAppOutputCode.UNKNOWN_ERROR:
667
- code = ClientAppOutputCode.UNKNOWN_ERROR
668
- return ClientAppOutputStatus(code=code, message=status.message)
669
-
670
-
671
- def clientappstatus_from_proto(
672
- msg: ClientAppOutputStatus,
673
- ) -> typing.ClientAppOutputStatus:
674
- """Deserialize `ClientAppOutputStatus` from ProtoBuf."""
675
- code = typing.ClientAppOutputCode.SUCCESS
676
- if msg.code == ClientAppOutputCode.DEADLINE_EXCEEDED:
677
- code = typing.ClientAppOutputCode.DEADLINE_EXCEEDED
678
- if msg.code == ClientAppOutputCode.UNKNOWN_ERROR:
679
- code = typing.ClientAppOutputCode.UNKNOWN_ERROR
680
- return typing.ClientAppOutputStatus(code=code, message=msg.message)
681
-
682
-
683
655
  # === Run status ===
684
656
 
685
657
 
flwr/compat/client/app.py CHANGED
@@ -29,8 +29,6 @@ from flwr.cli.config_utils import get_fab_metadata
29
29
  from flwr.cli.install import install_from_fab
30
30
  from flwr.client.client import Client
31
31
  from flwr.client.client_app import ClientApp, LoadClientAppError
32
- from flwr.client.grpc_adapter_client.connection import grpc_adapter
33
- from flwr.client.grpc_rere_client.connection import grpc_request_response
34
32
  from flwr.client.message_handler.message_handler import handle_control_message
35
33
  from flwr.client.numpy_client import NumPyClient
36
34
  from flwr.client.run_info_store import DeprecatedRunInfoStore
@@ -39,10 +37,7 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH, Context, EventType, Message, ev
39
37
  from flwr.common.address import parse_address
40
38
  from flwr.common.constant import (
41
39
  MAX_RETRY_DELAY,
42
- TRANSPORT_TYPE_GRPC_ADAPTER,
43
40
  TRANSPORT_TYPE_GRPC_BIDI,
44
- TRANSPORT_TYPE_GRPC_RERE,
45
- TRANSPORT_TYPE_REST,
46
41
  TRANSPORT_TYPES,
47
42
  ErrorCode,
48
43
  )
@@ -121,10 +116,8 @@ def start_client(
121
116
  Starts an insecure gRPC connection when True. Enables HTTPS connection
122
117
  when False, using system certificates if `root_certificates` is None.
123
118
  transport : Optional[str] (default: None)
124
- Configure the transport layer. Allowed values:
125
- - 'grpc-bidi': gRPC, bidirectional streaming
126
- - 'grpc-rere': gRPC, request-response (experimental)
127
- - 'rest': HTTP (experimental)
119
+ **[Deprecated]** This argument is no longer supported and will be
120
+ removed in a future release.
128
121
  authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
129
122
  Tuple containing the elliptic curve private key and public key for
130
123
  authentication from the cryptography library.
@@ -180,6 +173,12 @@ def start_client(
180
173
  )
181
174
  warn_deprecated_feature(name=msg)
182
175
 
176
+ if transport is not None and transport != "grpc-bidi":
177
+ raise ValueError(
178
+ f"Transport type {transport} is not supported. "
179
+ "Use 'grpc-bidi' or None (default) instead."
180
+ )
181
+
183
182
  event(EventType.START_CLIENT_ENTER)
184
183
  start_client_internal(
185
184
  server_address=server_address,
@@ -429,7 +428,7 @@ def start_client_internal(
429
428
 
430
429
  run: Run = runs[run_id]
431
430
  if get_fab is not None and run.fab_hash:
432
- fab = get_fab(run.fab_hash, run_id)
431
+ fab = get_fab(run.fab_hash, run_id) # pylint: disable=E1102
433
432
  # If `ClientApp` runs in the same process, install the FAB
434
433
  install_from_fab(fab.content, flwr_path, True)
435
434
  fab_id, fab_version = get_fab_metadata(fab.content)
@@ -573,10 +572,8 @@ def start_numpy_client(
573
572
  Starts an insecure gRPC connection when True. Enables HTTPS connection
574
573
  when False, using system certificates if `root_certificates` is None.
575
574
  transport : Optional[str] (default: None)
576
- Configure the transport layer. Allowed values:
577
- - 'grpc-bidi': gRPC, bidirectional streaming
578
- - 'grpc-rere': gRPC, request-response (experimental)
579
- - 'rest': HTTP (experimental)
575
+ **[Deprecated]** This argument is no longer supported and will be
576
+ removed in a future release.
580
577
 
581
578
  Examples
582
579
  --------
@@ -672,23 +669,9 @@ def _init_connection(transport: Optional[str], server_address: str) -> tuple[
672
669
  if transport is None:
673
670
  transport = TRANSPORT_TYPE_GRPC_BIDI
674
671
 
675
- # Use either gRPC bidirectional streaming or REST request/response
676
- if transport == TRANSPORT_TYPE_REST:
677
- try:
678
- from requests.exceptions import ConnectionError as RequestsConnectionError
679
-
680
- from flwr.client.rest_client.connection import http_request_response
681
- except ModuleNotFoundError:
682
- flwr_exit(ExitCode.COMMON_MISSING_EXTRA_REST)
683
- if server_address[:4] != "http":
684
- flwr_exit(ExitCode.SUPERNODE_REST_ADDRESS_INVALID)
685
- connection, error_type = http_request_response, RequestsConnectionError
686
- elif transport == TRANSPORT_TYPE_GRPC_RERE:
687
- connection, error_type = grpc_request_response, RpcError
688
- elif transport == TRANSPORT_TYPE_GRPC_ADAPTER:
689
- connection, error_type = grpc_adapter, RpcError
690
- elif transport == TRANSPORT_TYPE_GRPC_BIDI:
691
- connection, error_type = grpc_connection, RpcError # type: ignore[assignment]
672
+ # Use gRPC bidirectional streaming
673
+ if transport == TRANSPORT_TYPE_GRPC_BIDI:
674
+ connection, error_type = grpc_connection, RpcError
692
675
  else:
693
676
  raise ValueError(
694
677
  f"Unknown transport type: {transport} (possible: {TRANSPORT_TYPES})"