flwr-nightly 1.4.0.dev20230321__py3-none-any.whl → 1.4.0.dev20230323__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
flwr/__init__.py CHANGED
@@ -16,10 +16,11 @@
16
16
 
17
17
  from flwr.common.version import package_version as _package_version
18
18
 
19
- from . import client, server, simulation
19
+ from . import client, common, server, simulation
20
20
 
21
21
  __all__ = [
22
22
  "client",
23
+ "common",
23
24
  "server",
24
25
  "simulation",
25
26
  ]
flwr/client/app.py CHANGED
@@ -49,7 +49,6 @@ from .numpy_client import has_evaluate as numpyclient_has_evaluate
49
49
  from .numpy_client import has_fit as numpyclient_has_fit
50
50
  from .numpy_client import has_get_parameters as numpyclient_has_get_parameters
51
51
  from .numpy_client import has_get_properties as numpyclient_has_get_properties
52
- from .rest_client.connection import http_request_response
53
52
 
54
53
  EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT = """
55
54
  NumPyClient.fit did not return a tuple with 3 elements.
@@ -81,6 +80,7 @@ Example
81
80
  ClientLike = Union[Client, NumPyClient]
82
81
 
83
82
 
83
+ # pylint: disable=import-outside-toplevel
84
84
  def start_client(
85
85
  *,
86
86
  server_address: str,
@@ -138,7 +138,17 @@ def start_client(
138
138
  event(EventType.START_CLIENT_ENTER)
139
139
 
140
140
  # Use either gRPC bidirectional streaming or REST request/response
141
- connection = http_request_response if rest else grpc_connection
141
+ if rest:
142
+ try:
143
+ from .rest_client.connection import http_request_response
144
+ except ImportError as missing_dep:
145
+ raise ImportError(
146
+ "To use the REST API you must install the "
147
+ "extra dependencies by running `pip install flwr['rest']`."
148
+ ) from missing_dep
149
+ connection = http_request_response
150
+ else:
151
+ connection = grpc_connection
142
152
  while True:
143
153
  sleep_duration: int = 0
144
154
  with connection(
@@ -19,7 +19,13 @@ from contextlib import contextmanager
19
19
  from logging import ERROR, INFO, WARN
20
20
  from typing import Callable, Dict, Iterator, Optional, Tuple
21
21
 
22
- import requests
22
+ try:
23
+ import requests
24
+ except ImportError as missing_dep:
25
+ raise ImportError(
26
+ "To use the REST API you must install the "
27
+ "extra dependencies by running `pip install flwr['rest']`."
28
+ ) from missing_dep
23
29
 
24
30
  from flwr.client.message_handler.task_handler import get_server_message
25
31
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
flwr/server/app.py CHANGED
@@ -25,7 +25,6 @@ from types import FrameType
25
25
  from typing import List, Optional, Tuple
26
26
 
27
27
  import grpc
28
- import uvicorn
29
28
 
30
29
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
31
30
  from flwr.common.logger import log
@@ -40,7 +39,6 @@ from flwr.server.grpc_server.grpc_server import (
40
39
  start_grpc_server,
41
40
  )
42
41
  from flwr.server.history import History
43
- from flwr.server.rest_server.rest_api import app as fast_api_app
44
42
  from flwr.server.server import Server
45
43
  from flwr.server.state import StateFactory
46
44
  from flwr.server.strategy import FedAvg, Strategy
@@ -436,11 +434,22 @@ def _run_fleet_api_grpc_bidi(
436
434
  return fleet_grpc_server
437
435
 
438
436
 
437
+ # pylint: disable=import-outside-toplevel
439
438
  def _run_fleet_api_rest(
440
439
  address: str,
441
440
  state_factory: StateFactory,
442
441
  ) -> None:
443
442
  """Run Driver API (REST-based)."""
443
+ try:
444
+ import uvicorn
445
+
446
+ from flwr.server.rest_server.rest_api import app as fast_api_app
447
+ except ImportError as missing_dep:
448
+ raise ImportError(
449
+ "To use the REST API you must install the "
450
+ "extra dependencies by running "
451
+ "`pip install flwr['rest']`."
452
+ ) from missing_dep
444
453
  log(INFO, "Starting Flower REST server")
445
454
 
446
455
  # See: https://www.starlette.io/applications/#accessing-the-app-instance
@@ -19,8 +19,14 @@ from logging import INFO
19
19
  from typing import List, Optional
20
20
  from uuid import UUID
21
21
 
22
- from fastapi import FastAPI, HTTPException, Request, Response
23
- from starlette.datastructures import Headers
22
+ try:
23
+ from fastapi import FastAPI, HTTPException, Request, Response
24
+ from starlette.datastructures import Headers
25
+ except ImportError as missing_dep:
26
+ raise ImportError(
27
+ "To use the REST API you must install the "
28
+ "extra dependencies by running `pip install flwr['rest']`."
29
+ ) from missing_dep
24
30
 
25
31
  from flwr.common.logger import log
26
32
  from flwr.proto.fleet_pb2 import (
@@ -33,10 +39,10 @@ from flwr.proto.fleet_pb2 import (
33
39
  from flwr.proto.task_pb2 import TaskIns, TaskRes
34
40
  from flwr.server.state import State
35
41
 
36
- app = FastAPI()
42
+ app: FastAPI = FastAPI()
37
43
 
38
44
 
39
- @app.post("/api/v0/fleet/pull-task-ins", response_class=Response)
45
+ @app.post("/api/v0/fleet/pull-task-ins", response_class=Response) # type: ignore
40
46
  async def pull_task_ins(request: Request) -> Response:
41
47
  """Pull TaskIns."""
42
48
  _check_headers(request.headers)
@@ -72,7 +78,7 @@ async def pull_task_ins(request: Request) -> Response:
72
78
  )
73
79
 
74
80
 
75
- @app.post("/api/v0/fleet/push-task-res", response_class=Response)
81
+ @app.post("/api/v0/fleet/push-task-res", response_class=Response) # type: ignore
76
82
  async def push_task_res(request: Request) -> Response: # Check if token is needed here
77
83
  """Push TaskRes."""
78
84
  _check_headers(request.headers)
@@ -37,9 +37,6 @@ from flwr.server.client_proxy import ClientProxy
37
37
  from .aggregate import aggregate
38
38
  from .fedavg import FedAvg
39
39
 
40
- # from xgboost import XGBClassifier, XGBRegressor # pylint: disable=W0611
41
-
42
-
43
40
  WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
44
41
  Setting `min_available_clients` lower than `min_fit_clients` or
45
42
  `min_evaluate_clients` can cause the server to fail when there are too few clients
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.4.0.dev20230321
3
+ Version: 1.4.0.dev20230323
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.dev
6
6
  License: Apache-2.0
@@ -1,6 +1,6 @@
1
- flwr/__init__.py,sha256=DguD_tpDL2vR1F6rRs-H3ms6Cmbep0GCCGxln84yftU,907
1
+ flwr/__init__.py,sha256=KJDlgYheeHeI1KlSwlGbjtzO1mqLgtL5zYhthG9flFk,929
2
2
  flwr/client/__init__.py,sha256=UkiAEKds0mtpmVTZvvpJHKfTC_VWzflwL5IQUdbAf6U,1164
3
- flwr/client/app.py,sha256=zOZ3iMUz3T-7yiVQTaFwqjUoWPG2kKK32jaoFqL3d_s,11537
3
+ flwr/client/app.py,sha256=TiCranPrEF20JptsY8ppUAVXkbcLHqFs4OzSciGsogc,11880
4
4
  flwr/client/client.py,sha256=pR7A0ojBD2YEqz0nokSxKtbyBxmUQ6qlFmiJ9SAW8jE,6503
5
5
  flwr/client/dpfedavg_numpy_client.py,sha256=CGGFLuUCquVSGYm2ap0PjRapJ3A_h5uuO8NYr3k-P3E,3351
6
6
  flwr/client/grpc_client/__init__.py,sha256=svcQQhSaBJujNfZevHBsQrMKmqggpq-3I8Y6FUm_trM,728
@@ -10,7 +10,7 @@ flwr/client/message_handler/message_handler.py,sha256=8gXodt8fIZLWTcUJQN63fE78MF
10
10
  flwr/client/message_handler/task_handler.py,sha256=KtNnwrHBd1aekkZeWsCfhO-BhYzBCYttwQl0InqpydM,1686
11
11
  flwr/client/numpy_client.py,sha256=znnXgAO_BY-QUCqnTiCUJKS9TUY_z0boZ6D1dNHCTUA,5100
12
12
  flwr/client/rest_client/__init__.py,sha256=FWlVEPeCUt2VHKQ9oEqmr8HtGruIegCO5iobBWvtb-Q,728
13
- flwr/client/rest_client/connection.py,sha256=4qOMlGztSQedxVyaCPb5yF4LEfuwzGrmhrMdbYcOWfo,7245
13
+ flwr/client/rest_client/connection.py,sha256=JsMGhNc9ZAhqvh_rG9GtRDUzTKrMedZXNvt5lOkk3IM,7455
14
14
  flwr/common/__init__.py,sha256=JLu4NbyMb-zrvS-uXB9T5biOz1FzJa5h6xIvAu1QX14,2816
15
15
  flwr/common/date.py,sha256=CKvSQd-BcQ-Drr6VBOgjhUY8xTPmmIae-UeAW7HXODs,884
16
16
  flwr/common/dp.py,sha256=yifOQAOcLkVLt7XrOed1A_n_Ep2wKWZZz75PF2Wu8Jo,1827
@@ -46,7 +46,7 @@ flwr/proto/transport_pb2_grpc.py,sha256=vLN3EHtx2aEEMCO4f1Upu-l27BPzd3-5pV-u8wPc
46
46
  flwr/proto/transport_pb2_grpc.pyi,sha256=AGXf8RiIiW2J5IKMlm_3qT3AzcDa4F3P5IqUjve_esA,766
47
47
  flwr/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
48
  flwr/server/__init__.py,sha256=4-X514k82t4YCNn66UqEWAnyxpHsOhX1j-NRIoUSzFs,1331
49
- flwr/server/app.py,sha256=tyQKbkNkcD2zf7Z-RQG2RIuBAQJPFybJoXlaHIjqU6I,17009
49
+ flwr/server/app.py,sha256=yct8HwpA0bJdmAFENxpS1Euz3TwPUVurnA6-tzRoUuc,17313
50
50
  flwr/server/client_manager.py,sha256=OjJsIFos4kgol_knpGLa2IVCnFQJ1NuN-Qdr5BD6lYU,5941
51
51
  flwr/server/client_proxy.py,sha256=f0XnGrhaBzWsKN7j0KD288GqlebhVeJuSOMN38cD8Dc,2228
52
52
  flwr/server/criterion.py,sha256=9qfnUrDzgPhF8XIeBzJc1TrGrAmHOgAq24BEmZRN5BQ,1058
@@ -63,7 +63,7 @@ flwr/server/grpc_server/grpc_server.py,sha256=WiCurb8Juum4npU0I9Yz3GaBQX348d6Gdh
63
63
  flwr/server/grpc_server/ins_scheduler.py,sha256=HgxavTZx6ejBRrf-9jxCyEzS8sHDSA0imiwOvfeHdS8,6314
64
64
  flwr/server/history.py,sha256=i40ikSJcSLKT_W8DbxFuw6Zy8zx1xlNk0VfdMl-27dQ,4456
65
65
  flwr/server/rest_server/__init__.py,sha256=2rG5KG7xojkD8jgwUrJPrW4Nsdm_HhwZaQ1poy9x3KQ,728
66
- flwr/server/rest_server/rest_api.py,sha256=JsQIT9azBDf6G5smLaXU6TqJ3ZAQ6TUhVkjSEdp88z8,4385
66
+ flwr/server/rest_server/rest_api.py,sha256=X8XmSR4GpzwdEhi0qAq564t5Nymw8pictGaUXi4fugU,4640
67
67
  flwr/server/server.py,sha256=X7hbDv39zaibGqAnUurmn9GrUYIP4pEea-p9dCpWXvg,15855
68
68
  flwr/server/state/__init__.py,sha256=2ZRkns3PkXxUi_WzjbDWfSkLvGK4RE8o3SHgSHgoclQ,995
69
69
  flwr/server/state/in_memory_state.py,sha256=Vsk0smDQ1pNLUeI5vIGkRLmTT-a0P6yOvyfMX8fDbv4,6521
@@ -83,8 +83,7 @@ flwr/server/strategy/fedavgm.py,sha256=mFhsEKkJHJQQPUXV5tMNOAckl6SJ41-DYFciomKmz
83
83
  flwr/server/strategy/fedmedian.py,sha256=MAWrLijv2_wDT9KZeYLJ6W_dnVQGKEp2DK40V6qLr78,6469
84
84
  flwr/server/strategy/fedopt.py,sha256=eXH1lIrjCNUEBkAs-ziULzwACxqcArooHWg9vlj9A3Q,5393
85
85
  flwr/server/strategy/fedprox.py,sha256=sjHbjl9iUr6MTbinCYj5q2R0Tmr9o0X_5JrsUWXa7cE,7708
86
- flwr/server/strategy/fedxgb.py,sha256=klRj1bykXK3iXYf7vJ1mqjw9BHsRKChzB9lK29j7Oq4,5373
87
- flwr/server/strategy/fedxgb_nn_avg.py,sha256=fNvV2Mi2o8HDyx0PRZ1czevfTAGlDB1iilyQGXby97M,7849
86
+ flwr/server/strategy/fedxgb_nn_avg.py,sha256=-_iAjQtpmHUDC3qnvSEKdtazPHdW1V1mdaxyZfsj6NI,7772
88
87
  flwr/server/strategy/fedyogi.py,sha256=cC9WoIci3xmVs5HBGrkXHM0cpdMjApBWZS4d7-bKnZE,7044
89
88
  flwr/server/strategy/krum.py,sha256=PVfP8or0Rgj9TcxOzbngkYmm4VctX44NjOh1rs3Xmn8,7011
90
89
  flwr/server/strategy/qfedavg.py,sha256=2CJE7jvtVUcWpofBg8mG1w79PwBUoltY8wkyHMaFdag,10744
@@ -96,8 +95,8 @@ flwr/simulation/__init__.py,sha256=7gYUX6zr_yJd1wtv65xRHajBvurkfoPVXv66buUy4H8,1
96
95
  flwr/simulation/app.py,sha256=4UKy2OgCsTveHmtvIvqL2NsuL0OiIR4n1wDRLrugK88,7737
97
96
  flwr/simulation/ray_transport/__init__.py,sha256=eJ3pijYkI7XhbX2rLu6FBGTo8hZkFL8RSj4twhApOGw,727
98
97
  flwr/simulation/ray_transport/ray_client_proxy.py,sha256=WLO9ZagPCOyj-7e3x3gz-w0oiMVRdIhdHOkbrnmYvQo,5472
99
- flwr_nightly-1.4.0.dev20230321.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
100
- flwr_nightly-1.4.0.dev20230321.dist-info/entry_points.txt,sha256=1uLlD5tIunkzALMfMWnqjdE_D5hRUX_I1iMmOMv6tZI,181
101
- flwr_nightly-1.4.0.dev20230321.dist-info/WHEEL,sha256=vVCvjcmxuUltf8cYhJ0sJMRDLr1XsPuxEId8YDzbyCY,88
102
- flwr_nightly-1.4.0.dev20230321.dist-info/METADATA,sha256=EAJHuM2CJqJZNqpECYbaHoHvNOokVkIH7Yk87eFvp3A,12839
103
- flwr_nightly-1.4.0.dev20230321.dist-info/RECORD,,
98
+ flwr_nightly-1.4.0.dev20230323.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
99
+ flwr_nightly-1.4.0.dev20230323.dist-info/entry_points.txt,sha256=1uLlD5tIunkzALMfMWnqjdE_D5hRUX_I1iMmOMv6tZI,181
100
+ flwr_nightly-1.4.0.dev20230323.dist-info/WHEEL,sha256=vVCvjcmxuUltf8cYhJ0sJMRDLr1XsPuxEId8YDzbyCY,88
101
+ flwr_nightly-1.4.0.dev20230323.dist-info/METADATA,sha256=X68ZNpzTS1Blep8olTcQaE6LIp_HCb-QdqU4zJjpib0,12839
102
+ flwr_nightly-1.4.0.dev20230323.dist-info/RECORD,,
@@ -1,153 +0,0 @@
1
- # Copyright 2020 Adap GmbH. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Federated XGBoost utility functions."""
16
-
17
- from typing import Any, List, Optional, Tuple, Union
18
-
19
- import numpy as np
20
- import torch # pylint: disable=E0401
21
- import xgboost as xgb # pylint: disable=E0401
22
- from matplotlib import pyplot as plt # pylint: disable=E0401
23
- from torch.utils.data import DataLoader, Dataset # pylint: disable=E0401
24
- from xgboost import XGBClassifier, XGBRegressor # pylint: disable=E0401
25
-
26
- from flwr.common.typing import NDArray
27
-
28
-
29
- def plot_xgbtree(tree: Union[XGBClassifier, XGBRegressor], n_tree: int) -> None:
30
- """Visualize the built xgboost tree."""
31
- xgb.plot_tree(tree, num_trees=n_tree)
32
- plt.rcParams["figure.figsize"] = [50, 10]
33
- plt.show()
34
-
35
-
36
- def construct_tree(
37
- dataset: Dataset, label: NDArray, n_estimators: int, tree_type: str
38
- ) -> Union[XGBClassifier, XGBRegressor]:
39
- """Construct a xgboost tree form tabular dataset."""
40
- if tree_type == "BINARY":
41
- tree = xgb.XGBClassifier(
42
- objective="binary:logistic",
43
- learning_rate=0.1,
44
- max_depth=8,
45
- n_estimators=n_estimators,
46
- subsample=0.8,
47
- colsample_bylevel=1,
48
- colsample_bynode=1,
49
- colsample_bytree=1,
50
- alpha=5,
51
- gamma=5,
52
- num_parallel_tree=1,
53
- min_child_weight=1,
54
- )
55
-
56
- elif tree_type == "REG":
57
- tree = xgb.XGBRegressor(
58
- objective="reg:squarederror",
59
- learning_rate=0.1,
60
- max_depth=8,
61
- n_estimators=n_estimators,
62
- subsample=0.8,
63
- colsample_bylevel=1,
64
- colsample_bynode=1,
65
- colsample_bytree=1,
66
- alpha=5,
67
- gamma=5,
68
- num_parallel_tree=1,
69
- min_child_weight=1,
70
- )
71
-
72
- tree.fit(dataset, label)
73
- return tree
74
-
75
-
76
- def construct_tree_from_loader(
77
- dataset_loader: DataLoader, n_estimators: int, tree_type: str
78
- ) -> Union[XGBClassifier, XGBRegressor]:
79
- """Construct a xgboost tree form tabular dataset loader."""
80
- for dataset in dataset_loader:
81
- data, label = dataset[0], dataset[1]
82
- return construct_tree(data, label, n_estimators, tree_type)
83
-
84
-
85
- def single_tree_prediction(
86
- tree: Union[XGBClassifier, XGBRegressor], n_tree: int, dataset: NDArray
87
- ) -> Optional[NDArray]:
88
- """Extract the prediction result of a single tree in the xgboost tree
89
- ensemble."""
90
- # How to access a single tree
91
- # https://github.com/bmreiniger/datascience.stackexchange/blob/master/57905.ipynb
92
- num_t = len(tree.get_booster().get_dump())
93
- if n_tree > num_t:
94
- print(
95
- "The tree index to be extracted is larger than the total number of trees."
96
- )
97
- return None
98
-
99
- return tree.predict( # type: ignore
100
- dataset, iteration_range=(n_tree, n_tree + 1), output_margin=True
101
- )
102
-
103
-
104
- def tree_encoding( # pylint: disable=R0914
105
- trainloader: DataLoader,
106
- client_trees: Union[
107
- Tuple[XGBClassifier, int],
108
- Tuple[XGBRegressor, int],
109
- List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],
110
- ],
111
- client_tree_num: int,
112
- client_num: int,
113
- ) -> Optional[Tuple[NDArray, NDArray]]:
114
- """Transform the tabular dataset into prediction results using the
115
- aggregated xgboost tree ensembles from all clients."""
116
- if trainloader is None:
117
- return None
118
-
119
- for local_dataset in trainloader:
120
- x_train, y_train = local_dataset[0], local_dataset[1]
121
-
122
- x_train_enc = np.zeros((x_train.shape[0], client_num * client_tree_num))
123
- x_train_enc = np.array(x_train_enc, copy=True)
124
-
125
- temp_trees: Any = None
126
- if isinstance(client_trees, list) is False:
127
- temp_trees = [client_trees[0]] * client_num
128
- elif isinstance(client_trees, list) and len(client_trees) != client_num:
129
- temp_trees = [client_trees[0][0]] * client_num
130
- else:
131
- cids = []
132
- temp_trees = []
133
- for i, _ in enumerate(client_trees):
134
- temp_trees.append(client_trees[i][0]) # type: ignore
135
- cids.append(client_trees[i][1]) # type: ignore
136
- sorted_index = np.argsort(np.asarray(cids))
137
- temp_trees = np.asarray(temp_trees)[sorted_index]
138
-
139
- for i, _ in enumerate(temp_trees):
140
- for j in range(client_tree_num):
141
- x_train_enc[:, i * client_tree_num + j] = single_tree_prediction(
142
- temp_trees[i], j, x_train
143
- )
144
-
145
- x_train_enc32: Any = np.float32(x_train_enc)
146
- y_train32: Any = np.float32(y_train)
147
-
148
- x_train_enc32, y_train32 = torch.from_numpy(
149
- np.expand_dims(x_train_enc32, axis=1) # type: ignore
150
- ), torch.from_numpy(
151
- np.expand_dims(y_train32, axis=-1) # type: ignore
152
- )
153
- return x_train_enc32, y_train32