flwr 1.21.0__py3-none-any.whl → 1.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. flwr/cli/app.py +17 -1
  2. flwr/cli/auth_plugin/__init__.py +15 -6
  3. flwr/cli/auth_plugin/auth_plugin.py +95 -0
  4. flwr/cli/auth_plugin/noop_auth_plugin.py +58 -0
  5. flwr/cli/auth_plugin/oidc_cli_plugin.py +16 -25
  6. flwr/cli/build.py +118 -47
  7. flwr/cli/{cli_user_auth_interceptor.py → cli_account_auth_interceptor.py} +6 -5
  8. flwr/cli/log.py +2 -2
  9. flwr/cli/login/login.py +34 -23
  10. flwr/cli/ls.py +13 -9
  11. flwr/cli/new/new.py +196 -42
  12. flwr/cli/new/templates/app/README.flowertune.md.tpl +1 -1
  13. flwr/cli/new/templates/app/code/client.baseline.py.tpl +64 -47
  14. flwr/cli/new/templates/app/code/client.huggingface.py.tpl +68 -30
  15. flwr/cli/new/templates/app/code/client.jax.py.tpl +63 -42
  16. flwr/cli/new/templates/app/code/client.mlx.py.tpl +80 -51
  17. flwr/cli/new/templates/app/code/client.numpy.py.tpl +36 -13
  18. flwr/cli/new/templates/app/code/client.pytorch.py.tpl +71 -46
  19. flwr/cli/new/templates/app/code/client.pytorch_legacy_api.py.tpl +55 -0
  20. flwr/cli/new/templates/app/code/client.sklearn.py.tpl +75 -30
  21. flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +69 -44
  22. flwr/cli/new/templates/app/code/client.xgboost.py.tpl +110 -0
  23. flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +56 -90
  24. flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +1 -23
  25. flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +37 -58
  26. flwr/cli/new/templates/app/code/flwr_tune/strategy.py.tpl +39 -44
  27. flwr/cli/new/templates/app/code/model.baseline.py.tpl +0 -14
  28. flwr/cli/new/templates/app/code/server.baseline.py.tpl +27 -29
  29. flwr/cli/new/templates/app/code/server.huggingface.py.tpl +23 -19
  30. flwr/cli/new/templates/app/code/server.jax.py.tpl +27 -14
  31. flwr/cli/new/templates/app/code/server.mlx.py.tpl +29 -19
  32. flwr/cli/new/templates/app/code/server.numpy.py.tpl +30 -17
  33. flwr/cli/new/templates/app/code/server.pytorch.py.tpl +36 -26
  34. flwr/cli/new/templates/app/code/server.pytorch_legacy_api.py.tpl +31 -0
  35. flwr/cli/new/templates/app/code/server.sklearn.py.tpl +29 -21
  36. flwr/cli/new/templates/app/code/server.tensorflow.py.tpl +28 -19
  37. flwr/cli/new/templates/app/code/server.xgboost.py.tpl +56 -0
  38. flwr/cli/new/templates/app/code/task.huggingface.py.tpl +16 -20
  39. flwr/cli/new/templates/app/code/task.jax.py.tpl +1 -1
  40. flwr/cli/new/templates/app/code/task.numpy.py.tpl +1 -1
  41. flwr/cli/new/templates/app/code/task.pytorch.py.tpl +14 -27
  42. flwr/cli/new/templates/app/code/{task.pytorch_msg_api.py.tpl → task.pytorch_legacy_api.py.tpl} +27 -14
  43. flwr/cli/new/templates/app/code/task.tensorflow.py.tpl +1 -2
  44. flwr/cli/new/templates/app/code/task.xgboost.py.tpl +67 -0
  45. flwr/cli/new/templates/app/pyproject.baseline.toml.tpl +4 -4
  46. flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +2 -2
  47. flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +4 -4
  48. flwr/cli/new/templates/app/pyproject.jax.toml.tpl +1 -1
  49. flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +2 -2
  50. flwr/cli/new/templates/app/pyproject.numpy.toml.tpl +1 -1
  51. flwr/cli/new/templates/app/pyproject.pytorch.toml.tpl +3 -3
  52. flwr/cli/new/templates/app/{pyproject.pytorch_msg_api.toml.tpl → pyproject.pytorch_legacy_api.toml.tpl} +3 -3
  53. flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +1 -1
  54. flwr/cli/new/templates/app/pyproject.tensorflow.toml.tpl +1 -1
  55. flwr/cli/new/templates/app/pyproject.xgboost.toml.tpl +61 -0
  56. flwr/cli/pull.py +100 -0
  57. flwr/cli/run/run.py +11 -7
  58. flwr/cli/stop.py +2 -2
  59. flwr/cli/supernode/__init__.py +25 -0
  60. flwr/cli/supernode/ls.py +260 -0
  61. flwr/cli/supernode/register.py +185 -0
  62. flwr/cli/supernode/unregister.py +138 -0
  63. flwr/cli/utils.py +109 -69
  64. flwr/client/__init__.py +2 -1
  65. flwr/client/grpc_adapter_client/connection.py +6 -8
  66. flwr/client/grpc_rere_client/connection.py +59 -31
  67. flwr/client/grpc_rere_client/grpc_adapter.py +28 -12
  68. flwr/client/grpc_rere_client/{client_interceptor.py → node_auth_client_interceptor.py} +3 -6
  69. flwr/client/mod/secure_aggregation/secaggplus_mod.py +7 -5
  70. flwr/client/rest_client/connection.py +82 -37
  71. flwr/clientapp/__init__.py +1 -2
  72. flwr/clientapp/mod/__init__.py +4 -1
  73. flwr/clientapp/mod/centraldp_mods.py +156 -40
  74. flwr/clientapp/mod/localdp_mod.py +169 -0
  75. flwr/clientapp/typing.py +22 -0
  76. flwr/{client/clientapp → clientapp}/utils.py +1 -1
  77. flwr/common/constant.py +56 -13
  78. flwr/common/exit/exit_code.py +24 -10
  79. flwr/common/inflatable_utils.py +10 -10
  80. flwr/common/record/array.py +3 -3
  81. flwr/common/record/arrayrecord.py +10 -1
  82. flwr/common/record/typeddict.py +12 -0
  83. flwr/common/secure_aggregation/crypto/symmetric_encryption.py +1 -89
  84. flwr/common/serde.py +4 -2
  85. flwr/common/typing.py +7 -6
  86. flwr/compat/client/app.py +1 -1
  87. flwr/compat/client/grpc_client/connection.py +2 -2
  88. flwr/proto/control_pb2.py +48 -31
  89. flwr/proto/control_pb2.pyi +95 -5
  90. flwr/proto/control_pb2_grpc.py +136 -0
  91. flwr/proto/control_pb2_grpc.pyi +52 -0
  92. flwr/proto/fab_pb2.py +11 -7
  93. flwr/proto/fab_pb2.pyi +21 -1
  94. flwr/proto/fleet_pb2.py +31 -23
  95. flwr/proto/fleet_pb2.pyi +63 -23
  96. flwr/proto/fleet_pb2_grpc.py +98 -28
  97. flwr/proto/fleet_pb2_grpc.pyi +45 -13
  98. flwr/proto/node_pb2.py +3 -1
  99. flwr/proto/node_pb2.pyi +48 -0
  100. flwr/server/app.py +152 -114
  101. flwr/server/superlink/fleet/grpc_adapter/grpc_adapter_servicer.py +17 -7
  102. flwr/server/superlink/fleet/grpc_rere/fleet_servicer.py +132 -38
  103. flwr/server/superlink/fleet/grpc_rere/{server_interceptor.py → node_auth_server_interceptor.py} +27 -51
  104. flwr/server/superlink/fleet/message_handler/message_handler.py +67 -22
  105. flwr/server/superlink/fleet/rest_rere/rest_api.py +52 -31
  106. flwr/server/superlink/fleet/vce/backend/backend.py +1 -1
  107. flwr/server/superlink/fleet/vce/backend/raybackend.py +1 -1
  108. flwr/server/superlink/fleet/vce/vce_api.py +18 -5
  109. flwr/server/superlink/linkstate/in_memory_linkstate.py +167 -73
  110. flwr/server/superlink/linkstate/linkstate.py +107 -24
  111. flwr/server/superlink/linkstate/linkstate_factory.py +2 -1
  112. flwr/server/superlink/linkstate/sqlite_linkstate.py +306 -255
  113. flwr/server/superlink/linkstate/utils.py +3 -54
  114. flwr/server/superlink/serverappio/serverappio_servicer.py +2 -2
  115. flwr/server/superlink/simulation/simulationio_servicer.py +1 -1
  116. flwr/server/utils/validator.py +2 -3
  117. flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +4 -2
  118. flwr/serverapp/strategy/__init__.py +26 -0
  119. flwr/serverapp/strategy/bulyan.py +238 -0
  120. flwr/serverapp/strategy/dp_adaptive_clipping.py +335 -0
  121. flwr/serverapp/strategy/dp_fixed_clipping.py +71 -49
  122. flwr/serverapp/strategy/fedadagrad.py +0 -3
  123. flwr/serverapp/strategy/fedadam.py +0 -3
  124. flwr/serverapp/strategy/fedavg.py +89 -64
  125. flwr/serverapp/strategy/fedavgm.py +198 -0
  126. flwr/serverapp/strategy/fedmedian.py +105 -0
  127. flwr/serverapp/strategy/fedprox.py +174 -0
  128. flwr/serverapp/strategy/fedtrimmedavg.py +176 -0
  129. flwr/serverapp/strategy/fedxgb_bagging.py +117 -0
  130. flwr/serverapp/strategy/fedxgb_cyclic.py +220 -0
  131. flwr/serverapp/strategy/fedyogi.py +0 -3
  132. flwr/serverapp/strategy/krum.py +112 -0
  133. flwr/serverapp/strategy/multikrum.py +247 -0
  134. flwr/serverapp/strategy/qfedavg.py +252 -0
  135. flwr/serverapp/strategy/strategy_utils.py +48 -0
  136. flwr/simulation/app.py +1 -1
  137. flwr/simulation/ray_transport/ray_actor.py +1 -1
  138. flwr/simulation/ray_transport/ray_client_proxy.py +1 -1
  139. flwr/simulation/run_simulation.py +28 -32
  140. flwr/supercore/cli/flower_superexec.py +26 -1
  141. flwr/supercore/constant.py +41 -0
  142. flwr/supercore/object_store/in_memory_object_store.py +0 -4
  143. flwr/supercore/object_store/object_store_factory.py +26 -6
  144. flwr/supercore/object_store/sqlite_object_store.py +252 -0
  145. flwr/{client/clientapp → supercore/primitives}/__init__.py +1 -1
  146. flwr/supercore/primitives/asymmetric.py +117 -0
  147. flwr/supercore/primitives/asymmetric_ed25519.py +165 -0
  148. flwr/supercore/sqlite_mixin.py +156 -0
  149. flwr/supercore/superexec/plugin/exec_plugin.py +11 -1
  150. flwr/supercore/superexec/run_superexec.py +16 -2
  151. flwr/supercore/utils.py +20 -0
  152. flwr/superlink/artifact_provider/__init__.py +22 -0
  153. flwr/superlink/artifact_provider/artifact_provider.py +37 -0
  154. flwr/{common → superlink}/auth_plugin/__init__.py +6 -6
  155. flwr/superlink/auth_plugin/auth_plugin.py +91 -0
  156. flwr/superlink/auth_plugin/noop_auth_plugin.py +87 -0
  157. flwr/superlink/servicer/control/{control_user_auth_interceptor.py → control_account_auth_interceptor.py} +19 -19
  158. flwr/superlink/servicer/control/control_event_log_interceptor.py +1 -1
  159. flwr/superlink/servicer/control/control_grpc.py +16 -11
  160. flwr/superlink/servicer/control/control_servicer.py +207 -58
  161. flwr/supernode/cli/flower_supernode.py +19 -26
  162. flwr/supernode/runtime/run_clientapp.py +2 -2
  163. flwr/supernode/servicer/clientappio/clientappio_servicer.py +1 -1
  164. flwr/supernode/start_client_internal.py +17 -9
  165. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/METADATA +6 -16
  166. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/RECORD +170 -140
  167. flwr/cli/new/templates/app/code/client.pytorch_msg_api.py.tpl +0 -80
  168. flwr/cli/new/templates/app/code/server.pytorch_msg_api.py.tpl +0 -41
  169. flwr/common/auth_plugin/auth_plugin.py +0 -149
  170. flwr/serverapp/dp_fixed_clipping.py +0 -352
  171. flwr/serverapp/strategy/strategy_utils_tests.py +0 -304
  172. /flwr/cli/new/templates/app/code/{__init__.pytorch_msg_api.py.tpl → __init__.pytorch_legacy_api.py.tpl} +0 -0
  173. /flwr/{client → clientapp}/client_app.py +0 -0
  174. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/WHEEL +0 -0
  175. {flwr-1.21.0.dist-info → flwr-1.23.0.dist-info}/entry_points.txt +0 -0
@@ -2,10 +2,16 @@
2
2
 
3
3
  import warnings
4
4
 
5
- from sklearn.metrics import log_loss
5
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
6
+ from flwr.clientapp import ClientApp
7
+ from sklearn.metrics import (
8
+ accuracy_score,
9
+ f1_score,
10
+ log_loss,
11
+ precision_score,
12
+ recall_score,
13
+ )
6
14
 
7
- from flwr.client import ClientApp, NumPyClient
8
- from flwr.common import Context
9
15
  from $import_name.task import (
10
16
  get_model,
11
17
  get_model_params,
@@ -14,39 +20,52 @@ from $import_name.task import (
14
20
  set_model_params,
15
21
  )
16
22
 
23
+ # Flower ClientApp
24
+ app = ClientApp()
17
25
 
18
- class FlowerClient(NumPyClient):
19
- def __init__(self, model, X_train, X_test, y_train, y_test):
20
- self.model = model
21
- self.X_train = X_train
22
- self.X_test = X_test
23
- self.y_train = y_train
24
- self.y_test = y_test
25
26
 
26
- def fit(self, parameters, config):
27
- set_model_params(self.model, parameters)
27
+ @app.train()
28
+ def train(msg: Message, context: Context):
29
+ """Train the model on local data."""
28
30
 
29
- # Ignore convergence failure due to low local epochs
30
- with warnings.catch_warnings():
31
- warnings.simplefilter("ignore")
32
- self.model.fit(self.X_train, self.y_train)
31
+ # Create LogisticRegression Model
32
+ penalty = context.run_config["penalty"]
33
+ local_epochs = context.run_config["local-epochs"]
34
+ model = get_model(penalty, local_epochs)
35
+ # Setting initial parameters, akin to model.compile for keras models
36
+ set_initial_params(model)
33
37
 
34
- return get_model_params(self.model), len(self.X_train), {}
38
+ # Apply received pararameters
39
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
40
+ set_model_params(model, ndarrays)
35
41
 
36
- def evaluate(self, parameters, config):
37
- set_model_params(self.model, parameters)
42
+ # Load the data
43
+ partition_id = context.node_config["partition-id"]
44
+ num_partitions = context.node_config["num-partitions"]
45
+ X_train, _, y_train, _ = load_data(partition_id, num_partitions)
38
46
 
39
- loss = log_loss(self.y_test, self.model.predict_proba(self.X_test))
40
- accuracy = self.model.score(self.X_test, self.y_test)
47
+ # Ignore convergence failure due to low local epochs
48
+ with warnings.catch_warnings():
49
+ warnings.simplefilter("ignore")
50
+ # Train the model on local data
51
+ model.fit(X_train, y_train)
41
52
 
42
- return loss, len(self.X_test), {"accuracy": accuracy}
53
+ # Let's compute train loss
54
+ y_train_pred_proba = model.predict_proba(X_train)
55
+ train_logloss = log_loss(y_train, y_train_pred_proba)
43
56
 
57
+ # Construct and return reply Message
58
+ ndarrays = get_model_params(model)
59
+ model_record = ArrayRecord(ndarrays)
60
+ metrics = {"num-examples": len(X_train), "train_logloss": train_logloss}
61
+ metric_record = MetricRecord(metrics)
62
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
63
+ return Message(content=content, reply_to=msg)
44
64
 
45
- def client_fn(context: Context):
46
- partition_id = context.node_config["partition-id"]
47
- num_partitions = context.node_config["num-partitions"]
48
65
 
49
- X_train, X_test, y_train, y_test = load_data(partition_id, num_partitions)
66
+ @app.evaluate()
67
+ def evaluate(msg: Message, context: Context):
68
+ """Evaluate the model on test data."""
50
69
 
51
70
  # Create LogisticRegression Model
52
71
  penalty = context.run_config["penalty"]
@@ -56,8 +75,34 @@ def client_fn(context: Context):
56
75
  # Setting initial parameters, akin to model.compile for keras models
57
76
  set_initial_params(model)
58
77
 
59
- return FlowerClient(model, X_train, X_test, y_train, y_test).to_client()
78
+ # Apply received pararameters
79
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
80
+ set_model_params(model, ndarrays)
60
81
 
61
-
62
- # Flower ClientApp
63
- app = ClientApp(client_fn=client_fn)
82
+ # Load the data
83
+ partition_id = context.node_config["partition-id"]
84
+ num_partitions = context.node_config["num-partitions"]
85
+ _, X_test, _, y_test = load_data(partition_id, num_partitions)
86
+
87
+ # Evaluate the model on local data
88
+ y_train_pred = model.predict(X_test)
89
+ y_train_pred_proba = model.predict_proba(X_test)
90
+
91
+ accuracy = accuracy_score(y_test, y_train_pred)
92
+ loss = log_loss(y_test, y_train_pred_proba)
93
+ precision = precision_score(y_test, y_train_pred, average="macro", zero_division=0)
94
+ recall = recall_score(y_test, y_train_pred, average="macro", zero_division=0)
95
+ f1 = f1_score(y_test, y_train_pred, average="macro", zero_division=0)
96
+
97
+ # Construct and return reply Message
98
+ metrics = {
99
+ "num-examples": len(X_test),
100
+ "test_logloss": loss,
101
+ "accuracy": accuracy,
102
+ "precision": precision,
103
+ "recall": recall,
104
+ "f1": f1,
105
+ }
106
+ metric_record = MetricRecord(metrics)
107
+ content = RecordDict({"metrics": metric_record})
108
+ return Message(content=content, reply_to=msg)
@@ -1,57 +1,82 @@
1
1
  """$project_name: A Flower / $framework_str app."""
2
2
 
3
- from flwr.client import NumPyClient, ClientApp
4
- from flwr.common import Context
3
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
4
+ from flwr.clientapp import ClientApp
5
5
 
6
6
  from $import_name.task import load_data, load_model
7
7
 
8
+ # Flower ClientApp
9
+ app = ClientApp()
8
10
 
9
- # Define Flower Client and client_fn
10
- class FlowerClient(NumPyClient):
11
- def __init__(
12
- self, model, data, epochs, batch_size, verbose
13
- ):
14
- self.model = model
15
- self.x_train, self.y_train, self.x_test, self.y_test = data
16
- self.epochs = epochs
17
- self.batch_size = batch_size
18
- self.verbose = verbose
19
-
20
- def fit(self, parameters, config):
21
- self.model.set_weights(parameters)
22
- self.model.fit(
23
- self.x_train,
24
- self.y_train,
25
- epochs=self.epochs,
26
- batch_size=self.batch_size,
27
- verbose=self.verbose,
28
- )
29
- return self.model.get_weights(), len(self.x_train), {}
30
-
31
- def evaluate(self, parameters, config):
32
- self.model.set_weights(parameters)
33
- loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0)
34
- return loss, len(self.x_test), {"accuracy": accuracy}
35
-
36
-
37
- def client_fn(context: Context):
38
- # Load model and data
39
- net = load_model()
40
11
 
41
- partition_id = context.node_config["partition-id"]
42
- num_partitions = context.node_config["num-partitions"]
43
- data = load_data(partition_id, num_partitions)
12
+ @app.train()
13
+ def train(msg: Message, context: Context):
14
+ """Train the model on local data."""
15
+
16
+ # Load the model and initialize it with the received weights
17
+ model = load_model()
18
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
19
+ model.set_weights(ndarrays)
20
+
21
+ # Read from config
44
22
  epochs = context.run_config["local-epochs"]
45
23
  batch_size = context.run_config["batch-size"]
46
24
  verbose = context.run_config.get("verbose")
47
25
 
48
- # Return Client instance
49
- return FlowerClient(
50
- net, data, epochs, batch_size, verbose
51
- ).to_client()
26
+ # Load the data
27
+ partition_id = context.node_config["partition-id"]
28
+ num_partitions = context.node_config["num-partitions"]
29
+ x_train, y_train, _, _ = load_data(partition_id, num_partitions)
52
30
 
31
+ # Train the model on local data
32
+ history = model.fit(
33
+ x_train,
34
+ y_train,
35
+ epochs=epochs,
36
+ batch_size=batch_size,
37
+ verbose=verbose,
38
+ )
53
39
 
54
- # Flower ClientApp
55
- app = ClientApp(
56
- client_fn=client_fn,
57
- )
40
+ # Get final training loss and accuracy
41
+ train_loss = history.history["loss"][-1] if "loss" in history.history else None
42
+ train_acc = history.history.get("accuracy")
43
+ train_acc = train_acc[-1] if train_acc is not None else None
44
+
45
+ # Construct and return reply Message
46
+ model_record = ArrayRecord(model.get_weights())
47
+ metrics = {"num-examples": len(x_train)}
48
+ if train_loss is not None:
49
+ metrics["train_loss"] = train_loss
50
+ if train_acc is not None:
51
+ metrics["train_acc"] = train_acc
52
+ metric_record = MetricRecord(metrics)
53
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
54
+ return Message(content=content, reply_to=msg)
55
+
56
+
57
+ @app.evaluate()
58
+ def evaluate(msg: Message, context: Context):
59
+ """Evaluate the model on local data."""
60
+
61
+ # Load the model and initialize it with the received weights
62
+ model = load_model()
63
+ ndarrays = msg.content["arrays"].to_numpy_ndarrays()
64
+ model.set_weights(ndarrays)
65
+
66
+ # Load the data
67
+ partition_id = context.node_config["partition-id"]
68
+ num_partitions = context.node_config["num-partitions"]
69
+ _, _, x_test, y_test = load_data(partition_id, num_partitions)
70
+
71
+ # Evaluate the model on local data
72
+ loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
73
+
74
+ # Construct and return reply Message
75
+ metrics = {
76
+ "eval_loss": loss,
77
+ "eval_acc": accuracy,
78
+ "num-examples": len(x_test),
79
+ }
80
+ metric_record = MetricRecord(metrics)
81
+ content = RecordDict({"metrics": metric_record})
82
+ return Message(content=content, reply_to=msg)
@@ -0,0 +1,110 @@
1
+ """$project_name: A Flower / $framework_str app."""
2
+
3
+ import warnings
4
+
5
+ import numpy as np
6
+ import xgboost as xgb
7
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
8
+ from flwr.clientapp import ClientApp
9
+ from flwr.common.config import unflatten_dict
10
+
11
+ from $import_name.task import load_data, replace_keys
12
+
13
+ warnings.filterwarnings("ignore", category=UserWarning)
14
+
15
+
16
+ # Flower ClientApp
17
+ app = ClientApp()
18
+
19
+
20
+ def _local_boost(bst_input, num_local_round, train_dmatrix):
21
+ # Update trees based on local training data.
22
+ for i in range(num_local_round):
23
+ bst_input.update(train_dmatrix, bst_input.num_boosted_rounds())
24
+
25
+ # Bagging: extract the last N=num_local_round trees for sever aggregation
26
+ bst = bst_input[
27
+ bst_input.num_boosted_rounds()
28
+ - num_local_round : bst_input.num_boosted_rounds()
29
+ ]
30
+ return bst
31
+
32
+
33
+ @app.train()
34
+ def train(msg: Message, context: Context) -> Message:
35
+ # Load model and data
36
+ partition_id = context.node_config["partition-id"]
37
+ num_partitions = context.node_config["num-partitions"]
38
+ train_dmatrix, _, num_train, _ = load_data(partition_id, num_partitions)
39
+
40
+ # Read from run config
41
+ num_local_round = context.run_config["local-epochs"]
42
+ # Flatted config dict and replace "-" with "_"
43
+ cfg = replace_keys(unflatten_dict(context.run_config))
44
+ params = cfg["params"]
45
+
46
+ global_round = msg.content["config"]["server-round"]
47
+ if global_round == 1:
48
+ # First round local training
49
+ bst = xgb.train(
50
+ params,
51
+ train_dmatrix,
52
+ num_boost_round=num_local_round,
53
+ )
54
+ else:
55
+ bst = xgb.Booster(params=params)
56
+ global_model = bytearray(msg.content["arrays"]["0"].numpy().tobytes())
57
+
58
+ # Load global model into booster
59
+ bst.load_model(global_model)
60
+
61
+ # Local training
62
+ bst = _local_boost(bst, num_local_round, train_dmatrix)
63
+
64
+ # Save model
65
+ local_model = bst.save_raw("json")
66
+ model_np = np.frombuffer(local_model, dtype=np.uint8)
67
+
68
+ # Construct reply message
69
+ # Note: we store the model as the first item in a list into ArrayRecord,
70
+ # which can be accessed using index ["0"].
71
+ model_record = ArrayRecord([model_np])
72
+ metrics = {
73
+ "num-examples": num_train,
74
+ }
75
+ metric_record = MetricRecord(metrics)
76
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
77
+ return Message(content=content, reply_to=msg)
78
+
79
+
80
+ @app.evaluate()
81
+ def evaluate(msg: Message, context: Context) -> Message:
82
+ # Load model and data
83
+ partition_id = context.node_config["partition-id"]
84
+ num_partitions = context.node_config["num-partitions"]
85
+ _, valid_dmatrix, _, num_val = load_data(partition_id, num_partitions)
86
+
87
+ # Load config
88
+ cfg = replace_keys(unflatten_dict(context.run_config))
89
+ params = cfg["params"]
90
+
91
+ # Load global model
92
+ bst = xgb.Booster(params=params)
93
+ global_model = bytearray(msg.content["arrays"]["0"].numpy().tobytes())
94
+ bst.load_model(global_model)
95
+
96
+ # Run evaluation
97
+ eval_results = bst.eval_set(
98
+ evals=[(valid_dmatrix, "valid")],
99
+ iteration=bst.num_boosted_rounds() - 1,
100
+ )
101
+ auc = float(eval_results.split("\t")[1].split(":")[1])
102
+
103
+ # Construct and return reply Message
104
+ metrics = {
105
+ "auc": auc,
106
+ "num-examples": num_val,
107
+ }
108
+ metric_record = MetricRecord(metrics)
109
+ content = RecordDict({"metrics": metric_record})
110
+ return Message(content=content, reply_to=msg)
@@ -2,15 +2,12 @@
2
2
 
3
3
  import os
4
4
  import warnings
5
- from typing import Dict, Tuple
6
5
 
7
- import torch
8
- from flwr.client import ClientApp, NumPyClient
9
- from flwr.common import Context
6
+ from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
7
+ from flwr.clientapp import ClientApp
10
8
  from flwr.common.config import unflatten_dict
11
- from flwr.common.typing import NDArrays, Scalar
12
9
  from omegaconf import DictConfig
13
-
10
+ from peft import get_peft_model_state_dict, set_peft_model_state_dict
14
11
  from transformers import TrainingArguments
15
12
  from trl import SFTTrainer
16
13
 
@@ -19,12 +16,7 @@ from $import_name.dataset import (
19
16
  load_data,
20
17
  replace_keys,
21
18
  )
22
- from $import_name.models import (
23
- cosine_annealing,
24
- get_model,
25
- set_parameters,
26
- get_parameters,
27
- )
19
+ from $import_name.models import cosine_annealing, get_model
28
20
 
29
21
  # Avoid warnings
30
22
  os.environ["TOKENIZERS_PARALLELISM"] = "true"
@@ -32,95 +24,69 @@ os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
32
24
  warnings.filterwarnings("ignore", category=UserWarning)
33
25
 
34
26
 
35
- # pylint: disable=too-many-arguments
36
- # pylint: disable=too-many-instance-attributes
37
- class FlowerClient(NumPyClient):
38
- """Flower client for LLM fine-tuning."""
27
+ # Avoid warnings
28
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
29
+ os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
30
+ warnings.filterwarnings("ignore", category=UserWarning)
39
31
 
40
- def __init__(
41
- self,
42
- model_cfg: DictConfig,
43
- train_cfg: DictConfig,
44
- trainset,
45
- tokenizer,
46
- formatting_prompts_func,
47
- data_collator,
48
- num_rounds,
49
- ): # pylint: disable=too-many-arguments
50
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
51
- self.train_cfg = train_cfg
52
- self.training_arguments = TrainingArguments(**train_cfg.training_arguments)
53
- self.tokenizer = tokenizer
54
- self.formatting_prompts_func = formatting_prompts_func
55
- self.data_collator = data_collator
56
- self.num_rounds = num_rounds
57
- self.trainset = trainset
58
-
59
- # instantiate model
60
- self.model = get_model(model_cfg)
61
-
62
- def fit(
63
- self, parameters: NDArrays, config: Dict[str, Scalar]
64
- ) -> Tuple[NDArrays, int, Dict]:
65
- """Implement distributed fit function for a given client."""
66
- set_parameters(self.model, parameters)
67
-
68
- new_lr = cosine_annealing(
69
- int(config["current_round"]),
70
- self.num_rounds,
71
- self.train_cfg.learning_rate_max,
72
- self.train_cfg.learning_rate_min,
73
- )
74
-
75
- self.training_arguments.learning_rate = new_lr
76
- self.training_arguments.output_dir = config["save_path"]
77
-
78
- # Construct trainer
79
- trainer = SFTTrainer(
80
- model=self.model,
81
- tokenizer=self.tokenizer,
82
- args=self.training_arguments,
83
- max_seq_length=self.train_cfg.seq_length,
84
- train_dataset=self.trainset,
85
- formatting_func=self.formatting_prompts_func,
86
- data_collator=self.data_collator,
87
- )
88
-
89
- # Do local training
90
- results = trainer.train()
91
-
92
- return (
93
- get_parameters(self.model),
94
- len(self.trainset),
95
- {"train_loss": results.training_loss},
96
- )
97
-
98
-
99
- def client_fn(context: Context) -> FlowerClient:
100
- """Create a Flower client representing a single organization."""
32
+
33
+ # Flower ClientApp
34
+ app = ClientApp()
35
+
36
+
37
+ @app.train()
38
+ def train(msg: Message, context: Context):
39
+ """Train the model on local data."""
40
+ # Parse config
101
41
  partition_id = context.node_config["partition-id"]
102
42
  num_partitions = context.node_config["num-partitions"]
103
43
  num_rounds = context.run_config["num-server-rounds"]
104
44
  cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))
45
+ training_arguments = TrainingArguments(**cfg.train.training_arguments)
105
46
 
106
47
  # Let's get the client partition
107
- client_trainset = load_data(partition_id, num_partitions, cfg.static.dataset.name)
48
+ trainset = load_data(partition_id, num_partitions, cfg.static.dataset.name)
108
49
  (
109
50
  tokenizer,
110
51
  data_collator,
111
52
  formatting_prompts_func,
112
53
  ) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)
113
54
 
114
- return FlowerClient(
115
- cfg.model,
116
- cfg.train,
117
- client_trainset,
118
- tokenizer,
119
- formatting_prompts_func,
120
- data_collator,
121
- num_rounds,
122
- ).to_client()
123
-
55
+ # Load the model and initialize it with the received weights
56
+ model = get_model(cfg.model)
57
+ set_peft_model_state_dict(model, msg.content["arrays"].to_torch_state_dict())
124
58
 
125
- # Flower ClientApp
126
- app = ClientApp(client_fn)
59
+ # Set learning rate for current round
60
+ new_lr = cosine_annealing(
61
+ msg.content["config"]["server-round"],
62
+ num_rounds,
63
+ cfg.train.learning_rate_max,
64
+ cfg.train.learning_rate_min,
65
+ )
66
+
67
+ training_arguments.learning_rate = new_lr
68
+ training_arguments.output_dir = msg.content["config"]["save_path"]
69
+
70
+ # Construct trainer
71
+ trainer = SFTTrainer(
72
+ model=model,
73
+ tokenizer=tokenizer,
74
+ args=training_arguments,
75
+ max_seq_length=cfg.train.seq_length,
76
+ train_dataset=trainset,
77
+ formatting_func=formatting_prompts_func,
78
+ data_collator=data_collator,
79
+ )
80
+
81
+ # Do local training
82
+ results = trainer.train()
83
+
84
+ # Construct and return reply Message
85
+ model_record = ArrayRecord(get_peft_model_state_dict(model))
86
+ metrics = {
87
+ "train_loss": results.training_loss,
88
+ "num-examples": len(trainset),
89
+ }
90
+ metric_record = MetricRecord(metrics)
91
+ content = RecordDict({"arrays": model_record, "metrics": metric_record})
92
+ return Message(content=content, reply_to=msg)
@@ -4,18 +4,10 @@ import math
4
4
 
5
5
  import torch
6
6
  from omegaconf import DictConfig
7
- from collections import OrderedDict
8
- from peft import (
9
- LoraConfig,
10
- get_peft_model,
11
- get_peft_model_state_dict,
12
- set_peft_model_state_dict,
13
- )
7
+ from peft import LoraConfig, get_peft_model
14
8
  from peft.utils import prepare_model_for_kbit_training
15
9
  from transformers import AutoModelForCausalLM, BitsAndBytesConfig
16
10
 
17
- from flwr.common.typing import NDArrays
18
-
19
11
 
20
12
  def cosine_annealing(
21
13
  current_round: int,
@@ -62,17 +54,3 @@ def get_model(model_cfg: DictConfig):
62
54
  model.config.use_cache = False
63
55
 
64
56
  return get_peft_model(model, peft_config)
65
-
66
-
67
- def set_parameters(model, parameters: NDArrays) -> None:
68
- """Change the parameters of the model using the given ones."""
69
- peft_state_dict_keys = get_peft_model_state_dict(model).keys()
70
- params_dict = zip(peft_state_dict_keys, parameters)
71
- state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
72
- set_peft_model_state_dict(model, state_dict)
73
-
74
-
75
- def get_parameters(model) -> NDArrays:
76
- """Return the parameters of the current net."""
77
- state_dict = get_peft_model_state_dict(model)
78
- return [val.cpu().numpy() for _, val in state_dict.items()]