flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. flwr/cli/build.py +2 -0
  2. flwr/cli/log.py +20 -21
  3. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  12. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  13. flwr/cli/run/run.py +5 -9
  14. flwr/client/app.py +6 -4
  15. flwr/client/client_app.py +260 -86
  16. flwr/client/clientapp/app.py +6 -2
  17. flwr/client/grpc_client/connection.py +24 -21
  18. flwr/client/message_handler/message_handler.py +28 -28
  19. flwr/client/mod/__init__.py +2 -2
  20. flwr/client/mod/centraldp_mods.py +7 -7
  21. flwr/client/mod/comms_mods.py +16 -22
  22. flwr/client/mod/localdp_mod.py +4 -4
  23. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  24. flwr/client/rest_client/connection.py +4 -6
  25. flwr/client/run_info_store.py +2 -2
  26. flwr/client/supernode/__init__.py +0 -2
  27. flwr/client/supernode/app.py +1 -11
  28. flwr/common/__init__.py +12 -4
  29. flwr/common/address.py +35 -0
  30. flwr/common/args.py +8 -2
  31. flwr/common/auth_plugin/auth_plugin.py +2 -1
  32. flwr/common/config.py +4 -4
  33. flwr/common/constant.py +16 -0
  34. flwr/common/context.py +4 -4
  35. flwr/common/event_log_plugin/__init__.py +22 -0
  36. flwr/common/event_log_plugin/event_log_plugin.py +60 -0
  37. flwr/common/grpc.py +1 -1
  38. flwr/common/logger.py +2 -2
  39. flwr/common/message.py +338 -102
  40. flwr/common/object_ref.py +0 -10
  41. flwr/common/record/__init__.py +8 -4
  42. flwr/common/record/arrayrecord.py +626 -0
  43. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  44. flwr/common/record/conversion_utils.py +9 -18
  45. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  46. flwr/common/record/recorddict.py +288 -0
  47. flwr/common/recorddict_compat.py +410 -0
  48. flwr/common/secure_aggregation/quantization.py +5 -1
  49. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  50. flwr/common/serde.py +67 -190
  51. flwr/common/telemetry.py +0 -10
  52. flwr/common/typing.py +44 -8
  53. flwr/proto/exec_pb2.py +3 -3
  54. flwr/proto/exec_pb2.pyi +3 -3
  55. flwr/proto/message_pb2.py +12 -12
  56. flwr/proto/message_pb2.pyi +9 -9
  57. flwr/proto/recorddict_pb2.py +70 -0
  58. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  59. flwr/proto/run_pb2.py +31 -31
  60. flwr/proto/run_pb2.pyi +3 -3
  61. flwr/server/__init__.py +3 -1
  62. flwr/server/app.py +74 -3
  63. flwr/server/compat/__init__.py +2 -2
  64. flwr/server/compat/app.py +15 -12
  65. flwr/server/compat/app_utils.py +26 -18
  66. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
  67. flwr/server/fleet_event_log_interceptor.py +94 -0
  68. flwr/server/{driver → grid}/__init__.py +8 -7
  69. flwr/server/{driver/driver.py → grid/grid.py} +48 -19
  70. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
  71. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
  72. flwr/server/run_serverapp.py +6 -17
  73. flwr/server/server_app.py +126 -33
  74. flwr/server/serverapp/app.py +10 -10
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
  76. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
  77. flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
  78. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  79. flwr/server/superlink/fleet/vce/vce_api.py +33 -38
  80. flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
  81. flwr/server/superlink/linkstate/linkstate.py +51 -64
  82. flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
  83. flwr/server/superlink/linkstate/utils.py +171 -133
  84. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  85. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  86. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
  87. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  88. flwr/server/typing.py +3 -3
  89. flwr/server/utils/__init__.py +2 -2
  90. flwr/server/utils/validator.py +53 -68
  91. flwr/server/workflow/default_workflows.py +52 -58
  92. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
  93. flwr/simulation/app.py +2 -2
  94. flwr/simulation/ray_transport/ray_actor.py +4 -2
  95. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  96. flwr/simulation/run_simulation.py +15 -15
  97. flwr/superexec/app.py +0 -14
  98. flwr/superexec/deployment.py +4 -4
  99. flwr/superexec/exec_event_log_interceptor.py +135 -0
  100. flwr/superexec/exec_grpc.py +10 -4
  101. flwr/superexec/exec_servicer.py +6 -6
  102. flwr/superexec/exec_user_auth_interceptor.py +22 -4
  103. flwr/superexec/executor.py +3 -3
  104. flwr/superexec/simulation.py +3 -3
  105. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
  106. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
  107. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
  108. flwr/client/message_handler/task_handler.py +0 -37
  109. flwr/common/record/parametersrecord.py +0 -204
  110. flwr/common/record/recordset.py +0 -202
  111. flwr/common/recordset_compat.py +0 -418
  112. flwr/proto/recordset_pb2.py +0 -70
  113. flwr/proto/task_pb2.py +0 -33
  114. flwr/proto/task_pb2.pyi +0 -100
  115. flwr/proto/task_pb2_grpc.py +0 -4
  116. flwr/proto/task_pb2_grpc.pyi +0 -4
  117. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  118. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  119. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
  120. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
@@ -26,17 +26,17 @@ from flwr.client.client import (
26
26
  )
27
27
  from flwr.client.numpy_client import NumPyClient
28
28
  from flwr.client.typing import ClientFnExt
29
- from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordSet, log
29
+ from flwr.common import ConfigRecord, Context, Message, Metadata, RecordDict, log
30
30
  from flwr.common.constant import MessageType, MessageTypeLegacy
31
- from flwr.common.recordset_compat import (
32
- evaluateres_to_recordset,
33
- fitres_to_recordset,
34
- getparametersres_to_recordset,
35
- getpropertiesres_to_recordset,
36
- recordset_to_evaluateins,
37
- recordset_to_fitins,
38
- recordset_to_getparametersins,
39
- recordset_to_getpropertiesins,
31
+ from flwr.common.recorddict_compat import (
32
+ evaluateres_to_recorddict,
33
+ fitres_to_recorddict,
34
+ getparametersres_to_recorddict,
35
+ getpropertiesres_to_recorddict,
36
+ recorddict_to_evaluateins,
37
+ recorddict_to_fitins,
38
+ recorddict_to_getparametersins,
39
+ recorddict_to_getpropertiesins,
40
40
  )
41
41
  from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
42
42
  ClientMessage,
@@ -70,19 +70,19 @@ def handle_control_message(message: Message) -> tuple[Optional[Message], int]:
70
70
  Number of seconds that the client should disconnect from the server.
71
71
  """
72
72
  if message.metadata.message_type == "reconnect":
73
- # Retrieve ReconnectIns from recordset
74
- recordset = message.content
75
- seconds = cast(int, recordset.configs_records["config"]["seconds"])
73
+ # Retrieve ReconnectIns from RecordDict
74
+ recorddict = message.content
75
+ seconds = cast(int, recorddict.config_records["config"]["seconds"])
76
76
  # Construct ReconnectIns and call _reconnect
77
77
  disconnect_msg, sleep_duration = _reconnect(
78
78
  ServerMessage.ReconnectIns(seconds=seconds)
79
79
  )
80
- # Store DisconnectRes in recordset
80
+ # Store DisconnectRes in RecordDict
81
81
  reason = cast(int, disconnect_msg.disconnect_res.reason)
82
- recordset = RecordSet()
83
- recordset.configs_records["config"] = ConfigsRecord({"reason": reason})
84
- out_message = message.create_reply(recordset)
85
- # Return TaskRes and sleep duration
82
+ recorddict = RecordDict()
83
+ recorddict.config_records["config"] = ConfigRecord({"reason": reason})
84
+ out_message = Message(recorddict, reply_to=message)
85
+ # Return Message and sleep duration
86
86
  return out_message, sleep_duration
87
87
 
88
88
  # Any other message
@@ -111,37 +111,37 @@ def handle_legacy_message_from_msgtype(
111
111
  if message_type == MessageTypeLegacy.GET_PROPERTIES:
112
112
  get_properties_res = maybe_call_get_properties(
113
113
  client=client,
114
- get_properties_ins=recordset_to_getpropertiesins(message.content),
114
+ get_properties_ins=recorddict_to_getpropertiesins(message.content),
115
115
  )
116
- out_recordset = getpropertiesres_to_recordset(get_properties_res)
116
+ out_recorddict = getpropertiesres_to_recorddict(get_properties_res)
117
117
  # Handle GetParametersIns
118
118
  elif message_type == MessageTypeLegacy.GET_PARAMETERS:
119
119
  get_parameters_res = maybe_call_get_parameters(
120
120
  client=client,
121
- get_parameters_ins=recordset_to_getparametersins(message.content),
121
+ get_parameters_ins=recorddict_to_getparametersins(message.content),
122
122
  )
123
- out_recordset = getparametersres_to_recordset(
123
+ out_recorddict = getparametersres_to_recorddict(
124
124
  get_parameters_res, keep_input=False
125
125
  )
126
126
  # Handle FitIns
127
127
  elif message_type == MessageType.TRAIN:
128
128
  fit_res = maybe_call_fit(
129
129
  client=client,
130
- fit_ins=recordset_to_fitins(message.content, keep_input=True),
130
+ fit_ins=recorddict_to_fitins(message.content, keep_input=True),
131
131
  )
132
- out_recordset = fitres_to_recordset(fit_res, keep_input=False)
132
+ out_recorddict = fitres_to_recorddict(fit_res, keep_input=False)
133
133
  # Handle EvaluateIns
134
134
  elif message_type == MessageType.EVALUATE:
135
135
  evaluate_res = maybe_call_evaluate(
136
136
  client=client,
137
- evaluate_ins=recordset_to_evaluateins(message.content, keep_input=True),
137
+ evaluate_ins=recorddict_to_evaluateins(message.content, keep_input=True),
138
138
  )
139
- out_recordset = evaluateres_to_recordset(evaluate_res)
139
+ out_recorddict = evaluateres_to_recorddict(evaluate_res)
140
140
  else:
141
141
  raise ValueError(f"Invalid message type: {message_type}")
142
142
 
143
143
  # Return Message
144
- return message.create_reply(out_recordset)
144
+ return Message(out_recorddict, reply_to=message)
145
145
 
146
146
 
147
147
  def _reconnect(
@@ -167,7 +167,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
167
167
  and out_meta.message_id == "" # This will be generated by the server
168
168
  and out_meta.src_node_id == in_meta.dst_node_id
169
169
  and out_meta.dst_node_id == in_meta.src_node_id
170
- and out_meta.reply_to_message == in_meta.message_id
170
+ and out_meta.reply_to_message_id == in_meta.message_id
171
171
  and out_meta.group_id == in_meta.group_id
172
172
  and out_meta.message_type == in_meta.message_type
173
173
  and out_meta.created_at > in_meta.created_at
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod
19
- from .comms_mods import message_size_mod, parameters_size_mod
19
+ from .comms_mods import arrays_size_mod, message_size_mod
20
20
  from .localdp_mod import LocalDpMod
21
21
  from .secure_aggregation import secagg_mod, secaggplus_mod
22
22
  from .utils import make_ffn
@@ -24,10 +24,10 @@ from .utils import make_ffn
24
24
  __all__ = [
25
25
  "LocalDpMod",
26
26
  "adaptiveclipping_mod",
27
+ "arrays_size_mod",
27
28
  "fixedclipping_mod",
28
29
  "make_ffn",
29
30
  "message_size_mod",
30
- "parameters_size_mod",
31
31
  "secagg_mod",
32
32
  "secaggplus_mod",
33
33
  ]
@@ -19,7 +19,7 @@ from logging import INFO
19
19
 
20
20
  from flwr.client.typing import ClientAppCallable
21
21
  from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
22
- from flwr.common import recordset_compat as compat
22
+ from flwr.common import recorddict_compat as compat
23
23
  from flwr.common.constant import MessageType
24
24
  from flwr.common.context import Context
25
25
  from flwr.common.differential_privacy import (
@@ -53,7 +53,7 @@ def fixedclipping_mod(
53
53
  """
54
54
  if msg.metadata.message_type != MessageType.TRAIN:
55
55
  return call_next(msg, ctxt)
56
- fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
56
+ fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
57
57
  if KEY_CLIPPING_NORM not in fit_ins.config:
58
58
  raise KeyError(
59
59
  f"The {KEY_CLIPPING_NORM} value is not supplied by the "
@@ -71,7 +71,7 @@ def fixedclipping_mod(
71
71
  if out_msg.has_error():
72
72
  return out_msg
73
73
 
74
- fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True)
74
+ fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
75
75
 
76
76
  client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
77
77
 
@@ -87,7 +87,7 @@ def fixedclipping_mod(
87
87
  )
88
88
 
89
89
  fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
90
- out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
90
+ out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
91
91
  return out_msg
92
92
 
93
93
 
@@ -116,7 +116,7 @@ def adaptiveclipping_mod(
116
116
  if msg.metadata.message_type != MessageType.TRAIN:
117
117
  return call_next(msg, ctxt)
118
118
 
119
- fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
119
+ fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
120
120
 
121
121
  if KEY_CLIPPING_NORM not in fit_ins.config:
122
122
  raise KeyError(
@@ -136,7 +136,7 @@ def adaptiveclipping_mod(
136
136
  if out_msg.has_error():
137
137
  return out_msg
138
138
 
139
- fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True)
139
+ fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
140
140
 
141
141
  client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
142
142
 
@@ -155,5 +155,5 @@ def adaptiveclipping_mod(
155
155
  fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
156
156
 
157
157
  fit_res.metrics[KEY_NORM_BIT] = norm_bit
158
- out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
158
+ out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
159
159
  return out_msg
@@ -34,47 +34,41 @@ def message_size_mod(
34
34
  """
35
35
  message_size_in_bytes = 0
36
36
 
37
- for p_record in msg.content.parameters_records.values():
38
- message_size_in_bytes += p_record.count_bytes()
39
-
40
- for c_record in msg.content.configs_records.values():
41
- message_size_in_bytes += c_record.count_bytes()
42
-
43
- for m_record in msg.content.metrics_records.values():
44
- message_size_in_bytes += m_record.count_bytes()
37
+ for record in msg.content.values():
38
+ message_size_in_bytes += record.count_bytes()
45
39
 
46
40
  log(INFO, "Message size: %i bytes", message_size_in_bytes)
47
41
 
48
42
  return call_next(msg, ctxt)
49
43
 
50
44
 
51
- def parameters_size_mod(
45
+ def arrays_size_mod(
52
46
  msg: Message, ctxt: Context, call_next: ClientAppCallable
53
47
  ) -> Message:
54
- """Parameters size mod.
48
+ """Arrays size mod.
55
49
 
56
- This mod logs the number of parameters transmitted in the message as well as their
57
- size in bytes.
50
+ This mod logs the number of array elements transmitted in ``ArrayRecord``s of
51
+ the message as well as their sizes in bytes.
58
52
  """
59
53
  model_size_stats = {}
60
- parameters_size_in_bytes = 0
61
- for record_name, p_record in msg.content.parameters_records.items():
62
- p_record_bytes = p_record.count_bytes()
63
- parameters_size_in_bytes += p_record_bytes
64
- parameter_count = 0
65
- for array in p_record.values():
66
- parameter_count += (
54
+ arrays_size_in_bytes = 0
55
+ for record_name, arr_record in msg.content.array_records.items():
56
+ arr_record_bytes = arr_record.count_bytes()
57
+ arrays_size_in_bytes += arr_record_bytes
58
+ element_count = 0
59
+ for array in arr_record.values():
60
+ element_count += (
67
61
  int(np.prod(array.shape)) if array.shape else array.numpy().size
68
62
  )
69
63
 
70
64
  model_size_stats[f"{record_name}"] = {
71
- "parameters": parameter_count,
72
- "bytes": p_record_bytes,
65
+ "elements": element_count,
66
+ "bytes": arr_record_bytes,
73
67
  }
74
68
 
75
69
  if model_size_stats:
76
70
  log(INFO, model_size_stats)
77
71
 
78
- log(INFO, "Total parameters transmitted: %i bytes", parameters_size_in_bytes)
72
+ log(INFO, "Total array elements transmitted: %i bytes", arrays_size_in_bytes)
79
73
 
80
74
  return call_next(msg, ctxt)
@@ -21,7 +21,7 @@ import numpy as np
21
21
 
22
22
  from flwr.client.typing import ClientAppCallable
23
23
  from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
24
- from flwr.common import recordset_compat as compat
24
+ from flwr.common import recorddict_compat as compat
25
25
  from flwr.common.constant import MessageType
26
26
  from flwr.common.context import Context
27
27
  from flwr.common.differential_privacy import (
@@ -107,7 +107,7 @@ class LocalDpMod:
107
107
  if msg.metadata.message_type != MessageType.TRAIN:
108
108
  return call_next(msg, ctxt)
109
109
 
110
- fit_ins = compat.recordset_to_fitins(msg.content, keep_input=True)
110
+ fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
111
111
  server_to_client_params = parameters_to_ndarrays(fit_ins.parameters)
112
112
 
113
113
  # Call inner app
@@ -117,7 +117,7 @@ class LocalDpMod:
117
117
  if out_msg.has_error():
118
118
  return out_msg
119
119
 
120
- fit_res = compat.recordset_to_fitres(out_msg.content, keep_input=True)
120
+ fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
121
121
 
122
122
  client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
123
123
 
@@ -149,5 +149,5 @@ class LocalDpMod:
149
149
  noise_value_sd,
150
150
  )
151
151
 
152
- out_msg.content = compat.fitres_to_recordset(fit_res, keep_input=True)
152
+ out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
153
153
  return out_msg
@@ -22,15 +22,15 @@ from typing import Any, cast
22
22
 
23
23
  from flwr.client.typing import ClientAppCallable
24
24
  from flwr.common import (
25
- ConfigsRecord,
25
+ ConfigRecord,
26
26
  Context,
27
27
  Message,
28
28
  Parameters,
29
- RecordSet,
29
+ RecordDict,
30
30
  ndarray_to_bytes,
31
31
  parameters_to_ndarrays,
32
32
  )
33
- from flwr.common import recordset_compat as compat
33
+ from flwr.common import recorddict_compat as compat
34
34
  from flwr.common.constant import MessageType
35
35
  from flwr.common.logger import log
36
36
  from flwr.common.secure_aggregation.crypto.shamir import create_shares
@@ -63,7 +63,7 @@ from flwr.common.secure_aggregation.secaggplus_utils import (
63
63
  share_keys_plaintext_concat,
64
64
  share_keys_plaintext_separate,
65
65
  )
66
- from flwr.common.typing import ConfigsRecordValues
66
+ from flwr.common.typing import ConfigRecordValues
67
67
 
68
68
 
69
69
  @dataclass
@@ -97,7 +97,7 @@ class SecAggPlusState:
97
97
  ss2_dict: dict[int, bytes] = field(default_factory=dict)
98
98
  public_keys_dict: dict[int, tuple[bytes, bytes]] = field(default_factory=dict)
99
99
 
100
- def __init__(self, **kwargs: ConfigsRecordValues) -> None:
100
+ def __init__(self, **kwargs: ConfigRecordValues) -> None:
101
101
  for k, v in kwargs.items():
102
102
  if k.endswith(":V"):
103
103
  continue
@@ -115,7 +115,7 @@ class SecAggPlusState:
115
115
  new_v = dict(zip(keys, values))
116
116
  self.__setattr__(k, new_v)
117
117
 
118
- def to_dict(self) -> dict[str, ConfigsRecordValues]:
118
+ def to_dict(self) -> dict[str, ConfigRecordValues]:
119
119
  """Convert the state to a dictionary."""
120
120
  ret = vars(self)
121
121
  for k in list(ret.keys()):
@@ -144,13 +144,13 @@ def secaggplus_mod(
144
144
  return call_next(msg, ctxt)
145
145
 
146
146
  # Retrieve local state
147
- if RECORD_KEY_STATE not in ctxt.state.configs_records:
148
- ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord({})
149
- state_dict = ctxt.state.configs_records[RECORD_KEY_STATE]
147
+ if RECORD_KEY_STATE not in ctxt.state.config_records:
148
+ ctxt.state.config_records[RECORD_KEY_STATE] = ConfigRecord({})
149
+ state_dict = ctxt.state.config_records[RECORD_KEY_STATE]
150
150
  state = SecAggPlusState(**state_dict)
151
151
 
152
152
  # Retrieve incoming configs
153
- configs = msg.content.configs_records[RECORD_KEY_CONFIGS]
153
+ configs = msg.content.config_records[RECORD_KEY_CONFIGS]
154
154
 
155
155
  # Check the validity of the next stage
156
156
  check_stage(state.current_stage, configs)
@@ -162,7 +162,7 @@ def secaggplus_mod(
162
162
  check_configs(state.current_stage, configs)
163
163
 
164
164
  # Execute
165
- out_content = RecordSet()
165
+ out_content = RecordDict()
166
166
  if state.current_stage == Stage.SETUP:
167
167
  state.nid = msg.metadata.dst_node_id
168
168
  res = _setup(state, configs)
@@ -171,31 +171,31 @@ def secaggplus_mod(
171
171
  elif state.current_stage == Stage.COLLECT_MASKED_VECTORS:
172
172
  out_msg = call_next(msg, ctxt)
173
173
  out_content = out_msg.content
174
- fitres = compat.recordset_to_fitres(out_content, keep_input=True)
174
+ fitres = compat.recorddict_to_fitres(out_content, keep_input=True)
175
175
  res = _collect_masked_vectors(
176
176
  state, configs, fitres.num_examples, fitres.parameters
177
177
  )
178
- for p_record in out_content.parameters_records.values():
179
- p_record.clear()
178
+ for arr_record in out_content.array_records.values():
179
+ arr_record.clear()
180
180
  elif state.current_stage == Stage.UNMASK:
181
181
  res = _unmask(state, configs)
182
182
  else:
183
183
  raise ValueError(f"Unknown SecAgg/SecAgg+ stage: {state.current_stage}")
184
184
 
185
185
  # Save state
186
- ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict())
186
+ ctxt.state.config_records[RECORD_KEY_STATE] = ConfigRecord(state.to_dict())
187
187
 
188
188
  # Return message
189
- out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False)
190
- return msg.create_reply(out_content)
189
+ out_content.config_records[RECORD_KEY_CONFIGS] = ConfigRecord(res, False)
190
+ return Message(out_content, reply_to=msg)
191
191
 
192
192
 
193
- def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
193
+ def check_stage(current_stage: str, configs: ConfigRecord) -> None:
194
194
  """Check the validity of the next stage."""
195
195
  # Check the existence of Config.STAGE
196
196
  if Key.STAGE not in configs:
197
197
  raise KeyError(
198
- f"The required key '{Key.STAGE}' is missing from the ConfigsRecord."
198
+ f"The required key '{Key.STAGE}' is missing from the ConfigRecord."
199
199
  )
200
200
 
201
201
  # Check the value type of the Config.STAGE
@@ -223,7 +223,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
223
223
 
224
224
 
225
225
  # pylint: disable-next=too-many-branches
226
- def check_configs(stage: str, configs: ConfigsRecord) -> None:
226
+ def check_configs(stage: str, configs: ConfigRecord) -> None:
227
227
  """Check the validity of the configs."""
228
228
  # Check configs for the setup stage
229
229
  if stage == Stage.SETUP:
@@ -239,7 +239,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
239
239
  if key not in configs:
240
240
  raise KeyError(
241
241
  f"Stage {Stage.SETUP}: the required key '{key}' is "
242
- "missing from the ConfigsRecord."
242
+ "missing from the ConfigRecord."
243
243
  )
244
244
  # Bool is a subclass of int in Python,
245
245
  # so `isinstance(v, int)` will return True even if v is a boolean.
@@ -272,7 +272,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
272
272
  raise KeyError(
273
273
  f"Stage {Stage.COLLECT_MASKED_VECTORS}: "
274
274
  f"the required key '{key}' is "
275
- "missing from the ConfigsRecord."
275
+ "missing from the ConfigRecord."
276
276
  )
277
277
  if not isinstance(configs[key], list) or any(
278
278
  elm
@@ -295,7 +295,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
295
295
  raise KeyError(
296
296
  f"Stage {Stage.UNMASK}: "
297
297
  f"the required key '{key}' is "
298
- "missing from the ConfigsRecord."
298
+ "missing from the ConfigRecord."
299
299
  )
300
300
  if not isinstance(configs[key], list) or any(
301
301
  elm
@@ -313,8 +313,8 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None:
313
313
 
314
314
 
315
315
  def _setup(
316
- state: SecAggPlusState, configs: ConfigsRecord
317
- ) -> dict[str, ConfigsRecordValues]:
316
+ state: SecAggPlusState, configs: ConfigRecord
317
+ ) -> dict[str, ConfigRecordValues]:
318
318
  # Assigning parameter values to object fields
319
319
  sec_agg_param_dict = configs
320
320
  state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER])
@@ -349,8 +349,8 @@ def _setup(
349
349
 
350
350
  # pylint: disable-next=too-many-locals
351
351
  def _share_keys(
352
- state: SecAggPlusState, configs: ConfigsRecord
353
- ) -> dict[str, ConfigsRecordValues]:
352
+ state: SecAggPlusState, configs: ConfigRecord
353
+ ) -> dict[str, ConfigRecordValues]:
354
354
  named_bytes_tuples = cast(dict[str, tuple[bytes, bytes]], configs)
355
355
  key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()}
356
356
  log(DEBUG, "Node %d: starting stage 1...", state.nid)
@@ -412,10 +412,10 @@ def _share_keys(
412
412
  # pylint: disable-next=too-many-locals
413
413
  def _collect_masked_vectors(
414
414
  state: SecAggPlusState,
415
- configs: ConfigsRecord,
415
+ configs: ConfigRecord,
416
416
  num_examples: int,
417
417
  updated_parameters: Parameters,
418
- ) -> dict[str, ConfigsRecordValues]:
418
+ ) -> dict[str, ConfigRecordValues]:
419
419
  log(DEBUG, "Node %d: starting stage 2...", state.nid)
420
420
  available_clients: list[int] = []
421
421
  ciphertexts = cast(list[bytes], configs[Key.CIPHERTEXT_LIST])
@@ -498,8 +498,8 @@ def _collect_masked_vectors(
498
498
 
499
499
 
500
500
  def _unmask(
501
- state: SecAggPlusState, configs: ConfigsRecord
502
- ) -> dict[str, ConfigsRecordValues]:
501
+ state: SecAggPlusState, configs: ConfigRecord
502
+ ) -> dict[str, ConfigRecordValues]:
503
503
  log(DEBUG, "Node %d: starting stage 3...", state.nid)
504
504
 
505
505
  active_nids = cast(list[int], configs[Key.ACTIVE_NODE_ID_LIST])
@@ -66,9 +66,7 @@ except ModuleNotFoundError:
66
66
 
67
67
  PATH_CREATE_NODE: str = "api/v0/fleet/create-node"
68
68
  PATH_DELETE_NODE: str = "api/v0/fleet/delete-node"
69
- PATH_PULL_TASK_INS: str = "api/v0/fleet/pull-task-ins"
70
69
  PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages"
71
- PATH_PUSH_TASK_RES: str = "api/v0/fleet/push-task-res"
72
70
  PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages"
73
71
  PATH_PING: str = "api/v0/fleet/ping"
74
72
  PATH_GET_RUN: str = "/api/v0/fleet/get-run"
@@ -280,7 +278,7 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
280
278
  node = None
281
279
 
282
280
  def receive() -> Optional[Message]:
283
- """Receive next task from server."""
281
+ """Receive next Message from server."""
284
282
  # Get Node
285
283
  if node is None:
286
284
  log(ERROR, "Node instance missing")
@@ -309,11 +307,11 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
309
307
  if message_proto is not None:
310
308
  message = message_from_proto(message_proto)
311
309
  metadata = copy(message.metadata)
312
- log(INFO, "[Node] POST /%s: success", PATH_PULL_TASK_INS)
310
+ log(INFO, "[Node] POST /%s: success", PATH_PULL_MESSAGES)
313
311
  return message
314
312
 
315
313
  def send(message: Message) -> None:
316
- """Send task result back to server."""
314
+ """Send Message result back to server."""
317
315
  # Get Node
318
316
  if node is None:
319
317
  log(ERROR, "Node instance missing")
@@ -345,7 +343,7 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
345
343
  log(
346
344
  INFO,
347
345
  "[Node] POST /%s: success, created result %s",
348
- PATH_PUSH_TASK_RES,
346
+ PATH_PUSH_MESSAGES,
349
347
  res.results, # pylint: disable=no-member
350
348
  )
351
349
 
@@ -19,7 +19,7 @@ from dataclasses import dataclass
19
19
  from pathlib import Path
20
20
  from typing import Optional
21
21
 
22
- from flwr.common import Context, RecordSet
22
+ from flwr.common import Context, RecordDict
23
23
  from flwr.common.config import (
24
24
  get_fused_config,
25
25
  get_fused_config_from_dir,
@@ -86,7 +86,7 @@ class DeprecatedRunInfoStore:
86
86
  run_id=run_id,
87
87
  node_id=self.node_id,
88
88
  node_config=self.node_config,
89
- state=RecordSet(),
89
+ state=RecordDict(),
90
90
  run_config=initial_run_config.copy(),
91
91
  ),
92
92
  )
@@ -15,10 +15,8 @@
15
15
  """Flower SuperNode."""
16
16
 
17
17
 
18
- from .app import run_client_app as run_client_app
19
18
  from .app import run_supernode as run_supernode
20
19
 
21
20
  __all__ = [
22
- "run_client_app",
23
21
  "run_supernode",
24
22
  ]
@@ -16,7 +16,7 @@
16
16
 
17
17
 
18
18
  import argparse
19
- from logging import DEBUG, ERROR, INFO, WARN
19
+ from logging import DEBUG, INFO, WARN
20
20
  from pathlib import Path
21
21
  from typing import Optional
22
22
 
@@ -98,16 +98,6 @@ def run_supernode() -> None:
98
98
  )
99
99
 
100
100
 
101
- def run_client_app() -> None:
102
- """Run Flower client app."""
103
- event(EventType.RUN_CLIENT_APP_ENTER)
104
- log(
105
- ERROR,
106
- "The command `flower-client-app` has been replaced by `flwr run`.",
107
- )
108
- register_exit_handlers(event_type=EventType.RUN_CLIENT_APP_LEAVE)
109
-
110
-
111
101
  def _parse_args_run_supernode() -> argparse.ArgumentParser:
112
102
  """Parse flower-supernode command line arguments."""
113
103
  parser = argparse.ArgumentParser(
flwr/common/__init__.py CHANGED
@@ -31,9 +31,13 @@ from .parameter import ndarray_to_bytes as ndarray_to_bytes
31
31
  from .parameter import ndarrays_to_parameters as ndarrays_to_parameters
32
32
  from .parameter import parameters_to_ndarrays as parameters_to_ndarrays
33
33
  from .record import Array as Array
34
+ from .record import ArrayRecord as ArrayRecord
35
+ from .record import ConfigRecord as ConfigRecord
34
36
  from .record import ConfigsRecord as ConfigsRecord
37
+ from .record import MetricRecord as MetricRecord
35
38
  from .record import MetricsRecord as MetricsRecord
36
39
  from .record import ParametersRecord as ParametersRecord
40
+ from .record import RecordDict as RecordDict
37
41
  from .record import RecordSet as RecordSet
38
42
  from .record import array_from_numpy as array_from_numpy
39
43
  from .telemetry import EventType as EventType
@@ -41,7 +45,7 @@ from .telemetry import event as event
41
45
  from .typing import ClientMessage as ClientMessage
42
46
  from .typing import Code as Code
43
47
  from .typing import Config as Config
44
- from .typing import ConfigsRecordValues as ConfigsRecordValues
48
+ from .typing import ConfigRecordValues as ConfigRecordValues
45
49
  from .typing import DisconnectRes as DisconnectRes
46
50
  from .typing import EvaluateIns as EvaluateIns
47
51
  from .typing import EvaluateRes as EvaluateRes
@@ -51,9 +55,9 @@ from .typing import GetParametersIns as GetParametersIns
51
55
  from .typing import GetParametersRes as GetParametersRes
52
56
  from .typing import GetPropertiesIns as GetPropertiesIns
53
57
  from .typing import GetPropertiesRes as GetPropertiesRes
58
+ from .typing import MetricRecordValues as MetricRecordValues
54
59
  from .typing import Metrics as Metrics
55
60
  from .typing import MetricsAggregationFn as MetricsAggregationFn
56
- from .typing import MetricsRecordValues as MetricsRecordValues
57
61
  from .typing import NDArray as NDArray
58
62
  from .typing import NDArrays as NDArrays
59
63
  from .typing import Parameters as Parameters
@@ -65,11 +69,13 @@ from .typing import Status as Status
65
69
 
66
70
  __all__ = [
67
71
  "Array",
72
+ "ArrayRecord",
68
73
  "ClientMessage",
69
74
  "Code",
70
75
  "Config",
76
+ "ConfigRecord",
77
+ "ConfigRecordValues",
71
78
  "ConfigsRecord",
72
- "ConfigsRecordValues",
73
79
  "Context",
74
80
  "DEFAULT_TTL",
75
81
  "DisconnectRes",
@@ -88,16 +94,18 @@ __all__ = [
88
94
  "MessageType",
89
95
  "MessageTypeLegacy",
90
96
  "Metadata",
97
+ "MetricRecord",
98
+ "MetricRecordValues",
91
99
  "Metrics",
92
100
  "MetricsAggregationFn",
93
101
  "MetricsRecord",
94
- "MetricsRecordValues",
95
102
  "NDArray",
96
103
  "NDArrays",
97
104
  "Parameters",
98
105
  "ParametersRecord",
99
106
  "Properties",
100
107
  "ReconnectIns",
108
+ "RecordDict",
101
109
  "RecordSet",
102
110
  "Scalar",
103
111
  "ServerMessage",