flwr 1.16.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/cli/run/run.py +5 -9
- flwr/client/app.py +6 -4
- flwr/client/client_app.py +162 -99
- flwr/client/clientapp/app.py +2 -2
- flwr/client/grpc_client/connection.py +24 -21
- flwr/client/message_handler/message_handler.py +27 -27
- flwr/client/mod/__init__.py +2 -2
- flwr/client/mod/centraldp_mods.py +7 -7
- flwr/client/mod/comms_mods.py +16 -22
- flwr/client/mod/localdp_mod.py +4 -4
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +31 -31
- flwr/client/run_info_store.py +2 -2
- flwr/common/__init__.py +12 -4
- flwr/common/config.py +4 -4
- flwr/common/constant.py +6 -6
- flwr/common/context.py +4 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/logger.py +2 -2
- flwr/common/message.py +327 -102
- flwr/common/record/__init__.py +8 -4
- flwr/common/record/arrayrecord.py +626 -0
- flwr/common/record/{configsrecord.py → configrecord.py} +75 -29
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/{metricsrecord.py → metricrecord.py} +78 -32
- flwr/common/record/recorddict.py +288 -0
- flwr/common/recorddict_compat.py +410 -0
- flwr/common/secure_aggregation/secaggplus_constants.py +1 -1
- flwr/common/serde.py +66 -71
- flwr/common/typing.py +8 -8
- flwr/proto/exec_pb2.py +3 -3
- flwr/proto/exec_pb2.pyi +3 -3
- flwr/proto/message_pb2.py +12 -12
- flwr/proto/message_pb2.pyi +9 -9
- flwr/proto/recorddict_pb2.py +70 -0
- flwr/proto/{recordset_pb2.pyi → recorddict_pb2.pyi} +35 -35
- flwr/proto/run_pb2.py +31 -31
- flwr/proto/run_pb2.pyi +3 -3
- flwr/server/__init__.py +3 -1
- flwr/server/app.py +56 -1
- flwr/server/compat/__init__.py +2 -2
- flwr/server/compat/app.py +11 -11
- flwr/server/compat/app_utils.py +16 -16
- flwr/server/compat/{driver_client_proxy.py → grid_client_proxy.py} +39 -39
- flwr/server/fleet_event_log_interceptor.py +94 -0
- flwr/server/{driver → grid}/__init__.py +8 -7
- flwr/server/{driver/driver.py → grid/grid.py} +47 -18
- flwr/server/{driver/grpc_driver.py → grid/grpc_grid.py} +87 -64
- flwr/server/{driver/inmemory_driver.py → grid/inmemory_grid.py} +24 -34
- flwr/server/run_serverapp.py +4 -4
- flwr/server/server_app.py +38 -18
- flwr/server/serverapp/app.py +10 -10
- flwr/server/superlink/fleet/vce/backend/backend.py +2 -2
- flwr/server/superlink/fleet/vce/backend/raybackend.py +2 -2
- flwr/server/superlink/fleet/vce/vce_api.py +1 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +33 -8
- flwr/server/superlink/linkstate/linkstate.py +4 -4
- flwr/server/superlink/linkstate/sqlite_linkstate.py +61 -27
- flwr/server/superlink/linkstate/utils.py +93 -27
- flwr/server/superlink/{driver → serverappio}/__init__.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_grpc.py +1 -1
- flwr/server/superlink/{driver → serverappio}/serverappio_servicer.py +4 -4
- flwr/server/superlink/simulation/simulationio_servicer.py +2 -2
- flwr/server/typing.py +3 -3
- flwr/server/utils/validator.py +4 -4
- flwr/server/workflow/default_workflows.py +48 -57
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +48 -50
- flwr/simulation/app.py +2 -2
- flwr/simulation/ray_transport/ray_actor.py +4 -2
- flwr/simulation/ray_transport/ray_client_proxy.py +34 -32
- flwr/simulation/run_simulation.py +15 -15
- flwr/superexec/deployment.py +4 -4
- flwr/superexec/exec_event_log_interceptor.py +135 -0
- flwr/superexec/exec_grpc.py +10 -4
- flwr/superexec/exec_servicer.py +2 -2
- flwr/superexec/exec_user_auth_interceptor.py +18 -2
- flwr/superexec/executor.py +3 -3
- flwr/superexec/simulation.py +3 -3
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/METADATA +2 -2
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/RECORD +94 -92
- flwr/common/record/parametersrecord.py +0 -339
- flwr/common/record/recordset.py +0 -209
- flwr/common/recordset_compat.py +0 -418
- flwr/proto/recordset_pb2.py +0 -70
- /flwr/proto/{recordset_pb2_grpc.py → recorddict_pb2_grpc.py} +0 -0
- /flwr/proto/{recordset_pb2_grpc.pyi → recorddict_pb2_grpc.pyi} +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/LICENSE +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/WHEEL +0 -0
- {flwr-1.16.0.dist-info → flwr-1.17.0.dist-info}/entry_points.txt +0 -0
|
@@ -23,7 +23,7 @@ from flwr import common
|
|
|
23
23
|
from flwr.client import ClientFnExt
|
|
24
24
|
from flwr.client.client_app import ClientApp
|
|
25
25
|
from flwr.client.run_info_store import DeprecatedRunInfoStore
|
|
26
|
-
from flwr.common import DEFAULT_TTL, Message, Metadata,
|
|
26
|
+
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordDict, now
|
|
27
27
|
from flwr.common.constant import (
|
|
28
28
|
NUM_PARTITIONS_KEY,
|
|
29
29
|
PARTITION_ID_KEY,
|
|
@@ -31,15 +31,16 @@ from flwr.common.constant import (
|
|
|
31
31
|
MessageTypeLegacy,
|
|
32
32
|
)
|
|
33
33
|
from flwr.common.logger import log
|
|
34
|
-
from flwr.common.
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
34
|
+
from flwr.common.message import make_message
|
|
35
|
+
from flwr.common.recorddict_compat import (
|
|
36
|
+
evaluateins_to_recorddict,
|
|
37
|
+
fitins_to_recorddict,
|
|
38
|
+
getparametersins_to_recorddict,
|
|
39
|
+
getpropertiesins_to_recorddict,
|
|
40
|
+
recorddict_to_evaluateres,
|
|
41
|
+
recorddict_to_fitres,
|
|
42
|
+
recorddict_to_getparametersres,
|
|
43
|
+
recorddict_to_getpropertiesres,
|
|
43
44
|
)
|
|
44
45
|
from flwr.server.client_proxy import ClientProxy
|
|
45
46
|
from flwr.simulation.ray_transport.ray_actor import VirtualClientEngineActorPool
|
|
@@ -109,23 +110,24 @@ class RayActorClientProxy(ClientProxy):
|
|
|
109
110
|
|
|
110
111
|
return out_mssg
|
|
111
112
|
|
|
112
|
-
def
|
|
113
|
+
def _wrap_recorddict_in_message(
|
|
113
114
|
self,
|
|
114
|
-
|
|
115
|
+
recorddict: RecordDict,
|
|
115
116
|
message_type: str,
|
|
116
117
|
timeout: Optional[float],
|
|
117
118
|
group_id: Optional[int],
|
|
118
119
|
) -> Message:
|
|
119
|
-
"""Wrap a
|
|
120
|
-
return
|
|
121
|
-
content=
|
|
120
|
+
"""Wrap a RecordDict inside a Message."""
|
|
121
|
+
return make_message(
|
|
122
|
+
content=recorddict,
|
|
122
123
|
metadata=Metadata(
|
|
123
124
|
run_id=0,
|
|
124
125
|
message_id="",
|
|
125
126
|
group_id=str(group_id) if group_id is not None else "",
|
|
126
127
|
src_node_id=0,
|
|
127
128
|
dst_node_id=self.node_id,
|
|
128
|
-
|
|
129
|
+
reply_to_message_id="",
|
|
130
|
+
created_at=now().timestamp(),
|
|
129
131
|
ttl=timeout if timeout else DEFAULT_TTL,
|
|
130
132
|
message_type=message_type,
|
|
131
133
|
),
|
|
@@ -138,9 +140,9 @@ class RayActorClientProxy(ClientProxy):
|
|
|
138
140
|
group_id: Optional[int],
|
|
139
141
|
) -> common.GetPropertiesRes:
|
|
140
142
|
"""Return client's properties."""
|
|
141
|
-
|
|
142
|
-
message = self.
|
|
143
|
-
|
|
143
|
+
recorddict = getpropertiesins_to_recorddict(ins)
|
|
144
|
+
message = self._wrap_recorddict_in_message(
|
|
145
|
+
recorddict,
|
|
144
146
|
message_type=MessageTypeLegacy.GET_PROPERTIES,
|
|
145
147
|
timeout=timeout,
|
|
146
148
|
group_id=group_id,
|
|
@@ -148,7 +150,7 @@ class RayActorClientProxy(ClientProxy):
|
|
|
148
150
|
|
|
149
151
|
message_out = self._submit_job(message, timeout)
|
|
150
152
|
|
|
151
|
-
return
|
|
153
|
+
return recorddict_to_getpropertiesres(message_out.content)
|
|
152
154
|
|
|
153
155
|
def get_parameters(
|
|
154
156
|
self,
|
|
@@ -157,9 +159,9 @@ class RayActorClientProxy(ClientProxy):
|
|
|
157
159
|
group_id: Optional[int],
|
|
158
160
|
) -> common.GetParametersRes:
|
|
159
161
|
"""Return the current local model parameters."""
|
|
160
|
-
|
|
161
|
-
message = self.
|
|
162
|
-
|
|
162
|
+
recorddict = getparametersins_to_recorddict(ins)
|
|
163
|
+
message = self._wrap_recorddict_in_message(
|
|
164
|
+
recorddict,
|
|
163
165
|
message_type=MessageTypeLegacy.GET_PARAMETERS,
|
|
164
166
|
timeout=timeout,
|
|
165
167
|
group_id=group_id,
|
|
@@ -167,17 +169,17 @@ class RayActorClientProxy(ClientProxy):
|
|
|
167
169
|
|
|
168
170
|
message_out = self._submit_job(message, timeout)
|
|
169
171
|
|
|
170
|
-
return
|
|
172
|
+
return recorddict_to_getparametersres(message_out.content, keep_input=False)
|
|
171
173
|
|
|
172
174
|
def fit(
|
|
173
175
|
self, ins: common.FitIns, timeout: Optional[float], group_id: Optional[int]
|
|
174
176
|
) -> common.FitRes:
|
|
175
177
|
"""Train model parameters on the locally held dataset."""
|
|
176
|
-
|
|
178
|
+
recorddict = fitins_to_recorddict(
|
|
177
179
|
ins, keep_input=True
|
|
178
180
|
) # This must stay TRUE since ins are in-memory
|
|
179
|
-
message = self.
|
|
180
|
-
|
|
181
|
+
message = self._wrap_recorddict_in_message(
|
|
182
|
+
recorddict,
|
|
181
183
|
message_type=MessageType.TRAIN,
|
|
182
184
|
timeout=timeout,
|
|
183
185
|
group_id=group_id,
|
|
@@ -185,17 +187,17 @@ class RayActorClientProxy(ClientProxy):
|
|
|
185
187
|
|
|
186
188
|
message_out = self._submit_job(message, timeout)
|
|
187
189
|
|
|
188
|
-
return
|
|
190
|
+
return recorddict_to_fitres(message_out.content, keep_input=False)
|
|
189
191
|
|
|
190
192
|
def evaluate(
|
|
191
193
|
self, ins: common.EvaluateIns, timeout: Optional[float], group_id: Optional[int]
|
|
192
194
|
) -> common.EvaluateRes:
|
|
193
195
|
"""Evaluate model parameters on the locally held dataset."""
|
|
194
|
-
|
|
196
|
+
recorddict = evaluateins_to_recorddict(
|
|
195
197
|
ins, keep_input=True
|
|
196
198
|
) # This must stay TRUE since ins are in-memory
|
|
197
|
-
message = self.
|
|
198
|
-
|
|
199
|
+
message = self._wrap_recorddict_in_message(
|
|
200
|
+
recorddict,
|
|
199
201
|
message_type=MessageType.EVALUATE,
|
|
200
202
|
timeout=timeout,
|
|
201
203
|
group_id=group_id,
|
|
@@ -203,7 +205,7 @@ class RayActorClientProxy(ClientProxy):
|
|
|
203
205
|
|
|
204
206
|
message_out = self._submit_job(message, timeout)
|
|
205
207
|
|
|
206
|
-
return
|
|
208
|
+
return recorddict_to_evaluateres(message_out.content)
|
|
207
209
|
|
|
208
210
|
def reconnect(
|
|
209
211
|
self,
|
|
@@ -30,7 +30,7 @@ from typing import Any, Optional
|
|
|
30
30
|
from flwr.cli.config_utils import load_and_validate
|
|
31
31
|
from flwr.cli.utils import get_sha256_hash
|
|
32
32
|
from flwr.client import ClientApp
|
|
33
|
-
from flwr.common import Context, EventType,
|
|
33
|
+
from flwr.common import Context, EventType, RecordDict, event, log, now
|
|
34
34
|
from flwr.common.config import get_fused_config_from_dir, parse_config_args
|
|
35
35
|
from flwr.common.constant import RUN_ID_NUM_BYTES, Status
|
|
36
36
|
from flwr.common.logger import (
|
|
@@ -39,7 +39,7 @@ from flwr.common.logger import (
|
|
|
39
39
|
warn_deprecated_feature_with_example,
|
|
40
40
|
)
|
|
41
41
|
from flwr.common.typing import Run, RunStatus, UserConfig
|
|
42
|
-
from flwr.server.
|
|
42
|
+
from flwr.server.grid import Grid, InMemoryGrid
|
|
43
43
|
from flwr.server.run_serverapp import run as _run
|
|
44
44
|
from flwr.server.server_app import ServerApp
|
|
45
45
|
from flwr.server.superlink.fleet import vce
|
|
@@ -168,7 +168,7 @@ def run_simulation(
|
|
|
168
168
|
messages sent by the `ServerApp`.
|
|
169
169
|
|
|
170
170
|
num_supernodes : int
|
|
171
|
-
Number of nodes that run a ClientApp. They can be sampled by a
|
|
171
|
+
Number of nodes that run a ClientApp. They can be sampled by a Grid in the
|
|
172
172
|
ServerApp and receive a Message describing what the ClientApp should perform.
|
|
173
173
|
|
|
174
174
|
backend_name : str (default: ray)
|
|
@@ -180,7 +180,7 @@ def run_simulation(
|
|
|
180
180
|
for values parsed to initialisation of backend, `client_resources`
|
|
181
181
|
to define the resources for clients, and `actor` to define the actor
|
|
182
182
|
parameters. Values supported in <value> are those included by
|
|
183
|
-
`flwr.common.typing.
|
|
183
|
+
`flwr.common.typing.ConfigRecordValues`.
|
|
184
184
|
|
|
185
185
|
enable_tf_gpu_growth : bool (default: False)
|
|
186
186
|
A boolean to indicate whether to enable GPU growth on the main thread. This is
|
|
@@ -225,7 +225,7 @@ def run_serverapp_th(
|
|
|
225
225
|
server_app_attr: Optional[str],
|
|
226
226
|
server_app: Optional[ServerApp],
|
|
227
227
|
server_app_run_config: UserConfig,
|
|
228
|
-
|
|
228
|
+
grid: Grid,
|
|
229
229
|
app_dir: str,
|
|
230
230
|
f_stop: threading.Event,
|
|
231
231
|
has_exception: threading.Event,
|
|
@@ -239,7 +239,7 @@ def run_serverapp_th(
|
|
|
239
239
|
tf_gpu_growth: bool,
|
|
240
240
|
stop_event: threading.Event,
|
|
241
241
|
exception_event: threading.Event,
|
|
242
|
-
|
|
242
|
+
_grid: Grid,
|
|
243
243
|
_server_app_dir: str,
|
|
244
244
|
_server_app_run_config: UserConfig,
|
|
245
245
|
_server_app_attr: Optional[str],
|
|
@@ -260,13 +260,13 @@ def run_serverapp_th(
|
|
|
260
260
|
run_id=run_id,
|
|
261
261
|
node_id=0,
|
|
262
262
|
node_config={},
|
|
263
|
-
state=
|
|
263
|
+
state=RecordDict(),
|
|
264
264
|
run_config=_server_app_run_config,
|
|
265
265
|
)
|
|
266
266
|
|
|
267
267
|
# Run ServerApp
|
|
268
268
|
updated_context = _run(
|
|
269
|
-
|
|
269
|
+
grid=_grid,
|
|
270
270
|
context=context,
|
|
271
271
|
server_app_dir=_server_app_dir,
|
|
272
272
|
server_app_attr=_server_app_attr,
|
|
@@ -291,7 +291,7 @@ def run_serverapp_th(
|
|
|
291
291
|
enable_tf_gpu_growth,
|
|
292
292
|
f_stop,
|
|
293
293
|
has_exception,
|
|
294
|
-
|
|
294
|
+
grid,
|
|
295
295
|
app_dir,
|
|
296
296
|
server_app_run_config,
|
|
297
297
|
server_app_attr,
|
|
@@ -333,7 +333,7 @@ def _main_loop(
|
|
|
333
333
|
run_id=run.run_id,
|
|
334
334
|
node_id=0,
|
|
335
335
|
node_config=UserConfig(),
|
|
336
|
-
state=
|
|
336
|
+
state=RecordDict(),
|
|
337
337
|
run_config=UserConfig(),
|
|
338
338
|
)
|
|
339
339
|
try:
|
|
@@ -347,9 +347,9 @@ def _main_loop(
|
|
|
347
347
|
if server_app_run_config is None:
|
|
348
348
|
server_app_run_config = {}
|
|
349
349
|
|
|
350
|
-
# Initialize
|
|
351
|
-
|
|
352
|
-
|
|
350
|
+
# Initialize Grid
|
|
351
|
+
grid = InMemoryGrid(state_factory=state_factory)
|
|
352
|
+
grid.set_run(run_id=run.run_id)
|
|
353
353
|
output_context_queue: Queue[Context] = Queue()
|
|
354
354
|
|
|
355
355
|
# Get and run ServerApp thread
|
|
@@ -357,7 +357,7 @@ def _main_loop(
|
|
|
357
357
|
server_app_attr=server_app_attr,
|
|
358
358
|
server_app=server_app,
|
|
359
359
|
server_app_run_config=server_app_run_config,
|
|
360
|
-
|
|
360
|
+
grid=grid,
|
|
361
361
|
app_dir=app_dir,
|
|
362
362
|
f_stop=f_stop,
|
|
363
363
|
has_exception=server_app_thread_has_exception,
|
|
@@ -546,7 +546,7 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
|
|
|
546
546
|
default="{}",
|
|
547
547
|
help='A JSON formatted stream, e.g \'{"<keyA>":<value>, "<keyB>":<value>}\' to '
|
|
548
548
|
"configure a backend. Values supported in <value> are those included by "
|
|
549
|
-
"`flwr.common.typing.
|
|
549
|
+
"`flwr.common.typing.ConfigRecordValues`. ",
|
|
550
550
|
)
|
|
551
551
|
parser.add_argument(
|
|
552
552
|
"--enable-tf-gpu-growth",
|
flwr/superexec/deployment.py
CHANGED
|
@@ -23,7 +23,7 @@ from typing import Optional
|
|
|
23
23
|
from typing_extensions import override
|
|
24
24
|
|
|
25
25
|
from flwr.cli.config_utils import get_fab_metadata
|
|
26
|
-
from flwr.common import
|
|
26
|
+
from flwr.common import ConfigRecord, Context, RecordDict
|
|
27
27
|
from flwr.common.constant import (
|
|
28
28
|
SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS,
|
|
29
29
|
Status,
|
|
@@ -141,7 +141,7 @@ class DeploymentEngine(Executor):
|
|
|
141
141
|
fab_id, fab_version = get_fab_metadata(fab.content)
|
|
142
142
|
|
|
143
143
|
run_id = self.linkstate.create_run(
|
|
144
|
-
fab_id, fab_version, fab_hash, override_config,
|
|
144
|
+
fab_id, fab_version, fab_hash, override_config, ConfigRecord()
|
|
145
145
|
)
|
|
146
146
|
return run_id
|
|
147
147
|
|
|
@@ -149,7 +149,7 @@ class DeploymentEngine(Executor):
|
|
|
149
149
|
"""Register a Context for a Run."""
|
|
150
150
|
# Create an empty context for the Run
|
|
151
151
|
context = Context(
|
|
152
|
-
run_id=run_id, node_id=0, node_config={}, state=
|
|
152
|
+
run_id=run_id, node_id=0, node_config={}, state=RecordDict(), run_config={}
|
|
153
153
|
)
|
|
154
154
|
|
|
155
155
|
# Register the context at the LinkState
|
|
@@ -160,7 +160,7 @@ class DeploymentEngine(Executor):
|
|
|
160
160
|
self,
|
|
161
161
|
fab_file: bytes,
|
|
162
162
|
override_config: UserConfig,
|
|
163
|
-
federation_options:
|
|
163
|
+
federation_options: ConfigRecord,
|
|
164
164
|
) -> Optional[int]:
|
|
165
165
|
"""Start run using the Flower Deployment Engine."""
|
|
166
166
|
run_id = None
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower Exec API event log interceptor."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from collections.abc import Iterator
|
|
19
|
+
from typing import Any, Callable, Union, cast
|
|
20
|
+
|
|
21
|
+
import grpc
|
|
22
|
+
from google.protobuf.message import Message as GrpcMessage
|
|
23
|
+
|
|
24
|
+
from flwr.common.event_log_plugin.event_log_plugin import EventLogWriterPlugin
|
|
25
|
+
from flwr.common.typing import LogEntry
|
|
26
|
+
|
|
27
|
+
from .exec_user_auth_interceptor import shared_user_info
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ExecEventLogInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
31
|
+
"""Exec API interceptor for logging events."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, log_plugin: EventLogWriterPlugin) -> None:
|
|
34
|
+
self.log_plugin = log_plugin
|
|
35
|
+
|
|
36
|
+
def intercept_service(
|
|
37
|
+
self,
|
|
38
|
+
continuation: Callable[[Any], Any],
|
|
39
|
+
handler_call_details: grpc.HandlerCallDetails,
|
|
40
|
+
) -> grpc.RpcMethodHandler:
|
|
41
|
+
"""Flower server interceptor logging logic.
|
|
42
|
+
|
|
43
|
+
Intercept all unary-unary/unary-stream calls from users and log the event.
|
|
44
|
+
Continue RPC call if event logger is enabled on the SuperLink, else, terminate
|
|
45
|
+
RPC call by setting context to abort.
|
|
46
|
+
"""
|
|
47
|
+
# One of the method handlers in
|
|
48
|
+
# `flwr.superexec.exec_servicer.ExecServicer`
|
|
49
|
+
method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
|
|
50
|
+
method_name: str = handler_call_details.method
|
|
51
|
+
return self._generic_event_log_unary_method_handler(method_handler, method_name)
|
|
52
|
+
|
|
53
|
+
def _generic_event_log_unary_method_handler(
|
|
54
|
+
self, method_handler: grpc.RpcMethodHandler, method_name: str
|
|
55
|
+
) -> grpc.RpcMethodHandler:
|
|
56
|
+
def _generic_method_handler(
|
|
57
|
+
request: GrpcMessage,
|
|
58
|
+
context: grpc.ServicerContext,
|
|
59
|
+
) -> Union[GrpcMessage, Iterator[GrpcMessage], BaseException]:
|
|
60
|
+
log_entry: LogEntry
|
|
61
|
+
# Log before call
|
|
62
|
+
log_entry = self.log_plugin.compose_log_before_event(
|
|
63
|
+
request=request,
|
|
64
|
+
context=context,
|
|
65
|
+
user_info=shared_user_info.get(),
|
|
66
|
+
method_name=method_name,
|
|
67
|
+
)
|
|
68
|
+
self.log_plugin.write_log(log_entry)
|
|
69
|
+
|
|
70
|
+
# For unary-unary calls, log after the call immediately
|
|
71
|
+
if method_handler.unary_unary:
|
|
72
|
+
unary_response, error = None, None
|
|
73
|
+
try:
|
|
74
|
+
unary_response = cast(
|
|
75
|
+
GrpcMessage, method_handler.unary_unary(request, context)
|
|
76
|
+
)
|
|
77
|
+
except BaseException as e:
|
|
78
|
+
error = e
|
|
79
|
+
raise
|
|
80
|
+
finally:
|
|
81
|
+
log_entry = self.log_plugin.compose_log_after_event(
|
|
82
|
+
request=request,
|
|
83
|
+
context=context,
|
|
84
|
+
user_info=shared_user_info.get(),
|
|
85
|
+
method_name=method_name,
|
|
86
|
+
response=unary_response or error,
|
|
87
|
+
)
|
|
88
|
+
self.log_plugin.write_log(log_entry)
|
|
89
|
+
return unary_response
|
|
90
|
+
|
|
91
|
+
# For unary-stream calls, wrap the response iterator and write the event log
|
|
92
|
+
# after iteration completes
|
|
93
|
+
if method_handler.unary_stream:
|
|
94
|
+
response_iterator = cast(
|
|
95
|
+
Iterator[GrpcMessage],
|
|
96
|
+
method_handler.unary_stream(request, context),
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def response_wrapper() -> Iterator[GrpcMessage]:
|
|
100
|
+
stream_response, error = None, None
|
|
101
|
+
try:
|
|
102
|
+
# pylint: disable=use-yield-from
|
|
103
|
+
for stream_response in response_iterator:
|
|
104
|
+
yield stream_response
|
|
105
|
+
except BaseException as e:
|
|
106
|
+
error = e
|
|
107
|
+
raise
|
|
108
|
+
finally:
|
|
109
|
+
# This block is executed after the client has consumed
|
|
110
|
+
# the entire stream, or if iteration is interrupted
|
|
111
|
+
log_entry = self.log_plugin.compose_log_after_event(
|
|
112
|
+
request=request,
|
|
113
|
+
context=context,
|
|
114
|
+
user_info=shared_user_info.get(),
|
|
115
|
+
method_name=method_name,
|
|
116
|
+
response=stream_response or error,
|
|
117
|
+
)
|
|
118
|
+
self.log_plugin.write_log(log_entry)
|
|
119
|
+
|
|
120
|
+
return response_wrapper()
|
|
121
|
+
|
|
122
|
+
raise RuntimeError() # This line is unreachable
|
|
123
|
+
|
|
124
|
+
if method_handler.unary_unary:
|
|
125
|
+
message_handler = grpc.unary_unary_rpc_method_handler
|
|
126
|
+
elif method_handler.unary_stream:
|
|
127
|
+
message_handler = grpc.unary_stream_rpc_method_handler
|
|
128
|
+
else:
|
|
129
|
+
# If the method type is not `unary_unary` or `unary_stream`, raise an error
|
|
130
|
+
raise NotImplementedError("This RPC method type is not supported.")
|
|
131
|
+
return message_handler(
|
|
132
|
+
_generic_method_handler,
|
|
133
|
+
request_deserializer=method_handler.request_deserializer,
|
|
134
|
+
response_serializer=method_handler.response_serializer,
|
|
135
|
+
)
|
flwr/superexec/exec_grpc.py
CHANGED
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
"""SuperExec gRPC API."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from collections.abc import Sequence
|
|
19
18
|
from logging import INFO
|
|
20
19
|
from typing import Optional
|
|
21
20
|
|
|
@@ -23,12 +22,14 @@ import grpc
|
|
|
23
22
|
|
|
24
23
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
25
24
|
from flwr.common.auth_plugin import ExecAuthPlugin
|
|
25
|
+
from flwr.common.event_log_plugin import EventLogWriterPlugin
|
|
26
26
|
from flwr.common.grpc import generic_create_grpc_server
|
|
27
27
|
from flwr.common.logger import log
|
|
28
28
|
from flwr.common.typing import UserConfig
|
|
29
29
|
from flwr.proto.exec_pb2_grpc import add_ExecServicer_to_server
|
|
30
30
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
31
31
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
32
|
+
from flwr.superexec.exec_event_log_interceptor import ExecEventLogInterceptor
|
|
32
33
|
from flwr.superexec.exec_user_auth_interceptor import ExecUserAuthInterceptor
|
|
33
34
|
|
|
34
35
|
from .exec_servicer import ExecServicer
|
|
@@ -44,6 +45,7 @@ def run_exec_api_grpc(
|
|
|
44
45
|
certificates: Optional[tuple[bytes, bytes, bytes]],
|
|
45
46
|
config: UserConfig,
|
|
46
47
|
auth_plugin: Optional[ExecAuthPlugin] = None,
|
|
48
|
+
event_log_plugin: Optional[EventLogWriterPlugin] = None,
|
|
47
49
|
) -> grpc.Server:
|
|
48
50
|
"""Run Exec API (gRPC, request-response)."""
|
|
49
51
|
executor.set_config(config)
|
|
@@ -54,16 +56,20 @@ def run_exec_api_grpc(
|
|
|
54
56
|
executor=executor,
|
|
55
57
|
auth_plugin=auth_plugin,
|
|
56
58
|
)
|
|
57
|
-
interceptors:
|
|
59
|
+
interceptors: list[grpc.ServerInterceptor] = []
|
|
58
60
|
if auth_plugin is not None:
|
|
59
|
-
interceptors
|
|
61
|
+
interceptors.append(ExecUserAuthInterceptor(auth_plugin))
|
|
62
|
+
# Event log interceptor must be added after user auth interceptor
|
|
63
|
+
if event_log_plugin is not None:
|
|
64
|
+
interceptors.append(ExecEventLogInterceptor(event_log_plugin))
|
|
65
|
+
log(INFO, "Flower event logging enabled")
|
|
60
66
|
exec_add_servicer_to_server_fn = add_ExecServicer_to_server
|
|
61
67
|
exec_grpc_server = generic_create_grpc_server(
|
|
62
68
|
servicer_and_add_fn=(exec_servicer, exec_add_servicer_to_server_fn),
|
|
63
69
|
server_address=address,
|
|
64
70
|
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
|
|
65
71
|
certificates=certificates,
|
|
66
|
-
interceptors=interceptors,
|
|
72
|
+
interceptors=interceptors or None,
|
|
67
73
|
)
|
|
68
74
|
|
|
69
75
|
if auth_plugin is None:
|
flwr/superexec/exec_servicer.py
CHANGED
|
@@ -28,7 +28,7 @@ from flwr.common.auth_plugin import ExecAuthPlugin
|
|
|
28
28
|
from flwr.common.constant import LOG_STREAM_INTERVAL, Status, SubStatus
|
|
29
29
|
from flwr.common.logger import log
|
|
30
30
|
from flwr.common.serde import (
|
|
31
|
-
|
|
31
|
+
config_record_from_proto,
|
|
32
32
|
run_to_proto,
|
|
33
33
|
user_config_from_proto,
|
|
34
34
|
)
|
|
@@ -79,7 +79,7 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
|
79
79
|
run_id = self.executor.start_run(
|
|
80
80
|
request.fab.content,
|
|
81
81
|
user_config_from_proto(request.override_config),
|
|
82
|
-
|
|
82
|
+
config_record_from_proto(request.federation_options),
|
|
83
83
|
)
|
|
84
84
|
|
|
85
85
|
if run_id is None:
|
|
@@ -15,11 +15,13 @@
|
|
|
15
15
|
"""Flower Exec API interceptor."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
|
|
18
|
+
import contextvars
|
|
19
|
+
from typing import Any, Callable, Union, cast
|
|
19
20
|
|
|
20
21
|
import grpc
|
|
21
22
|
|
|
22
23
|
from flwr.common.auth_plugin import ExecAuthPlugin
|
|
24
|
+
from flwr.common.typing import UserInfo
|
|
23
25
|
from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
|
|
24
26
|
GetAuthTokensRequest,
|
|
25
27
|
GetAuthTokensResponse,
|
|
@@ -43,6 +45,11 @@ Response = Union[
|
|
|
43
45
|
]
|
|
44
46
|
|
|
45
47
|
|
|
48
|
+
shared_user_info: contextvars.ContextVar[UserInfo] = contextvars.ContextVar(
|
|
49
|
+
"user_info", default=UserInfo(user_id=None, user_name=None)
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
46
53
|
class ExecUserAuthInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
47
54
|
"""Exec API interceptor for user authentication."""
|
|
48
55
|
|
|
@@ -77,13 +84,22 @@ class ExecUserAuthInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
77
84
|
) -> Response:
|
|
78
85
|
call = method_handler.unary_unary or method_handler.unary_stream
|
|
79
86
|
metadata = context.invocation_metadata()
|
|
87
|
+
|
|
88
|
+
# Intercept GetLoginDetails and GetAuthTokens requests, and return
|
|
89
|
+
# the response without authentication
|
|
80
90
|
if isinstance(request, (GetLoginDetailsRequest, GetAuthTokensRequest)):
|
|
81
91
|
return call(request, context) # type: ignore
|
|
82
92
|
|
|
83
|
-
|
|
93
|
+
# For other requests, check if the user is authenticated
|
|
94
|
+
valid_tokens, user_info = self.auth_plugin.validate_tokens_in_metadata(
|
|
95
|
+
metadata
|
|
96
|
+
)
|
|
84
97
|
if valid_tokens:
|
|
98
|
+
# Store user info in contextvars for authenticated users
|
|
99
|
+
shared_user_info.set(cast(UserInfo, user_info))
|
|
85
100
|
return call(request, context) # type: ignore
|
|
86
101
|
|
|
102
|
+
# If the user is not authenticated, refresh tokens
|
|
87
103
|
tokens = self.auth_plugin.refresh_tokens(context.invocation_metadata())
|
|
88
104
|
if tokens is not None:
|
|
89
105
|
context.send_initial_metadata(tokens)
|
flwr/superexec/executor.py
CHANGED
|
@@ -20,7 +20,7 @@ from dataclasses import dataclass, field
|
|
|
20
20
|
from subprocess import Popen
|
|
21
21
|
from typing import Optional
|
|
22
22
|
|
|
23
|
-
from flwr.common import
|
|
23
|
+
from flwr.common import ConfigRecord
|
|
24
24
|
from flwr.common.typing import UserConfig
|
|
25
25
|
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
|
|
26
26
|
from flwr.server.superlink.linkstate import LinkStateFactory
|
|
@@ -73,7 +73,7 @@ class Executor(ABC):
|
|
|
73
73
|
self,
|
|
74
74
|
fab_file: bytes,
|
|
75
75
|
override_config: UserConfig,
|
|
76
|
-
federation_options:
|
|
76
|
+
federation_options: ConfigRecord,
|
|
77
77
|
) -> Optional[int]:
|
|
78
78
|
"""Start a run using the given Flower FAB ID and version.
|
|
79
79
|
|
|
@@ -86,7 +86,7 @@ class Executor(ABC):
|
|
|
86
86
|
The Flower App Bundle file bytes.
|
|
87
87
|
override_config: UserConfig
|
|
88
88
|
The config overrides dict sent by the user (using `flwr run`).
|
|
89
|
-
federation_options:
|
|
89
|
+
federation_options: ConfigRecord
|
|
90
90
|
The federation options sent by the user (using `flwr run`).
|
|
91
91
|
|
|
92
92
|
Returns
|
flwr/superexec/simulation.py
CHANGED
|
@@ -22,7 +22,7 @@ from typing import Optional
|
|
|
22
22
|
from typing_extensions import override
|
|
23
23
|
|
|
24
24
|
from flwr.cli.config_utils import get_fab_metadata
|
|
25
|
-
from flwr.common import
|
|
25
|
+
from flwr.common import ConfigRecord, Context, RecordDict
|
|
26
26
|
from flwr.common.logger import log
|
|
27
27
|
from flwr.common.typing import Fab, UserConfig
|
|
28
28
|
from flwr.server.superlink.ffs import Ffs
|
|
@@ -76,7 +76,7 @@ class SimulationEngine(Executor):
|
|
|
76
76
|
self,
|
|
77
77
|
fab_file: bytes,
|
|
78
78
|
override_config: UserConfig,
|
|
79
|
-
federation_options:
|
|
79
|
+
federation_options: ConfigRecord,
|
|
80
80
|
) -> Optional[int]:
|
|
81
81
|
"""Start run using the Flower Simulation Engine."""
|
|
82
82
|
try:
|
|
@@ -104,7 +104,7 @@ class SimulationEngine(Executor):
|
|
|
104
104
|
run_id=run_id,
|
|
105
105
|
node_id=0,
|
|
106
106
|
node_config={},
|
|
107
|
-
state=
|
|
107
|
+
state=RecordDict(),
|
|
108
108
|
run_config={},
|
|
109
109
|
)
|
|
110
110
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: flwr
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.17.0
|
|
4
4
|
Summary: Flower: A Friendly Federated AI Framework
|
|
5
5
|
Home-page: https://flower.ai
|
|
6
6
|
License: Apache-2.0
|
|
@@ -63,7 +63,7 @@ Description-Content-Type: text/markdown
|
|
|
63
63
|
<a href="https://flower.ai/">Website</a> |
|
|
64
64
|
<a href="https://flower.ai/blog">Blog</a> |
|
|
65
65
|
<a href="https://flower.ai/docs/">Docs</a> |
|
|
66
|
-
<a href="https://flower.ai/
|
|
66
|
+
<a href="https://flower.ai/events/flower-ai-summit-2025">Summit</a> |
|
|
67
67
|
<a href="https://flower.ai/join-slack">Slack</a>
|
|
68
68
|
<br /><br />
|
|
69
69
|
</p>
|