flwr-nightly 1.8.0.dev20240310__py3-none-any.whl → 1.8.0.dev20240312__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flwr/cli/new/new.py +6 -3
 - flwr/cli/utils.py +14 -1
 - flwr/client/app.py +25 -2
 - flwr/client/mod/__init__.py +2 -1
 - flwr/client/mod/secure_aggregation/__init__.py +2 -0
 - flwr/client/mod/secure_aggregation/secagg_mod.py +30 -0
 - flwr/client/mod/secure_aggregation/secaggplus_mod.py +42 -51
 - flwr/common/logger.py +6 -8
 - flwr/common/pyproject.py +41 -0
 - flwr/common/secure_aggregation/secaggplus_constants.py +2 -2
 - flwr/server/superlink/state/in_memory_state.py +34 -32
 - flwr/server/workflow/__init__.py +2 -1
 - flwr/server/workflow/default_workflows.py +39 -40
 - flwr/server/workflow/secure_aggregation/__init__.py +2 -0
 - flwr/server/workflow/secure_aggregation/secagg_workflow.py +112 -0
 - flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +98 -26
 - {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/METADATA +1 -1
 - {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/RECORD +21 -18
 - {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/LICENSE +0 -0
 - {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/WHEEL +0 -0
 - {flwr_nightly-1.8.0.dev20240310.dist-info → flwr_nightly-1.8.0.dev20240312.dist-info}/entry_points.txt +0 -0
 
| 
         @@ -15,8 +15,9 @@ 
     | 
|
| 
       15 
15 
     | 
    
         
             
            """Legacy default workflows."""
         
     | 
| 
       16 
16 
     | 
    
         | 
| 
       17 
17 
     | 
    
         | 
| 
      
 18 
     | 
    
         
            +
            import io
         
     | 
| 
       18 
19 
     | 
    
         
             
            import timeit
         
     | 
| 
       19 
     | 
    
         
            -
            from logging import  
     | 
| 
      
 20 
     | 
    
         
            +
            from logging import INFO
         
     | 
| 
       20 
21 
     | 
    
         
             
            from typing import Optional, cast
         
     | 
| 
       21 
22 
     | 
    
         | 
| 
       22 
23 
     | 
    
         
             
            import flwr.common.recordset_compat as compat
         
     | 
| 
         @@ -58,16 +59,18 @@ class DefaultWorkflow: 
     | 
|
| 
       58 
59 
     | 
    
         
             
                    )
         
     | 
| 
       59 
60 
     | 
    
         | 
| 
       60 
61 
     | 
    
         
             
                    # Initialize parameters
         
     | 
| 
      
 62 
     | 
    
         
            +
                    log(INFO, "[INIT]")
         
     | 
| 
       61 
63 
     | 
    
         
             
                    default_init_params_workflow(driver, context)
         
     | 
| 
       62 
64 
     | 
    
         | 
| 
       63 
65 
     | 
    
         
             
                    # Run federated learning for num_rounds
         
     | 
| 
       64 
     | 
    
         
            -
                    log(INFO, "FL starting")
         
     | 
| 
       65 
66 
     | 
    
         
             
                    start_time = timeit.default_timer()
         
     | 
| 
       66 
67 
     | 
    
         
             
                    cfg = ConfigsRecord()
         
     | 
| 
       67 
68 
     | 
    
         
             
                    cfg[Key.START_TIME] = start_time
         
     | 
| 
       68 
69 
     | 
    
         
             
                    context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg
         
     | 
| 
       69 
70 
     | 
    
         | 
| 
       70 
71 
     | 
    
         
             
                    for current_round in range(1, context.config.num_rounds + 1):
         
     | 
| 
      
 72 
     | 
    
         
            +
                        log(INFO, "")
         
     | 
| 
      
 73 
     | 
    
         
            +
                        log(INFO, "[ROUND %s]", current_round)
         
     | 
| 
       71 
74 
     | 
    
         
             
                        cfg[Key.CURRENT_ROUND] = current_round
         
     | 
| 
       72 
75 
     | 
    
         | 
| 
       73 
76 
     | 
    
         
             
                        # Fit round
         
     | 
| 
         @@ -79,22 +82,19 @@ class DefaultWorkflow: 
     | 
|
| 
       79 
82 
     | 
    
         
             
                        # Evaluate round
         
     | 
| 
       80 
83 
     | 
    
         
             
                        self.evaluate_workflow(driver, context)
         
     | 
| 
       81 
84 
     | 
    
         | 
| 
       82 
     | 
    
         
            -
                    # Bookkeeping
         
     | 
| 
      
 85 
     | 
    
         
            +
                    # Bookkeeping and log results
         
     | 
| 
       83 
86 
     | 
    
         
             
                    end_time = timeit.default_timer()
         
     | 
| 
       84 
87 
     | 
    
         
             
                    elapsed = end_time - start_time
         
     | 
| 
       85 
     | 
    
         
            -
                    log(INFO, "FL finished in %s", elapsed)
         
     | 
| 
       86 
     | 
    
         
            -
             
     | 
| 
       87 
     | 
    
         
            -
                    # Log results
         
     | 
| 
       88 
88 
     | 
    
         
             
                    hist = context.history
         
     | 
| 
       89 
     | 
    
         
            -
                    log(INFO, " 
     | 
| 
       90 
     | 
    
         
            -
                    log(
         
     | 
| 
       91 
     | 
    
         
            -
             
     | 
| 
       92 
     | 
    
         
            -
             
     | 
| 
       93 
     | 
    
         
            -
                         
     | 
| 
       94 
     | 
    
         
            -
             
     | 
| 
       95 
     | 
    
         
            -
             
     | 
| 
       96 
     | 
    
         
            -
             
     | 
| 
       97 
     | 
    
         
            -
                    log(INFO, " 
     | 
| 
      
 89 
     | 
    
         
            +
                    log(INFO, "")
         
     | 
| 
      
 90 
     | 
    
         
            +
                    log(INFO, "[SUMMARY]")
         
     | 
| 
      
 91 
     | 
    
         
            +
                    log(INFO, "Run finished %s rounds in %.2fs", context.config.num_rounds, elapsed)
         
     | 
| 
      
 92 
     | 
    
         
            +
                    for idx, line in enumerate(io.StringIO(str(hist))):
         
     | 
| 
      
 93 
     | 
    
         
            +
                        if idx == 0:
         
     | 
| 
      
 94 
     | 
    
         
            +
                            log(INFO, "%s", line.strip("\n"))
         
     | 
| 
      
 95 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 96 
     | 
    
         
            +
                            log(INFO, "\t%s", line.strip("\n"))
         
     | 
| 
      
 97 
     | 
    
         
            +
                    log(INFO, "")
         
     | 
| 
       98 
98 
     | 
    
         | 
| 
       99 
99 
     | 
    
         
             
                    # Terminate the thread
         
     | 
| 
       100 
100 
     | 
    
         
             
                    f_stop.set()
         
     | 
| 
         @@ -107,12 +107,11 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None: 
     | 
|
| 
       107 
107 
     | 
    
         
             
                if not isinstance(context, LegacyContext):
         
     | 
| 
       108 
108 
     | 
    
         
             
                    raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
         
     | 
| 
       109 
109 
     | 
    
         | 
| 
       110 
     | 
    
         
            -
                log(INFO, "Initializing global parameters")
         
     | 
| 
       111 
110 
     | 
    
         
             
                parameters = context.strategy.initialize_parameters(
         
     | 
| 
       112 
111 
     | 
    
         
             
                    client_manager=context.client_manager
         
     | 
| 
       113 
112 
     | 
    
         
             
                )
         
     | 
| 
       114 
113 
     | 
    
         
             
                if parameters is not None:
         
     | 
| 
       115 
     | 
    
         
            -
                    log(INFO, "Using initial parameters provided by strategy")
         
     | 
| 
      
 114 
     | 
    
         
            +
                    log(INFO, "Using initial global parameters provided by strategy")
         
     | 
| 
       116 
115 
     | 
    
         
             
                    paramsrecord = compat.parameters_to_parametersrecord(
         
     | 
| 
       117 
116 
     | 
    
         
             
                        parameters, keep_input=True
         
     | 
| 
       118 
117 
     | 
    
         
             
                    )
         
     | 
| 
         @@ -128,7 +127,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None: 
     | 
|
| 
       128 
127 
     | 
    
         
             
                                content=content,
         
     | 
| 
       129 
128 
     | 
    
         
             
                                message_type=MessageTypeLegacy.GET_PARAMETERS,
         
     | 
| 
       130 
129 
     | 
    
         
             
                                dst_node_id=random_client.node_id,
         
     | 
| 
       131 
     | 
    
         
            -
                                group_id="",
         
     | 
| 
      
 130 
     | 
    
         
            +
                                group_id="0",
         
     | 
| 
       132 
131 
     | 
    
         
             
                                ttl="",
         
     | 
| 
       133 
132 
     | 
    
         
             
                            )
         
     | 
| 
       134 
133 
     | 
    
         
             
                        ]
         
     | 
| 
         @@ -140,7 +139,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None: 
     | 
|
| 
       140 
139 
     | 
    
         
             
                context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
         
     | 
| 
       141 
140 
     | 
    
         | 
| 
       142 
141 
     | 
    
         
             
                # Evaluate initial parameters
         
     | 
| 
       143 
     | 
    
         
            -
                log(INFO, "Evaluating initial parameters")
         
     | 
| 
      
 142 
     | 
    
         
            +
                log(INFO, "Evaluating initial global parameters")
         
     | 
| 
       144 
143 
     | 
    
         
             
                parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True)
         
     | 
| 
       145 
144 
     | 
    
         
             
                res = context.strategy.evaluate(0, parameters=parameters)
         
     | 
| 
       146 
145 
     | 
    
         
             
                if res is not None:
         
     | 
| 
         @@ -186,7 +185,9 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None 
     | 
|
| 
       186 
185 
     | 
    
         
             
                    )
         
     | 
| 
       187 
186 
     | 
    
         | 
| 
       188 
187 
     | 
    
         | 
| 
       189 
     | 
    
         
            -
            def default_fit_workflow( 
     | 
| 
      
 188 
     | 
    
         
            +
            def default_fit_workflow(  # pylint: disable=R0914
         
     | 
| 
      
 189 
     | 
    
         
            +
                driver: Driver, context: Context
         
     | 
| 
      
 190 
     | 
    
         
            +
            ) -> None:
         
     | 
| 
       190 
191 
     | 
    
         
             
                """Execute the default workflow for a single fit round."""
         
     | 
| 
       191 
192 
     | 
    
         
             
                if not isinstance(context, LegacyContext):
         
     | 
| 
       192 
193 
     | 
    
         
             
                    raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.")
         
     | 
| 
         @@ -207,12 +208,11 @@ def default_fit_workflow(driver: Driver, context: Context) -> None: 
     | 
|
| 
       207 
208 
     | 
    
         
             
                )
         
     | 
| 
       208 
209 
     | 
    
         | 
| 
       209 
210 
     | 
    
         
             
                if not client_instructions:
         
     | 
| 
       210 
     | 
    
         
            -
                    log(INFO, " 
     | 
| 
      
 211 
     | 
    
         
            +
                    log(INFO, "configure_fit: no clients selected, cancel")
         
     | 
| 
       211 
212 
     | 
    
         
             
                    return
         
     | 
| 
       212 
213 
     | 
    
         
             
                log(
         
     | 
| 
       213 
     | 
    
         
            -
                     
     | 
| 
       214 
     | 
    
         
            -
                    " 
     | 
| 
       215 
     | 
    
         
            -
                    current_round,
         
     | 
| 
      
 214 
     | 
    
         
            +
                    INFO,
         
     | 
| 
      
 215 
     | 
    
         
            +
                    "configure_fit: strategy sampled %s clients (out of %s)",
         
     | 
| 
       216 
216 
     | 
    
         
             
                    len(client_instructions),
         
     | 
| 
       217 
217 
     | 
    
         
             
                    context.client_manager.num_available(),
         
     | 
| 
       218 
218 
     | 
    
         
             
                )
         
     | 
| 
         @@ -226,7 +226,7 @@ def default_fit_workflow(driver: Driver, context: Context) -> None: 
     | 
|
| 
       226 
226 
     | 
    
         
             
                        content=compat.fitins_to_recordset(fitins, True),
         
     | 
| 
       227 
227 
     | 
    
         
             
                        message_type=MessageType.TRAIN,
         
     | 
| 
       228 
228 
     | 
    
         
             
                        dst_node_id=proxy.node_id,
         
     | 
| 
       229 
     | 
    
         
            -
                        group_id= 
     | 
| 
      
 229 
     | 
    
         
            +
                        group_id=str(current_round),
         
     | 
| 
       230 
230 
     | 
    
         
             
                        ttl="",
         
     | 
| 
       231 
231 
     | 
    
         
             
                    )
         
     | 
| 
       232 
232 
     | 
    
         
             
                    for proxy, fitins in client_instructions
         
     | 
| 
         @@ -236,14 +236,14 @@ def default_fit_workflow(driver: Driver, context: Context) -> None: 
     | 
|
| 
       236 
236 
     | 
    
         
             
                # collect `fit` results from all clients participating in this round
         
     | 
| 
       237 
237 
     | 
    
         
             
                messages = list(driver.send_and_receive(out_messages))
         
     | 
| 
       238 
238 
     | 
    
         
             
                del out_messages
         
     | 
| 
      
 239 
     | 
    
         
            +
                num_failures = len([msg for msg in messages if msg.has_error()])
         
     | 
| 
       239 
240 
     | 
    
         | 
| 
       240 
241 
     | 
    
         
             
                # No exception/failure handling currently
         
     | 
| 
       241 
242 
     | 
    
         
             
                log(
         
     | 
| 
       242 
     | 
    
         
            -
                     
     | 
| 
       243 
     | 
    
         
            -
                    " 
     | 
| 
       244 
     | 
    
         
            -
                     
     | 
| 
       245 
     | 
    
         
            -
                     
     | 
| 
       246 
     | 
    
         
            -
                    0,
         
     | 
| 
      
 243 
     | 
    
         
            +
                    INFO,
         
     | 
| 
      
 244 
     | 
    
         
            +
                    "aggregate_fit: received %s results and %s failures",
         
     | 
| 
      
 245 
     | 
    
         
            +
                    len(messages) - num_failures,
         
     | 
| 
      
 246 
     | 
    
         
            +
                    num_failures,
         
     | 
| 
       247 
247 
     | 
    
         
             
                )
         
     | 
| 
       248 
248 
     | 
    
         | 
| 
       249 
249 
     | 
    
         
             
                # Aggregate training results
         
     | 
| 
         @@ -288,12 +288,11 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None: 
     | 
|
| 
       288 
288 
     | 
    
         
             
                    client_manager=context.client_manager,
         
     | 
| 
       289 
289 
     | 
    
         
             
                )
         
     | 
| 
       290 
290 
     | 
    
         
             
                if not client_instructions:
         
     | 
| 
       291 
     | 
    
         
            -
                    log(INFO, " 
     | 
| 
      
 291 
     | 
    
         
            +
                    log(INFO, "configure_evaluate: no clients selected, skipping evaluation")
         
     | 
| 
       292 
292 
     | 
    
         
             
                    return
         
     | 
| 
       293 
293 
     | 
    
         
             
                log(
         
     | 
| 
       294 
     | 
    
         
            -
                     
     | 
| 
       295 
     | 
    
         
            -
                    " 
     | 
| 
       296 
     | 
    
         
            -
                    current_round,
         
     | 
| 
      
 294 
     | 
    
         
            +
                    INFO,
         
     | 
| 
      
 295 
     | 
    
         
            +
                    "configure_evaluate: strategy sampled %s clients (out of %s)",
         
     | 
| 
       297 
296 
     | 
    
         
             
                    len(client_instructions),
         
     | 
| 
       298 
297 
     | 
    
         
             
                    context.client_manager.num_available(),
         
     | 
| 
       299 
298 
     | 
    
         
             
                )
         
     | 
| 
         @@ -307,7 +306,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None: 
     | 
|
| 
       307 
306 
     | 
    
         
             
                        content=compat.evaluateins_to_recordset(evalins, True),
         
     | 
| 
       308 
307 
     | 
    
         
             
                        message_type=MessageType.EVALUATE,
         
     | 
| 
       309 
308 
     | 
    
         
             
                        dst_node_id=proxy.node_id,
         
     | 
| 
       310 
     | 
    
         
            -
                        group_id= 
     | 
| 
      
 309 
     | 
    
         
            +
                        group_id=str(current_round),
         
     | 
| 
       311 
310 
     | 
    
         
             
                        ttl="",
         
     | 
| 
       312 
311 
     | 
    
         
             
                    )
         
     | 
| 
       313 
312 
     | 
    
         
             
                    for proxy, evalins in client_instructions
         
     | 
| 
         @@ -317,14 +316,14 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None: 
     | 
|
| 
       317 
316 
     | 
    
         
             
                # collect `evaluate` results from all clients participating in this round
         
     | 
| 
       318 
317 
     | 
    
         
             
                messages = list(driver.send_and_receive(out_messages))
         
     | 
| 
       319 
318 
     | 
    
         
             
                del out_messages
         
     | 
| 
      
 319 
     | 
    
         
            +
                num_failures = len([msg for msg in messages if msg.has_error()])
         
     | 
| 
       320 
320 
     | 
    
         | 
| 
       321 
321 
     | 
    
         
             
                # No exception/failure handling currently
         
     | 
| 
       322 
322 
     | 
    
         
             
                log(
         
     | 
| 
       323 
     | 
    
         
            -
                     
     | 
| 
       324 
     | 
    
         
            -
                    " 
     | 
| 
       325 
     | 
    
         
            -
                     
     | 
| 
       326 
     | 
    
         
            -
                     
     | 
| 
       327 
     | 
    
         
            -
                    0,
         
     | 
| 
      
 323 
     | 
    
         
            +
                    INFO,
         
     | 
| 
      
 324 
     | 
    
         
            +
                    "aggregate_evaluate: received %s results and %s failures",
         
     | 
| 
      
 325 
     | 
    
         
            +
                    len(messages) - num_failures,
         
     | 
| 
      
 326 
     | 
    
         
            +
                    num_failures,
         
     | 
| 
       328 
327 
     | 
    
         
             
                )
         
     | 
| 
       329 
328 
     | 
    
         | 
| 
       330 
329 
     | 
    
         
             
                # Aggregate the evaluation results
         
     | 
| 
         @@ -0,0 +1,112 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # Copyright 2024 Flower Labs GmbH. All Rights Reserved.
         
     | 
| 
      
 2 
     | 
    
         
            +
            #
         
     | 
| 
      
 3 
     | 
    
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
      
 4 
     | 
    
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 
      
 5 
     | 
    
         
            +
            # You may obtain a copy of the License at
         
     | 
| 
      
 6 
     | 
    
         
            +
            #
         
     | 
| 
      
 7 
     | 
    
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
      
 8 
     | 
    
         
            +
            #
         
     | 
| 
      
 9 
     | 
    
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
      
 10 
     | 
    
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
      
 11 
     | 
    
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
      
 12 
     | 
    
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 
      
 13 
     | 
    
         
            +
            # limitations under the License.
         
     | 
| 
      
 14 
     | 
    
         
            +
            # ==============================================================================
         
     | 
| 
      
 15 
     | 
    
         
            +
            """Workflow for the SecAgg protocol."""
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
            from typing import Optional, Union
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
            from .secaggplus_workflow import SecAggPlusWorkflow
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
            class SecAggWorkflow(SecAggPlusWorkflow):
         
     | 
| 
      
 24 
     | 
    
         
            +
                """The workflow for the SecAgg protocol.
         
     | 
| 
      
 25 
     | 
    
         
            +
             
     | 
| 
      
 26 
     | 
    
         
            +
                The SecAgg protocol ensures the secure summation of integer vectors owned by
         
     | 
| 
      
 27 
     | 
    
         
            +
                multiple parties, without accessing any individual integer vector. This workflow
         
     | 
| 
      
 28 
     | 
    
         
            +
                allows the server to compute the weighted average of model parameters across all
         
     | 
| 
      
 29 
     | 
    
         
            +
                clients, ensuring individual contributions remain private. This is achieved by
         
     | 
| 
      
 30 
     | 
    
         
            +
                clients sending both, a weighting factor and a weighted version of the locally
         
     | 
| 
      
 31 
     | 
    
         
            +
                updated parameters, both of which are masked for privacy. Specifically, each
         
     | 
| 
      
 32 
     | 
    
         
            +
                client uploads "[w, w * params]" with masks, where weighting factor 'w' is the
         
     | 
| 
      
 33 
     | 
    
         
            +
                number of examples ('num_examples') and 'params' represents the model parameters
         
     | 
| 
      
 34 
     | 
    
         
            +
                ('parameters') from the client's `FitRes`. The server then aggregates these
         
     | 
| 
      
 35 
     | 
    
         
            +
                contributions to compute the weighted average of model parameters.
         
     | 
| 
      
 36 
     | 
    
         
            +
             
     | 
| 
      
 37 
     | 
    
         
            +
                The protocol involves four main stages:
         
     | 
| 
      
 38 
     | 
    
         
            +
                - 'setup': Send SecAgg configuration to clients and collect their public keys.
         
     | 
| 
      
 39 
     | 
    
         
            +
                - 'share keys': Broadcast public keys among clients and collect encrypted secret
         
     | 
| 
      
 40 
     | 
    
         
            +
                  key shares.
         
     | 
| 
      
 41 
     | 
    
         
            +
                - 'collect masked vectors': Forward encrypted secret key shares to target clients
         
     | 
| 
      
 42 
     | 
    
         
            +
                  and collect masked model parameters.
         
     | 
| 
      
 43 
     | 
    
         
            +
                - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
         
     | 
| 
      
 44 
     | 
    
         
            +
             
     | 
| 
      
 45 
     | 
    
         
            +
                Only the aggregated model parameters are exposed and passed to
         
     | 
| 
      
 46 
     | 
    
         
            +
                `Strategy.aggregate_fit`, ensuring individual data privacy.
         
     | 
| 
      
 47 
     | 
    
         
            +
             
     | 
| 
      
 48 
     | 
    
         
            +
                Parameters
         
     | 
| 
      
 49 
     | 
    
         
            +
                ----------
         
     | 
| 
      
 50 
     | 
    
         
            +
                reconstruction_threshold : Union[int, float]
         
     | 
| 
      
 51 
     | 
    
         
            +
                    The minimum number of shares required to reconstruct a client's private key,
         
     | 
| 
      
 52 
     | 
    
         
            +
                    or, if specified as a float, it represents the proportion of the total number
         
     | 
| 
      
 53 
     | 
    
         
            +
                    of shares needed for reconstruction. This threshold ensures privacy by allowing
         
     | 
| 
      
 54 
     | 
    
         
            +
                    for the recovery of contributions from dropped clients during aggregation,
         
     | 
| 
      
 55 
     | 
    
         
            +
                    without compromising individual client data.
         
     | 
| 
      
 56 
     | 
    
         
            +
                max_weight : Optional[float] (default: 1000.0)
         
     | 
| 
      
 57 
     | 
    
         
            +
                    The maximum value of the weight that can be assigned to any single client's
         
     | 
| 
      
 58 
     | 
    
         
            +
                    update during the weighted average calculation on the server side, e.g., in the
         
     | 
| 
      
 59 
     | 
    
         
            +
                    FedAvg algorithm.
         
     | 
| 
      
 60 
     | 
    
         
            +
                clipping_range : float, optional (default: 8.0)
         
     | 
| 
      
 61 
     | 
    
         
            +
                    The range within which model parameters are clipped before quantization.
         
     | 
| 
      
 62 
     | 
    
         
            +
                    This parameter ensures each model parameter is bounded within
         
     | 
| 
      
 63 
     | 
    
         
            +
                    [-clipping_range, clipping_range], facilitating quantization.
         
     | 
| 
      
 64 
     | 
    
         
            +
                quantization_range : int, optional (default: 4194304, this equals 2**22)
         
     | 
| 
      
 65 
     | 
    
         
            +
                    The size of the range into which floating-point model parameters are quantized,
         
     | 
| 
      
 66 
     | 
    
         
            +
                    mapping each parameter to an integer in [0, quantization_range-1]. This
         
     | 
| 
      
 67 
     | 
    
         
            +
                    facilitates cryptographic operations on the model updates.
         
     | 
| 
      
 68 
     | 
    
         
            +
                modulus_range : int, optional (default: 4294967296, this equals 2**32)
         
     | 
| 
      
 69 
     | 
    
         
            +
                    The range of values from which random mask entries are uniformly sampled
         
     | 
| 
      
 70 
     | 
    
         
            +
                    ([0, modulus_range-1]). `modulus_range` must be less than 4294967296.
         
     | 
| 
      
 71 
     | 
    
         
            +
                    Please use 2**n values for `modulus_range` to prevent overflow issues.
         
     | 
| 
      
 72 
     | 
    
         
            +
                timeout : Optional[float] (default: None)
         
     | 
| 
      
 73 
     | 
    
         
            +
                    The timeout duration in seconds. If specified, the workflow will wait for
         
     | 
| 
      
 74 
     | 
    
         
            +
                    replies for this duration each time. If `None`, there is no time limit and
         
     | 
| 
      
 75 
     | 
    
         
            +
                    the workflow will wait until replies for all messages are received.
         
     | 
| 
      
 76 
     | 
    
         
            +
             
     | 
| 
      
 77 
     | 
    
         
            +
                Notes
         
     | 
| 
      
 78 
     | 
    
         
            +
                -----
         
     | 
| 
      
 79 
     | 
    
         
            +
                - Each client's private key is split into N shares under the SecAgg protocol, where
         
     | 
| 
      
 80 
     | 
    
         
            +
                  N is the number of selected clients.
         
     | 
| 
      
 81 
     | 
    
         
            +
                - Generally, higher `reconstruction_threshold` means better privacy guarantees but
         
     | 
| 
      
 82 
     | 
    
         
            +
                  less tolerance to dropouts.
         
     | 
| 
      
 83 
     | 
    
         
            +
                - Too large `max_weight` may compromise the precision of the quantization.
         
     | 
| 
      
 84 
     | 
    
         
            +
                - `modulus_range` must be 2**n and larger than `quantization_range`.
         
     | 
| 
      
 85 
     | 
    
         
            +
                - When `reconstruction_threshold` is a float, it is interpreted as the proportion of
         
     | 
| 
      
 86 
     | 
    
         
            +
                  the number of all selected clients needed for the reconstruction of a private key.
         
     | 
| 
      
 87 
     | 
    
         
            +
                  This feature enables flexibility in setting the security threshold relative to the
         
     | 
| 
      
 88 
     | 
    
         
            +
                  number of selected clients.
         
     | 
| 
      
 89 
     | 
    
         
            +
                - `reconstruction_threshold`, and the quantization parameters
         
     | 
| 
      
 90 
     | 
    
         
            +
                  (`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in
         
     | 
| 
      
 91 
     | 
    
         
            +
                  balancing privacy, robustness, and efficiency within the SecAgg protocol.
         
     | 
| 
      
 92 
     | 
    
         
            +
                """
         
     | 
| 
      
 93 
     | 
    
         
            +
             
     | 
| 
      
 94 
     | 
    
         
            +
                def __init__(  # pylint: disable=R0913
         
     | 
| 
      
 95 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 96 
     | 
    
         
            +
                    reconstruction_threshold: Union[int, float],
         
     | 
| 
      
 97 
     | 
    
         
            +
                    *,
         
     | 
| 
      
 98 
     | 
    
         
            +
                    max_weight: float = 1000.0,
         
     | 
| 
      
 99 
     | 
    
         
            +
                    clipping_range: float = 8.0,
         
     | 
| 
      
 100 
     | 
    
         
            +
                    quantization_range: int = 4194304,
         
     | 
| 
      
 101 
     | 
    
         
            +
                    modulus_range: int = 4294967296,
         
     | 
| 
      
 102 
     | 
    
         
            +
                    timeout: Optional[float] = None,
         
     | 
| 
      
 103 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 104 
     | 
    
         
            +
                    super().__init__(
         
     | 
| 
      
 105 
     | 
    
         
            +
                        num_shares=1.0,
         
     | 
| 
      
 106 
     | 
    
         
            +
                        reconstruction_threshold=reconstruction_threshold,
         
     | 
| 
      
 107 
     | 
    
         
            +
                        max_weight=max_weight,
         
     | 
| 
      
 108 
     | 
    
         
            +
                        clipping_range=clipping_range,
         
     | 
| 
      
 109 
     | 
    
         
            +
                        quantization_range=quantization_range,
         
     | 
| 
      
 110 
     | 
    
         
            +
                        modulus_range=modulus_range,
         
     | 
| 
      
 111 
     | 
    
         
            +
                        timeout=timeout,
         
     | 
| 
      
 112 
     | 
    
         
            +
                    )
         
     | 
| 
         @@ -17,12 +17,11 @@ 
     | 
|
| 
       17 
17 
     | 
    
         | 
| 
       18 
18 
     | 
    
         
             
            import random
         
     | 
| 
       19 
19 
     | 
    
         
             
            from dataclasses import dataclass, field
         
     | 
| 
       20 
     | 
    
         
            -
            from logging import ERROR, WARN
         
     | 
| 
       21 
     | 
    
         
            -
            from typing import Dict, List, Optional, Set, Union, cast
         
     | 
| 
      
 20 
     | 
    
         
            +
            from logging import DEBUG, ERROR, INFO, WARN
         
     | 
| 
      
 21 
     | 
    
         
            +
            from typing import Dict, List, Optional, Set, Tuple, Union, cast
         
     | 
| 
       22 
22 
     | 
    
         | 
| 
       23 
23 
     | 
    
         
             
            import flwr.common.recordset_compat as compat
         
     | 
| 
       24 
24 
     | 
    
         
             
            from flwr.common import (
         
     | 
| 
       25 
     | 
    
         
            -
                Code,
         
     | 
| 
       26 
25 
     | 
    
         
             
                ConfigsRecord,
         
     | 
| 
       27 
26 
     | 
    
         
             
                Context,
         
     | 
| 
       28 
27 
     | 
    
         
             
                FitRes,
         
     | 
| 
         @@ -30,7 +29,6 @@ from flwr.common import ( 
     | 
|
| 
       30 
29 
     | 
    
         
             
                MessageType,
         
     | 
| 
       31 
30 
     | 
    
         
             
                NDArrays,
         
     | 
| 
       32 
31 
     | 
    
         
             
                RecordSet,
         
     | 
| 
       33 
     | 
    
         
            -
                Status,
         
     | 
| 
       34 
32 
     | 
    
         
             
                bytes_to_ndarray,
         
     | 
| 
       35 
33 
     | 
    
         
             
                log,
         
     | 
| 
       36 
34 
     | 
    
         
             
                ndarrays_to_parameters,
         
     | 
| 
         @@ -55,7 +53,7 @@ from flwr.common.secure_aggregation.secaggplus_constants import ( 
     | 
|
| 
       55 
53 
     | 
    
         
             
                Stage,
         
     | 
| 
       56 
54 
     | 
    
         
             
            )
         
     | 
| 
       57 
55 
     | 
    
         
             
            from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
         
     | 
| 
       58 
     | 
    
         
            -
            from flwr.server. 
     | 
| 
      
 56 
     | 
    
         
            +
            from flwr.server.client_proxy import ClientProxy
         
     | 
| 
       59 
57 
     | 
    
         
             
            from flwr.server.compat.legacy_context import LegacyContext
         
     | 
| 
       60 
58 
     | 
    
         
             
            from flwr.server.driver import Driver
         
     | 
| 
       61 
59 
     | 
    
         | 
| 
         @@ -67,6 +65,7 @@ from ..constant import Key as WorkflowKey 
     | 
|
| 
       67 
65 
     | 
    
         
             
            class WorkflowState:  # pylint: disable=R0902
         
     | 
| 
       68 
66 
     | 
    
         
             
                """The state of the SecAgg+ protocol."""
         
     | 
| 
       69 
67 
     | 
    
         | 
| 
      
 68 
     | 
    
         
            +
                nid_to_proxies: Dict[int, ClientProxy] = field(default_factory=dict)
         
     | 
| 
       70 
69 
     | 
    
         
             
                nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict)
         
     | 
| 
       71 
70 
     | 
    
         
             
                sampled_node_ids: Set[int] = field(default_factory=set)
         
     | 
| 
       72 
71 
     | 
    
         
             
                active_node_ids: Set[int] = field(default_factory=set)
         
     | 
| 
         @@ -81,6 +80,7 @@ class WorkflowState:  # pylint: disable=R0902 
     | 
|
| 
       81 
80 
     | 
    
         
             
                forward_srcs: Dict[int, List[int]] = field(default_factory=dict)
         
     | 
| 
       82 
81 
     | 
    
         
             
                forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
         
     | 
| 
       83 
82 
     | 
    
         
             
                aggregate_ndarrays: NDArrays = field(default_factory=list)
         
     | 
| 
      
 83 
     | 
    
         
            +
                legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)
         
     | 
| 
       84 
84 
     | 
    
         | 
| 
       85 
85 
     | 
    
         | 
| 
       86 
86 
     | 
    
         
             
            class SecAggPlusWorkflow:
         
     | 
| 
         @@ -101,7 +101,7 @@ class SecAggPlusWorkflow: 
     | 
|
| 
       101 
101 
     | 
    
         
             
                - 'setup': Send SecAgg+ configuration to clients and collect their public keys.
         
     | 
| 
       102 
102 
     | 
    
         
             
                - 'share keys': Broadcast public keys among clients and collect encrypted secret
         
     | 
| 
       103 
103 
     | 
    
         
             
                  key shares.
         
     | 
| 
       104 
     | 
    
         
            -
                - 'collect masked  
     | 
| 
      
 104 
     | 
    
         
            +
                - 'collect masked vectors': Forward encrypted secret key shares to target clients
         
     | 
| 
       105 
105 
     | 
    
         
             
                  and collect masked model parameters.
         
     | 
| 
       106 
106 
     | 
    
         
             
                - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters.
         
     | 
| 
       107 
107 
     | 
    
         | 
| 
         @@ -195,12 +195,15 @@ class SecAggPlusWorkflow: 
     | 
|
| 
       195 
195 
     | 
    
         
             
                    steps = (
         
     | 
| 
       196 
196 
     | 
    
         
             
                        self.setup_stage,
         
     | 
| 
       197 
197 
     | 
    
         
             
                        self.share_keys_stage,
         
     | 
| 
       198 
     | 
    
         
            -
                        self. 
     | 
| 
      
 198 
     | 
    
         
            +
                        self.collect_masked_vectors_stage,
         
     | 
| 
       199 
199 
     | 
    
         
             
                        self.unmask_stage,
         
     | 
| 
       200 
200 
     | 
    
         
             
                    )
         
     | 
| 
      
 201 
     | 
    
         
            +
                    log(INFO, "Secure aggregation commencing.")
         
     | 
| 
       201 
202 
     | 
    
         
             
                    for step in steps:
         
     | 
| 
       202 
203 
     | 
    
         
             
                        if not step(driver, context, state):
         
     | 
| 
      
 204 
     | 
    
         
            +
                            log(INFO, "Secure aggregation halted.")
         
     | 
| 
       203 
205 
     | 
    
         
             
                            return
         
     | 
| 
      
 206 
     | 
    
         
            +
                    log(INFO, "Secure aggregation completed.")
         
     | 
| 
       204 
207 
     | 
    
         | 
| 
       205 
208 
     | 
    
         
             
                def _check_init_params(self) -> None:  # pylint: disable=R0912
         
     | 
| 
       206 
209 
     | 
    
         
             
                    # Check `num_shares`
         
     | 
| 
         @@ -287,10 +290,21 @@ class SecAggPlusWorkflow: 
     | 
|
| 
       287 
290 
     | 
    
         
             
                    proxy_fitins_lst = context.strategy.configure_fit(
         
     | 
| 
       288 
291 
     | 
    
         
             
                        current_round, parameters, context.client_manager
         
     | 
| 
       289 
292 
     | 
    
         
             
                    )
         
     | 
| 
      
 293 
     | 
    
         
            +
                    if not proxy_fitins_lst:
         
     | 
| 
      
 294 
     | 
    
         
            +
                        log(INFO, "configure_fit: no clients selected, cancel")
         
     | 
| 
      
 295 
     | 
    
         
            +
                        return False
         
     | 
| 
      
 296 
     | 
    
         
            +
                    log(
         
     | 
| 
      
 297 
     | 
    
         
            +
                        INFO,
         
     | 
| 
      
 298 
     | 
    
         
            +
                        "configure_fit: strategy sampled %s clients (out of %s)",
         
     | 
| 
      
 299 
     | 
    
         
            +
                        len(proxy_fitins_lst),
         
     | 
| 
      
 300 
     | 
    
         
            +
                        context.client_manager.num_available(),
         
     | 
| 
      
 301 
     | 
    
         
            +
                    )
         
     | 
| 
      
 302 
     | 
    
         
            +
             
     | 
| 
       290 
303 
     | 
    
         
             
                    state.nid_to_fitins = {
         
     | 
| 
       291 
     | 
    
         
            -
                        proxy.node_id: compat.fitins_to_recordset(fitins,  
     | 
| 
      
 304 
     | 
    
         
            +
                        proxy.node_id: compat.fitins_to_recordset(fitins, True)
         
     | 
| 
       292 
305 
     | 
    
         
             
                        for proxy, fitins in proxy_fitins_lst
         
     | 
| 
       293 
306 
     | 
    
         
             
                    }
         
     | 
| 
      
 307 
     | 
    
         
            +
                    state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst}
         
     | 
| 
       294 
308 
     | 
    
         | 
| 
       295 
309 
     | 
    
         
             
                    # Protocol config
         
     | 
| 
       296 
310 
     | 
    
         
             
                    sampled_node_ids = list(state.nid_to_fitins.keys())
         
     | 
| 
         @@ -362,12 +376,22 @@ class SecAggPlusWorkflow: 
     | 
|
| 
       362 
376 
     | 
    
         
             
                            ttl="",
         
     | 
| 
       363 
377 
     | 
    
         
             
                        )
         
     | 
| 
       364 
378 
     | 
    
         | 
| 
      
 379 
     | 
    
         
            +
                    log(
         
     | 
| 
      
 380 
     | 
    
         
            +
                        DEBUG,
         
     | 
| 
      
 381 
     | 
    
         
            +
                        "[Stage 0] Sending configurations to %s clients.",
         
     | 
| 
      
 382 
     | 
    
         
            +
                        len(state.active_node_ids),
         
     | 
| 
      
 383 
     | 
    
         
            +
                    )
         
     | 
| 
       365 
384 
     | 
    
         
             
                    msgs = driver.send_and_receive(
         
     | 
| 
       366 
385 
     | 
    
         
             
                        [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
         
     | 
| 
       367 
386 
     | 
    
         
             
                    )
         
     | 
| 
       368 
387 
     | 
    
         
             
                    state.active_node_ids = {
         
     | 
| 
       369 
388 
     | 
    
         
             
                        msg.metadata.src_node_id for msg in msgs if not msg.has_error()
         
     | 
| 
       370 
389 
     | 
    
         
             
                    }
         
     | 
| 
      
 390 
     | 
    
         
            +
                    log(
         
     | 
| 
      
 391 
     | 
    
         
            +
                        DEBUG,
         
     | 
| 
      
 392 
     | 
    
         
            +
                        "[Stage 0] Received public keys from %s clients.",
         
     | 
| 
      
 393 
     | 
    
         
            +
                        len(state.active_node_ids),
         
     | 
| 
      
 394 
     | 
    
         
            +
                    )
         
     | 
| 
       371 
395 
     | 
    
         | 
| 
       372 
396 
     | 
    
         
             
                    for msg in msgs:
         
     | 
| 
       373 
397 
     | 
    
         
             
                        if msg.has_error():
         
     | 
| 
         @@ -401,12 +425,22 @@ class SecAggPlusWorkflow: 
     | 
|
| 
       401 
425 
     | 
    
         
             
                        )
         
     | 
| 
       402 
426 
     | 
    
         | 
| 
       403 
427 
     | 
    
         
             
                    # Broadcast public keys to clients and receive secret key shares
         
     | 
| 
      
 428 
     | 
    
         
            +
                    log(
         
     | 
| 
      
 429 
     | 
    
         
            +
                        DEBUG,
         
     | 
| 
      
 430 
     | 
    
         
            +
                        "[Stage 1] Forwarding public keys to %s clients.",
         
     | 
| 
      
 431 
     | 
    
         
            +
                        len(state.active_node_ids),
         
     | 
| 
      
 432 
     | 
    
         
            +
                    )
         
     | 
| 
       404 
433 
     | 
    
         
             
                    msgs = driver.send_and_receive(
         
     | 
| 
       405 
434 
     | 
    
         
             
                        [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
         
     | 
| 
       406 
435 
     | 
    
         
             
                    )
         
     | 
| 
       407 
436 
     | 
    
         
             
                    state.active_node_ids = {
         
     | 
| 
       408 
437 
     | 
    
         
             
                        msg.metadata.src_node_id for msg in msgs if not msg.has_error()
         
     | 
| 
       409 
438 
     | 
    
         
             
                    }
         
     | 
| 
      
 439 
     | 
    
         
            +
                    log(
         
     | 
| 
      
 440 
     | 
    
         
            +
                        DEBUG,
         
     | 
| 
      
 441 
     | 
    
         
            +
                        "[Stage 1] Received encrypted key shares from %s clients.",
         
     | 
| 
      
 442 
     | 
    
         
            +
                        len(state.active_node_ids),
         
     | 
| 
      
 443 
     | 
    
         
            +
                    )
         
     | 
| 
       410 
444 
     | 
    
         | 
| 
       411 
445 
     | 
    
         
             
                    # Build forward packet list dictionary
         
     | 
| 
       412 
446 
     | 
    
         
             
                    srcs: List[int] = []
         
     | 
| 
         @@ -437,16 +471,16 @@ class SecAggPlusWorkflow: 
     | 
|
| 
       437 
471 
     | 
    
         | 
| 
       438 
472 
     | 
    
         
             
                    return self._check_threshold(state)
         
     | 
| 
       439 
473 
     | 
    
         | 
| 
       440 
     | 
    
         
            -
                def  
     | 
| 
      
 474 
     | 
    
         
            +
                def collect_masked_vectors_stage(
         
     | 
| 
       441 
475 
     | 
    
         
             
                    self, driver: Driver, context: LegacyContext, state: WorkflowState
         
     | 
| 
       442 
476 
     | 
    
         
             
                ) -> bool:
         
     | 
| 
       443 
     | 
    
         
            -
                    """Execute the 'collect masked  
     | 
| 
      
 477 
     | 
    
         
            +
                    """Execute the 'collect masked vectors' stage."""
         
     | 
| 
       444 
478 
     | 
    
         
             
                    cfg = context.state.configs_records[MAIN_CONFIGS_RECORD]
         
     | 
| 
       445 
479 
     | 
    
         | 
| 
       446 
     | 
    
         
            -
                    # Send secret key shares to clients (plus FitIns) and collect masked  
     | 
| 
      
 480 
     | 
    
         
            +
                    # Send secret key shares to clients (plus FitIns) and collect masked vectors
         
     | 
| 
       447 
481 
     | 
    
         
             
                    def make(nid: int) -> Message:
         
     | 
| 
       448 
482 
     | 
    
         
             
                        cfgs_dict = {
         
     | 
| 
       449 
     | 
    
         
            -
                            Key.STAGE: Stage. 
     | 
| 
      
 483 
     | 
    
         
            +
                            Key.STAGE: Stage.COLLECT_MASKED_VECTORS,
         
     | 
| 
       450 
484 
     | 
    
         
             
                            Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid],
         
     | 
| 
       451 
485 
     | 
    
         
             
                            Key.SOURCE_LIST: state.forward_srcs[nid],
         
     | 
| 
       452 
486 
     | 
    
         
             
                        }
         
     | 
| 
         @@ -461,12 +495,22 @@ class SecAggPlusWorkflow: 
     | 
|
| 
       461 
495 
     | 
    
         
             
                            ttl="",
         
     | 
| 
       462 
496 
     | 
    
         
             
                        )
         
     | 
| 
       463 
497 
     | 
    
         | 
| 
      
 498 
     | 
    
         
            +
                    log(
         
     | 
| 
      
 499 
     | 
    
         
            +
                        DEBUG,
         
     | 
| 
      
 500 
     | 
    
         
            +
                        "[Stage 2] Forwarding encrypted key shares to %s clients.",
         
     | 
| 
      
 501 
     | 
    
         
            +
                        len(state.active_node_ids),
         
     | 
| 
      
 502 
     | 
    
         
            +
                    )
         
     | 
| 
       464 
503 
     | 
    
         
             
                    msgs = driver.send_and_receive(
         
     | 
| 
       465 
504 
     | 
    
         
             
                        [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
         
     | 
| 
       466 
505 
     | 
    
         
             
                    )
         
     | 
| 
       467 
506 
     | 
    
         
             
                    state.active_node_ids = {
         
     | 
| 
       468 
507 
     | 
    
         
             
                        msg.metadata.src_node_id for msg in msgs if not msg.has_error()
         
     | 
| 
       469 
508 
     | 
    
         
             
                    }
         
     | 
| 
      
 509 
     | 
    
         
            +
                    log(
         
     | 
| 
      
 510 
     | 
    
         
            +
                        DEBUG,
         
     | 
| 
      
 511 
     | 
    
         
            +
                        "[Stage 2] Received masked vectors from %s clients.",
         
     | 
| 
      
 512 
     | 
    
         
            +
                        len(state.active_node_ids),
         
     | 
| 
      
 513 
     | 
    
         
            +
                    )
         
     | 
| 
       470 
514 
     | 
    
         | 
| 
       471 
515 
     | 
    
         
             
                    # Clear cache
         
     | 
| 
       472 
516 
     | 
    
         
             
                    del state.forward_ciphertexts, state.forward_srcs, state.nid_to_fitins
         
     | 
| 
         @@ -485,9 +529,15 @@ class SecAggPlusWorkflow: 
     | 
|
| 
       485 
529 
     | 
    
         
             
                        masked_vector = parameters_mod(masked_vector, state.mod_range)
         
     | 
| 
       486 
530 
     | 
    
         
             
                        state.aggregate_ndarrays = masked_vector
         
     | 
| 
       487 
531 
     | 
    
         | 
| 
      
 532 
     | 
    
         
            +
                    # Backward compatibility with Strategy
         
     | 
| 
      
 533 
     | 
    
         
            +
                    for msg in msgs:
         
     | 
| 
      
 534 
     | 
    
         
            +
                        fitres = compat.recordset_to_fitres(msg.content, True)
         
     | 
| 
      
 535 
     | 
    
         
            +
                        proxy = state.nid_to_proxies[msg.metadata.src_node_id]
         
     | 
| 
      
 536 
     | 
    
         
            +
                        state.legacy_results.append((proxy, fitres))
         
     | 
| 
      
 537 
     | 
    
         
            +
             
     | 
| 
       488 
538 
     | 
    
         
             
                    return self._check_threshold(state)
         
     | 
| 
       489 
539 
     | 
    
         | 
| 
       490 
     | 
    
         
            -
                def unmask_stage(  # pylint: disable=R0912, R0914
         
     | 
| 
      
 540 
     | 
    
         
            +
                def unmask_stage(  # pylint: disable=R0912, R0914, R0915
         
     | 
| 
       491 
541 
     | 
    
         
             
                    self, driver: Driver, context: LegacyContext, state: WorkflowState
         
     | 
| 
       492 
542 
     | 
    
         
             
                ) -> bool:
         
     | 
| 
       493 
543 
     | 
    
         
             
                    """Execute the 'unmask' stage."""
         
     | 
| 
         @@ -516,12 +566,22 @@ class SecAggPlusWorkflow: 
     | 
|
| 
       516 
566 
     | 
    
         
             
                            ttl="",
         
     | 
| 
       517 
567 
     | 
    
         
             
                        )
         
     | 
| 
       518 
568 
     | 
    
         | 
| 
      
 569 
     | 
    
         
            +
                    log(
         
     | 
| 
      
 570 
     | 
    
         
            +
                        DEBUG,
         
     | 
| 
      
 571 
     | 
    
         
            +
                        "[Stage 3] Requesting key shares from %s clients to remove masks.",
         
     | 
| 
      
 572 
     | 
    
         
            +
                        len(state.active_node_ids),
         
     | 
| 
      
 573 
     | 
    
         
            +
                    )
         
     | 
| 
       519 
574 
     | 
    
         
             
                    msgs = driver.send_and_receive(
         
     | 
| 
       520 
575 
     | 
    
         
             
                        [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout
         
     | 
| 
       521 
576 
     | 
    
         
             
                    )
         
     | 
| 
       522 
577 
     | 
    
         
             
                    state.active_node_ids = {
         
     | 
| 
       523 
578 
     | 
    
         
             
                        msg.metadata.src_node_id for msg in msgs if not msg.has_error()
         
     | 
| 
       524 
579 
     | 
    
         
             
                    }
         
     | 
| 
      
 580 
     | 
    
         
            +
                    log(
         
     | 
| 
      
 581 
     | 
    
         
            +
                        DEBUG,
         
     | 
| 
      
 582 
     | 
    
         
            +
                        "[Stage 3] Received key shares from %s clients.",
         
     | 
| 
      
 583 
     | 
    
         
            +
                        len(state.active_node_ids),
         
     | 
| 
      
 584 
     | 
    
         
            +
                    )
         
     | 
| 
       525 
585 
     | 
    
         | 
| 
       526 
586 
     | 
    
         
             
                    # Build collected shares dict
         
     | 
| 
       527 
587 
     | 
    
         
             
                    collected_shares_dict: Dict[int, List[bytes]] = {}
         
     | 
| 
         @@ -534,7 +594,7 @@ class SecAggPlusWorkflow: 
     | 
|
| 
       534 
594 
     | 
    
         
             
                        for owner_nid, share in zip(nids, shares):
         
     | 
| 
       535 
595 
     | 
    
         
             
                            collected_shares_dict[owner_nid].append(share)
         
     | 
| 
       536 
596 
     | 
    
         | 
| 
       537 
     | 
    
         
            -
                    # Remove  
     | 
| 
      
 597 
     | 
    
         
            +
                    # Remove masks for every active client after collect_masked_vectors stage
         
     | 
| 
       538 
598 
     | 
    
         
             
                    masked_vector = state.aggregate_ndarrays
         
     | 
| 
       539 
599 
     | 
    
         
             
                    del state.aggregate_ndarrays
         
     | 
| 
       540 
600 
     | 
    
         
             
                    for nid, share_list in collected_shares_dict.items():
         
     | 
| 
         @@ -584,18 +644,30 @@ class SecAggPlusWorkflow: 
     | 
|
| 
       584 
644 
     | 
    
         
             
                    for vec in aggregated_vector:
         
     | 
| 
       585 
645 
     | 
    
         
             
                        vec += offset
         
     | 
| 
       586 
646 
     | 
    
         
             
                        vec *= inv_dq_total_ratio
         
     | 
| 
       587 
     | 
    
         
            -
             
     | 
| 
       588 
     | 
    
         
            -
                     
     | 
| 
       589 
     | 
    
         
            -
             
     | 
| 
       590 
     | 
    
         
            -
             
     | 
| 
       591 
     | 
    
         
            -
             
     | 
| 
       592 
     | 
    
         
            -
                         
     | 
| 
       593 
     | 
    
         
            -
             
     | 
| 
       594 
     | 
    
         
            -
                     
     | 
| 
      
 647 
     | 
    
         
            +
             
     | 
| 
      
 648 
     | 
    
         
            +
                    # Backward compatibility with Strategy
         
     | 
| 
      
 649 
     | 
    
         
            +
                    results = state.legacy_results
         
     | 
| 
      
 650 
     | 
    
         
            +
                    parameters = ndarrays_to_parameters(aggregated_vector)
         
     | 
| 
      
 651 
     | 
    
         
            +
                    for _, fitres in results:
         
     | 
| 
      
 652 
     | 
    
         
            +
                        fitres.parameters = parameters
         
     | 
| 
      
 653 
     | 
    
         
            +
             
     | 
| 
      
 654 
     | 
    
         
            +
                    # No exception/failure handling currently
         
     | 
| 
      
 655 
     | 
    
         
            +
                    log(
         
     | 
| 
      
 656 
     | 
    
         
            +
                        INFO,
         
     | 
| 
      
 657 
     | 
    
         
            +
                        "aggregate_fit: received %s results and %s failures",
         
     | 
| 
      
 658 
     | 
    
         
            +
                        len(results),
         
     | 
| 
       595 
659 
     | 
    
         
             
                        0,
         
     | 
| 
       596 
     | 
    
         
            -
                        driver.grpc_driver,  # type: ignore
         
     | 
| 
       597 
     | 
    
         
            -
                        False,
         
     | 
| 
       598 
     | 
    
         
            -
                        driver.run_id,  # type: ignore
         
     | 
| 
       599 
660 
     | 
    
         
             
                    )
         
     | 
| 
       600 
     | 
    
         
            -
                    context.strategy.aggregate_fit(current_round,  
     | 
| 
      
 661 
     | 
    
         
            +
                    aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
         
     | 
| 
      
 662 
     | 
    
         
            +
                    parameters_aggregated, metrics_aggregated = aggregated_result
         
     | 
| 
      
 663 
     | 
    
         
            +
             
     | 
| 
      
 664 
     | 
    
         
            +
                    # Update the parameters and write history
         
     | 
| 
      
 665 
     | 
    
         
            +
                    if parameters_aggregated:
         
     | 
| 
      
 666 
     | 
    
         
            +
                        paramsrecord = compat.parameters_to_parametersrecord(
         
     | 
| 
      
 667 
     | 
    
         
            +
                            parameters_aggregated, True
         
     | 
| 
      
 668 
     | 
    
         
            +
                        )
         
     | 
| 
      
 669 
     | 
    
         
            +
                        context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord
         
     | 
| 
      
 670 
     | 
    
         
            +
                        context.history.add_metrics_distributed_fit(
         
     | 
| 
      
 671 
     | 
    
         
            +
                            server_round=current_round, metrics=metrics_aggregated
         
     | 
| 
      
 672 
     | 
    
         
            +
                        )
         
     | 
| 
       601 
673 
     | 
    
         
             
                    return True
         
     |