flwr-nightly 1.17.0.dev20250319__py3-none-any.whl → 1.17.0.dev20250321__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. flwr/cli/run/run.py +5 -9
  2. flwr/client/app.py +6 -4
  3. flwr/client/client_app.py +10 -12
  4. flwr/client/clientapp/app.py +2 -2
  5. flwr/client/grpc_client/connection.py +24 -21
  6. flwr/client/message_handler/message_handler.py +27 -27
  7. flwr/client/mod/__init__.py +2 -2
  8. flwr/client/mod/centraldp_mods.py +7 -7
  9. flwr/client/mod/comms_mods.py +16 -22
  10. flwr/client/mod/localdp_mod.py +4 -4
  11. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  12. flwr/client/run_info_store.py +2 -2
  13. flwr/common/__init__.py +12 -4
  14. flwr/common/config.py +4 -4
  15. flwr/common/constant.py +1 -1
  16. flwr/common/context.py +4 -4
  17. flwr/common/message.py +269 -101
  18. flwr/common/record/__init__.py +8 -4
  19. flwr/common/record/{parametersrecord.py → arrayrecord.py} +75 -32
  20. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  21. flwr/common/record/conversion_utils.py +1 -1
  22. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  23. flwr/common/record/recorddict.py +288 -0
  24. flwr/common/recorddict_compat.py +410 -0
  25. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  26. flwr/common/serde.py +66 -71
  27. flwr/common/typing.py +8 -8
  28. flwr/proto/exec_pb2.py +3 -3
  29. flwr/proto/exec_pb2.pyi +3 -3
  30. flwr/proto/message_pb2.py +12 -12
  31. flwr/proto/message_pb2.pyi +9 -9
  32. flwr/proto/recorddict_pb2.py +70 -0
  33. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  34. flwr/proto/run_pb2.py +31 -31
  35. flwr/proto/run_pb2.pyi +3 -3
  36. flwr/server/compat/grid_client_proxy.py +31 -31
  37. flwr/server/grid/grid.py +3 -3
  38. flwr/server/grid/grpc_grid.py +15 -23
  39. flwr/server/grid/inmemory_grid.py +14 -20
  40. flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
  41. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  42. flwr/server/superlink/fleet/vce/vce_api.py +1 -3
  43. flwr/server/superlink/linkstate/in_memory_linkstate.py +5 -5
  44. flwr/server/superlink/linkstate/linkstate.py +4 -4
  45. flwr/server/superlink/linkstate/sqlite_linkstate.py +21 -25
  46. flwr/server/superlink/linkstate/utils.py +18 -15
  47. flwr/server/superlink/serverappio/serverappio_servicer.py +3 -3
  48. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  49. flwr/server/utils/validator.py +4 -4
  50. flwr/server/workflow/default_workflows.py +34 -41
  51. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +37 -39
  52. flwr/simulation/app.py +2 -2
  53. flwr/simulation/ray_transport/ray_actor.py +4 -2
  54. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  55. flwr/simulation/run_simulation.py +5 -5
  56. flwr/superexec/deployment.py +4 -4
  57. flwr/superexec/exec_servicer.py +2 -2
  58. flwr/superexec/executor.py +3 -3
  59. flwr/superexec/simulation.py +3 -3
  60. {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/METADATA +1 -1
  61. {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/RECORD +66 -66
  62. flwr/common/record/recordset.py +0 -209
  63. flwr/common/recordset_compat.py +0 -418
  64. flwr/proto/recordset_pb2.py +0 -70
  65. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  66. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  67. {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/LICENSE +0 -0
  68. {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/WHEEL +0 -0
  69. {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250321.dist-info}/entry_points.txt +0 -0
flwr/cli/run/run.py CHANGED
@@ -35,15 +35,11 @@ from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
35
35
  from flwr.common.config import (
36
36
  flatten_dict,
37
37
  parse_config_args,
38
- user_config_to_configsrecord,
38
+ user_config_to_configrecord,
39
39
  )
40
40
  from flwr.common.constant import CliOutputFormat
41
41
  from flwr.common.logger import print_json_error, redirect_output, restore_output
42
- from flwr.common.serde import (
43
- configs_record_to_proto,
44
- fab_to_proto,
45
- user_config_to_proto,
46
- )
42
+ from flwr.common.serde import config_record_to_proto, fab_to_proto, user_config_to_proto
47
43
  from flwr.common.typing import Fab
48
44
  from flwr.proto.exec_pb2 import StartRunRequest # pylint: disable=E0611
49
45
  from flwr.proto.exec_pb2_grpc import ExecStub
@@ -171,14 +167,14 @@ def _run_with_exec_api(
171
167
 
172
168
  fab = Fab(fab_hash, content)
173
169
 
174
- # Construct a `ConfigsRecord` out of a flattened `UserConfig`
170
+ # Construct a `ConfigRecord` out of a flattened `UserConfig`
175
171
  fed_conf = flatten_dict(federation_config.get("options", {}))
176
- c_record = user_config_to_configsrecord(fed_conf)
172
+ c_record = user_config_to_configrecord(fed_conf)
177
173
 
178
174
  req = StartRunRequest(
179
175
  fab=fab_to_proto(fab),
180
176
  override_config=user_config_to_proto(parse_config_args(config_overrides)),
181
- federation_options=configs_record_to_proto(c_record),
177
+ federation_options=config_record_to_proto(c_record),
182
178
  )
183
179
  with unauthenticated_exc_handler():
184
180
  res = stub.StartRun(req)
flwr/client/app.py CHANGED
@@ -495,8 +495,9 @@ def start_client_internal(
495
495
  context = run_info_store.retrieve_context(run_id=run_id)
496
496
  # Create an error reply message that will never be used to prevent
497
497
  # the used-before-assignment linting error
498
- reply_message = message.create_error_reply(
499
- error=Error(code=ErrorCode.UNKNOWN, reason="Unknown")
498
+ reply_message = Message(
499
+ Error(code=ErrorCode.UNKNOWN, reason="Unknown"),
500
+ reply_to=message,
500
501
  )
501
502
 
502
503
  # Handle app loading and task message
@@ -593,8 +594,9 @@ def start_client_internal(
593
594
  log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
594
595
 
595
596
  # Create error message
596
- reply_message = message.create_error_reply(
597
- error=Error(code=e_code, reason=reason)
597
+ reply_message = Message(
598
+ Error(code=e_code, reason=reason),
599
+ reply_to=message,
598
600
  )
599
601
  else:
600
602
  # No exception, update node state
flwr/client/client_app.py CHANGED
@@ -189,7 +189,7 @@ class ClientApp:
189
189
  >>> def train(message: Message, context: Context) -> Message:
190
190
  >>> print("Executing default train function")
191
191
  >>> # Create and return an echo reply message
192
- >>> return message.create_reply(content=message.content)
192
+ >>> return Message(message.content, reply_to=message)
193
193
 
194
194
  Registering a train function with a custom action name:
195
195
 
@@ -200,7 +200,7 @@ class ClientApp:
200
200
  >>> @app.train("custom_action")
201
201
  >>> def custom_action(message: Message, context: Context) -> Message:
202
202
  >>> print("Executing train function for custom action")
203
- >>> return message.create_reply(content=message.content)
203
+ >>> return Message(message.content, reply_to=message)
204
204
 
205
205
  Registering a train function with a function-specific Flower Mod:
206
206
 
@@ -213,7 +213,7 @@ class ClientApp:
213
213
  >>> def train(message: Message, context: Context) -> Message:
214
214
  >>> print("Executing train function with message size mod")
215
215
  >>> # Create and return an echo reply message
216
- >>> return message.create_reply(content=message.content)
216
+ >>> return Message(message.content, reply_to=message)
217
217
  """
218
218
  return _get_decorator(self, MessageType.TRAIN, action, mods)
219
219
 
@@ -244,7 +244,7 @@ class ClientApp:
244
244
  >>> def evaluate(message: Message, context: Context) -> Message:
245
245
  >>> print("Executing default evaluate function")
246
246
  >>> # Create and return an echo reply message
247
- >>> return message.create_reply(content=message.content)
247
+ >>> return Message(message.content, reply_to=message)
248
248
 
249
249
  Registering an evaluate function with a custom action name:
250
250
 
@@ -255,7 +255,7 @@ class ClientApp:
255
255
  >>> @app.evaluate("custom_action")
256
256
  >>> def custom_action(message: Message, context: Context) -> Message:
257
257
  >>> print("Executing evaluate function for custom action")
258
- >>> return message.create_reply(content=message.content)
258
+ >>> return Message(message.content, reply_to=message)
259
259
 
260
260
  Registering an evaluate function with a function-specific Flower Mod:
261
261
 
@@ -268,7 +268,7 @@ class ClientApp:
268
268
  >>> def evaluate(message: Message, context: Context) -> Message:
269
269
  >>> print("Executing evaluate function with message size mod")
270
270
  >>> # Create and return an echo reply message
271
- >>> return message.create_reply(content=message.content)
271
+ >>> return Message(message.content, reply_to=message)
272
272
  """
273
273
  return _get_decorator(self, MessageType.EVALUATE, action, mods)
274
274
 
@@ -299,7 +299,7 @@ class ClientApp:
299
299
  >>> def query(message: Message, context: Context) -> Message:
300
300
  >>> print("Executing default query function")
301
301
  >>> # Create and return an echo reply message
302
- >>> return message.create_reply(content=message.content)
302
+ >>> return Message(message.content, reply_to=message)
303
303
 
304
304
  Registering a query function with a custom action name:
305
305
 
@@ -310,7 +310,7 @@ class ClientApp:
310
310
  >>> @app.query("custom_action")
311
311
  >>> def custom_action(message: Message, context: Context) -> Message:
312
312
  >>> print("Executing query function for custom action")
313
- >>> return message.create_reply(content=message.content)
313
+ >>> return Message(message.content, reply_to=message)
314
314
 
315
315
  Registering a query function with a function-specific Flower Mod:
316
316
 
@@ -323,7 +323,7 @@ class ClientApp:
323
323
  >>> def query(message: Message, context: Context) -> Message:
324
324
  >>> print("Executing query function with message size mod")
325
325
  >>> # Create and return an echo reply message
326
- >>> return message.create_reply(content=message.content)
326
+ >>> return Message(message.content, reply_to=message)
327
327
  """
328
328
  return _get_decorator(self, MessageType.QUERY, action, mods)
329
329
 
@@ -454,8 +454,6 @@ def _registration_error(fn_name: str) -> ValueError:
454
454
  >>> def {fn_name}(message: Message, context: Context) -> Message:
455
455
  >>> print("ClientApp {fn_name} running")
456
456
  >>> # Create and return an echo reply message
457
- >>> return message.create_reply(
458
- >>> content=message.content()
459
- >>> )
457
+ >>> return Message(message.content, reply_to=message)
460
458
  """,
461
459
  )
@@ -152,8 +152,8 @@ def run_clientapp( # pylint: disable=R0914
152
152
  log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
153
153
 
154
154
  # Create error message
155
- reply_message = message.create_error_reply(
156
- error=Error(code=e_code, reason=reason)
155
+ reply_message = Message(
156
+ Error(code=e_code, reason=reason), reply_to=message
157
157
  )
158
158
 
159
159
  # Push Message and Context to SuperNode
@@ -28,16 +28,18 @@ from cryptography.hazmat.primitives.asymmetric import ec
28
28
  from flwr.common import (
29
29
  DEFAULT_TTL,
30
30
  GRPC_MAX_MESSAGE_LENGTH,
31
- ConfigsRecord,
31
+ ConfigRecord,
32
32
  Message,
33
33
  Metadata,
34
- RecordSet,
34
+ RecordDict,
35
+ now,
35
36
  )
36
- from flwr.common import recordset_compat as compat
37
+ from flwr.common import recorddict_compat as compat
37
38
  from flwr.common import serde
38
39
  from flwr.common.constant import MessageType, MessageTypeLegacy
39
40
  from flwr.common.grpc import create_channel, on_channel_state_change
40
41
  from flwr.common.logger import log
42
+ from flwr.common.message import make_message
41
43
  from flwr.common.retry_invoker import RetryInvoker
42
44
  from flwr.common.typing import Fab, Run
43
45
  from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
@@ -139,32 +141,32 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
139
141
  # Receive ServerMessage proto
140
142
  proto = next(server_message_iterator)
141
143
 
142
- # ServerMessage proto --> *Ins --> RecordSet
144
+ # ServerMessage proto --> *Ins --> RecordDict
143
145
  field = proto.WhichOneof("msg")
144
146
  message_type = ""
145
147
  if field == "get_properties_ins":
146
- recordset = compat.getpropertiesins_to_recordset(
148
+ recorddict = compat.getpropertiesins_to_recorddict(
147
149
  serde.get_properties_ins_from_proto(proto.get_properties_ins)
148
150
  )
149
151
  message_type = MessageTypeLegacy.GET_PROPERTIES
150
152
  elif field == "get_parameters_ins":
151
- recordset = compat.getparametersins_to_recordset(
153
+ recorddict = compat.getparametersins_to_recorddict(
152
154
  serde.get_parameters_ins_from_proto(proto.get_parameters_ins)
153
155
  )
154
156
  message_type = MessageTypeLegacy.GET_PARAMETERS
155
157
  elif field == "fit_ins":
156
- recordset = compat.fitins_to_recordset(
158
+ recorddict = compat.fitins_to_recorddict(
157
159
  serde.fit_ins_from_proto(proto.fit_ins), False
158
160
  )
159
161
  message_type = MessageType.TRAIN
160
162
  elif field == "evaluate_ins":
161
- recordset = compat.evaluateins_to_recordset(
163
+ recorddict = compat.evaluateins_to_recorddict(
162
164
  serde.evaluate_ins_from_proto(proto.evaluate_ins), False
163
165
  )
164
166
  message_type = MessageType.EVALUATE
165
167
  elif field == "reconnect_ins":
166
- recordset = RecordSet()
167
- recordset.configs_records["config"] = ConfigsRecord(
168
+ recorddict = RecordDict()
169
+ recorddict.config_records["config"] = ConfigRecord(
168
170
  {"seconds": proto.reconnect_ins.seconds}
169
171
  )
170
172
  message_type = "reconnect"
@@ -175,45 +177,46 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
175
177
  )
176
178
 
177
179
  # Construct Message
178
- return Message(
180
+ return make_message(
179
181
  metadata=Metadata(
180
182
  run_id=0,
181
183
  message_id=str(uuid.uuid4()),
182
184
  src_node_id=0,
183
185
  dst_node_id=0,
184
- reply_to_message="",
186
+ reply_to_message_id="",
185
187
  group_id="",
188
+ created_at=now().timestamp(),
186
189
  ttl=DEFAULT_TTL,
187
190
  message_type=message_type,
188
191
  ),
189
- content=recordset,
192
+ content=recorddict,
190
193
  )
191
194
 
192
195
  def send(message: Message) -> None:
193
- # Retrieve RecordSet and message_type
194
- recordset = message.content
196
+ # Retrieve RecordDict and message_type
197
+ recorddict = message.content
195
198
  message_type = message.metadata.message_type
196
199
 
197
- # RecordSet --> *Res --> *Res proto -> ClientMessage proto
200
+ # RecordDict --> *Res --> *Res proto -> ClientMessage proto
198
201
  if message_type == MessageTypeLegacy.GET_PROPERTIES:
199
- getpropres = compat.recordset_to_getpropertiesres(recordset)
202
+ getpropres = compat.recorddict_to_getpropertiesres(recorddict)
200
203
  msg_proto = ClientMessage(
201
204
  get_properties_res=serde.get_properties_res_to_proto(getpropres)
202
205
  )
203
206
  elif message_type == MessageTypeLegacy.GET_PARAMETERS:
204
- getparamres = compat.recordset_to_getparametersres(recordset, False)
207
+ getparamres = compat.recorddict_to_getparametersres(recorddict, False)
205
208
  msg_proto = ClientMessage(
206
209
  get_parameters_res=serde.get_parameters_res_to_proto(getparamres)
207
210
  )
208
211
  elif message_type == MessageType.TRAIN:
209
- fitres = compat.recordset_to_fitres(recordset, False)
212
+ fitres = compat.recorddict_to_fitres(recorddict, False)
210
213
  msg_proto = ClientMessage(fit_res=serde.fit_res_to_proto(fitres))
211
214
  elif message_type == MessageType.EVALUATE:
212
- evalres = compat.recordset_to_evaluateres(recordset)
215
+ evalres = compat.recorddict_to_evaluateres(recorddict)
213
216
  msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres))
214
217
  elif message_type == "reconnect":
215
218
  reason = cast(
216
- Reason.ValueType, recordset.configs_records["config"]["reason"]
219
+ Reason.ValueType, recorddict.config_records["config"]["reason"]
217
220
  )
218
221
  msg_proto = ClientMessage(
219
222
  disconnect_res=ClientMessage.DisconnectRes(reason=reason)
@@ -26,17 +26,17 @@ from flwr.client.client import (
26
26
  )
27
27
  from flwr.client.numpy_client import NumPyClient
28
28
  from flwr.client.typing import ClientFnExt
29
- from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log
29
+ from flwr.common import ConfigRecord, Context, Message, Metadata, RecordDict, log
30
30
  from flwr.common.constant import MessageType, MessageTypeLegacy
31
- from flwr.common.recordset_compat import (
32
- evaluateres_to_recordset,
33
- fitres_to_recordset,
34
- getparametersres_to_recordset,
35
- getpropertiesres_to_recordset,
36
- recordset_to_evaluateins,
37
- recordset_to_fitins,
38
- recordset_to_getparametersins,
39
- recordset_to_getpropertiesins,
31
+ from flwr.common.recorddict_compat import (
32
+ evaluateres_to_recorddict,
33
+ fitres_to_recorddict,
34
+ getparametersres_to_recorddict,
35
+ getpropertiesres_to_recorddict,
36
+ recorddict_to_evaluateins,
37
+ recorddict_to_fitins,
38
+ recorddict_to_getparametersins,
39
+ recorddict_to_getpropertiesins,
40
40
  )
41
41
  from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
42
42
  ClientMessage,
@@ -70,18 +70,18 @@ def handle_control_message(message: Message) -> tuple[Optional[Message], int]:
70
70
  Number of seconds that the client should disconnect from the server.
71
71
  """
72
72
  if message.metadata.message_type == "reconnect":
73
- # Retrieve ReconnectIns from recordset
74
- recordset = message.content
75
- seconds = cast(int, recordset.configs_records["config"]["seconds"])
73
+ # Retrieve ReconnectIns from RecordDict
74
+ recorddict = message.content
75
+ seconds = cast(int, recorddict.config_records["config"]["seconds"])
76
76
  # Construct ReconnectIns and call _reconnect
77
77
  disconnect_msg, sleep_duration = _reconnect(
78
78
  ServerMessage.ReconnectIns(seconds=seconds)
79
79
  )
80
- # Store DisconnectRes in recordset
80
+ # Store DisconnectRes in RecordDict
81
81
  reason = cast(int, disconnect_msg.disconnect_res.reason)
82
- recordset = RecordSet()
83
- recordset.configs_records["config"] = ConfigsRecord({"reason": reason})
84
- out_message = message.create_reply(recordset)
82
+ recorddict = RecordDict()
83
+ recorddict.config_records["config"] = ConfigRecord({"reason": reason})
84
+ out_message = Message(recorddict, reply_to=message)
85
85
  # Return Message and sleep duration
86
86
  return out_message, sleep_duration
87
87
 
@@ -111,37 +111,37 @@ def handle_legacy_message_from_msgtype(
111
111
  if message_type == MessageTypeLegacy.GET_PROPERTIES:
112
112
  get_properties_res = maybe_call_get_properties(
113
113
  client=client,
114
- get_properties_ins=recordset_to_getpropertiesins(message.content),
114
+ get_properties_ins=recorddict_to_getpropertiesins(message.content),
115
115
  )
116
- out_recordset = getpropertiesres_to_recordset(get_properties_res)
116
+ out_recorddict = getpropertiesres_to_recorddict(get_properties_res)
117
117
  # Handle GetParametersIns
118
118
  elif message_type == MessageTypeLegacy.GET_PARAMETERS:
119
119
  get_parameters_res = maybe_call_get_parameters(
120
120
  client=client,
121
- get_parameters_ins=recordset_to_getparametersins(message.content),
121
+ get_parameters_ins=recorddict_to_getparametersins(message.content),
122
122
  )
123
- out_recordset = getparametersres_to_recordset(
123
+ out_recorddict = getparametersres_to_recorddict(
124
124
  get_parameters_res, keep_input=False
125
125
  )
126
126
  # Handle FitIns
127
127
  elif message_type == MessageType.TRAIN:
128
128
  fit_res = maybe_call_fit(
129
129
  client=client,
130
- fit_ins=recordset_to_fitins(message.content, keep_input=True),
130
+ fit_ins=recorddict_to_fitins(message.content, keep_input=True),
131
131
  )
132
- out_recordset = fitres_to_recordset(fit_res, keep_input=False)
132
+ out_recorddict = fitres_to_recorddict(fit_res, keep_input=False)
133
133
  # Handle EvaluateIns
134
134
  elif message_type == MessageType.EVALUATE:
135
135
  evaluate_res = maybe_call_evaluate(
136
136
  client=client,
137
- evaluate_ins=recordset_to_evaluateins(message.content, keep_input=True),
137
+ evaluate_ins=recorddict_to_evaluateins(message.content, keep_input=True),
138
138
  )
139
- out_recordset = evaluateres_to_recordset(evaluate_res)
139
+ out_recorddict = evaluateres_to_recorddict(evaluate_res)
140
140
  else:
141
141
  raise ValueError(f"Invalid message type: {message_type}")
142
142
 
143
143
  # Return Message
144
- return message.create_reply(out_recordset)
144
+ return Message(out_recorddict, reply_to=message)
145
145
 
146
146
 
147
147
  def _reconnect(
@@ -167,7 +167,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
167
167
  and out_meta.message_id == "" # This will be generated by the server
168
168
  and out_meta.src_node_id == in_meta.dst_node_id
169
169
  and out_meta.dst_node_id == in_meta.src_node_id
170
- and out_meta.reply_to_message == in_meta.message_id
170
+ and out_meta.reply_to_message_id == in_meta.message_id
171
171
  and out_meta.group_id == in_meta.group_id
172
172
  and out_meta.message_type == in_meta.message_type
173
173
  and out_meta.created_at > in_meta.created_at
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
19
- from .comms_mods import message_size_mod, parameters_size_mod
19
+ from .comms_mods import arrays_size_mod, message_size_mod
20
20
  from .localdp_mod import LocalDpMod
21
21
  from .secure_aggregation import secagg_mod, secaggplus_mod
22
22
  from .utils import make_ffn
@@ -24,10 +24,10 @@ from .utils import make_ffn
24
24
  __all__ = [
25
25
  "LocalDpMod",
26
26
  "adaptiveclipping_mod",
27
+ "arrays_size_mod",
27
28
  "fixedclipping_mod",
28
29
  "make_ffn",
29
30
  "message_size_mod",
30
- "parameters_size_mod",
31
31
  "secagg_mod",
32
32
  "secaggplus_mod",
33
33
  ]
@@ -19,7 +19,7 @@ from logging import INFO
19
19
 
20
20
  from flwr.client.typing import ClientAppCallable
21
21
  from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
22
- from flwr.common import recordset_compat as compat
22
+ from flwr.common import recorddict_compat as compat
23
23
  from flwr.common.constant import MessageType
24
24
  from flwr.common.context import Context
25
25
  from flwr.common.differential_privacy import (
@@ -53,7 +53,7 @@ def fixedclipping_mod(
53
53
  """
54
54
  if msg.metadata.message_type != MessageType.TRAIN:
55
55
  return call_next(msg, ctxt)
56
- fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
56
+ fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
57
57
  if KEY_CLIPPING_NORM not in fit_ins.config:
58
58
  raise KeyError(
59
59
  f"The {KEY_CLIPPING_NORM} value is not supplied by the "
@@ -71,7 +71,7 @@ def fixedclipping_mod(
71
71
  if out_msg.has_error():
72
72
  return out_msg
73
73
 
74
- fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True)
74
+ fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
75
75
 
76
76
  client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
77
77
 
@@ -87,7 +87,7 @@ def fixedclipping_mod(
87
87
  )
88
88
 
89
89
  fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
90
- out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
90
+ out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
91
91
  return out_msg
92
92
 
93
93
 
@@ -116,7 +116,7 @@ def adaptiveclipping_mod(
116
116
  if msg.metadata.message_type != MessageType.TRAIN:
117
117
  return call_next(msg, ctxt)
118
118
 
119
- fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
119
+ fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
120
120
 
121
121
  if KEY_CLIPPING_NORM not in fit_ins.config:
122
122
  raise KeyError(
@@ -136,7 +136,7 @@ def adaptiveclipping_mod(
136
136
  if out_msg.has_error():
137
137
  return out_msg
138
138
 
139
- fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True)
139
+ fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
140
140
 
141
141
  client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
142
142
 
@@ -155,5 +155,5 @@ def adaptiveclipping_mod(
155
155
  fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
156
156
 
157
157
  fit_res.metrics[KEY_NORM_BIT] = norm_bit
158
- out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
158
+ out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
159
159
  return out_msg
@@ -34,47 +34,41 @@ def message_size_mod(
34
34
  """
35
35
  message_size_in_bytes = 0
36
36
 
37
- for p_record in msg.content.parameters_records.values():
38
- message_size_in_bytes += p_record.count_bytes()
39
-
40
- for c_record in msg.content.configs_records.values():
41
- message_size_in_bytes += c_record.count_bytes()
42
-
43
- for m_record in msg.content.metrics_records.values():
44
- message_size_in_bytes += m_record.count_bytes()
37
+ for record in msg.content.values():
38
+ message_size_in_bytes += record.count_bytes()
45
39
 
46
40
  log(INFO, "Message size: %i bytes", message_size_in_bytes)
47
41
 
48
42
  return call_next(msg, ctxt)
49
43
 
50
44
 
51
- def parameters_size_mod(
45
+ def arrays_size_mod(
52
46
  msg: Message, ctxt: Context, call_next: ClientAppCallable
53
47
  ) -> Message:
54
- """Parameters size mod.
48
+ """Arrays size mod.
55
49
 
56
- This mod logs the number of parameters transmitted in the message as well as their
57
- size in bytes.
50
+ This mod logs the number of array elements transmitted in ``ArrayRecord``s of
51
+ the message as well as their sizes in bytes.
58
52
  """
59
53
  model_size_stats = {}
60
- parameters_size_in_bytes = 0
61
- for record_name, p_record in msg.content.parameters_records.items():
62
- p_record_bytes = p_record.count_bytes()
63
- parameters_size_in_bytes += p_record_bytes
64
- parameter_count = 0
65
- for array in p_record.values():
66
- parameter_count += (
54
+ arrays_size_in_bytes = 0
55
+ for record_name, arr_record in msg.content.array_records.items():
56
+ arr_record_bytes = arr_record.count_bytes()
57
+ arrays_size_in_bytes += arr_record_bytes
58
+ element_count = 0
59
+ for array in arr_record.values():
60
+ element_count += (
67
61
  int(np.prod(array.shape)) if array.shape else array.numpy().size
68
62
  )
69
63
 
70
64
  model_size_stats[f"{record_name}"] = {
71
- "parameters": parameter_count,
72
- "bytes": p_record_bytes,
65
+ "elements": element_count,
66
+ "bytes": arr_record_bytes,
73
67
  }
74
68
 
75
69
  if model_size_stats:
76
70
  log(INFO, model_size_stats)
77
71
 
78
- log(INFO, "Total parameters transmitted: %i bytes", parameters_size_in_bytes)
72
+ log(INFO, "Total array elements transmitted: %i bytes", arrays_size_in_bytes)
79
73
 
80
74
  return call_next(msg, ctxt)
@@ -21,7 +21,7 @@ import numpy as np
21
21
 
22
22
  from flwr.client.typing import ClientAppCallable
23
23
  from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
24
- from flwr.common import recordset_compat as compat
24
+ from flwr.common import recorddict_compat as compat
25
25
  from flwr.common.constant import MessageType
26
26
  from flwr.common.context import Context
27
27
  from flwr.common.differential_privacy import (
@@ -107,7 +107,7 @@ class LocalDpMod:
107
107
  if msg.metadata.message_type != MessageType.TRAIN:
108
108
  return call_next(msg, ctxt)
109
109
 
110
- fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
110
+ fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
111
111
  server_to_client_params = parameters_to_ndarrays(fit_ins.parameters)
112
112
 
113
113
  # Call inner app
@@ -117,7 +117,7 @@ class LocalDpMod:
117
117
  if out_msg.has_error():
118
118
  return out_msg
119
119
 
120
- fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True)
120
+ fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
121
121
 
122
122
  client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
123
123
 
@@ -149,5 +149,5 @@ class LocalDpMod:
149
149
  noise_value_sd,
150
150
  )
151
151
 
152
- out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
152
+ out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
153
153
  return out_msg