flwr 1.15.2__py3-none-any.whl → 1.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. flwr/cli/build.py +2 -0
  2. flwr/cli/log.py +20 -21
  3. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
  4. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
  5. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
  6. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
  7. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  8. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
  9. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  10. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
  11. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  12. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  13. flwr/cli/run/run.py +5 -9
  14. flwr/client/app.py +6 -4
  15. flwr/client/client_app.py +260 -86
  16. flwr/client/clientapp/app.py +6 -2
  17. flwr/client/grpc_client/connection.py +24 -21
  18. flwr/client/message_handler/message_handler.py +28 -28
  19. flwr/client/mod/__init__.py +2 -2
  20. flwr/client/mod/centraldp_mods.py +7 -7
  21. flwr/client/mod/comms_mods.py +16 -22
  22. flwr/client/mod/localdp_mod.py +4 -4
  23. flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
  24. flwr/client/rest_client/connection.py +4 -6
  25. flwr/client/run_info_store.py +2 -2
  26. flwr/client/supernode/__init__.py +0 -2
  27. flwr/client/supernode/app.py +1 -11
  28. flwr/common/__init__.py +12 -4
  29. flwr/common/address.py +35 -0
  30. flwr/common/args.py +8 -2
  31. flwr/common/auth_plugin/auth_plugin.py +2 -1
  32. flwr/common/config.py +4 -4
  33. flwr/common/constant.py +16 -0
  34. flwr/common/context.py +4 -4
  35. flwr/common/event_log_plugin/__init__.py +22 -0
  36. flwr/common/event_log_plugin/event_log_plugin.py +60 -0
  37. flwr/common/grpc.py +1 -1
  38. flwr/common/logger.py +2 -2
  39. flwr/common/message.py +338 -102
  40. flwr/common/object_ref.py +0 -10
  41. flwr/common/record/__init__.py +8 -4
  42. flwr/common/record/arrayrecord.py +626 -0
  43. flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
  44. flwr/common/record/conversion_utils.py +9 -18
  45. flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
  46. flwr/common/record/recorddict.py +288 -0
  47. flwr/common/recorddict_compat.py +410 -0
  48. flwr/common/secure_aggregation/quantization.py +5 -1
  49. flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
  50. flwr/common/serde.py +67 -190
  51. flwr/common/telemetry.py +0 -10
  52. flwr/common/typing.py +44 -8
  53. flwr/proto/exec_pb2.py +3 -3
  54. flwr/proto/exec_pb2.pyi +3 -3
  55. flwr/proto/message_pb2.py +12 -12
  56. flwr/proto/message_pb2.pyi +9 -9
  57. flwr/proto/recorddict_pb2.py +70 -0
  58. flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
  59. flwr/proto/run_pb2.py +31 -31
  60. flwr/proto/run_pb2.pyi +3 -3
  61. flwr/server/__init__.py +3 -1
  62. flwr/server/app.py +74 -3
  63. flwr/server/compat/__init__.py +2 -2
  64. flwr/server/compat/app.py +15 -12
  65. flwr/server/compat/app_utils.py +26 -18
  66. flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +41 -41
  67. flwr/server/fleet_event_log_interceptor.py +94 -0
  68. flwr/server/{driver → grid}/__init__.py +8 -7
  69. flwr/server/{driver/driver.py → grid/grid.py} +48 -19
  70. flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +88 -56
  71. flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +41 -54
  72. flwr/server/run_serverapp.py +6 -17
  73. flwr/server/server_app.py +126 -33
  74. flwr/server/serverapp/app.py +10 -10
  75. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
  76. flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
  77. flwr/server/superlink/fleet/vce/backend/backend.py +3 -3
  78. flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
  79. flwr/server/superlink/fleet/vce/vce_api.py +33 -38
  80. flwr/server/superlink/linkstate/in_memory_linkstate.py +171 -132
  81. flwr/server/superlink/linkstate/linkstate.py +51 -64
  82. flwr/server/superlink/linkstate/sqlite_linkstate.py +253 -285
  83. flwr/server/superlink/linkstate/utils.py +171 -133
  84. flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
  85. flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
  86. flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +27 -29
  87. flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
  88. flwr/server/typing.py +3 -3
  89. flwr/server/utils/__init__.py +2 -2
  90. flwr/server/utils/validator.py +53 -68
  91. flwr/server/workflow/default_workflows.py +52 -58
  92. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
  93. flwr/simulation/app.py +2 -2
  94. flwr/simulation/ray_transport/ray_actor.py +4 -2
  95. flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
  96. flwr/simulation/run_simulation.py +15 -15
  97. flwr/superexec/app.py +0 -14
  98. flwr/superexec/deployment.py +4 -4
  99. flwr/superexec/exec_event_log_interceptor.py +135 -0
  100. flwr/superexec/exec_grpc.py +10 -4
  101. flwr/superexec/exec_servicer.py +6 -6
  102. flwr/superexec/exec_user_auth_interceptor.py +22 -4
  103. flwr/superexec/executor.py +3 -3
  104. flwr/superexec/simulation.py +3 -3
  105. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/METADATA +5 -5
  106. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/RECORD +111 -112
  107. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -3
  108. flwr/client/message_handler/task_handler.py +0 -37
  109. flwr/common/record/parametersrecord.py +0 -204
  110. flwr/common/record/recordset.py +0 -202
  111. flwr/common/recordset_compat.py +0 -418
  112. flwr/proto/recordset_pb2.py +0 -70
  113. flwr/proto/task_pb2.py +0 -33
  114. flwr/proto/task_pb2.pyi +0 -100
  115. flwr/proto/task_pb2_grpc.py +0 -4
  116. flwr/proto/task_pb2_grpc.pyi +0 -4
  117. /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
  118. /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
  119. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
  120. {flwr-1.15.2.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
@@ -16,93 +16,78 @@
16
16
 
17
17
 
18
18
  import time
19
- from typing import Union
20
19
 
20
+ from flwr.common import Message
21
21
  from flwr.common.constant import SUPERLINK_NODE_ID
22
- from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
23
22
 
24
23
 
25
- # pylint: disable-next=too-many-branches,too-many-statements
26
- def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> list[str]:
27
- """Validate a TaskIns or TaskRes."""
24
+ # pylint: disable-next=too-many-branches
25
+ def validate_message(message: Message, is_reply_message: bool) -> list[str]:
26
+ """Validate a Message."""
28
27
  validation_errors = []
28
+ metadata = message.metadata
29
29
 
30
- if tasks_ins_res.task_id != "":
31
- validation_errors.append("non-empty `task_id`")
32
-
33
- if not tasks_ins_res.HasField("task"):
34
- validation_errors.append("`task` does not set field `task`")
30
+ if metadata.message_id != "":
31
+ validation_errors.append("non-empty `metadata.message_id`")
35
32
 
36
33
  # Created/delivered/TTL/Pushed
37
34
  if (
38
- tasks_ins_res.task.created_at < 1711497600.0
39
- ): # unix timestamp of 27 March 2024 00h:00m:00s UTC
35
+ metadata.created_at < 1740700800.0
36
+ ): # unix timestamp of 28 February 2025 00h:00m:00s UTC
40
37
  validation_errors.append(
41
- "`created_at` must be a float that records the unix timestamp "
38
+ "`metadata.created_at` must be a float that records the unix timestamp "
42
39
  "in seconds when the message was created."
43
40
  )
44
- if tasks_ins_res.task.delivered_at != "":
45
- validation_errors.append("`delivered_at` must be an empty str")
46
- if tasks_ins_res.task.ttl <= 0:
47
- validation_errors.append("`ttl` must be higher than zero")
41
+ if metadata.delivered_at != "":
42
+ validation_errors.append("`metadata.delivered_at` must be an empty str")
43
+ if metadata.ttl <= 0:
44
+ validation_errors.append("`metadata.ttl` must be higher than zero")
48
45
 
49
46
  # Verify TTL and created_at time
50
47
  current_time = time.time()
51
- if tasks_ins_res.task.created_at + tasks_ins_res.task.ttl <= current_time:
52
- validation_errors.append("Task TTL has expired")
53
-
54
- # TaskIns specific
55
- if isinstance(tasks_ins_res, TaskIns):
56
- # Task producer
57
- if not tasks_ins_res.task.HasField("producer"):
58
- validation_errors.append("`producer` does not set field `producer`")
59
- if tasks_ins_res.task.producer.node_id != SUPERLINK_NODE_ID:
60
- validation_errors.append(f"`producer.node_id` is not {SUPERLINK_NODE_ID}")
61
-
62
- # Task consumer
63
- if not tasks_ins_res.task.HasField("consumer"):
64
- validation_errors.append("`consumer` does not set field `consumer`")
65
- if tasks_ins_res.task.consumer.node_id == SUPERLINK_NODE_ID:
66
- validation_errors.append("consumer MUST provide a valid `node_id`")
48
+ if metadata.created_at + metadata.ttl <= current_time:
49
+ validation_errors.append("Message TTL has expired")
67
50
 
68
- # Content check
69
- if tasks_ins_res.task.task_type == "":
70
- validation_errors.append("`task_type` MUST be set")
71
- if not (
72
- tasks_ins_res.task.HasField("recordset")
73
- ^ tasks_ins_res.task.HasField("error")
74
- ):
75
- validation_errors.append("Either `recordset` or `error` MUST be set")
51
+ # Source node is set and is not zero
52
+ if not metadata.src_node_id:
53
+ validation_errors.append("`metadata.src_node_id` is not set.")
76
54
 
77
- # Ancestors
78
- if len(tasks_ins_res.task.ancestry) != 0:
79
- validation_errors.append("`ancestry` is not empty")
55
+ # Destination node is set and is not zero
56
+ if not metadata.dst_node_id:
57
+ validation_errors.append("`metadata.dst_node_id` is not set.")
80
58
 
81
- # TaskRes specific
82
- if isinstance(tasks_ins_res, TaskRes):
83
- # Task producer
84
- if not tasks_ins_res.task.HasField("producer"):
85
- validation_errors.append("`producer` does not set field `producer`")
86
- if tasks_ins_res.task.producer.node_id == SUPERLINK_NODE_ID:
87
- validation_errors.append("producer MUST provide a valid `node_id`")
59
+ # Message type
60
+ if metadata.message_type == "":
61
+ validation_errors.append("`metadata.message_type` MUST be set")
88
62
 
89
- # Task consumer
90
- if not tasks_ins_res.task.HasField("consumer"):
91
- validation_errors.append("`consumer` does not set field `consumer`")
92
- if tasks_ins_res.task.consumer.node_id != SUPERLINK_NODE_ID:
93
- validation_errors.append(f"consumer is not {SUPERLINK_NODE_ID}")
94
-
95
- # Content check
96
- if tasks_ins_res.task.task_type == "":
97
- validation_errors.append("`task_type` MUST be set")
98
- if not (
99
- tasks_ins_res.task.HasField("recordset")
100
- ^ tasks_ins_res.task.HasField("error")
101
- ):
102
- validation_errors.append("Either `recordset` or `error` MUST be set")
63
+ # Content
64
+ if not message.has_content() != message.has_error():
65
+ validation_errors.append(
66
+ "Either message `content` or `error` MUST be set (but not both)"
67
+ )
103
68
 
104
- # Ancestors
105
- if len(tasks_ins_res.task.ancestry) == 0:
106
- validation_errors.append("`ancestry` is empty")
69
+ # Link respose to original message
70
+ if not is_reply_message:
71
+ if metadata.reply_to_message_id != "":
72
+ validation_errors.append("`metadata.reply_to_message_id` MUST not be set.")
73
+ if metadata.src_node_id != SUPERLINK_NODE_ID:
74
+ validation_errors.append(
75
+ f"`metadata.src_node_id` is not {SUPERLINK_NODE_ID} (SuperLink node ID)"
76
+ )
77
+ if metadata.dst_node_id == SUPERLINK_NODE_ID:
78
+ validation_errors.append(
79
+ f"`metadata.dst_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
80
+ )
81
+ else:
82
+ if metadata.reply_to_message_id == "":
83
+ validation_errors.append("`metadata.reply_to_message_id` MUST be set.")
84
+ if metadata.src_node_id == SUPERLINK_NODE_ID:
85
+ validation_errors.append(
86
+ f"`metadata.src_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
87
+ )
88
+ if metadata.dst_node_id != SUPERLINK_NODE_ID:
89
+ validation_errors.append(
90
+ f"`metadata.dst_node_id` is not {SUPERLINK_NODE_ID} (SuperLink node ID)"
91
+ )
107
92
 
108
93
  return validation_errors
@@ -20,15 +20,16 @@ import timeit
20
20
  from logging import INFO, WARN
21
21
  from typing import Optional, Union, cast
22
22
 
23
- import flwr.common.recordset_compat as compat
23
+ import flwr.common.recorddict_compat as compat
24
24
  from flwr.common import (
25
+ ArrayRecord,
25
26
  Code,
26
- ConfigsRecord,
27
+ ConfigRecord,
27
28
  Context,
28
29
  EvaluateRes,
29
30
  FitRes,
30
31
  GetParametersIns,
31
- ParametersRecord,
32
+ Message,
32
33
  log,
33
34
  )
34
35
  from flwr.common.constant import MessageType, MessageTypeLegacy
@@ -36,7 +37,7 @@ from flwr.common.constant import MessageType, MessageTypeLegacy
36
37
  from ..client_proxy import ClientProxy
37
38
  from ..compat.app_utils import start_update_client_manager_thread
38
39
  from ..compat.legacy_context import LegacyContext
39
- from ..driver import Driver
40
+ from ..grid import Grid
40
41
  from ..typing import Workflow
41
42
  from .constant import MAIN_CONFIGS_RECORD, MAIN_PARAMS_RECORD, Key
42
43
 
@@ -56,7 +57,7 @@ class DefaultWorkflow:
56
57
  self.fit_workflow: Workflow = fit_workflow
57
58
  self.evaluate_workflow: Workflow = evaluate_workflow
58
59
 
59
- def __call__(self, driver: Driver, context: Context) -> None:
60
+ def __call__(self, grid: Grid, context: Context) -> None:
60
61
  """Execute the workflow."""
61
62
  if not isinstance(context, LegacyContext):
62
63
  raise TypeError(
@@ -64,19 +65,22 @@ class DefaultWorkflow:
64
65
  )
65
66
 
66
67
  # Start the thread updating nodes
67
- thread, f_stop = start_update_client_manager_thread(
68
- driver, context.client_manager
68
+ thread, f_stop, c_done = start_update_client_manager_thread(
69
+ grid, context.client_manager
69
70
  )
70
71
 
72
+ # Wait until the node registration done
73
+ c_done.wait()
74
+
71
75
  # Initialize parameters
72
76
  log(INFO, "[INIT]")
73
- default_init_params_workflow(driver, context)
77
+ default_init_params_workflow(grid, context)
74
78
 
75
79
  # Run federated learning for num_rounds
76
80
  start_time = timeit.default_timer()
77
- cfg = ConfigsRecord()
81
+ cfg = ConfigRecord()
78
82
  cfg[Key.START_TIME] = start_time
79
- context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg
83
+ context.state.config_records[MAIN_CONFIGS_RECORD] = cfg
80
84
 
81
85
  for current_round in range(1, context.config.num_rounds + 1):
82
86
  log(INFO, "")
@@ -84,13 +88,13 @@ class DefaultWorkflow:
84
88
  cfg[Key.CURRENT_ROUND] = current_round
85
89
 
86
90
  # Fit round
87
- self.fit_workflow(driver, context)
91
+ self.fit_workflow(grid, context)
88
92
 
89
93
  # Centralized evaluation
90
- default_centralized_evaluation_workflow(driver, context)
94
+ default_centralized_evaluation_workflow(grid, context)
91
95
 
92
96
  # Evaluate round
93
- self.evaluate_workflow(driver, context)
97
+ self.evaluate_workflow(grid, context)
94
98
 
95
99
  # Bookkeeping and log results
96
100
  end_time = timeit.default_timer()
@@ -116,7 +120,7 @@ class DefaultWorkflow:
116
120
  thread.join()
117
121
 
118
122
 
119
- def default_init_params_workflow(driver: Driver, context: Context) -> None:
123
+ def default_init_params_workflow(grid: Grid, context: Context) -> None:
120
124
  """Execute the default workflow for parameters initialization."""
121
125
  if not isinstance(context, LegacyContext):
122
126
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
@@ -126,21 +130,19 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
126
130
  )
127
131
  if parameters is not None:
128
132
  log(INFO, "Using initial global parameters provided by strategy")
129
- paramsrecord = compat.parameters_to_parametersrecord(
130
- parameters, keep_input=True
131
- )
133
+ arr_record = compat.parameters_to_arrayrecord(parameters, keep_input=True)
132
134
  else:
133
135
  # Get initial parameters from one of the clients
134
136
  log(INFO, "Requesting initial parameters from one random client")
135
137
  random_client = context.client_manager.sample(1)[0]
136
138
  # Send GetParametersIns and get the response
137
- content = compat.getparametersins_to_recordset(GetParametersIns({}))
138
- messages = driver.send_and_receive(
139
+ content = compat.getparametersins_to_recorddict(GetParametersIns({}))
140
+ messages = grid.send_and_receive(
139
141
  [
140
- driver.create_message(
142
+ Message(
141
143
  content=content,
142
- message_type=MessageTypeLegacy.GET_PARAMETERS,
143
144
  dst_node_id=random_client.node_id,
145
+ message_type=MessageTypeLegacy.GET_PARAMETERS,
144
146
  group_id="0",
145
147
  )
146
148
  ]
@@ -149,26 +151,26 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
149
151
 
150
152
  if (
151
153
  msg.has_content()
152
- and compat._extract_status_from_recordset( # pylint: disable=W0212
154
+ and compat._extract_status_from_recorddict( # pylint: disable=W0212
153
155
  "getparametersres", msg.content
154
156
  ).code
155
157
  == Code.OK
156
158
  ):
157
159
  log(INFO, "Received initial parameters from one random client")
158
- paramsrecord = next(iter(msg.content.parameters_records.values()))
160
+ arr_record = next(iter(msg.content.array_records.values()))
159
161
  else:
160
162
  log(
161
163
  WARN,
162
164
  "Failed to receive initial parameters from the client."
163
165
  " Empty initial parameters will be used.",
164
166
  )
165
- paramsrecord = ParametersRecord()
167
+ arr_record = ArrayRecord()
166
168
 
167
- context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
169
+ context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
168
170
 
169
171
  # Evaluate initial parameters
170
172
  log(INFO, "Starting evaluation of initial global parameters")
171
- parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
173
+ parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
172
174
  res = context.strategy.evaluate(0, parameters=parameters)
173
175
  if res is not None:
174
176
  log(
@@ -183,19 +185,19 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
183
185
  log(INFO, "Evaluation returned no results (`None`)")
184
186
 
185
187
 
186
- def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None:
188
+ def default_centralized_evaluation_workflow(_: Grid, context: Context) -> None:
187
189
  """Execute the default workflow for centralized evaluation."""
188
190
  if not isinstance(context, LegacyContext):
189
191
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
190
192
 
191
193
  # Retrieve current_round and start_time from the context
192
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
194
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
193
195
  current_round = cast(int, cfg[Key.CURRENT_ROUND])
194
196
  start_time = cast(float, cfg[Key.START_TIME])
195
197
 
196
198
  # Centralized evaluation
197
- parameters = compat.parametersrecord_to_parameters(
198
- record=context.state.parameters_records[MAIN_PARAMS_RECORD],
199
+ parameters = compat.arrayrecord_to_parameters(
200
+ record=context.state.array_records[MAIN_PARAMS_RECORD],
199
201
  keep_input=True,
200
202
  )
201
203
  res_cen = context.strategy.evaluate(current_round, parameters=parameters)
@@ -215,20 +217,16 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None
215
217
  )
216
218
 
217
219
 
218
- def default_fit_workflow( # pylint: disable=R0914
219
- driver: Driver, context: Context
220
- ) -> None:
220
+ def default_fit_workflow(grid: Grid, context: Context) -> None: # pylint: disable=R0914
221
221
  """Execute the default workflow for a single fit round."""
222
222
  if not isinstance(context, LegacyContext):
223
223
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
224
224
 
225
225
  # Get current_round and parameters
226
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
226
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
227
227
  current_round = cast(int, cfg[Key.CURRENT_ROUND])
228
- parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
229
- parameters = compat.parametersrecord_to_parameters(
230
- parametersrecord, keep_input=True
231
- )
228
+ arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
229
+ parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
232
230
 
233
231
  # Get clients and their respective instructions from strategy
234
232
  client_instructions = context.strategy.configure_fit(
@@ -252,10 +250,10 @@ def default_fit_workflow( # pylint: disable=R0914
252
250
 
253
251
  # Build out messages
254
252
  out_messages = [
255
- driver.create_message(
256
- content=compat.fitins_to_recordset(fitins, True),
257
- message_type=MessageType.TRAIN,
253
+ Message(
254
+ content=compat.fitins_to_recorddict(fitins, True),
258
255
  dst_node_id=proxy.node_id,
256
+ message_type=MessageType.TRAIN,
259
257
  group_id=str(current_round),
260
258
  )
261
259
  for proxy, fitins in client_instructions
@@ -263,7 +261,7 @@ def default_fit_workflow( # pylint: disable=R0914
263
261
 
264
262
  # Send instructions to clients and
265
263
  # collect `fit` results from all clients participating in this round
266
- messages = list(driver.send_and_receive(out_messages))
264
+ messages = list(grid.send_and_receive(out_messages))
267
265
  del out_messages
268
266
  num_failures = len([msg for msg in messages if msg.has_error()])
269
267
 
@@ -281,7 +279,7 @@ def default_fit_workflow( # pylint: disable=R0914
281
279
  for msg in messages:
282
280
  if msg.has_content():
283
281
  proxy = node_id_to_proxy[msg.metadata.src_node_id]
284
- fitres = compat.recordset_to_fitres(msg.content, False)
282
+ fitres = compat.recorddict_to_fitres(msg.content, False)
285
283
  if fitres.status.code == Code.OK:
286
284
  results.append((proxy, fitres))
287
285
  else:
@@ -294,28 +292,24 @@ def default_fit_workflow( # pylint: disable=R0914
294
292
 
295
293
  # Update the parameters and write history
296
294
  if parameters_aggregated:
297
- paramsrecord = compat.parameters_to_parametersrecord(
298
- parameters_aggregated, True
299
- )
300
- context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
295
+ arr_record = compat.parameters_to_arrayrecord(parameters_aggregated, True)
296
+ context.state.array_records[MAIN_PARAMS_RECORD] = arr_record
301
297
  context.history.add_metrics_distributed_fit(
302
298
  server_round=current_round, metrics=metrics_aggregated
303
299
  )
304
300
 
305
301
 
306
302
  # pylint: disable-next=R0914
307
- def default_evaluate_workflow(driver: Driver, context: Context) -> None:
303
+ def default_evaluate_workflow(grid: Grid, context: Context) -> None:
308
304
  """Execute the default workflow for a single evaluate round."""
309
305
  if not isinstance(context, LegacyContext):
310
306
  raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
311
307
 
312
308
  # Get current_round and parameters
313
- cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
309
+ cfg = context.state.config_records[MAIN_CONFIGS_RECORD]
314
310
  current_round = cast(int, cfg[Key.CURRENT_ROUND])
315
- parametersrecord = context.state.parameters_records[MAIN_PARAMS_RECORD]
316
- parameters = compat.parametersrecord_to_parameters(
317
- parametersrecord, keep_input=True
318
- )
311
+ arr_record = context.state.array_records[MAIN_PARAMS_RECORD]
312
+ parameters = compat.arrayrecord_to_parameters(arr_record, keep_input=True)
319
313
 
320
314
  # Get clients and their respective instructions from strategy
321
315
  client_instructions = context.strategy.configure_evaluate(
@@ -338,10 +332,10 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
338
332
 
339
333
  # Build out messages
340
334
  out_messages = [
341
- driver.create_message(
342
- content=compat.evaluateins_to_recordset(evalins, True),
343
- message_type=MessageType.EVALUATE,
335
+ Message(
336
+ content=compat.evaluateins_to_recorddict(evalins, True),
344
337
  dst_node_id=proxy.node_id,
338
+ message_type=MessageType.EVALUATE,
345
339
  group_id=str(current_round),
346
340
  )
347
341
  for proxy, evalins in client_instructions
@@ -349,7 +343,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
349
343
 
350
344
  # Send instructions to clients and
351
345
  # collect `evaluate` results from all clients participating in this round
352
- messages = list(driver.send_and_receive(out_messages))
346
+ messages = list(grid.send_and_receive(out_messages))
353
347
  del out_messages
354
348
  num_failures = len([msg for msg in messages if msg.has_error()])
355
349
 
@@ -367,7 +361,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
367
361
  for msg in messages:
368
362
  if msg.has_content():
369
363
  proxy = node_id_to_proxy[msg.metadata.src_node_id]
370
- evalres = compat.recordset_to_evaluateres(msg.content)
364
+ evalres = compat.recorddict_to_evaluateres(msg.content)
371
365
  if evalres.status.code == Code.OK:
372
366
  results.append((proxy, evalres))
373
367
  else: