flwr-nightly 1.17.0.dev20250319__py3-none-any.whl → 1.17.0.dev20250320__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. flwr/client/app.py +6 -4
  2. flwr/client/clientapp/app.py +2 -2
  3. flwr/client/grpc_client/connection.py +23 -20
  4. flwr/client/message_handler/message_handler.py +27 -27
  5. flwr/client/mod/centraldp_mods.py +7 -7
  6. flwr/client/mod/localdp_mod.py +4 -4
  7. flwr/client/mod/secure_aggregation/secaggplus_mod.py +5 -5
  8. flwr/client/run_info_store.py +2 -2
  9. flwr/common/__init__.py +2 -0
  10. flwr/common/context.py +4 -4
  11. flwr/common/message.py +269 -101
  12. flwr/common/record/__init__.py +2 -1
  13. flwr/common/record/configsrecord.py +2 -2
  14. flwr/common/record/metricsrecord.py +1 -1
  15. flwr/common/record/parametersrecord.py +1 -1
  16. flwr/common/record/{recordset.py → recorddict.py} +57 -17
  17. flwr/common/{recordset_compat.py → recorddict_compat.py} +105 -105
  18. flwr/common/serde.py +33 -37
  19. flwr/proto/exec_pb2.py +32 -32
  20. flwr/proto/exec_pb2.pyi +3 -3
  21. flwr/proto/message_pb2.py +12 -12
  22. flwr/proto/message_pb2.pyi +9 -9
  23. flwr/proto/recorddict_pb2.py +70 -0
  24. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +2 -2
  25. flwr/proto/run_pb2.py +32 -32
  26. flwr/proto/run_pb2.pyi +3 -3
  27. flwr/server/compat/grid_client_proxy.py +30 -30
  28. flwr/server/grid/grid.py +3 -3
  29. flwr/server/grid/grpc_grid.py +15 -23
  30. flwr/server/grid/inmemory_grid.py +14 -20
  31. flwr/server/superlink/fleet/vce/vce_api.py +1 -3
  32. flwr/server/superlink/linkstate/in_memory_linkstate.py +1 -1
  33. flwr/server/superlink/linkstate/sqlite_linkstate.py +14 -18
  34. flwr/server/superlink/linkstate/utils.py +10 -7
  35. flwr/server/superlink/serverappio/serverappio_servicer.py +1 -1
  36. flwr/server/utils/validator.py +4 -4
  37. flwr/server/workflow/default_workflows.py +7 -7
  38. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
  39. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  40. flwr/simulation/run_simulation.py +3 -3
  41. flwr/superexec/deployment.py +2 -2
  42. flwr/superexec/simulation.py +2 -2
  43. {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/METADATA +1 -1
  44. {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/RECORD +49 -49
  45. flwr/proto/recordset_pb2.py +0 -70
  46. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  47. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  48. {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/LICENSE +0 -0
  49. {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/WHEEL +0 -0
  50. {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/entry_points.txt +0 -0
flwr/client/app.py CHANGED
@@ -495,8 +495,9 @@ def start_client_internal(
495
495
  context = run_info_store.retrieve_context(run_id=run_id)
496
496
  # Create an error reply message that will never be used to prevent
497
497
  # the used-before-assignment linting error
498
- reply_message = message.create_error_reply(
499
- error=Error(code=ErrorCode.UNKNOWN, reason="Unknown")
498
+ reply_message = Message(
499
+ Error(code=ErrorCode.UNKNOWN, reason="Unknown"),
500
+ reply_to=message,
500
501
  )
501
502
 
502
503
  # Handle app loading and task message
@@ -593,8 +594,9 @@ def start_client_internal(
593
594
  log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
594
595
 
595
596
  # Create error message
596
- reply_message = message.create_error_reply(
597
- error=Error(code=e_code, reason=reason)
597
+ reply_message = Message(
598
+ Error(code=e_code, reason=reason),
599
+ reply_to=message,
598
600
  )
599
601
  else:
600
602
  # No exception, update node state
@@ -152,8 +152,8 @@ def run_clientapp( # pylint: disable=R0914
152
152
  log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
153
153
 
154
154
  # Create error message
155
- reply_message = message.create_error_reply(
156
- error=Error(code=e_code, reason=reason)
155
+ reply_message = Message(
156
+ Error(code=e_code, reason=reason), reply_to=message
157
157
  )
158
158
 
159
159
  # Push Message and Context to SuperNode
@@ -31,13 +31,15 @@ from flwr.common import (
31
31
  ConfigsRecord,
32
32
  Message,
33
33
  Metadata,
34
- RecordSet,
34
+ RecordDict,
35
+ now,
35
36
  )
36
- from flwr.common import recordset_compat as compat
37
+ from flwr.common import recorddict_compat as compat
37
38
  from flwr.common import serde
38
39
  from flwr.common.constant import MessageType, MessageTypeLegacy
39
40
  from flwr.common.grpc import create_channel, on_channel_state_change
40
41
  from flwr.common.logger import log
42
+ from flwr.common.message import make_message
41
43
  from flwr.common.retry_invoker import RetryInvoker
42
44
  from flwr.common.typing import Fab, Run
43
45
  from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
@@ -139,32 +141,32 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
139
141
  # Receive ServerMessage proto
140
142
  proto = next(server_message_iterator)
141
143
 
142
- # ServerMessage proto --> *Ins --> RecordSet
144
+ # ServerMessage proto --> *Ins --> RecordDict
143
145
  field = proto.WhichOneof("msg")
144
146
  message_type = ""
145
147
  if field == "get_properties_ins":
146
- recordset = compat.getpropertiesins_to_recordset(
148
+ recorddict = compat.getpropertiesins_to_recorddict(
147
149
  serde.get_properties_ins_from_proto(proto.get_properties_ins)
148
150
  )
149
151
  message_type = MessageTypeLegacy.GET_PROPERTIES
150
152
  elif field == "get_parameters_ins":
151
- recordset = compat.getparametersins_to_recordset(
153
+ recorddict = compat.getparametersins_to_recorddict(
152
154
  serde.get_parameters_ins_from_proto(proto.get_parameters_ins)
153
155
  )
154
156
  message_type = MessageTypeLegacy.GET_PARAMETERS
155
157
  elif field == "fit_ins":
156
- recordset = compat.fitins_to_recordset(
158
+ recorddict = compat.fitins_to_recorddict(
157
159
  serde.fit_ins_from_proto(proto.fit_ins), False
158
160
  )
159
161
  message_type = MessageType.TRAIN
160
162
  elif field == "evaluate_ins":
161
- recordset = compat.evaluateins_to_recordset(
163
+ recorddict = compat.evaluateins_to_recorddict(
162
164
  serde.evaluate_ins_from_proto(proto.evaluate_ins), False
163
165
  )
164
166
  message_type = MessageType.EVALUATE
165
167
  elif field == "reconnect_ins":
166
- recordset = RecordSet()
167
- recordset.configs_records["config"] = ConfigsRecord(
168
+ recorddict = RecordDict()
169
+ recorddict.configs_records["config"] = ConfigsRecord(
168
170
  {"seconds": proto.reconnect_ins.seconds}
169
171
  )
170
172
  message_type = "reconnect"
@@ -175,45 +177,46 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
175
177
  )
176
178
 
177
179
  # Construct Message
178
- return Message(
180
+ return make_message(
179
181
  metadata=Metadata(
180
182
  run_id=0,
181
183
  message_id=str(uuid.uuid4()),
182
184
  src_node_id=0,
183
185
  dst_node_id=0,
184
- reply_to_message="",
186
+ reply_to_message_id="",
185
187
  group_id="",
188
+ created_at=now().timestamp(),
186
189
  ttl=DEFAULT_TTL,
187
190
  message_type=message_type,
188
191
  ),
189
- content=recordset,
192
+ content=recorddict,
190
193
  )
191
194
 
192
195
  def send(message: Message) -> None:
193
- # Retrieve RecordSet and message_type
194
- recordset = message.content
196
+ # Retrieve RecordDict and message_type
197
+ recorddict = message.content
195
198
  message_type = message.metadata.message_type
196
199
 
197
- # RecordSet --> *Res --> *Res proto -> ClientMessage proto
200
+ # RecordDict --> *Res --> *Res proto -> ClientMessage proto
198
201
  if message_type == MessageTypeLegacy.GET_PROPERTIES:
199
- getpropres = compat.recordset_to_getpropertiesres(recordset)
202
+ getpropres = compat.recorddict_to_getpropertiesres(recorddict)
200
203
  msg_proto = ClientMessage(
201
204
  get_properties_res=serde.get_properties_res_to_proto(getpropres)
202
205
  )
203
206
  elif message_type == MessageTypeLegacy.GET_PARAMETERS:
204
- getparamres = compat.recordset_to_getparametersres(recordset, False)
207
+ getparamres = compat.recorddict_to_getparametersres(recorddict, False)
205
208
  msg_proto = ClientMessage(
206
209
  get_parameters_res=serde.get_parameters_res_to_proto(getparamres)
207
210
  )
208
211
  elif message_type == MessageType.TRAIN:
209
- fitres = compat.recordset_to_fitres(recordset, False)
212
+ fitres = compat.recorddict_to_fitres(recorddict, False)
210
213
  msg_proto = ClientMessage(fit_res=serde.fit_res_to_proto(fitres))
211
214
  elif message_type == MessageType.EVALUATE:
212
- evalres = compat.recordset_to_evaluateres(recordset)
215
+ evalres = compat.recorddict_to_evaluateres(recorddict)
213
216
  msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres))
214
217
  elif message_type == "reconnect":
215
218
  reason = cast(
216
- Reason.ValueType, recordset.configs_records["config"]["reason"]
219
+ Reason.ValueType, recorddict.configs_records["config"]["reason"]
217
220
  )
218
221
  msg_proto = ClientMessage(
219
222
  disconnect_res=ClientMessage.DisconnectRes(reason=reason)
@@ -26,17 +26,17 @@ from flwr.client.client import (
26
26
  )
27
27
  from flwr.client.numpy_client import NumPyClient
28
28
  from flwr.client.typing import ClientFnExt
29
- from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log
29
+ from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordDict, log
30
30
  from flwr.common.constant import MessageType, MessageTypeLegacy
31
- from flwr.common.recordset_compat import (
32
- evaluateres_to_recordset,
33
- fitres_to_recordset,
34
- getparametersres_to_recordset,
35
- getpropertiesres_to_recordset,
36
- recordset_to_evaluateins,
37
- recordset_to_fitins,
38
- recordset_to_getparametersins,
39
- recordset_to_getpropertiesins,
31
+ from flwr.common.recorddict_compat import (
32
+ evaluateres_to_recorddict,
33
+ fitres_to_recorddict,
34
+ getparametersres_to_recorddict,
35
+ getpropertiesres_to_recorddict,
36
+ recorddict_to_evaluateins,
37
+ recorddict_to_fitins,
38
+ recorddict_to_getparametersins,
39
+ recorddict_to_getpropertiesins,
40
40
  )
41
41
  from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
42
42
  ClientMessage,
@@ -70,18 +70,18 @@ def handle_control_message(message: Message) -> tuple[Optional[Message], int]:
70
70
  Number of seconds that the client should disconnect from the server.
71
71
  """
72
72
  if message.metadata.message_type == "reconnect":
73
- # Retrieve ReconnectIns from recordset
74
- recordset = message.content
75
- seconds = cast(int, recordset.configs_records["config"]["seconds"])
73
+ # Retrieve ReconnectIns from RecordDict
74
+ recorddict = message.content
75
+ seconds = cast(int, recorddict.configs_records["config"]["seconds"])
76
76
  # Construct ReconnectIns and call _reconnect
77
77
  disconnect_msg, sleep_duration = _reconnect(
78
78
  ServerMessage.ReconnectIns(seconds=seconds)
79
79
  )
80
- # Store DisconnectRes in recordset
80
+ # Store DisconnectRes in RecordDict
81
81
  reason = cast(int, disconnect_msg.disconnect_res.reason)
82
- recordset = RecordSet()
83
- recordset.configs_records["config"] = ConfigsRecord({"reason": reason})
84
- out_message = message.create_reply(recordset)
82
+ recorddict = RecordDict()
83
+ recorddict.configs_records["config"] = ConfigsRecord({"reason": reason})
84
+ out_message = Message(recorddict, reply_to=message)
85
85
  # Return Message and sleep duration
86
86
  return out_message, sleep_duration
87
87
 
@@ -111,37 +111,37 @@ def handle_legacy_message_from_msgtype(
111
111
  if message_type == MessageTypeLegacy.GET_PROPERTIES:
112
112
  get_properties_res = maybe_call_get_properties(
113
113
  client=client,
114
- get_properties_ins=recordset_to_getpropertiesins(message.content),
114
+ get_properties_ins=recorddict_to_getpropertiesins(message.content),
115
115
  )
116
- out_recordset = getpropertiesres_to_recordset(get_properties_res)
116
+ out_recorddict = getpropertiesres_to_recorddict(get_properties_res)
117
117
  # Handle GetParametersIns
118
118
  elif message_type == MessageTypeLegacy.GET_PARAMETERS:
119
119
  get_parameters_res = maybe_call_get_parameters(
120
120
  client=client,
121
- get_parameters_ins=recordset_to_getparametersins(message.content),
121
+ get_parameters_ins=recorddict_to_getparametersins(message.content),
122
122
  )
123
- out_recordset = getparametersres_to_recordset(
123
+ out_recorddict = getparametersres_to_recorddict(
124
124
  get_parameters_res, keep_input=False
125
125
  )
126
126
  # Handle FitIns
127
127
  elif message_type == MessageType.TRAIN:
128
128
  fit_res = maybe_call_fit(
129
129
  client=client,
130
- fit_ins=recordset_to_fitins(message.content, keep_input=True),
130
+ fit_ins=recorddict_to_fitins(message.content, keep_input=True),
131
131
  )
132
- out_recordset = fitres_to_recordset(fit_res, keep_input=False)
132
+ out_recorddict = fitres_to_recorddict(fit_res, keep_input=False)
133
133
  # Handle EvaluateIns
134
134
  elif message_type == MessageType.EVALUATE:
135
135
  evaluate_res = maybe_call_evaluate(
136
136
  client=client,
137
- evaluate_ins=recordset_to_evaluateins(message.content, keep_input=True),
137
+ evaluate_ins=recorddict_to_evaluateins(message.content, keep_input=True),
138
138
  )
139
- out_recordset = evaluateres_to_recordset(evaluate_res)
139
+ out_recorddict = evaluateres_to_recorddict(evaluate_res)
140
140
  else:
141
141
  raise ValueError(f"Invalid message type: {message_type}")
142
142
 
143
143
  # Return Message
144
- return message.create_reply(out_recordset)
144
+ return Message(out_recorddict, reply_to=message)
145
145
 
146
146
 
147
147
  def _reconnect(
@@ -167,7 +167,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
167
167
  and out_meta.message_id == "" # This will be generated by the server
168
168
  and out_meta.src_node_id == in_meta.dst_node_id
169
169
  and out_meta.dst_node_id == in_meta.src_node_id
170
- and out_meta.reply_to_message == in_meta.message_id
170
+ and out_meta.reply_to_message_id == in_meta.message_id
171
171
  and out_meta.group_id == in_meta.group_id
172
172
  and out_meta.message_type == in_meta.message_type
173
173
  and out_meta.created_at > in_meta.created_at
@@ -19,7 +19,7 @@ from logging import INFO
19
19
 
20
20
  from flwr.client.typing import ClientAppCallable
21
21
  from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
22
- from flwr.common import recordset_compat as compat
22
+ from flwr.common import recorddict_compat as compat
23
23
  from flwr.common.constant import MessageType
24
24
  from flwr.common.context import Context
25
25
  from flwr.common.differential_privacy import (
@@ -53,7 +53,7 @@ def fixedclipping_mod(
53
53
  """
54
54
  if msg.metadata.message_type != MessageType.TRAIN:
55
55
  return call_next(msg, ctxt)
56
- fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
56
+ fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
57
57
  if KEY_CLIPPING_NORM not in fit_ins.config:
58
58
  raise KeyError(
59
59
  f"The {KEY_CLIPPING_NORM} value is not supplied by the "
@@ -71,7 +71,7 @@ def fixedclipping_mod(
71
71
  if out_msg.has_error():
72
72
  return out_msg
73
73
 
74
- fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True)
74
+ fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
75
75
 
76
76
  client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
77
77
 
@@ -87,7 +87,7 @@ def fixedclipping_mod(
87
87
  )
88
88
 
89
89
  fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
90
- out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
90
+ out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
91
91
  return out_msg
92
92
 
93
93
 
@@ -116,7 +116,7 @@ def adaptiveclipping_mod(
116
116
  if msg.metadata.message_type != MessageType.TRAIN:
117
117
  return call_next(msg, ctxt)
118
118
 
119
- fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
119
+ fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
120
120
 
121
121
  if KEY_CLIPPING_NORM not in fit_ins.config:
122
122
  raise KeyError(
@@ -136,7 +136,7 @@ def adaptiveclipping_mod(
136
136
  if out_msg.has_error():
137
137
  return out_msg
138
138
 
139
- fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True)
139
+ fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
140
140
 
141
141
  client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
142
142
 
@@ -155,5 +155,5 @@ def adaptiveclipping_mod(
155
155
  fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
156
156
 
157
157
  fit_res.metrics[KEY_NORM_BIT] = norm_bit
158
- out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
158
+ out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
159
159
  return out_msg
@@ -21,7 +21,7 @@ import numpy as np
21
21
 
22
22
  from flwr.client.typing import ClientAppCallable
23
23
  from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
24
- from flwr.common import recordset_compat as compat
24
+ from flwr.common import recorddict_compat as compat
25
25
  from flwr.common.constant import MessageType
26
26
  from flwr.common.context import Context
27
27
  from flwr.common.differential_privacy import (
@@ -107,7 +107,7 @@ class LocalDpMod:
107
107
  if msg.metadata.message_type != MessageType.TRAIN:
108
108
  return call_next(msg, ctxt)
109
109
 
110
- fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
110
+ fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
111
111
  server_to_client_params = parameters_to_ndarrays(fit_ins.parameters)
112
112
 
113
113
  # Call inner app
@@ -117,7 +117,7 @@ class LocalDpMod:
117
117
  if out_msg.has_error():
118
118
  return out_msg
119
119
 
120
- fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True)
120
+ fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
121
121
 
122
122
  client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
123
123
 
@@ -149,5 +149,5 @@ class LocalDpMod:
149
149
  noise_value_sd,
150
150
  )
151
151
 
152
- out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
152
+ out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
153
153
  return out_msg
@@ -26,11 +26,11 @@ from flwr.common import (
26
26
  Context,
27
27
  Message,
28
28
  Parameters,
29
- RecordSet,
29
+ RecordDict,
30
30
  ndarray_to_bytes,
31
31
  parameters_to_ndarrays,
32
32
  )
33
- from flwr.common import recordset_compat as compat
33
+ from flwr.common import recorddict_compat as compat
34
34
  from flwr.common.constant import MessageType
35
35
  from flwr.common.logger import log
36
36
  from flwr.common.secure_aggregation.crypto.shamir import create_shares
@@ -162,7 +162,7 @@ def secaggplus_mod(
162
162
  check_configs(state.current_stage, configs)
163
163
 
164
164
  # Execute
165
- out_content = RecordSet()
165
+ out_content = RecordDict()
166
166
  if state.current_stage == Stage.SETUP:
167
167
  state.nid = msg.metadata.dst_node_id
168
168
  res = _setup(state, configs)
@@ -171,7 +171,7 @@ def secaggplus_mod(
171
171
  elif state.current_stage == Stage.COLLECT_MASKED_VECTORS:
172
172
  out_msg = call_next(msg, ctxt)
173
173
  out_content = out_msg.content
174
- fitres = compat.recordset_to_fitres(out_content, keep_input=True)
174
+ fitres = compat.recorddict_to_fitres(out_content, keep_input=True)
175
175
  res = _collect_masked_vectors(
176
176
  state, configs, fitres.num_examples, fitres.parameters
177
177
  )
@@ -187,7 +187,7 @@ def secaggplus_mod(
187
187
 
188
188
  # Return message
189
189
  out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False)
190
- return msg.create_reply(out_content)
190
+ return Message(out_content, reply_to=msg)
191
191
 
192
192
 
193
193
  def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
@@ -19,7 +19,7 @@ from dataclasses import dataclass
19
19
  from pathlib import Path
20
20
  from typing import Optional
21
21
 
22
- from flwr.common import Context, RecordSet
22
+ from flwr.common import Context, RecordDict
23
23
  from flwr.common.config import (
24
24
  get_fused_config,
25
25
  get_fused_config_from_dir,
@@ -86,7 +86,7 @@ class DeprecatedRunInfoStore:
86
86
  run_id=run_id,
87
87
  node_id=self.node_id,
88
88
  node_config=self.node_config,
89
- state=RecordSet(),
89
+ state=RecordDict(),
90
90
  run_config=initial_run_config.copy(),
91
91
  ),
92
92
  )
flwr/common/__init__.py CHANGED
@@ -34,6 +34,7 @@ from .record import Array as Array
34
34
  from .record import ConfigsRecord as ConfigsRecord
35
35
  from .record import MetricsRecord as MetricsRecord
36
36
  from .record import ParametersRecord as ParametersRecord
37
+ from .record import RecordDict as RecordDict
37
38
  from .record import RecordSet as RecordSet
38
39
  from .record import array_from_numpy as array_from_numpy
39
40
  from .telemetry import EventType as EventType
@@ -98,6 +99,7 @@ __all__ = [
98
99
  "ParametersRecord",
99
100
  "Properties",
100
101
  "ReconnectIns",
102
+ "RecordDict",
101
103
  "RecordSet",
102
104
  "Scalar",
103
105
  "ServerMessage",
flwr/common/context.py CHANGED
@@ -17,7 +17,7 @@
17
17
 
18
18
  from dataclasses import dataclass
19
19
 
20
- from .record import RecordSet
20
+ from .record import RecordDict
21
21
  from .typing import UserConfig
22
22
 
23
23
 
@@ -34,7 +34,7 @@ class Context:
34
34
  node_config : UserConfig
35
35
  A config (key/value mapping) unique to the node and independent of the
36
36
  `run_config`. This config persists across all runs this node participates in.
37
- state : RecordSet
37
+ state : RecordDict
38
38
  Holds records added by the entity in a given `run_id` and that will stay local.
39
39
  This means that the data it holds will never leave the system it's running from.
40
40
  This can be used as an intermediate storage or scratchpad when
@@ -50,7 +50,7 @@ class Context:
50
50
  run_id: int
51
51
  node_id: int
52
52
  node_config: UserConfig
53
- state: RecordSet
53
+ state: RecordDict
54
54
  run_config: UserConfig
55
55
 
56
56
  def __init__( # pylint: disable=too-many-arguments, too-many-positional-arguments
@@ -58,7 +58,7 @@ class Context:
58
58
  run_id: int,
59
59
  node_id: int,
60
60
  node_config: UserConfig,
61
- state: RecordSet,
61
+ state: RecordDict,
62
62
  run_config: UserConfig,
63
63
  ) -> None:
64
64
  self.run_id = run_id