flwr-nightly 1.4.0.dev20230321__py3-none-any.whl → 1.4.0.dev20230323__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/__init__.py +2 -1
- flwr/client/app.py +12 -2
- flwr/client/rest_client/connection.py +7 -1
- flwr/server/app.py +11 -2
- flwr/server/rest_server/rest_api.py +11 -5
- flwr/server/strategy/fedxgb_nn_avg.py +0 -3
- {flwr_nightly-1.4.0.dev20230321.dist-info → flwr_nightly-1.4.0.dev20230323.dist-info}/METADATA +1 -1
- {flwr_nightly-1.4.0.dev20230321.dist-info → flwr_nightly-1.4.0.dev20230323.dist-info}/RECORD +11 -12
- flwr/server/strategy/fedxgb.py +0 -153
- {flwr_nightly-1.4.0.dev20230321.dist-info → flwr_nightly-1.4.0.dev20230323.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.4.0.dev20230321.dist-info → flwr_nightly-1.4.0.dev20230323.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.4.0.dev20230321.dist-info → flwr_nightly-1.4.0.dev20230323.dist-info}/entry_points.txt +0 -0
flwr/__init__.py
CHANGED
@@ -16,10 +16,11 @@
|
|
16
16
|
|
17
17
|
from flwr.common.version import package_version as _package_version
|
18
18
|
|
19
|
-
from . import client, server, simulation
|
19
|
+
from . import client, common, server, simulation
|
20
20
|
|
21
21
|
__all__ = [
|
22
22
|
"client",
|
23
|
+
"common",
|
23
24
|
"server",
|
24
25
|
"simulation",
|
25
26
|
]
|
flwr/client/app.py
CHANGED
@@ -49,7 +49,6 @@ from .numpy_client import has_evaluate as numpyclient_has_evaluate
|
|
49
49
|
from .numpy_client import has_fit as numpyclient_has_fit
|
50
50
|
from .numpy_client import has_get_parameters as numpyclient_has_get_parameters
|
51
51
|
from .numpy_client import has_get_properties as numpyclient_has_get_properties
|
52
|
-
from .rest_client.connection import http_request_response
|
53
52
|
|
54
53
|
EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT = """
|
55
54
|
NumPyClient.fit did not return a tuple with 3 elements.
|
@@ -81,6 +80,7 @@ Example
|
|
81
80
|
ClientLike = Union[Client, NumPyClient]
|
82
81
|
|
83
82
|
|
83
|
+
# pylint: disable=import-outside-toplevel
|
84
84
|
def start_client(
|
85
85
|
*,
|
86
86
|
server_address: str,
|
@@ -138,7 +138,17 @@ def start_client(
|
|
138
138
|
event(EventType.START_CLIENT_ENTER)
|
139
139
|
|
140
140
|
# Use either gRPC bidirectional streaming or REST request/response
|
141
|
-
|
141
|
+
if rest:
|
142
|
+
try:
|
143
|
+
from .rest_client.connection import http_request_response
|
144
|
+
except ImportError as missing_dep:
|
145
|
+
raise ImportError(
|
146
|
+
"To use the REST API you must install the "
|
147
|
+
"extra dependencies by running `pip install flwr['rest']`."
|
148
|
+
) from missing_dep
|
149
|
+
connection = http_request_response
|
150
|
+
else:
|
151
|
+
connection = grpc_connection
|
142
152
|
while True:
|
143
153
|
sleep_duration: int = 0
|
144
154
|
with connection(
|
@@ -19,7 +19,13 @@ from contextlib import contextmanager
|
|
19
19
|
from logging import ERROR, INFO, WARN
|
20
20
|
from typing import Callable, Dict, Iterator, Optional, Tuple
|
21
21
|
|
22
|
-
|
22
|
+
try:
|
23
|
+
import requests
|
24
|
+
except ImportError as missing_dep:
|
25
|
+
raise ImportError(
|
26
|
+
"To use the REST API you must install the "
|
27
|
+
"extra dependencies by running `pip install flwr['rest']`."
|
28
|
+
) from missing_dep
|
23
29
|
|
24
30
|
from flwr.client.message_handler.task_handler import get_server_message
|
25
31
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
flwr/server/app.py
CHANGED
@@ -25,7 +25,6 @@ from types import FrameType
|
|
25
25
|
from typing import List, Optional, Tuple
|
26
26
|
|
27
27
|
import grpc
|
28
|
-
import uvicorn
|
29
28
|
|
30
29
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
31
30
|
from flwr.common.logger import log
|
@@ -40,7 +39,6 @@ from flwr.server.grpc_server.grpc_server import (
|
|
40
39
|
start_grpc_server,
|
41
40
|
)
|
42
41
|
from flwr.server.history import History
|
43
|
-
from flwr.server.rest_server.rest_api import app as fast_api_app
|
44
42
|
from flwr.server.server import Server
|
45
43
|
from flwr.server.state import StateFactory
|
46
44
|
from flwr.server.strategy import FedAvg, Strategy
|
@@ -436,11 +434,22 @@ def _run_fleet_api_grpc_bidi(
|
|
436
434
|
return fleet_grpc_server
|
437
435
|
|
438
436
|
|
437
|
+
# pylint: disable=import-outside-toplevel
|
439
438
|
def _run_fleet_api_rest(
|
440
439
|
address: str,
|
441
440
|
state_factory: StateFactory,
|
442
441
|
) -> None:
|
443
442
|
"""Run Driver API (REST-based)."""
|
443
|
+
try:
|
444
|
+
import uvicorn
|
445
|
+
|
446
|
+
from flwr.server.rest_server.rest_api import app as fast_api_app
|
447
|
+
except ImportError as missing_dep:
|
448
|
+
raise ImportError(
|
449
|
+
"To use the REST API you must install the "
|
450
|
+
"extra dependencies by running "
|
451
|
+
"`pip install flwr['rest']`."
|
452
|
+
) from missing_dep
|
444
453
|
log(INFO, "Starting Flower REST server")
|
445
454
|
|
446
455
|
# See: https://www.starlette.io/applications/#accessing-the-app-instance
|
@@ -19,8 +19,14 @@ from logging import INFO
|
|
19
19
|
from typing import List, Optional
|
20
20
|
from uuid import UUID
|
21
21
|
|
22
|
-
|
23
|
-
from
|
22
|
+
try:
|
23
|
+
from fastapi import FastAPI, HTTPException, Request, Response
|
24
|
+
from starlette.datastructures import Headers
|
25
|
+
except ImportError as missing_dep:
|
26
|
+
raise ImportError(
|
27
|
+
"To use the REST API you must install the "
|
28
|
+
"extra dependencies by running `pip install flwr['rest']`."
|
29
|
+
) from missing_dep
|
24
30
|
|
25
31
|
from flwr.common.logger import log
|
26
32
|
from flwr.proto.fleet_pb2 import (
|
@@ -33,10 +39,10 @@ from flwr.proto.fleet_pb2 import (
|
|
33
39
|
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
34
40
|
from flwr.server.state import State
|
35
41
|
|
36
|
-
app = FastAPI()
|
42
|
+
app: FastAPI = FastAPI()
|
37
43
|
|
38
44
|
|
39
|
-
@app.post("/api/v0/fleet/pull-task-ins", response_class=Response)
|
45
|
+
@app.post("/api/v0/fleet/pull-task-ins", response_class=Response) # type: ignore
|
40
46
|
async def pull_task_ins(request: Request) -> Response:
|
41
47
|
"""Pull TaskIns."""
|
42
48
|
_check_headers(request.headers)
|
@@ -72,7 +78,7 @@ async def pull_task_ins(request: Request) -> Response:
|
|
72
78
|
)
|
73
79
|
|
74
80
|
|
75
|
-
@app.post("/api/v0/fleet/push-task-res", response_class=Response)
|
81
|
+
@app.post("/api/v0/fleet/push-task-res", response_class=Response) # type: ignore
|
76
82
|
async def push_task_res(request: Request) -> Response: # Check if token is needed here
|
77
83
|
"""Push TaskRes."""
|
78
84
|
_check_headers(request.headers)
|
@@ -37,9 +37,6 @@ from flwr.server.client_proxy import ClientProxy
|
|
37
37
|
from .aggregate import aggregate
|
38
38
|
from .fedavg import FedAvg
|
39
39
|
|
40
|
-
# from xgboost import XGBClassifier, XGBRegressor # pylint: disable=W0611
|
41
|
-
|
42
|
-
|
43
40
|
WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
|
44
41
|
Setting `min_available_clients` lower than `min_fit_clients` or
|
45
42
|
`min_evaluate_clients` can cause the server to fail when there are too few clients
|
{flwr_nightly-1.4.0.dev20230321.dist-info → flwr_nightly-1.4.0.dev20230323.dist-info}/RECORD
RENAMED
@@ -1,6 +1,6 @@
|
|
1
|
-
flwr/__init__.py,sha256=
|
1
|
+
flwr/__init__.py,sha256=KJDlgYheeHeI1KlSwlGbjtzO1mqLgtL5zYhthG9flFk,929
|
2
2
|
flwr/client/__init__.py,sha256=UkiAEKds0mtpmVTZvvpJHKfTC_VWzflwL5IQUdbAf6U,1164
|
3
|
-
flwr/client/app.py,sha256=
|
3
|
+
flwr/client/app.py,sha256=TiCranPrEF20JptsY8ppUAVXkbcLHqFs4OzSciGsogc,11880
|
4
4
|
flwr/client/client.py,sha256=pR7A0ojBD2YEqz0nokSxKtbyBxmUQ6qlFmiJ9SAW8jE,6503
|
5
5
|
flwr/client/dpfedavg_numpy_client.py,sha256=CGGFLuUCquVSGYm2ap0PjRapJ3A_h5uuO8NYr3k-P3E,3351
|
6
6
|
flwr/client/grpc_client/__init__.py,sha256=svcQQhSaBJujNfZevHBsQrMKmqggpq-3I8Y6FUm_trM,728
|
@@ -10,7 +10,7 @@ flwr/client/message_handler/message_handler.py,sha256=8gXodt8fIZLWTcUJQN63fE78MF
|
|
10
10
|
flwr/client/message_handler/task_handler.py,sha256=KtNnwrHBd1aekkZeWsCfhO-BhYzBCYttwQl0InqpydM,1686
|
11
11
|
flwr/client/numpy_client.py,sha256=znnXgAO_BY-QUCqnTiCUJKS9TUY_z0boZ6D1dNHCTUA,5100
|
12
12
|
flwr/client/rest_client/__init__.py,sha256=FWlVEPeCUt2VHKQ9oEqmr8HtGruIegCO5iobBWvtb-Q,728
|
13
|
-
flwr/client/rest_client/connection.py,sha256=
|
13
|
+
flwr/client/rest_client/connection.py,sha256=JsMGhNc9ZAhqvh_rG9GtRDUzTKrMedZXNvt5lOkk3IM,7455
|
14
14
|
flwr/common/__init__.py,sha256=JLu4NbyMb-zrvS-uXB9T5biOz1FzJa5h6xIvAu1QX14,2816
|
15
15
|
flwr/common/date.py,sha256=CKvSQd-BcQ-Drr6VBOgjhUY8xTPmmIae-UeAW7HXODs,884
|
16
16
|
flwr/common/dp.py,sha256=yifOQAOcLkVLt7XrOed1A_n_Ep2wKWZZz75PF2Wu8Jo,1827
|
@@ -46,7 +46,7 @@ flwr/proto/transport_pb2_grpc.py,sha256=vLN3EHtx2aEEMCO4f1Upu-l27BPzd3-5pV-u8wPc
|
|
46
46
|
flwr/proto/transport_pb2_grpc.pyi,sha256=AGXf8RiIiW2J5IKMlm_3qT3AzcDa4F3P5IqUjve_esA,766
|
47
47
|
flwr/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
48
48
|
flwr/server/__init__.py,sha256=4-X514k82t4YCNn66UqEWAnyxpHsOhX1j-NRIoUSzFs,1331
|
49
|
-
flwr/server/app.py,sha256=
|
49
|
+
flwr/server/app.py,sha256=yct8HwpA0bJdmAFENxpS1Euz3TwPUVurnA6-tzRoUuc,17313
|
50
50
|
flwr/server/client_manager.py,sha256=OjJsIFos4kgol_knpGLa2IVCnFQJ1NuN-Qdr5BD6lYU,5941
|
51
51
|
flwr/server/client_proxy.py,sha256=f0XnGrhaBzWsKN7j0KD288GqlebhVeJuSOMN38cD8Dc,2228
|
52
52
|
flwr/server/criterion.py,sha256=9qfnUrDzgPhF8XIeBzJc1TrGrAmHOgAq24BEmZRN5BQ,1058
|
@@ -63,7 +63,7 @@ flwr/server/grpc_server/grpc_server.py,sha256=WiCurb8Juum4npU0I9Yz3GaBQX348d6Gdh
|
|
63
63
|
flwr/server/grpc_server/ins_scheduler.py,sha256=HgxavTZx6ejBRrf-9jxCyEzS8sHDSA0imiwOvfeHdS8,6314
|
64
64
|
flwr/server/history.py,sha256=i40ikSJcSLKT_W8DbxFuw6Zy8zx1xlNk0VfdMl-27dQ,4456
|
65
65
|
flwr/server/rest_server/__init__.py,sha256=2rG5KG7xojkD8jgwUrJPrW4Nsdm_HhwZaQ1poy9x3KQ,728
|
66
|
-
flwr/server/rest_server/rest_api.py,sha256=
|
66
|
+
flwr/server/rest_server/rest_api.py,sha256=X8XmSR4GpzwdEhi0qAq564t5Nymw8pictGaUXi4fugU,4640
|
67
67
|
flwr/server/server.py,sha256=X7hbDv39zaibGqAnUurmn9GrUYIP4pEea-p9dCpWXvg,15855
|
68
68
|
flwr/server/state/__init__.py,sha256=2ZRkns3PkXxUi_WzjbDWfSkLvGK4RE8o3SHgSHgoclQ,995
|
69
69
|
flwr/server/state/in_memory_state.py,sha256=Vsk0smDQ1pNLUeI5vIGkRLmTT-a0P6yOvyfMX8fDbv4,6521
|
@@ -83,8 +83,7 @@ flwr/server/strategy/fedavgm.py,sha256=mFhsEKkJHJQQPUXV5tMNOAckl6SJ41-DYFciomKmz
|
|
83
83
|
flwr/server/strategy/fedmedian.py,sha256=MAWrLijv2_wDT9KZeYLJ6W_dnVQGKEp2DK40V6qLr78,6469
|
84
84
|
flwr/server/strategy/fedopt.py,sha256=eXH1lIrjCNUEBkAs-ziULzwACxqcArooHWg9vlj9A3Q,5393
|
85
85
|
flwr/server/strategy/fedprox.py,sha256=sjHbjl9iUr6MTbinCYj5q2R0Tmr9o0X_5JrsUWXa7cE,7708
|
86
|
-
flwr/server/strategy/
|
87
|
-
flwr/server/strategy/fedxgb_nn_avg.py,sha256=fNvV2Mi2o8HDyx0PRZ1czevfTAGlDB1iilyQGXby97M,7849
|
86
|
+
flwr/server/strategy/fedxgb_nn_avg.py,sha256=-_iAjQtpmHUDC3qnvSEKdtazPHdW1V1mdaxyZfsj6NI,7772
|
88
87
|
flwr/server/strategy/fedyogi.py,sha256=cC9WoIci3xmVs5HBGrkXHM0cpdMjApBWZS4d7-bKnZE,7044
|
89
88
|
flwr/server/strategy/krum.py,sha256=PVfP8or0Rgj9TcxOzbngkYmm4VctX44NjOh1rs3Xmn8,7011
|
90
89
|
flwr/server/strategy/qfedavg.py,sha256=2CJE7jvtVUcWpofBg8mG1w79PwBUoltY8wkyHMaFdag,10744
|
@@ -96,8 +95,8 @@ flwr/simulation/__init__.py,sha256=7gYUX6zr_yJd1wtv65xRHajBvurkfoPVXv66buUy4H8,1
|
|
96
95
|
flwr/simulation/app.py,sha256=4UKy2OgCsTveHmtvIvqL2NsuL0OiIR4n1wDRLrugK88,7737
|
97
96
|
flwr/simulation/ray_transport/__init__.py,sha256=eJ3pijYkI7XhbX2rLu6FBGTo8hZkFL8RSj4twhApOGw,727
|
98
97
|
flwr/simulation/ray_transport/ray_client_proxy.py,sha256=WLO9ZagPCOyj-7e3x3gz-w0oiMVRdIhdHOkbrnmYvQo,5472
|
99
|
-
flwr_nightly-1.4.0.
|
100
|
-
flwr_nightly-1.4.0.
|
101
|
-
flwr_nightly-1.4.0.
|
102
|
-
flwr_nightly-1.4.0.
|
103
|
-
flwr_nightly-1.4.0.
|
98
|
+
flwr_nightly-1.4.0.dev20230323.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
99
|
+
flwr_nightly-1.4.0.dev20230323.dist-info/entry_points.txt,sha256=1uLlD5tIunkzALMfMWnqjdE_D5hRUX_I1iMmOMv6tZI,181
|
100
|
+
flwr_nightly-1.4.0.dev20230323.dist-info/WHEEL,sha256=vVCvjcmxuUltf8cYhJ0sJMRDLr1XsPuxEId8YDzbyCY,88
|
101
|
+
flwr_nightly-1.4.0.dev20230323.dist-info/METADATA,sha256=X68ZNpzTS1Blep8olTcQaE6LIp_HCb-QdqU4zJjpib0,12839
|
102
|
+
flwr_nightly-1.4.0.dev20230323.dist-info/RECORD,,
|
flwr/server/strategy/fedxgb.py
DELETED
@@ -1,153 +0,0 @@
|
|
1
|
-
# Copyright 2020 Adap GmbH. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
"""Federated XGBoost utility functions."""
|
16
|
-
|
17
|
-
from typing import Any, List, Optional, Tuple, Union
|
18
|
-
|
19
|
-
import numpy as np
|
20
|
-
import torch # pylint: disable=E0401
|
21
|
-
import xgboost as xgb # pylint: disable=E0401
|
22
|
-
from matplotlib import pyplot as plt # pylint: disable=E0401
|
23
|
-
from torch.utils.data import DataLoader, Dataset # pylint: disable=E0401
|
24
|
-
from xgboost import XGBClassifier, XGBRegressor # pylint: disable=E0401
|
25
|
-
|
26
|
-
from flwr.common.typing import NDArray
|
27
|
-
|
28
|
-
|
29
|
-
def plot_xgbtree(tree: Union[XGBClassifier, XGBRegressor], n_tree: int) -> None:
|
30
|
-
"""Visualize the built xgboost tree."""
|
31
|
-
xgb.plot_tree(tree, num_trees=n_tree)
|
32
|
-
plt.rcParams["figure.figsize"] = [50, 10]
|
33
|
-
plt.show()
|
34
|
-
|
35
|
-
|
36
|
-
def construct_tree(
|
37
|
-
dataset: Dataset, label: NDArray, n_estimators: int, tree_type: str
|
38
|
-
) -> Union[XGBClassifier, XGBRegressor]:
|
39
|
-
"""Construct a xgboost tree form tabular dataset."""
|
40
|
-
if tree_type == "BINARY":
|
41
|
-
tree = xgb.XGBClassifier(
|
42
|
-
objective="binary:logistic",
|
43
|
-
learning_rate=0.1,
|
44
|
-
max_depth=8,
|
45
|
-
n_estimators=n_estimators,
|
46
|
-
subsample=0.8,
|
47
|
-
colsample_bylevel=1,
|
48
|
-
colsample_bynode=1,
|
49
|
-
colsample_bytree=1,
|
50
|
-
alpha=5,
|
51
|
-
gamma=5,
|
52
|
-
num_parallel_tree=1,
|
53
|
-
min_child_weight=1,
|
54
|
-
)
|
55
|
-
|
56
|
-
elif tree_type == "REG":
|
57
|
-
tree = xgb.XGBRegressor(
|
58
|
-
objective="reg:squarederror",
|
59
|
-
learning_rate=0.1,
|
60
|
-
max_depth=8,
|
61
|
-
n_estimators=n_estimators,
|
62
|
-
subsample=0.8,
|
63
|
-
colsample_bylevel=1,
|
64
|
-
colsample_bynode=1,
|
65
|
-
colsample_bytree=1,
|
66
|
-
alpha=5,
|
67
|
-
gamma=5,
|
68
|
-
num_parallel_tree=1,
|
69
|
-
min_child_weight=1,
|
70
|
-
)
|
71
|
-
|
72
|
-
tree.fit(dataset, label)
|
73
|
-
return tree
|
74
|
-
|
75
|
-
|
76
|
-
def construct_tree_from_loader(
|
77
|
-
dataset_loader: DataLoader, n_estimators: int, tree_type: str
|
78
|
-
) -> Union[XGBClassifier, XGBRegressor]:
|
79
|
-
"""Construct a xgboost tree form tabular dataset loader."""
|
80
|
-
for dataset in dataset_loader:
|
81
|
-
data, label = dataset[0], dataset[1]
|
82
|
-
return construct_tree(data, label, n_estimators, tree_type)
|
83
|
-
|
84
|
-
|
85
|
-
def single_tree_prediction(
|
86
|
-
tree: Union[XGBClassifier, XGBRegressor], n_tree: int, dataset: NDArray
|
87
|
-
) -> Optional[NDArray]:
|
88
|
-
"""Extract the prediction result of a single tree in the xgboost tree
|
89
|
-
ensemble."""
|
90
|
-
# How to access a single tree
|
91
|
-
# https://github.com/bmreiniger/datascience.stackexchange/blob/master/57905.ipynb
|
92
|
-
num_t = len(tree.get_booster().get_dump())
|
93
|
-
if n_tree > num_t:
|
94
|
-
print(
|
95
|
-
"The tree index to be extracted is larger than the total number of trees."
|
96
|
-
)
|
97
|
-
return None
|
98
|
-
|
99
|
-
return tree.predict( # type: ignore
|
100
|
-
dataset, iteration_range=(n_tree, n_tree + 1), output_margin=True
|
101
|
-
)
|
102
|
-
|
103
|
-
|
104
|
-
def tree_encoding( # pylint: disable=R0914
|
105
|
-
trainloader: DataLoader,
|
106
|
-
client_trees: Union[
|
107
|
-
Tuple[XGBClassifier, int],
|
108
|
-
Tuple[XGBRegressor, int],
|
109
|
-
List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],
|
110
|
-
],
|
111
|
-
client_tree_num: int,
|
112
|
-
client_num: int,
|
113
|
-
) -> Optional[Tuple[NDArray, NDArray]]:
|
114
|
-
"""Transform the tabular dataset into prediction results using the
|
115
|
-
aggregated xgboost tree ensembles from all clients."""
|
116
|
-
if trainloader is None:
|
117
|
-
return None
|
118
|
-
|
119
|
-
for local_dataset in trainloader:
|
120
|
-
x_train, y_train = local_dataset[0], local_dataset[1]
|
121
|
-
|
122
|
-
x_train_enc = np.zeros((x_train.shape[0], client_num * client_tree_num))
|
123
|
-
x_train_enc = np.array(x_train_enc, copy=True)
|
124
|
-
|
125
|
-
temp_trees: Any = None
|
126
|
-
if isinstance(client_trees, list) is False:
|
127
|
-
temp_trees = [client_trees[0]] * client_num
|
128
|
-
elif isinstance(client_trees, list) and len(client_trees) != client_num:
|
129
|
-
temp_trees = [client_trees[0][0]] * client_num
|
130
|
-
else:
|
131
|
-
cids = []
|
132
|
-
temp_trees = []
|
133
|
-
for i, _ in enumerate(client_trees):
|
134
|
-
temp_trees.append(client_trees[i][0]) # type: ignore
|
135
|
-
cids.append(client_trees[i][1]) # type: ignore
|
136
|
-
sorted_index = np.argsort(np.asarray(cids))
|
137
|
-
temp_trees = np.asarray(temp_trees)[sorted_index]
|
138
|
-
|
139
|
-
for i, _ in enumerate(temp_trees):
|
140
|
-
for j in range(client_tree_num):
|
141
|
-
x_train_enc[:, i * client_tree_num + j] = single_tree_prediction(
|
142
|
-
temp_trees[i], j, x_train
|
143
|
-
)
|
144
|
-
|
145
|
-
x_train_enc32: Any = np.float32(x_train_enc)
|
146
|
-
y_train32: Any = np.float32(y_train)
|
147
|
-
|
148
|
-
x_train_enc32, y_train32 = torch.from_numpy(
|
149
|
-
np.expand_dims(x_train_enc32, axis=1) # type: ignore
|
150
|
-
), torch.from_numpy(
|
151
|
-
np.expand_dims(y_train32, axis=-1) # type: ignore
|
152
|
-
)
|
153
|
-
return x_train_enc32, y_train32
|
{flwr_nightly-1.4.0.dev20230321.dist-info → flwr_nightly-1.4.0.dev20230323.dist-info}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|