flwr-nightly 1.4.0.dev20230322__py3-none-any.whl → 1.4.0.dev20230323__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- flwr/__init__.py +2 -1
- flwr/client/app.py +12 -2
- flwr/client/rest_client/connection.py +7 -1
- flwr/server/app.py +11 -2
- flwr/server/rest_server/rest_api.py +11 -5
- flwr/server/strategy/fedxgb_nn_avg.py +0 -3
- {flwr_nightly-1.4.0.dev20230322.dist-info → flwr_nightly-1.4.0.dev20230323.dist-info}/METADATA +1 -1
- {flwr_nightly-1.4.0.dev20230322.dist-info → flwr_nightly-1.4.0.dev20230323.dist-info}/RECORD +11 -12
- flwr/server/strategy/fedxgb.py +0 -153
- {flwr_nightly-1.4.0.dev20230322.dist-info → flwr_nightly-1.4.0.dev20230323.dist-info}/LICENSE +0 -0
- {flwr_nightly-1.4.0.dev20230322.dist-info → flwr_nightly-1.4.0.dev20230323.dist-info}/WHEEL +0 -0
- {flwr_nightly-1.4.0.dev20230322.dist-info → flwr_nightly-1.4.0.dev20230323.dist-info}/entry_points.txt +0 -0
flwr/__init__.py
CHANGED
@@ -16,10 +16,11 @@
|
|
16
16
|
|
17
17
|
from flwr.common.version import package_version as _package_version
|
18
18
|
|
19
|
-
from . import client, server, simulation
|
19
|
+
from . import client, common, server, simulation
|
20
20
|
|
21
21
|
__all__ = [
|
22
22
|
"client",
|
23
|
+
"common",
|
23
24
|
"server",
|
24
25
|
"simulation",
|
25
26
|
]
|
flwr/client/app.py
CHANGED
@@ -49,7 +49,6 @@ from .numpy_client import has_evaluate as numpyclient_has_evaluate
|
|
49
49
|
from .numpy_client import has_fit as numpyclient_has_fit
|
50
50
|
from .numpy_client import has_get_parameters as numpyclient_has_get_parameters
|
51
51
|
from .numpy_client import has_get_properties as numpyclient_has_get_properties
|
52
|
-
from .rest_client.connection import http_request_response
|
53
52
|
|
54
53
|
EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT = """
|
55
54
|
NumPyClient.fit did not return a tuple with 3 elements.
|
@@ -81,6 +80,7 @@ Example
|
|
81
80
|
ClientLike = Union[Client, NumPyClient]
|
82
81
|
|
83
82
|
|
83
|
+
# pylint: disable=import-outside-toplevel
|
84
84
|
def start_client(
|
85
85
|
*,
|
86
86
|
server_address: str,
|
@@ -138,7 +138,17 @@ def start_client(
|
|
138
138
|
event(EventType.START_CLIENT_ENTER)
|
139
139
|
|
140
140
|
# Use either gRPC bidirectional streaming or REST request/response
|
141
|
-
|
141
|
+
if rest:
|
142
|
+
try:
|
143
|
+
from .rest_client.connection import http_request_response
|
144
|
+
except ImportError as missing_dep:
|
145
|
+
raise ImportError(
|
146
|
+
"To use the REST API you must install the "
|
147
|
+
"extra dependencies by running `pip install flwr['rest']`."
|
148
|
+
) from missing_dep
|
149
|
+
connection = http_request_response
|
150
|
+
else:
|
151
|
+
connection = grpc_connection
|
142
152
|
while True:
|
143
153
|
sleep_duration: int = 0
|
144
154
|
with connection(
|
@@ -19,7 +19,13 @@ from contextlib import contextmanager
|
|
19
19
|
from logging import ERROR, INFO, WARN
|
20
20
|
from typing import Callable, Dict, Iterator, Optional, Tuple
|
21
21
|
|
22
|
-
|
22
|
+
try:
|
23
|
+
import requests
|
24
|
+
except ImportError as missing_dep:
|
25
|
+
raise ImportError(
|
26
|
+
"To use the REST API you must install the "
|
27
|
+
"extra dependencies by running `pip install flwr['rest']`."
|
28
|
+
) from missing_dep
|
23
29
|
|
24
30
|
from flwr.client.message_handler.task_handler import get_server_message
|
25
31
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH
|
flwr/server/app.py
CHANGED
@@ -25,7 +25,6 @@ from types import FrameType
|
|
25
25
|
from typing import List, Optional, Tuple
|
26
26
|
|
27
27
|
import grpc
|
28
|
-
import uvicorn
|
29
28
|
|
30
29
|
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
|
31
30
|
from flwr.common.logger import log
|
@@ -40,7 +39,6 @@ from flwr.server.grpc_server.grpc_server import (
|
|
40
39
|
start_grpc_server,
|
41
40
|
)
|
42
41
|
from flwr.server.history import History
|
43
|
-
from flwr.server.rest_server.rest_api import app as fast_api_app
|
44
42
|
from flwr.server.server import Server
|
45
43
|
from flwr.server.state import StateFactory
|
46
44
|
from flwr.server.strategy import FedAvg, Strategy
|
@@ -436,11 +434,22 @@ def _run_fleet_api_grpc_bidi(
|
|
436
434
|
return fleet_grpc_server
|
437
435
|
|
438
436
|
|
437
|
+
# pylint: disable=import-outside-toplevel
|
439
438
|
def _run_fleet_api_rest(
|
440
439
|
address: str,
|
441
440
|
state_factory: StateFactory,
|
442
441
|
) -> None:
|
443
442
|
"""Run Driver API (REST-based)."""
|
443
|
+
try:
|
444
|
+
import uvicorn
|
445
|
+
|
446
|
+
from flwr.server.rest_server.rest_api import app as fast_api_app
|
447
|
+
except ImportError as missing_dep:
|
448
|
+
raise ImportError(
|
449
|
+
"To use the REST API you must install the "
|
450
|
+
"extra dependencies by running "
|
451
|
+
"`pip install flwr['rest']`."
|
452
|
+
) from missing_dep
|
444
453
|
log(INFO, "Starting Flower REST server")
|
445
454
|
|
446
455
|
# See: https://www.starlette.io/applications/#accessing-the-app-instance
|
@@ -19,8 +19,14 @@ from logging import INFO
|
|
19
19
|
from typing import List, Optional
|
20
20
|
from uuid import UUID
|
21
21
|
|
22
|
-
|
23
|
-
from
|
22
|
+
try:
|
23
|
+
from fastapi import FastAPI, HTTPException, Request, Response
|
24
|
+
from starlette.datastructures import Headers
|
25
|
+
except ImportError as missing_dep:
|
26
|
+
raise ImportError(
|
27
|
+
"To use the REST API you must install the "
|
28
|
+
"extra dependencies by running `pip install flwr['rest']`."
|
29
|
+
) from missing_dep
|
24
30
|
|
25
31
|
from flwr.common.logger import log
|
26
32
|
from flwr.proto.fleet_pb2 import (
|
@@ -33,10 +39,10 @@ from flwr.proto.fleet_pb2 import (
|
|
33
39
|
from flwr.proto.task_pb2 import TaskIns, TaskRes
|
34
40
|
from flwr.server.state import State
|
35
41
|
|
36
|
-
app = FastAPI()
|
42
|
+
app: FastAPI = FastAPI()
|
37
43
|
|
38
44
|
|
39
|
-
@app.post("/api/v0/fleet/pull-task-ins", response_class=Response)
|
45
|
+
@app.post("/api/v0/fleet/pull-task-ins", response_class=Response) # type: ignore
|
40
46
|
async def pull_task_ins(request: Request) -> Response:
|
41
47
|
"""Pull TaskIns."""
|
42
48
|
_check_headers(request.headers)
|
@@ -72,7 +78,7 @@ async def pull_task_ins(request: Request) -> Response:
|
|
72
78
|
)
|
73
79
|
|
74
80
|
|
75
|
-
@app.post("/api/v0/fleet/push-task-res", response_class=Response)
|
81
|
+
@app.post("/api/v0/fleet/push-task-res", response_class=Response) # type: ignore
|
76
82
|
async def push_task_res(request: Request) -> Response: # Check if token is needed here
|
77
83
|
"""Push TaskRes."""
|
78
84
|
_check_headers(request.headers)
|
@@ -37,9 +37,6 @@ from flwr.server.client_proxy import ClientProxy
|
|
37
37
|
from .aggregate import aggregate
|
38
38
|
from .fedavg import FedAvg
|
39
39
|
|
40
|
-
# from xgboost import XGBClassifier, XGBRegressor # pylint: disable=W0611
|
41
|
-
|
42
|
-
|
43
40
|
WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
|
44
41
|
Setting `min_available_clients` lower than `min_fit_clients` or
|
45
42
|
`min_evaluate_clients` can cause the server to fail when there are too few clients
|
{flwr_nightly-1.4.0.dev20230322.dist-info → flwr_nightly-1.4.0.dev20230323.dist-info}/RECORD
RENAMED
@@ -1,6 +1,6 @@
|
|
1
|
-
flwr/__init__.py,sha256=
|
1
|
+
flwr/__init__.py,sha256=KJDlgYheeHeI1KlSwlGbjtzO1mqLgtL5zYhthG9flFk,929
|
2
2
|
flwr/client/__init__.py,sha256=UkiAEKds0mtpmVTZvvpJHKfTC_VWzflwL5IQUdbAf6U,1164
|
3
|
-
flwr/client/app.py,sha256=
|
3
|
+
flwr/client/app.py,sha256=TiCranPrEF20JptsY8ppUAVXkbcLHqFs4OzSciGsogc,11880
|
4
4
|
flwr/client/client.py,sha256=pR7A0ojBD2YEqz0nokSxKtbyBxmUQ6qlFmiJ9SAW8jE,6503
|
5
5
|
flwr/client/dpfedavg_numpy_client.py,sha256=CGGFLuUCquVSGYm2ap0PjRapJ3A_h5uuO8NYr3k-P3E,3351
|
6
6
|
flwr/client/grpc_client/__init__.py,sha256=svcQQhSaBJujNfZevHBsQrMKmqggpq-3I8Y6FUm_trM,728
|
@@ -10,7 +10,7 @@ flwr/client/message_handler/message_handler.py,sha256=8gXodt8fIZLWTcUJQN63fE78MF
|
|
10
10
|
flwr/client/message_handler/task_handler.py,sha256=KtNnwrHBd1aekkZeWsCfhO-BhYzBCYttwQl0InqpydM,1686
|
11
11
|
flwr/client/numpy_client.py,sha256=znnXgAO_BY-QUCqnTiCUJKS9TUY_z0boZ6D1dNHCTUA,5100
|
12
12
|
flwr/client/rest_client/__init__.py,sha256=FWlVEPeCUt2VHKQ9oEqmr8HtGruIegCO5iobBWvtb-Q,728
|
13
|
-
flwr/client/rest_client/connection.py,sha256=
|
13
|
+
flwr/client/rest_client/connection.py,sha256=JsMGhNc9ZAhqvh_rG9GtRDUzTKrMedZXNvt5lOkk3IM,7455
|
14
14
|
flwr/common/__init__.py,sha256=JLu4NbyMb-zrvS-uXB9T5biOz1FzJa5h6xIvAu1QX14,2816
|
15
15
|
flwr/common/date.py,sha256=CKvSQd-BcQ-Drr6VBOgjhUY8xTPmmIae-UeAW7HXODs,884
|
16
16
|
flwr/common/dp.py,sha256=yifOQAOcLkVLt7XrOed1A_n_Ep2wKWZZz75PF2Wu8Jo,1827
|
@@ -46,7 +46,7 @@ flwr/proto/transport_pb2_grpc.py,sha256=vLN3EHtx2aEEMCO4f1Upu-l27BPzd3-5pV-u8wPc
|
|
46
46
|
flwr/proto/transport_pb2_grpc.pyi,sha256=AGXf8RiIiW2J5IKMlm_3qT3AzcDa4F3P5IqUjve_esA,766
|
47
47
|
flwr/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
48
48
|
flwr/server/__init__.py,sha256=4-X514k82t4YCNn66UqEWAnyxpHsOhX1j-NRIoUSzFs,1331
|
49
|
-
flwr/server/app.py,sha256=
|
49
|
+
flwr/server/app.py,sha256=yct8HwpA0bJdmAFENxpS1Euz3TwPUVurnA6-tzRoUuc,17313
|
50
50
|
flwr/server/client_manager.py,sha256=OjJsIFos4kgol_knpGLa2IVCnFQJ1NuN-Qdr5BD6lYU,5941
|
51
51
|
flwr/server/client_proxy.py,sha256=f0XnGrhaBzWsKN7j0KD288GqlebhVeJuSOMN38cD8Dc,2228
|
52
52
|
flwr/server/criterion.py,sha256=9qfnUrDzgPhF8XIeBzJc1TrGrAmHOgAq24BEmZRN5BQ,1058
|
@@ -63,7 +63,7 @@ flwr/server/grpc_server/grpc_server.py,sha256=WiCurb8Juum4npU0I9Yz3GaBQX348d6Gdh
|
|
63
63
|
flwr/server/grpc_server/ins_scheduler.py,sha256=HgxavTZx6ejBRrf-9jxCyEzS8sHDSA0imiwOvfeHdS8,6314
|
64
64
|
flwr/server/history.py,sha256=i40ikSJcSLKT_W8DbxFuw6Zy8zx1xlNk0VfdMl-27dQ,4456
|
65
65
|
flwr/server/rest_server/__init__.py,sha256=2rG5KG7xojkD8jgwUrJPrW4Nsdm_HhwZaQ1poy9x3KQ,728
|
66
|
-
flwr/server/rest_server/rest_api.py,sha256=
|
66
|
+
flwr/server/rest_server/rest_api.py,sha256=X8XmSR4GpzwdEhi0qAq564t5Nymw8pictGaUXi4fugU,4640
|
67
67
|
flwr/server/server.py,sha256=X7hbDv39zaibGqAnUurmn9GrUYIP4pEea-p9dCpWXvg,15855
|
68
68
|
flwr/server/state/__init__.py,sha256=2ZRkns3PkXxUi_WzjbDWfSkLvGK4RE8o3SHgSHgoclQ,995
|
69
69
|
flwr/server/state/in_memory_state.py,sha256=Vsk0smDQ1pNLUeI5vIGkRLmTT-a0P6yOvyfMX8fDbv4,6521
|
@@ -83,8 +83,7 @@ flwr/server/strategy/fedavgm.py,sha256=mFhsEKkJHJQQPUXV5tMNOAckl6SJ41-DYFciomKmz
|
|
83
83
|
flwr/server/strategy/fedmedian.py,sha256=MAWrLijv2_wDT9KZeYLJ6W_dnVQGKEp2DK40V6qLr78,6469
|
84
84
|
flwr/server/strategy/fedopt.py,sha256=eXH1lIrjCNUEBkAs-ziULzwACxqcArooHWg9vlj9A3Q,5393
|
85
85
|
flwr/server/strategy/fedprox.py,sha256=sjHbjl9iUr6MTbinCYj5q2R0Tmr9o0X_5JrsUWXa7cE,7708
|
86
|
-
flwr/server/strategy/
|
87
|
-
flwr/server/strategy/fedxgb_nn_avg.py,sha256=fNvV2Mi2o8HDyx0PRZ1czevfTAGlDB1iilyQGXby97M,7849
|
86
|
+
flwr/server/strategy/fedxgb_nn_avg.py,sha256=-_iAjQtpmHUDC3qnvSEKdtazPHdW1V1mdaxyZfsj6NI,7772
|
88
87
|
flwr/server/strategy/fedyogi.py,sha256=cC9WoIci3xmVs5HBGrkXHM0cpdMjApBWZS4d7-bKnZE,7044
|
89
88
|
flwr/server/strategy/krum.py,sha256=PVfP8or0Rgj9TcxOzbngkYmm4VctX44NjOh1rs3Xmn8,7011
|
90
89
|
flwr/server/strategy/qfedavg.py,sha256=2CJE7jvtVUcWpofBg8mG1w79PwBUoltY8wkyHMaFdag,10744
|
@@ -96,8 +95,8 @@ flwr/simulation/__init__.py,sha256=7gYUX6zr_yJd1wtv65xRHajBvurkfoPVXv66buUy4H8,1
|
|
96
95
|
flwr/simulation/app.py,sha256=4UKy2OgCsTveHmtvIvqL2NsuL0OiIR4n1wDRLrugK88,7737
|
97
96
|
flwr/simulation/ray_transport/__init__.py,sha256=eJ3pijYkI7XhbX2rLu6FBGTo8hZkFL8RSj4twhApOGw,727
|
98
97
|
flwr/simulation/ray_transport/ray_client_proxy.py,sha256=WLO9ZagPCOyj-7e3x3gz-w0oiMVRdIhdHOkbrnmYvQo,5472
|
99
|
-
flwr_nightly-1.4.0.
|
100
|
-
flwr_nightly-1.4.0.
|
101
|
-
flwr_nightly-1.4.0.
|
102
|
-
flwr_nightly-1.4.0.
|
103
|
-
flwr_nightly-1.4.0.
|
98
|
+
flwr_nightly-1.4.0.dev20230323.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
99
|
+
flwr_nightly-1.4.0.dev20230323.dist-info/entry_points.txt,sha256=1uLlD5tIunkzALMfMWnqjdE_D5hRUX_I1iMmOMv6tZI,181
|
100
|
+
flwr_nightly-1.4.0.dev20230323.dist-info/WHEEL,sha256=vVCvjcmxuUltf8cYhJ0sJMRDLr1XsPuxEId8YDzbyCY,88
|
101
|
+
flwr_nightly-1.4.0.dev20230323.dist-info/METADATA,sha256=X68ZNpzTS1Blep8olTcQaE6LIp_HCb-QdqU4zJjpib0,12839
|
102
|
+
flwr_nightly-1.4.0.dev20230323.dist-info/RECORD,,
|
flwr/server/strategy/fedxgb.py
DELETED
@@ -1,153 +0,0 @@
|
|
1
|
-
# Copyright 2020 Adap GmbH. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
"""Federated XGBoost utility functions."""
|
16
|
-
|
17
|
-
from typing import Any, List, Optional, Tuple, Union
|
18
|
-
|
19
|
-
import numpy as np
|
20
|
-
import torch # pylint: disable=E0401
|
21
|
-
import xgboost as xgb # pylint: disable=E0401
|
22
|
-
from matplotlib import pyplot as plt # pylint: disable=E0401
|
23
|
-
from torch.utils.data import DataLoader, Dataset # pylint: disable=E0401
|
24
|
-
from xgboost import XGBClassifier, XGBRegressor # pylint: disable=E0401
|
25
|
-
|
26
|
-
from flwr.common.typing import NDArray
|
27
|
-
|
28
|
-
|
29
|
-
def plot_xgbtree(tree: Union[XGBClassifier, XGBRegressor], n_tree: int) -> None:
|
30
|
-
"""Visualize the built xgboost tree."""
|
31
|
-
xgb.plot_tree(tree, num_trees=n_tree)
|
32
|
-
plt.rcParams["figure.figsize"] = [50, 10]
|
33
|
-
plt.show()
|
34
|
-
|
35
|
-
|
36
|
-
def construct_tree(
|
37
|
-
dataset: Dataset, label: NDArray, n_estimators: int, tree_type: str
|
38
|
-
) -> Union[XGBClassifier, XGBRegressor]:
|
39
|
-
"""Construct a xgboost tree form tabular dataset."""
|
40
|
-
if tree_type == "BINARY":
|
41
|
-
tree = xgb.XGBClassifier(
|
42
|
-
objective="binary:logistic",
|
43
|
-
learning_rate=0.1,
|
44
|
-
max_depth=8,
|
45
|
-
n_estimators=n_estimators,
|
46
|
-
subsample=0.8,
|
47
|
-
colsample_bylevel=1,
|
48
|
-
colsample_bynode=1,
|
49
|
-
colsample_bytree=1,
|
50
|
-
alpha=5,
|
51
|
-
gamma=5,
|
52
|
-
num_parallel_tree=1,
|
53
|
-
min_child_weight=1,
|
54
|
-
)
|
55
|
-
|
56
|
-
elif tree_type == "REG":
|
57
|
-
tree = xgb.XGBRegressor(
|
58
|
-
objective="reg:squarederror",
|
59
|
-
learning_rate=0.1,
|
60
|
-
max_depth=8,
|
61
|
-
n_estimators=n_estimators,
|
62
|
-
subsample=0.8,
|
63
|
-
colsample_bylevel=1,
|
64
|
-
colsample_bynode=1,
|
65
|
-
colsample_bytree=1,
|
66
|
-
alpha=5,
|
67
|
-
gamma=5,
|
68
|
-
num_parallel_tree=1,
|
69
|
-
min_child_weight=1,
|
70
|
-
)
|
71
|
-
|
72
|
-
tree.fit(dataset, label)
|
73
|
-
return tree
|
74
|
-
|
75
|
-
|
76
|
-
def construct_tree_from_loader(
|
77
|
-
dataset_loader: DataLoader, n_estimators: int, tree_type: str
|
78
|
-
) -> Union[XGBClassifier, XGBRegressor]:
|
79
|
-
"""Construct a xgboost tree form tabular dataset loader."""
|
80
|
-
for dataset in dataset_loader:
|
81
|
-
data, label = dataset[0], dataset[1]
|
82
|
-
return construct_tree(data, label, n_estimators, tree_type)
|
83
|
-
|
84
|
-
|
85
|
-
def single_tree_prediction(
|
86
|
-
tree: Union[XGBClassifier, XGBRegressor], n_tree: int, dataset: NDArray
|
87
|
-
) -> Optional[NDArray]:
|
88
|
-
"""Extract the prediction result of a single tree in the xgboost tree
|
89
|
-
ensemble."""
|
90
|
-
# How to access a single tree
|
91
|
-
# https://github.com/bmreiniger/datascience.stackexchange/blob/master/57905.ipynb
|
92
|
-
num_t = len(tree.get_booster().get_dump())
|
93
|
-
if n_tree > num_t:
|
94
|
-
print(
|
95
|
-
"The tree index to be extracted is larger than the total number of trees."
|
96
|
-
)
|
97
|
-
return None
|
98
|
-
|
99
|
-
return tree.predict( # type: ignore
|
100
|
-
dataset, iteration_range=(n_tree, n_tree + 1), output_margin=True
|
101
|
-
)
|
102
|
-
|
103
|
-
|
104
|
-
def tree_encoding( # pylint: disable=R0914
|
105
|
-
trainloader: DataLoader,
|
106
|
-
client_trees: Union[
|
107
|
-
Tuple[XGBClassifier, int],
|
108
|
-
Tuple[XGBRegressor, int],
|
109
|
-
List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],
|
110
|
-
],
|
111
|
-
client_tree_num: int,
|
112
|
-
client_num: int,
|
113
|
-
) -> Optional[Tuple[NDArray, NDArray]]:
|
114
|
-
"""Transform the tabular dataset into prediction results using the
|
115
|
-
aggregated xgboost tree ensembles from all clients."""
|
116
|
-
if trainloader is None:
|
117
|
-
return None
|
118
|
-
|
119
|
-
for local_dataset in trainloader:
|
120
|
-
x_train, y_train = local_dataset[0], local_dataset[1]
|
121
|
-
|
122
|
-
x_train_enc = np.zeros((x_train.shape[0], client_num * client_tree_num))
|
123
|
-
x_train_enc = np.array(x_train_enc, copy=True)
|
124
|
-
|
125
|
-
temp_trees: Any = None
|
126
|
-
if isinstance(client_trees, list) is False:
|
127
|
-
temp_trees = [client_trees[0]] * client_num
|
128
|
-
elif isinstance(client_trees, list) and len(client_trees) != client_num:
|
129
|
-
temp_trees = [client_trees[0][0]] * client_num
|
130
|
-
else:
|
131
|
-
cids = []
|
132
|
-
temp_trees = []
|
133
|
-
for i, _ in enumerate(client_trees):
|
134
|
-
temp_trees.append(client_trees[i][0]) # type: ignore
|
135
|
-
cids.append(client_trees[i][1]) # type: ignore
|
136
|
-
sorted_index = np.argsort(np.asarray(cids))
|
137
|
-
temp_trees = np.asarray(temp_trees)[sorted_index]
|
138
|
-
|
139
|
-
for i, _ in enumerate(temp_trees):
|
140
|
-
for j in range(client_tree_num):
|
141
|
-
x_train_enc[:, i * client_tree_num + j] = single_tree_prediction(
|
142
|
-
temp_trees[i], j, x_train
|
143
|
-
)
|
144
|
-
|
145
|
-
x_train_enc32: Any = np.float32(x_train_enc)
|
146
|
-
y_train32: Any = np.float32(y_train)
|
147
|
-
|
148
|
-
x_train_enc32, y_train32 = torch.from_numpy(
|
149
|
-
np.expand_dims(x_train_enc32, axis=1) # type: ignore
|
150
|
-
), torch.from_numpy(
|
151
|
-
np.expand_dims(y_train32, axis=-1) # type: ignore
|
152
|
-
)
|
153
|
-
return x_train_enc32, y_train32
|
{flwr_nightly-1.4.0.dev20230322.dist-info → flwr_nightly-1.4.0.dev20230323.dist-info}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|