flwr 1.18.0__py3-none-any.whl → 1.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/app/__init__.py +15 -0
- flwr/app/error.py +68 -0
- flwr/app/metadata.py +223 -0
- flwr/cli/build.py +94 -59
- flwr/cli/log.py +3 -3
- flwr/cli/login/login.py +3 -7
- flwr/cli/ls.py +15 -36
- flwr/cli/new/new.py +12 -4
- flwr/cli/new/templates/app/README.flowertune.md.tpl +2 -0
- flwr/cli/new/templates/app/README.md.tpl +5 -0
- flwr/cli/new/templates/app/code/client.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/model.baseline.py.tpl +1 -1
- flwr/cli/new/templates/app/code/server.baseline.py.tpl +2 -3
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +25 -17
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +13 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +21 -2
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +19 -2
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +20 -3
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +18 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +18 -1
- flwr/cli/run/run.py +48 -49
- flwr/cli/stop.py +2 -2
- flwr/cli/utils.py +38 -5
- flwr/client/__init__.py +2 -2
- flwr/client/client_app.py +1 -1
- flwr/client/clientapp/__init__.py +0 -7
- flwr/client/grpc_adapter_client/connection.py +15 -8
- flwr/client/grpc_rere_client/connection.py +142 -97
- flwr/client/grpc_rere_client/grpc_adapter.py +34 -6
- flwr/client/message_handler/message_handler.py +1 -1
- flwr/client/mod/comms_mods.py +36 -17
- flwr/client/rest_client/connection.py +176 -103
- flwr/clientapp/__init__.py +15 -0
- flwr/common/__init__.py +2 -2
- flwr/common/auth_plugin/__init__.py +2 -0
- flwr/common/auth_plugin/auth_plugin.py +29 -3
- flwr/common/constant.py +39 -8
- flwr/common/event_log_plugin/event_log_plugin.py +3 -3
- flwr/common/exit/exit_code.py +16 -1
- flwr/common/exit_handlers.py +30 -0
- flwr/common/grpc.py +12 -1
- flwr/common/heartbeat.py +165 -0
- flwr/common/inflatable.py +290 -0
- flwr/common/inflatable_protobuf_utils.py +141 -0
- flwr/common/inflatable_utils.py +508 -0
- flwr/common/message.py +110 -242
- flwr/common/record/__init__.py +2 -1
- flwr/common/record/array.py +402 -0
- flwr/common/record/arraychunk.py +59 -0
- flwr/common/record/arrayrecord.py +103 -225
- flwr/common/record/configrecord.py +59 -4
- flwr/common/record/conversion_utils.py +1 -1
- flwr/common/record/metricrecord.py +55 -4
- flwr/common/record/recorddict.py +69 -1
- flwr/common/recorddict_compat.py +2 -2
- flwr/common/retry_invoker.py +5 -1
- flwr/common/serde.py +59 -211
- flwr/common/serde_utils.py +175 -0
- flwr/common/typing.py +5 -3
- flwr/compat/__init__.py +15 -0
- flwr/compat/client/__init__.py +15 -0
- flwr/{client → compat/client}/app.py +28 -185
- flwr/compat/common/__init__.py +15 -0
- flwr/compat/server/__init__.py +15 -0
- flwr/compat/server/app.py +174 -0
- flwr/compat/simulation/__init__.py +15 -0
- flwr/proto/appio_pb2.py +43 -0
- flwr/proto/appio_pb2.pyi +151 -0
- flwr/proto/appio_pb2_grpc.py +4 -0
- flwr/proto/appio_pb2_grpc.pyi +4 -0
- flwr/proto/clientappio_pb2.py +12 -19
- flwr/proto/clientappio_pb2.pyi +23 -101
- flwr/proto/clientappio_pb2_grpc.py +269 -28
- flwr/proto/clientappio_pb2_grpc.pyi +114 -20
- flwr/proto/fleet_pb2.py +24 -27
- flwr/proto/fleet_pb2.pyi +19 -35
- flwr/proto/fleet_pb2_grpc.py +117 -13
- flwr/proto/fleet_pb2_grpc.pyi +47 -6
- flwr/proto/heartbeat_pb2.py +33 -0
- flwr/proto/heartbeat_pb2.pyi +66 -0
- flwr/proto/heartbeat_pb2_grpc.py +4 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +28 -11
- flwr/proto/message_pb2.pyi +125 -0
- flwr/proto/recorddict_pb2.py +16 -28
- flwr/proto/recorddict_pb2.pyi +46 -64
- flwr/proto/run_pb2.py +24 -32
- flwr/proto/run_pb2.pyi +4 -52
- flwr/proto/serverappio_pb2.py +9 -23
- flwr/proto/serverappio_pb2.pyi +0 -110
- flwr/proto/serverappio_pb2_grpc.py +177 -72
- flwr/proto/serverappio_pb2_grpc.pyi +75 -33
- flwr/proto/simulationio_pb2.py +12 -11
- flwr/proto/simulationio_pb2_grpc.py +35 -0
- flwr/proto/simulationio_pb2_grpc.pyi +14 -0
- flwr/server/__init__.py +1 -1
- flwr/server/app.py +69 -187
- flwr/server/compat/app_utils.py +50 -28
- flwr/server/fleet_event_log_interceptor.py +6 -2
- flwr/server/grid/grpc_grid.py +148 -41
- flwr/server/grid/inmemory_grid.py +5 -4
- flwr/server/serverapp/app.py +45 -17
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +21 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +102 -8
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -5
- flwr/server/superlink/fleet/message_handler/message_handler.py +130 -19
- flwr/server/superlink/fleet/rest_rere/rest_api.py +73 -13
- flwr/server/superlink/fleet/vce/vce_api.py +6 -3
- flwr/server/superlink/linkstate/in_memory_linkstate.py +138 -43
- flwr/server/superlink/linkstate/linkstate.py +53 -20
- flwr/server/superlink/linkstate/sqlite_linkstate.py +149 -55
- flwr/server/superlink/linkstate/utils.py +33 -29
- flwr/server/superlink/serverappio/serverappio_grpc.py +4 -1
- flwr/server/superlink/serverappio/serverappio_servicer.py +230 -84
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -1
- flwr/server/superlink/simulation/simulationio_servicer.py +26 -2
- flwr/server/superlink/utils.py +9 -2
- flwr/server/utils/validator.py +2 -2
- flwr/serverapp/__init__.py +15 -0
- flwr/simulation/app.py +25 -0
- flwr/simulation/run_simulation.py +17 -0
- flwr/supercore/__init__.py +15 -0
- flwr/{server/superlink → supercore}/ffs/__init__.py +2 -0
- flwr/{server/superlink → supercore}/ffs/disk_ffs.py +1 -1
- flwr/supercore/grpc_health/__init__.py +22 -0
- flwr/supercore/grpc_health/simple_health_servicer.py +38 -0
- flwr/supercore/license_plugin/__init__.py +22 -0
- flwr/supercore/license_plugin/license_plugin.py +26 -0
- flwr/supercore/object_store/__init__.py +24 -0
- flwr/supercore/object_store/in_memory_object_store.py +229 -0
- flwr/supercore/object_store/object_store.py +170 -0
- flwr/supercore/object_store/object_store_factory.py +44 -0
- flwr/supercore/object_store/utils.py +43 -0
- flwr/supercore/scheduler/__init__.py +22 -0
- flwr/supercore/scheduler/plugin.py +71 -0
- flwr/{client/nodestate/nodestate.py → supercore/utils.py} +14 -13
- flwr/superexec/deployment.py +7 -4
- flwr/superexec/exec_event_log_interceptor.py +8 -4
- flwr/superexec/exec_grpc.py +25 -5
- flwr/superexec/exec_license_interceptor.py +82 -0
- flwr/superexec/exec_servicer.py +135 -24
- flwr/superexec/exec_user_auth_interceptor.py +45 -8
- flwr/superexec/executor.py +5 -1
- flwr/superexec/simulation.py +8 -3
- flwr/superlink/__init__.py +15 -0
- flwr/{client/supernode → supernode}/__init__.py +0 -7
- flwr/supernode/cli/__init__.py +24 -0
- flwr/{client/supernode/app.py → supernode/cli/flower_supernode.py} +3 -19
- flwr/supernode/cli/flwr_clientapp.py +88 -0
- flwr/supernode/nodestate/in_memory_nodestate.py +199 -0
- flwr/supernode/nodestate/nodestate.py +227 -0
- flwr/supernode/runtime/__init__.py +15 -0
- flwr/{client/clientapp/app.py → supernode/runtime/run_clientapp.py} +135 -89
- flwr/supernode/scheduler/__init__.py +22 -0
- flwr/supernode/scheduler/simple_clientapp_scheduler_plugin.py +49 -0
- flwr/supernode/servicer/__init__.py +15 -0
- flwr/supernode/servicer/clientappio/__init__.py +22 -0
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +303 -0
- flwr/supernode/start_client_internal.py +589 -0
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/METADATA +6 -4
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/RECORD +171 -123
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/WHEEL +1 -1
- {flwr-1.18.0.dist-info → flwr-1.20.0.dist-info}/entry_points.txt +2 -2
- flwr/client/clientapp/clientappio_servicer.py +0 -244
- flwr/client/heartbeat.py +0 -74
- flwr/client/nodestate/in_memory_nodestate.py +0 -38
- /flwr/{client → compat/client}/grpc_client/__init__.py +0 -0
- /flwr/{client → compat/client}/grpc_client/connection.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs.py +0 -0
- /flwr/{server/superlink → supercore}/ffs/ffs_factory.py +0 -0
- /flwr/{client → supernode}/nodestate/__init__.py +0 -0
- /flwr/{client → supernode}/nodestate/nodestate_factory.py +0 -0
|
@@ -38,8 +38,6 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
38
38
|
CreateNodeResponse,
|
|
39
39
|
DeleteNodeRequest,
|
|
40
40
|
DeleteNodeResponse,
|
|
41
|
-
PingRequest,
|
|
42
|
-
PingResponse,
|
|
43
41
|
PullMessagesRequest,
|
|
44
42
|
PullMessagesResponse,
|
|
45
43
|
PushMessagesRequest,
|
|
@@ -47,6 +45,18 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
47
45
|
)
|
|
48
46
|
from flwr.proto.grpcadapter_pb2 import MessageContainer # pylint: disable=E0611
|
|
49
47
|
from flwr.proto.grpcadapter_pb2_grpc import GrpcAdapterStub
|
|
48
|
+
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
49
|
+
SendNodeHeartbeatRequest,
|
|
50
|
+
SendNodeHeartbeatResponse,
|
|
51
|
+
)
|
|
52
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
53
|
+
ConfirmMessageReceivedRequest,
|
|
54
|
+
ConfirmMessageReceivedResponse,
|
|
55
|
+
PullObjectRequest,
|
|
56
|
+
PullObjectResponse,
|
|
57
|
+
PushObjectRequest,
|
|
58
|
+
PushObjectResponse,
|
|
59
|
+
)
|
|
50
60
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
51
61
|
|
|
52
62
|
T = TypeVar("T", bound=GrpcMessage)
|
|
@@ -120,11 +130,11 @@ class GrpcAdapter:
|
|
|
120
130
|
"""."""
|
|
121
131
|
return self._send_and_receive(request, DeleteNodeResponse, **kwargs)
|
|
122
132
|
|
|
123
|
-
def
|
|
124
|
-
self, request:
|
|
125
|
-
) ->
|
|
133
|
+
def SendNodeHeartbeat( # pylint: disable=C0103
|
|
134
|
+
self, request: SendNodeHeartbeatRequest, **kwargs: Any
|
|
135
|
+
) -> SendNodeHeartbeatResponse:
|
|
126
136
|
"""."""
|
|
127
|
-
return self._send_and_receive(request,
|
|
137
|
+
return self._send_and_receive(request, SendNodeHeartbeatResponse, **kwargs)
|
|
128
138
|
|
|
129
139
|
def PullMessages( # pylint: disable=C0103
|
|
130
140
|
self, request: PullMessagesRequest, **kwargs: Any
|
|
@@ -149,3 +159,21 @@ class GrpcAdapter:
|
|
|
149
159
|
) -> GetFabResponse:
|
|
150
160
|
"""."""
|
|
151
161
|
return self._send_and_receive(request, GetFabResponse, **kwargs)
|
|
162
|
+
|
|
163
|
+
def PushObject( # pylint: disable=C0103
|
|
164
|
+
self, request: PushObjectRequest, **kwargs: Any
|
|
165
|
+
) -> PushObjectResponse:
|
|
166
|
+
"""."""
|
|
167
|
+
return self._send_and_receive(request, PushObjectResponse, **kwargs)
|
|
168
|
+
|
|
169
|
+
def PullObject( # pylint: disable=C0103
|
|
170
|
+
self, request: PullObjectRequest, **kwargs: Any
|
|
171
|
+
) -> PullObjectResponse:
|
|
172
|
+
"""."""
|
|
173
|
+
return self._send_and_receive(request, PullObjectResponse, **kwargs)
|
|
174
|
+
|
|
175
|
+
def ConfirmMessageReceived( # pylint: disable=C0103
|
|
176
|
+
self, request: ConfirmMessageReceivedRequest, **kwargs: Any
|
|
177
|
+
) -> ConfirmMessageReceivedResponse:
|
|
178
|
+
"""."""
|
|
179
|
+
return self._send_and_receive(request, ConfirmMessageReceivedResponse, **kwargs)
|
|
@@ -164,7 +164,7 @@ def validate_out_message(out_message: Message, in_message_metadata: Metadata) ->
|
|
|
164
164
|
in_meta = in_message_metadata
|
|
165
165
|
if ( # pylint: disable-next=too-many-boolean-expressions
|
|
166
166
|
out_meta.run_id == in_meta.run_id
|
|
167
|
-
and out_meta.message_id ==
|
|
167
|
+
and out_meta.message_id == out_message.object_id # Should match the object id
|
|
168
168
|
and out_meta.src_node_id == in_meta.dst_node_id
|
|
169
169
|
and out_meta.dst_node_id == in_meta.src_node_id
|
|
170
170
|
and out_meta.reply_to_message_id == in_meta.message_id
|
flwr/client/mod/comms_mods.py
CHANGED
|
@@ -32,14 +32,17 @@ def message_size_mod(
|
|
|
32
32
|
|
|
33
33
|
This mod logs the size in bytes of the message being transmited.
|
|
34
34
|
"""
|
|
35
|
-
|
|
35
|
+
# Log the size of the incoming message in bytes
|
|
36
|
+
total_bytes = sum(record.count_bytes() for record in msg.content.values())
|
|
37
|
+
log(INFO, "Incoming message size: %i bytes", total_bytes)
|
|
36
38
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
+
# Call the next layer
|
|
40
|
+
msg = call_next(msg, ctxt)
|
|
39
41
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
42
|
+
# Log the size of the outgoing message in bytes
|
|
43
|
+
total_bytes = sum(record.count_bytes() for record in msg.content.values())
|
|
44
|
+
log(INFO, "Outgoing message size: %i bytes", total_bytes)
|
|
45
|
+
return msg
|
|
43
46
|
|
|
44
47
|
|
|
45
48
|
def arrays_size_mod(
|
|
@@ -50,25 +53,41 @@ def arrays_size_mod(
|
|
|
50
53
|
This mod logs the number of array elements transmitted in ``ArrayRecord`` objects
|
|
51
54
|
of the message as well as their sizes in bytes.
|
|
52
55
|
"""
|
|
53
|
-
|
|
54
|
-
|
|
56
|
+
# Log the ArrayRecord size statistics and the total size in the incoming message
|
|
57
|
+
array_record_size_stats = _get_array_record_size_stats(msg)
|
|
58
|
+
total_bytes = sum(stat["bytes"] for stat in array_record_size_stats.values())
|
|
59
|
+
if array_record_size_stats:
|
|
60
|
+
log(INFO, "Incoming `ArrayRecord` size statistics:")
|
|
61
|
+
log(INFO, array_record_size_stats)
|
|
62
|
+
log(INFO, "Total array elements received: %i bytes", total_bytes)
|
|
63
|
+
|
|
64
|
+
msg = call_next(msg, ctxt)
|
|
65
|
+
|
|
66
|
+
# Log the ArrayRecord size statistics and the total size in the outgoing message
|
|
67
|
+
array_record_size_stats = _get_array_record_size_stats(msg)
|
|
68
|
+
total_bytes = sum(stat["bytes"] for stat in array_record_size_stats.values())
|
|
69
|
+
if array_record_size_stats:
|
|
70
|
+
log(INFO, "Outgoing `ArrayRecord` size statistics:")
|
|
71
|
+
log(INFO, array_record_size_stats)
|
|
72
|
+
log(INFO, "Total array elements sent: %i bytes", total_bytes)
|
|
73
|
+
return msg
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _get_array_record_size_stats(
|
|
77
|
+
msg: Message,
|
|
78
|
+
) -> dict[str, dict[str, int]]:
|
|
79
|
+
"""Get `ArrayRecord` size statistics from the message."""
|
|
80
|
+
array_record_size_stats = {}
|
|
55
81
|
for record_name, arr_record in msg.content.array_records.items():
|
|
56
82
|
arr_record_bytes = arr_record.count_bytes()
|
|
57
|
-
arrays_size_in_bytes += arr_record_bytes
|
|
58
83
|
element_count = 0
|
|
59
84
|
for array in arr_record.values():
|
|
60
85
|
element_count += (
|
|
61
86
|
int(np.prod(array.shape)) if array.shape else array.numpy().size
|
|
62
87
|
)
|
|
63
88
|
|
|
64
|
-
|
|
89
|
+
array_record_size_stats[record_name] = {
|
|
65
90
|
"elements": element_count,
|
|
66
91
|
"bytes": arr_record_bytes,
|
|
67
92
|
}
|
|
68
|
-
|
|
69
|
-
if model_size_stats:
|
|
70
|
-
log(INFO, model_size_stats)
|
|
71
|
-
|
|
72
|
-
log(INFO, "Total array elements transmitted: %i bytes", arrays_size_in_bytes)
|
|
73
|
-
|
|
74
|
-
return call_next(msg, ctxt)
|
|
93
|
+
return array_record_size_stats
|
|
@@ -15,30 +15,26 @@
|
|
|
15
15
|
"""Contextmanager for a REST request-response channel to the Flower server."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
import random
|
|
19
|
-
import threading
|
|
20
18
|
from collections.abc import Iterator
|
|
21
19
|
from contextlib import contextmanager
|
|
22
|
-
from
|
|
23
|
-
from logging import ERROR, INFO, WARN
|
|
20
|
+
from logging import ERROR, WARN
|
|
24
21
|
from typing import Callable, Optional, TypeVar, Union
|
|
25
22
|
|
|
26
23
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
27
24
|
from google.protobuf.message import Message as GrpcMessage
|
|
28
25
|
from requests.exceptions import ConnectionError as RequestsConnectionError
|
|
29
26
|
|
|
30
|
-
from flwr.client.heartbeat import start_ping_loop
|
|
31
|
-
from flwr.client.message_handler.message_handler import validate_out_message
|
|
32
27
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
|
33
|
-
from flwr.common.constant import
|
|
34
|
-
PING_BASE_MULTIPLIER,
|
|
35
|
-
PING_CALL_TIMEOUT,
|
|
36
|
-
PING_DEFAULT_INTERVAL,
|
|
37
|
-
PING_RANDOM_RANGE,
|
|
38
|
-
)
|
|
28
|
+
from flwr.common.constant import HEARTBEAT_DEFAULT_INTERVAL
|
|
39
29
|
from flwr.common.exit import ExitCode, flwr_exit
|
|
30
|
+
from flwr.common.heartbeat import HeartbeatSender
|
|
31
|
+
from flwr.common.inflatable_protobuf_utils import (
|
|
32
|
+
make_confirm_message_received_fn_protobuf,
|
|
33
|
+
make_pull_object_fn_protobuf,
|
|
34
|
+
make_push_object_fn_protobuf,
|
|
35
|
+
)
|
|
40
36
|
from flwr.common.logger import log
|
|
41
|
-
from flwr.common.message import Message,
|
|
37
|
+
from flwr.common.message import Message, remove_content_from_message
|
|
42
38
|
from flwr.common.retry_invoker import RetryInvoker
|
|
43
39
|
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
|
|
44
40
|
from flwr.common.typing import Fab, Run
|
|
@@ -48,13 +44,24 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
|
|
|
48
44
|
CreateNodeResponse,
|
|
49
45
|
DeleteNodeRequest,
|
|
50
46
|
DeleteNodeResponse,
|
|
51
|
-
PingRequest,
|
|
52
|
-
PingResponse,
|
|
53
47
|
PullMessagesRequest,
|
|
54
48
|
PullMessagesResponse,
|
|
55
49
|
PushMessagesRequest,
|
|
56
50
|
PushMessagesResponse,
|
|
57
51
|
)
|
|
52
|
+
from flwr.proto.heartbeat_pb2 import ( # pylint: disable=E0611
|
|
53
|
+
SendNodeHeartbeatRequest,
|
|
54
|
+
SendNodeHeartbeatResponse,
|
|
55
|
+
)
|
|
56
|
+
from flwr.proto.message_pb2 import ( # pylint: disable=E0611
|
|
57
|
+
ConfirmMessageReceivedRequest,
|
|
58
|
+
ConfirmMessageReceivedResponse,
|
|
59
|
+
ObjectTree,
|
|
60
|
+
PullObjectRequest,
|
|
61
|
+
PullObjectResponse,
|
|
62
|
+
PushObjectRequest,
|
|
63
|
+
PushObjectResponse,
|
|
64
|
+
)
|
|
58
65
|
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
59
66
|
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
|
|
60
67
|
|
|
@@ -68,9 +75,12 @@ PATH_CREATE_NODE: str = "api/v0/fleet/create-node"
|
|
|
68
75
|
PATH_DELETE_NODE: str = "api/v0/fleet/delete-node"
|
|
69
76
|
PATH_PULL_MESSAGES: str = "/api/v0/fleet/pull-messages"
|
|
70
77
|
PATH_PUSH_MESSAGES: str = "/api/v0/fleet/push-messages"
|
|
71
|
-
|
|
78
|
+
PATH_PULL_OBJECT: str = "/api/v0/fleet/pull-object"
|
|
79
|
+
PATH_PUSH_OBJECT: str = "/api/v0/fleet/push-object"
|
|
80
|
+
PATH_SEND_NODE_HEARTBEAT: str = "api/v0/fleet/send-node-heartbeat"
|
|
72
81
|
PATH_GET_RUN: str = "/api/v0/fleet/get-run"
|
|
73
82
|
PATH_GET_FAB: str = "/api/v0/fleet/get-fab"
|
|
83
|
+
PATH_CONFIRM_MESSAGE_RECEIVED: str = "/api/v0/fleet/confirm-message-received"
|
|
74
84
|
|
|
75
85
|
T = TypeVar("T", bound=GrpcMessage)
|
|
76
86
|
|
|
@@ -89,12 +99,15 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
89
99
|
] = None,
|
|
90
100
|
) -> Iterator[
|
|
91
101
|
tuple[
|
|
92
|
-
Callable[[], Optional[Message]],
|
|
93
|
-
Callable[[Message],
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
102
|
+
Callable[[], Optional[tuple[Message, ObjectTree]]],
|
|
103
|
+
Callable[[Message, ObjectTree], set[str]],
|
|
104
|
+
Callable[[], Optional[int]],
|
|
105
|
+
Callable[[], None],
|
|
106
|
+
Callable[[int], Run],
|
|
107
|
+
Callable[[str, int], Fab],
|
|
108
|
+
Callable[[int, str], bytes],
|
|
109
|
+
Callable[[int, str, bytes], None],
|
|
110
|
+
Callable[[int, str], None],
|
|
98
111
|
]
|
|
99
112
|
]:
|
|
100
113
|
"""Primitives for request/response-based interaction with a server.
|
|
@@ -130,6 +143,9 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
130
143
|
create_node : Optional[Callable]
|
|
131
144
|
delete_node : Optional[Callable]
|
|
132
145
|
get_run : Optional[Callable]
|
|
146
|
+
pull_object : Callable[[str], bytes]
|
|
147
|
+
push_object : Callable[[str, bytes], None]
|
|
148
|
+
confirm_message_received : Callable[[str], None]
|
|
133
149
|
"""
|
|
134
150
|
log(
|
|
135
151
|
WARN,
|
|
@@ -158,13 +174,10 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
158
174
|
log(ERROR, "Client authentication is not supported for this transport type.")
|
|
159
175
|
|
|
160
176
|
# Shared variables for inner functions
|
|
161
|
-
metadata: Optional[Metadata] = None
|
|
162
177
|
node: Optional[Node] = None
|
|
163
|
-
ping_thread: Optional[threading.Thread] = None
|
|
164
|
-
ping_stop_event = threading.Event()
|
|
165
178
|
|
|
166
179
|
###########################################################################
|
|
167
|
-
#
|
|
180
|
+
# heartbeat/create_node/delete_node/receive/send/get_run functions
|
|
168
181
|
###########################################################################
|
|
169
182
|
|
|
170
183
|
def _request(
|
|
@@ -214,57 +227,89 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
214
227
|
grpc_res.ParseFromString(res.content)
|
|
215
228
|
return grpc_res
|
|
216
229
|
|
|
217
|
-
def
|
|
230
|
+
def _pull_object_protobuf(request: PullObjectRequest) -> PullObjectResponse:
|
|
231
|
+
res = _request(
|
|
232
|
+
req=request,
|
|
233
|
+
res_type=PullObjectResponse,
|
|
234
|
+
api_path=PATH_PULL_OBJECT,
|
|
235
|
+
)
|
|
236
|
+
if res is None:
|
|
237
|
+
raise ValueError(f"{PullObjectResponse.__name__} is None.")
|
|
238
|
+
return res
|
|
239
|
+
|
|
240
|
+
def _push_object_protobuf(request: PushObjectRequest) -> PushObjectResponse:
|
|
241
|
+
res = _request(
|
|
242
|
+
req=request,
|
|
243
|
+
res_type=PushObjectResponse,
|
|
244
|
+
api_path=PATH_PUSH_OBJECT,
|
|
245
|
+
)
|
|
246
|
+
if res is None:
|
|
247
|
+
raise ValueError(f"{PushObjectResponse.__name__} is None.")
|
|
248
|
+
return res
|
|
249
|
+
|
|
250
|
+
def _confirm_message_received_protobuf(
|
|
251
|
+
request: ConfirmMessageReceivedRequest,
|
|
252
|
+
) -> ConfirmMessageReceivedResponse:
|
|
253
|
+
res = _request(
|
|
254
|
+
req=request,
|
|
255
|
+
res_type=ConfirmMessageReceivedResponse,
|
|
256
|
+
api_path=PATH_CONFIRM_MESSAGE_RECEIVED,
|
|
257
|
+
)
|
|
258
|
+
if res is None:
|
|
259
|
+
raise ValueError(f"{ConfirmMessageReceivedResponse.__name__} is None.")
|
|
260
|
+
return res
|
|
261
|
+
|
|
262
|
+
def send_node_heartbeat() -> bool:
|
|
218
263
|
# Get Node
|
|
219
264
|
if node is None:
|
|
220
265
|
log(ERROR, "Node instance missing")
|
|
221
|
-
return
|
|
266
|
+
return False
|
|
222
267
|
|
|
223
|
-
# Construct the
|
|
224
|
-
req =
|
|
268
|
+
# Construct the heartbeat request
|
|
269
|
+
req = SendNodeHeartbeatRequest(
|
|
270
|
+
node=node, heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL
|
|
271
|
+
)
|
|
225
272
|
|
|
226
273
|
# Send the request
|
|
227
|
-
res = _request(
|
|
274
|
+
res = _request(
|
|
275
|
+
req, SendNodeHeartbeatResponse, PATH_SEND_NODE_HEARTBEAT, retry=False
|
|
276
|
+
)
|
|
228
277
|
if res is None:
|
|
229
|
-
return
|
|
278
|
+
return False
|
|
230
279
|
|
|
231
280
|
# Check if success
|
|
232
281
|
if not res.success:
|
|
233
|
-
raise RuntimeError(
|
|
282
|
+
raise RuntimeError(
|
|
283
|
+
"Heartbeat failed unexpectedly. The SuperLink does not "
|
|
284
|
+
"recognize this SuperNode."
|
|
285
|
+
)
|
|
286
|
+
return True
|
|
234
287
|
|
|
235
|
-
|
|
236
|
-
rd = random.uniform(*PING_RANDOM_RANGE)
|
|
237
|
-
next_interval: float = PING_DEFAULT_INTERVAL - PING_CALL_TIMEOUT
|
|
238
|
-
next_interval *= PING_BASE_MULTIPLIER + rd
|
|
239
|
-
if not ping_stop_event.is_set():
|
|
240
|
-
ping_stop_event.wait(next_interval)
|
|
288
|
+
heartbeat_sender = HeartbeatSender(send_node_heartbeat)
|
|
241
289
|
|
|
242
290
|
def create_node() -> Optional[int]:
|
|
243
291
|
"""Set create_node."""
|
|
244
|
-
req = CreateNodeRequest(
|
|
292
|
+
req = CreateNodeRequest(heartbeat_interval=HEARTBEAT_DEFAULT_INTERVAL)
|
|
245
293
|
|
|
246
294
|
# Send the request
|
|
247
295
|
res = _request(req, CreateNodeResponse, PATH_CREATE_NODE)
|
|
248
296
|
if res is None:
|
|
249
297
|
return None
|
|
250
298
|
|
|
251
|
-
# Remember the node and the
|
|
252
|
-
nonlocal node
|
|
299
|
+
# Remember the node and start the heartbeat sender
|
|
300
|
+
nonlocal node
|
|
253
301
|
node = res.node
|
|
254
|
-
|
|
302
|
+
heartbeat_sender.start()
|
|
255
303
|
return node.node_id
|
|
256
304
|
|
|
257
305
|
def delete_node() -> None:
|
|
258
306
|
"""Set delete_node."""
|
|
259
307
|
nonlocal node
|
|
260
308
|
if node is None:
|
|
261
|
-
|
|
262
|
-
return
|
|
309
|
+
raise RuntimeError("Node instance missing")
|
|
263
310
|
|
|
264
|
-
# Stop the
|
|
265
|
-
|
|
266
|
-
if ping_thread is not None:
|
|
267
|
-
ping_thread.join()
|
|
311
|
+
# Stop the heartbeat sender
|
|
312
|
+
heartbeat_sender.stop()
|
|
268
313
|
|
|
269
314
|
# Send DeleteNode request
|
|
270
315
|
req = DeleteNodeRequest(node=node)
|
|
@@ -277,75 +322,54 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
277
322
|
# Cleanup
|
|
278
323
|
node = None
|
|
279
324
|
|
|
280
|
-
def receive() -> Optional[Message]:
|
|
281
|
-
"""
|
|
325
|
+
def receive() -> Optional[tuple[Message, ObjectTree]]:
|
|
326
|
+
"""Pull a message with its ObjectTree from SuperLink."""
|
|
282
327
|
# Get Node
|
|
283
328
|
if node is None:
|
|
284
|
-
|
|
285
|
-
return None
|
|
329
|
+
raise RuntimeError("Node instance missing")
|
|
286
330
|
|
|
287
|
-
#
|
|
331
|
+
# Try to pull a message with its object tree from SuperLink
|
|
288
332
|
req = PullMessagesRequest(node=node)
|
|
289
|
-
|
|
290
|
-
# Send the request
|
|
291
333
|
res = _request(req, PullMessagesResponse, PATH_PULL_MESSAGES)
|
|
292
334
|
if res is None:
|
|
335
|
+
raise ValueError("PushMessagesResponse is None.")
|
|
336
|
+
|
|
337
|
+
# If no messages are available, return None
|
|
338
|
+
if len(res.messages_list) == 0:
|
|
293
339
|
return None
|
|
294
340
|
|
|
295
|
-
# Get the current
|
|
296
|
-
message_proto =
|
|
297
|
-
|
|
298
|
-
# Discard the current message if not valid
|
|
299
|
-
if message_proto is not None and not (
|
|
300
|
-
message_proto.metadata.dst_node_id == node.node_id
|
|
301
|
-
):
|
|
302
|
-
message_proto = None
|
|
303
|
-
|
|
304
|
-
# Return the Message if available
|
|
305
|
-
nonlocal metadata
|
|
306
|
-
message = None
|
|
307
|
-
if message_proto is not None:
|
|
308
|
-
message = message_from_proto(message_proto)
|
|
309
|
-
metadata = copy(message.metadata)
|
|
310
|
-
log(INFO, "[Node] POST /%s: success", PATH_PULL_MESSAGES)
|
|
311
|
-
return message
|
|
312
|
-
|
|
313
|
-
def send(message: Message) -> None:
|
|
314
|
-
"""Send Message result back to server."""
|
|
315
|
-
# Get Node
|
|
316
|
-
if node is None:
|
|
317
|
-
log(ERROR, "Node instance missing")
|
|
318
|
-
return
|
|
341
|
+
# Get the current Message and its object tree
|
|
342
|
+
message_proto = res.messages_list[0]
|
|
343
|
+
object_tree = res.message_object_trees[0]
|
|
319
344
|
|
|
320
|
-
#
|
|
321
|
-
|
|
322
|
-
if metadata is None:
|
|
323
|
-
log(ERROR, "No current message")
|
|
324
|
-
return
|
|
345
|
+
# Construct the Message
|
|
346
|
+
in_message = message_from_proto(message_proto)
|
|
325
347
|
|
|
326
|
-
#
|
|
327
|
-
|
|
328
|
-
log(ERROR, "Invalid out message")
|
|
329
|
-
return
|
|
330
|
-
metadata = None
|
|
348
|
+
# Return the Message and its object tree
|
|
349
|
+
return in_message, object_tree
|
|
331
350
|
|
|
332
|
-
|
|
333
|
-
|
|
351
|
+
def send(message: Message, object_tree: ObjectTree) -> set[str]:
|
|
352
|
+
"""Send the message with its ObjectTree to SuperLink."""
|
|
353
|
+
# Get Node
|
|
354
|
+
if node is None:
|
|
355
|
+
raise RuntimeError("Node instance missing")
|
|
334
356
|
|
|
335
|
-
#
|
|
336
|
-
|
|
357
|
+
# Remove the content from the message if it has
|
|
358
|
+
if message.has_content():
|
|
359
|
+
message = remove_content_from_message(message)
|
|
337
360
|
|
|
338
|
-
# Send the
|
|
361
|
+
# Send the message with its ObjectTree to SuperLink
|
|
362
|
+
req = PushMessagesRequest(
|
|
363
|
+
node=node,
|
|
364
|
+
messages_list=[message_to_proto(message)],
|
|
365
|
+
message_object_trees=[object_tree],
|
|
366
|
+
)
|
|
339
367
|
res = _request(req, PushMessagesResponse, PATH_PUSH_MESSAGES)
|
|
340
368
|
if res is None:
|
|
341
|
-
|
|
369
|
+
raise ValueError("PushMessagesResponse is None.")
|
|
342
370
|
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
"[Node] POST /%s: success, created result %s",
|
|
346
|
-
PATH_PUSH_MESSAGES,
|
|
347
|
-
res.results, # pylint: disable=no-member
|
|
348
|
-
)
|
|
371
|
+
# Get and return the object IDs to push
|
|
372
|
+
return set(res.objects_to_push)
|
|
349
373
|
|
|
350
374
|
def get_run(run_id: int) -> Run:
|
|
351
375
|
# Construct the request
|
|
@@ -372,9 +396,58 @@ def http_request_response( # pylint: disable=R0913,R0914,R0915,R0917
|
|
|
372
396
|
res.fab.content,
|
|
373
397
|
)
|
|
374
398
|
|
|
399
|
+
def pull_object(run_id: int, object_id: str) -> bytes:
|
|
400
|
+
"""Pull the object from the SuperLink."""
|
|
401
|
+
# Check Node
|
|
402
|
+
if node is None:
|
|
403
|
+
raise RuntimeError("Node instance missing")
|
|
404
|
+
|
|
405
|
+
fn = make_pull_object_fn_protobuf(
|
|
406
|
+
pull_object_protobuf=_pull_object_protobuf,
|
|
407
|
+
node=node,
|
|
408
|
+
run_id=run_id,
|
|
409
|
+
)
|
|
410
|
+
return fn(object_id)
|
|
411
|
+
|
|
412
|
+
def push_object(run_id: int, object_id: str, contents: bytes) -> None:
|
|
413
|
+
"""Push the object to the SuperLink."""
|
|
414
|
+
# Check Node
|
|
415
|
+
if node is None:
|
|
416
|
+
raise RuntimeError("Node instance missing")
|
|
417
|
+
|
|
418
|
+
fn = make_push_object_fn_protobuf(
|
|
419
|
+
push_object_protobuf=_push_object_protobuf,
|
|
420
|
+
node=node,
|
|
421
|
+
run_id=run_id,
|
|
422
|
+
)
|
|
423
|
+
fn(object_id, contents)
|
|
424
|
+
|
|
425
|
+
def confirm_message_received(run_id: int, object_id: str) -> None:
|
|
426
|
+
"""Confirm that the message has been received."""
|
|
427
|
+
# Check Node
|
|
428
|
+
if node is None:
|
|
429
|
+
raise RuntimeError("Node instance missing")
|
|
430
|
+
|
|
431
|
+
fn = make_confirm_message_received_fn_protobuf(
|
|
432
|
+
confirm_message_received_protobuf=_confirm_message_received_protobuf,
|
|
433
|
+
node=node,
|
|
434
|
+
run_id=run_id,
|
|
435
|
+
)
|
|
436
|
+
fn(object_id)
|
|
437
|
+
|
|
375
438
|
try:
|
|
376
439
|
# Yield methods
|
|
377
|
-
yield (
|
|
440
|
+
yield (
|
|
441
|
+
receive,
|
|
442
|
+
send,
|
|
443
|
+
create_node,
|
|
444
|
+
delete_node,
|
|
445
|
+
get_run,
|
|
446
|
+
get_fab,
|
|
447
|
+
pull_object,
|
|
448
|
+
push_object,
|
|
449
|
+
confirm_message_received,
|
|
450
|
+
)
|
|
378
451
|
except Exception as exc: # pylint: disable=broad-except
|
|
379
452
|
log(ERROR, exc)
|
|
380
453
|
# Cleanup
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Public Flower ClientApp APIs."""
|
flwr/common/__init__.py
CHANGED
|
@@ -15,6 +15,8 @@
|
|
|
15
15
|
"""Common components shared between server and client."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
from ..app.error import Error as Error
|
|
19
|
+
from ..app.metadata import Metadata as Metadata
|
|
18
20
|
from .constant import MessageType as MessageType
|
|
19
21
|
from .constant import MessageTypeLegacy as MessageTypeLegacy
|
|
20
22
|
from .context import Context as Context
|
|
@@ -23,9 +25,7 @@ from .grpc import GRPC_MAX_MESSAGE_LENGTH
|
|
|
23
25
|
from .logger import configure as configure
|
|
24
26
|
from .logger import log as log
|
|
25
27
|
from .message import DEFAULT_TTL
|
|
26
|
-
from .message import Error as Error
|
|
27
28
|
from .message import Message as Message
|
|
28
|
-
from .message import Metadata as Metadata
|
|
29
29
|
from .parameter import bytes_to_ndarray as bytes_to_ndarray
|
|
30
30
|
from .parameter import ndarray_to_bytes as ndarray_to_bytes
|
|
31
31
|
from .parameter import ndarrays_to_parameters as ndarrays_to_parameters
|
|
@@ -17,8 +17,10 @@
|
|
|
17
17
|
|
|
18
18
|
from .auth_plugin import CliAuthPlugin as CliAuthPlugin
|
|
19
19
|
from .auth_plugin import ExecAuthPlugin as ExecAuthPlugin
|
|
20
|
+
from .auth_plugin import ExecAuthzPlugin as ExecAuthzPlugin
|
|
20
21
|
|
|
21
22
|
__all__ = [
|
|
22
23
|
"CliAuthPlugin",
|
|
23
24
|
"ExecAuthPlugin",
|
|
25
|
+
"ExecAuthzPlugin",
|
|
24
26
|
]
|
|
@@ -20,7 +20,7 @@ from collections.abc import Sequence
|
|
|
20
20
|
from pathlib import Path
|
|
21
21
|
from typing import Optional, Union
|
|
22
22
|
|
|
23
|
-
from flwr.common.typing import
|
|
23
|
+
from flwr.common.typing import AccountInfo
|
|
24
24
|
from flwr.proto.exec_pb2_grpc import ExecStub
|
|
25
25
|
|
|
26
26
|
from ..typing import UserAuthCredentials, UserAuthLoginDetails
|
|
@@ -33,6 +33,9 @@ class ExecAuthPlugin(ABC):
|
|
|
33
33
|
----------
|
|
34
34
|
user_auth_config_path : Path
|
|
35
35
|
Path to the YAML file containing the authentication configuration.
|
|
36
|
+
verify_tls_cert : bool
|
|
37
|
+
Boolean indicating whether to verify the TLS certificate
|
|
38
|
+
when making requests to the server.
|
|
36
39
|
"""
|
|
37
40
|
|
|
38
41
|
@abstractmethod
|
|
@@ -50,7 +53,7 @@ class ExecAuthPlugin(ABC):
|
|
|
50
53
|
@abstractmethod
|
|
51
54
|
def validate_tokens_in_metadata(
|
|
52
55
|
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
|
|
53
|
-
) -> tuple[bool, Optional[
|
|
56
|
+
) -> tuple[bool, Optional[AccountInfo]]:
|
|
54
57
|
"""Validate authentication tokens in the provided metadata."""
|
|
55
58
|
|
|
56
59
|
@abstractmethod
|
|
@@ -60,10 +63,33 @@ class ExecAuthPlugin(ABC):
|
|
|
60
63
|
@abstractmethod
|
|
61
64
|
def refresh_tokens(
|
|
62
65
|
self, metadata: Sequence[tuple[str, Union[str, bytes]]]
|
|
63
|
-
) ->
|
|
66
|
+
) -> tuple[
|
|
67
|
+
Optional[Sequence[tuple[str, Union[str, bytes]]]], Optional[AccountInfo]
|
|
68
|
+
]:
|
|
64
69
|
"""Refresh authentication tokens in the provided metadata."""
|
|
65
70
|
|
|
66
71
|
|
|
72
|
+
class ExecAuthzPlugin(ABC): # pylint: disable=too-few-public-methods
|
|
73
|
+
"""Abstract Flower Authorization Plugin class for ExecServicer.
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
user_auth_config_path : Path
|
|
78
|
+
Path to the YAML file containing the authorization configuration.
|
|
79
|
+
verify_tls_cert : bool
|
|
80
|
+
Boolean indicating whether to verify the TLS certificate
|
|
81
|
+
when making requests to the server.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def __init__(self, user_auth_config_path: Path, verify_tls_cert: bool):
|
|
86
|
+
"""Abstract constructor."""
|
|
87
|
+
|
|
88
|
+
@abstractmethod
|
|
89
|
+
def verify_user_authorization(self, account_info: AccountInfo) -> bool:
|
|
90
|
+
"""Verify user authorization request."""
|
|
91
|
+
|
|
92
|
+
|
|
67
93
|
class CliAuthPlugin(ABC):
|
|
68
94
|
"""Abstract Flower Auth Plugin class for CLI.
|
|
69
95
|
|