flwr 1.15.2__py3-none-any.whl → 1.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/build.py +2 -0
- flwr/cli/log.py +20 -21
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
- flwr/client/client_app.py +147 -36
- flwr/client/clientapp/app.py +4 -0
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/rest_client/connection.py +4 -6
- flwr/client/supernode/__init__.py +0 -2
- flwr/client/supernode/app.py +1 -11
- flwr/common/address.py +35 -0
- flwr/common/args.py +8 -2
- flwr/common/auth_plugin/auth_plugin.py +2 -1
- flwr/common/constant.py +16 -0
- flwr/common/event_log_plugin/__init__.py +22 -0
- flwr/common/event_log_plugin/event_log_plugin.py +60 -0
- flwr/common/grpc.py +1 -1
- flwr/common/message.py +18 -7
- flwr/common/object_ref.py +0 -10
- flwr/common/record/conversion_utils.py +8 -17
- flwr/common/record/parametersrecord.py +151 -16
- flwr/common/record/recordset.py +95 -88
- flwr/common/secure_aggregation/quantization.py +5 -1
- flwr/common/serde.py +8 -126
- flwr/common/telemetry.py +0 -10
- flwr/common/typing.py +36 -0
- flwr/server/app.py +18 -2
- flwr/server/compat/app.py +4 -1
- flwr/server/compat/app_utils.py +10 -2
- flwr/server/compat/driver_client_proxy.py +2 -2
- flwr/server/driver/driver.py +1 -1
- flwr/server/driver/grpc_driver.py +10 -1
- flwr/server/driver/inmemory_driver.py +17 -20
- flwr/server/run_serverapp.py +2 -13
- flwr/server/server_app.py +93 -20
- flwr/server/superlink/driver/serverappio_servicer.py +25 -27
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +8 -12
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/vce_api.py +32 -35
- flwr/server/superlink/linkstate/in_memory_linkstate.py +140 -126
- flwr/server/superlink/linkstate/linkstate.py +47 -60
- flwr/server/superlink/linkstate/sqlite_linkstate.py +210 -276
- flwr/server/superlink/linkstate/utils.py +91 -119
- flwr/server/utils/__init__.py +2 -2
- flwr/server/utils/validator.py +53 -68
- flwr/server/workflow/default_workflows.py +4 -1
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +3 -3
- flwr/superexec/app.py +0 -14
- flwr/superexec/exec_servicer.py +4 -4
- flwr/superexec/exec_user_auth_interceptor.py +5 -3
- {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/METADATA +4 -4
- {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/RECORD +63 -66
- {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/entry_points.txt +0 -3
- flwr/client/message_handler/task_handler.py +0 -37
- flwr/proto/task_pb2.py +0 -33
- flwr/proto/task_pb2.pyi +0 -100
- flwr/proto/task_pb2_grpc.py +0 -4
- flwr/proto/task_pb2_grpc.pyi +0 -4
- {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/LICENSE +0 -0
- {flwr-1.15.2.dist-info → flwr-1.16.0.dist-info}/WHEEL +0 -0
|
@@ -15,21 +15,17 @@
|
|
|
15
15
|
"""Utility functions for State."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
from logging import ERROR
|
|
19
18
|
from os import urandom
|
|
20
|
-
from typing import Optional
|
|
19
|
+
from typing import Optional
|
|
21
20
|
from uuid import UUID, uuid4
|
|
22
21
|
|
|
23
|
-
from flwr.common import ConfigsRecord, Context,
|
|
22
|
+
from flwr.common import ConfigsRecord, Context, Error, Message, Metadata, now, serde
|
|
24
23
|
from flwr.common.constant import SUPERLINK_NODE_ID, ErrorCode, Status, SubStatus
|
|
25
24
|
from flwr.common.typing import RunStatus
|
|
26
25
|
|
|
27
26
|
# pylint: disable=E0611
|
|
28
|
-
from flwr.proto.error_pb2 import Error
|
|
29
27
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
|
30
|
-
from flwr.proto.node_pb2 import Node
|
|
31
28
|
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
|
|
32
|
-
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
|
33
29
|
|
|
34
30
|
# pylint: enable=E0611
|
|
35
31
|
|
|
@@ -235,165 +231,141 @@ def has_valid_sub_status(status: RunStatus) -> bool:
|
|
|
235
231
|
return status.sub_status == ""
|
|
236
232
|
|
|
237
233
|
|
|
238
|
-
def
|
|
239
|
-
"""Generate
|
|
240
|
-
|
|
241
|
-
Parameters
|
|
242
|
-
----------
|
|
243
|
-
taskins_id : Union[str, UUID]
|
|
244
|
-
The ID of the unavailable TaskIns.
|
|
245
|
-
|
|
246
|
-
Returns
|
|
247
|
-
-------
|
|
248
|
-
TaskRes
|
|
249
|
-
A TaskRes with an error code MESSAGE_UNAVAILABLE to indicate that the
|
|
250
|
-
inquired TaskIns ID cannot be found (due to non-existence or expiration).
|
|
251
|
-
"""
|
|
234
|
+
def create_message_error_unavailable_res_message(ins_metadata: Metadata) -> Message:
|
|
235
|
+
"""Generate an error Message that the SuperLink returns carrying the specified
|
|
236
|
+
error."""
|
|
252
237
|
current_time = now().timestamp()
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
ancestry=[str(taskins_id)],
|
|
264
|
-
task_type="", # Unknown message type
|
|
265
|
-
error=Error(
|
|
266
|
-
code=ErrorCode.MESSAGE_UNAVAILABLE,
|
|
267
|
-
reason=MESSAGE_UNAVAILABLE_ERROR_REASON,
|
|
268
|
-
),
|
|
269
|
-
),
|
|
238
|
+
ttl = max(ins_metadata.ttl - (current_time - ins_metadata.created_at), 0)
|
|
239
|
+
metadata = Metadata(
|
|
240
|
+
run_id=ins_metadata.run_id,
|
|
241
|
+
message_id=str(uuid4()),
|
|
242
|
+
src_node_id=SUPERLINK_NODE_ID,
|
|
243
|
+
dst_node_id=SUPERLINK_NODE_ID,
|
|
244
|
+
reply_to_message=ins_metadata.message_id,
|
|
245
|
+
group_id=ins_metadata.group_id,
|
|
246
|
+
message_type=ins_metadata.message_type,
|
|
247
|
+
ttl=ttl,
|
|
270
248
|
)
|
|
271
249
|
|
|
250
|
+
return Message(
|
|
251
|
+
metadata=metadata,
|
|
252
|
+
error=Error(
|
|
253
|
+
code=ErrorCode.REPLY_MESSAGE_UNAVAILABLE,
|
|
254
|
+
reason=REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON,
|
|
255
|
+
),
|
|
256
|
+
)
|
|
272
257
|
|
|
273
|
-
def create_taskres_for_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
|
|
274
|
-
"""Generate a TaskRes with a reply message unavailable error from a TaskIns.
|
|
275
258
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
259
|
+
def create_message_error_unavailable_ins_message(reply_to_message: UUID) -> Message:
|
|
260
|
+
"""Error to indicate that the enquired Message had expired before reply arrived or
|
|
261
|
+
that it isn't found."""
|
|
262
|
+
metadata = Metadata(
|
|
263
|
+
run_id=0, # Unknown
|
|
264
|
+
message_id=str(uuid4()),
|
|
265
|
+
src_node_id=SUPERLINK_NODE_ID,
|
|
266
|
+
dst_node_id=SUPERLINK_NODE_ID,
|
|
267
|
+
reply_to_message=str(reply_to_message),
|
|
268
|
+
group_id="", # Unknown
|
|
269
|
+
message_type="", # Unknown
|
|
270
|
+
ttl=0,
|
|
271
|
+
)
|
|
280
272
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
"""
|
|
287
|
-
current_time = now().timestamp()
|
|
288
|
-
ttl = ref_taskins.task.ttl - (current_time - ref_taskins.task.created_at)
|
|
289
|
-
if ttl < 0:
|
|
290
|
-
log(ERROR, "Creating TaskRes for TaskIns that exceeds its TTL.")
|
|
291
|
-
ttl = 0
|
|
292
|
-
return TaskRes(
|
|
293
|
-
task_id=str(uuid4()),
|
|
294
|
-
group_id=ref_taskins.group_id,
|
|
295
|
-
run_id=ref_taskins.run_id,
|
|
296
|
-
task=Task(
|
|
297
|
-
# This function is only called by SuperLink, and thus it's the producer.
|
|
298
|
-
producer=Node(node_id=SUPERLINK_NODE_ID),
|
|
299
|
-
consumer=Node(node_id=SUPERLINK_NODE_ID),
|
|
300
|
-
created_at=current_time,
|
|
301
|
-
ttl=ttl,
|
|
302
|
-
ancestry=[ref_taskins.task_id],
|
|
303
|
-
task_type=ref_taskins.task.task_type,
|
|
304
|
-
error=Error(
|
|
305
|
-
code=ErrorCode.REPLY_MESSAGE_UNAVAILABLE,
|
|
306
|
-
reason=REPLY_MESSAGE_UNAVAILABLE_ERROR_REASON,
|
|
307
|
-
),
|
|
273
|
+
return Message(
|
|
274
|
+
metadata=metadata,
|
|
275
|
+
error=Error(
|
|
276
|
+
code=ErrorCode.MESSAGE_UNAVAILABLE,
|
|
277
|
+
reason=MESSAGE_UNAVAILABLE_ERROR_REASON,
|
|
308
278
|
),
|
|
309
279
|
)
|
|
310
280
|
|
|
311
281
|
|
|
312
|
-
def
|
|
313
|
-
"""Check if the
|
|
314
|
-
return
|
|
282
|
+
def message_ttl_has_expired(message_metadata: Metadata, current_time: float) -> bool:
|
|
283
|
+
"""Check if the Message has expired."""
|
|
284
|
+
return message_metadata.ttl + message_metadata.created_at < current_time
|
|
315
285
|
|
|
316
286
|
|
|
317
|
-
def
|
|
318
|
-
|
|
319
|
-
|
|
287
|
+
def verify_message_ids(
|
|
288
|
+
inquired_message_ids: set[UUID],
|
|
289
|
+
found_message_ins_dict: dict[UUID, Message],
|
|
320
290
|
current_time: Optional[float] = None,
|
|
321
291
|
update_set: bool = True,
|
|
322
|
-
) -> dict[UUID,
|
|
323
|
-
"""Verify found
|
|
292
|
+
) -> dict[UUID, Message]:
|
|
293
|
+
"""Verify found Messages and generate error Messages for invalid ones.
|
|
324
294
|
|
|
325
295
|
Parameters
|
|
326
296
|
----------
|
|
327
|
-
|
|
328
|
-
Set of
|
|
329
|
-
|
|
330
|
-
Dictionary containing all found
|
|
297
|
+
inquired_message_ids : set[UUID]
|
|
298
|
+
Set of Message IDs for which to generate error Message if invalid.
|
|
299
|
+
found_message_ins_dict : dict[UUID, Message]
|
|
300
|
+
Dictionary containing all found Message indexed by their IDs.
|
|
331
301
|
current_time : Optional[float] (default: None)
|
|
332
302
|
The current time to check for expiration. If set to `None`, the current time
|
|
333
303
|
will automatically be set to the current timestamp using `now().timestamp()`.
|
|
334
304
|
update_set : bool (default: True)
|
|
335
|
-
If True, the `
|
|
305
|
+
If True, the `inquired_message_ids` will be updated to remove invalid ones,
|
|
336
306
|
by default True.
|
|
337
307
|
|
|
338
308
|
Returns
|
|
339
309
|
-------
|
|
340
|
-
dict[UUID,
|
|
341
|
-
A dictionary of error
|
|
310
|
+
dict[UUID, Message]
|
|
311
|
+
A dictionary of error Message indexed by the corresponding ID of the message
|
|
312
|
+
they are a reply of.
|
|
342
313
|
"""
|
|
343
314
|
ret_dict = {}
|
|
344
315
|
current = current_time if current_time else now().timestamp()
|
|
345
|
-
for
|
|
346
|
-
# Generate error
|
|
347
|
-
|
|
348
|
-
if
|
|
316
|
+
for message_id in list(inquired_message_ids):
|
|
317
|
+
# Generate error message if the inquired message doesn't exist or has expired
|
|
318
|
+
message_ins = found_message_ins_dict.get(message_id)
|
|
319
|
+
if message_ins is None or message_ttl_has_expired(
|
|
320
|
+
message_ins.metadata, current
|
|
321
|
+
):
|
|
349
322
|
if update_set:
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
ret_dict[
|
|
323
|
+
inquired_message_ids.remove(message_id)
|
|
324
|
+
message_res = create_message_error_unavailable_ins_message(message_id)
|
|
325
|
+
ret_dict[message_id] = message_res
|
|
353
326
|
return ret_dict
|
|
354
327
|
|
|
355
328
|
|
|
356
|
-
def
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
329
|
+
def verify_found_message_replies(
|
|
330
|
+
inquired_message_ids: set[UUID],
|
|
331
|
+
found_message_ins_dict: dict[UUID, Message],
|
|
332
|
+
found_message_res_list: list[Message],
|
|
360
333
|
current_time: Optional[float] = None,
|
|
361
334
|
update_set: bool = True,
|
|
362
|
-
) -> dict[UUID,
|
|
363
|
-
"""Verify found
|
|
335
|
+
) -> dict[UUID, Message]:
|
|
336
|
+
"""Verify found Message replies and generate error Message for invalid ones.
|
|
364
337
|
|
|
365
338
|
Parameters
|
|
366
339
|
----------
|
|
367
|
-
|
|
368
|
-
Set of
|
|
369
|
-
|
|
370
|
-
Dictionary containing all found
|
|
371
|
-
|
|
372
|
-
List of found
|
|
340
|
+
inquired_message_ids : set[UUID]
|
|
341
|
+
Set of Message IDs for which to generate error Message if invalid.
|
|
342
|
+
found_message_ins_dict : dict[UUID, Message]
|
|
343
|
+
Dictionary containing all found instruction Messages indexed by their IDs.
|
|
344
|
+
found_message_res_list : dict[Message, Message]
|
|
345
|
+
List of found Message to be verified.
|
|
373
346
|
current_time : Optional[float] (default: None)
|
|
374
347
|
The current time to check for expiration. If set to `None`, the current time
|
|
375
348
|
will automatically be set to the current timestamp using `now().timestamp()`.
|
|
376
349
|
update_set : bool (default: True)
|
|
377
|
-
If True, the `
|
|
378
|
-
that have a
|
|
350
|
+
If True, the `inquired_message_ids` will be updated to remove ones
|
|
351
|
+
that have a reply Message, by default True.
|
|
379
352
|
|
|
380
353
|
Returns
|
|
381
354
|
-------
|
|
382
|
-
dict[UUID,
|
|
383
|
-
A dictionary of
|
|
355
|
+
dict[UUID, Message]
|
|
356
|
+
A dictionary of Message indexed by the corresponding Message ID.
|
|
384
357
|
"""
|
|
385
|
-
ret_dict: dict[UUID,
|
|
358
|
+
ret_dict: dict[UUID, Message] = {}
|
|
386
359
|
current = current_time if current_time else now().timestamp()
|
|
387
|
-
for
|
|
388
|
-
|
|
360
|
+
for message_res in found_message_res_list:
|
|
361
|
+
message_ins_id = UUID(message_res.metadata.reply_to_message)
|
|
389
362
|
if update_set:
|
|
390
|
-
|
|
391
|
-
# Check if the
|
|
392
|
-
if
|
|
393
|
-
# No need to insert the error
|
|
394
|
-
|
|
395
|
-
|
|
363
|
+
inquired_message_ids.remove(message_ins_id)
|
|
364
|
+
# Check if the reply Message has expired
|
|
365
|
+
if message_ttl_has_expired(message_res.metadata, current):
|
|
366
|
+
# No need to insert the error Message
|
|
367
|
+
message_res = create_message_error_unavailable_res_message(
|
|
368
|
+
found_message_ins_dict[message_ins_id].metadata
|
|
396
369
|
)
|
|
397
|
-
|
|
398
|
-
ret_dict[taskins_id] = taskres
|
|
370
|
+
ret_dict[message_ins_id] = message_res
|
|
399
371
|
return ret_dict
|
flwr/server/utils/__init__.py
CHANGED
|
@@ -16,9 +16,9 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from .tensorboard import tensorboard as tensorboard
|
|
19
|
-
from .validator import
|
|
19
|
+
from .validator import validate_message as validate_message
|
|
20
20
|
|
|
21
21
|
__all__ = [
|
|
22
22
|
"tensorboard",
|
|
23
|
-
"
|
|
23
|
+
"validate_message",
|
|
24
24
|
]
|
flwr/server/utils/validator.py
CHANGED
|
@@ -16,93 +16,78 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import time
|
|
19
|
-
from typing import Union
|
|
20
19
|
|
|
20
|
+
from flwr.common import Message
|
|
21
21
|
from flwr.common.constant import SUPERLINK_NODE_ID
|
|
22
|
-
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
|
|
23
22
|
|
|
24
23
|
|
|
25
|
-
# pylint: disable-next=too-many-branches
|
|
26
|
-
def
|
|
27
|
-
"""Validate a
|
|
24
|
+
# pylint: disable-next=too-many-branches
|
|
25
|
+
def validate_message(message: Message, is_reply_message: bool) -> list[str]:
|
|
26
|
+
"""Validate a Message."""
|
|
28
27
|
validation_errors = []
|
|
28
|
+
metadata = message.metadata
|
|
29
29
|
|
|
30
|
-
if
|
|
31
|
-
validation_errors.append("non-empty `
|
|
32
|
-
|
|
33
|
-
if not tasks_ins_res.HasField("task"):
|
|
34
|
-
validation_errors.append("`task` does not set field `task`")
|
|
30
|
+
if metadata.message_id != "":
|
|
31
|
+
validation_errors.append("non-empty `metadata.message_id`")
|
|
35
32
|
|
|
36
33
|
# Created/delivered/TTL/Pushed
|
|
37
34
|
if (
|
|
38
|
-
|
|
39
|
-
): # unix timestamp of
|
|
35
|
+
metadata.created_at < 1740700800.0
|
|
36
|
+
): # unix timestamp of 28 February 2025 00h:00m:00s UTC
|
|
40
37
|
validation_errors.append(
|
|
41
|
-
"`created_at` must be a float that records the unix timestamp "
|
|
38
|
+
"`metadata.created_at` must be a float that records the unix timestamp "
|
|
42
39
|
"in seconds when the message was created."
|
|
43
40
|
)
|
|
44
|
-
if
|
|
45
|
-
validation_errors.append("`delivered_at` must be an empty str")
|
|
46
|
-
if
|
|
47
|
-
validation_errors.append("`ttl` must be higher than zero")
|
|
41
|
+
if metadata.delivered_at != "":
|
|
42
|
+
validation_errors.append("`metadata.delivered_at` must be an empty str")
|
|
43
|
+
if metadata.ttl <= 0:
|
|
44
|
+
validation_errors.append("`metadata.ttl` must be higher than zero")
|
|
48
45
|
|
|
49
46
|
# Verify TTL and created_at time
|
|
50
47
|
current_time = time.time()
|
|
51
|
-
if
|
|
52
|
-
validation_errors.append("
|
|
53
|
-
|
|
54
|
-
# TaskIns specific
|
|
55
|
-
if isinstance(tasks_ins_res, TaskIns):
|
|
56
|
-
# Task producer
|
|
57
|
-
if not tasks_ins_res.task.HasField("producer"):
|
|
58
|
-
validation_errors.append("`producer` does not set field `producer`")
|
|
59
|
-
if tasks_ins_res.task.producer.node_id != SUPERLINK_NODE_ID:
|
|
60
|
-
validation_errors.append(f"`producer.node_id` is not {SUPERLINK_NODE_ID}")
|
|
61
|
-
|
|
62
|
-
# Task consumer
|
|
63
|
-
if not tasks_ins_res.task.HasField("consumer"):
|
|
64
|
-
validation_errors.append("`consumer` does not set field `consumer`")
|
|
65
|
-
if tasks_ins_res.task.consumer.node_id == SUPERLINK_NODE_ID:
|
|
66
|
-
validation_errors.append("consumer MUST provide a valid `node_id`")
|
|
48
|
+
if metadata.created_at + metadata.ttl <= current_time:
|
|
49
|
+
validation_errors.append("Message TTL has expired")
|
|
67
50
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
if not (
|
|
72
|
-
tasks_ins_res.task.HasField("recordset")
|
|
73
|
-
^ tasks_ins_res.task.HasField("error")
|
|
74
|
-
):
|
|
75
|
-
validation_errors.append("Either `recordset` or `error` MUST be set")
|
|
51
|
+
# Source node is set and is not zero
|
|
52
|
+
if not metadata.src_node_id:
|
|
53
|
+
validation_errors.append("`metadata.src_node_id` is not set.")
|
|
76
54
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
55
|
+
# Destination node is set and is not zero
|
|
56
|
+
if not metadata.dst_node_id:
|
|
57
|
+
validation_errors.append("`metadata.dst_node_id` is not set.")
|
|
80
58
|
|
|
81
|
-
#
|
|
82
|
-
if
|
|
83
|
-
|
|
84
|
-
if not tasks_ins_res.task.HasField("producer"):
|
|
85
|
-
validation_errors.append("`producer` does not set field `producer`")
|
|
86
|
-
if tasks_ins_res.task.producer.node_id == SUPERLINK_NODE_ID:
|
|
87
|
-
validation_errors.append("producer MUST provide a valid `node_id`")
|
|
59
|
+
# Message type
|
|
60
|
+
if metadata.message_type == "":
|
|
61
|
+
validation_errors.append("`metadata.message_type` MUST be set")
|
|
88
62
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
# Content check
|
|
96
|
-
if tasks_ins_res.task.task_type == "":
|
|
97
|
-
validation_errors.append("`task_type` MUST be set")
|
|
98
|
-
if not (
|
|
99
|
-
tasks_ins_res.task.HasField("recordset")
|
|
100
|
-
^ tasks_ins_res.task.HasField("error")
|
|
101
|
-
):
|
|
102
|
-
validation_errors.append("Either `recordset` or `error` MUST be set")
|
|
63
|
+
# Content
|
|
64
|
+
if not message.has_content() != message.has_error():
|
|
65
|
+
validation_errors.append(
|
|
66
|
+
"Either message `content` or `error` MUST be set (but not both)"
|
|
67
|
+
)
|
|
103
68
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
69
|
+
# Link respose to original message
|
|
70
|
+
if not is_reply_message:
|
|
71
|
+
if metadata.reply_to_message != "":
|
|
72
|
+
validation_errors.append("`metadata.reply_to_message` MUST not be set.")
|
|
73
|
+
if metadata.src_node_id != SUPERLINK_NODE_ID:
|
|
74
|
+
validation_errors.append(
|
|
75
|
+
f"`metadata.src_node_id` is not {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
|
76
|
+
)
|
|
77
|
+
if metadata.dst_node_id == SUPERLINK_NODE_ID:
|
|
78
|
+
validation_errors.append(
|
|
79
|
+
f"`metadata.dst_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
|
80
|
+
)
|
|
81
|
+
else:
|
|
82
|
+
if metadata.reply_to_message == "":
|
|
83
|
+
validation_errors.append("`metadata.reply_to_message` MUST be set.")
|
|
84
|
+
if metadata.src_node_id == SUPERLINK_NODE_ID:
|
|
85
|
+
validation_errors.append(
|
|
86
|
+
f"`metadata.src_node_id` is {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
|
87
|
+
)
|
|
88
|
+
if metadata.dst_node_id != SUPERLINK_NODE_ID:
|
|
89
|
+
validation_errors.append(
|
|
90
|
+
f"`metadata.dst_node_id` is not {SUPERLINK_NODE_ID} (SuperLink node ID)"
|
|
91
|
+
)
|
|
107
92
|
|
|
108
93
|
return validation_errors
|
|
@@ -64,10 +64,13 @@ class DefaultWorkflow:
|
|
|
64
64
|
)
|
|
65
65
|
|
|
66
66
|
# Start the thread updating nodes
|
|
67
|
-
thread, f_stop = start_update_client_manager_thread(
|
|
67
|
+
thread, f_stop, c_done = start_update_client_manager_thread(
|
|
68
68
|
driver, context.client_manager
|
|
69
69
|
)
|
|
70
70
|
|
|
71
|
+
# Wait until the node registration done
|
|
72
|
+
c_done.wait()
|
|
73
|
+
|
|
71
74
|
# Initialize parameters
|
|
72
75
|
log(INFO, "[INIT]")
|
|
73
76
|
default_init_params_workflow(driver, context)
|
|
@@ -367,7 +367,7 @@ class SecAggPlusWorkflow:
|
|
|
367
367
|
|
|
368
368
|
# Send setup configuration to clients
|
|
369
369
|
cfgs_record = ConfigsRecord(sa_params_dict) # type: ignore
|
|
370
|
-
content = RecordSet(
|
|
370
|
+
content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record})
|
|
371
371
|
|
|
372
372
|
def make(nid: int) -> Message:
|
|
373
373
|
return driver.create_message(
|
|
@@ -417,7 +417,7 @@ class SecAggPlusWorkflow:
|
|
|
417
417
|
{str(nid): state.nid_to_publickeys[nid] for nid in neighbours}
|
|
418
418
|
)
|
|
419
419
|
cfgs_record[Key.STAGE] = Stage.SHARE_KEYS
|
|
420
|
-
content = RecordSet(
|
|
420
|
+
content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record})
|
|
421
421
|
return driver.create_message(
|
|
422
422
|
content=content,
|
|
423
423
|
message_type=MessageType.TRAIN,
|
|
@@ -566,7 +566,7 @@ class SecAggPlusWorkflow:
|
|
|
566
566
|
Key.DEAD_NODE_ID_LIST: list(neighbours & dead_nids),
|
|
567
567
|
}
|
|
568
568
|
cfgs_record = ConfigsRecord(cfgs_dict) # type: ignore
|
|
569
|
-
content = RecordSet(
|
|
569
|
+
content = RecordSet({RECORD_KEY_CONFIGS: cfgs_record})
|
|
570
570
|
return driver.create_message(
|
|
571
571
|
content=content,
|
|
572
572
|
message_type=MessageType.TRAIN,
|
flwr/superexec/app.py
CHANGED
|
@@ -16,26 +16,12 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import argparse
|
|
19
|
-
import sys
|
|
20
|
-
from logging import INFO
|
|
21
19
|
|
|
22
|
-
from flwr.common import log
|
|
23
20
|
from flwr.common.object_ref import load_app, validate
|
|
24
21
|
|
|
25
22
|
from .executor import Executor
|
|
26
23
|
|
|
27
24
|
|
|
28
|
-
def run_superexec() -> None:
|
|
29
|
-
"""Run Flower SuperExec."""
|
|
30
|
-
log(INFO, "Starting Flower SuperExec")
|
|
31
|
-
|
|
32
|
-
sys.exit(
|
|
33
|
-
"Manually launching the SuperExec is deprecated. Since `flwr 1.13.0` "
|
|
34
|
-
"the executor service runs in the SuperLink. Launching it manually is not "
|
|
35
|
-
"recommended."
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
|
|
39
25
|
def load_executor(
|
|
40
26
|
args: argparse.Namespace,
|
|
41
27
|
) -> Executor:
|
flwr/superexec/exec_servicer.py
CHANGED
|
@@ -120,7 +120,7 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
|
120
120
|
run_status = state.get_run_status({run_id})[run_id]
|
|
121
121
|
if run_status.status == Status.FINISHED:
|
|
122
122
|
log(INFO, "All logs for run ID `%s` returned", request.run_id)
|
|
123
|
-
|
|
123
|
+
break
|
|
124
124
|
|
|
125
125
|
time.sleep(LOG_STREAM_INTERVAL) # Sleep briefly to avoid busy waiting
|
|
126
126
|
|
|
@@ -163,10 +163,10 @@ class ExecServicer(exec_pb2_grpc.ExecServicer):
|
|
|
163
163
|
)
|
|
164
164
|
|
|
165
165
|
if update_success:
|
|
166
|
-
|
|
166
|
+
message_ids: set[UUID] = state.get_message_ids_from_run_id(request.run_id)
|
|
167
167
|
|
|
168
|
-
# Delete
|
|
169
|
-
state.
|
|
168
|
+
# Delete Messages and their replies for the `run_id`
|
|
169
|
+
state.delete_messages(message_ids)
|
|
170
170
|
|
|
171
171
|
return StopRunResponse(success=update_success)
|
|
172
172
|
|
|
@@ -77,9 +77,11 @@ class ExecUserAuthInterceptor(grpc.ServerInterceptor): # type: ignore
|
|
|
77
77
|
) -> Response:
|
|
78
78
|
call = method_handler.unary_unary or method_handler.unary_stream
|
|
79
79
|
metadata = context.invocation_metadata()
|
|
80
|
-
if isinstance(
|
|
81
|
-
|
|
82
|
-
|
|
80
|
+
if isinstance(request, (GetLoginDetailsRequest, GetAuthTokensRequest)):
|
|
81
|
+
return call(request, context) # type: ignore
|
|
82
|
+
|
|
83
|
+
valid_tokens, _ = self.auth_plugin.validate_tokens_in_metadata(metadata)
|
|
84
|
+
if valid_tokens:
|
|
83
85
|
return call(request, context) # type: ignore
|
|
84
86
|
|
|
85
87
|
tokens = self.auth_plugin.refresh_tokens(context.invocation_metadata())
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: flwr
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.16.0
|
|
4
4
|
Summary: Flower: A Friendly Federated AI Framework
|
|
5
5
|
Home-page: https://flower.ai
|
|
6
6
|
License: Apache-2.0
|
|
7
7
|
Keywords: Artificial Intelligence,Federated AI,Federated Analytics,Federated Evaluation,Federated Learning,Flower,Machine Learning
|
|
8
8
|
Author: The Flower Authors
|
|
9
9
|
Author-email: hello@flower.ai
|
|
10
|
-
Requires-Python: >=3.9,<4.0
|
|
10
|
+
Requires-Python: >=3.9.2,<4.0.0
|
|
11
11
|
Classifier: Development Status :: 5 - Production/Stable
|
|
12
12
|
Classifier: Intended Audience :: Developers
|
|
13
13
|
Classifier: Intended Audience :: Science/Research
|
|
@@ -16,12 +16,12 @@ Classifier: Operating System :: MacOS :: MacOS X
|
|
|
16
16
|
Classifier: Operating System :: POSIX :: Linux
|
|
17
17
|
Classifier: Programming Language :: Python
|
|
18
18
|
Classifier: Programming Language :: Python :: 3
|
|
19
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
20
19
|
Classifier: Programming Language :: Python :: 3.10
|
|
21
20
|
Classifier: Programming Language :: Python :: 3.11
|
|
22
21
|
Classifier: Programming Language :: Python :: 3.12
|
|
23
22
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
24
23
|
Classifier: Programming Language :: Python :: 3.13
|
|
24
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
25
25
|
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
26
26
|
Classifier: Topic :: Scientific/Engineering
|
|
27
27
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
@@ -32,7 +32,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
|
32
32
|
Classifier: Typing :: Typed
|
|
33
33
|
Provides-Extra: rest
|
|
34
34
|
Provides-Extra: simulation
|
|
35
|
-
Requires-Dist: cryptography (>=
|
|
35
|
+
Requires-Dist: cryptography (>=44.0.1,<45.0.0)
|
|
36
36
|
Requires-Dist: grpcio (>=1.62.3,<2.0.0,!=1.65.0)
|
|
37
37
|
Requires-Dist: iterators (>=0.0.2,<0.0.3)
|
|
38
38
|
Requires-Dist: numpy (>=1.26.0,<3.0.0)
|