flwr 1.23.0__py3-none-any.whl → 1.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +16 -5
- flwr/app/error.py +2 -2
- flwr/app/exception.py +3 -3
- flwr/cli/app.py +19 -0
- flwr/cli/app_cmd/__init__.py +23 -0
- flwr/cli/app_cmd/publish.py +285 -0
- flwr/cli/app_cmd/review.py +252 -0
- flwr/cli/auth_plugin/auth_plugin.py +4 -5
- flwr/cli/auth_plugin/noop_auth_plugin.py +54 -11
- flwr/cli/auth_plugin/oidc_cli_plugin.py +32 -9
- flwr/cli/build.py +60 -18
- flwr/cli/cli_account_auth_interceptor.py +24 -7
- flwr/cli/config_utils.py +101 -13
- flwr/cli/federation/__init__.py +24 -0
- flwr/cli/federation/ls.py +140 -0
- flwr/cli/federation/show.py +317 -0
- flwr/cli/install.py +91 -13
- flwr/cli/log.py +52 -9
- flwr/cli/login/login.py +7 -4
- flwr/cli/ls.py +170 -130
- flwr/cli/new/new.py +33 -50
- flwr/cli/new/templates/app/code/task.pytorch.py.tpl +1 -0
- flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
- flwr/cli/new/templates/app/pyproject.pytorch_legacy_api.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
- flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +2 -2
- flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +1 -1
- flwr/cli/pull.py +10 -5
- flwr/cli/run/run.py +77 -30
- flwr/cli/run_utils.py +130 -0
- flwr/cli/stop.py +25 -7
- flwr/cli/supernode/ls.py +16 -8
- flwr/cli/supernode/register.py +9 -4
- flwr/cli/supernode/unregister.py +5 -3
- flwr/cli/utils.py +376 -16
- flwr/client/__init__.py +1 -1
- flwr/client/dpfedavg_numpy_client.py +4 -1
- flwr/client/grpc_adapter_client/connection.py +6 -7
- flwr/client/grpc_rere_client/connection.py +10 -11
- flwr/client/grpc_rere_client/grpc_adapter.py +6 -2
- flwr/client/grpc_rere_client/node_auth_client_interceptor.py +2 -1
- flwr/client/message_handler/message_handler.py +2 -2
- flwr/client/mod/secure_aggregation/secaggplus_mod.py +3 -3
- flwr/client/numpy_client.py +1 -1
- flwr/client/rest_client/connection.py +12 -14
- flwr/client/run_info_store.py +4 -5
- flwr/client/typing.py +1 -1
- flwr/clientapp/client_app.py +9 -10
- flwr/clientapp/mod/centraldp_mods.py +16 -17
- flwr/clientapp/mod/localdp_mod.py +8 -9
- flwr/clientapp/typing.py +1 -1
- flwr/clientapp/utils.py +3 -3
- flwr/common/address.py +1 -2
- flwr/common/args.py +3 -4
- flwr/common/config.py +13 -16
- flwr/common/constant.py +5 -2
- flwr/common/differential_privacy.py +3 -4
- flwr/common/event_log_plugin/event_log_plugin.py +3 -4
- flwr/common/exit/exit.py +15 -2
- flwr/common/exit/exit_code.py +19 -0
- flwr/common/exit/exit_handler.py +6 -2
- flwr/common/exit/signal_handler.py +5 -5
- flwr/common/grpc.py +6 -6
- flwr/common/inflatable_protobuf_utils.py +1 -1
- flwr/common/inflatable_utils.py +38 -21
- flwr/common/logger.py +19 -19
- flwr/common/message.py +4 -4
- flwr/common/object_ref.py +7 -7
- flwr/common/record/array.py +3 -3
- flwr/common/record/arrayrecord.py +18 -30
- flwr/common/record/configrecord.py +3 -3
- flwr/common/record/recorddict.py +5 -5
- flwr/common/record/typeddict.py +9 -2
- flwr/common/recorddict_compat.py +7 -10
- flwr/common/retry_invoker.py +20 -20
- flwr/common/secure_aggregation/ndarrays_arithmetic.py +3 -3
- flwr/common/serde.py +5 -4
- flwr/common/serde_utils.py +2 -2
- flwr/common/telemetry.py +9 -5
- flwr/common/typing.py +52 -37
- flwr/compat/client/app.py +38 -37
- flwr/compat/client/grpc_client/connection.py +11 -11
- flwr/compat/server/app.py +5 -6
- flwr/proto/appio_pb2.py +13 -3
- flwr/proto/appio_pb2.pyi +134 -65
- flwr/proto/appio_pb2_grpc.py +20 -0
- flwr/proto/appio_pb2_grpc.pyi +27 -0
- flwr/proto/clientappio_pb2.py +17 -7
- flwr/proto/clientappio_pb2.pyi +15 -0
- flwr/proto/clientappio_pb2_grpc.py +206 -40
- flwr/proto/clientappio_pb2_grpc.pyi +168 -53
- flwr/proto/control_pb2.py +71 -52
- flwr/proto/control_pb2.pyi +277 -111
- flwr/proto/control_pb2_grpc.py +249 -40
- flwr/proto/control_pb2_grpc.pyi +185 -52
- flwr/proto/error_pb2.py +13 -3
- flwr/proto/error_pb2.pyi +24 -6
- flwr/proto/error_pb2_grpc.py +20 -0
- flwr/proto/error_pb2_grpc.pyi +27 -0
- flwr/proto/fab_pb2.py +14 -4
- flwr/proto/fab_pb2.pyi +59 -31
- flwr/proto/fab_pb2_grpc.py +20 -0
- flwr/proto/fab_pb2_grpc.pyi +27 -0
- flwr/proto/federation_pb2.py +38 -0
- flwr/proto/federation_pb2.pyi +56 -0
- flwr/proto/federation_pb2_grpc.py +24 -0
- flwr/proto/federation_pb2_grpc.pyi +31 -0
- flwr/proto/fleet_pb2.py +14 -4
- flwr/proto/fleet_pb2.pyi +137 -61
- flwr/proto/fleet_pb2_grpc.py +189 -48
- flwr/proto/fleet_pb2_grpc.pyi +175 -61
- flwr/proto/grpcadapter_pb2.py +14 -4
- flwr/proto/grpcadapter_pb2.pyi +38 -16
- flwr/proto/grpcadapter_pb2_grpc.py +35 -4
- flwr/proto/grpcadapter_pb2_grpc.pyi +38 -7
- flwr/proto/heartbeat_pb2.py +17 -7
- flwr/proto/heartbeat_pb2.pyi +51 -22
- flwr/proto/heartbeat_pb2_grpc.py +20 -0
- flwr/proto/heartbeat_pb2_grpc.pyi +27 -0
- flwr/proto/log_pb2.py +13 -3
- flwr/proto/log_pb2.pyi +34 -11
- flwr/proto/log_pb2_grpc.py +20 -0
- flwr/proto/log_pb2_grpc.pyi +27 -0
- flwr/proto/message_pb2.py +15 -5
- flwr/proto/message_pb2.pyi +154 -86
- flwr/proto/message_pb2_grpc.py +20 -0
- flwr/proto/message_pb2_grpc.pyi +27 -0
- flwr/proto/node_pb2.py +15 -5
- flwr/proto/node_pb2.pyi +50 -25
- flwr/proto/node_pb2_grpc.py +20 -0
- flwr/proto/node_pb2_grpc.pyi +27 -0
- flwr/proto/recorddict_pb2.py +13 -3
- flwr/proto/recorddict_pb2.pyi +184 -107
- flwr/proto/recorddict_pb2_grpc.py +20 -0
- flwr/proto/recorddict_pb2_grpc.pyi +27 -0
- flwr/proto/run_pb2.py +40 -31
- flwr/proto/run_pb2.pyi +149 -84
- flwr/proto/run_pb2_grpc.py +20 -0
- flwr/proto/run_pb2_grpc.pyi +27 -0
- flwr/proto/serverappio_pb2.py +13 -3
- flwr/proto/serverappio_pb2.pyi +32 -8
- flwr/proto/serverappio_pb2_grpc.py +246 -65
- flwr/proto/serverappio_pb2_grpc.pyi +221 -85
- flwr/proto/simulationio_pb2.py +16 -8
- flwr/proto/simulationio_pb2.pyi +15 -0
- flwr/proto/simulationio_pb2_grpc.py +162 -41
- flwr/proto/simulationio_pb2_grpc.pyi +149 -55
- flwr/proto/transport_pb2.py +20 -10
- flwr/proto/transport_pb2.pyi +249 -160
- flwr/proto/transport_pb2_grpc.py +35 -4
- flwr/proto/transport_pb2_grpc.pyi +38 -8
- flwr/server/app.py +38 -17
- flwr/server/client_manager.py +4 -5
- flwr/server/client_proxy.py +10 -11
- flwr/server/compat/app.py +4 -5
- flwr/server/compat/app_utils.py +2 -1
- flwr/server/compat/grid_client_proxy.py +10 -12
- flwr/server/compat/legacy_context.py +3 -4
- flwr/server/fleet_event_log_interceptor.py +2 -1
- flwr/server/grid/grid.py +2 -3
- flwr/server/grid/grpc_grid.py +10 -8
- flwr/server/grid/inmemory_grid.py +4 -4
- flwr/server/run_serverapp.py +2 -3
- flwr/server/server.py +34 -39
- flwr/server/server_app.py +7 -8
- flwr/server/server_config.py +1 -2
- flwr/server/serverapp/app.py +34 -28
- flwr/server/serverapp_components.py +4 -5
- flwr/server/strategy/aggregate.py +9 -8
- flwr/server/strategy/bulyan.py +13 -11
- flwr/server/strategy/dp_adaptive_clipping.py +16 -20
- flwr/server/strategy/dp_fixed_clipping.py +12 -17
- flwr/server/strategy/dpfedavg_adaptive.py +3 -4
- flwr/server/strategy/dpfedavg_fixed.py +6 -10
- flwr/server/strategy/fault_tolerant_fedavg.py +14 -13
- flwr/server/strategy/fedadagrad.py +18 -14
- flwr/server/strategy/fedadam.py +16 -14
- flwr/server/strategy/fedavg.py +16 -17
- flwr/server/strategy/fedavg_android.py +15 -15
- flwr/server/strategy/fedavgm.py +21 -18
- flwr/server/strategy/fedmedian.py +2 -3
- flwr/server/strategy/fedopt.py +11 -10
- flwr/server/strategy/fedprox.py +10 -9
- flwr/server/strategy/fedtrimmedavg.py +12 -11
- flwr/server/strategy/fedxgb_bagging.py +13 -11
- flwr/server/strategy/fedxgb_cyclic.py +6 -6
- flwr/server/strategy/fedxgb_nn_avg.py +4 -4
- flwr/server/strategy/fedyogi.py +16 -14
- flwr/server/strategy/krum.py +12 -11
- flwr/server/strategy/qfedavg.py +16 -15
- flwr/server/strategy/strategy.py +6 -9
- flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +2 -1
- flwr/server/superlink/fleet/grpc_bidi/flower_service_servicer.py +1 -2
- flwr/server/superlink/fleet/grpc_bidi/grpc_bridge.py +3 -4
- flwr/server/superlink/fleet/grpc_bidi/grpc_client_proxy.py +10 -12
- flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +1 -3
- flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +4 -4
- flwr/server/superlink/fleet/grpc_rere/node_auth_server_interceptor.py +3 -2
- flwr/server/superlink/fleet/message_handler/message_handler.py +34 -28
- flwr/server/superlink/fleet/rest_rere/rest_api.py +2 -2
- flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
- flwr/server/superlink/fleet/vce/backend/raybackend.py +5 -5
- flwr/server/superlink/fleet/vce/vce_api.py +15 -9
- flwr/server/superlink/linkstate/in_memory_linkstate.py +115 -150
- flwr/server/superlink/linkstate/linkstate.py +59 -43
- flwr/server/superlink/linkstate/linkstate_factory.py +22 -5
- flwr/server/superlink/linkstate/sqlite_linkstate.py +447 -438
- flwr/server/superlink/linkstate/utils.py +6 -6
- flwr/server/superlink/serverappio/serverappio_grpc.py +1 -2
- flwr/server/superlink/serverappio/serverappio_servicer.py +26 -21
- flwr/server/superlink/simulation/simulationio_grpc.py +1 -2
- flwr/server/superlink/simulation/simulationio_servicer.py +18 -13
- flwr/server/superlink/utils.py +4 -6
- flwr/server/typing.py +1 -1
- flwr/server/utils/tensorboard.py +15 -8
- flwr/server/workflow/default_workflows.py +5 -5
- flwr/server/workflow/secure_aggregation/secagg_workflow.py +2 -4
- flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +8 -8
- flwr/serverapp/strategy/bulyan.py +16 -15
- flwr/serverapp/strategy/dp_adaptive_clipping.py +12 -11
- flwr/serverapp/strategy/dp_fixed_clipping.py +11 -14
- flwr/serverapp/strategy/fedadagrad.py +10 -11
- flwr/serverapp/strategy/fedadam.py +10 -11
- flwr/serverapp/strategy/fedavg.py +9 -10
- flwr/serverapp/strategy/fedavgm.py +17 -16
- flwr/serverapp/strategy/fedmedian.py +2 -2
- flwr/serverapp/strategy/fedopt.py +10 -11
- flwr/serverapp/strategy/fedprox.py +7 -8
- flwr/serverapp/strategy/fedtrimmedavg.py +9 -9
- flwr/serverapp/strategy/fedxgb_bagging.py +3 -3
- flwr/serverapp/strategy/fedxgb_cyclic.py +9 -9
- flwr/serverapp/strategy/fedyogi.py +9 -11
- flwr/serverapp/strategy/krum.py +7 -7
- flwr/serverapp/strategy/multikrum.py +9 -9
- flwr/serverapp/strategy/qfedavg.py +17 -16
- flwr/serverapp/strategy/strategy.py +6 -9
- flwr/serverapp/strategy/strategy_utils.py +7 -8
- flwr/simulation/app.py +46 -42
- flwr/simulation/legacy_app.py +12 -12
- flwr/simulation/ray_transport/ray_actor.py +10 -11
- flwr/simulation/ray_transport/ray_client_proxy.py +11 -12
- flwr/simulation/run_simulation.py +43 -43
- flwr/simulation/simulationio_connection.py +4 -4
- flwr/supercore/cli/flower_superexec.py +3 -4
- flwr/supercore/constant.py +31 -1
- flwr/supercore/corestate/corestate.py +24 -3
- flwr/supercore/corestate/in_memory_corestate.py +138 -0
- flwr/supercore/corestate/sqlite_corestate.py +157 -0
- flwr/supercore/ffs/disk_ffs.py +1 -2
- flwr/supercore/ffs/ffs.py +1 -2
- flwr/supercore/ffs/ffs_factory.py +1 -2
- flwr/{common → supercore}/heartbeat.py +20 -25
- flwr/supercore/object_store/in_memory_object_store.py +1 -2
- flwr/supercore/object_store/object_store.py +1 -2
- flwr/supercore/object_store/object_store_factory.py +1 -2
- flwr/supercore/object_store/sqlite_object_store.py +8 -7
- flwr/supercore/primitives/asymmetric.py +1 -1
- flwr/supercore/primitives/asymmetric_ed25519.py +11 -1
- flwr/supercore/sqlite_mixin.py +37 -34
- flwr/supercore/superexec/plugin/base_exec_plugin.py +1 -2
- flwr/supercore/superexec/plugin/exec_plugin.py +3 -3
- flwr/supercore/superexec/run_superexec.py +9 -13
- flwr/superlink/artifact_provider/artifact_provider.py +1 -2
- flwr/superlink/auth_plugin/auth_plugin.py +6 -9
- flwr/superlink/auth_plugin/noop_auth_plugin.py +6 -9
- flwr/superlink/federation/__init__.py +24 -0
- flwr/superlink/federation/federation_manager.py +64 -0
- flwr/superlink/federation/noop_federation_manager.py +71 -0
- flwr/superlink/servicer/control/control_account_auth_interceptor.py +22 -13
- flwr/superlink/servicer/control/control_event_log_interceptor.py +7 -7
- flwr/superlink/servicer/control/control_grpc.py +5 -6
- flwr/superlink/servicer/control/control_license_interceptor.py +3 -3
- flwr/superlink/servicer/control/control_servicer.py +102 -18
- flwr/supernode/cli/flower_supernode.py +58 -3
- flwr/supernode/nodestate/in_memory_nodestate.py +60 -49
- flwr/supernode/nodestate/nodestate.py +7 -8
- flwr/supernode/nodestate/nodestate_factory.py +7 -4
- flwr/supernode/runtime/run_clientapp.py +41 -22
- flwr/supernode/servicer/clientappio/clientappio_servicer.py +40 -10
- flwr/supernode/start_client_internal.py +158 -42
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/METADATA +8 -8
- flwr-1.24.0.dist-info/RECORD +454 -0
- flwr/supercore/object_store/utils.py +0 -43
- flwr-1.23.0.dist-info/RECORD +0 -439
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/WHEEL +0 -0
- {flwr-1.23.0.dist-info → flwr-1.24.0.dist-info}/entry_points.txt +0 -0
|
@@ -18,10 +18,8 @@ Paper: openreview.net/pdf?id=ByexElSYDr
|
|
|
18
18
|
"""
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
from collections import
|
|
22
|
-
from collections.abc import Iterable
|
|
21
|
+
from collections.abc import Callable, Iterable
|
|
23
22
|
from logging import INFO
|
|
24
|
-
from typing import Callable, Optional
|
|
25
23
|
|
|
26
24
|
import numpy as np
|
|
27
25
|
|
|
@@ -105,12 +103,12 @@ class QFedAvg(FedAvg):
|
|
|
105
103
|
weighted_by_key: str = "num-examples",
|
|
106
104
|
arrayrecord_key: str = "arrays",
|
|
107
105
|
configrecord_key: str = "config",
|
|
108
|
-
train_metrics_aggr_fn:
|
|
109
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
110
|
-
|
|
111
|
-
evaluate_metrics_aggr_fn:
|
|
112
|
-
Callable[[list[RecordDict], str], MetricRecord]
|
|
113
|
-
|
|
106
|
+
train_metrics_aggr_fn: (
|
|
107
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
108
|
+
) = None,
|
|
109
|
+
evaluate_metrics_aggr_fn: (
|
|
110
|
+
Callable[[list[RecordDict], str], MetricRecord] | None
|
|
111
|
+
) = None,
|
|
114
112
|
) -> None:
|
|
115
113
|
super().__init__(
|
|
116
114
|
fraction_train=fraction_train,
|
|
@@ -127,7 +125,7 @@ class QFedAvg(FedAvg):
|
|
|
127
125
|
self.q = q
|
|
128
126
|
self.client_learning_rate = client_learning_rate
|
|
129
127
|
self.train_loss_key = train_loss_key
|
|
130
|
-
self.current_arrays:
|
|
128
|
+
self.current_arrays: ArrayRecord | None = None
|
|
131
129
|
|
|
132
130
|
def summary(self) -> None:
|
|
133
131
|
"""Log summary configuration of the strategy."""
|
|
@@ -148,7 +146,7 @@ class QFedAvg(FedAvg):
|
|
|
148
146
|
self,
|
|
149
147
|
server_round: int,
|
|
150
148
|
replies: Iterable[Message],
|
|
151
|
-
) -> tuple[
|
|
149
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
152
150
|
"""Aggregate ArrayRecords and MetricRecords in the received Messages."""
|
|
153
151
|
# Call FedAvg aggregate_train to perform validation and aggregation
|
|
154
152
|
valid_replies, _ = self._check_and_log_replies(replies, is_train=True)
|
|
@@ -184,7 +182,7 @@ class QFedAvg(FedAvg):
|
|
|
184
182
|
if sum_delta is None:
|
|
185
183
|
sum_delta = delta
|
|
186
184
|
else:
|
|
187
|
-
sum_delta = [sd + d for sd, d in zip(sum_delta, delta)]
|
|
185
|
+
sum_delta = [sd + d for sd, d in zip(sum_delta, delta, strict=True)]
|
|
188
186
|
sum_h += h
|
|
189
187
|
|
|
190
188
|
# Compute new global weights and convert to Array type
|
|
@@ -192,7 +190,7 @@ class QFedAvg(FedAvg):
|
|
|
192
190
|
assert sum_delta is not None # Make mypy happy
|
|
193
191
|
array_list = [
|
|
194
192
|
Array(np.asarray(gw - (d / sum_h)))
|
|
195
|
-
for gw, d in zip(global_weights, sum_delta)
|
|
193
|
+
for gw, d in zip(global_weights, sum_delta, strict=True)
|
|
196
194
|
]
|
|
197
195
|
|
|
198
196
|
# Aggregate MetricRecords
|
|
@@ -200,13 +198,16 @@ class QFedAvg(FedAvg):
|
|
|
200
198
|
[msg.content for msg in valid_replies],
|
|
201
199
|
self.weighted_by_key,
|
|
202
200
|
)
|
|
203
|
-
return
|
|
201
|
+
return (
|
|
202
|
+
ArrayRecord(dict(zip(array_keys, array_list, strict=True))),
|
|
203
|
+
metrics,
|
|
204
|
+
)
|
|
204
205
|
|
|
205
206
|
|
|
206
207
|
def get_train_loss(msg: Message, loss_key: str) -> float:
|
|
207
208
|
"""Extract training loss from a Message."""
|
|
208
209
|
metrics = list(msg.content.metric_records.values())[0]
|
|
209
|
-
if (loss := metrics.get(loss_key)) is None or not isinstance(loss, (int
|
|
210
|
+
if (loss := metrics.get(loss_key)) is None or not isinstance(loss, (int | float)):
|
|
210
211
|
raise AggregationError(
|
|
211
212
|
"Missing or invalid training loss. "
|
|
212
213
|
f"The strategy expected a float value for the key '{loss_key}' "
|
|
@@ -236,7 +237,7 @@ def compute_delta_and_h(
|
|
|
236
237
|
) -> tuple[list[NDArray], float]:
|
|
237
238
|
"""Compute delta and h used in q-FedAvg aggregation."""
|
|
238
239
|
# Compute gradient_k = L * (w - w_k)
|
|
239
|
-
for gw, lw in zip(global_weights, local_weights):
|
|
240
|
+
for gw, lw in zip(global_weights, local_weights, strict=True):
|
|
240
241
|
np.subtract(gw, lw, out=lw)
|
|
241
242
|
lw *= L
|
|
242
243
|
grad = local_weights # After in-place operations, local_weights is now grad
|
|
@@ -18,9 +18,8 @@
|
|
|
18
18
|
import io
|
|
19
19
|
import time
|
|
20
20
|
from abc import ABC, abstractmethod
|
|
21
|
-
from collections.abc import Iterable
|
|
21
|
+
from collections.abc import Callable, Iterable
|
|
22
22
|
from logging import INFO
|
|
23
|
-
from typing import Callable, Optional
|
|
24
23
|
|
|
25
24
|
from flwr.common import ArrayRecord, ConfigRecord, Message, MetricRecord, log
|
|
26
25
|
from flwr.server import Grid
|
|
@@ -61,7 +60,7 @@ class Strategy(ABC):
|
|
|
61
60
|
self,
|
|
62
61
|
server_round: int,
|
|
63
62
|
replies: Iterable[Message],
|
|
64
|
-
) -> tuple[
|
|
63
|
+
) -> tuple[ArrayRecord | None, MetricRecord | None]:
|
|
65
64
|
"""Aggregate training results from client nodes.
|
|
66
65
|
|
|
67
66
|
Parameters
|
|
@@ -109,7 +108,7 @@ class Strategy(ABC):
|
|
|
109
108
|
self,
|
|
110
109
|
server_round: int,
|
|
111
110
|
replies: Iterable[Message],
|
|
112
|
-
) ->
|
|
111
|
+
) -> MetricRecord | None:
|
|
113
112
|
"""Aggregate evaluation metrics from client nodes.
|
|
114
113
|
|
|
115
114
|
Parameters
|
|
@@ -138,11 +137,9 @@ class Strategy(ABC):
|
|
|
138
137
|
initial_arrays: ArrayRecord,
|
|
139
138
|
num_rounds: int = 3,
|
|
140
139
|
timeout: float = 3600,
|
|
141
|
-
train_config:
|
|
142
|
-
evaluate_config:
|
|
143
|
-
evaluate_fn:
|
|
144
|
-
Callable[[int, ArrayRecord], Optional[MetricRecord]]
|
|
145
|
-
] = None,
|
|
140
|
+
train_config: ConfigRecord | None = None,
|
|
141
|
+
evaluate_config: ConfigRecord | None = None,
|
|
142
|
+
evaluate_fn: Callable[[int, ArrayRecord], MetricRecord | None] | None = None,
|
|
146
143
|
) -> Result:
|
|
147
144
|
"""Execute the federated learning strategy.
|
|
148
145
|
|
|
@@ -17,10 +17,9 @@
|
|
|
17
17
|
|
|
18
18
|
import json
|
|
19
19
|
import random
|
|
20
|
-
from collections import OrderedDict
|
|
21
20
|
from logging import INFO
|
|
22
21
|
from time import sleep
|
|
23
|
-
from typing import
|
|
22
|
+
from typing import cast
|
|
24
23
|
|
|
25
24
|
import numpy as np
|
|
26
25
|
|
|
@@ -49,8 +48,8 @@ def config_to_str(config: ConfigRecord) -> str:
|
|
|
49
48
|
def log_strategy_start_info(
|
|
50
49
|
num_rounds: int,
|
|
51
50
|
arrays: ArrayRecord,
|
|
52
|
-
train_config:
|
|
53
|
-
evaluate_config:
|
|
51
|
+
train_config: ConfigRecord | None,
|
|
52
|
+
evaluate_config: ConfigRecord | None,
|
|
54
53
|
) -> None:
|
|
55
54
|
"""Log information about the strategy start."""
|
|
56
55
|
log(INFO, "\t├── Number of rounds: %d", num_rounds)
|
|
@@ -92,7 +91,7 @@ def aggregate_arrayrecords(
|
|
|
92
91
|
# Perform weighted aggregation
|
|
93
92
|
aggregated_np_arrays: dict[str, NDArray] = {}
|
|
94
93
|
|
|
95
|
-
for record, weight in zip(records, weight_factors):
|
|
94
|
+
for record, weight in zip(records, weight_factors, strict=True):
|
|
96
95
|
for record_item in record.array_records.values():
|
|
97
96
|
# aggregate in-place
|
|
98
97
|
for key, value in record_item.items():
|
|
@@ -102,7 +101,7 @@ def aggregate_arrayrecords(
|
|
|
102
101
|
aggregated_np_arrays[key] += value.numpy() * weight
|
|
103
102
|
|
|
104
103
|
return ArrayRecord(
|
|
105
|
-
|
|
104
|
+
{k: Array(np.asarray(v)) for k, v in aggregated_np_arrays.items()}
|
|
106
105
|
)
|
|
107
106
|
|
|
108
107
|
|
|
@@ -125,7 +124,7 @@ def aggregate_metricrecords(
|
|
|
125
124
|
weight_factors = [w / total_weight for w in weights]
|
|
126
125
|
|
|
127
126
|
aggregated_metrics = MetricRecord()
|
|
128
|
-
for record, weight in zip(records, weight_factors):
|
|
127
|
+
for record, weight in zip(records, weight_factors, strict=True):
|
|
129
128
|
for record_item in record.metric_records.values():
|
|
130
129
|
# aggregate in-place
|
|
131
130
|
for key, value in record_item.items():
|
|
@@ -142,7 +141,7 @@ def aggregate_metricrecords(
|
|
|
142
141
|
current_list = cast(list[float], aggregated_metrics[key])
|
|
143
142
|
aggregated_metrics[key] = [
|
|
144
143
|
curr + val * weight
|
|
145
|
-
for curr, val in zip(current_list, value)
|
|
144
|
+
for curr, val in zip(current_list, value, strict=True)
|
|
146
145
|
]
|
|
147
146
|
else:
|
|
148
147
|
current_value = cast(float, aggregated_metrics[key])
|
flwr/simulation/app.py
CHANGED
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
import argparse
|
|
19
19
|
from logging import DEBUG, ERROR, INFO
|
|
20
20
|
from queue import Queue
|
|
21
|
-
from typing import Optional
|
|
22
21
|
|
|
23
22
|
from flwr.cli.config_utils import get_fab_metadata
|
|
24
23
|
from flwr.cli.install import install_from_fab
|
|
@@ -38,8 +37,7 @@ from flwr.common.constant import (
|
|
|
38
37
|
Status,
|
|
39
38
|
SubStatus,
|
|
40
39
|
)
|
|
41
|
-
from flwr.common.exit import ExitCode, flwr_exit
|
|
42
|
-
from flwr.common.heartbeat import HeartbeatSender, get_grpc_app_heartbeat_fn
|
|
40
|
+
from flwr.common.exit import ExitCode, flwr_exit, register_signal_handlers
|
|
43
41
|
from flwr.common.logger import (
|
|
44
42
|
log,
|
|
45
43
|
mirror_output_to_queue,
|
|
@@ -71,6 +69,7 @@ from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig
|
|
|
71
69
|
from flwr.simulation.run_simulation import _run_simulation
|
|
72
70
|
from flwr.simulation.simulationio_connection import SimulationIoConnection
|
|
73
71
|
from flwr.supercore.app_utils import start_parent_process_monitor
|
|
72
|
+
from flwr.supercore.heartbeat import HeartbeatSender, make_app_heartbeat_fn_grpc
|
|
74
73
|
from flwr.supercore.superexec.plugin import SimulationExecPlugin
|
|
75
74
|
from flwr.supercore.superexec.run_superexec import run_with_deprecation_warning
|
|
76
75
|
|
|
@@ -78,7 +77,7 @@ from flwr.supercore.superexec.run_superexec import run_with_deprecation_warning
|
|
|
78
77
|
def flwr_simulation() -> None:
|
|
79
78
|
"""Run process-isolated Flower Simulation."""
|
|
80
79
|
# Capture stdout/stderr
|
|
81
|
-
log_queue: Queue[
|
|
80
|
+
log_queue: Queue[str | None] = Queue()
|
|
82
81
|
mirror_output_to_queue(log_queue)
|
|
83
82
|
|
|
84
83
|
args = _parse_args_run_flwr_simulation().parse_args()
|
|
@@ -125,11 +124,11 @@ def flwr_simulation() -> None:
|
|
|
125
124
|
|
|
126
125
|
def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
127
126
|
simulationio_api_address: str,
|
|
128
|
-
log_queue: Queue[
|
|
127
|
+
log_queue: Queue[str | None],
|
|
129
128
|
token: str,
|
|
130
|
-
flwr_dir_:
|
|
131
|
-
certificates:
|
|
132
|
-
parent_pid:
|
|
129
|
+
flwr_dir_: str | None = None,
|
|
130
|
+
certificates: bytes | None = None,
|
|
131
|
+
parent_pid: int | None = None,
|
|
133
132
|
) -> None:
|
|
134
133
|
"""Run Flower Simulation process."""
|
|
135
134
|
# Start monitoring the parent process if a PID is provided
|
|
@@ -141,11 +140,35 @@ def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
141
140
|
root_certificates=certificates,
|
|
142
141
|
)
|
|
143
142
|
|
|
144
|
-
#
|
|
143
|
+
# Initialize variables for finally block
|
|
145
144
|
flwr_dir = get_flwr_dir(flwr_dir_)
|
|
146
145
|
log_uploader = None
|
|
147
146
|
heartbeat_sender = None
|
|
147
|
+
run = None
|
|
148
148
|
run_status = None
|
|
149
|
+
exit_code = ExitCode.SUCCESS
|
|
150
|
+
|
|
151
|
+
def on_exit() -> None:
|
|
152
|
+
# Stop heartbeat sender
|
|
153
|
+
if heartbeat_sender and heartbeat_sender.is_running:
|
|
154
|
+
heartbeat_sender.stop()
|
|
155
|
+
|
|
156
|
+
# Stop log uploader for this run and upload final logs
|
|
157
|
+
if log_uploader:
|
|
158
|
+
stop_log_uploader(log_queue, log_uploader)
|
|
159
|
+
|
|
160
|
+
# Update run status
|
|
161
|
+
if run and run_status:
|
|
162
|
+
run_status_proto = run_status_to_proto(run_status)
|
|
163
|
+
conn._stub.UpdateRunStatus(
|
|
164
|
+
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
register_signal_handlers(
|
|
168
|
+
event_type=EventType.FLWR_SIMULATION_RUN_LEAVE,
|
|
169
|
+
exit_message="Run stopped by user.",
|
|
170
|
+
exit_handlers=[on_exit],
|
|
171
|
+
)
|
|
149
172
|
|
|
150
173
|
try:
|
|
151
174
|
# Pull SimulationInputs from LinkState
|
|
@@ -193,12 +216,6 @@ def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
193
216
|
app_path,
|
|
194
217
|
)
|
|
195
218
|
|
|
196
|
-
# Change status to Running
|
|
197
|
-
run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
|
|
198
|
-
conn._stub.UpdateRunStatus(
|
|
199
|
-
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
|
|
200
|
-
)
|
|
201
|
-
|
|
202
219
|
# Pull Federation Options
|
|
203
220
|
fed_opt_res: GetFederationOptionsResponse = conn._stub.GetFederationOptions(
|
|
204
221
|
GetFederationOptionsRequest(run_id=run.run_id)
|
|
@@ -216,23 +233,20 @@ def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
216
233
|
verbose: bool = fed_opt.get("verbose", False)
|
|
217
234
|
enable_tf_gpu_growth: bool = fed_opt.get("enable_tf_gpu_growth", False)
|
|
218
235
|
|
|
236
|
+
run_id_hash = get_sha256_hash(run.run_id)
|
|
219
237
|
event(
|
|
220
238
|
EventType.FLWR_SIMULATION_RUN_ENTER,
|
|
221
239
|
event_details={
|
|
222
240
|
"backend": "ray",
|
|
223
241
|
"num-supernodes": num_supernodes,
|
|
224
|
-
"run-id-hash":
|
|
242
|
+
"run-id-hash": run_id_hash,
|
|
225
243
|
},
|
|
226
244
|
)
|
|
227
245
|
|
|
228
246
|
# Set up heartbeat sender
|
|
229
|
-
|
|
230
|
-
conn._stub,
|
|
231
|
-
run.run_id,
|
|
232
|
-
failure_message="Heartbeat failed unexpectedly. The SuperLink could "
|
|
233
|
-
"not find the provided run ID, or the run status is invalid.",
|
|
247
|
+
heartbeat_sender = HeartbeatSender(
|
|
248
|
+
make_app_heartbeat_fn_grpc(conn._stub, token)
|
|
234
249
|
)
|
|
235
|
-
heartbeat_sender = HeartbeatSender(heartbeat_fn)
|
|
236
250
|
heartbeat_sender.start()
|
|
237
251
|
|
|
238
252
|
# Launch the simulation
|
|
@@ -264,27 +278,17 @@ def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
|
|
|
264
278
|
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
|
|
265
279
|
run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))
|
|
266
280
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
if heartbeat_sender:
|
|
270
|
-
heartbeat_sender.stop()
|
|
271
|
-
|
|
272
|
-
# Stop log uploader for this run and upload final logs
|
|
273
|
-
if log_uploader:
|
|
274
|
-
stop_log_uploader(log_queue, log_uploader)
|
|
281
|
+
# General exit code
|
|
282
|
+
exit_code = ExitCode.SIMULATION_EXCEPTION
|
|
275
283
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
try:
|
|
285
|
-
del updated_context
|
|
286
|
-
except NameError:
|
|
287
|
-
pass
|
|
284
|
+
flwr_exit(
|
|
285
|
+
code=exit_code,
|
|
286
|
+
event_type=EventType.FLWR_SIMULATION_RUN_LEAVE,
|
|
287
|
+
event_details={
|
|
288
|
+
"run-id-hash": run_id_hash,
|
|
289
|
+
"success": exit_code == ExitCode.SUCCESS,
|
|
290
|
+
},
|
|
291
|
+
)
|
|
288
292
|
|
|
289
293
|
|
|
290
294
|
def _parse_args_run_flwr_simulation() -> argparse.ArgumentParser:
|
flwr/simulation/legacy_app.py
CHANGED
|
@@ -22,7 +22,7 @@ import threading
|
|
|
22
22
|
import traceback
|
|
23
23
|
import warnings
|
|
24
24
|
from logging import ERROR, INFO
|
|
25
|
-
from typing import Any
|
|
25
|
+
from typing import Any
|
|
26
26
|
|
|
27
27
|
import ray
|
|
28
28
|
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
|
@@ -101,17 +101,17 @@ def start_simulation(
|
|
|
101
101
|
*,
|
|
102
102
|
client_fn: ClientFnExt,
|
|
103
103
|
num_clients: int,
|
|
104
|
-
clients_ids:
|
|
105
|
-
client_resources:
|
|
106
|
-
server:
|
|
107
|
-
config:
|
|
108
|
-
strategy:
|
|
109
|
-
client_manager:
|
|
110
|
-
ray_init_args:
|
|
111
|
-
keep_initialised:
|
|
104
|
+
clients_ids: list[str] | None = None, # UNSUPPORTED, WILL BE REMOVED
|
|
105
|
+
client_resources: dict[str, float] | None = None,
|
|
106
|
+
server: Server | None = None,
|
|
107
|
+
config: ServerConfig | None = None,
|
|
108
|
+
strategy: Strategy | None = None,
|
|
109
|
+
client_manager: ClientManager | None = None,
|
|
110
|
+
ray_init_args: dict[str, Any] | None = None,
|
|
111
|
+
keep_initialised: bool | None = False,
|
|
112
112
|
actor_type: type[VirtualClientEngineActor] = ClientAppActor,
|
|
113
|
-
actor_kwargs:
|
|
114
|
-
actor_scheduling:
|
|
113
|
+
actor_kwargs: dict[str, Any] | None = None,
|
|
114
|
+
actor_scheduling: str | NodeAffinitySchedulingStrategy = "DEFAULT",
|
|
115
115
|
) -> History:
|
|
116
116
|
"""Start a Ray-based Flower simulation server.
|
|
117
117
|
|
|
@@ -219,7 +219,7 @@ def start_simulation(
|
|
|
219
219
|
sys.exit()
|
|
220
220
|
|
|
221
221
|
# Set logger propagation
|
|
222
|
-
loop:
|
|
222
|
+
loop: asyncio.AbstractEventLoop | None = None
|
|
223
223
|
try:
|
|
224
224
|
loop = asyncio.get_running_loop()
|
|
225
225
|
except RuntimeError:
|
|
@@ -17,8 +17,9 @@
|
|
|
17
17
|
|
|
18
18
|
import threading
|
|
19
19
|
from abc import ABC
|
|
20
|
+
from collections.abc import Callable
|
|
20
21
|
from logging import DEBUG, ERROR, WARNING
|
|
21
|
-
from typing import Any
|
|
22
|
+
from typing import Any
|
|
22
23
|
|
|
23
24
|
import ray
|
|
24
25
|
from ray import ObjectRef
|
|
@@ -76,13 +77,13 @@ class ClientAppActor(VirtualClientEngineActor):
|
|
|
76
77
|
A function to execute upon actor initialization.
|
|
77
78
|
"""
|
|
78
79
|
|
|
79
|
-
def __init__(self, on_actor_init_fn:
|
|
80
|
+
def __init__(self, on_actor_init_fn: Callable[[], None] | None = None) -> None:
|
|
80
81
|
super().__init__()
|
|
81
82
|
if on_actor_init_fn:
|
|
82
83
|
on_actor_init_fn()
|
|
83
84
|
|
|
84
85
|
|
|
85
|
-
def pool_size_from_resources(client_resources: dict[str,
|
|
86
|
+
def pool_size_from_resources(client_resources: dict[str, int | float]) -> int:
|
|
86
87
|
"""Calculate number of Actors that fit in the cluster.
|
|
87
88
|
|
|
88
89
|
For this we consider the resources available on each node and those required per
|
|
@@ -166,8 +167,8 @@ class VirtualClientEngineActorPool(ActorPool):
|
|
|
166
167
|
def __init__(
|
|
167
168
|
self,
|
|
168
169
|
create_actor_fn: Callable[[], type[VirtualClientEngineActor]],
|
|
169
|
-
client_resources: dict[str,
|
|
170
|
-
actor_list:
|
|
170
|
+
client_resources: dict[str, int | float],
|
|
171
|
+
actor_list: list[type[VirtualClientEngineActor]] | None = None,
|
|
171
172
|
):
|
|
172
173
|
self.client_resources = client_resources
|
|
173
174
|
self.create_actor_fn = create_actor_fn
|
|
@@ -186,9 +187,7 @@ class VirtualClientEngineActorPool(ActorPool):
|
|
|
186
187
|
|
|
187
188
|
# A dict that maps cid to another dict containing: a reference to the remote job
|
|
188
189
|
# and its status (i.e. whether it is ready or not)
|
|
189
|
-
self._cid_to_future: dict[
|
|
190
|
-
str, dict[str, Union[bool, Optional[ObjectRef[Any]]]]
|
|
191
|
-
] = {}
|
|
190
|
+
self._cid_to_future: dict[str, dict[str, bool | ObjectRef[Any] | None]] = {}
|
|
192
191
|
self.actor_to_remove: set[str] = set() # a set
|
|
193
192
|
self.num_actors = len(actors)
|
|
194
193
|
|
|
@@ -353,7 +352,7 @@ class VirtualClientEngineActorPool(ActorPool):
|
|
|
353
352
|
|
|
354
353
|
return True
|
|
355
354
|
|
|
356
|
-
def process_unordered_future(self, timeout:
|
|
355
|
+
def process_unordered_future(self, timeout: float | None = None) -> None:
|
|
357
356
|
"""Similar to parent's get_next_unordered() but without final ray.get()."""
|
|
358
357
|
if not self.has_next(): # type: ignore
|
|
359
358
|
raise StopIteration("No more results to get")
|
|
@@ -384,7 +383,7 @@ class VirtualClientEngineActorPool(ActorPool):
|
|
|
384
383
|
actor.terminate.remote()
|
|
385
384
|
|
|
386
385
|
def get_client_result(
|
|
387
|
-
self, cid: str, timeout:
|
|
386
|
+
self, cid: str, timeout: float | None
|
|
388
387
|
) -> tuple[Message, Context]:
|
|
389
388
|
"""Get result from VirtualClient with specific cid."""
|
|
390
389
|
# Loop until all jobs submitted to the pool are completed. Break early
|
|
@@ -407,7 +406,7 @@ class BasicActorPool:
|
|
|
407
406
|
def __init__(
|
|
408
407
|
self,
|
|
409
408
|
actor_type: type[VirtualClientEngineActor],
|
|
410
|
-
client_resources: dict[str,
|
|
409
|
+
client_resources: dict[str, int | float],
|
|
411
410
|
actor_kwargs: dict[str, Any],
|
|
412
411
|
):
|
|
413
412
|
self.client_resources = client_resources
|
|
@@ -17,7 +17,6 @@
|
|
|
17
17
|
|
|
18
18
|
import traceback
|
|
19
19
|
from logging import ERROR
|
|
20
|
-
from typing import Optional
|
|
21
20
|
|
|
22
21
|
from flwr import common
|
|
23
22
|
from flwr.client import ClientFnExt
|
|
@@ -74,7 +73,7 @@ class RayActorClientProxy(ClientProxy):
|
|
|
74
73
|
},
|
|
75
74
|
)
|
|
76
75
|
|
|
77
|
-
def _submit_job(self, message: Message, timeout:
|
|
76
|
+
def _submit_job(self, message: Message, timeout: float | None) -> Message:
|
|
78
77
|
"""Sumbit a message to the ActorPool."""
|
|
79
78
|
run_id = message.metadata.run_id
|
|
80
79
|
|
|
@@ -114,8 +113,8 @@ class RayActorClientProxy(ClientProxy):
|
|
|
114
113
|
self,
|
|
115
114
|
recorddict: RecordDict,
|
|
116
115
|
message_type: str,
|
|
117
|
-
timeout:
|
|
118
|
-
group_id:
|
|
116
|
+
timeout: float | None,
|
|
117
|
+
group_id: int | None,
|
|
119
118
|
) -> Message:
|
|
120
119
|
"""Wrap a RecordDict inside a Message."""
|
|
121
120
|
return make_message(
|
|
@@ -136,8 +135,8 @@ class RayActorClientProxy(ClientProxy):
|
|
|
136
135
|
def get_properties(
|
|
137
136
|
self,
|
|
138
137
|
ins: common.GetPropertiesIns,
|
|
139
|
-
timeout:
|
|
140
|
-
group_id:
|
|
138
|
+
timeout: float | None,
|
|
139
|
+
group_id: int | None,
|
|
141
140
|
) -> common.GetPropertiesRes:
|
|
142
141
|
"""Return client's properties."""
|
|
143
142
|
recorddict = getpropertiesins_to_recorddict(ins)
|
|
@@ -155,8 +154,8 @@ class RayActorClientProxy(ClientProxy):
|
|
|
155
154
|
def get_parameters(
|
|
156
155
|
self,
|
|
157
156
|
ins: common.GetParametersIns,
|
|
158
|
-
timeout:
|
|
159
|
-
group_id:
|
|
157
|
+
timeout: float | None,
|
|
158
|
+
group_id: int | None,
|
|
160
159
|
) -> common.GetParametersRes:
|
|
161
160
|
"""Return the current local model parameters."""
|
|
162
161
|
recorddict = getparametersins_to_recorddict(ins)
|
|
@@ -172,7 +171,7 @@ class RayActorClientProxy(ClientProxy):
|
|
|
172
171
|
return recorddict_to_getparametersres(message_out.content, keep_input=False)
|
|
173
172
|
|
|
174
173
|
def fit(
|
|
175
|
-
self, ins: common.FitIns, timeout:
|
|
174
|
+
self, ins: common.FitIns, timeout: float | None, group_id: int | None
|
|
176
175
|
) -> common.FitRes:
|
|
177
176
|
"""Train model parameters on the locally held dataset."""
|
|
178
177
|
recorddict = fitins_to_recorddict(
|
|
@@ -190,7 +189,7 @@ class RayActorClientProxy(ClientProxy):
|
|
|
190
189
|
return recorddict_to_fitres(message_out.content, keep_input=False)
|
|
191
190
|
|
|
192
191
|
def evaluate(
|
|
193
|
-
self, ins: common.EvaluateIns, timeout:
|
|
192
|
+
self, ins: common.EvaluateIns, timeout: float | None, group_id: int | None
|
|
194
193
|
) -> common.EvaluateRes:
|
|
195
194
|
"""Evaluate model parameters on the locally held dataset."""
|
|
196
195
|
recorddict = evaluateins_to_recorddict(
|
|
@@ -210,8 +209,8 @@ class RayActorClientProxy(ClientProxy):
|
|
|
210
209
|
def reconnect(
|
|
211
210
|
self,
|
|
212
211
|
ins: common.ReconnectIns,
|
|
213
|
-
timeout:
|
|
214
|
-
group_id:
|
|
212
|
+
timeout: float | None,
|
|
213
|
+
group_id: int | None,
|
|
215
214
|
) -> common.DisconnectRes:
|
|
216
215
|
"""Disconnect and (optionally) reconnect later."""
|
|
217
216
|
return common.DisconnectRes(reason="") # Nothing to do here (yet)
|