flwr 1.12.0__py3-none-any.whl → 1.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/app.py +2 -0
- flwr/cli/build.py +2 -2
- flwr/cli/config_utils.py +97 -0
- flwr/cli/install.py +0 -16
- flwr/cli/log.py +63 -97
- flwr/cli/ls.py +228 -0
- flwr/cli/new/new.py +23 -13
- flwr/cli/new/templates/app/README.md.tpl +11 -0
- flwr/cli/new/templates/app/code/flwr_tune/dataset.py.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/run/run.py +37 -89
- flwr/client/app.py +73 -34
- flwr/client/clientapp/app.py +58 -37
- flwr/client/grpc_rere_client/connection.py +7 -12
- flwr/client/nodestate/__init__.py +25 -0
- flwr/client/nodestate/in_memory_nodestate.py +38 -0
- flwr/client/nodestate/nodestate.py +30 -0
- flwr/client/nodestate/nodestate_factory.py +37 -0
- flwr/client/rest_client/connection.py +4 -14
- flwr/client/{node_state.py → run_info_store.py} +4 -3
- flwr/client/supernode/app.py +34 -58
- flwr/common/args.py +152 -0
- flwr/common/config.py +10 -0
- flwr/common/constant.py +59 -7
- flwr/common/context.py +9 -4
- flwr/common/date.py +21 -3
- flwr/common/grpc.py +4 -1
- flwr/common/logger.py +108 -1
- flwr/common/object_ref.py +47 -16
- flwr/common/serde.py +34 -0
- flwr/common/telemetry.py +0 -6
- flwr/common/typing.py +32 -2
- flwr/proto/exec_pb2.py +23 -17
- flwr/proto/exec_pb2.pyi +58 -22
- flwr/proto/exec_pb2_grpc.py +34 -0
- flwr/proto/exec_pb2_grpc.pyi +13 -0
- flwr/proto/log_pb2.py +29 -0
- flwr/proto/log_pb2.pyi +39 -0
- flwr/proto/log_pb2_grpc.py +4 -0
- flwr/proto/log_pb2_grpc.pyi +4 -0
- flwr/proto/message_pb2.py +8 -8
- flwr/proto/message_pb2.pyi +4 -1
- flwr/proto/run_pb2.py +32 -27
- flwr/proto/run_pb2.pyi +44 -1
- flwr/proto/serverappio_pb2.py +52 -0
- flwr/proto/{driver_pb2.pyi → serverappio_pb2.pyi} +54 -0
- flwr/proto/serverappio_pb2_grpc.py +376 -0
- flwr/proto/serverappio_pb2_grpc.pyi +147 -0
- flwr/proto/simulationio_pb2.py +38 -0
- flwr/proto/simulationio_pb2.pyi +65 -0
- flwr/proto/simulationio_pb2_grpc.py +205 -0
- flwr/proto/simulationio_pb2_grpc.pyi +81 -0
- flwr/server/app.py +297 -162
- flwr/server/driver/driver.py +15 -1
- flwr/server/driver/grpc_driver.py +89 -50
- flwr/server/driver/inmemory_driver.py +6 -16
- flwr/server/run_serverapp.py +11 -235
- flwr/server/{superlink/state → serverapp}/__init__.py +3 -9
- flwr/server/serverapp/app.py +234 -0
- flwr/server/strategy/aggregate.py +4 -4
- flwr/server/strategy/fedadam.py +11 -1
- flwr/server/superlink/driver/__init__.py +1 -1
- flwr/server/superlink/driver/{driver_grpc.py → serverappio_grpc.py} +19 -16
- flwr/server/superlink/driver/{driver_servicer.py → serverappio_servicer.py} +125 -39
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +2 -2
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -2
- flwr/server/superlink/fleet/grpc_rere/server_interceptor.py +2 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +7 -7
- flwr/server/superlink/fleet/rest_rere/rest_api.py +10 -9
- flwr/server/superlink/fleet/vce/vce_api.py +23 -23
- flwr/server/superlink/linkstate/__init__.py +28 -0
- flwr/server/superlink/{state/in_memory_state.py → linkstate/in_memory_linkstate.py} +237 -64
- flwr/server/superlink/{state/state.py → linkstate/linkstate.py} +166 -22
- flwr/server/superlink/{state/state_factory.py → linkstate/linkstate_factory.py} +9 -9
- flwr/server/superlink/{state/sqlite_state.py → linkstate/sqlite_linkstate.py} +383 -174
- flwr/server/superlink/linkstate/utils.py +389 -0
- flwr/server/superlink/simulation/__init__.py +15 -0
- flwr/server/superlink/simulation/simulationio_grpc.py +65 -0
- flwr/server/superlink/simulation/simulationio_servicer.py +153 -0
- flwr/simulation/__init__.py +5 -1
- flwr/simulation/app.py +236 -347
- flwr/simulation/legacy_app.py +402 -0
- flwr/simulation/ray_transport/ray_client_proxy.py +2 -2
- flwr/simulation/run_simulation.py +56 -141
- flwr/simulation/simulationio_connection.py +86 -0
- flwr/superexec/app.py +6 -134
- flwr/superexec/deployment.py +70 -69
- flwr/superexec/exec_grpc.py +15 -8
- flwr/superexec/exec_servicer.py +65 -65
- flwr/superexec/executor.py +26 -7
- flwr/superexec/simulation.py +62 -150
- {flwr-1.12.0.dist-info → flwr-1.13.1.dist-info}/METADATA +9 -7
- {flwr-1.12.0.dist-info → flwr-1.13.1.dist-info}/RECORD +105 -85
- {flwr-1.12.0.dist-info → flwr-1.13.1.dist-info}/entry_points.txt +2 -0
- flwr/client/node_state_tests.py +0 -66
- flwr/proto/driver_pb2.py +0 -42
- flwr/proto/driver_pb2_grpc.py +0 -239
- flwr/proto/driver_pb2_grpc.pyi +0 -94
- flwr/server/superlink/state/utils.py +0 -148
- {flwr-1.12.0.dist-info → flwr-1.13.1.dist-info}/LICENSE +0 -0
- {flwr-1.12.0.dist-info → flwr-1.13.1.dist-info}/WHEEL +0 -0
flwr/common/args.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Common Flower arguments."""
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import sys
|
|
19
|
+
from logging import DEBUG, ERROR, WARN
|
|
20
|
+
from os.path import isfile
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Optional
|
|
23
|
+
|
|
24
|
+
from flwr.common.constant import (
|
|
25
|
+
TRANSPORT_TYPE_GRPC_ADAPTER,
|
|
26
|
+
TRANSPORT_TYPE_GRPC_RERE,
|
|
27
|
+
TRANSPORT_TYPE_REST,
|
|
28
|
+
)
|
|
29
|
+
from flwr.common.logger import log
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def add_args_flwr_app_common(parser: argparse.ArgumentParser) -> None:
|
|
33
|
+
"""Add common Flower arguments for flwr-*app to the provided parser."""
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
"--flwr-dir",
|
|
36
|
+
default=None,
|
|
37
|
+
help="""The path containing installed Flower Apps.
|
|
38
|
+
By default, this value is equal to:
|
|
39
|
+
|
|
40
|
+
- `$FLWR_HOME/` if `$FLWR_HOME` is defined
|
|
41
|
+
- `$XDG_DATA_HOME/.flwr/` if `$XDG_DATA_HOME` is defined
|
|
42
|
+
- `$HOME/.flwr/` in all other cases
|
|
43
|
+
""",
|
|
44
|
+
)
|
|
45
|
+
parser.add_argument(
|
|
46
|
+
"--insecure",
|
|
47
|
+
action="store_true",
|
|
48
|
+
help="Run the server without HTTPS, regardless of whether certificate "
|
|
49
|
+
"paths are provided. By default, the server runs with HTTPS enabled. "
|
|
50
|
+
"Use this flag only if you understand the risks.",
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def try_obtain_root_certificates(
|
|
55
|
+
args: argparse.Namespace,
|
|
56
|
+
grpc_server_address: str,
|
|
57
|
+
) -> Optional[bytes]:
|
|
58
|
+
"""Validate and return the root certificates."""
|
|
59
|
+
root_cert_path = args.root_certificates
|
|
60
|
+
if args.insecure:
|
|
61
|
+
if root_cert_path is not None:
|
|
62
|
+
sys.exit(
|
|
63
|
+
"Conflicting options: The '--insecure' flag disables HTTPS, "
|
|
64
|
+
"but '--root-certificates' was also specified. Please remove "
|
|
65
|
+
"the '--root-certificates' option when running in insecure mode, "
|
|
66
|
+
"or omit '--insecure' to use HTTPS."
|
|
67
|
+
)
|
|
68
|
+
log(
|
|
69
|
+
WARN,
|
|
70
|
+
"Option `--insecure` was set. Starting insecure HTTP channel to %s.",
|
|
71
|
+
grpc_server_address,
|
|
72
|
+
)
|
|
73
|
+
root_certificates = None
|
|
74
|
+
else:
|
|
75
|
+
# Load the certificates if provided, or load the system certificates
|
|
76
|
+
if root_cert_path is None:
|
|
77
|
+
log(
|
|
78
|
+
WARN,
|
|
79
|
+
"Both `--insecure` and `--root-certificates` were not set. "
|
|
80
|
+
"Using system certificates.",
|
|
81
|
+
)
|
|
82
|
+
root_certificates = None
|
|
83
|
+
elif not isfile(root_cert_path):
|
|
84
|
+
log(ERROR, "Path argument `--root-certificates` does not point to a file.")
|
|
85
|
+
sys.exit(1)
|
|
86
|
+
else:
|
|
87
|
+
root_certificates = Path(root_cert_path).read_bytes()
|
|
88
|
+
log(
|
|
89
|
+
DEBUG,
|
|
90
|
+
"Starting secure HTTPS channel to %s "
|
|
91
|
+
"with the following certificates: %s.",
|
|
92
|
+
grpc_server_address,
|
|
93
|
+
root_cert_path,
|
|
94
|
+
)
|
|
95
|
+
return root_certificates
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def try_obtain_server_certificates(
|
|
99
|
+
args: argparse.Namespace,
|
|
100
|
+
transport_type: str,
|
|
101
|
+
) -> Optional[tuple[bytes, bytes, bytes]]:
|
|
102
|
+
"""Validate and return the CA cert, server cert, and server private key."""
|
|
103
|
+
if args.insecure:
|
|
104
|
+
log(WARN, "Option `--insecure` was set. Starting insecure HTTP server.")
|
|
105
|
+
return None
|
|
106
|
+
# Check if certificates are provided
|
|
107
|
+
if transport_type in [TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_GRPC_ADAPTER]:
|
|
108
|
+
if args.ssl_certfile and args.ssl_keyfile and args.ssl_ca_certfile:
|
|
109
|
+
if not isfile(args.ssl_ca_certfile):
|
|
110
|
+
sys.exit("Path argument `--ssl-ca-certfile` does not point to a file.")
|
|
111
|
+
if not isfile(args.ssl_certfile):
|
|
112
|
+
sys.exit("Path argument `--ssl-certfile` does not point to a file.")
|
|
113
|
+
if not isfile(args.ssl_keyfile):
|
|
114
|
+
sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
|
|
115
|
+
certificates = (
|
|
116
|
+
Path(args.ssl_ca_certfile).read_bytes(), # CA certificate
|
|
117
|
+
Path(args.ssl_certfile).read_bytes(), # server certificate
|
|
118
|
+
Path(args.ssl_keyfile).read_bytes(), # server private key
|
|
119
|
+
)
|
|
120
|
+
return certificates
|
|
121
|
+
if args.ssl_certfile or args.ssl_keyfile or args.ssl_ca_certfile:
|
|
122
|
+
sys.exit(
|
|
123
|
+
"You need to provide valid file paths to `--ssl-certfile`, "
|
|
124
|
+
"`--ssl-keyfile`, and `—-ssl-ca-certfile` to create a secure "
|
|
125
|
+
"connection in Fleet API server (gRPC-rere)."
|
|
126
|
+
)
|
|
127
|
+
if transport_type == TRANSPORT_TYPE_REST:
|
|
128
|
+
if args.ssl_certfile and args.ssl_keyfile:
|
|
129
|
+
if not isfile(args.ssl_certfile):
|
|
130
|
+
sys.exit("Path argument `--ssl-certfile` does not point to a file.")
|
|
131
|
+
if not isfile(args.ssl_keyfile):
|
|
132
|
+
sys.exit("Path argument `--ssl-keyfile` does not point to a file.")
|
|
133
|
+
certificates = (
|
|
134
|
+
b"",
|
|
135
|
+
Path(args.ssl_certfile).read_bytes(), # server certificate
|
|
136
|
+
Path(args.ssl_keyfile).read_bytes(), # server private key
|
|
137
|
+
)
|
|
138
|
+
return certificates
|
|
139
|
+
if args.ssl_certfile or args.ssl_keyfile:
|
|
140
|
+
sys.exit(
|
|
141
|
+
"You need to provide valid file paths to `--ssl-certfile` "
|
|
142
|
+
"and `--ssl-keyfile` to create a secure connection "
|
|
143
|
+
"in Fleet API server (REST, experimental)."
|
|
144
|
+
)
|
|
145
|
+
log(
|
|
146
|
+
ERROR,
|
|
147
|
+
"Certificates are required unless running in insecure mode. "
|
|
148
|
+
"Please provide certificate paths to `--ssl-certfile`, "
|
|
149
|
+
"`--ssl-keyfile`, and `—-ssl-ca-certfile` or run the server "
|
|
150
|
+
"in insecure mode using '--insecure' if you understand the risks.",
|
|
151
|
+
)
|
|
152
|
+
sys.exit(1)
|
flwr/common/config.py
CHANGED
|
@@ -22,6 +22,7 @@ from typing import Any, Optional, Union, cast, get_args
|
|
|
22
22
|
import tomli
|
|
23
23
|
|
|
24
24
|
from flwr.cli.config_utils import get_fab_config, validate_fields
|
|
25
|
+
from flwr.common import ConfigsRecord
|
|
25
26
|
from flwr.common.constant import (
|
|
26
27
|
APP_DIR,
|
|
27
28
|
FAB_CONFIG_FILE,
|
|
@@ -229,3 +230,12 @@ def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
|
|
|
229
230
|
config["project"]["version"],
|
|
230
231
|
f"{config['tool']['flwr']['app']['publisher']}/{config['project']['name']}",
|
|
231
232
|
)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def user_config_to_configsrecord(config: UserConfig) -> ConfigsRecord:
|
|
236
|
+
"""Construct a `ConfigsRecord` out of a `UserConfig`."""
|
|
237
|
+
c_record = ConfigsRecord()
|
|
238
|
+
for k, v in config.items():
|
|
239
|
+
c_record[k] = v
|
|
240
|
+
|
|
241
|
+
return c_record
|
flwr/common/constant.py
CHANGED
|
@@ -38,17 +38,30 @@ TRANSPORT_TYPES = [
|
|
|
38
38
|
]
|
|
39
39
|
|
|
40
40
|
# Addresses
|
|
41
|
+
# Ports
|
|
42
|
+
CLIENTAPPIO_PORT = "9094"
|
|
43
|
+
SERVERAPPIO_PORT = "9091"
|
|
44
|
+
FLEETAPI_GRPC_RERE_PORT = "9092"
|
|
45
|
+
FLEETAPI_PORT = "9095"
|
|
46
|
+
EXEC_API_PORT = "9093"
|
|
47
|
+
SIMULATIONIO_PORT = "9096"
|
|
48
|
+
# Octets
|
|
49
|
+
SERVER_OCTET = "0.0.0.0"
|
|
50
|
+
CLIENT_OCTET = "127.0.0.1"
|
|
41
51
|
# SuperNode
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
EXEC_API_DEFAULT_ADDRESS = "0.0.0.0:9093"
|
|
52
|
+
CLIENTAPPIO_API_DEFAULT_SERVER_ADDRESS = f"{SERVER_OCTET}:{CLIENTAPPIO_PORT}"
|
|
53
|
+
CLIENTAPPIO_API_DEFAULT_CLIENT_ADDRESS = f"{CLIENT_OCTET}:{CLIENTAPPIO_PORT}"
|
|
45
54
|
# SuperLink
|
|
46
|
-
|
|
47
|
-
|
|
55
|
+
SERVERAPPIO_API_DEFAULT_SERVER_ADDRESS = f"{SERVER_OCTET}:{SERVERAPPIO_PORT}"
|
|
56
|
+
SERVERAPPIO_API_DEFAULT_CLIENT_ADDRESS = f"{CLIENT_OCTET}:{SERVERAPPIO_PORT}"
|
|
57
|
+
FLEET_API_GRPC_RERE_DEFAULT_ADDRESS = f"{SERVER_OCTET}:{FLEETAPI_GRPC_RERE_PORT}"
|
|
48
58
|
FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS = (
|
|
49
59
|
"[::]:8080" # IPv6 to keep start_server compatible
|
|
50
60
|
)
|
|
51
|
-
FLEET_API_REST_DEFAULT_ADDRESS = "
|
|
61
|
+
FLEET_API_REST_DEFAULT_ADDRESS = f"{SERVER_OCTET}:{FLEETAPI_PORT}"
|
|
62
|
+
EXEC_API_DEFAULT_SERVER_ADDRESS = f"{SERVER_OCTET}:{EXEC_API_PORT}"
|
|
63
|
+
SIMULATIONIO_API_DEFAULT_SERVER_ADDRESS = f"{SERVER_OCTET}:{SIMULATIONIO_PORT}"
|
|
64
|
+
SIMULATIONIO_API_DEFAULT_CLIENT_ADDRESS = f"{CLIENT_OCTET}:{SIMULATIONIO_PORT}"
|
|
52
65
|
|
|
53
66
|
# Constants for ping
|
|
54
67
|
PING_DEFAULT_INTERVAL = 30
|
|
@@ -84,6 +97,19 @@ GRPC_ADAPTER_METADATA_MESSAGE_QUALNAME_KEY = "grpc-message-qualname"
|
|
|
84
97
|
# Message TTL
|
|
85
98
|
MESSAGE_TTL_TOLERANCE = 1e-1
|
|
86
99
|
|
|
100
|
+
# Isolation modes
|
|
101
|
+
ISOLATION_MODE_SUBPROCESS = "subprocess"
|
|
102
|
+
ISOLATION_MODE_PROCESS = "process"
|
|
103
|
+
|
|
104
|
+
# Log streaming configurations
|
|
105
|
+
CONN_REFRESH_PERIOD = 60 # Stream connection refresh period
|
|
106
|
+
CONN_RECONNECT_INTERVAL = 0.5 # Reconnect interval between two stream connections
|
|
107
|
+
LOG_STREAM_INTERVAL = 0.5 # Log stream interval for `ExecServicer.StreamLogs`
|
|
108
|
+
LOG_UPLOAD_INTERVAL = 0.2 # Minimum interval between two log uploads
|
|
109
|
+
|
|
110
|
+
# Retry configurations
|
|
111
|
+
MAX_RETRY_DELAY = 20 # Maximum delay duration between two consecutive retries.
|
|
112
|
+
|
|
87
113
|
|
|
88
114
|
class MessageType:
|
|
89
115
|
"""Message type."""
|
|
@@ -124,8 +150,34 @@ class ErrorCode:
|
|
|
124
150
|
UNKNOWN = 0
|
|
125
151
|
LOAD_CLIENT_APP_EXCEPTION = 1
|
|
126
152
|
CLIENT_APP_RAISED_EXCEPTION = 2
|
|
127
|
-
|
|
153
|
+
MESSAGE_UNAVAILABLE = 3
|
|
154
|
+
REPLY_MESSAGE_UNAVAILABLE = 4
|
|
128
155
|
|
|
129
156
|
def __new__(cls) -> ErrorCode:
|
|
130
157
|
"""Prevent instantiation."""
|
|
131
158
|
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class Status:
|
|
162
|
+
"""Run status."""
|
|
163
|
+
|
|
164
|
+
PENDING = "pending"
|
|
165
|
+
STARTING = "starting"
|
|
166
|
+
RUNNING = "running"
|
|
167
|
+
FINISHED = "finished"
|
|
168
|
+
|
|
169
|
+
def __new__(cls) -> Status:
|
|
170
|
+
"""Prevent instantiation."""
|
|
171
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class SubStatus:
|
|
175
|
+
"""Run sub-status."""
|
|
176
|
+
|
|
177
|
+
COMPLETED = "completed"
|
|
178
|
+
FAILED = "failed"
|
|
179
|
+
STOPPED = "stopped"
|
|
180
|
+
|
|
181
|
+
def __new__(cls) -> SubStatus:
|
|
182
|
+
"""Prevent instantiation."""
|
|
183
|
+
raise TypeError(f"{cls.__name__} cannot be instantiated.")
|
flwr/common/context.py
CHANGED
|
@@ -27,36 +27,41 @@ class Context:
|
|
|
27
27
|
|
|
28
28
|
Parameters
|
|
29
29
|
----------
|
|
30
|
+
run_id : int
|
|
31
|
+
The ID that identifies the run.
|
|
30
32
|
node_id : int
|
|
31
33
|
The ID that identifies the node.
|
|
32
34
|
node_config : UserConfig
|
|
33
35
|
A config (key/value mapping) unique to the node and independent of the
|
|
34
36
|
`run_config`. This config persists across all runs this node participates in.
|
|
35
37
|
state : RecordSet
|
|
36
|
-
Holds records added by the entity in a given
|
|
38
|
+
Holds records added by the entity in a given `run_id` and that will stay local.
|
|
37
39
|
This means that the data it holds will never leave the system it's running from.
|
|
38
40
|
This can be used as an intermediate storage or scratchpad when
|
|
39
41
|
executing mods. It can also be used as a memory to access
|
|
40
42
|
at different points during the lifecycle of this entity (e.g. across
|
|
41
43
|
multiple rounds)
|
|
42
44
|
run_config : UserConfig
|
|
43
|
-
A config (key/value mapping) held by the entity in a given
|
|
44
|
-
stay local. It can be used at any point during the lifecycle of this entity
|
|
45
|
+
A config (key/value mapping) held by the entity in a given `run_id` and that
|
|
46
|
+
will stay local. It can be used at any point during the lifecycle of this entity
|
|
45
47
|
(e.g. across multiple rounds)
|
|
46
48
|
"""
|
|
47
49
|
|
|
50
|
+
run_id: int
|
|
48
51
|
node_id: int
|
|
49
52
|
node_config: UserConfig
|
|
50
53
|
state: RecordSet
|
|
51
54
|
run_config: UserConfig
|
|
52
55
|
|
|
53
|
-
def __init__( # pylint: disable=too-many-arguments
|
|
56
|
+
def __init__( # pylint: disable=too-many-arguments, too-many-positional-arguments
|
|
54
57
|
self,
|
|
58
|
+
run_id: int,
|
|
55
59
|
node_id: int,
|
|
56
60
|
node_config: UserConfig,
|
|
57
61
|
state: RecordSet,
|
|
58
62
|
run_config: UserConfig,
|
|
59
63
|
) -> None:
|
|
64
|
+
self.run_id = run_id
|
|
60
65
|
self.node_id = node_id
|
|
61
66
|
self.node_config = node_config
|
|
62
67
|
self.state = state
|
flwr/common/date.py
CHANGED
|
@@ -15,9 +15,27 @@
|
|
|
15
15
|
"""Flower date utils."""
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
|
|
18
|
+
import datetime
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
def now() -> datetime:
|
|
21
|
+
def now() -> datetime.datetime:
|
|
22
22
|
"""Construct a datetime from time.time() with time zone set to UTC."""
|
|
23
|
-
return datetime.now(tz=timezone.utc)
|
|
23
|
+
return datetime.datetime.now(tz=datetime.timezone.utc)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def format_timedelta(td: datetime.timedelta) -> str:
|
|
27
|
+
"""Format a timedelta as a string."""
|
|
28
|
+
days = td.days
|
|
29
|
+
hours, remainder = divmod(td.seconds, 3600)
|
|
30
|
+
minutes, seconds = divmod(remainder, 60)
|
|
31
|
+
|
|
32
|
+
if days > 0:
|
|
33
|
+
return f"{days}d {hours:02}:{minutes:02}:{seconds:02}"
|
|
34
|
+
return f"{hours:02}:{minutes:02}:{seconds:02}"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def isoformat8601_utc(dt: datetime.datetime) -> str:
|
|
38
|
+
"""Return the datetime formatted as an ISO 8601 string with a trailing 'Z'."""
|
|
39
|
+
if dt.tzinfo != datetime.timezone.utc:
|
|
40
|
+
raise ValueError("Expected datetime with timezone set to UTC")
|
|
41
|
+
return dt.isoformat(timespec="seconds").replace("+00:00", "Z")
|
flwr/common/grpc.py
CHANGED
|
@@ -53,7 +53,10 @@ def create_channel(
|
|
|
53
53
|
channel = grpc.insecure_channel(server_address, options=channel_options)
|
|
54
54
|
log(DEBUG, "Opened insecure gRPC connection (no certificates were passed)")
|
|
55
55
|
else:
|
|
56
|
-
|
|
56
|
+
try:
|
|
57
|
+
ssl_channel_credentials = grpc.ssl_channel_credentials(root_certificates)
|
|
58
|
+
except Exception as e:
|
|
59
|
+
raise ValueError(f"Failed to create SSL channel credentials: {e}") from e
|
|
57
60
|
channel = grpc.secure_channel(
|
|
58
61
|
server_address, ssl_channel_credentials, options=channel_options
|
|
59
62
|
)
|
flwr/common/logger.py
CHANGED
|
@@ -16,9 +16,22 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
import logging
|
|
19
|
+
import sys
|
|
20
|
+
import threading
|
|
21
|
+
import time
|
|
19
22
|
from logging import WARN, LogRecord
|
|
20
23
|
from logging.handlers import HTTPHandler
|
|
21
|
-
from
|
|
24
|
+
from queue import Empty, Queue
|
|
25
|
+
from typing import TYPE_CHECKING, Any, Optional, TextIO, Union
|
|
26
|
+
|
|
27
|
+
import grpc
|
|
28
|
+
|
|
29
|
+
from flwr.proto.log_pb2 import PushLogsRequest # pylint: disable=E0611
|
|
30
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
31
|
+
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub # pylint: disable=E0611
|
|
32
|
+
from flwr.proto.simulationio_pb2_grpc import SimulationIoStub # pylint: disable=E0611
|
|
33
|
+
|
|
34
|
+
from .constant import LOG_UPLOAD_INTERVAL
|
|
22
35
|
|
|
23
36
|
# Create logger
|
|
24
37
|
LOGGER_NAME = "flwr"
|
|
@@ -259,3 +272,97 @@ def set_logger_propagation(
|
|
|
259
272
|
if not child_logger.propagate:
|
|
260
273
|
child_logger.log(logging.DEBUG, "Logger propagate set to False")
|
|
261
274
|
return child_logger
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def mirror_output_to_queue(log_queue: Queue[Optional[str]]) -> None:
|
|
278
|
+
"""Mirror stdout and stderr output to the provided queue."""
|
|
279
|
+
|
|
280
|
+
def get_write_fn(stream: TextIO) -> Any:
|
|
281
|
+
original_write = stream.write
|
|
282
|
+
|
|
283
|
+
def fn(s: str) -> int:
|
|
284
|
+
ret = original_write(s)
|
|
285
|
+
stream.flush()
|
|
286
|
+
log_queue.put(s)
|
|
287
|
+
return ret
|
|
288
|
+
|
|
289
|
+
return fn
|
|
290
|
+
|
|
291
|
+
sys.stdout.write = get_write_fn(sys.stdout) # type: ignore[method-assign]
|
|
292
|
+
sys.stderr.write = get_write_fn(sys.stderr) # type: ignore[method-assign]
|
|
293
|
+
console_handler.stream = sys.stdout
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def restore_output() -> None:
|
|
297
|
+
"""Restore stdout and stderr.
|
|
298
|
+
|
|
299
|
+
This will stop mirroring output to queues.
|
|
300
|
+
"""
|
|
301
|
+
sys.stdout = sys.__stdout__
|
|
302
|
+
sys.stderr = sys.__stderr__
|
|
303
|
+
console_handler.stream = sys.stdout
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def _log_uploader(
|
|
307
|
+
log_queue: Queue[Optional[str]], node_id: int, run_id: int, stub: ServerAppIoStub
|
|
308
|
+
) -> None:
|
|
309
|
+
"""Upload logs to the SuperLink."""
|
|
310
|
+
exit_flag = False
|
|
311
|
+
node = Node(node_id=node_id, anonymous=False)
|
|
312
|
+
msgs: list[str] = []
|
|
313
|
+
while True:
|
|
314
|
+
# Fetch all messages from the queue
|
|
315
|
+
try:
|
|
316
|
+
while True:
|
|
317
|
+
msg = log_queue.get_nowait()
|
|
318
|
+
# Quit the loops if the returned message is `None`
|
|
319
|
+
# This is a signal that the run has finished
|
|
320
|
+
if msg is None:
|
|
321
|
+
exit_flag = True
|
|
322
|
+
break
|
|
323
|
+
msgs.append(msg)
|
|
324
|
+
except Empty:
|
|
325
|
+
pass
|
|
326
|
+
|
|
327
|
+
# Upload if any logs
|
|
328
|
+
if msgs:
|
|
329
|
+
req = PushLogsRequest(
|
|
330
|
+
node=node,
|
|
331
|
+
run_id=run_id,
|
|
332
|
+
logs=msgs,
|
|
333
|
+
)
|
|
334
|
+
try:
|
|
335
|
+
stub.PushLogs(req)
|
|
336
|
+
msgs.clear()
|
|
337
|
+
except grpc.RpcError as e:
|
|
338
|
+
# Ignore minor network errors
|
|
339
|
+
# pylint: disable-next=no-member
|
|
340
|
+
if e.code() != grpc.StatusCode.UNAVAILABLE:
|
|
341
|
+
raise e
|
|
342
|
+
|
|
343
|
+
if exit_flag:
|
|
344
|
+
break
|
|
345
|
+
|
|
346
|
+
time.sleep(LOG_UPLOAD_INTERVAL)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def start_log_uploader(
|
|
350
|
+
log_queue: Queue[Optional[str]],
|
|
351
|
+
node_id: int,
|
|
352
|
+
run_id: int,
|
|
353
|
+
stub: Union[ServerAppIoStub, SimulationIoStub],
|
|
354
|
+
) -> threading.Thread:
|
|
355
|
+
"""Start the log uploader thread and return it."""
|
|
356
|
+
thread = threading.Thread(
|
|
357
|
+
target=_log_uploader, args=(log_queue, node_id, run_id, stub)
|
|
358
|
+
)
|
|
359
|
+
thread.start()
|
|
360
|
+
return thread
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def stop_log_uploader(
|
|
364
|
+
log_queue: Queue[Optional[str]], log_uploader: threading.Thread
|
|
365
|
+
) -> None:
|
|
366
|
+
"""Stop the log uploader thread."""
|
|
367
|
+
log_queue.put(None)
|
|
368
|
+
log_uploader.join()
|
flwr/common/object_ref.py
CHANGED
|
@@ -55,8 +55,8 @@ def validate(
|
|
|
55
55
|
specified attribute within it.
|
|
56
56
|
project_dir : Optional[Union[str, Path]] (default: None)
|
|
57
57
|
The directory containing the module. If None, the current working directory
|
|
58
|
-
is used. If `check_module` is True, the `project_dir` will be
|
|
59
|
-
the system path
|
|
58
|
+
is used. If `check_module` is True, the `project_dir` will be temporarily
|
|
59
|
+
inserted into the system path and then removed after the validation is complete.
|
|
60
60
|
|
|
61
61
|
Returns
|
|
62
62
|
-------
|
|
@@ -66,8 +66,8 @@ def validate(
|
|
|
66
66
|
|
|
67
67
|
Note
|
|
68
68
|
----
|
|
69
|
-
This function will modify `sys.path` by inserting the provided
|
|
70
|
-
|
|
69
|
+
This function will temporarily modify `sys.path` by inserting the provided
|
|
70
|
+
`project_dir`, which will be removed after the validation is complete.
|
|
71
71
|
"""
|
|
72
72
|
module_str, _, attributes_str = module_attribute_str.partition(":")
|
|
73
73
|
if not module_str:
|
|
@@ -82,11 +82,19 @@ def validate(
|
|
|
82
82
|
)
|
|
83
83
|
|
|
84
84
|
if check_module:
|
|
85
|
+
if project_dir is None:
|
|
86
|
+
project_dir = Path.cwd()
|
|
87
|
+
project_dir = Path(project_dir).absolute()
|
|
85
88
|
# Set the system path
|
|
86
|
-
|
|
89
|
+
sys.path.insert(0, str(project_dir))
|
|
87
90
|
|
|
88
91
|
# Load module
|
|
89
92
|
module = find_spec(module_str)
|
|
93
|
+
|
|
94
|
+
# Unset the system path
|
|
95
|
+
sys.path.remove(str(project_dir))
|
|
96
|
+
|
|
97
|
+
# Check if the module and the attribute exist
|
|
90
98
|
if module and module.origin:
|
|
91
99
|
if not _find_attribute_in_module(module.origin, attributes_str):
|
|
92
100
|
return (
|
|
@@ -133,8 +141,10 @@ def load_app( # pylint: disable= too-many-branches
|
|
|
133
141
|
|
|
134
142
|
Note
|
|
135
143
|
----
|
|
136
|
-
This function will
|
|
137
|
-
|
|
144
|
+
- This function will unload all modules in the previously provided `project_dir`,
|
|
145
|
+
if it is invoked again.
|
|
146
|
+
- This function will modify `sys.path` by inserting the provided `project_dir`
|
|
147
|
+
and removing the previously inserted `project_dir`.
|
|
138
148
|
"""
|
|
139
149
|
valid, error_msg = validate(module_attribute_str, check_module=False)
|
|
140
150
|
if not valid and error_msg:
|
|
@@ -143,8 +153,19 @@ def load_app( # pylint: disable= too-many-branches
|
|
|
143
153
|
module_str, _, attributes_str = module_attribute_str.partition(":")
|
|
144
154
|
|
|
145
155
|
try:
|
|
156
|
+
# Initialize project path
|
|
157
|
+
if project_dir is None:
|
|
158
|
+
project_dir = Path.cwd()
|
|
159
|
+
project_dir = Path(project_dir).absolute()
|
|
160
|
+
|
|
161
|
+
# Unload modules if the project directory has changed
|
|
162
|
+
if _current_sys_path and _current_sys_path != str(project_dir):
|
|
163
|
+
_unload_modules(Path(_current_sys_path))
|
|
164
|
+
|
|
165
|
+
# Set the system path
|
|
146
166
|
_set_sys_path(project_dir)
|
|
147
167
|
|
|
168
|
+
# Import the module
|
|
148
169
|
if module_str not in sys.modules:
|
|
149
170
|
module = importlib.import_module(module_str)
|
|
150
171
|
# Hack: `tabnet` does not work with `importlib.reload`
|
|
@@ -160,15 +181,7 @@ def load_app( # pylint: disable= too-many-branches
|
|
|
160
181
|
module = sys.modules[module_str]
|
|
161
182
|
else:
|
|
162
183
|
module = sys.modules[module_str]
|
|
163
|
-
|
|
164
|
-
if project_dir is None:
|
|
165
|
-
project_dir = Path.cwd()
|
|
166
|
-
|
|
167
|
-
# Reload cached modules in the project directory
|
|
168
|
-
for m in list(sys.modules.values()):
|
|
169
|
-
path: Optional[str] = getattr(m, "__file__", None)
|
|
170
|
-
if path is not None and path.startswith(str(project_dir)):
|
|
171
|
-
importlib.reload(m)
|
|
184
|
+
_reload_modules(project_dir)
|
|
172
185
|
|
|
173
186
|
except ModuleNotFoundError as err:
|
|
174
187
|
raise error_type(
|
|
@@ -189,6 +202,24 @@ def load_app( # pylint: disable= too-many-branches
|
|
|
189
202
|
return attribute
|
|
190
203
|
|
|
191
204
|
|
|
205
|
+
def _unload_modules(project_dir: Path) -> None:
|
|
206
|
+
"""Unload modules from the project directory."""
|
|
207
|
+
dir_str = str(project_dir.absolute())
|
|
208
|
+
for name, m in list(sys.modules.items()):
|
|
209
|
+
path: Optional[str] = getattr(m, "__file__", None)
|
|
210
|
+
if path is not None and path.startswith(dir_str):
|
|
211
|
+
del sys.modules[name]
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _reload_modules(project_dir: Path) -> None:
|
|
215
|
+
"""Reload modules from the project directory."""
|
|
216
|
+
dir_str = str(project_dir.absolute())
|
|
217
|
+
for m in list(sys.modules.values()):
|
|
218
|
+
path: Optional[str] = getattr(m, "__file__", None)
|
|
219
|
+
if path is not None and path.startswith(dir_str):
|
|
220
|
+
importlib.reload(m)
|
|
221
|
+
|
|
222
|
+
|
|
192
223
|
def _set_sys_path(directory: Optional[Union[str, Path]]) -> None:
|
|
193
224
|
"""Set the system path."""
|
|
194
225
|
if directory is None:
|
flwr/common/serde.py
CHANGED
|
@@ -40,6 +40,7 @@ from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
|
|
|
40
40
|
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
|
|
41
41
|
from flwr.proto.recordset_pb2 import SintList, StringList, UintList
|
|
42
42
|
from flwr.proto.run_pb2 import Run as ProtoRun
|
|
43
|
+
from flwr.proto.run_pb2 import RunStatus as ProtoRunStatus
|
|
43
44
|
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
|
|
44
45
|
from flwr.proto.transport_pb2 import (
|
|
45
46
|
ClientMessage,
|
|
@@ -839,6 +840,7 @@ def message_from_proto(message_proto: ProtoMessage) -> Message:
|
|
|
839
840
|
def context_to_proto(context: Context) -> ProtoContext:
|
|
840
841
|
"""Serialize `Context` to ProtoBuf."""
|
|
841
842
|
proto = ProtoContext(
|
|
843
|
+
run_id=context.run_id,
|
|
842
844
|
node_id=context.node_id,
|
|
843
845
|
node_config=user_config_to_proto(context.node_config),
|
|
844
846
|
state=recordset_to_proto(context.state),
|
|
@@ -850,6 +852,7 @@ def context_to_proto(context: Context) -> ProtoContext:
|
|
|
850
852
|
def context_from_proto(context_proto: ProtoContext) -> Context:
|
|
851
853
|
"""Deserialize `Context` from ProtoBuf."""
|
|
852
854
|
context = Context(
|
|
855
|
+
run_id=context_proto.run_id,
|
|
853
856
|
node_id=context_proto.node_id,
|
|
854
857
|
node_config=user_config_from_proto(context_proto.node_config),
|
|
855
858
|
state=recordset_from_proto(context_proto.state),
|
|
@@ -869,6 +872,11 @@ def run_to_proto(run: typing.Run) -> ProtoRun:
|
|
|
869
872
|
fab_version=run.fab_version,
|
|
870
873
|
fab_hash=run.fab_hash,
|
|
871
874
|
override_config=user_config_to_proto(run.override_config),
|
|
875
|
+
pending_at=run.pending_at,
|
|
876
|
+
starting_at=run.starting_at,
|
|
877
|
+
running_at=run.running_at,
|
|
878
|
+
finished_at=run.finished_at,
|
|
879
|
+
status=run_status_to_proto(run.status),
|
|
872
880
|
)
|
|
873
881
|
return proto
|
|
874
882
|
|
|
@@ -881,6 +889,11 @@ def run_from_proto(run_proto: ProtoRun) -> typing.Run:
|
|
|
881
889
|
fab_version=run_proto.fab_version,
|
|
882
890
|
fab_hash=run_proto.fab_hash,
|
|
883
891
|
override_config=user_config_from_proto(run_proto.override_config),
|
|
892
|
+
pending_at=run_proto.pending_at,
|
|
893
|
+
starting_at=run_proto.starting_at,
|
|
894
|
+
running_at=run_proto.running_at,
|
|
895
|
+
finished_at=run_proto.finished_at,
|
|
896
|
+
status=run_status_from_proto(run_proto.status),
|
|
884
897
|
)
|
|
885
898
|
return run
|
|
886
899
|
|
|
@@ -910,3 +923,24 @@ def clientappstatus_from_proto(
|
|
|
910
923
|
if msg.code == ClientAppOutputCode.UNKNOWN_ERROR:
|
|
911
924
|
code = typing.ClientAppOutputCode.UNKNOWN_ERROR
|
|
912
925
|
return typing.ClientAppOutputStatus(code=code, message=msg.message)
|
|
926
|
+
|
|
927
|
+
|
|
928
|
+
# === Run status ===
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
def run_status_to_proto(run_status: typing.RunStatus) -> ProtoRunStatus:
|
|
932
|
+
"""Serialize `RunStatus` to ProtoBuf."""
|
|
933
|
+
return ProtoRunStatus(
|
|
934
|
+
status=run_status.status,
|
|
935
|
+
sub_status=run_status.sub_status,
|
|
936
|
+
details=run_status.details,
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
def run_status_from_proto(run_status_proto: ProtoRunStatus) -> typing.RunStatus:
|
|
941
|
+
"""Deserialize `RunStatus` from ProtoBuf."""
|
|
942
|
+
return typing.RunStatus(
|
|
943
|
+
status=run_status_proto.status,
|
|
944
|
+
sub_status=run_status_proto.sub_status,
|
|
945
|
+
details=run_status_proto.details,
|
|
946
|
+
)
|
flwr/common/telemetry.py
CHANGED
|
@@ -150,12 +150,6 @@ class EventType(str, Enum):
|
|
|
150
150
|
|
|
151
151
|
# Not yet implemented
|
|
152
152
|
|
|
153
|
-
# --- SuperExec --------------------------------------------------------------------
|
|
154
|
-
|
|
155
|
-
# SuperExec
|
|
156
|
-
RUN_SUPEREXEC_ENTER = auto()
|
|
157
|
-
RUN_SUPEREXEC_LEAVE = auto()
|
|
158
|
-
|
|
159
153
|
# --- Simulation Engine ------------------------------------------------------------
|
|
160
154
|
|
|
161
155
|
# CLI: flower-simulation
|