flwr-nightly 1.17.0.dev20250319__py3-none-any.whl → 1.17.0.dev20250320__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/client/app.py +6 -4
- flwr/client/clientapp/app.py +2 -2
- flwr/client/grpc_client/connection.py +23 -20
- flwr/client/message_handler/message_handler.py +27 -27
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +5 -5
- flwr/client/run_info_store.py +2 -2
- flwr/common/__init__.py +2 -0
- flwr/common/context.py +4 -4
- flwr/common/message.py +269 -101
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/configsrecord.py +2 -2
- flwr/common/record/metricsrecord.py +1 -1
- flwr/common/record/parametersrecord.py +1 -1
- flwr/common/record/{recordset.py → recorddict.py} +57 -17
- flwr/common/{recordset_compat.py → recorddict_compat.py} +105 -105
- flwr/common/serde.py +33 -37
- flwr/proto/exec_pb2.py +32 -32
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +2 -2
- flwr/proto/run_pb2.py +32 -32
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/compat/grid_client_proxy.py +30 -30
- flwr/server/grid/grid.py +3 -3
- flwr/server/grid/grpc_grid.py +15 -23
- flwr/server/grid/inmemory_grid.py +14 -20
- flwr/server/superlink/fleet/vce/vce_api.py +1 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +1 -1
- flwr/server/superlink/linkstate/sqlite_linkstate.py +14 -18
- flwr/server/superlink/linkstate/utils.py +10 -7
- flwr/server/superlink/serverappio/serverappio_servicer.py +1 -1
- flwr/server/utils/validator.py +4 -4
- flwr/server/workflow/default_workflows.py +7 -7
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +3 -3
- flwr/superexec/deployment.py +2 -2
- flwr/superexec/simulation.py +2 -2
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/METADATA +1 -1
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/RECORD +49 -49
- flwr/proto/recordset_pb2.py +0 -70
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.17.0.dev20250319.dist-info → flwr_nightly-1.17.0.dev20250320.dist-info}/entry_points.txt +0 -0
flwr/client/app.py
CHANGED
@@ -495,8 +495,9 @@ def start_client_internal(
|
|
495
495
|
context = run_info_store.retrieve_context(run_id=run_id)
|
496
496
|
# Create an error reply message that will never be used to prevent
|
497
497
|
# the used-before-assignment linting error
|
498
|
-
reply_message =
|
499
|
-
|
498
|
+
reply_message = Message(
|
499
|
+
Error(code=ErrorCode.UNKNOWN, reason="Unknown"),
|
500
|
+
reply_to=message,
|
500
501
|
)
|
501
502
|
|
502
503
|
# Handle app loading and task message
|
@@ -593,8 +594,9 @@ def start_client_internal(
|
|
593
594
|
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
594
595
|
|
595
596
|
# Create error message
|
596
|
-
reply_message =
|
597
|
-
|
597
|
+
reply_message = Message(
|
598
|
+
Error(code=e_code, reason=reason),
|
599
|
+
reply_to=message,
|
598
600
|
)
|
599
601
|
else:
|
600
602
|
# No exception, update node state
|
flwr/client/clientapp/app.py
CHANGED
@@ -152,8 +152,8 @@ def run_clientapp( # pylint: disable=R0914
|
|
152
152
|
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
153
153
|
|
154
154
|
# Create error message
|
155
|
-
reply_message =
|
156
|
-
|
155
|
+
reply_message = Message(
|
156
|
+
Error(code=e_code, reason=reason), reply_to=message
|
157
157
|
)
|
158
158
|
|
159
159
|
# Push Message and Context to SuperNode
|
@@ -31,13 +31,15 @@ from flwr.common import (
|
|
31
31
|
ConfigsRecord,
|
32
32
|
Message,
|
33
33
|
Metadata,
|
34
|
-
|
34
|
+
RecordDict,
|
35
|
+
now,
|
35
36
|
)
|
36
|
-
from flwr.common import
|
37
|
+
from flwr.common import recorddict_compat as compat
|
37
38
|
from flwr.common import serde
|
38
39
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
39
40
|
from flwr.common.grpc import create_channel, on_channel_state_change
|
40
41
|
from flwr.common.logger import log
|
42
|
+
from flwr.common.message import make_message
|
41
43
|
from flwr.common.retry_invoker import RetryInvoker
|
42
44
|
from flwr.common.typing import Fab, Run
|
43
45
|
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
@@ -139,32 +141,32 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
|
|
139
141
|
# Receive ServerMessage proto
|
140
142
|
proto = next(server_message_iterator)
|
141
143
|
|
142
|
-
# ServerMessage proto --> *Ins -->
|
144
|
+
# ServerMessage proto --> *Ins --> RecordDict
|
143
145
|
field = proto.WhichOneof("msg")
|
144
146
|
message_type = ""
|
145
147
|
if field == "get_properties_ins":
|
146
|
-
|
148
|
+
recorddict = compat.getpropertiesins_to_recorddict(
|
147
149
|
serde.get_properties_ins_from_proto(proto.get_properties_ins)
|
148
150
|
)
|
149
151
|
message_type = MessageTypeLegacy.GET_PROPERTIES
|
150
152
|
elif field == "get_parameters_ins":
|
151
|
-
|
153
|
+
recorddict = compat.getparametersins_to_recorddict(
|
152
154
|
serde.get_parameters_ins_from_proto(proto.get_parameters_ins)
|
153
155
|
)
|
154
156
|
message_type = MessageTypeLegacy.GET_PARAMETERS
|
155
157
|
elif field == "fit_ins":
|
156
|
-
|
158
|
+
recorddict = compat.fitins_to_recorddict(
|
157
159
|
serde.fit_ins_from_proto(proto.fit_ins), False
|
158
160
|
)
|
159
161
|
message_type = MessageType.TRAIN
|
160
162
|
elif field == "evaluate_ins":
|
161
|
-
|
163
|
+
recorddict = compat.evaluateins_to_recorddict(
|
162
164
|
serde.evaluate_ins_from_proto(proto.evaluate_ins), False
|
163
165
|
)
|
164
166
|
message_type = MessageType.EVALUATE
|
165
167
|
elif field == "reconnect_ins":
|
166
|
-
|
167
|
-
|
168
|
+
recorddict = RecordDict()
|
169
|
+
recorddict.configs_records["config"] = ConfigsRecord(
|
168
170
|
{"seconds": proto.reconnect_ins.seconds}
|
169
171
|
)
|
170
172
|
message_type = "reconnect"
|
@@ -175,45 +177,46 @@ def grpc_connection( # pylint: disable=R0913,R0915,too-many-positional-argument
|
|
175
177
|
)
|
176
178
|
|
177
179
|
# Construct Message
|
178
|
-
return
|
180
|
+
return make_message(
|
179
181
|
metadata=Metadata(
|
180
182
|
run_id=0,
|
181
183
|
message_id=str(uuid.uuid4()),
|
182
184
|
src_node_id=0,
|
183
185
|
dst_node_id=0,
|
184
|
-
|
186
|
+
reply_to_message_id="",
|
185
187
|
group_id="",
|
188
|
+
created_at=now().timestamp(),
|
186
189
|
ttl=DEFAULT_TTL,
|
187
190
|
message_type=message_type,
|
188
191
|
),
|
189
|
-
content=
|
192
|
+
content=recorddict,
|
190
193
|
)
|
191
194
|
|
192
195
|
def send(message: Message) -> None:
|
193
|
-
# Retrieve
|
194
|
-
|
196
|
+
# Retrieve RecordDict and message_type
|
197
|
+
recorddict = message.content
|
195
198
|
message_type = message.metadata.message_type
|
196
199
|
|
197
|
-
#
|
200
|
+
# RecordDict --> *Res --> *Res proto -> ClientMessage proto
|
198
201
|
if message_type == MessageTypeLegacy.GET_PROPERTIES:
|
199
|
-
getpropres = compat.
|
202
|
+
getpropres = compat.recorddict_to_getpropertiesres(recorddict)
|
200
203
|
msg_proto = ClientMessage(
|
201
204
|
get_properties_res=serde.get_properties_res_to_proto(getpropres)
|
202
205
|
)
|
203
206
|
elif message_type == MessageTypeLegacy.GET_PARAMETERS:
|
204
|
-
getparamres = compat.
|
207
|
+
getparamres = compat.recorddict_to_getparametersres(recorddict, False)
|
205
208
|
msg_proto = ClientMessage(
|
206
209
|
get_parameters_res=serde.get_parameters_res_to_proto(getparamres)
|
207
210
|
)
|
208
211
|
elif message_type == MessageType.TRAIN:
|
209
|
-
fitres = compat.
|
212
|
+
fitres = compat.recorddict_to_fitres(recorddict, False)
|
210
213
|
msg_proto = ClientMessage(fit_res=serde.fit_res_to_proto(fitres))
|
211
214
|
elif message_type == MessageType.EVALUATE:
|
212
|
-
evalres = compat.
|
215
|
+
evalres = compat.recorddict_to_evaluateres(recorddict)
|
213
216
|
msg_proto = ClientMessage(evaluate_res=serde.evaluate_res_to_proto(evalres))
|
214
217
|
elif message_type == "reconnect":
|
215
218
|
reason = cast(
|
216
|
-
Reason.ValueType,
|
219
|
+
Reason.ValueType, recorddict.configs_records["config"]["reason"]
|
217
220
|
)
|
218
221
|
msg_proto = ClientMessage(
|
219
222
|
disconnect_res=ClientMessage.DisconnectRes(reason=reason)
|
@@ -26,17 +26,17 @@ from flwr.client.client import (
|
|
26
26
|
)
|
27
27
|
from flwr.client.numpy_client import NumPyClient
|
28
28
|
from flwr.client.typing import ClientFnExt
|
29
|
-
from flwr.common import ConfigsRecord, Context, Message, Metadata,
|
29
|
+
from flwr.common import ConfigsRecord, Context, Message, Metadata, RecordDict, log
|
30
30
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
31
|
-
from flwr.common.
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
31
|
+
from flwr.common.recorddict_compat import (
|
32
|
+
evaluateres_to_recorddict,
|
33
|
+
fitres_to_recorddict,
|
34
|
+
getparametersres_to_recorddict,
|
35
|
+
getpropertiesres_to_recorddict,
|
36
|
+
recorddict_to_evaluateins,
|
37
|
+
recorddict_to_fitins,
|
38
|
+
recorddict_to_getparametersins,
|
39
|
+
recorddict_to_getpropertiesins,
|
40
40
|
)
|
41
41
|
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
|
42
42
|
ClientMessage,
|
@@ -70,18 +70,18 @@ def handle_control_message(message: Message) -> tuple[Optional[Message], int]:
|
|
70
70
|
Number of seconds that the client should disconnect from the server.
|
71
71
|
"""
|
72
72
|
if message.metadata.message_type == "reconnect":
|
73
|
-
# Retrieve ReconnectIns from
|
74
|
-
|
75
|
-
seconds = cast(int,
|
73
|
+
# Retrieve ReconnectIns from RecordDict
|
74
|
+
recorddict = message.content
|
75
|
+
seconds = cast(int, recorddict.configs_records["config"]["seconds"])
|
76
76
|
# Construct ReconnectIns and call _reconnect
|
77
77
|
disconnect_msg, sleep_duration = _reconnect(
|
78
78
|
ServerMessage.ReconnectIns(seconds=seconds)
|
79
79
|
)
|
80
|
-
# Store DisconnectRes in
|
80
|
+
# Store DisconnectRes in RecordDict
|
81
81
|
reason = cast(int, disconnect_msg.disconnect_res.reason)
|
82
|
-
|
83
|
-
|
84
|
-
out_message = message
|
82
|
+
recorddict = RecordDict()
|
83
|
+
recorddict.configs_records["config"] = ConfigsRecord({"reason": reason})
|
84
|
+
out_message = Message(recorddict, reply_to=message)
|
85
85
|
# Return Message and sleep duration
|
86
86
|
return out_message, sleep_duration
|
87
87
|
|
@@ -111,37 +111,37 @@ def handle_legacy_message_from_msgtype(
|
|
111
111
|
if message_type == MessageTypeLegacy.GET_PROPERTIES:
|
112
112
|
get_properties_res = maybe_call_get_properties(
|
113
113
|
client=client,
|
114
|
-
get_properties_ins=
|
114
|
+
get_properties_ins=recorddict_to_getpropertiesins(message.content),
|
115
115
|
)
|
116
|
-
|
116
|
+
out_recorddict = getpropertiesres_to_recorddict(get_properties_res)
|
117
117
|
# Handle GetParametersIns
|
118
118
|
elif message_type == MessageTypeLegacy.GET_PARAMETERS:
|
119
119
|
get_parameters_res = maybe_call_get_parameters(
|
120
120
|
client=client,
|
121
|
-
get_parameters_ins=
|
121
|
+
get_parameters_ins=recorddict_to_getparametersins(message.content),
|
122
122
|
)
|
123
|
-
|
123
|
+
out_recorddict = getparametersres_to_recorddict(
|
124
124
|
get_parameters_res, keep_input=False
|
125
125
|
)
|
126
126
|
# Handle FitIns
|
127
127
|
elif message_type == MessageType.TRAIN:
|
128
128
|
fit_res = maybe_call_fit(
|
129
129
|
client=client,
|
130
|
-
fit_ins=
|
130
|
+
fit_ins=recorddict_to_fitins(message.content, keep_input=True),
|
131
131
|
)
|
132
|
-
|
132
|
+
out_recorddict = fitres_to_recorddict(fit_res, keep_input=False)
|
133
133
|
# Handle EvaluateIns
|
134
134
|
elif message_type == MessageType.EVALUATE:
|
135
135
|
evaluate_res = maybe_call_evaluate(
|
136
136
|
client=client,
|
137
|
-
evaluate_ins=
|
137
|
+
evaluate_ins=recorddict_to_evaluateins(message.content, keep_input=True),
|
138
138
|
)
|
139
|
-
|
139
|
+
out_recorddict = evaluateres_to_recorddict(evaluate_res)
|
140
140
|
else:
|
141
141
|
raise ValueError(f"Invalid message type: {message_type}")
|
142
142
|
|
143
143
|
# Return Message
|
144
|
-
return message
|
144
|
+
return Message(out_recorddict, reply_to=message)
|
145
145
|
|
146
146
|
|
147
147
|
def _reconnect(
|
@@ -167,7 +167,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
|
|
167
167
|
and out_meta.message_id == "" # This will be generated by the server
|
168
168
|
and out_meta.src_node_id == in_meta.dst_node_id
|
169
169
|
and out_meta.dst_node_id == in_meta.src_node_id
|
170
|
-
and out_meta.
|
170
|
+
and out_meta.reply_to_message_id == in_meta.message_id
|
171
171
|
and out_meta.group_id == in_meta.group_id
|
172
172
|
and out_meta.message_type == in_meta.message_type
|
173
173
|
and out_meta.created_at > in_meta.created_at
|
@@ -19,7 +19,7 @@ from logging import INFO
|
|
19
19
|
|
20
20
|
from flwr.client.typing import ClientAppCallable
|
21
21
|
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
|
22
|
-
from flwr.common import
|
22
|
+
from flwr.common import recorddict_compat as compat
|
23
23
|
from flwr.common.constant import MessageType
|
24
24
|
from flwr.common.context import Context
|
25
25
|
from flwr.common.differential_privacy import (
|
@@ -53,7 +53,7 @@ def fixedclipping_mod(
|
|
53
53
|
"""
|
54
54
|
if msg.metadata.message_type != MessageType.TRAIN:
|
55
55
|
return call_next(msg, ctxt)
|
56
|
-
fit_ins = compat.
|
56
|
+
fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
|
57
57
|
if KEY_CLIPPING_NORM not in fit_ins.config:
|
58
58
|
raise KeyError(
|
59
59
|
f"The {KEY_CLIPPING_NORM} value is not supplied by the "
|
@@ -71,7 +71,7 @@ def fixedclipping_mod(
|
|
71
71
|
if out_msg.has_error():
|
72
72
|
return out_msg
|
73
73
|
|
74
|
-
fit_res = compat.
|
74
|
+
fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
|
75
75
|
|
76
76
|
client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
|
77
77
|
|
@@ -87,7 +87,7 @@ def fixedclipping_mod(
|
|
87
87
|
)
|
88
88
|
|
89
89
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
90
|
-
out_msg.content = compat.
|
90
|
+
out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
|
91
91
|
return out_msg
|
92
92
|
|
93
93
|
|
@@ -116,7 +116,7 @@ def adaptiveclipping_mod(
|
|
116
116
|
if msg.metadata.message_type != MessageType.TRAIN:
|
117
117
|
return call_next(msg, ctxt)
|
118
118
|
|
119
|
-
fit_ins = compat.
|
119
|
+
fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
|
120
120
|
|
121
121
|
if KEY_CLIPPING_NORM not in fit_ins.config:
|
122
122
|
raise KeyError(
|
@@ -136,7 +136,7 @@ def adaptiveclipping_mod(
|
|
136
136
|
if out_msg.has_error():
|
137
137
|
return out_msg
|
138
138
|
|
139
|
-
fit_res = compat.
|
139
|
+
fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
|
140
140
|
|
141
141
|
client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
|
142
142
|
|
@@ -155,5 +155,5 @@ def adaptiveclipping_mod(
|
|
155
155
|
fit_res.parameters = ndarrays_to_parameters(client_to_server_params)
|
156
156
|
|
157
157
|
fit_res.metrics[KEY_NORM_BIT] = norm_bit
|
158
|
-
out_msg.content = compat.
|
158
|
+
out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
|
159
159
|
return out_msg
|
flwr/client/mod/localdp_mod.py
CHANGED
@@ -21,7 +21,7 @@ import numpy as np
|
|
21
21
|
|
22
22
|
from flwr.client.typing import ClientAppCallable
|
23
23
|
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
|
24
|
-
from flwr.common import
|
24
|
+
from flwr.common import recorddict_compat as compat
|
25
25
|
from flwr.common.constant import MessageType
|
26
26
|
from flwr.common.context import Context
|
27
27
|
from flwr.common.differential_privacy import (
|
@@ -107,7 +107,7 @@ class LocalDpMod:
|
|
107
107
|
if msg.metadata.message_type != MessageType.TRAIN:
|
108
108
|
return call_next(msg, ctxt)
|
109
109
|
|
110
|
-
fit_ins = compat.
|
110
|
+
fit_ins = compat.recorddict_to_fitins(msg.content, keep_input=True)
|
111
111
|
server_to_client_params = parameters_to_ndarrays(fit_ins.parameters)
|
112
112
|
|
113
113
|
# Call inner app
|
@@ -117,7 +117,7 @@ class LocalDpMod:
|
|
117
117
|
if out_msg.has_error():
|
118
118
|
return out_msg
|
119
119
|
|
120
|
-
fit_res = compat.
|
120
|
+
fit_res = compat.recorddict_to_fitres(out_msg.content, keep_input=True)
|
121
121
|
|
122
122
|
client_to_server_params = parameters_to_ndarrays(fit_res.parameters)
|
123
123
|
|
@@ -149,5 +149,5 @@ class LocalDpMod:
|
|
149
149
|
noise_value_sd,
|
150
150
|
)
|
151
151
|
|
152
|
-
out_msg.content = compat.
|
152
|
+
out_msg.content = compat.fitres_to_recorddict(fit_res, keep_input=True)
|
153
153
|
return out_msg
|
@@ -26,11 +26,11 @@ from flwr.common import (
|
|
26
26
|
Context,
|
27
27
|
Message,
|
28
28
|
Parameters,
|
29
|
-
|
29
|
+
RecordDict,
|
30
30
|
ndarray_to_bytes,
|
31
31
|
parameters_to_ndarrays,
|
32
32
|
)
|
33
|
-
from flwr.common import
|
33
|
+
from flwr.common import recorddict_compat as compat
|
34
34
|
from flwr.common.constant import MessageType
|
35
35
|
from flwr.common.logger import log
|
36
36
|
from flwr.common.secure_aggregation.crypto.shamir import create_shares
|
@@ -162,7 +162,7 @@ def secaggplus_mod(
|
|
162
162
|
check_configs(state.current_stage, configs)
|
163
163
|
|
164
164
|
# Execute
|
165
|
-
out_content =
|
165
|
+
out_content = RecordDict()
|
166
166
|
if state.current_stage == Stage.SETUP:
|
167
167
|
state.nid = msg.metadata.dst_node_id
|
168
168
|
res = _setup(state, configs)
|
@@ -171,7 +171,7 @@ def secaggplus_mod(
|
|
171
171
|
elif state.current_stage == Stage.COLLECT_MASKED_VECTORS:
|
172
172
|
out_msg = call_next(msg, ctxt)
|
173
173
|
out_content = out_msg.content
|
174
|
-
fitres = compat.
|
174
|
+
fitres = compat.recorddict_to_fitres(out_content, keep_input=True)
|
175
175
|
res = _collect_masked_vectors(
|
176
176
|
state, configs, fitres.num_examples, fitres.parameters
|
177
177
|
)
|
@@ -187,7 +187,7 @@ def secaggplus_mod(
|
|
187
187
|
|
188
188
|
# Return message
|
189
189
|
out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False)
|
190
|
-
return
|
190
|
+
return Message(out_content, reply_to=msg)
|
191
191
|
|
192
192
|
|
193
193
|
def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
|
flwr/client/run_info_store.py
CHANGED
@@ -19,7 +19,7 @@ from dataclasses import dataclass
|
|
19
19
|
from pathlib import Path
|
20
20
|
from typing import Optional
|
21
21
|
|
22
|
-
from flwr.common import Context,
|
22
|
+
from flwr.common import Context, RecordDict
|
23
23
|
from flwr.common.config import (
|
24
24
|
get_fused_config,
|
25
25
|
get_fused_config_from_dir,
|
@@ -86,7 +86,7 @@ class DeprecatedRunInfoStore:
|
|
86
86
|
run_id=run_id,
|
87
87
|
node_id=self.node_id,
|
88
88
|
node_config=self.node_config,
|
89
|
-
state=
|
89
|
+
state=RecordDict(),
|
90
90
|
run_config=initial_run_config.copy(),
|
91
91
|
),
|
92
92
|
)
|
flwr/common/__init__.py
CHANGED
@@ -34,6 +34,7 @@ from .record import Array as Array
|
|
34
34
|
from .record import ConfigsRecord as ConfigsRecord
|
35
35
|
from .record import MetricsRecord as MetricsRecord
|
36
36
|
from .record import ParametersRecord as ParametersRecord
|
37
|
+
from .record import RecordDict as RecordDict
|
37
38
|
from .record import RecordSet as RecordSet
|
38
39
|
from .record import array_from_numpy as array_from_numpy
|
39
40
|
from .telemetry import EventType as EventType
|
@@ -98,6 +99,7 @@ __all__ = [
|
|
98
99
|
"ParametersRecord",
|
99
100
|
"Properties",
|
100
101
|
"ReconnectIns",
|
102
|
+
"RecordDict",
|
101
103
|
"RecordSet",
|
102
104
|
"Scalar",
|
103
105
|
"ServerMessage",
|
flwr/common/context.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
from dataclasses import dataclass
|
19
19
|
|
20
|
-
from .record import
|
20
|
+
from .record import RecordDict
|
21
21
|
from .typing import UserConfig
|
22
22
|
|
23
23
|
|
@@ -34,7 +34,7 @@ class Context:
|
|
34
34
|
node_config : UserConfig
|
35
35
|
A config (key/value mapping) unique to the node and independent of the
|
36
36
|
`run_config`. This config persists across all runs this node participates in.
|
37
|
-
state :
|
37
|
+
state : RecordDict
|
38
38
|
Holds records added by the entity in a given `run_id` and that will stay local.
|
39
39
|
This means that the data it holds will never leave the system it's running from.
|
40
40
|
This can be used as an intermediate storage or scratchpad when
|
@@ -50,7 +50,7 @@ class Context:
|
|
50
50
|
run_id: int
|
51
51
|
node_id: int
|
52
52
|
node_config: UserConfig
|
53
|
-
state:
|
53
|
+
state: RecordDict
|
54
54
|
run_config: UserConfig
|
55
55
|
|
56
56
|
def __init__( # pylint: disable=too-many-arguments, too-many-positional-arguments
|
@@ -58,7 +58,7 @@ class Context:
|
|
58
58
|
run_id: int,
|
59
59
|
node_id: int,
|
60
60
|
node_config: UserConfig,
|
61
|
-
state:
|
61
|
+
state: RecordDict,
|
62
62
|
run_config: UserConfig,
|
63
63
|
) -> None:
|
64
64
|
self.run_id = run_id
|