flwr-nightly 1.17.0.dev20250319__py3-none-any.whl → 1.17.0.dev20250321__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. flwr/cli/run/run.py +5 -9
  2. flwr/client/app.py +6 -4
  3. flwr/client/client_app.py +10 -12
  4. flwr/client/clientapp/app.py +2 -2
  5. flwr/client/grpc_client/connection.py +24 -21
  6. flwr/client/message_handler/message_handler.py +27 -27
  7. flwr/client/mod/__init__.py +2 -2
  8. flwr/client/mod/centraldp_mods.py +7 -7
  9. flwr/client/mod/comms_mods.py +16 -22
  10. flwr/client/mod/localdp_mod.py +4 -4
  11. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  12. flwr/client/run_info_store.py +2 -2
  13. flwr/common/__init__.py +12 -4
  14. flwr/common/config.py +4 -4
  15. flwr/common/constant.py +1 -1
  16. flwr/common/context.py +4 -4
  17. flwr/common/message.py +269 -101
  18. flwr/common/record/__init__.py +8 -4
  19. flwr/common/record/{parametersrecord.py → arrayrecord.py} +75 -32
  20. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  21. flwr/common/record/conversion_utils.py +1 -1
  22. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  23. flwr/common/record/recorddict.py +288 -0
  24. flwr/common/recorddict_compat.py +410 -0
  25. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  26. flwr/common/serde.py +66 -71
  27. flwr/common/typing.py +8 -8
  28. flwr/proto/exec_pb2.py +3 -3
  29. flwr/proto/exec_pb2.pyi +3 -3
  30. flwr/proto/message_pb2.py +12 -12
  31. flwr/proto/message_pb2.pyi +9 -9
  32. flwr/proto/recorddict_pb2.py +70 -0
  33. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  34. flwr/proto/run_pb2.py +31 -31
  35. flwr/proto/run_pb2.pyi +3 -3
  36. flwr/server/compat/grid_client_proxy.py +31 -31
  37. flwr/server/grid/grid.py +3 -3
  38. flwr/server/grid/grpc_grid.py +15 -23
  39. flwr/server/grid/inmemory_grid.py +14 -20
  40. flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
  41. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  42. flwr/server/superlink/fleet/vce/vce_api.py +1 -3
  43. flwr/server/superlink/linkstate/in_memory_linkstate.py +5 -5
  44. flwr/server/superlink/linkstate/linkstate.py +4 -4
  45. flwr/server/superlink/linkstate/sqlite_linkstate.py +21 -25
  46. flwr/server/superlink/linkstate/utils.py +18 -15
  47. flwr/server/superlink/serverappio/serverappio_servicer.py +3 -3
  48. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  49. flwr/server/utils/validator.py +4 -4
  50. flwr/server/workflow/default_workflows.py +34 -41
  51. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +37 -39
  52. flwr/simulation/app.py +2 -2
  53. flwr/simulation/ray_transport/ray_actor.py +4 -2
  54. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  55. flwr/simulation/run_simulation.py +5 -5
  56. flwr/superexec/deployment.py +4 -4
  57. flwr/superexec/exec_servicer.py +2 -2
  58. flwr/superexec/executor.py +3 -3
  59. flwr/superexec/simulation.py +3 -3
  60. {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/METADATA +1 -1
  61. {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/RECORD +66 -66
  62. flwr/common/record/recordset.py +0 -209
  63. flwr/common/recordset_compat.py +0 -418
  64. flwr/proto/recordset_pb2.py +0 -70
  65. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  66. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  67. {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/LICENSE +0 -0
  68. {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/WHEEL +0 -0
  69. {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/entry_points.txt +0 -0
@@ -130,9 +130,7 @@ def worker(
130
130
  e_code = ErrorCode.UNKNOWN
131
131
 
132
132
  reason = str(type(ex)) + ":<'" + str(ex) + "'>"
133
- out_mssg = message.create_error_reply(
134
- error=Error(code=e_code, reason=reason)
135
- )
133
+ out_mssg = Message(Error(code=e_code, reason=reason), reply_to=message)
136
134
 
137
135
  finally:
138
136
  if out_mssg:
@@ -32,7 +32,7 @@ from flwr.common.constant import (
32
32
  SUPERLINK_NODE_ID,
33
33
  Status,
34
34
  )
35
- from flwr.common.record import ConfigsRecord
35
+ from flwr.common.record import ConfigRecord
36
36
  from flwr.common.typing import Run, RunStatus, UserConfig
37
37
  from flwr.server.superlink.linkstate.linkstate import LinkState
38
38
  from flwr.server.utils import validate_message
@@ -69,7 +69,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
69
69
  # Map run_id to RunRecord
70
70
  self.run_ids: dict[int, RunRecord] = {}
71
71
  self.contexts: dict[int, Context] = {}
72
- self.federation_options: dict[int, ConfigsRecord] = {}
72
+ self.federation_options: dict[int, ConfigRecord] = {}
73
73
  self.message_ins_store: dict[UUID, Message] = {}
74
74
  self.message_res_store: dict[UUID, Message] = {}
75
75
  self.message_ins_id_to_message_res_id: dict[UUID, UUID] = {}
@@ -158,7 +158,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
158
158
  res_metadata = message.metadata
159
159
  with self.lock:
160
160
  # Check if the Message it is replying to exists and is valid
161
- msg_ins_id = res_metadata.reply_to_message
161
+ msg_ins_id = res_metadata.reply_to_message_id
162
162
  msg_ins = self.message_ins_store.get(UUID(msg_ins_id))
163
163
 
164
164
  # Ensure that dst_node_id of original Message matches the src_node_id of
@@ -399,7 +399,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
399
399
  fab_version: Optional[str],
400
400
  fab_hash: Optional[str],
401
401
  override_config: UserConfig,
402
- federation_options: ConfigsRecord,
402
+ federation_options: ConfigRecord,
403
403
  ) -> int:
404
404
  """Create a new run for the specified `fab_hash`."""
405
405
  # Sample a random int64 as run_id
@@ -528,7 +528,7 @@ class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
528
528
 
529
529
  return pending_run_id
530
530
 
531
- def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
531
+ def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
532
532
  """Retrieve the federation options for the specified `run_id`."""
533
533
  with self.lock:
534
534
  if run_id not in self.run_ids:
@@ -20,7 +20,7 @@ from typing import Optional
20
20
  from uuid import UUID
21
21
 
22
22
  from flwr.common import Context, Message
23
- from flwr.common.record import ConfigsRecord
23
+ from flwr.common.record import ConfigRecord
24
24
  from flwr.common.typing import Run, RunStatus, UserConfig
25
25
 
26
26
 
@@ -164,7 +164,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
164
164
  fab_version: Optional[str],
165
165
  fab_hash: Optional[str],
166
166
  override_config: UserConfig,
167
- federation_options: ConfigsRecord,
167
+ federation_options: ConfigRecord,
168
168
  ) -> int:
169
169
  """Create a new run for the specified `fab_hash`."""
170
170
 
@@ -236,7 +236,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
236
236
  """
237
237
 
238
238
  @abc.abstractmethod
239
- def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
239
+ def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
240
240
  """Retrieve the federation options for the specified `run_id`.
241
241
 
242
242
  Parameters
@@ -246,7 +246,7 @@ class LinkState(abc.ABC): # pylint: disable=R0904
246
246
 
247
247
  Returns
248
248
  -------
249
- Optional[ConfigsRecord]
249
+ Optional[ConfigRecord]
250
250
  The federation options for the run if it exists; None otherwise.
251
251
  """
252
252
 
@@ -35,18 +35,19 @@ from flwr.common.constant import (
35
35
  SUPERLINK_NODE_ID,
36
36
  Status,
37
37
  )
38
- from flwr.common.record import ConfigsRecord
38
+ from flwr.common.message import make_message
39
+ from flwr.common.record import ConfigRecord
39
40
  from flwr.common.serde import (
40
41
  error_from_proto,
41
42
  error_to_proto,
42
- recordset_from_proto,
43
- recordset_to_proto,
43
+ recorddict_from_proto,
44
+ recorddict_to_proto,
44
45
  )
45
46
  from flwr.common.typing import Run, RunStatus, UserConfig
46
47
 
47
48
  # pylint: disable=E0611
48
49
  from flwr.proto.error_pb2 import Error as ProtoError
49
- from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
50
+ from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
50
51
 
51
52
  # pylint: enable=E0611
52
53
  from flwr.server.utils.validator import validate_message
@@ -54,8 +55,8 @@ from flwr.server.utils.validator import validate_message
54
55
  from .linkstate import LinkState
55
56
  from .utils import (
56
57
  check_node_availability_for_in_message,
57
- configsrecord_from_bytes,
58
- configsrecord_to_bytes,
58
+ configrecord_from_bytes,
59
+ configrecord_to_bytes,
59
60
  context_from_bytes,
60
61
  context_to_bytes,
61
62
  convert_sint64_to_uint64,
@@ -131,7 +132,7 @@ CREATE TABLE IF NOT EXISTS message_ins(
131
132
  run_id INTEGER,
132
133
  src_node_id INTEGER,
133
134
  dst_node_id INTEGER,
134
- reply_to_message TEXT,
135
+ reply_to_message_id TEXT,
135
136
  created_at REAL,
136
137
  delivered_at TEXT,
137
138
  ttl REAL,
@@ -150,7 +151,7 @@ CREATE TABLE IF NOT EXISTS message_res(
150
151
  run_id INTEGER,
151
152
  src_node_id INTEGER,
152
153
  dst_node_id INTEGER,
153
- reply_to_message TEXT,
154
+ reply_to_message_id TEXT,
154
155
  created_at REAL,
155
156
  delivered_at TEXT,
156
157
  ttl REAL,
@@ -373,7 +374,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
373
374
  return None
374
375
 
375
376
  res_metadata = message.metadata
376
- msg_ins_id = res_metadata.reply_to_message
377
+ msg_ins_id = res_metadata.reply_to_message_id
377
378
  msg_ins = self.get_valid_message_ins(msg_ins_id)
378
379
  if msg_ins is None:
379
380
  log(
@@ -495,7 +496,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
495
496
  query = f"""
496
497
  SELECT *
497
498
  FROM message_res
498
- WHERE reply_to_message IN ({",".join(["?"] * len(message_ids))})
499
+ WHERE reply_to_message_id IN ({",".join(["?"] * len(message_ids))})
499
500
  AND delivered_at = "";
500
501
  """
501
502
  rows = self.query(query, tuple(str(message_id) for message_id in message_ids))
@@ -568,7 +569,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
568
569
  # Delete reply Message
569
570
  query_2 = f"""
570
571
  DELETE FROM message_res
571
- WHERE reply_to_message IN ({placeholders});
572
+ WHERE reply_to_message_id IN ({placeholders});
572
573
  """
573
574
 
574
575
  with self.conn:
@@ -726,7 +727,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
726
727
  fab_version: Optional[str],
727
728
  fab_hash: Optional[str],
728
729
  override_config: UserConfig,
729
- federation_options: ConfigsRecord,
730
+ federation_options: ConfigRecord,
730
731
  ) -> int:
731
732
  """Create a new run for the specified `fab_id` and `fab_version`."""
732
733
  # Sample a random int64 as run_id
@@ -752,7 +753,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
752
753
  fab_version,
753
754
  fab_hash,
754
755
  override_config_json,
755
- configsrecord_to_bytes(federation_options),
756
+ configrecord_to_bytes(federation_options),
756
757
  ]
757
758
  data += [
758
759
  now().isoformat(),
@@ -910,7 +911,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
910
911
 
911
912
  return pending_run_id
912
913
 
913
- def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
914
+ def get_federation_options(self, run_id: int) -> Optional[ConfigRecord]:
914
915
  """Retrieve the federation options for the specified `run_id`."""
915
916
  # Convert the uint64 value to sint64 for SQLite
916
917
  sint64_run_id = convert_uint64_to_sint64(run_id)
@@ -923,7 +924,7 @@ class SqliteLinkState(LinkState): # pylint: disable=R0904
923
924
  return None
924
925
 
925
926
  row = rows[0]
926
- return configsrecord_from_bytes(row["federation_options"])
927
+ return configrecord_from_bytes(row["federation_options"])
927
928
 
928
929
  def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
929
930
  """Acknowledge a ping received from a node, serving as a heartbeat.
@@ -1064,7 +1065,7 @@ def message_to_dict(message: Message) -> dict[str, Any]:
1064
1065
  "run_id": message.metadata.run_id,
1065
1066
  "src_node_id": message.metadata.src_node_id,
1066
1067
  "dst_node_id": message.metadata.dst_node_id,
1067
- "reply_to_message": message.metadata.reply_to_message,
1068
+ "reply_to_message_id": message.metadata.reply_to_message_id,
1068
1069
  "created_at": message.metadata.created_at,
1069
1070
  "delivered_at": message.metadata.delivered_at,
1070
1071
  "ttl": message.metadata.ttl,
@@ -1074,7 +1075,7 @@ def message_to_dict(message: Message) -> dict[str, Any]:
1074
1075
  }
1075
1076
 
1076
1077
  if message.has_content():
1077
- result["content"] = recordset_to_proto(message.content).SerializeToString()
1078
+ result["content"] = recorddict_to_proto(message.content).SerializeToString()
1078
1079
  else:
1079
1080
  result["error"] = error_to_proto(message.error).SerializeToString()
1080
1081
 
@@ -1085,20 +1086,15 @@ def dict_to_message(message_dict: dict[str, Any]) -> Message:
1085
1086
  """Transform dict to Message."""
1086
1087
  content, error = None, None
1087
1088
  if (b_content := message_dict.pop("content")) is not None:
1088
- content = recordset_from_proto(ProtoRecordSet.FromString(b_content))
1089
+ content = recorddict_from_proto(ProtoRecordDict.FromString(b_content))
1089
1090
  if (b_error := message_dict.pop("error")) is not None:
1090
1091
  error = error_from_proto(ProtoError.FromString(b_error))
1091
1092
 
1092
1093
  # Metadata constructor doesn't allow passing created_at. We set it later
1093
1094
  metadata = Metadata(
1094
- **{
1095
- k: v
1096
- for k, v in message_dict.items()
1097
- if k not in ["created_at", "delivered_at"]
1098
- }
1095
+ **{k: v for k, v in message_dict.items() if k not in ["delivered_at"]}
1099
1096
  )
1100
- msg = Message(metadata=metadata, content=content, error=error)
1101
- msg.metadata.__dict__["_created_at"] = message_dict["created_at"]
1097
+ msg = make_message(metadata=metadata, content=content, error=error)
1102
1098
  msg.metadata.delivered_at = message_dict["delivered_at"]
1103
1099
  return msg
1104
1100
 
@@ -19,7 +19,7 @@ from os import urandom
19
19
  from typing import Optional
20
20
  from uuid import UUID, uuid4
21
21
 
22
- from flwr.common import ConfigsRecord, Context, Error, Message, Metadata, now, serde
22
+ from flwr.common import ConfigRecord, Context, Error, Message, Metadata, now, serde
23
23
  from flwr.common.constant import (
24
24
  SUPERLINK_NODE_ID,
25
25
  ErrorCode,
@@ -27,11 +27,12 @@ from flwr.common.constant import (
27
27
  Status,
28
28
  SubStatus,
29
29
  )
30
+ from flwr.common.message import make_message
30
31
  from flwr.common.typing import RunStatus
31
32
 
32
33
  # pylint: disable=E0611
33
34
  from flwr.proto.message_pb2 import Context as ProtoContext
34
- from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
35
+ from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
35
36
 
36
37
  # pylint: enable=E0611
37
38
  VALID_RUN_STATUS_TRANSITIONS = {
@@ -171,15 +172,15 @@ def context_from_bytes(context_bytes: bytes) -> Context:
171
172
  return serde.context_from_proto(ProtoContext.FromString(context_bytes))
172
173
 
173
174
 
174
- def configsrecord_to_bytes(configs_record: ConfigsRecord) -> bytes:
175
- """Serialize a `ConfigsRecord` to bytes."""
176
- return serde.configs_record_to_proto(configs_record).SerializeToString()
175
+ def configrecord_to_bytes(config_record: ConfigRecord) -> bytes:
176
+ """Serialize a `ConfigRecord` to bytes."""
177
+ return serde.config_record_to_proto(config_record).SerializeToString()
177
178
 
178
179
 
179
- def configsrecord_from_bytes(configsrecord_bytes: bytes) -> ConfigsRecord:
180
- """Deserialize `ConfigsRecord` from bytes."""
181
- return serde.configs_record_from_proto(
182
- ProtoConfigsRecord.FromString(configsrecord_bytes)
180
+ def configrecord_from_bytes(configrecord_bytes: bytes) -> ConfigRecord:
181
+ """Deserialize `ConfigRecord` from bytes."""
182
+ return serde.config_record_from_proto(
183
+ ProtoConfigRecord.FromString(configrecord_bytes)
183
184
  )
184
185
 
185
186
 
@@ -247,13 +248,14 @@ def create_message_error_unavailable_res_message(
247
248
  message_id=str(uuid4()),
248
249
  src_node_id=SUPERLINK_NODE_ID,
249
250
  dst_node_id=SUPERLINK_NODE_ID,
250
- reply_to_message=ins_metadata.message_id,
251
+ reply_to_message_id=ins_metadata.message_id,
251
252
  group_id=ins_metadata.group_id,
252
253
  message_type=ins_metadata.message_type,
254
+ created_at=current_time,
253
255
  ttl=ttl,
254
256
  )
255
257
 
256
- return Message(
258
+ return make_message(
257
259
  metadata=metadata,
258
260
  error=Error(
259
261
  code=(
@@ -270,7 +272,7 @@ def create_message_error_unavailable_res_message(
270
272
  )
271
273
 
272
274
 
273
- def create_message_error_unavailable_ins_message(reply_to_message: UUID) -> Message:
275
+ def create_message_error_unavailable_ins_message(reply_to_message_id: UUID) -> Message:
274
276
  """Error to indicate that the enquired Message had expired before reply arrived or
275
277
  that it isn't found."""
276
278
  metadata = Metadata(
@@ -278,13 +280,14 @@ def create_message_error_unavailable_ins_message(reply_to_message: UUID) -> Mess
278
280
  message_id=str(uuid4()),
279
281
  src_node_id=SUPERLINK_NODE_ID,
280
282
  dst_node_id=SUPERLINK_NODE_ID,
281
- reply_to_message=str(reply_to_message),
283
+ reply_to_message_id=str(reply_to_message_id),
282
284
  group_id="", # Unknown
283
285
  message_type=MessageType.SYSTEM,
286
+ created_at=now().timestamp(),
284
287
  ttl=0,
285
288
  )
286
289
 
287
- return Message(
290
+ return make_message(
288
291
  metadata=metadata,
289
292
  error=Error(
290
293
  code=ErrorCode.MESSAGE_UNAVAILABLE,
@@ -372,7 +375,7 @@ def verify_found_message_replies(
372
375
  ret_dict: dict[UUID, Message] = {}
373
376
  current = current_time if current_time else now().timestamp()
374
377
  for message_res in found_message_res_list:
375
- message_ins_id = UUID(message_res.metadata.reply_to_message)
378
+ message_ins_id = UUID(message_res.metadata.reply_to_message_id)
376
379
  if update_set:
377
380
  inquired_message_ids.remove(message_ins_id)
378
381
  # Check if the reply Message has expired
@@ -22,7 +22,7 @@ from uuid import UUID
22
22
 
23
23
  import grpc
24
24
 
25
- from flwr.common import ConfigsRecord, Message
25
+ from flwr.common import ConfigRecord, Message
26
26
  from flwr.common.constant import SUPERLINK_NODE_ID, Status
27
27
  from flwr.common.logger import log
28
28
  from flwr.common.serde import (
@@ -127,7 +127,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
127
127
  request.fab_version,
128
128
  fab_hash,
129
129
  user_config_from_proto(request.override_config),
130
- ConfigsRecord(),
130
+ ConfigRecord(),
131
131
  )
132
132
  return CreateRunResponse(run_id=run_id)
133
133
 
@@ -206,7 +206,7 @@ class ServerAppIoServicer(serverappio_pb2_grpc.ServerAppIoServicer):
206
206
 
207
207
  # Delete the instruction Messages and their replies if found
208
208
  message_ins_ids_to_delete = {
209
- UUID(msg_res.metadata.reply_to_message) for msg_res in messages_res
209
+ UUID(msg_res.metadata.reply_to_message_id) for msg_res in messages_res
210
210
  }
211
211
 
212
212
  state.delete_messages(message_ins_ids=message_ins_ids_to_delete)
@@ -24,7 +24,7 @@ from grpc import ServicerContext
24
24
  from flwr.common.constant import Status
25
25
  from flwr.common.logger import log
26
26
  from flwr.common.serde import (
27
- configs_record_to_proto,
27
+ config_record_to_proto,
28
28
  context_from_proto,
29
29
  context_to_proto,
30
30
  fab_to_proto,
@@ -182,5 +182,5 @@ class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer):
182
182
  )
183
183
  return GetFederationOptionsResponse()
184
184
  return GetFederationOptionsResponse(
185
- federation_options=configs_record_to_proto(federation_options)
185
+ federation_options=config_record_to_proto(federation_options)
186
186
  )
@@ -68,8 +68,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
68
68
 
69
69
  # Link respose to original message
70
70
  if not is_reply_message:
71
- if metadata.reply_to_message != "":
72
- validation_errors.append("`metadata.reply_to_message` MUST not be set.")
71
+ if metadata.reply_to_message_id != "":
72
+ validation_errors.append("`metadata.reply_to_message_id` MUST not be set.")
73
73
  if metadata.src_node_id != SUPERLINK_NODE_ID:
74
74
  validation_errors.append(
75
75
  f"`metadata.src_node_id` is not {SUPERLINK_NODE_ID} (SuperLink node ID)"
@@ -79,8 +79,8 @@ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
79
79
  f"`metadata.dst_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
80
80
  )
81
81
  else:
82
- if metadata.reply_to_message == "":
83
- validation_errors.append("`metadata.reply_to_message` MUST be set.")
82
+ if metadata.reply_to_message_id == "":
83
+ validation_errors.append("`metadata.reply_to_message_id` MUST be set.")
84
84
  if metadata.src_node_id == SUPERLINK_NODE_ID:
85
85
  validation_errors.append(
86
86
  f"`metadata.src_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
@@ -20,15 +20,16 @@ import timeit
20
20
  from logging import INFO, WARN
21
21
  from typing import Optional, Union, cast
22
22
 
23
- import flwr.common.recordset_compat as compat
23
+ import flwr.common.recorddict_compat as compat
24
24
  from flwr.common import (
25
+ ArrayRecord,
25
26
  Code,
26
- ConfigsRecord,
27
+ ConfigRecord,
27
28
  Context,
28
29
  EvaluateRes,
29
30
  FitRes,
30
31
  GetParametersIns,
31
- ParametersRecord,
32
+ Message,
32
33
  log,
33
34
  )
34
35
  from flwr.common.constant import MessageType, MessageTypeLegacy
@@ -77,9 +78,9 @@ class DefaultWorkflow:
77
78
 
78
79
  # Run federated learning for num_rounds
79
80
  start_time = timeit.default_timer()
80
- cfg = ConfigsRecord()
81
+ cfg = ConfigRecord()
81
82
  cfg[Key.START_TIME] = start_time
82
- context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg
83
+ context.state.config_records[MAIN_CONFIGS_RECORD] = cfg
83
84
 
84
85
  for current_round in range(1, context.config.num_rounds + 1):
85
86
  log(INFO, "")
@@ -129,21 +130,19 @@ def default_init_params_workflow(grid: Grid, context: Context) -> None:
129
130
  )
130
131
  if parameters is not None:
131
132
  log(INFO, "Using initial global parameters provided by strategy")
132
- paramsrecord = compat.parameters_to_parametersrecord(
133
- parameters, keep_input=True
134
- )
133
+ arr_record = compat.parameters_to_arrayrecord(parameters, keep_input=True)
135
134
  else:
136
135
  # Get initial parameters from one of the clients
137
136
  log(INFO, "Requesting initial parameters from one random client")
138
137
  random_client = context.client_manager.sample(1)[0]
139
138
  # Send GetParametersIns and get the response
140
- content = compat.getparametersins_to_recordset(GetParametersIns({}))
139
+ content = compat.getparametersins_to_recorddict(GetParametersIns({}))
141
140
  messages = grid.send_and_receive(
142
141
  [
143
- grid.create_message(
142
+ Message(
144
143
  content=content,
145
- message_type=MessageTypeLegacy.GET_PARAMETERS,
146
144
  dst_node_id=random_client.node_id,
145
+ message_type=MessageTypeLegacy.GET_PARAMETERS,
147
146
  group_id="0",
148
147
  )
149
148
  ]
@@ -152,26 +151,26 @@ def default_init_params_workflow(grid: Grid, context: Context) -> None:
152
151
 
153
152
  if (
154
153
  msg.has_content()
155
- and compat._extract_status_from_recordset( # pylint: disable=W0212
154
+ and compat._extract_status_from_recorddict( # pylint: disable=W0212
156
155
  "getparametersres", msg.content
157
156
  ).code
158
157
  == Code.OK
159
158
  ):
160
159
  log(INFO, "Received initial parameters from one random client")
161
- paramsrecord = next(iter(msg.content.parameters_records.values()))
160
+ arr_record = next(iter(msg.content.array_records.values()))
162
161
  else:
163
162
  log(
164
163
  WARN,
165
164
  "Failed to receive initial parameters from the client."
166
165
  " Empty initial parameters will be used.",
167
166
  )
168
- paramsrecord = ParametersRecord()
167
+ arr_record = ArrayRecord()
169
168
 
170
- context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
169
+ context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
171
170
 
172
171
  # Evaluate initial parameters
173
172
  log(INFO, "Starting evaluation of initial global parameters")
174
- parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
173
+ parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
175
174
  res = context.strategy.evaluate(0, parameters=parameters)
176
175
  if res is not None:
177
176
  log(
@@ -192,13 +191,13 @@ def default_centralized_evaluation_workflow(_: Grid, context: Context) -> None:
192
191
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
193
192
 
194
193
  # Retrieve current_round and start_time from the context
195
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
194
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
196
195
  current_round = cast(int, cfg[Key.CURRENT_ROUND])
197
196
  start_time = cast(float, cfg[Key.START_TIME])
198
197
 
199
198
  # Centralized evaluation
200
- parameters = compat.parametersrecord_to_parameters(
201
- record=context.state.parameters_records[MAIN_PARAMS_RECORD],
199
+ parameters = compat.arrayrecord_to_parameters(
200
+ record=context.state.array_records[MAIN_PARAMS_RECORD],
202
201
  keep_input=True,
203
202
  )
204
203
  res_cen = context.strategy.evaluate(current_round, parameters=parameters)
@@ -224,12 +223,10 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
224
223
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
225
224
 
226
225
  # Get current_round and parameters
227
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
226
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
228
227
  current_round = cast(int, cfg[Key.CURRENT_ROUND])
229
- parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
230
- parameters = compat.parametersrecord_to_parameters(
231
- parametersrecord, keep_input=True
232
- )
228
+ arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
229
+ parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
233
230
 
234
231
  # Get clients and their respective instructions from strategy
235
232
  client_instructions = context.strategy.configure_fit(
@@ -253,10 +250,10 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
253
250
 
254
251
  # Build out messages
255
252
  out_messages = [
256
- grid.create_message(
257
- content=compat.fitins_to_recordset(fitins, True),
258
- message_type=MessageType.TRAIN,
253
+ Message(
254
+ content=compat.fitins_to_recorddict(fitins, True),
259
255
  dst_node_id=proxy.node_id,
256
+ message_type=MessageType.TRAIN,
260
257
  group_id=str(current_round),
261
258
  )
262
259
  for proxy, fitins in client_instructions
@@ -282,7 +279,7 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
282
279
  for msg in messages:
283
280
  if msg.has_content():
284
281
  proxy = node_id_to_proxy[msg.metadata.src_node_id]
285
- fitres = compat.recordset_to_fitres(msg.content, False)
282
+ fitres = compat.recorddict_to_fitres(msg.content, False)
286
283
  if fitres.status.code == Code.OK:
287
284
  results.append((proxy, fitres))
288
285
  else:
@@ -295,10 +292,8 @@ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disab
295
292
 
296
293
  # Update the parameters and write history
297
294
  if parameters_aggregated:
298
- paramsrecord = compat.parameters_to_parametersrecord(
299
- parameters_aggregated, True
300
- )
301
- context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
295
+ arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
296
+ context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
302
297
  context.history.add_metrics_distributed_fit(
303
298
  server_round=current_round, metrics=metrics_aggregated
304
299
  )
@@ -311,12 +306,10 @@ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
311
306
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
312
307
 
313
308
  # Get current_round and parameters
314
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
309
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
315
310
  current_round = cast(int, cfg[Key.CURRENT_ROUND])
316
- parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
317
- parameters = compat.parametersrecord_to_parameters(
318
- parametersrecord, keep_input=True
319
- )
311
+ arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
312
+ parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
320
313
 
321
314
  # Get clients and their respective instructions from strategy
322
315
  client_instructions = context.strategy.configure_evaluate(
@@ -339,10 +332,10 @@ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
339
332
 
340
333
  # Build out messages
341
334
  out_messages = [
342
- grid.create_message(
343
- content=compat.evaluateins_to_recordset(evalins, True),
344
- message_type=MessageType.EVALUATE,
335
+ Message(
336
+ content=compat.evaluateins_to_recorddict(evalins, True),
345
337
  dst_node_id=proxy.node_id,
338
+ message_type=MessageType.EVALUATE,
346
339
  group_id=str(current_round),
347
340
  )
348
341
  for proxy, evalins in client_instructions
@@ -368,7 +361,7 @@ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
368
361
  for msg in messages:
369
362
  if msg.has_content():
370
363
  proxy = node_id_to_proxy[msg.metadata.src_node_id]
371
- evalres = compat.recordset_to_evaluateres(msg.content)
364
+ evalres = compat.recorddict_to_evaluateres(msg.content)
372
365
  if evalres.status.code == Code.OK:
373
366
  results.append((proxy, evalres))
374
367
  else: