flwr-nightly 1.19.0.dev20250610__py3-none-any.whl → 1.19.0.dev20250612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. flwr/client/grpc_rere_client/connection.py +48 -29
  2. flwr/client/grpc_rere_client/grpc_adapter.py +8 -0
  3. flwr/client/rest_client/connection.py +138 -27
  4. flwr/common/auth_plugin/auth_plugin.py +6 -4
  5. flwr/common/event_log_plugin/event_log_plugin.py +3 -3
  6. flwr/common/inflatable.py +70 -1
  7. flwr/common/inflatable_grpc_utils.py +1 -1
  8. flwr/common/inflatable_rest_utils.py +99 -0
  9. flwr/common/serde.py +2 -0
  10. flwr/common/typing.py +5 -3
  11. flwr/proto/fleet_pb2.py +12 -16
  12. flwr/proto/fleet_pb2.pyi +4 -19
  13. flwr/proto/fleet_pb2_grpc.py +34 -0
  14. flwr/proto/fleet_pb2_grpc.pyi +13 -0
  15. flwr/proto/message_pb2.py +15 -9
  16. flwr/proto/message_pb2.pyi +41 -0
  17. flwr/proto/run_pb2.py +24 -24
  18. flwr/proto/run_pb2.pyi +4 -1
  19. flwr/proto/serverappio_pb2.py +22 -26
  20. flwr/proto/serverappio_pb2.pyi +4 -19
  21. flwr/proto/serverappio_pb2_grpc.py +34 -0
  22. flwr/proto/serverappio_pb2_grpc.pyi +13 -0
  23. flwr/server/fleet_event_log_interceptor.py +2 -2
  24. flwr/server/grid/grpc_grid.py +20 -9
  25. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +25 -0
  26. flwr/server/superlink/fleet/message_handler/message_handler.py +33 -2
  27. flwr/server/superlink/fleet/rest_rere/rest_api.py +56 -2
  28. flwr/server/superlink/linkstate/in_memory_linkstate.py +17 -2
  29. flwr/server/superlink/linkstate/linkstate.py +6 -2
  30. flwr/server/superlink/linkstate/sqlite_linkstate.py +19 -7
  31. flwr/server/superlink/serverappio/serverappio_servicer.py +65 -29
  32. flwr/server/superlink/simulation/simulationio_servicer.py +2 -1
  33. flwr/server/superlink/utils.py +23 -10
  34. flwr/supercore/object_store/in_memory_object_store.py +160 -33
  35. flwr/supercore/object_store/object_store.py +54 -7
  36. flwr/superexec/deployment.py +6 -2
  37. flwr/superexec/exec_event_log_interceptor.py +4 -4
  38. flwr/superexec/exec_servicer.py +4 -1
  39. flwr/superexec/exec_user_auth_interceptor.py +11 -11
  40. flwr/superexec/executor.py +4 -0
  41. flwr/superexec/simulation.py +7 -1
  42. {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/METADATA +1 -1
  43. {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/RECORD +45 -44
  44. {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/WHEEL +0 -0
  45. {flwr_nightly-1.19.0.dev20250610.dist-info → flwr_nightly-1.19.0.dev20250612.dist-info}/entry_points.txt +0 -0
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
  """Contextmanager for a gRPC request-response channel to the Flower server."""
16
16
 
17
+
17
18
  from collections.abc import Iterator, Sequence
18
19
  from contextlib import contextmanager
19
20
  from copy import copy
@@ -30,7 +31,11 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
30
31
  from flwr.common.constant import HEARTBEAT_CALL_TIMEOUT, HEARTBEAT_DEFAULT_INTERVAL
31
32
  from flwr.common.grpc import create_channel, on_channel_state_change
32
33
  from flwr.common.heartbeat import HeartbeatSender
33
- from flwr.common.inflatable import get_all_nested_objects
34
+ from flwr.common.inflatable import (
35
+ get_all_nested_objects,
36
+ get_object_tree,
37
+ no_object_id_recompute,
38
+ )
34
39
  from flwr.common.inflatable_grpc_utils import (
35
40
  make_pull_object_fn_grpc,
36
41
  make_push_object_fn_grpc,
@@ -62,7 +67,9 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
62
67
  SendNodeHeartbeatRequest,
63
68
  SendNodeHeartbeatResponse,
64
69
  )
65
- from flwr.proto.message_pb2 import ObjectIDs # pylint: disable=E0611
70
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
71
+ ConfirmMessageReceivedRequest,
72
+ )
66
73
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
67
74
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
68
75
 
@@ -269,14 +276,23 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
269
276
 
270
277
  if message_proto:
271
278
  msg_id = message_proto.metadata.message_id
279
+ run_id = message_proto.metadata.run_id
272
280
  all_object_contents = pull_objects(
273
281
  list(response.objects_to_pull[msg_id].object_ids) + [msg_id],
274
282
  pull_object_fn=make_pull_object_fn_grpc(
275
283
  pull_object_grpc=stub.PullObject,
276
284
  node=node,
277
- run_id=message_proto.metadata.run_id,
285
+ run_id=run_id,
278
286
  ),
279
287
  )
288
+
289
+ # Confirm that the message has been received
290
+ stub.ConfirmMessageReceived(
291
+ ConfirmMessageReceivedRequest(
292
+ node=node, run_id=run_id, message_object_id=msg_id
293
+ )
294
+ )
295
+
280
296
  in_message = cast(
281
297
  Message, inflate_object_from_contents(msg_id, all_object_contents)
282
298
  )
@@ -311,33 +327,36 @@ def grpc_request_response( # pylint: disable=R0913,R0914,R0915,R0917
311
327
  log(ERROR, "Invalid out message")
312
328
  return
313
329
 
314
- # Get all nested objects
315
- all_objects = get_all_nested_objects(message)
316
- all_object_ids = list(all_objects.keys())
317
- msg_id = all_object_ids[-1] # Last object is the message itself
318
- descendant_ids = all_object_ids[:-1] # All but the last object are descendants
319
-
320
- # Serialize Message
321
- message_proto = message_to_proto(message=remove_content_from_message(message))
322
- request = PushMessagesRequest(
323
- node=node,
324
- messages_list=[message_proto],
325
- msg_to_descendant_mapping={msg_id: ObjectIDs(object_ids=descendant_ids)},
326
- )
327
- response: PushMessagesResponse = stub.PushMessages(request=request)
328
-
329
- if response.objects_to_push:
330
- objs_to_push = set(response.objects_to_push[message.object_id].object_ids)
331
- push_objects(
332
- all_objects,
333
- push_object_fn=make_push_object_fn_grpc(
334
- push_object_grpc=stub.PushObject,
335
- node=node,
336
- run_id=message.metadata.run_id,
337
- ),
338
- object_ids_to_push=objs_to_push,
330
+ with no_object_id_recompute():
331
+ # Get all nested objects
332
+ all_objects = get_all_nested_objects(message)
333
+ object_tree = get_object_tree(message)
334
+
335
+ # Serialize Message
336
+ message_proto = message_to_proto(
337
+ message=remove_content_from_message(message)
338
+ )
339
+ request = PushMessagesRequest(
340
+ node=node,
341
+ messages_list=[message_proto],
342
+ message_object_trees=[object_tree],
339
343
  )
340
- log(DEBUG, "Pushed %s objects to servicer.", len(objs_to_push))
344
+ response: PushMessagesResponse = stub.PushMessages(request=request)
345
+
346
+ if response.objects_to_push:
347
+ objs_to_push = set(
348
+ response.objects_to_push[message.object_id].object_ids
349
+ )
350
+ push_objects(
351
+ all_objects,
352
+ push_object_fn=make_push_object_fn_grpc(
353
+ push_object_grpc=stub.PushObject,
354
+ node=node,
355
+ run_id=message.metadata.run_id,
356
+ ),
357
+ object_ids_to_push=objs_to_push,
358
+ )
359
+ log(DEBUG, "Pushed %s objects to servicer.", len(objs_to_push))
341
360
 
342
361
  # Cleanup
343
362
  metadata = None
@@ -50,6 +50,8 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
50
50
  SendNodeHeartbeatResponse,
51
51
  )
52
52
  from flwr.proto.message_pb2 import ( # pylint: disable=E0611
53
+ ConfirmMessageReceivedRequest,
54
+ ConfirmMessageReceivedResponse,
53
55
  PullObjectRequest,
54
56
  PullObjectResponse,
55
57
  PushObjectRequest,
@@ -169,3 +171,9 @@ class GrpcAdapter:
169
171
  ) -> PullObjectResponse:
170
172
  """."""
171
173
  return self._send_and_receive(request, PullObjectResponse, **kwargs)
174
+
175
+ def ConfirmMessageReceived( # pylint: disable=C0103
176
+ self, request: ConfirmMessageReceivedRequest, **kwargs: Any
177
+ ) -> ConfirmMessageReceivedResponse:
178
+ """."""
179
+ return self._send_and_receive(request, ConfirmMessageReceivedResponse, **kwargs)
@@ -14,12 +14,11 @@
14
14
  # ==============================================================================
15
15
  """Contextmanager for a REST request-response channel to the Flower server."""
16
16
 
17
-
18
17
  from collections.abc import Iterator
19
18
  from contextlib import contextmanager
20
19
  from copy import copy
21
- from logging import ERROR, INFO, WARN
22
- from typing import Callable, Optional, TypeVar, Union
20
+ from logging import DEBUG, ERROR, INFO, WARN
21
+ from typing import Callable, Optional, TypeVar, Union, cast
23
22
 
24
23
  from cryptography.hazmat.primitives.asymmetric import ec
25
24
  from google.protobuf.message import Message as GrpcMessage
@@ -31,10 +30,24 @@ from flwr.common import GRPC_MAX_MESSAGE_LENGTH
31
30
  from flwr.common.constant import HEARTBEAT_DEFAULT_INTERVAL
32
31
  from flwr.common.exit import ExitCode, flwr_exit
33
32
  from flwr.common.heartbeat import HeartbeatSender
33
+ from flwr.common.inflatable import (
34
+ get_all_nested_objects,
35
+ get_object_tree,
36
+ no_object_id_recompute,
37
+ )
38
+ from flwr.common.inflatable_rest_utils import (
39
+ make_pull_object_fn_rest,
40
+ make_push_object_fn_rest,
41
+ )
42
+ from flwr.common.inflatable_utils import (
43
+ inflate_object_from_contents,
44
+ pull_objects,
45
+ push_objects,
46
+ )
34
47
  from flwr.common.logger import log
35
- from flwr.common.message import Message
48
+ from flwr.common.message import Message, remove_content_from_message
36
49
  from flwr.common.retry_invoker import RetryInvoker
37
- from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
50
+ from flwr.common.serde import message_to_proto, run_from_proto
38
51
  from flwr.common.typing import Fab, Run
39
52
  from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
40
53
  from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
@@ -51,6 +64,14 @@ from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
51
64
  SendNodeHeartbeatRequest,
52
65
  SendNodeHeartbeatResponse,
53
66
  )
67
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
68
+ ConfirmMessageReceivedRequest,
69
+ ConfirmMessageReceivedResponse,
70
+ PullObjectRequest,
71
+ PullObjectResponse,
72
+ PushObjectRequest,
73
+ PushObjectResponse,
74
+ )
54
75
  from flwr.proto.node_pb2 import Node # pylint: disable=E0611
55
76
  from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
56
77
 
@@ -64,9 +85,12 @@ PATH_CREATE_NODE: str = "api/v0/fleet/create-node"
64
85
  PATH_DELETE_NODE: str = "api/v0/fleet/delete-node"
65
86
  PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages"
66
87
  PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages"
88
+ PATH_PULL_OBJECT: str = "/api/v0/fleet/pull-object"
89
+ PATH_PUSH_OBJECT: str = "/api/v0/fleet/push-object"
67
90
  PATH_SEND_NODE_HEARTBEAT: str = "api/v0/fleet/send-node-heartbeat"
68
91
  PATH_GET_RUN: str = "/api/v0/fleet/get-run"
69
92
  PATH_GET_FAB: str = "/api/v0/fleet/get-fab"
93
+ PATH_CONFIRM_MESSAGE_RECEIVED: str = "/api/v0/fleet/confirm-message-received"
70
94
 
71
95
  T = TypeVar("T", bound=GrpcMessage)
72
96
 
@@ -296,14 +320,58 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
296
320
  ):
297
321
  message_proto = None
298
322
 
299
- # Return the Message if available
300
- nonlocal metadata
301
- message = None
302
- if message_proto is not None:
303
- message = message_from_proto(message_proto)
304
- metadata = copy(message.metadata)
323
+ # Construct the Message
324
+ in_message: Optional[Message] = None
325
+
326
+ if message_proto:
305
327
  log(INFO, "[Node] POST /%s: success", PATH_PULL_MESSAGES)
306
- return message
328
+ msg_id = message_proto.metadata.message_id
329
+ run_id = message_proto.metadata.run_id
330
+
331
+ def fn(request: PullObjectRequest) -> PullObjectResponse:
332
+ res = _request(
333
+ req=request, res_type=PullObjectResponse, api_path=PATH_PULL_OBJECT
334
+ )
335
+ if res is None:
336
+ raise ValueError("PushObjectResponse is None.")
337
+ return res
338
+
339
+ try:
340
+ all_object_contents = pull_objects(
341
+ list(res.objects_to_pull[msg_id].object_ids) + [msg_id],
342
+ pull_object_fn=make_pull_object_fn_rest(
343
+ pull_object_rest=fn,
344
+ node=node,
345
+ run_id=run_id,
346
+ ),
347
+ )
348
+
349
+ # Confirm that the message has been received
350
+ _request(
351
+ req=ConfirmMessageReceivedRequest(
352
+ node=node, run_id=run_id, message_object_id=msg_id
353
+ ),
354
+ res_type=ConfirmMessageReceivedResponse,
355
+ api_path=PATH_CONFIRM_MESSAGE_RECEIVED,
356
+ )
357
+ except ValueError as e:
358
+ log(
359
+ ERROR,
360
+ "Pulling objects failed. Potential irrecoverable error: %s",
361
+ str(e),
362
+ )
363
+ in_message = cast(
364
+ Message, inflate_object_from_contents(msg_id, all_object_contents)
365
+ )
366
+ # The deflated message doesn't contain the message_id (its own object_id)
367
+ # Inject
368
+ in_message.metadata.__dict__["_message_id"] = msg_id
369
+
370
+ # Remember `metadata` of the in message
371
+ nonlocal metadata
372
+ metadata = copy(in_message.metadata) if in_message else None
373
+
374
+ return in_message
307
375
 
308
376
  def send(message: Message) -> None:
309
377
  """Send Message result back to server."""
@@ -318,29 +386,72 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
318
386
  log(ERROR, "No current message")
319
387
  return
320
388
 
389
+ # Set message_id
390
+ message.metadata.__dict__["_message_id"] = message.object_id
321
391
  # Validate out message
322
392
  if not validate_out_message(message, metadata):
323
393
  log(ERROR, "Invalid out message")
324
394
  return
325
- metadata = None
326
395
 
327
- # Serialize ProtoBuf to bytes
328
- message_proto = message_to_proto(message=message)
396
+ with no_object_id_recompute():
397
+ # Get all nested objects
398
+ all_objects = get_all_nested_objects(message)
399
+ object_tree = get_object_tree(message)
329
400
 
330
- # Serialize ProtoBuf to bytes
331
- req = PushMessagesRequest(node=node, messages_list=[message_proto])
401
+ # Serialize Message
402
+ message_proto = message_to_proto(
403
+ message=remove_content_from_message(message)
404
+ )
405
+ req = PushMessagesRequest(
406
+ node=node,
407
+ messages_list=[message_proto],
408
+ message_object_trees=[object_tree],
409
+ )
332
410
 
333
- # Send the request
334
- res = _request(req, PushMessagesResponse, PATH_PUSH_MESSAGES)
335
- if res is None:
336
- return
411
+ # Send the request
412
+ res = _request(req, PushMessagesResponse, PATH_PUSH_MESSAGES)
413
+ if res:
414
+ log(
415
+ INFO,
416
+ "[Node] POST /%s: success, created result %s",
417
+ PATH_PUSH_MESSAGES,
418
+ res.results, # pylint: disable=no-member
419
+ )
420
+
421
+ if res and res.objects_to_push:
422
+ objs_to_push = set(res.objects_to_push[message.object_id].object_ids)
423
+
424
+ def fn(request: PushObjectRequest) -> PushObjectResponse:
425
+ res = _request(
426
+ req=request,
427
+ res_type=PushObjectResponse,
428
+ api_path=PATH_PUSH_OBJECT,
429
+ )
430
+ if res is None:
431
+ raise ValueError("PushObjectResponse is None.")
432
+ return res
433
+
434
+ try:
435
+ push_objects(
436
+ all_objects,
437
+ push_object_fn=make_push_object_fn_rest(
438
+ push_object_rest=fn,
439
+ node=node,
440
+ run_id=message_proto.metadata.run_id,
441
+ ),
442
+ object_ids_to_push=objs_to_push,
443
+ )
444
+ log(DEBUG, "Pushed %s objects to servicer.", len(objs_to_push))
445
+ except ValueError as e:
446
+ log(
447
+ ERROR,
448
+ "Pushing objects failed. Potential irrecoverable error: %s",
449
+ str(e),
450
+ )
451
+ log(ERROR, str(e))
337
452
 
338
- log(
339
- INFO,
340
- "[Node] POST /%s: success, created result %s",
341
- PATH_PUSH_MESSAGES,
342
- res.results, # pylint: disable=no-member
343
- )
453
+ # Cleanup
454
+ metadata = None
344
455
 
345
456
  def get_run(run_id: int) -> Run:
346
457
  # Construct the request
@@ -20,7 +20,7 @@ from collections.abc import Sequence
20
20
  from pathlib import Path
21
21
  from typing import Optional, Union
22
22
 
23
- from flwr.common.typing import UserInfo
23
+ from flwr.common.typing import AccountInfo
24
24
  from flwr.proto.exec_pb2_grpc import ExecStub
25
25
 
26
26
  from ..typing import UserAuthCredentials, UserAuthLoginDetails
@@ -53,7 +53,7 @@ class ExecAuthPlugin(ABC):
53
53
  @abstractmethod
54
54
  def validate_tokens_in_metadata(
55
55
  self, metadata: Sequence[tuple[str, Union[str, bytes]]]
56
- ) -> tuple[bool, Optional[UserInfo]]:
56
+ ) -> tuple[bool, Optional[AccountInfo]]:
57
57
  """Validate authentication tokens in the provided metadata."""
58
58
 
59
59
  @abstractmethod
@@ -63,7 +63,9 @@ class ExecAuthPlugin(ABC):
63
63
  @abstractmethod
64
64
  def refresh_tokens(
65
65
  self, metadata: Sequence[tuple[str, Union[str, bytes]]]
66
- ) -> tuple[Optional[Sequence[tuple[str, Union[str, bytes]]]], Optional[UserInfo]]:
66
+ ) -> tuple[
67
+ Optional[Sequence[tuple[str, Union[str, bytes]]]], Optional[AccountInfo]
68
+ ]:
67
69
  """Refresh authentication tokens in the provided metadata."""
68
70
 
69
71
 
@@ -84,7 +86,7 @@ class ExecAuthzPlugin(ABC): # pylint: disable=too-few-public-methods
84
86
  """Abstract constructor."""
85
87
 
86
88
  @abstractmethod
87
- def verify_user_authorization(self, user_info: UserInfo) -> bool:
89
+ def verify_user_authorization(self, account_info: AccountInfo) -> bool:
88
90
  """Verify user authorization request."""
89
91
 
90
92
 
@@ -21,7 +21,7 @@ from typing import Optional, Union
21
21
  import grpc
22
22
  from google.protobuf.message import Message as GrpcMessage
23
23
 
24
- from flwr.common.typing import LogEntry, UserInfo
24
+ from flwr.common.typing import AccountInfo, LogEntry
25
25
 
26
26
 
27
27
  class EventLogWriterPlugin(ABC):
@@ -36,7 +36,7 @@ class EventLogWriterPlugin(ABC):
36
36
  self,
37
37
  request: GrpcMessage,
38
38
  context: grpc.ServicerContext,
39
- user_info: Optional[UserInfo],
39
+ account_info: Optional[AccountInfo],
40
40
  method_name: str,
41
41
  ) -> LogEntry:
42
42
  """Compose pre-event log entry from the provided request and context."""
@@ -46,7 +46,7 @@ class EventLogWriterPlugin(ABC):
46
46
  self,
47
47
  request: GrpcMessage,
48
48
  context: grpc.ServicerContext,
49
- user_info: Optional[UserInfo],
49
+ account_info: Optional[AccountInfo],
50
50
  method_name: str,
51
51
  response: Optional[Union[GrpcMessage, BaseException]],
52
52
  ) -> LogEntry:
flwr/common/inflatable.py CHANGED
@@ -18,8 +18,13 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import hashlib
21
+ import threading
22
+ from collections.abc import Iterator
23
+ from contextlib import contextmanager
21
24
  from typing import TypeVar, cast
22
25
 
26
+ from flwr.proto.message_pb2 import ObjectTree # pylint: disable=E0611
27
+
23
28
  from .constant import HEAD_BODY_DIVIDER, HEAD_VALUE_DIVIDER
24
29
 
25
30
 
@@ -33,6 +38,33 @@ class UnexpectedObjectContentError(Exception):
33
38
  )
34
39
 
35
40
 
41
+ _ctx = threading.local()
42
+
43
+
44
+ def _is_recompute_enabled() -> bool:
45
+ """Check if recomputing object IDs is enabled."""
46
+ return getattr(_ctx, "recompute_object_id_enabled", True)
47
+
48
+
49
+ def _get_computed_object_ids() -> set[str]:
50
+ """Get the set of computed object IDs."""
51
+ return getattr(_ctx, "computed_object_ids", set())
52
+
53
+
54
+ @contextmanager
55
+ def no_object_id_recompute() -> Iterator[None]:
56
+ """Context manager to disable recomputing object IDs."""
57
+ old_value = _is_recompute_enabled()
58
+ old_set = _get_computed_object_ids()
59
+ _ctx.recompute_object_id_enabled = False
60
+ _ctx.computed_object_ids = set()
61
+ try:
62
+ yield
63
+ finally:
64
+ _ctx.recompute_object_id_enabled = old_value
65
+ _ctx.computed_object_ids = old_set
66
+
67
+
36
68
  class InflatableObject:
37
69
  """Base class for inflatable objects."""
38
70
 
@@ -65,8 +97,23 @@ class InflatableObject:
65
97
  @property
66
98
  def object_id(self) -> str:
67
99
  """Get object_id."""
100
+ # If recomputing object ID is disabled and the object ID is already computed,
101
+ # return the cached object ID.
102
+ if (
103
+ not _is_recompute_enabled()
104
+ and (obj_id := self.__dict__.get("_object_id"))
105
+ in _get_computed_object_ids()
106
+ ):
107
+ return cast(str, obj_id)
108
+
68
109
  if self.is_dirty or "_object_id" not in self.__dict__:
69
- self.__dict__["_object_id"] = get_object_id(self.deflate())
110
+ obj_id = get_object_id(self.deflate())
111
+ self.__dict__["_object_id"] = obj_id
112
+
113
+ # If recomputing object ID is disabled, add the object ID to the set of
114
+ # computed object IDs to avoid recomputing it within the context.
115
+ if not _is_recompute_enabled():
116
+ _get_computed_object_ids().add(obj_id)
70
117
  return cast(str, self.__dict__["_object_id"])
71
118
 
72
119
  @property
@@ -219,3 +266,25 @@ def get_all_nested_objects(obj: InflatableObject) -> dict[str, InflatableObject]
219
266
  ret[obj.object_id] = obj
220
267
 
221
268
  return ret
269
+
270
+
271
+ def get_object_tree(obj: InflatableObject) -> ObjectTree:
272
+ """Get a tree representation of the InflatableObject."""
273
+ tree_children = []
274
+ if children := obj.children:
275
+ for child in children.values():
276
+ tree_children.append(get_object_tree(child))
277
+ return ObjectTree(object_id=obj.object_id, children=tree_children)
278
+
279
+
280
+ def iterate_object_tree(
281
+ tree: ObjectTree,
282
+ ) -> Iterator[ObjectTree]:
283
+ """Iterate over the object tree and yield object IDs.
284
+
285
+ This function performs a post-order traversal of the tree, yielding the object ID of
286
+ each node after all its children have been yielded.
287
+ """
288
+ for child in tree.children:
289
+ yield from iterate_object_tree(child)
290
+ yield tree
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """InflatableObject utils."""
15
+ """InflatableObject gRPC utils."""
16
16
 
17
17
 
18
18
  from typing import Callable
@@ -0,0 +1,99 @@
1
+ # Copyright 2025 Flower Labs GmbH. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """InflatableObject REST utils."""
16
+
17
+
18
+ from typing import Callable
19
+
20
+ from flwr.proto.message_pb2 import ( # pylint: disable=E0611
21
+ PullObjectRequest,
22
+ PullObjectResponse,
23
+ PushObjectRequest,
24
+ PushObjectResponse,
25
+ )
26
+ from flwr.proto.node_pb2 import Node # pylint: disable=E0611
27
+
28
+ from .inflatable_utils import ObjectIdNotPreregisteredError, ObjectUnavailableError
29
+
30
+
31
+ def make_pull_object_fn_rest(
32
+ pull_object_rest: Callable[[PullObjectRequest], PullObjectResponse],
33
+ node: Node,
34
+ run_id: int,
35
+ ) -> Callable[[str], bytes]:
36
+ """Create a pull object function that uses REST to pull objects.
37
+
38
+ Parameters
39
+ ----------
40
+ pull_object_rest : Callable[[PullObjectRequest], PullObjectResponse]
41
+ A function that makes a POST request against the `/push-object` REST endpoint
42
+ node : Node
43
+ The node making the request.
44
+ run_id : int
45
+ The run ID for the current operation.
46
+
47
+ Returns
48
+ -------
49
+ Callable[[str], bytes]
50
+ A function that takes an object ID and returns the object content as bytes.
51
+ The function raises `ObjectIdNotPreregisteredError` if the object ID is not
52
+ pre-registered, or `ObjectUnavailableError` if the object is not yet available.
53
+ """
54
+
55
+ def pull_object_fn(object_id: str) -> bytes:
56
+ request = PullObjectRequest(node=node, run_id=run_id, object_id=object_id)
57
+ response: PullObjectResponse = pull_object_rest(request)
58
+ if not response.object_found:
59
+ raise ObjectIdNotPreregisteredError(object_id)
60
+ if not response.object_available:
61
+ raise ObjectUnavailableError(object_id)
62
+ return response.object_content
63
+
64
+ return pull_object_fn
65
+
66
+
67
+ def make_push_object_fn_rest(
68
+ push_object_rest: Callable[[PushObjectRequest], PushObjectResponse],
69
+ node: Node,
70
+ run_id: int,
71
+ ) -> Callable[[str, bytes], None]:
72
+ """Create a push object function that uses REST to push objects.
73
+
74
+ Parameters
75
+ ----------
76
+ push_object_rest : Callable[[PushObjectRequest], PushObjectResponse]
77
+ A function that makes a POST request against the `/pull-object` REST endpoint
78
+ node : Node
79
+ The node making the request.
80
+ run_id : int
81
+ The run ID for the current operation.
82
+
83
+ Returns
84
+ -------
85
+ Callable[[str, bytes], None]
86
+ A function that takes an object ID and its content as bytes, and pushes it
87
+ to the servicer. The function raises `ObjectIdNotPreregisteredError` if
88
+ the object ID is not pre-registered.
89
+ """
90
+
91
+ def push_object_fn(object_id: str, object_content: bytes) -> None:
92
+ request = PushObjectRequest(
93
+ node=node, run_id=run_id, object_id=object_id, object_content=object_content
94
+ )
95
+ response: PushObjectResponse = push_object_rest(request)
96
+ if not response.stored:
97
+ raise ObjectIdNotPreregisteredError(object_id)
98
+
99
+ return push_object_fn
flwr/common/serde.py CHANGED
@@ -630,6 +630,7 @@ def run_to_proto(run: typing.Run) -> ProtoRun:
630
630
  running_at=run.running_at,
631
631
  finished_at=run.finished_at,
632
632
  status=run_status_to_proto(run.status),
633
+ flwr_aid=run.flwr_aid,
633
634
  )
634
635
  return proto
635
636
 
@@ -647,6 +648,7 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
647
648
  running_at=run_proto.running_at,
648
649
  finished_at=run_proto.finished_at,
649
650
  status=run_status_from_proto(run_proto.status),
651
+ flwr_aid=run_proto.flwr_aid,
650
652
  )
651
653
  return run
652
654
 
flwr/common/typing.py CHANGED
@@ -230,6 +230,7 @@ class Run: # pylint: disable=too-many-instance-attributes
230
230
  running_at: str
231
231
  finished_at: str
232
232
  status: RunStatus
233
+ flwr_aid: str
233
234
 
234
235
  @classmethod
235
236
  def create_empty(cls, run_id: int) -> "Run":
@@ -245,6 +246,7 @@ class Run: # pylint: disable=too-many-instance-attributes
245
246
  running_at="",
246
247
  finished_at="",
247
248
  status=RunStatus(status="", sub_status="", details=""),
249
+ flwr_aid="",
248
250
  )
249
251
 
250
252
 
@@ -289,11 +291,11 @@ class UserAuthCredentials:
289
291
 
290
292
 
291
293
  @dataclass
292
- class UserInfo:
294
+ class AccountInfo:
293
295
  """User information for event log."""
294
296
 
295
- user_id: Optional[str]
296
- user_name: Optional[str]
297
+ flwr_aid: Optional[str]
298
+ account_name: Optional[str]
297
299
 
298
300
 
299
301
  @dataclass