flwr 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/build.py +94 -59
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +15 -36
- flwr/cli/new/new.py +12 -4
- flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
- flwr/cli/new/templates/app/README.md.tpl +5 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +25 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
- flwr/cli/run/run.py +48 -49
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +38 -5
- flwr/client/__init__.py +2 -2
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/__init__.py +0 -7
- flwr/client/grpc_adapter_client/connection.py +15 -8
- flwr/client/grpc_rere_client/connection.py +142 -97
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/comms_mods.py +36 -17
- flwr/client/rest_client/connection.py +176 -103
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +29 -3
- flwr/common/constant.py +39 -8
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit/exit_code.py +16 -1
- flwr/common/exit_handlers.py +30 -0
- flwr/common/grpc.py +12 -1
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_protobuf_utils.py +141 -0
- flwr/common/inflatable_utils.py +508 -0
- flwr/common/message.py +110 -242
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/array.py +402 -0
- flwr/common/record/arraychunk.py +59 -0
- flwr/common/record/arrayrecord.py +103 -225
- flwr/common/record/configrecord.py +59 -4
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/metricrecord.py +55 -4
- flwr/common/record/recorddict.py +69 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +5 -1
- flwr/common/serde.py +59 -211
- flwr/common/serde_utils.py +175 -0
- flwr/common/typing.py +5 -3
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +28 -185
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/appio_pb2.py +43 -0
- flwr/proto/appio_pb2.pyi +151 -0
- flwr/proto/appio_pb2_grpc.py +4 -0
- flwr/proto/appio_pb2_grpc.pyi +4 -0
- flwr/proto/clientappio_pb2.py +12 -19
- flwr/proto/clientappio_pb2.pyi +23 -101
- flwr/proto/clientappio_pb2_grpc.py +269 -28
- flwr/proto/clientappio_pb2_grpc.pyi +114 -20
- flwr/proto/fleet_pb2.py +24 -27
- flwr/proto/fleet_pb2.pyi +19 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +9 -23
- flwr/proto/serverappio_pb2.pyi +0 -110
- flwr/proto/serverappio_pb2_grpc.py +177 -72
- flwr/proto/serverappio_pb2_grpc.pyi +75 -33
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +1 -1
- flwr/server/app.py +69 -187
- flwr/server/compat/app_utils.py +50 -28
- flwr/server/fleet_event_log_interceptor.py +6 -2
- flwr/server/grid/grpc_grid.py +148 -41
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/serverapp/app.py +45 -17
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +21 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
- flwr/server/superlink/fleet/message_handler/message_handler.py +130 -19
- flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -13
- flwr/server/superlink/fleet/vce/vce_api.py +6 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
- flwr/server/superlink/linkstate/linkstate.py +53 -20
- flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
- flwr/server/superlink/linkstate/utils.py +33 -29
- flwr/server/superlink/serverappio/serverappio_grpc.py +4 -1
- flwr/server/superlink/serverappio/serverappio_servicer.py +230 -84
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
- flwr/server/superlink/utils.py +9 -2
- flwr/server/utils/validator.py +2 -2
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/app.py +25 -0
- flwr/simulation/run_simulation.py +17 -0
- flwr/supercore/__init__.py +15 -0
- flwr/{server/superlink → supercore}/ffs/__init__.py +2 -0
- flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
- flwr/supercore/grpc_health/__init__.py +22 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
- flwr/supercore/license_plugin/__init__.py +22 -0
- flwr/supercore/license_plugin/license_plugin.py +26 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +170 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/supercore/object_store/utils.py +43 -0
- flwr/supercore/scheduler/__init__.py +22 -0
- flwr/supercore/scheduler/plugin.py +71 -0
- flwr/{client/nodestate/nodestate.py → supercore/utils.py} +14 -13
- flwr/superexec/deployment.py +7 -4
- flwr/superexec/exec_event_log_interceptor.py +8 -4
- flwr/superexec/exec_grpc.py +25 -5
- flwr/superexec/exec_license_interceptor.py +82 -0
- flwr/superexec/exec_servicer.py +135 -24
- flwr/superexec/exec_user_auth_interceptor.py +45 -8
- flwr/superexec/executor.py +5 -1
- flwr/superexec/simulation.py +8 -3
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +0 -7
- flwr/supernode/cli/__init__.py +24 -0
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -19
- flwr/supernode/cli/flwr_clientapp.py +88 -0
- flwr/supernode/nodestate/in_memory_nodestate.py +199 -0
- flwr/supernode/nodestate/nodestate.py +227 -0
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +135 -89
- flwr/supernode/scheduler/__init__.py +22 -0
- flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +22 -0
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +303 -0
- flwr/supernode/start_client_internal.py +589 -0
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/METADATA +6 -4
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/RECORD +171 -123
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +1 -1
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +2 -2
- flwr/client/clientapp/clientappio_servicer.py +0 -244
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
- /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
- /flwr/{client → supernode}/nodestate/__init__.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
flwr/common/retry_invoker.py
CHANGED
|
@@ -25,10 +25,12 @@ from typing import Any, Callable, Optional, Union, cast
|
|
|
25
25
|
|
|
26
26
|
import grpc
|
|
27
27
|
|
|
28
|
+
from flwr.client.grpc_rere_client.grpc_adapter import GrpcAdapter
|
|
28
29
|
from flwr.common.constant import MAX_RETRY_DELAY
|
|
29
30
|
from flwr.common.logger import log
|
|
30
31
|
from flwr.common.typing import RunNotRunningException
|
|
31
32
|
from flwr.proto.clientappio_pb2_grpc import ClientAppIoStub
|
|
33
|
+
from flwr.proto.fleet_pb2_grpc import FleetStub
|
|
32
34
|
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub
|
|
33
35
|
from flwr.proto.simulationio_pb2_grpc import SimulationIoStub
|
|
34
36
|
|
|
@@ -366,7 +368,9 @@ def _make_simple_grpc_retry_invoker() -> RetryInvoker:
|
|
|
366
368
|
|
|
367
369
|
|
|
368
370
|
def _wrap_stub(
|
|
369
|
-
stub: Union[
|
|
371
|
+
stub: Union[
|
|
372
|
+
ServerAppIoStub, ClientAppIoStub, SimulationIoStub, FleetStub, GrpcAdapter
|
|
373
|
+
],
|
|
370
374
|
retry_invoker: RetryInvoker,
|
|
371
375
|
) -> None:
|
|
372
376
|
"""Wrap a gRPC stub with a retry invoker."""
|
flwr/common/serde.py
CHANGED
|
@@ -16,28 +16,19 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from collections import OrderedDict
|
|
19
|
-
from
|
|
20
|
-
from typing import Any, TypeVar, cast
|
|
21
|
-
|
|
22
|
-
from google.protobuf.message import Message as GrpcMessage
|
|
19
|
+
from typing import Any, cast
|
|
23
20
|
|
|
24
21
|
# pylint: disable=E0611
|
|
25
|
-
from flwr.proto.clientappio_pb2 import ClientAppOutputCode, ClientAppOutputStatus
|
|
26
|
-
from flwr.proto.error_pb2 import Error as ProtoError
|
|
27
22
|
from flwr.proto.fab_pb2 import Fab as ProtoFab
|
|
28
23
|
from flwr.proto.message_pb2 import Context as ProtoContext
|
|
29
24
|
from flwr.proto.message_pb2 import Message as ProtoMessage
|
|
30
|
-
from flwr.proto.message_pb2 import Metadata as ProtoMetadata
|
|
31
25
|
from flwr.proto.recorddict_pb2 import Array as ProtoArray
|
|
32
26
|
from flwr.proto.recorddict_pb2 import ArrayRecord as ProtoArrayRecord
|
|
33
|
-
from flwr.proto.recorddict_pb2 import BoolList, BytesList
|
|
34
27
|
from flwr.proto.recorddict_pb2 import ConfigRecord as ProtoConfigRecord
|
|
35
28
|
from flwr.proto.recorddict_pb2 import ConfigRecordValue as ProtoConfigRecordValue
|
|
36
|
-
from flwr.proto.recorddict_pb2 import DoubleList
|
|
37
29
|
from flwr.proto.recorddict_pb2 import MetricRecord as ProtoMetricRecord
|
|
38
30
|
from flwr.proto.recorddict_pb2 import MetricRecordValue as ProtoMetricRecordValue
|
|
39
31
|
from flwr.proto.recorddict_pb2 import RecordDict as ProtoRecordDict
|
|
40
|
-
from flwr.proto.recorddict_pb2 import SintList, StringList, UintList
|
|
41
32
|
from flwr.proto.run_pb2 import Run as ProtoRun
|
|
42
33
|
from flwr.proto.run_pb2 import RunStatus as ProtoRunStatus
|
|
43
34
|
from flwr.proto.transport_pb2 import (
|
|
@@ -60,8 +51,16 @@ from . import (
|
|
|
60
51
|
RecordDict,
|
|
61
52
|
typing,
|
|
62
53
|
)
|
|
63
|
-
from .
|
|
64
|
-
from .
|
|
54
|
+
from .constant import INT64_MAX_VALUE
|
|
55
|
+
from .message import Message, make_message
|
|
56
|
+
from .serde_utils import (
|
|
57
|
+
error_from_proto,
|
|
58
|
+
error_to_proto,
|
|
59
|
+
metadata_from_proto,
|
|
60
|
+
metadata_to_proto,
|
|
61
|
+
record_value_dict_from_proto,
|
|
62
|
+
record_value_dict_to_proto,
|
|
63
|
+
)
|
|
65
64
|
|
|
66
65
|
# === Parameters message ===
|
|
67
66
|
|
|
@@ -339,7 +338,6 @@ def metrics_from_proto(proto: Any) -> typing.Metrics:
|
|
|
339
338
|
|
|
340
339
|
|
|
341
340
|
# === Scalar messages ===
|
|
342
|
-
INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1
|
|
343
341
|
|
|
344
342
|
|
|
345
343
|
def scalar_to_proto(scalar: typing.Scalar) -> Scalar:
|
|
@@ -377,107 +375,21 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:
|
|
|
377
375
|
# === Record messages ===
|
|
378
376
|
|
|
379
377
|
|
|
380
|
-
_type_to_field: dict[type, str] = {
|
|
381
|
-
float: "double",
|
|
382
|
-
int: "sint64",
|
|
383
|
-
bool: "bool",
|
|
384
|
-
str: "string",
|
|
385
|
-
bytes: "bytes",
|
|
386
|
-
}
|
|
387
|
-
_list_type_to_class_and_field: dict[type, tuple[type[GrpcMessage], str]] = {
|
|
388
|
-
float: (DoubleList, "double_list"),
|
|
389
|
-
int: (SintList, "sint_list"),
|
|
390
|
-
bool: (BoolList, "bool_list"),
|
|
391
|
-
str: (StringList, "string_list"),
|
|
392
|
-
bytes: (BytesList, "bytes_list"),
|
|
393
|
-
}
|
|
394
|
-
T = TypeVar("T")
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
def _is_uint64(value: Any) -> bool:
|
|
398
|
-
"""Check if a value is uint64."""
|
|
399
|
-
return isinstance(value, int) and value > INT64_MAX_VALUE
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
def _record_value_to_proto(
|
|
403
|
-
value: Any, allowed_types: list[type], proto_class: type[T]
|
|
404
|
-
) -> T:
|
|
405
|
-
"""Serialize `*RecordValue` to ProtoBuf.
|
|
406
|
-
|
|
407
|
-
Note: `bool` MUST be put in the front of allowd_types if it exists.
|
|
408
|
-
"""
|
|
409
|
-
arg = {}
|
|
410
|
-
for t in allowed_types:
|
|
411
|
-
# Single element
|
|
412
|
-
# Note: `isinstance(False, int) == True`.
|
|
413
|
-
if isinstance(value, t):
|
|
414
|
-
fld = _type_to_field[t]
|
|
415
|
-
if t is int and _is_uint64(value):
|
|
416
|
-
fld = "uint64"
|
|
417
|
-
arg[fld] = value
|
|
418
|
-
return proto_class(**arg)
|
|
419
|
-
# List
|
|
420
|
-
if isinstance(value, list) and all(isinstance(item, t) for item in value):
|
|
421
|
-
list_class, fld = _list_type_to_class_and_field[t]
|
|
422
|
-
# Use UintList if any element is of type `uint64`.
|
|
423
|
-
if t is int and any(_is_uint64(v) for v in value):
|
|
424
|
-
list_class, fld = UintList, "uint_list"
|
|
425
|
-
arg[fld] = list_class(vals=value)
|
|
426
|
-
return proto_class(**arg)
|
|
427
|
-
# Invalid types
|
|
428
|
-
raise TypeError(
|
|
429
|
-
f"The type of the following value is not allowed "
|
|
430
|
-
f"in '{proto_class.__name__}':\n{value}"
|
|
431
|
-
)
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
|
|
435
|
-
"""Deserialize `*RecordValue` from ProtoBuf."""
|
|
436
|
-
value_field = cast(str, value_proto.WhichOneof("value"))
|
|
437
|
-
if value_field.endswith("list"):
|
|
438
|
-
value = list(getattr(value_proto, value_field).vals)
|
|
439
|
-
else:
|
|
440
|
-
value = getattr(value_proto, value_field)
|
|
441
|
-
return value
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
def _record_value_dict_to_proto(
|
|
445
|
-
value_dict: TypedDict[str, Any],
|
|
446
|
-
allowed_types: list[type],
|
|
447
|
-
value_proto_class: type[T],
|
|
448
|
-
) -> dict[str, T]:
|
|
449
|
-
"""Serialize the record value dict to ProtoBuf.
|
|
450
|
-
|
|
451
|
-
Note: `bool` MUST be put in the front of allowd_types if it exists.
|
|
452
|
-
"""
|
|
453
|
-
# Move bool to the front
|
|
454
|
-
if bool in allowed_types and allowed_types[0] != bool:
|
|
455
|
-
allowed_types.remove(bool)
|
|
456
|
-
allowed_types.insert(0, bool)
|
|
457
|
-
|
|
458
|
-
def proto(_v: Any) -> T:
|
|
459
|
-
return _record_value_to_proto(_v, allowed_types, value_proto_class)
|
|
460
|
-
|
|
461
|
-
return {k: proto(v) for k, v in value_dict.items()}
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
def _record_value_dict_from_proto(
|
|
465
|
-
value_dict_proto: MutableMapping[str, Any]
|
|
466
|
-
) -> dict[str, Any]:
|
|
467
|
-
"""Deserialize the record value dict from ProtoBuf."""
|
|
468
|
-
return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}
|
|
469
|
-
|
|
470
|
-
|
|
471
378
|
def array_to_proto(array: Array) -> ProtoArray:
|
|
472
379
|
"""Serialize Array to ProtoBuf."""
|
|
473
|
-
return ProtoArray(
|
|
380
|
+
return ProtoArray(
|
|
381
|
+
dtype=array.dtype,
|
|
382
|
+
shape=array.shape,
|
|
383
|
+
stype=array.stype,
|
|
384
|
+
data=array.data,
|
|
385
|
+
)
|
|
474
386
|
|
|
475
387
|
|
|
476
388
|
def array_from_proto(array_proto: ProtoArray) -> Array:
|
|
477
389
|
"""Deserialize Array from ProtoBuf."""
|
|
478
390
|
return Array(
|
|
479
391
|
dtype=array_proto.dtype,
|
|
480
|
-
shape=
|
|
392
|
+
shape=tuple(array_proto.shape),
|
|
481
393
|
stype=array_proto.stype,
|
|
482
394
|
data=array_proto.data,
|
|
483
395
|
)
|
|
@@ -486,8 +398,10 @@ def array_from_proto(array_proto: ProtoArray) -> Array:
|
|
|
486
398
|
def array_record_to_proto(record: ArrayRecord) -> ProtoArrayRecord:
|
|
487
399
|
"""Serialize ArrayRecord to ProtoBuf."""
|
|
488
400
|
return ProtoArrayRecord(
|
|
489
|
-
|
|
490
|
-
|
|
401
|
+
items=[
|
|
402
|
+
ProtoArrayRecord.Item(key=k, value=array_to_proto(v))
|
|
403
|
+
for k, v in record.items()
|
|
404
|
+
]
|
|
491
405
|
)
|
|
492
406
|
|
|
493
407
|
|
|
@@ -497,7 +411,7 @@ def array_record_from_proto(
|
|
|
497
411
|
"""Deserialize ArrayRecord from ProtoBuf."""
|
|
498
412
|
return ArrayRecord(
|
|
499
413
|
array_dict=OrderedDict(
|
|
500
|
-
|
|
414
|
+
{item.key: array_from_proto(item.value) for item in record_proto.items}
|
|
501
415
|
),
|
|
502
416
|
keep_input=False,
|
|
503
417
|
)
|
|
@@ -505,17 +419,19 @@ def array_record_from_proto(
|
|
|
505
419
|
|
|
506
420
|
def metric_record_to_proto(record: MetricRecord) -> ProtoMetricRecord:
|
|
507
421
|
"""Serialize MetricRecord to ProtoBuf."""
|
|
422
|
+
protos = record_value_dict_to_proto(record, [float, int], ProtoMetricRecordValue)
|
|
508
423
|
return ProtoMetricRecord(
|
|
509
|
-
|
|
424
|
+
items=[ProtoMetricRecord.Item(key=k, value=v) for k, v in protos.items()]
|
|
510
425
|
)
|
|
511
426
|
|
|
512
427
|
|
|
513
428
|
def metric_record_from_proto(record_proto: ProtoMetricRecord) -> MetricRecord:
|
|
514
429
|
"""Deserialize MetricRecord from ProtoBuf."""
|
|
430
|
+
protos = {item.key: item.value for item in record_proto.items}
|
|
515
431
|
return MetricRecord(
|
|
516
432
|
metric_dict=cast(
|
|
517
433
|
dict[str, typing.MetricRecordValues],
|
|
518
|
-
|
|
434
|
+
record_value_dict_from_proto(protos),
|
|
519
435
|
),
|
|
520
436
|
keep_input=False,
|
|
521
437
|
)
|
|
@@ -523,68 +439,60 @@ def metric_record_from_proto(record_proto: ProtoMetricRecord) -> MetricRecord:
|
|
|
523
439
|
|
|
524
440
|
def config_record_to_proto(record: ConfigRecord) -> ProtoConfigRecord:
|
|
525
441
|
"""Serialize ConfigRecord to ProtoBuf."""
|
|
442
|
+
protos = record_value_dict_to_proto(
|
|
443
|
+
record,
|
|
444
|
+
[bool, int, float, str, bytes],
|
|
445
|
+
ProtoConfigRecordValue,
|
|
446
|
+
)
|
|
526
447
|
return ProtoConfigRecord(
|
|
527
|
-
|
|
528
|
-
record,
|
|
529
|
-
[bool, int, float, str, bytes],
|
|
530
|
-
ProtoConfigRecordValue,
|
|
531
|
-
)
|
|
448
|
+
items=[ProtoConfigRecord.Item(key=k, value=v) for k, v in protos.items()]
|
|
532
449
|
)
|
|
533
450
|
|
|
534
451
|
|
|
535
452
|
def config_record_from_proto(record_proto: ProtoConfigRecord) -> ConfigRecord:
|
|
536
453
|
"""Deserialize ConfigRecord from ProtoBuf."""
|
|
454
|
+
protos = {item.key: item.value for item in record_proto.items}
|
|
537
455
|
return ConfigRecord(
|
|
538
456
|
config_dict=cast(
|
|
539
457
|
dict[str, typing.ConfigRecordValues],
|
|
540
|
-
|
|
458
|
+
record_value_dict_from_proto(protos),
|
|
541
459
|
),
|
|
542
460
|
keep_input=False,
|
|
543
461
|
)
|
|
544
462
|
|
|
545
463
|
|
|
546
|
-
# === Error message ===
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
def error_to_proto(error: Error) -> ProtoError:
|
|
550
|
-
"""Serialize Error to ProtoBuf."""
|
|
551
|
-
reason = error.reason if error.reason else ""
|
|
552
|
-
return ProtoError(code=error.code, reason=reason)
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
def error_from_proto(error_proto: ProtoError) -> Error:
|
|
556
|
-
"""Deserialize Error from ProtoBuf."""
|
|
557
|
-
reason = error_proto.reason if len(error_proto.reason) > 0 else None
|
|
558
|
-
return Error(code=error_proto.code, reason=reason)
|
|
559
|
-
|
|
560
|
-
|
|
561
464
|
# === RecordDict message ===
|
|
562
465
|
|
|
563
466
|
|
|
564
467
|
def recorddict_to_proto(recorddict: RecordDict) -> ProtoRecordDict:
|
|
565
468
|
"""Serialize RecordDict to ProtoBuf."""
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
469
|
+
item_cls = ProtoRecordDict.Item
|
|
470
|
+
items: list[ProtoRecordDict.Item] = []
|
|
471
|
+
for k, v in recorddict.items():
|
|
472
|
+
if isinstance(v, ArrayRecord):
|
|
473
|
+
items += [item_cls(key=k, array_record=array_record_to_proto(v))]
|
|
474
|
+
elif isinstance(v, MetricRecord):
|
|
475
|
+
items += [item_cls(key=k, metric_record=metric_record_to_proto(v))]
|
|
476
|
+
elif isinstance(v, ConfigRecord):
|
|
477
|
+
items += [item_cls(key=k, config_record=config_record_to_proto(v))]
|
|
478
|
+
else:
|
|
479
|
+
raise ValueError(f"Unsupported record type: {type(v)}")
|
|
480
|
+
return ProtoRecordDict(items=items)
|
|
577
481
|
|
|
578
482
|
|
|
579
483
|
def recorddict_from_proto(recorddict_proto: ProtoRecordDict) -> RecordDict:
|
|
580
484
|
"""Deserialize RecordDict from ProtoBuf."""
|
|
581
485
|
ret = RecordDict()
|
|
582
|
-
for
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
486
|
+
for item in recorddict_proto.items:
|
|
487
|
+
field = item.WhichOneof("value")
|
|
488
|
+
if field == "array_record":
|
|
489
|
+
ret[item.key] = array_record_from_proto(item.array_record)
|
|
490
|
+
elif field == "metric_record":
|
|
491
|
+
ret[item.key] = metric_record_from_proto(item.metric_record)
|
|
492
|
+
elif field == "config_record":
|
|
493
|
+
ret[item.key] = config_record_from_proto(item.config_record)
|
|
494
|
+
else:
|
|
495
|
+
raise ValueError(f"Unsupported record type: {field}")
|
|
588
496
|
return ret
|
|
589
497
|
|
|
590
498
|
|
|
@@ -646,41 +554,6 @@ def user_config_value_from_proto(scalar_msg: Scalar) -> typing.UserConfigValue:
|
|
|
646
554
|
return cast(typing.UserConfigValue, scalar)
|
|
647
555
|
|
|
648
556
|
|
|
649
|
-
# === Metadata messages ===
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
def metadata_to_proto(metadata: Metadata) -> ProtoMetadata:
|
|
653
|
-
"""Serialize `Metadata` to ProtoBuf."""
|
|
654
|
-
proto = ProtoMetadata( # pylint: disable=E1101
|
|
655
|
-
run_id=metadata.run_id,
|
|
656
|
-
message_id=metadata.message_id,
|
|
657
|
-
src_node_id=metadata.src_node_id,
|
|
658
|
-
dst_node_id=metadata.dst_node_id,
|
|
659
|
-
reply_to_message_id=metadata.reply_to_message_id,
|
|
660
|
-
group_id=metadata.group_id,
|
|
661
|
-
ttl=metadata.ttl,
|
|
662
|
-
message_type=metadata.message_type,
|
|
663
|
-
created_at=metadata.created_at,
|
|
664
|
-
)
|
|
665
|
-
return proto
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
def metadata_from_proto(metadata_proto: ProtoMetadata) -> Metadata:
|
|
669
|
-
"""Deserialize `Metadata` from ProtoBuf."""
|
|
670
|
-
metadata = Metadata(
|
|
671
|
-
run_id=metadata_proto.run_id,
|
|
672
|
-
message_id=metadata_proto.message_id,
|
|
673
|
-
src_node_id=metadata_proto.src_node_id,
|
|
674
|
-
dst_node_id=metadata_proto.dst_node_id,
|
|
675
|
-
reply_to_message_id=metadata_proto.reply_to_message_id,
|
|
676
|
-
group_id=metadata_proto.group_id,
|
|
677
|
-
created_at=metadata_proto.created_at,
|
|
678
|
-
ttl=metadata_proto.ttl,
|
|
679
|
-
message_type=metadata_proto.message_type,
|
|
680
|
-
)
|
|
681
|
-
return metadata
|
|
682
|
-
|
|
683
|
-
|
|
684
557
|
# === Message messages ===
|
|
685
558
|
|
|
686
559
|
|
|
@@ -756,6 +629,7 @@ def run_to_proto(run: typing.Run) -> ProtoRun:
|
|
|
756
629
|
running_at=run.running_at,
|
|
757
630
|
finished_at=run.finished_at,
|
|
758
631
|
status=run_status_to_proto(run.status),
|
|
632
|
+
flwr_aid=run.flwr_aid,
|
|
759
633
|
)
|
|
760
634
|
return proto
|
|
761
635
|
|
|
@@ -773,37 +647,11 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
|
|
|
773
647
|
running_at=run_proto.running_at,
|
|
774
648
|
finished_at=run_proto.finished_at,
|
|
775
649
|
status=run_status_from_proto(run_proto.status),
|
|
650
|
+
flwr_aid=run_proto.flwr_aid,
|
|
776
651
|
)
|
|
777
652
|
return run
|
|
778
653
|
|
|
779
654
|
|
|
780
|
-
# === ClientApp status messages ===
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
def clientappstatus_to_proto(
|
|
784
|
-
status: typing.ClientAppOutputStatus,
|
|
785
|
-
) -> ClientAppOutputStatus:
|
|
786
|
-
"""Serialize `ClientAppOutputStatus` to ProtoBuf."""
|
|
787
|
-
code = ClientAppOutputCode.SUCCESS
|
|
788
|
-
if status.code == typing.ClientAppOutputCode.DEADLINE_EXCEEDED:
|
|
789
|
-
code = ClientAppOutputCode.DEADLINE_EXCEEDED
|
|
790
|
-
if status.code == typing.ClientAppOutputCode.UNKNOWN_ERROR:
|
|
791
|
-
code = ClientAppOutputCode.UNKNOWN_ERROR
|
|
792
|
-
return ClientAppOutputStatus(code=code, message=status.message)
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
def clientappstatus_from_proto(
|
|
796
|
-
msg: ClientAppOutputStatus,
|
|
797
|
-
) -> typing.ClientAppOutputStatus:
|
|
798
|
-
"""Deserialize `ClientAppOutputStatus` from ProtoBuf."""
|
|
799
|
-
code = typing.ClientAppOutputCode.SUCCESS
|
|
800
|
-
if msg.code == ClientAppOutputCode.DEADLINE_EXCEEDED:
|
|
801
|
-
code = typing.ClientAppOutputCode.DEADLINE_EXCEEDED
|
|
802
|
-
if msg.code == ClientAppOutputCode.UNKNOWN_ERROR:
|
|
803
|
-
code = typing.ClientAppOutputCode.UNKNOWN_ERROR
|
|
804
|
-
return typing.ClientAppOutputStatus(code=code, message=msg.message)
|
|
805
|
-
|
|
806
|
-
|
|
807
655
|
# === Run status ===
|
|
808
656
|
|
|
809
657
|
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Utils for serde."""
|
|
16
|
+
|
|
17
|
+
from collections.abc import MutableMapping
|
|
18
|
+
from typing import Any, TypeVar, cast
|
|
19
|
+
|
|
20
|
+
from google.protobuf.message import Message as GrpcMessage
|
|
21
|
+
|
|
22
|
+
# pylint: disable=E0611
|
|
23
|
+
from flwr.proto.error_pb2 import Error as ProtoError
|
|
24
|
+
from flwr.proto.message_pb2 import Metadata as ProtoMetadata
|
|
25
|
+
from flwr.proto.recorddict_pb2 import (
|
|
26
|
+
BoolList,
|
|
27
|
+
BytesList,
|
|
28
|
+
DoubleList,
|
|
29
|
+
SintList,
|
|
30
|
+
StringList,
|
|
31
|
+
UintList,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
from ..app.error import Error
|
|
35
|
+
from ..app.metadata import Metadata
|
|
36
|
+
from .constant import INT64_MAX_VALUE
|
|
37
|
+
from .record.typeddict import TypedDict
|
|
38
|
+
|
|
39
|
+
# pylint: enable=E0611
|
|
40
|
+
|
|
41
|
+
_type_to_field: dict[type, str] = {
|
|
42
|
+
float: "double",
|
|
43
|
+
int: "sint64",
|
|
44
|
+
bool: "bool",
|
|
45
|
+
str: "string",
|
|
46
|
+
bytes: "bytes",
|
|
47
|
+
}
|
|
48
|
+
_list_type_to_class_and_field: dict[type, tuple[type[GrpcMessage], str]] = {
|
|
49
|
+
float: (DoubleList, "double_list"),
|
|
50
|
+
int: (SintList, "sint_list"),
|
|
51
|
+
bool: (BoolList, "bool_list"),
|
|
52
|
+
str: (StringList, "string_list"),
|
|
53
|
+
bytes: (BytesList, "bytes_list"),
|
|
54
|
+
}
|
|
55
|
+
T = TypeVar("T")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _is_uint64(value: Any) -> bool:
|
|
59
|
+
"""Check if a value is uint64."""
|
|
60
|
+
return isinstance(value, int) and value > INT64_MAX_VALUE
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _record_value_to_proto(
|
|
64
|
+
value: Any, allowed_types: list[type], proto_class: type[T]
|
|
65
|
+
) -> T:
|
|
66
|
+
"""Serialize `*RecordValue` to ProtoBuf.
|
|
67
|
+
|
|
68
|
+
Note: `bool` MUST be put in the front of allowd_types if it exists.
|
|
69
|
+
"""
|
|
70
|
+
arg = {}
|
|
71
|
+
for t in allowed_types:
|
|
72
|
+
# Single element
|
|
73
|
+
# Note: `isinstance(False, int) == True`.
|
|
74
|
+
if isinstance(value, t):
|
|
75
|
+
fld = _type_to_field[t]
|
|
76
|
+
if t is int and _is_uint64(value):
|
|
77
|
+
fld = "uint64"
|
|
78
|
+
arg[fld] = value
|
|
79
|
+
return proto_class(**arg)
|
|
80
|
+
# List
|
|
81
|
+
if isinstance(value, list) and all(isinstance(item, t) for item in value):
|
|
82
|
+
list_class, fld = _list_type_to_class_and_field[t]
|
|
83
|
+
# Use UintList if any element is of type `uint64`.
|
|
84
|
+
if t is int and any(_is_uint64(v) for v in value):
|
|
85
|
+
list_class, fld = UintList, "uint_list"
|
|
86
|
+
arg[fld] = list_class(vals=value)
|
|
87
|
+
return proto_class(**arg)
|
|
88
|
+
# Invalid types
|
|
89
|
+
raise TypeError(
|
|
90
|
+
f"The type of the following value is not allowed "
|
|
91
|
+
f"in '{proto_class.__name__}':\n{value}"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
|
|
96
|
+
"""Deserialize `*RecordValue` from ProtoBuf."""
|
|
97
|
+
value_field = cast(str, value_proto.WhichOneof("value"))
|
|
98
|
+
if value_field.endswith("list"):
|
|
99
|
+
value = list(getattr(value_proto, value_field).vals)
|
|
100
|
+
else:
|
|
101
|
+
value = getattr(value_proto, value_field)
|
|
102
|
+
return value
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def record_value_dict_to_proto(
|
|
106
|
+
value_dict: TypedDict[str, Any],
|
|
107
|
+
allowed_types: list[type],
|
|
108
|
+
value_proto_class: type[T],
|
|
109
|
+
) -> dict[str, T]:
|
|
110
|
+
"""Serialize the record value dict to ProtoBuf.
|
|
111
|
+
|
|
112
|
+
This function will preserve the order of the keys in the input dictionary.
|
|
113
|
+
|
|
114
|
+
Note: `bool` MUST be put in the front of allowd_types if it exists.
|
|
115
|
+
"""
|
|
116
|
+
# Move bool to the front
|
|
117
|
+
if bool in allowed_types and allowed_types[0] != bool:
|
|
118
|
+
allowed_types.remove(bool)
|
|
119
|
+
allowed_types.insert(0, bool)
|
|
120
|
+
|
|
121
|
+
def proto(_v: Any) -> T:
|
|
122
|
+
return _record_value_to_proto(_v, allowed_types, value_proto_class)
|
|
123
|
+
|
|
124
|
+
return {k: proto(v) for k, v in value_dict.items()}
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def record_value_dict_from_proto(
|
|
128
|
+
value_dict_proto: MutableMapping[str, Any]
|
|
129
|
+
) -> dict[str, Any]:
|
|
130
|
+
"""Deserialize the record value dict from ProtoBuf."""
|
|
131
|
+
return {k: _record_value_from_proto(v) for k, v in value_dict_proto.items()}
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def error_to_proto(error: Error) -> ProtoError:
|
|
135
|
+
"""Serialize Error to ProtoBuf."""
|
|
136
|
+
reason = error.reason if error.reason else ""
|
|
137
|
+
return ProtoError(code=error.code, reason=reason)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def error_from_proto(error_proto: ProtoError) -> Error:
|
|
141
|
+
"""Deserialize Error from ProtoBuf."""
|
|
142
|
+
reason = error_proto.reason if len(error_proto.reason) > 0 else None
|
|
143
|
+
return Error(code=error_proto.code, reason=reason)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def metadata_to_proto(metadata: Metadata) -> ProtoMetadata:
|
|
147
|
+
"""Serialize `Metadata` to ProtoBuf."""
|
|
148
|
+
proto = ProtoMetadata( # pylint: disable=E1101
|
|
149
|
+
run_id=metadata.run_id,
|
|
150
|
+
message_id=metadata.message_id,
|
|
151
|
+
src_node_id=metadata.src_node_id,
|
|
152
|
+
dst_node_id=metadata.dst_node_id,
|
|
153
|
+
reply_to_message_id=metadata.reply_to_message_id,
|
|
154
|
+
group_id=metadata.group_id,
|
|
155
|
+
ttl=metadata.ttl,
|
|
156
|
+
message_type=metadata.message_type,
|
|
157
|
+
created_at=metadata.created_at,
|
|
158
|
+
)
|
|
159
|
+
return proto
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def metadata_from_proto(metadata_proto: ProtoMetadata) -> Metadata:
|
|
163
|
+
"""Deserialize `Metadata` from ProtoBuf."""
|
|
164
|
+
metadata = Metadata(
|
|
165
|
+
run_id=metadata_proto.run_id,
|
|
166
|
+
message_id=metadata_proto.message_id,
|
|
167
|
+
src_node_id=metadata_proto.src_node_id,
|
|
168
|
+
dst_node_id=metadata_proto.dst_node_id,
|
|
169
|
+
reply_to_message_id=metadata_proto.reply_to_message_id,
|
|
170
|
+
group_id=metadata_proto.group_id,
|
|
171
|
+
created_at=metadata_proto.created_at,
|
|
172
|
+
ttl=metadata_proto.ttl,
|
|
173
|
+
message_type=metadata_proto.message_type,
|
|
174
|
+
)
|
|
175
|
+
return metadata
|
flwr/common/typing.py
CHANGED
|
@@ -230,6 +230,7 @@ class Run: # pylint: disable=too-many-instance-attributes
|
|
|
230
230
|
running_at: str
|
|
231
231
|
finished_at: str
|
|
232
232
|
status: RunStatus
|
|
233
|
+
flwr_aid: str
|
|
233
234
|
|
|
234
235
|
@classmethod
|
|
235
236
|
def create_empty(cls, run_id: int) -> "Run":
|
|
@@ -245,6 +246,7 @@ class Run: # pylint: disable=too-many-instance-attributes
|
|
|
245
246
|
running_at="",
|
|
246
247
|
finished_at="",
|
|
247
248
|
status=RunStatus(status="", sub_status="", details=""),
|
|
249
|
+
flwr_aid="",
|
|
248
250
|
)
|
|
249
251
|
|
|
250
252
|
|
|
@@ -289,11 +291,11 @@ class UserAuthCredentials:
|
|
|
289
291
|
|
|
290
292
|
|
|
291
293
|
@dataclass
|
|
292
|
-
class
|
|
294
|
+
class AccountInfo:
|
|
293
295
|
"""User information for event log."""
|
|
294
296
|
|
|
295
|
-
|
|
296
|
-
|
|
297
|
+
flwr_aid: Optional[str]
|
|
298
|
+
account_name: Optional[str]
|
|
297
299
|
|
|
298
300
|
|
|
299
301
|
@dataclass
|
flwr/compat/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Compatibility package containing deprecated legacy components."""
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Legacy components previously located in ``flwr.client``."""
|