flwr-nightly 1.4.0.dev20230321__py3-none-any.whl → 1.4.0.dev20230323__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flwr/__init__.py CHANGED
@@ -16,10 +16,11 @@
16
16
 
17
17
  from flwr.common.version import package_version as _package_version
18
18
 
19
- from . import client, server, simulation
19
+ from . import client, common, server, simulation
20
20
 
21
21
  __all__ = [
22
22
  "client",
23
+ "common",
23
24
  "server",
24
25
  "simulation",
25
26
  ]
flwr/client/app.py CHANGED
@@ -49,7 +49,6 @@ from .numpy_client import has_evaluate as numpyclient_has_evaluate
49
49
  from .numpy_client import has_fit as numpyclient_has_fit
50
50
  from .numpy_client import has_get_parameters as numpyclient_has_get_parameters
51
51
  from .numpy_client import has_get_properties as numpyclient_has_get_properties
52
- from .rest_client.connection import http_request_response
53
52
 
54
53
  EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT = """
55
54
  NumPyClient.fit did not return a tuple with 3 elements.
@@ -81,6 +80,7 @@ Example
81
80
  ClientLike = Union[Client, NumPyClient]
82
81
 
83
82
 
83
+ # pylint: disable=import-outside-toplevel
84
84
  def start_client(
85
85
  *,
86
86
  server_address: str,
@@ -138,7 +138,17 @@ def start_client(
138
138
  event(EventType.START_CLIENT_ENTER)
139
139
 
140
140
  # Use either gRPC bidirectional streaming or REST request/response
141
- connection = http_request_response if rest else grpc_connection
141
+ if rest:
142
+ try:
143
+ from .rest_client.connection import http_request_response
144
+ except ImportError as missing_dep:
145
+ raise ImportError(
146
+ "To use the REST API you must install the "
147
+ "extra dependencies by running `pip install flwr['rest']`."
148
+ ) from missing_dep
149
+ connection = http_request_response
150
+ else:
151
+ connection = grpc_connection
142
152
  while True:
143
153
  sleep_duration: int = 0
144
154
  with connection(
@@ -19,7 +19,13 @@ from contextlib import contextmanager
19
19
  from logging import ERROR, INFO, WARN
20
20
  from typing import Callable, Dict, Iterator, Optional, Tuple
21
21
 
22
- import requests
22
+ try:
23
+ import requests
24
+ except ImportError as missing_dep:
25
+ raise ImportError(
26
+ "To use the REST API you must install the "
27
+ "extra dependencies by running `pip install flwr['rest']`."
28
+ ) from missing_dep
23
29
 
24
30
  from flwr.client.message_handler.task_handler import get_server_message
25
31
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH
flwr/server/app.py CHANGED
@@ -25,7 +25,6 @@ from types import FrameType
25
25
  from typing import List, Optional, Tuple
26
26
 
27
27
  import grpc
28
- import uvicorn
29
28
 
30
29
  from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
31
30
  from flwr.common.logger import log
@@ -40,7 +39,6 @@ from flwr.server.grpc_server.grpc_server import (
40
39
  start_grpc_server,
41
40
  )
42
41
  from flwr.server.history import History
43
- from flwr.server.rest_server.rest_api import app as fast_api_app
44
42
  from flwr.server.server import Server
45
43
  from flwr.server.state import StateFactory
46
44
  from flwr.server.strategy import FedAvg, Strategy
@@ -436,11 +434,22 @@ def _run_fleet_api_grpc_bidi(
436
434
  return fleet_grpc_server
437
435
 
438
436
 
437
+ # pylint: disable=import-outside-toplevel
439
438
  def _run_fleet_api_rest(
440
439
  address: str,
441
440
  state_factory: StateFactory,
442
441
  ) -> None:
443
442
  """Run Driver API (REST-based)."""
443
+ try:
444
+ import uvicorn
445
+
446
+ from flwr.server.rest_server.rest_api import app as fast_api_app
447
+ except ImportError as missing_dep:
448
+ raise ImportError(
449
+ "To use the REST API you must install the "
450
+ "extra dependencies by running "
451
+ "`pip install flwr['rest']`."
452
+ ) from missing_dep
444
453
  log(INFO, "Starting Flower REST server")
445
454
 
446
455
  # See: https://www.starlette.io/applications/#accessing-the-app-instance
@@ -19,8 +19,14 @@ from logging import INFO
19
19
  from typing import List, Optional
20
20
  from uuid import UUID
21
21
 
22
- from fastapi import FastAPI, HTTPException, Request, Response
23
- from starlette.datastructures import Headers
22
+ try:
23
+ from fastapi import FastAPI, HTTPException, Request, Response
24
+ from starlette.datastructures import Headers
25
+ except ImportError as missing_dep:
26
+ raise ImportError(
27
+ "To use the REST API you must install the "
28
+ "extra dependencies by running `pip install flwr['rest']`."
29
+ ) from missing_dep
24
30
 
25
31
  from flwr.common.logger import log
26
32
  from flwr.proto.fleet_pb2 import (
@@ -33,10 +39,10 @@ from flwr.proto.fleet_pb2 import (
33
39
  from flwr.proto.task_pb2 import TaskIns, TaskRes
34
40
  from flwr.server.state import State
35
41
 
36
- app = FastAPI()
42
+ app: FastAPI = FastAPI()
37
43
 
38
44
 
39
- @app.post("/api/v0/fleet/pull-task-ins", response_class=Response)
45
+ @app.post("/api/v0/fleet/pull-task-ins", response_class=Response) # type: ignore
40
46
  async def pull_task_ins(request: Request) -> Response:
41
47
  """Pull TaskIns."""
42
48
  _check_headers(request.headers)
@@ -72,7 +78,7 @@ async def pull_task_ins(request: Request) -> Response:
72
78
  )
73
79
 
74
80
 
75
- @app.post("/api/v0/fleet/push-task-res", response_class=Response)
81
+ @app.post("/api/v0/fleet/push-task-res", response_class=Response) # type: ignore
76
82
  async def push_task_res(request: Request) -> Response: # Check if token is needed here
77
83
  """Push TaskRes."""
78
84
  _check_headers(request.headers)
@@ -37,9 +37,6 @@ from flwr.server.client_proxy import ClientProxy
37
37
  from .aggregate import aggregate
38
38
  from .fedavg import FedAvg
39
39
 
40
- # from xgboost import XGBClassifier, XGBRegressor # pylint: disable=W0611
41
-
42
-
43
40
  WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
44
41
  Setting `min_available_clients` lower than `min_fit_clients` or
45
42
  `min_evaluate_clients` can cause the server to fail when there are too few clients
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flwr-nightly
3
- Version: 1.4.0.dev20230321
3
+ Version: 1.4.0.dev20230323
4
4
  Summary: Flower: A Friendly Federated Learning Framework
5
5
  Home-page: https://flower.dev
6
6
  License: Apache-2.0
@@ -1,6 +1,6 @@
1
- flwr/__init__.py,sha256=DguD_tpDL2vR1F6rRs-H3ms6Cmbep0GCCGxln84yftU,907
1
+ flwr/__init__.py,sha256=KJDlgYheeHeI1KlSwlGbjtzO1mqLgtL5zYhthG9flFk,929
2
2
  flwr/client/__init__.py,sha256=UkiAEKds0mtpmVTZvvpJHKfTC_VWzflwL5IQUdbAf6U,1164
3
- flwr/client/app.py,sha256=zOZ3iMUz3T-7yiVQTaFwqjUoWPG2kKK32jaoFqL3d_s,11537
3
+ flwr/client/app.py,sha256=TiCranPrEF20JptsY8ppUAVXkbcLHqFs4OzSciGsogc,11880
4
4
  flwr/client/client.py,sha256=pR7A0ojBD2YEqz0nokSxKtbyBxmUQ6qlFmiJ9SAW8jE,6503
5
5
  flwr/client/dpfedavg_numpy_client.py,sha256=CGGFLuUCquVSGYm2ap0PjRapJ3A_h5uuO8NYr3k-P3E,3351
6
6
  flwr/client/grpc_client/__init__.py,sha256=svcQQhSaBJujNfZevHBsQrMKmqggpq-3I8Y6FUm_trM,728
@@ -10,7 +10,7 @@ flwr/client/message_handler/message_handler.py,sha256=8gXodt8fIZLWTcUJQN63fE78MF
10
10
  flwr/client/message_handler/task_handler.py,sha256=KtNnwrHBd1aekkZeWsCfhO-BhYzBCYttwQl0InqpydM,1686
11
11
  flwr/client/numpy_client.py,sha256=znnXgAO_BY-QUCqnTiCUJKS9TUY_z0boZ6D1dNHCTUA,5100
12
12
  flwr/client/rest_client/__init__.py,sha256=FWlVEPeCUt2VHKQ9oEqmr8HtGruIegCO5iobBWvtb-Q,728
13
- flwr/client/rest_client/connection.py,sha256=4qOMlGztSQedxVyaCPb5yF4LEfuwzGrmhrMdbYcOWfo,7245
13
+ flwr/client/rest_client/connection.py,sha256=JsMGhNc9ZAhqvh_rG9GtRDUzTKrMedZXNvt5lOkk3IM,7455
14
14
  flwr/common/__init__.py,sha256=JLu4NbyMb-zrvS-uXB9T5biOz1FzJa5h6xIvAu1QX14,2816
15
15
  flwr/common/date.py,sha256=CKvSQd-BcQ-Drr6VBOgjhUY8xTPmmIae-UeAW7HXODs,884
16
16
  flwr/common/dp.py,sha256=yifOQAOcLkVLt7XrOed1A_n_Ep2wKWZZz75PF2Wu8Jo,1827
@@ -46,7 +46,7 @@ flwr/proto/transport_pb2_grpc.py,sha256=vLN3EHtx2aEEMCO4f1Upu-l27BPzd3-5pV-u8wPc
46
46
  flwr/proto/transport_pb2_grpc.pyi,sha256=AGXf8RiIiW2J5IKMlm_3qT3AzcDa4F3P5IqUjve_esA,766
47
47
  flwr/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
48
  flwr/server/__init__.py,sha256=4-X514k82t4YCNn66UqEWAnyxpHsOhX1j-NRIoUSzFs,1331
49
- flwr/server/app.py,sha256=tyQKbkNkcD2zf7Z-RQG2RIuBAQJPFybJoXlaHIjqU6I,17009
49
+ flwr/server/app.py,sha256=yct8HwpA0bJdmAFENxpS1Euz3TwPUVurnA6-tzRoUuc,17313
50
50
  flwr/server/client_manager.py,sha256=OjJsIFos4kgol_knpGLa2IVCnFQJ1NuN-Qdr5BD6lYU,5941
51
51
  flwr/server/client_proxy.py,sha256=f0XnGrhaBzWsKN7j0KD288GqlebhVeJuSOMN38cD8Dc,2228
52
52
  flwr/server/criterion.py,sha256=9qfnUrDzgPhF8XIeBzJc1TrGrAmHOgAq24BEmZRN5BQ,1058
@@ -63,7 +63,7 @@ flwr/server/grpc_server/grpc_server.py,sha256=WiCurb8Juum4npU0I9Yz3GaBQX348d6Gdh
63
63
  flwr/server/grpc_server/ins_scheduler.py,sha256=HgxavTZx6ejBRrf-9jxCyEzS8sHDSA0imiwOvfeHdS8,6314
64
64
  flwr/server/history.py,sha256=i40ikSJcSLKT_W8DbxFuw6Zy8zx1xlNk0VfdMl-27dQ,4456
65
65
  flwr/server/rest_server/__init__.py,sha256=2rG5KG7xojkD8jgwUrJPrW4Nsdm_HhwZaQ1poy9x3KQ,728
66
- flwr/server/rest_server/rest_api.py,sha256=JsQIT9azBDf6G5smLaXU6TqJ3ZAQ6TUhVkjSEdp88z8,4385
66
+ flwr/server/rest_server/rest_api.py,sha256=X8XmSR4GpzwdEhi0qAq564t5Nymw8pictGaUXi4fugU,4640
67
67
  flwr/server/server.py,sha256=X7hbDv39zaibGqAnUurmn9GrUYIP4pEea-p9dCpWXvg,15855
68
68
  flwr/server/state/__init__.py,sha256=2ZRkns3PkXxUi_WzjbDWfSkLvGK4RE8o3SHgSHgoclQ,995
69
69
  flwr/server/state/in_memory_state.py,sha256=Vsk0smDQ1pNLUeI5vIGkRLmTT-a0P6yOvyfMX8fDbv4,6521
@@ -83,8 +83,7 @@ flwr/server/strategy/fedavgm.py,sha256=mFhsEKkJHJQQPUXV5tMNOAckl6SJ41-DYFciomKmz
83
83
  flwr/server/strategy/fedmedian.py,sha256=MAWrLijv2_wDT9KZeYLJ6W_dnVQGKEp2DK40V6qLr78,6469
84
84
  flwr/server/strategy/fedopt.py,sha256=eXH1lIrjCNUEBkAs-ziULzwACxqcArooHWg9vlj9A3Q,5393
85
85
  flwr/server/strategy/fedprox.py,sha256=sjHbjl9iUr6MTbinCYj5q2R0Tmr9o0X_5JrsUWXa7cE,7708
86
- flwr/server/strategy/fedxgb.py,sha256=klRj1bykXK3iXYf7vJ1mqjw9BHsRKChzB9lK29j7Oq4,5373
87
- flwr/server/strategy/fedxgb_nn_avg.py,sha256=fNvV2Mi2o8HDyx0PRZ1czevfTAGlDB1iilyQGXby97M,7849
86
+ flwr/server/strategy/fedxgb_nn_avg.py,sha256=-_iAjQtpmHUDC3qnvSEKdtazPHdW1V1mdaxyZfsj6NI,7772
88
87
  flwr/server/strategy/fedyogi.py,sha256=cC9WoIci3xmVs5HBGrkXHM0cpdMjApBWZS4d7-bKnZE,7044
89
88
  flwr/server/strategy/krum.py,sha256=PVfP8or0Rgj9TcxOzbngkYmm4VctX44NjOh1rs3Xmn8,7011
90
89
  flwr/server/strategy/qfedavg.py,sha256=2CJE7jvtVUcWpofBg8mG1w79PwBUoltY8wkyHMaFdag,10744
@@ -96,8 +95,8 @@ flwr/simulation/__init__.py,sha256=7gYUX6zr_yJd1wtv65xRHajBvurkfoPVXv66buUy4H8,1
96
95
  flwr/simulation/app.py,sha256=4UKy2OgCsTveHmtvIvqL2NsuL0OiIR4n1wDRLrugK88,7737
97
96
  flwr/simulation/ray_transport/__init__.py,sha256=eJ3pijYkI7XhbX2rLu6FBGTo8hZkFL8RSj4twhApOGw,727
98
97
  flwr/simulation/ray_transport/ray_client_proxy.py,sha256=WLO9ZagPCOyj-7e3x3gz-w0oiMVRdIhdHOkbrnmYvQo,5472
99
- flwr_nightly-1.4.0.dev20230321.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
100
- flwr_nightly-1.4.0.dev20230321.dist-info/entry_points.txt,sha256=1uLlD5tIunkzALMfMWnqjdE_D5hRUX_I1iMmOMv6tZI,181
101
- flwr_nightly-1.4.0.dev20230321.dist-info/WHEEL,sha256=vVCvjcmxuUltf8cYhJ0sJMRDLr1XsPuxEId8YDzbyCY,88
102
- flwr_nightly-1.4.0.dev20230321.dist-info/METADATA,sha256=EAJHuM2CJqJZNqpECYbaHoHvNOokVkIH7Yk87eFvp3A,12839
103
- flwr_nightly-1.4.0.dev20230321.dist-info/RECORD,,
98
+ flwr_nightly-1.4.0.dev20230323.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
99
+ flwr_nightly-1.4.0.dev20230323.dist-info/entry_points.txt,sha256=1uLlD5tIunkzALMfMWnqjdE_D5hRUX_I1iMmOMv6tZI,181
100
+ flwr_nightly-1.4.0.dev20230323.dist-info/WHEEL,sha256=vVCvjcmxuUltf8cYhJ0sJMRDLr1XsPuxEId8YDzbyCY,88
101
+ flwr_nightly-1.4.0.dev20230323.dist-info/METADATA,sha256=X68ZNpzTS1Blep8olTcQaE6LIp_HCb-QdqU4zJjpib0,12839
102
+ flwr_nightly-1.4.0.dev20230323.dist-info/RECORD,,
@@ -1,153 +0,0 @@
1
- # Copyright 2020 Adap GmbH. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Federated XGBoost utility functions."""
16
-
17
- from typing import Any, List, Optional, Tuple, Union
18
-
19
- import numpy as np
20
- import torch # pylint: disable=E0401
21
- import xgboost as xgb # pylint: disable=E0401
22
- from matplotlib import pyplot as plt # pylint: disable=E0401
23
- from torch.utils.data import DataLoader, Dataset # pylint: disable=E0401
24
- from xgboost import XGBClassifier, XGBRegressor # pylint: disable=E0401
25
-
26
- from flwr.common.typing import NDArray
27
-
28
-
29
- def plot_xgbtree(tree: Union[XGBClassifier, XGBRegressor], n_tree: int) -> None:
30
- """Visualize the built xgboost tree."""
31
- xgb.plot_tree(tree, num_trees=n_tree)
32
- plt.rcParams["figure.figsize"] = [50, 10]
33
- plt.show()
34
-
35
-
36
- def construct_tree(
37
- dataset: Dataset, label: NDArray, n_estimators: int, tree_type: str
38
- ) -> Union[XGBClassifier, XGBRegressor]:
39
- """Construct a xgboost tree form tabular dataset."""
40
- if tree_type == "BINARY":
41
- tree = xgb.XGBClassifier(
42
- objective="binary:logistic",
43
- learning_rate=0.1,
44
- max_depth=8,
45
- n_estimators=n_estimators,
46
- subsample=0.8,
47
- colsample_bylevel=1,
48
- colsample_bynode=1,
49
- colsample_bytree=1,
50
- alpha=5,
51
- gamma=5,
52
- num_parallel_tree=1,
53
- min_child_weight=1,
54
- )
55
-
56
- elif tree_type == "REG":
57
- tree = xgb.XGBRegressor(
58
- objective="reg:squarederror",
59
- learning_rate=0.1,
60
- max_depth=8,
61
- n_estimators=n_estimators,
62
- subsample=0.8,
63
- colsample_bylevel=1,
64
- colsample_bynode=1,
65
- colsample_bytree=1,
66
- alpha=5,
67
- gamma=5,
68
- num_parallel_tree=1,
69
- min_child_weight=1,
70
- )
71
-
72
- tree.fit(dataset, label)
73
- return tree
74
-
75
-
76
- def construct_tree_from_loader(
77
- dataset_loader: DataLoader, n_estimators: int, tree_type: str
78
- ) -> Union[XGBClassifier, XGBRegressor]:
79
- """Construct a xgboost tree form tabular dataset loader."""
80
- for dataset in dataset_loader:
81
- data, label = dataset[0], dataset[1]
82
- return construct_tree(data, label, n_estimators, tree_type)
83
-
84
-
85
- def single_tree_prediction(
86
- tree: Union[XGBClassifier, XGBRegressor], n_tree: int, dataset: NDArray
87
- ) -> Optional[NDArray]:
88
- """Extract the prediction result of a single tree in the xgboost tree
89
- ensemble."""
90
- # How to access a single tree
91
- # https://github.com/bmreiniger/datascience.stackexchange/blob/master/57905.ipynb
92
- num_t = len(tree.get_booster().get_dump())
93
- if n_tree > num_t:
94
- print(
95
- "The tree index to be extracted is larger than the total number of trees."
96
- )
97
- return None
98
-
99
- return tree.predict( # type: ignore
100
- dataset, iteration_range=(n_tree, n_tree + 1), output_margin=True
101
- )
102
-
103
-
104
- def tree_encoding( # pylint: disable=R0914
105
- trainloader: DataLoader,
106
- client_trees: Union[
107
- Tuple[XGBClassifier, int],
108
- Tuple[XGBRegressor, int],
109
- List[Union[Tuple[XGBClassifier, int], Tuple[XGBRegressor, int]]],
110
- ],
111
- client_tree_num: int,
112
- client_num: int,
113
- ) -> Optional[Tuple[NDArray, NDArray]]:
114
- """Transform the tabular dataset into prediction results using the
115
- aggregated xgboost tree ensembles from all clients."""
116
- if trainloader is None:
117
- return None
118
-
119
- for local_dataset in trainloader:
120
- x_train, y_train = local_dataset[0], local_dataset[1]
121
-
122
- x_train_enc = np.zeros((x_train.shape[0], client_num * client_tree_num))
123
- x_train_enc = np.array(x_train_enc, copy=True)
124
-
125
- temp_trees: Any = None
126
- if isinstance(client_trees, list) is False:
127
- temp_trees = [client_trees[0]] * client_num
128
- elif isinstance(client_trees, list) and len(client_trees) != client_num:
129
- temp_trees = [client_trees[0][0]] * client_num
130
- else:
131
- cids = []
132
- temp_trees = []
133
- for i, _ in enumerate(client_trees):
134
- temp_trees.append(client_trees[i][0]) # type: ignore
135
- cids.append(client_trees[i][1]) # type: ignore
136
- sorted_index = np.argsort(np.asarray(cids))
137
- temp_trees = np.asarray(temp_trees)[sorted_index]
138
-
139
- for i, _ in enumerate(temp_trees):
140
- for j in range(client_tree_num):
141
- x_train_enc[:, i * client_tree_num + j] = single_tree_prediction(
142
- temp_trees[i], j, x_train
143
- )
144
-
145
- x_train_enc32: Any = np.float32(x_train_enc)
146
- y_train32: Any = np.float32(y_train)
147
-
148
- x_train_enc32, y_train32 = torch.from_numpy(
149
- np.expand_dims(x_train_enc32, axis=1) # type: ignore
150
- ), torch.from_numpy(
151
- np.expand_dims(y_train32, axis=-1) # type: ignore
152
- )
153
- return x_train_enc32, y_train32