flwr-nightly 1.9.0.dev20240516__py3-none-any.whl → 1.9.0.dev20240519__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flwr-nightly might be problematic. Click here for more details.
- flwr/cli/new/new.py +2 -0
- flwr/cli/new/templates/app/code/client.jax.py.tpl +55 -0
- flwr/cli/new/templates/app/code/server.jax.py.tpl +12 -0
- flwr/cli/new/templates/app/code/task.jax.py.tpl +57 -0
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +28 -0
- flwr/server/driver/__init__.py +3 -2
- flwr/server/driver/inmemory_driver.py +181 -0
- flwr/server/server.py +8 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +6 -1
- flwr/server/workflow/default_workflows.py +48 -25
- flwr/simulation/run_simulation.py +2 -31
- {flwr_nightly-1.9.0.dev20240516.dist-info → flwr_nightly-1.9.0.dev20240519.dist-info}/METADATA +1 -1
- {flwr_nightly-1.9.0.dev20240516.dist-info → flwr_nightly-1.9.0.dev20240519.dist-info}/RECORD +16 -11
- {flwr_nightly-1.9.0.dev20240516.dist-info → flwr_nightly-1.9.0.dev20240519.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.9.0.dev20240516.dist-info → flwr_nightly-1.9.0.dev20240519.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.9.0.dev20240516.dist-info → flwr_nightly-1.9.0.dev20240519.dist-info}/entry_points.txt +0 -0
flwr/cli/new/new.py
CHANGED
|
@@ -37,6 +37,7 @@ class MlFramework(str, Enum):
|
|
|
37
37
|
NUMPY = "NumPy"
|
|
38
38
|
PYTORCH = "PyTorch"
|
|
39
39
|
TENSORFLOW = "TensorFlow"
|
|
40
|
+
JAX = "JAX"
|
|
40
41
|
HUGGINGFACE = "HF"
|
|
41
42
|
MLX = "MLX"
|
|
42
43
|
SKLEARN = "sklearn"
|
|
@@ -155,6 +156,7 @@ def new(
|
|
|
155
156
|
# Depending on the framework, generate task.py file
|
|
156
157
|
frameworks_with_tasks = [
|
|
157
158
|
MlFramework.PYTORCH.value.lower(),
|
|
159
|
+
MlFramework.JAX.value.lower(),
|
|
158
160
|
MlFramework.HUGGINGFACE.value.lower(),
|
|
159
161
|
MlFramework.MLX.value.lower(),
|
|
160
162
|
MlFramework.TENSORFLOW.value.lower(),
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""$project_name: A Flower / JAX app."""
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
from flwr.client import NumPyClient, ClientApp
|
|
5
|
+
|
|
6
|
+
from $import_name.task import (
|
|
7
|
+
evaluation,
|
|
8
|
+
get_params,
|
|
9
|
+
load_data,
|
|
10
|
+
load_model,
|
|
11
|
+
loss_fn,
|
|
12
|
+
set_params,
|
|
13
|
+
train,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# Define Flower Client and client_fn
|
|
18
|
+
class FlowerClient(NumPyClient):
|
|
19
|
+
def __init__(self):
|
|
20
|
+
self.train_x, self.train_y, self.test_x, self.test_y = load_data()
|
|
21
|
+
self.grad_fn = jax.grad(loss_fn)
|
|
22
|
+
model_shape = self.train_x.shape[1:]
|
|
23
|
+
|
|
24
|
+
self.params = load_model(model_shape)
|
|
25
|
+
|
|
26
|
+
def get_parameters(self, config):
|
|
27
|
+
return get_params(self.params)
|
|
28
|
+
|
|
29
|
+
def set_parameters(self, parameters):
|
|
30
|
+
set_params(self.params, parameters)
|
|
31
|
+
|
|
32
|
+
def fit(self, parameters, config):
|
|
33
|
+
self.set_parameters(parameters)
|
|
34
|
+
self.params, loss, num_examples = train(
|
|
35
|
+
self.params, self.grad_fn, self.train_x, self.train_y
|
|
36
|
+
)
|
|
37
|
+
parameters = self.get_parameters(config={})
|
|
38
|
+
return parameters, num_examples, {"loss": float(loss)}
|
|
39
|
+
|
|
40
|
+
def evaluate(self, parameters, config):
|
|
41
|
+
self.set_parameters(parameters)
|
|
42
|
+
loss, num_examples = evaluation(
|
|
43
|
+
self.params, self.grad_fn, self.test_x, self.test_y
|
|
44
|
+
)
|
|
45
|
+
return float(loss), num_examples, {"loss": float(loss)}
|
|
46
|
+
|
|
47
|
+
def client_fn(cid):
|
|
48
|
+
# Return Client instance
|
|
49
|
+
return FlowerClient().to_client()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# Flower ClientApp
|
|
53
|
+
app = ClientApp(
|
|
54
|
+
client_fn,
|
|
55
|
+
)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""$project_name: A Flower / JAX app."""
|
|
2
|
+
|
|
3
|
+
import flwr as fl
|
|
4
|
+
|
|
5
|
+
# Configure the strategy
|
|
6
|
+
strategy = fl.server.strategy.FedAvg()
|
|
7
|
+
|
|
8
|
+
# Flower ServerApp
|
|
9
|
+
app = fl.server.ServerApp(
|
|
10
|
+
config=fl.server.ServerConfig(num_rounds=3),
|
|
11
|
+
strategy=strategy,
|
|
12
|
+
)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""$project_name: A Flower / JAX app."""
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from sklearn.datasets import make_regression
|
|
6
|
+
from sklearn.model_selection import train_test_split
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
key = jax.random.PRNGKey(0)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def load_data():
|
|
13
|
+
# Load dataset
|
|
14
|
+
X, y = make_regression(n_features=3, random_state=0)
|
|
15
|
+
X, X_test, y, y_test = train_test_split(X, y)
|
|
16
|
+
return X, y, X_test, y_test
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def load_model(model_shape):
|
|
20
|
+
# Extract model parameters
|
|
21
|
+
params = {"b": jax.random.uniform(key), "w": jax.random.uniform(key, model_shape)}
|
|
22
|
+
return params
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def loss_fn(params, X, y):
|
|
26
|
+
# Return MSE as loss
|
|
27
|
+
err = jnp.dot(X, params["w"]) + params["b"] - y
|
|
28
|
+
return jnp.mean(jnp.square(err))
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def train(params, grad_fn, X, y):
|
|
32
|
+
loss = 1_000_000
|
|
33
|
+
num_examples = X.shape[0]
|
|
34
|
+
for epochs in range(50):
|
|
35
|
+
grads = grad_fn(params, X, y)
|
|
36
|
+
params = jax.tree.map(lambda p, g: p - 0.05 * g, params, grads)
|
|
37
|
+
loss = loss_fn(params, X, y)
|
|
38
|
+
return params, loss, num_examples
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def evaluation(params, grad_fn, X_test, y_test):
|
|
42
|
+
num_examples = X_test.shape[0]
|
|
43
|
+
err_test = loss_fn(params, X_test, y_test)
|
|
44
|
+
loss_test = jnp.mean(jnp.square(err_test))
|
|
45
|
+
return loss_test, num_examples
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_params(params):
|
|
49
|
+
parameters = []
|
|
50
|
+
for _, val in params.items():
|
|
51
|
+
parameters.append(np.array(val))
|
|
52
|
+
return parameters
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def set_params(local_params, global_params):
|
|
56
|
+
for key, value in list(zip(local_params.keys(), global_params)):
|
|
57
|
+
local_params[key] = value
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "$package_name"
|
|
7
|
+
version = "1.0.0"
|
|
8
|
+
description = ""
|
|
9
|
+
authors = [
|
|
10
|
+
{ name = "The Flower Authors", email = "hello@flower.ai" },
|
|
11
|
+
]
|
|
12
|
+
license = {text = "Apache License (2.0)"}
|
|
13
|
+
dependencies = [
|
|
14
|
+
"flwr[simulation]>=1.8.0,<2.0",
|
|
15
|
+
"jax==0.4.26",
|
|
16
|
+
"jaxlib==0.4.26",
|
|
17
|
+
"scikit-learn==1.4.2",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
[tool.hatch.build.targets.wheel]
|
|
21
|
+
packages = ["."]
|
|
22
|
+
|
|
23
|
+
[flower]
|
|
24
|
+
publisher = "$username"
|
|
25
|
+
|
|
26
|
+
[flower.components]
|
|
27
|
+
serverapp = "$import_name.server:app"
|
|
28
|
+
clientapp = "$import_name.client:app"
|
flwr/server/driver/__init__.py
CHANGED
|
@@ -16,10 +16,11 @@
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
from .driver import Driver
|
|
19
|
-
from .grpc_driver import GrpcDriver
|
|
19
|
+
from .grpc_driver import GrpcDriver
|
|
20
|
+
from .inmemory_driver import InMemoryDriver
|
|
20
21
|
|
|
21
22
|
__all__ = [
|
|
22
23
|
"Driver",
|
|
23
24
|
"GrpcDriver",
|
|
24
|
-
"
|
|
25
|
+
"InMemoryDriver",
|
|
25
26
|
]
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Flower in-memory Driver."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
import time
|
|
19
|
+
import warnings
|
|
20
|
+
from typing import Iterable, List, Optional
|
|
21
|
+
from uuid import UUID
|
|
22
|
+
|
|
23
|
+
from flwr.common import DEFAULT_TTL, Message, Metadata, RecordSet
|
|
24
|
+
from flwr.common.serde import message_from_taskres, message_to_taskins
|
|
25
|
+
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
|
|
26
|
+
from flwr.server.superlink.state import StateFactory
|
|
27
|
+
|
|
28
|
+
from .driver import Driver
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class InMemoryDriver(Driver):
|
|
32
|
+
"""`InMemoryDriver` class provides an interface to the Driver API.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
state_factory : StateFactory
|
|
37
|
+
A StateFactory embedding a state that this driver can interface with.
|
|
38
|
+
fab_id : str (default: None)
|
|
39
|
+
The identifier of the FAB used in the run.
|
|
40
|
+
fab_version : str (default: None)
|
|
41
|
+
The version of the FAB used in the run.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
state_factory: StateFactory,
|
|
47
|
+
fab_id: Optional[str] = None,
|
|
48
|
+
fab_version: Optional[str] = None,
|
|
49
|
+
) -> None:
|
|
50
|
+
self.run_id: Optional[int] = None
|
|
51
|
+
self.fab_id = fab_id if fab_id is not None else ""
|
|
52
|
+
self.fab_version = fab_version if fab_version is not None else ""
|
|
53
|
+
self.node = Node(node_id=0, anonymous=True)
|
|
54
|
+
self.state = state_factory.state()
|
|
55
|
+
|
|
56
|
+
def _check_message(self, message: Message) -> None:
|
|
57
|
+
# Check if the message is valid
|
|
58
|
+
if not (
|
|
59
|
+
message.metadata.run_id == self.run_id
|
|
60
|
+
and message.metadata.src_node_id == self.node.node_id
|
|
61
|
+
and message.metadata.message_id == ""
|
|
62
|
+
and message.metadata.reply_to_message == ""
|
|
63
|
+
and message.metadata.ttl > 0
|
|
64
|
+
):
|
|
65
|
+
raise ValueError(f"Invalid message: {message}")
|
|
66
|
+
|
|
67
|
+
def _get_run_id(self) -> int:
|
|
68
|
+
"""Return run_id.
|
|
69
|
+
|
|
70
|
+
If unset, create a new run.
|
|
71
|
+
"""
|
|
72
|
+
if self.run_id is None:
|
|
73
|
+
self.run_id = self.state.create_run(
|
|
74
|
+
fab_id=self.fab_id, fab_version=self.fab_version
|
|
75
|
+
)
|
|
76
|
+
return self.run_id
|
|
77
|
+
|
|
78
|
+
def create_message( # pylint: disable=too-many-arguments
|
|
79
|
+
self,
|
|
80
|
+
content: RecordSet,
|
|
81
|
+
message_type: str,
|
|
82
|
+
dst_node_id: int,
|
|
83
|
+
group_id: str,
|
|
84
|
+
ttl: Optional[float] = None,
|
|
85
|
+
) -> Message:
|
|
86
|
+
"""Create a new message with specified parameters.
|
|
87
|
+
|
|
88
|
+
This method constructs a new `Message` with given content and metadata.
|
|
89
|
+
The `run_id` and `src_node_id` will be set automatically.
|
|
90
|
+
"""
|
|
91
|
+
run_id = self._get_run_id()
|
|
92
|
+
if ttl:
|
|
93
|
+
warnings.warn(
|
|
94
|
+
"A custom TTL was set, but note that the SuperLink does not enforce "
|
|
95
|
+
"the TTL yet. The SuperLink will start enforcing the TTL in a future "
|
|
96
|
+
"version of Flower.",
|
|
97
|
+
stacklevel=2,
|
|
98
|
+
)
|
|
99
|
+
ttl_ = DEFAULT_TTL if ttl is None else ttl
|
|
100
|
+
|
|
101
|
+
metadata = Metadata(
|
|
102
|
+
run_id=run_id,
|
|
103
|
+
message_id="", # Will be set by the server
|
|
104
|
+
src_node_id=self.node.node_id,
|
|
105
|
+
dst_node_id=dst_node_id,
|
|
106
|
+
reply_to_message="",
|
|
107
|
+
group_id=group_id,
|
|
108
|
+
ttl=ttl_,
|
|
109
|
+
message_type=message_type,
|
|
110
|
+
)
|
|
111
|
+
return Message(metadata=metadata, content=content)
|
|
112
|
+
|
|
113
|
+
def get_node_ids(self) -> List[int]:
|
|
114
|
+
"""Get node IDs."""
|
|
115
|
+
run_id = self._get_run_id()
|
|
116
|
+
return list(self.state.get_nodes(run_id))
|
|
117
|
+
|
|
118
|
+
def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
|
|
119
|
+
"""Push messages to specified node IDs.
|
|
120
|
+
|
|
121
|
+
This method takes an iterable of messages and sends each message
|
|
122
|
+
to the node specified in `dst_node_id`.
|
|
123
|
+
"""
|
|
124
|
+
task_ids: List[str] = []
|
|
125
|
+
for msg in messages:
|
|
126
|
+
# Check message
|
|
127
|
+
self._check_message(msg)
|
|
128
|
+
# Convert Message to TaskIns
|
|
129
|
+
taskins = message_to_taskins(msg)
|
|
130
|
+
# Store in state
|
|
131
|
+
taskins.task.pushed_at = time.time()
|
|
132
|
+
task_id = self.state.store_task_ins(taskins)
|
|
133
|
+
if task_id:
|
|
134
|
+
task_ids.append(str(task_id))
|
|
135
|
+
|
|
136
|
+
return task_ids
|
|
137
|
+
|
|
138
|
+
def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
|
|
139
|
+
"""Pull messages based on message IDs.
|
|
140
|
+
|
|
141
|
+
This method is used to collect messages from the SuperLink that correspond to a
|
|
142
|
+
set of given message IDs.
|
|
143
|
+
"""
|
|
144
|
+
msg_ids = {UUID(msg_id) for msg_id in message_ids}
|
|
145
|
+
# Pull TaskRes
|
|
146
|
+
task_res_list = self.state.get_task_res(task_ids=msg_ids, limit=len(msg_ids))
|
|
147
|
+
# Delete tasks in state
|
|
148
|
+
self.state.delete_tasks(msg_ids)
|
|
149
|
+
# Convert TaskRes to Message
|
|
150
|
+
msgs = [message_from_taskres(taskres) for taskres in task_res_list]
|
|
151
|
+
return msgs
|
|
152
|
+
|
|
153
|
+
def send_and_receive(
|
|
154
|
+
self,
|
|
155
|
+
messages: Iterable[Message],
|
|
156
|
+
*,
|
|
157
|
+
timeout: Optional[float] = None,
|
|
158
|
+
) -> Iterable[Message]:
|
|
159
|
+
"""Push messages to specified node IDs and pull the reply messages.
|
|
160
|
+
|
|
161
|
+
This method sends a list of messages to their destination node IDs and then
|
|
162
|
+
waits for the replies. It continues to pull replies until either all replies are
|
|
163
|
+
received or the specified timeout duration is exceeded.
|
|
164
|
+
"""
|
|
165
|
+
# Push messages
|
|
166
|
+
msg_ids = set(self.push_messages(messages))
|
|
167
|
+
|
|
168
|
+
# Pull messages
|
|
169
|
+
end_time = time.time() + (timeout if timeout is not None else 0.0)
|
|
170
|
+
ret: List[Message] = []
|
|
171
|
+
while timeout is None or time.time() < end_time:
|
|
172
|
+
res_msgs = self.pull_messages(msg_ids)
|
|
173
|
+
ret.extend(res_msgs)
|
|
174
|
+
msg_ids.difference_update(
|
|
175
|
+
{msg.metadata.reply_to_message for msg in res_msgs}
|
|
176
|
+
)
|
|
177
|
+
if len(msg_ids) == 0:
|
|
178
|
+
break
|
|
179
|
+
# Sleep
|
|
180
|
+
time.sleep(3)
|
|
181
|
+
return ret
|
flwr/server/server.py
CHANGED
|
@@ -282,7 +282,14 @@ class Server:
|
|
|
282
282
|
get_parameters_res = random_client.get_parameters(
|
|
283
283
|
ins=ins, timeout=timeout, group_id=server_round
|
|
284
284
|
)
|
|
285
|
-
|
|
285
|
+
if get_parameters_res.status.code == Code.OK:
|
|
286
|
+
log(INFO, "Received initial parameters from one random client")
|
|
287
|
+
else:
|
|
288
|
+
log(
|
|
289
|
+
WARN,
|
|
290
|
+
"Failed to receive initial parameters from the client."
|
|
291
|
+
" Empty initial parameters will be used.",
|
|
292
|
+
)
|
|
286
293
|
return get_parameters_res.parameters
|
|
287
294
|
|
|
288
295
|
|
|
@@ -55,7 +55,12 @@ class RayBackend(Backend):
|
|
|
55
55
|
runtime_env = (
|
|
56
56
|
self._configure_runtime_env(work_dir=work_dir) if work_dir else None
|
|
57
57
|
)
|
|
58
|
-
|
|
58
|
+
|
|
59
|
+
if backend_config.get("mute_logging", False):
|
|
60
|
+
init_ray(
|
|
61
|
+
logging_level=WARNING, log_to_driver=False, runtime_env=runtime_env
|
|
62
|
+
)
|
|
63
|
+
elif backend_config.get("silent", False):
|
|
59
64
|
init_ray(logging_level=WARNING, log_to_driver=True, runtime_env=runtime_env)
|
|
60
65
|
else:
|
|
61
66
|
init_ray(runtime_env=runtime_env)
|
|
@@ -18,12 +18,22 @@
|
|
|
18
18
|
import io
|
|
19
19
|
import timeit
|
|
20
20
|
from logging import INFO, WARN
|
|
21
|
-
from typing import Optional, cast
|
|
21
|
+
from typing import List, Optional, Tuple, Union, cast
|
|
22
22
|
|
|
23
23
|
import flwr.common.recordset_compat as compat
|
|
24
|
-
from flwr.common import
|
|
24
|
+
from flwr.common import (
|
|
25
|
+
Code,
|
|
26
|
+
ConfigsRecord,
|
|
27
|
+
Context,
|
|
28
|
+
EvaluateRes,
|
|
29
|
+
FitRes,
|
|
30
|
+
GetParametersIns,
|
|
31
|
+
ParametersRecord,
|
|
32
|
+
log,
|
|
33
|
+
)
|
|
25
34
|
from flwr.common.constant import MessageType, MessageTypeLegacy
|
|
26
35
|
|
|
36
|
+
from ..client_proxy import ClientProxy
|
|
27
37
|
from ..compat.app_utils import start_update_client_manager_thread
|
|
28
38
|
from ..compat.legacy_context import LegacyContext
|
|
29
39
|
from ..driver import Driver
|
|
@@ -136,7 +146,14 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None:
|
|
|
136
146
|
]
|
|
137
147
|
)
|
|
138
148
|
msg = list(messages)[0]
|
|
139
|
-
|
|
149
|
+
|
|
150
|
+
if (
|
|
151
|
+
msg.has_content()
|
|
152
|
+
and compat._extract_status_from_recordset( # pylint: disable=W0212
|
|
153
|
+
"getparametersres", msg.content
|
|
154
|
+
).code
|
|
155
|
+
== Code.OK
|
|
156
|
+
):
|
|
140
157
|
log(INFO, "Received initial parameters from one random client")
|
|
141
158
|
paramsrecord = next(iter(msg.content.parameters_records.values()))
|
|
142
159
|
else:
|
|
@@ -257,18 +274,20 @@ def default_fit_workflow( # pylint: disable=R0914
|
|
|
257
274
|
)
|
|
258
275
|
|
|
259
276
|
# Aggregate training results
|
|
260
|
-
results = [
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
277
|
+
results: List[Tuple[ClientProxy, FitRes]] = []
|
|
278
|
+
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = []
|
|
279
|
+
for msg in messages:
|
|
280
|
+
if msg.has_content():
|
|
281
|
+
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
282
|
+
fitres = compat.recordset_to_fitres(msg.content, False)
|
|
283
|
+
if fitres.status.code == Code.OK:
|
|
284
|
+
results.append((proxy, fitres))
|
|
285
|
+
else:
|
|
286
|
+
failures.append((proxy, fitres))
|
|
287
|
+
else:
|
|
288
|
+
failures.append(Exception(msg.error))
|
|
289
|
+
|
|
290
|
+
aggregated_result = context.strategy.aggregate_fit(current_round, results, failures)
|
|
272
291
|
parameters_aggregated, metrics_aggregated = aggregated_result
|
|
273
292
|
|
|
274
293
|
# Update the parameters and write history
|
|
@@ -341,17 +360,21 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None:
|
|
|
341
360
|
)
|
|
342
361
|
|
|
343
362
|
# Aggregate the evaluation results
|
|
344
|
-
results = [
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
363
|
+
results: List[Tuple[ClientProxy, EvaluateRes]] = []
|
|
364
|
+
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = []
|
|
365
|
+
for msg in messages:
|
|
366
|
+
if msg.has_content():
|
|
367
|
+
proxy = node_id_to_proxy[msg.metadata.src_node_id]
|
|
368
|
+
evalres = compat.recordset_to_evaluateres(msg.content)
|
|
369
|
+
if evalres.status.code == Code.OK:
|
|
370
|
+
results.append((proxy, evalres))
|
|
371
|
+
else:
|
|
372
|
+
failures.append((proxy, evalres))
|
|
373
|
+
else:
|
|
374
|
+
failures.append(Exception(msg.error))
|
|
375
|
+
|
|
353
376
|
aggregated_result = context.strategy.aggregate_evaluate(
|
|
354
|
-
current_round, results, failures
|
|
377
|
+
current_round, results, failures
|
|
355
378
|
)
|
|
356
379
|
|
|
357
380
|
loss_aggregated, metrics_aggregated = aggregated_result
|
|
@@ -24,16 +24,13 @@ from logging import DEBUG, ERROR, INFO, WARNING
|
|
|
24
24
|
from time import sleep
|
|
25
25
|
from typing import Dict, Optional
|
|
26
26
|
|
|
27
|
-
import grpc
|
|
28
|
-
|
|
29
27
|
from flwr.client import ClientApp
|
|
30
28
|
from flwr.common import EventType, event, log
|
|
31
29
|
from flwr.common.logger import set_logger_propagation, update_console_handler
|
|
32
30
|
from flwr.common.typing import ConfigsRecordValues
|
|
33
|
-
from flwr.server.driver import Driver,
|
|
31
|
+
from flwr.server.driver import Driver, InMemoryDriver
|
|
34
32
|
from flwr.server.run_serverapp import run
|
|
35
33
|
from flwr.server.server_app import ServerApp
|
|
36
|
-
from flwr.server.superlink.driver.driver_grpc import run_driver_api_grpc
|
|
37
34
|
from flwr.server.superlink.fleet import vce
|
|
38
35
|
from flwr.server.superlink.state import StateFactory
|
|
39
36
|
from flwr.simulation.ray_transport.utils import (
|
|
@@ -56,7 +53,6 @@ def run_simulation_from_cli() -> None:
|
|
|
56
53
|
backend_name=args.backend,
|
|
57
54
|
backend_config=backend_config_dict,
|
|
58
55
|
app_dir=args.app_dir,
|
|
59
|
-
driver_api_address=args.driver_api_address,
|
|
60
56
|
enable_tf_gpu_growth=args.enable_tf_gpu_growth,
|
|
61
57
|
verbose_logging=args.verbose,
|
|
62
58
|
)
|
|
@@ -177,7 +173,6 @@ def _main_loop(
|
|
|
177
173
|
num_supernodes: int,
|
|
178
174
|
backend_name: str,
|
|
179
175
|
backend_config_stream: str,
|
|
180
|
-
driver_api_address: str,
|
|
181
176
|
app_dir: str,
|
|
182
177
|
enable_tf_gpu_growth: bool,
|
|
183
178
|
client_app: Optional[ClientApp] = None,
|
|
@@ -194,21 +189,11 @@ def _main_loop(
|
|
|
194
189
|
# Initialize StateFactory
|
|
195
190
|
state_factory = StateFactory(":flwr-in-memory-state:")
|
|
196
191
|
|
|
197
|
-
# Start Driver API
|
|
198
|
-
driver_server: grpc.Server = run_driver_api_grpc(
|
|
199
|
-
address=driver_api_address,
|
|
200
|
-
state_factory=state_factory,
|
|
201
|
-
certificates=None,
|
|
202
|
-
)
|
|
203
|
-
|
|
204
192
|
f_stop = asyncio.Event()
|
|
205
193
|
serverapp_th = None
|
|
206
194
|
try:
|
|
207
195
|
# Initialize Driver
|
|
208
|
-
driver =
|
|
209
|
-
driver_service_address=driver_api_address,
|
|
210
|
-
root_certificates=None,
|
|
211
|
-
)
|
|
196
|
+
driver = InMemoryDriver(state_factory)
|
|
212
197
|
|
|
213
198
|
# Get and run ServerApp thread
|
|
214
199
|
serverapp_th = run_serverapp_th(
|
|
@@ -239,9 +224,6 @@ def _main_loop(
|
|
|
239
224
|
raise RuntimeError("An error was encountered. Ending simulation.") from ex
|
|
240
225
|
|
|
241
226
|
finally:
|
|
242
|
-
# Stop Driver
|
|
243
|
-
driver_server.stop(grace=0)
|
|
244
|
-
driver.close()
|
|
245
227
|
# Trigger stop event
|
|
246
228
|
f_stop.set()
|
|
247
229
|
|
|
@@ -262,7 +244,6 @@ def _run_simulation(
|
|
|
262
244
|
client_app_attr: Optional[str] = None,
|
|
263
245
|
server_app_attr: Optional[str] = None,
|
|
264
246
|
app_dir: str = "",
|
|
265
|
-
driver_api_address: str = "0.0.0.0:9091",
|
|
266
247
|
enable_tf_gpu_growth: bool = False,
|
|
267
248
|
verbose_logging: bool = False,
|
|
268
249
|
) -> None:
|
|
@@ -302,9 +283,6 @@ def _run_simulation(
|
|
|
302
283
|
Add specified directory to the PYTHONPATH and load `ClientApp` from there.
|
|
303
284
|
(Default: current working directory.)
|
|
304
285
|
|
|
305
|
-
driver_api_address : str (default: "0.0.0.0:9091")
|
|
306
|
-
Driver API (gRPC) server address (IPv4, IPv6, or a domain name)
|
|
307
|
-
|
|
308
286
|
enable_tf_gpu_growth : bool (default: False)
|
|
309
287
|
A boolean to indicate whether to enable GPU growth on the main thread. This is
|
|
310
288
|
desirable if you make use of a TensorFlow model on your `ServerApp` while
|
|
@@ -342,7 +320,6 @@ def _run_simulation(
|
|
|
342
320
|
num_supernodes,
|
|
343
321
|
backend_name,
|
|
344
322
|
backend_config_stream,
|
|
345
|
-
driver_api_address,
|
|
346
323
|
app_dir,
|
|
347
324
|
enable_tf_gpu_growth,
|
|
348
325
|
client_app,
|
|
@@ -399,12 +376,6 @@ def _parse_args_run_simulation() -> argparse.ArgumentParser:
|
|
|
399
376
|
required=True,
|
|
400
377
|
help="Number of simulated SuperNodes.",
|
|
401
378
|
)
|
|
402
|
-
parser.add_argument(
|
|
403
|
-
"--driver-api-address",
|
|
404
|
-
default="0.0.0.0:9091",
|
|
405
|
-
type=str,
|
|
406
|
-
help="For example: `server:app` or `project.package.module:wrapper.app`",
|
|
407
|
-
)
|
|
408
379
|
parser.add_argument(
|
|
409
380
|
"--backend",
|
|
410
381
|
default="ray",
|
{flwr_nightly-1.9.0.dev20240516.dist-info → flwr_nightly-1.9.0.dev20240519.dist-info}/RECORD
RENAMED
|
@@ -5,7 +5,7 @@ flwr/cli/build.py,sha256=W30wnPSgFuHRnGB9G_vKO14rsaibWk7m-jv9r8rDqo4,5106
|
|
|
5
5
|
flwr/cli/config_utils.py,sha256=Hql5A5hbSpJ51hgpwaTkKqfPoaZN4Zq7FZfBuQYLMcQ,4899
|
|
6
6
|
flwr/cli/example.py,sha256=1bGDYll3BXQY2kRqSN-oICqS5n1b9m0g0RvXTopXHl4,2215
|
|
7
7
|
flwr/cli/new/__init__.py,sha256=cQzK1WH4JP2awef1t2UQ2xjl1agVEz9rwutV18SWV1k,789
|
|
8
|
-
flwr/cli/new/new.py,sha256=
|
|
8
|
+
flwr/cli/new/new.py,sha256=7BWziuEOE15MXX4xNLH-w0-x0ytOEfYn_AUrbaDp13Y,6223
|
|
9
9
|
flwr/cli/new/templates/__init__.py,sha256=4luU8RL-CK8JJCstQ_ON809W9bNTkY1l9zSaPKBkgwY,725
|
|
10
10
|
flwr/cli/new/templates/app/.gitignore.tpl,sha256=XixnHdyeMB2vwkGtGnwHqoWpH-9WChdyG0GXe57duhc,3078
|
|
11
11
|
flwr/cli/new/templates/app/README.md.tpl,sha256=_qGtgpKYKoCJVjQnvlBMKvFs_1gzTcL908I3KJg0oAM,668
|
|
@@ -13,22 +13,26 @@ flwr/cli/new/templates/app/__init__.py,sha256=DU7QMY7IhMQyuwm_tja66xU0KXTWQFqzfT
|
|
|
13
13
|
flwr/cli/new/templates/app/code/__init__.py,sha256=EM6vfvgAILKPaPn7H1wMV1Wi01WyZCP_Eg6NxD6oWg8,736
|
|
14
14
|
flwr/cli/new/templates/app/code/__init__.py.tpl,sha256=olwrBeJemHNBWvjc6gJURloFRqW40dAy7FRQA5pDqHU,21
|
|
15
15
|
flwr/cli/new/templates/app/code/client.hf.py.tpl,sha256=RaN89A8HgKp6kjhzH8tgtDSWW8BwwcvJdqRLcvG04zw,1450
|
|
16
|
+
flwr/cli/new/templates/app/code/client.jax.py.tpl,sha256=MtxhwQzxAWNvlz8B-L-2a8LXcgaLPW8dp5K0vBZHR_o,1434
|
|
16
17
|
flwr/cli/new/templates/app/code/client.mlx.py.tpl,sha256=53wJy6s3zk4CZwob_qPmMoOqJ-LZNKbdDe_hw5LwOXE,2113
|
|
17
18
|
flwr/cli/new/templates/app/code/client.numpy.py.tpl,sha256=mTh7Y_jOJrPUvDYHVJy4wJCnjXZV_q-jlDkB07U5GSk,521
|
|
18
19
|
flwr/cli/new/templates/app/code/client.pytorch.py.tpl,sha256=MgCtMSv1Th16Faod11HubVaARkLYt7vS9RYH962-2pk,1172
|
|
19
20
|
flwr/cli/new/templates/app/code/client.sklearn.py.tpl,sha256=S71SZiHaRXtKqUk3m5Elc_c6HhKAIKLalrKOQ3p20No,2801
|
|
20
21
|
flwr/cli/new/templates/app/code/client.tensorflow.py.tpl,sha256=dxrTO9JwYrDBjLsmCiRLetN9KxbnWRTeGA0BQbnOu_A,1280
|
|
21
22
|
flwr/cli/new/templates/app/code/server.hf.py.tpl,sha256=Mld452y3SUkejlFzac5hpCjT7_mbA0ZEEMJIUyHtSTI,338
|
|
23
|
+
flwr/cli/new/templates/app/code/server.jax.py.tpl,sha256=YTi-wroUpjRDY_AZqnoN5X-n3U5V7laL6UJgqFLEbKE,246
|
|
22
24
|
flwr/cli/new/templates/app/code/server.mlx.py.tpl,sha256=Cqk3PvM0e7hzohXPqD5hG_cthXoxCfc30bpEThqMy7M,272
|
|
23
25
|
flwr/cli/new/templates/app/code/server.numpy.py.tpl,sha256=fRxrDXV7pB1aDhQUXMBmrCsC1zp0uKwsBxZBx1JzbHA,248
|
|
24
26
|
flwr/cli/new/templates/app/code/server.pytorch.py.tpl,sha256=ltdsnFSvFGPcycVmRL4ITlr-TV0CmmXcperZe7Vamow,593
|
|
25
27
|
flwr/cli/new/templates/app/code/server.sklearn.py.tpl,sha256=cLzOpQzGIUzEazuFsjBpXAQUNPy6in6zR33SCqhix6o,341
|
|
26
28
|
flwr/cli/new/templates/app/code/server.tensorflow.py.tpl,sha256=gsNrWCKTU77_65_gw9nlp1LSQojgP5QQIWILvqdjx2s,579
|
|
27
29
|
flwr/cli/new/templates/app/code/task.hf.py.tpl,sha256=Rw8cnds4Ym8o8TOq6kMkwlBJfIfvsfnb02jwyulOgF8,2857
|
|
30
|
+
flwr/cli/new/templates/app/code/task.jax.py.tpl,sha256=u4o3V019EH79szOw2xzVeC5r9xgQiayPi9ZTIopV2TA,1519
|
|
28
31
|
flwr/cli/new/templates/app/code/task.mlx.py.tpl,sha256=y7aVj3F_98-wBnDcbPsCNnFs9BOHTn0y6XIYkByzv7Y,2598
|
|
29
32
|
flwr/cli/new/templates/app/code/task.pytorch.py.tpl,sha256=NvajdZN-eTyfdqKK0v2MrvWITXw9BjJ3Ri5c1haPJDs,3684
|
|
30
33
|
flwr/cli/new/templates/app/code/task.tensorflow.py.tpl,sha256=cPOUUS07QbblT9PGFucwu9lY1clRA4-W4DQGA7cpcao,1044
|
|
31
34
|
flwr/cli/new/templates/app/pyproject.hf.toml.tpl,sha256=PNGBNTfWmNJ23aVnW5f1TMMJ0uEwIljevpOsI-mqX08,676
|
|
35
|
+
flwr/cli/new/templates/app/pyproject.jax.toml.tpl,sha256=o34H5MvQeu4H2nRolbIas9G63mR7nDDL4rqQMlJW6LA,568
|
|
32
36
|
flwr/cli/new/templates/app/pyproject.mlx.toml.tpl,sha256=JCEsuHZffO1KKkN65rSp6N-A9-OW8-kl6EQp5Z2H3uE,585
|
|
33
37
|
flwr/cli/new/templates/app/pyproject.numpy.toml.tpl,sha256=m276SKsjOZ4awGdXasUKvLim66agrpAsPNP9-PN6q4I,523
|
|
34
38
|
flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl,sha256=QikP3u5ht6qr2BkgcnvB3rCYK7jt1cS0nAm7V8g_zFc,592
|
|
@@ -143,12 +147,13 @@ flwr/server/compat/app_utils.py,sha256=06NHrPRPrjMjz5FglSPicJ9lAWZ-rIZ1cKQFs4nD6
|
|
|
143
147
|
flwr/server/compat/driver_client_proxy.py,sha256=Wc6jyyHY4OrJzeiy8tdXtkF8IdGREdxUPnom7VvvWPI,5444
|
|
144
148
|
flwr/server/compat/legacy_context.py,sha256=D2s7PvQoDnTexuRmf1uG9Von7GUj4Qqyr7qLklSlKAM,1766
|
|
145
149
|
flwr/server/criterion.py,sha256=ypbAexbztzGUxNen9RCHF91QeqiEQix4t4Ih3E-42MM,1061
|
|
146
|
-
flwr/server/driver/__init__.py,sha256=
|
|
150
|
+
flwr/server/driver/__init__.py,sha256=bikRv6CjTwSvYh7tf10gziU5o2YotOWhhftz2tr3KDc,886
|
|
147
151
|
flwr/server/driver/driver.py,sha256=t9SSSDlo9wT_y2Nl7waGYMTm2VlkvK3_bOb7ggPPlho,5090
|
|
148
152
|
flwr/server/driver/grpc_driver.py,sha256=rdjkcAmtRWKeqJw4xDFqULuwVf0G2nLhfbOTrNUvPeY,11832
|
|
153
|
+
flwr/server/driver/inmemory_driver.py,sha256=XfdLV3mVorTWBfthBkErJDLm8jXZ834IHF3139lTS5o,6490
|
|
149
154
|
flwr/server/history.py,sha256=bBOHKyX1eQONIsUx4EUU-UnAk1i0EbEl8ioyMq_UWQ8,5063
|
|
150
155
|
flwr/server/run_serverapp.py,sha256=avLi_yRNE5jD2ql95gzh04BTUbHvzH-N848_mdnnkVk,5972
|
|
151
|
-
flwr/server/server.py,sha256=
|
|
156
|
+
flwr/server/server.py,sha256=wsXsxMZ9SQ0B42nBnUlcV83NJPycgrgg5bFwcQ4BYBE,17821
|
|
152
157
|
flwr/server/server_app.py,sha256=KgAT_HqsfseTLNnfX2ph42PBbVqQ0lFzvYrT90V34y0,4402
|
|
153
158
|
flwr/server/server_config.py,sha256=CZaHVAsMvGLjpWVcLPkiYxgJN4xfIyAiUrCI3fETKY4,1349
|
|
154
159
|
flwr/server/strategy/__init__.py,sha256=7eVZ3hQEg2BgA_usAeL6tsLp9T6XI1VYYoFy08Xn-ew,2836
|
|
@@ -195,7 +200,7 @@ flwr/server/superlink/fleet/rest_rere/rest_api.py,sha256=8gNziOjBA8ygTzfVPYiNkg_
|
|
|
195
200
|
flwr/server/superlink/fleet/vce/__init__.py,sha256=36MHKiefnJeyjwMQzVUK4m06Ojon3WDcwZGQsAcyVhQ,783
|
|
196
201
|
flwr/server/superlink/fleet/vce/backend/__init__.py,sha256=oBIzmnrSSRvH_H0vRGEGWhWzQQwqe3zn6e13RsNwlIY,1466
|
|
197
202
|
flwr/server/superlink/fleet/vce/backend/backend.py,sha256=LJsKl7oixVvptcG98Rd9ejJycNWcEVB0ODvSreLGp-A,2260
|
|
198
|
-
flwr/server/superlink/fleet/vce/backend/raybackend.py,sha256=
|
|
203
|
+
flwr/server/superlink/fleet/vce/backend/raybackend.py,sha256=KCzV-n-czXxIKPwNfuD-JEVCl4-xAJaHe4taGmw9cTQ,6722
|
|
199
204
|
flwr/server/superlink/fleet/vce/vce_api.py,sha256=aH-1h1EhTPCxdiqgH0_t8oDPiXX8VNNLV_BiDvu6kRk,12456
|
|
200
205
|
flwr/server/superlink/state/__init__.py,sha256=ij-7Ms-hyordQdRmGQxY1-nVa4OhixJ0jr7_YDkys0s,1003
|
|
201
206
|
flwr/server/superlink/state/in_memory_state.py,sha256=WoIOwgayuCu1DLRkkV6KgBsc28SKzSDxtXwO2a9Phuw,12750
|
|
@@ -209,7 +214,7 @@ flwr/server/utils/tensorboard.py,sha256=k0G6bqsLx7wfYbH2KtXsDYcOCfyIeE12-hefXA7l
|
|
|
209
214
|
flwr/server/utils/validator.py,sha256=pzyXoOEEPSoYC2UEzened8IKSFRI-kIqqI0QlwRK9jk,5301
|
|
210
215
|
flwr/server/workflow/__init__.py,sha256=SXY0XkwbkezFBxxrFB5hKUtmtAgnYISBkPouR1V71ss,902
|
|
211
216
|
flwr/server/workflow/constant.py,sha256=q4DLdR8Krlxuewq2AQjwTL75hphxE5ODNz4AhViHMXk,1082
|
|
212
|
-
flwr/server/workflow/default_workflows.py,sha256=
|
|
217
|
+
flwr/server/workflow/default_workflows.py,sha256=_GqFCaxtiq3_UVCvZWgJ200QroGSI9qibeVcT2R71ao,14003
|
|
213
218
|
flwr/server/workflow/secure_aggregation/__init__.py,sha256=3XlgDOjD_hcukTGl6Bc1B-8M_dPlVSJuTbvXIbiO-Ic,880
|
|
214
219
|
flwr/server/workflow/secure_aggregation/secagg_workflow.py,sha256=wpAkYPId0nfK6SgpUAtsCni4_MQLd-uqJ81tUKu3xlI,5838
|
|
215
220
|
flwr/server/workflow/secure_aggregation/secaggplus_workflow.py,sha256=BRqhlnVe8CYNoUvb_KCfRXay02NTT6a-pCrMaOqAxGc,29038
|
|
@@ -219,9 +224,9 @@ flwr/simulation/ray_transport/__init__.py,sha256=FsaAnzC4cw4DqoouBCix6496k29jACk
|
|
|
219
224
|
flwr/simulation/ray_transport/ray_actor.py,sha256=_wv2eP7qxkCZ-6rMyYWnjLrGPBZRxjvTPjaVk8zIaQ4,19367
|
|
220
225
|
flwr/simulation/ray_transport/ray_client_proxy.py,sha256=oDu4sEPIOu39vrNi-fqDAe10xtNUXMO49bM2RWfRcyw,6738
|
|
221
226
|
flwr/simulation/ray_transport/utils.py,sha256=TYdtfg1P9VfTdLMOJlifInGpxWHYs9UfUqIv2wfkRLA,2392
|
|
222
|
-
flwr/simulation/run_simulation.py,sha256=
|
|
223
|
-
flwr_nightly-1.9.0.
|
|
224
|
-
flwr_nightly-1.9.0.
|
|
225
|
-
flwr_nightly-1.9.0.
|
|
226
|
-
flwr_nightly-1.9.0.
|
|
227
|
-
flwr_nightly-1.9.0.
|
|
227
|
+
flwr/simulation/run_simulation.py,sha256=Jmc6DyN5UCY1U1PcDvL04NgYmEQ6ufJ1JisjG5yqfY8,15098
|
|
228
|
+
flwr_nightly-1.9.0.dev20240519.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
229
|
+
flwr_nightly-1.9.0.dev20240519.dist-info/METADATA,sha256=Awsqxt2JHQrRwiGgDZhkqzGIoV-oco2gU58BVxb5YW8,15302
|
|
230
|
+
flwr_nightly-1.9.0.dev20240519.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
|
|
231
|
+
flwr_nightly-1.9.0.dev20240519.dist-info/entry_points.txt,sha256=8JJPfpqMnXz9c5V_FSt07Xwd-wCWbAO3MFUDXQ5ZGsI,378
|
|
232
|
+
flwr_nightly-1.9.0.dev20240519.dist-info/RECORD,,
|
{flwr_nightly-1.9.0.dev20240516.dist-info → flwr_nightly-1.9.0.dev20240519.dist-info}/LICENSE
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|