flwr 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. flwr/app/__init__.py +15 -0
  2. flwr/app/error.py +68 -0
  3. flwr/app/metadata.py +223 -0
  4. flwr/cli/build.py +94 -59
  5. flwr/cli/log.py +3 -3
  6. flwr/cli/login/login.py +3 -7
  7. flwr/cli/ls.py +15 -36
  8. flwr/cli/new/new.py +12 -4
  9. flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
  10. flwr/cli/new/templates/app/README.md.tpl +5 -0
  11. flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
  12. flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
  13. flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
  14. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +25 -17
  15. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
  16. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
  17. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
  18. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
  19. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
  20. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
  21. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
  22. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
  23. flwr/cli/run/run.py +48 -49
  24. flwr/cli/stop.py +2 -2
  25. flwr/cli/utils.py +38 -5
  26. flwr/client/__init__.py +2 -2
  27. flwr/client/client_app.py +1 -1
  28. flwr/client/clientapp/__init__.py +0 -7
  29. flwr/client/grpc_adapter_client/connection.py +15 -8
  30. flwr/client/grpc_rere_client/connection.py +142 -97
  31. flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
  32. flwr/client/message_handler/message_handler.py +1 -1
  33. flwr/client/mod/comms_mods.py +36 -17
  34. flwr/client/rest_client/connection.py +176 -103
  35. flwr/clientapp/__init__.py +15 -0
  36. flwr/common/__init__.py +2 -2
  37. flwr/common/auth_plugin/__init__.py +2 -0
  38. flwr/common/auth_plugin/auth_plugin.py +29 -3
  39. flwr/common/constant.py +39 -8
  40. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  41. flwr/common/exit/exit_code.py +16 -1
  42. flwr/common/exit_handlers.py +30 -0
  43. flwr/common/grpc.py +12 -1
  44. flwr/common/heartbeat.py +165 -0
  45. flwr/common/inflatable.py +290 -0
  46. flwr/common/inflatable_protobuf_utils.py +141 -0
  47. flwr/common/inflatable_utils.py +508 -0
  48. flwr/common/message.py +110 -242
  49. flwr/common/record/__init__.py +2 -1
  50. flwr/common/record/array.py +402 -0
  51. flwr/common/record/arraychunk.py +59 -0
  52. flwr/common/record/arrayrecord.py +103 -225
  53. flwr/common/record/configrecord.py +59 -4
  54. flwr/common/record/conversion_utils.py +1 -1
  55. flwr/common/record/metricrecord.py +55 -4
  56. flwr/common/record/recorddict.py +69 -1
  57. flwr/common/recorddict_compat.py +2 -2
  58. flwr/common/retry_invoker.py +5 -1
  59. flwr/common/serde.py +59 -211
  60. flwr/common/serde_utils.py +175 -0
  61. flwr/common/typing.py +5 -3
  62. flwr/compat/__init__.py +15 -0
  63. flwr/compat/client/__init__.py +15 -0
  64. flwr/{client → compat/client}/app.py +28 -185
  65. flwr/compat/common/__init__.py +15 -0
  66. flwr/compat/server/__init__.py +15 -0
  67. flwr/compat/server/app.py +174 -0
  68. flwr/compat/simulation/__init__.py +15 -0
  69. flwr/proto/appio_pb2.py +43 -0
  70. flwr/proto/appio_pb2.pyi +151 -0
  71. flwr/proto/appio_pb2_grpc.py +4 -0
  72. flwr/proto/appio_pb2_grpc.pyi +4 -0
  73. flwr/proto/clientappio_pb2.py +12 -19
  74. flwr/proto/clientappio_pb2.pyi +23 -101
  75. flwr/proto/clientappio_pb2_grpc.py +269 -28
  76. flwr/proto/clientappio_pb2_grpc.pyi +114 -20
  77. flwr/proto/fleet_pb2.py +24 -27
  78. flwr/proto/fleet_pb2.pyi +19 -35
  79. flwr/proto/fleet_pb2_grpc.py +117 -13
  80. flwr/proto/fleet_pb2_grpc.pyi +47 -6
  81. flwr/proto/heartbeat_pb2.py +33 -0
  82. flwr/proto/heartbeat_pb2.pyi +66 -0
  83. flwr/proto/heartbeat_pb2_grpc.py +4 -0
  84. flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
  85. flwr/proto/message_pb2.py +28 -11
  86. flwr/proto/message_pb2.pyi +125 -0
  87. flwr/proto/recorddict_pb2.py +16 -28
  88. flwr/proto/recorddict_pb2.pyi +46 -64
  89. flwr/proto/run_pb2.py +24 -32
  90. flwr/proto/run_pb2.pyi +4 -52
  91. flwr/proto/serverappio_pb2.py +9 -23
  92. flwr/proto/serverappio_pb2.pyi +0 -110
  93. flwr/proto/serverappio_pb2_grpc.py +177 -72
  94. flwr/proto/serverappio_pb2_grpc.pyi +75 -33
  95. flwr/proto/simulationio_pb2.py +12 -11
  96. flwr/proto/simulationio_pb2_grpc.py +35 -0
  97. flwr/proto/simulationio_pb2_grpc.pyi +14 -0
  98. flwr/server/__init__.py +1 -1
  99. flwr/server/app.py +69 -187
  100. flwr/server/compat/app_utils.py +50 -28
  101. flwr/server/fleet_event_log_interceptor.py +6 -2
  102. flwr/server/grid/grpc_grid.py +148 -41
  103. flwr/server/grid/inmemory_grid.py +5 -4
  104. flwr/server/serverapp/app.py +45 -17
  105. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +21 -3
  106. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
  107. flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
  108. flwr/server/superlink/fleet/message_handler/message_handler.py +130 -19
  109. flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -13
  110. flwr/server/superlink/fleet/vce/vce_api.py +6 -3
  111. flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
  112. flwr/server/superlink/linkstate/linkstate.py +53 -20
  113. flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
  114. flwr/server/superlink/linkstate/utils.py +33 -29
  115. flwr/server/superlink/serverappio/serverappio_grpc.py +4 -1
  116. flwr/server/superlink/serverappio/serverappio_servicer.py +230 -84
  117. flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
  118. flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
  119. flwr/server/superlink/utils.py +9 -2
  120. flwr/server/utils/validator.py +2 -2
  121. flwr/serverapp/__init__.py +15 -0
  122. flwr/simulation/app.py +25 -0
  123. flwr/simulation/run_simulation.py +17 -0
  124. flwr/supercore/__init__.py +15 -0
  125. flwr/{server/superlink → supercore}/ffs/__init__.py +2 -0
  126. flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
  127. flwr/supercore/grpc_health/__init__.py +22 -0
  128. flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
  129. flwr/supercore/license_plugin/__init__.py +22 -0
  130. flwr/supercore/license_plugin/license_plugin.py +26 -0
  131. flwr/supercore/object_store/__init__.py +24 -0
  132. flwr/supercore/object_store/in_memory_object_store.py +229 -0
  133. flwr/supercore/object_store/object_store.py +170 -0
  134. flwr/supercore/object_store/object_store_factory.py +44 -0
  135. flwr/supercore/object_store/utils.py +43 -0
  136. flwr/supercore/scheduler/__init__.py +22 -0
  137. flwr/supercore/scheduler/plugin.py +71 -0
  138. flwr/{client/nodestate/nodestate.py → supercore/utils.py} +14 -13
  139. flwr/superexec/deployment.py +7 -4
  140. flwr/superexec/exec_event_log_interceptor.py +8 -4
  141. flwr/superexec/exec_grpc.py +25 -5
  142. flwr/superexec/exec_license_interceptor.py +82 -0
  143. flwr/superexec/exec_servicer.py +135 -24
  144. flwr/superexec/exec_user_auth_interceptor.py +45 -8
  145. flwr/superexec/executor.py +5 -1
  146. flwr/superexec/simulation.py +8 -3
  147. flwr/superlink/__init__.py +15 -0
  148. flwr/{client/supernode → supernode}/__init__.py +0 -7
  149. flwr/supernode/cli/__init__.py +24 -0
  150. flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -19
  151. flwr/supernode/cli/flwr_clientapp.py +88 -0
  152. flwr/supernode/nodestate/in_memory_nodestate.py +199 -0
  153. flwr/supernode/nodestate/nodestate.py +227 -0
  154. flwr/supernode/runtime/__init__.py +15 -0
  155. flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +135 -89
  156. flwr/supernode/scheduler/__init__.py +22 -0
  157. flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
  158. flwr/supernode/servicer/__init__.py +15 -0
  159. flwr/supernode/servicer/clientappio/__init__.py +22 -0
  160. flwr/supernode/servicer/clientappio/clientappio_servicer.py +303 -0
  161. flwr/supernode/start_client_internal.py +589 -0
  162. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/METADATA +6 -4
  163. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/RECORD +171 -123
  164. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +1 -1
  165. {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +2 -2
  166. flwr/client/clientapp/clientappio_servicer.py +0 -244
  167. flwr/client/heartbeat.py +0 -74
  168. flwr/client/nodestate/in_memory_nodestate.py +0 -38
  169. /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
  170. /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
  171. /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
  172. /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
  173. /flwr/{client → supernode}/nodestate/__init__.py +0 -0
  174. /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
@@ -0,0 +1,508 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """InflatableObject utilities."""
16
+
17
+ import concurrent.futures
18
+ import os
19
+ import random
20
+ import threading
21
+ import time
22
+ from collections.abc import Iterable, Iterator
23
+ from typing import Callable, Optional, TypeVar
24
+
25
+ from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
26
+
27
+ from .constant import (
28
+ HEAD_BODY_DIVIDER,
29
+ HEAD_VALUE_DIVIDER,
30
+ MAX_CONCURRENT_PULLS,
31
+ MAX_CONCURRENT_PUSHES,
32
+ PULL_BACKOFF_CAP,
33
+ PULL_INITIAL_BACKOFF,
34
+ PULL_MAX_TIME,
35
+ PULL_MAX_TRIES_PER_OBJECT,
36
+ )
37
+ from .exit_handlers import add_exit_handler
38
+ from .inflatable import (
39
+ InflatableObject,
40
+ UnexpectedObjectContentError,
41
+ _get_object_head,
42
+ get_object_head_values_from_object_content,
43
+ get_object_id,
44
+ is_valid_sha256_hash,
45
+ iterate_object_tree,
46
+ )
47
+ from .message import Message
48
+ from .record import Array, ArrayRecord, ConfigRecord, MetricRecord, RecordDict
49
+ from .record.arraychunk import ArrayChunk
50
+
51
+ # Helper registry that maps names of classes to their type
52
+ inflatable_class_registry: dict[str, type[InflatableObject]] = {
53
+ ArrayChunk.__qualname__: ArrayChunk,
54
+ Array.__qualname__: Array,
55
+ ArrayRecord.__qualname__: ArrayRecord,
56
+ ConfigRecord.__qualname__: ConfigRecord,
57
+ Message.__qualname__: Message,
58
+ MetricRecord.__qualname__: MetricRecord,
59
+ RecordDict.__qualname__: RecordDict,
60
+ }
61
+
62
+ T = TypeVar("T", bound=InflatableObject)
63
+
64
+
65
+ # Allow thread pool executors to be shut down gracefully
66
+ _thread_pool_executors: set[concurrent.futures.ThreadPoolExecutor] = set()
67
+ _lock = threading.Lock()
68
+
69
+
70
+ def _shutdown_thread_pool_executors() -> None:
71
+ """Shutdown all thread pool executors gracefully."""
72
+ with _lock:
73
+ for executor in _thread_pool_executors:
74
+ executor.shutdown(wait=False, cancel_futures=True)
75
+ _thread_pool_executors.clear()
76
+
77
+
78
+ def _track_executor(executor: concurrent.futures.ThreadPoolExecutor) -> None:
79
+ """Track a thread pool executor for graceful shutdown."""
80
+ with _lock:
81
+ _thread_pool_executors.add(executor)
82
+
83
+
84
+ def _untrack_executor(executor: concurrent.futures.ThreadPoolExecutor) -> None:
85
+ """Untrack a thread pool executor."""
86
+ with _lock:
87
+ _thread_pool_executors.discard(executor)
88
+
89
+
90
+ add_exit_handler(_shutdown_thread_pool_executors)
91
+
92
+
93
+ class ObjectUnavailableError(Exception):
94
+ """Exception raised when an object has been pre-registered but is not yet
95
+ available."""
96
+
97
+ def __init__(self, object_id: str):
98
+ super().__init__(f"Object with ID '{object_id}' is not yet available.")
99
+
100
+
101
+ class ObjectIdNotPreregisteredError(Exception):
102
+ """Exception raised when an object ID is not pre-registered."""
103
+
104
+ def __init__(self, object_id: str):
105
+ super().__init__(f"Object with ID '{object_id}' could not be found.")
106
+
107
+
108
+ def get_num_workers(max_concurrent: int) -> int:
109
+ """Get number of workers based on the number of CPU cores and the maximum
110
+ allowed."""
111
+ num_cores = os.cpu_count() or 1
112
+ return min(max_concurrent, num_cores)
113
+
114
+
115
+ def push_objects(
116
+ objects: dict[str, InflatableObject],
117
+ push_object_fn: Callable[[str, bytes], None],
118
+ *,
119
+ object_ids_to_push: Optional[set[str]] = None,
120
+ keep_objects: bool = False,
121
+ max_concurrent_pushes: int = MAX_CONCURRENT_PUSHES,
122
+ ) -> None:
123
+ """Push multiple objects to the servicer.
124
+
125
+ Parameters
126
+ ----------
127
+ objects : dict[str, InflatableObject]
128
+ A dictionary of objects to push, where keys are object IDs and values are
129
+ `InflatableObject` instances.
130
+ push_object_fn : Callable[[str, bytes], None]
131
+ A function that takes an object ID and its content as bytes, and pushes
132
+ it to the servicer. This function should raise `ObjectIdNotPreregisteredError`
133
+ if the object ID is not pre-registered.
134
+ object_ids_to_push : Optional[set[str]] (default: None)
135
+ A set of object IDs to push. If not provided, all objects will be pushed.
136
+ keep_objects : bool (default: False)
137
+ If `True`, the original objects will be kept in the `objects` dictionary
138
+ after pushing. If `False`, they will be removed from the dictionary to avoid
139
+ high memory usage.
140
+ max_concurrent_pushes : int (default: MAX_CONCURRENT_PUSHES)
141
+ The maximum number of concurrent pushes to perform.
142
+ """
143
+ lock = threading.Lock()
144
+
145
+ def iter_dict_items() -> Iterator[tuple[str, bytes]]:
146
+ """Iterate over the dictionary items."""
147
+ for obj_id in list(objects.keys()):
148
+ # Skip the object if no need to push it
149
+ if object_ids_to_push is not None and obj_id not in object_ids_to_push:
150
+ continue
151
+
152
+ # Deflate the object content
153
+ object_content = objects[obj_id].deflate()
154
+ if not keep_objects:
155
+ with lock:
156
+ del objects[obj_id]
157
+
158
+ yield obj_id, object_content
159
+
160
+ push_object_contents_from_iterable(
161
+ iter_dict_items(),
162
+ push_object_fn,
163
+ max_concurrent_pushes=max_concurrent_pushes,
164
+ )
165
+
166
+
167
+ def push_object_contents_from_iterable(
168
+ object_contents: Iterable[tuple[str, bytes]],
169
+ push_object_fn: Callable[[str, bytes], None],
170
+ *,
171
+ max_concurrent_pushes: int = MAX_CONCURRENT_PUSHES,
172
+ ) -> None:
173
+ """Push multiple object contents to the servicer.
174
+
175
+ Parameters
176
+ ----------
177
+ object_contents : Iterable[tuple[str, bytes]]
178
+ An iterable of `(object_id, object_content)` pairs.
179
+ `object_id` is the object ID, and `object_content` is the object content.
180
+ push_object_fn : Callable[[str, bytes], None]
181
+ A function that takes an object ID and its content as bytes, and pushes
182
+ it to the servicer. This function should raise `ObjectIdNotPreregisteredError`
183
+ if the object ID is not pre-registered.
184
+ max_concurrent_pushes : int (default: MAX_CONCURRENT_PUSHES)
185
+ The maximum number of concurrent pushes to perform.
186
+ """
187
+
188
+ def push(args: tuple[str, bytes]) -> None:
189
+ """Push a single object."""
190
+ obj_id, obj_content = args
191
+ # Push the object using the provided function
192
+ push_object_fn(obj_id, obj_content)
193
+
194
+ # Push all object contents concurrently
195
+ num_workers = get_num_workers(max_concurrent_pushes)
196
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
197
+ # Ensure that the thread pool executors are tracked for graceful shutdown
198
+ _track_executor(executor)
199
+
200
+ # Submit push tasks for each object content
201
+ executor.map(push, object_contents) # Non-blocking map
202
+
203
+ # The context manager will block until all submitted tasks have completed
204
+
205
+ # Remove the executor from the list of tracked executors
206
+ _untrack_executor(executor)
207
+
208
+
209
+ def pull_objects( # pylint: disable=too-many-arguments,too-many-locals
210
+ object_ids: list[str],
211
+ pull_object_fn: Callable[[str], bytes],
212
+ *,
213
+ max_concurrent_pulls: int = MAX_CONCURRENT_PULLS,
214
+ max_time: Optional[float] = PULL_MAX_TIME,
215
+ max_tries_per_object: Optional[int] = PULL_MAX_TRIES_PER_OBJECT,
216
+ initial_backoff: float = PULL_INITIAL_BACKOFF,
217
+ backoff_cap: float = PULL_BACKOFF_CAP,
218
+ ) -> dict[str, bytes]:
219
+ """Pull multiple objects from the servicer.
220
+
221
+ Parameters
222
+ ----------
223
+ object_ids : list[str]
224
+ A list of object IDs to pull.
225
+ pull_object_fn : Callable[[str], bytes]
226
+ A function that takes an object ID and returns the object content as bytes.
227
+ The function should raise `ObjectUnavailableError` if the object is not yet
228
+ available, or `ObjectIdNotPreregisteredError` if the object ID is not
229
+ pre-registered.
230
+ max_concurrent_pulls : int (default: MAX_CONCURRENT_PULLS)
231
+ The maximum number of concurrent pulls to perform.
232
+ max_time : Optional[float] (default: PULL_MAX_TIME)
233
+ The maximum time to wait for all pulls to complete. If `None`, waits
234
+ indefinitely.
235
+ max_tries_per_object : Optional[int] (default: PULL_MAX_TRIES_PER_OBJECT)
236
+ The maximum number of attempts to pull each object. If `None`, pulls
237
+ indefinitely until the object is available.
238
+ initial_backoff : float (default: PULL_INITIAL_BACKOFF)
239
+ The initial backoff time in seconds for retrying pulls after an
240
+ `ObjectUnavailableError`.
241
+ backoff_cap : float (default: PULL_BACKOFF_CAP)
242
+ The maximum backoff time in seconds. Backoff times will not exceed this value.
243
+
244
+ Returns
245
+ -------
246
+ dict[str, bytes]
247
+ A dictionary where keys are object IDs and values are the pulled
248
+ object contents.
249
+ """
250
+ if max_tries_per_object is None:
251
+ max_tries_per_object = int(1e9)
252
+ if max_time is None:
253
+ max_time = float("inf")
254
+
255
+ results: dict[str, bytes] = {}
256
+ results_lock = threading.Lock()
257
+ err_to_raise: Optional[Exception] = None
258
+ early_stop = threading.Event()
259
+ start = time.monotonic()
260
+
261
+ def pull_with_retries(object_id: str) -> None:
262
+ """Attempt to pull a single object with retry and backoff."""
263
+ nonlocal err_to_raise
264
+ tries = 0
265
+ delay = initial_backoff
266
+
267
+ while not early_stop.is_set():
268
+ try:
269
+ object_content = pull_object_fn(object_id)
270
+ with results_lock:
271
+ results[object_id] = object_content
272
+ return
273
+
274
+ except ObjectUnavailableError as err:
275
+ tries += 1
276
+ if (
277
+ tries >= max_tries_per_object
278
+ or time.monotonic() - start >= max_time
279
+ ):
280
+ # Stop all work if one object exhausts retries
281
+ early_stop.set()
282
+ with results_lock:
283
+ if err_to_raise is None:
284
+ err_to_raise = err
285
+ return
286
+
287
+ # Apply exponential backoff with ±20% jitter
288
+ sleep_time = delay * (1 + random.uniform(-0.2, 0.2))
289
+ early_stop.wait(sleep_time)
290
+ delay = min(delay * 2, backoff_cap)
291
+
292
+ except ObjectIdNotPreregisteredError as err:
293
+ # Permanent failure: object ID is invalid
294
+ early_stop.set()
295
+ with results_lock:
296
+ if err_to_raise is None:
297
+ err_to_raise = err
298
+ return
299
+
300
+ # Submit all pull tasks concurrently
301
+ num_workers = get_num_workers(max_concurrent_pulls)
302
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
303
+ # Ensure that the thread pool executors are tracked for graceful shutdown
304
+ _track_executor(executor)
305
+
306
+ # Submit pull tasks for each object ID
307
+ executor.map(pull_with_retries, object_ids) # Non-blocking map
308
+
309
+ # The context manager will block until all submitted tasks have completed
310
+
311
+ # Remove the executor from the list of tracked executors
312
+ _untrack_executor(executor)
313
+
314
+ # If an error occurred during pulling, raise it
315
+ if err_to_raise is not None:
316
+ raise err_to_raise
317
+
318
+ return results
319
+
320
+
321
+ def inflate_object_from_contents(
322
+ object_id: str,
323
+ object_contents: dict[str, bytes],
324
+ *,
325
+ keep_object_contents: bool = False,
326
+ objects: Optional[dict[str, InflatableObject]] = None,
327
+ ) -> InflatableObject:
328
+ """Inflate an object from object contents.
329
+
330
+ Parameters
331
+ ----------
332
+ object_id : str
333
+ The ID of the object to inflate.
334
+ object_contents : dict[str, bytes]
335
+ A dictionary mapping object IDs to their contents as bytes.
336
+ All descendant objects must be present in this dictionary.
337
+ keep_object_contents : bool (default: False)
338
+ If `True`, the object content will be kept in the `object_contents`
339
+ dictionary after inflation. If `False`, the object content will be
340
+ removed from the dictionary to save memory.
341
+ objects : Optional[dict[str, InflatableObject]] (default: None)
342
+ No need to provide this parameter. A dictionary to store already
343
+ inflated objects, mapping object IDs to their corresponding
344
+ `InflatableObject` instances.
345
+
346
+ Returns
347
+ -------
348
+ InflatableObject
349
+ The inflated object.
350
+ """
351
+ if objects is None:
352
+ # Initialize objects dictionary
353
+ objects = {}
354
+
355
+ if object_id in objects:
356
+ # If the object is already in the objects dictionary, return it
357
+ return objects[object_id]
358
+
359
+ # Extract object class and object_ids of children
360
+ object_content = object_contents[object_id]
361
+ obj_type, children_obj_ids, _ = get_object_head_values_from_object_content(
362
+ object_content=object_contents[object_id]
363
+ )
364
+
365
+ # Remove the object content from the dictionary to save memory
366
+ if not keep_object_contents:
367
+ del object_contents[object_id]
368
+
369
+ # Resolve object class
370
+ cls_type = inflatable_class_registry[obj_type]
371
+
372
+ # Inflate all children objects
373
+ children: dict[str, InflatableObject] = {}
374
+ for child_obj_id in children_obj_ids:
375
+ children[child_obj_id] = inflate_object_from_contents(
376
+ child_obj_id,
377
+ object_contents,
378
+ keep_object_contents=keep_object_contents,
379
+ objects=objects,
380
+ )
381
+
382
+ # Inflate object passing its children
383
+ obj = cls_type.inflate(object_content, children=children)
384
+ del object_content # Free memory after inflation
385
+ objects[object_id] = obj
386
+ return obj
387
+
388
+
389
+ def validate_object_content(content: bytes) -> None:
390
+ """Validate the deflated content of an InflatableObject."""
391
+ try:
392
+ # Check if there is a head-body divider
393
+ index = content.find(HEAD_BODY_DIVIDER)
394
+ if index == -1:
395
+ raise ValueError(
396
+ "Unexpected format for object content. Head and body "
397
+ "could not be split."
398
+ )
399
+
400
+ head = _get_object_head(content)
401
+
402
+ # check if the head has three parts:
403
+ # <object_type> <children_ids> <object_body_len>
404
+ head_decoded = head.decode(encoding="utf-8")
405
+ head_parts = head_decoded.split(HEAD_VALUE_DIVIDER)
406
+
407
+ if len(head_parts) != 3:
408
+ raise ValueError("Unexpected format for object head.")
409
+
410
+ obj_type, children_str, body_len = head_parts
411
+
412
+ # Check that children IDs are valid IDs
413
+ children = children_str.split(",")
414
+ for children_id in children:
415
+ if children_id and not is_valid_sha256_hash(children_id):
416
+ raise ValueError(
417
+ f"Detected invalid object ID ({children_id}) in children."
418
+ )
419
+
420
+ # Check that object type is recognized
421
+ if obj_type not in inflatable_class_registry:
422
+ if obj_type != "CustomDataClass": # to allow for the class in tests
423
+ raise ValueError(f"Object of type {obj_type} is not supported.")
424
+
425
+ # Check if the body length in the head matches that of the body
426
+ actual_body_len = len(content) - len(head) - len(HEAD_BODY_DIVIDER)
427
+ if actual_body_len != int(body_len):
428
+ raise ValueError(
429
+ f"Object content length expected {body_len} bytes but got "
430
+ f"{actual_body_len} bytes."
431
+ )
432
+
433
+ except ValueError as err:
434
+ raise UnexpectedObjectContentError(
435
+ object_id=get_object_id(content), reason=str(err)
436
+ ) from err
437
+
438
+
439
+ def pull_and_inflate_object_from_tree( # pylint: disable=R0913
440
+ object_tree: ObjectTree,
441
+ pull_object_fn: Callable[[str], bytes],
442
+ confirm_object_received_fn: Callable[[str], None],
443
+ *,
444
+ return_type: type[T] = InflatableObject, # type: ignore
445
+ max_concurrent_pulls: int = MAX_CONCURRENT_PULLS,
446
+ max_time: Optional[float] = PULL_MAX_TIME,
447
+ max_tries_per_object: Optional[int] = PULL_MAX_TRIES_PER_OBJECT,
448
+ initial_backoff: float = PULL_INITIAL_BACKOFF,
449
+ backoff_cap: float = PULL_BACKOFF_CAP,
450
+ ) -> T:
451
+ """Pull and inflate the head object from the provided object tree.
452
+
453
+ Parameters
454
+ ----------
455
+ object_tree : ObjectTree
456
+ The object tree containing the object ID and its descendants.
457
+ pull_object_fn : Callable[[str], bytes]
458
+ A function that takes an object ID and returns the object content as bytes.
459
+ confirm_object_received_fn : Callable[[str], None]
460
+ A function to confirm that the object has been received.
461
+ return_type : type[T] (default: InflatableObject)
462
+ The type of the object to return. Must be a subclass of `InflatableObject`.
463
+ max_concurrent_pulls : int (default: MAX_CONCURRENT_PULLS)
464
+ The maximum number of concurrent pulls to perform.
465
+ max_time : Optional[float] (default: PULL_MAX_TIME)
466
+ The maximum time to wait for all pulls to complete. If `None`, waits
467
+ indefinitely.
468
+ max_tries_per_object : Optional[int] (default: PULL_MAX_TRIES_PER_OBJECT)
469
+ The maximum number of attempts to pull each object. If `None`, pulls
470
+ indefinitely until the object is available.
471
+ initial_backoff : float (default: PULL_INITIAL_BACKOFF)
472
+ The initial backoff time in seconds for retrying pulls after an
473
+ `ObjectUnavailableError`.
474
+ backoff_cap : float (default: PULL_BACKOFF_CAP)
475
+ The maximum backoff time in seconds. Backoff times will not exceed this value.
476
+
477
+ Returns
478
+ -------
479
+ T
480
+ An instance of the specified return type containing the inflated object.
481
+ """
482
+ # Pull the main object and all its descendants
483
+ pulled_object_contents = pull_objects(
484
+ [tree.object_id for tree in iterate_object_tree(object_tree)],
485
+ pull_object_fn,
486
+ max_concurrent_pulls=max_concurrent_pulls,
487
+ max_time=max_time,
488
+ max_tries_per_object=max_tries_per_object,
489
+ initial_backoff=initial_backoff,
490
+ backoff_cap=backoff_cap,
491
+ )
492
+
493
+ # Confirm that all objects were pulled
494
+ confirm_object_received_fn(object_tree.object_id)
495
+
496
+ # Inflate the main object
497
+ inflated_object = inflate_object_from_contents(
498
+ object_tree.object_id, pulled_object_contents, keep_object_contents=False
499
+ )
500
+
501
+ # Check if the inflated object is of the expected type
502
+ if not isinstance(inflated_object, return_type):
503
+ raise TypeError(
504
+ f"Expected object of type {return_type.__name__}, "
505
+ f"but got {type(inflated_object).__name__}."
506
+ )
507
+
508
+ return inflated_object