blazefl.contrib.FedAvgServerHandler#
- class blazefl.contrib.FedAvgServerHandler(model_selector: ModelSelector, model_name: str, dataset: PartitionedDataset, global_round: int, num_clients: int, sample_ratio: float, device: str)[source]#
Bases:
ServerHandler
Server-side handler for the Federated Averaging (FedAvg) algorithm.
Manages the global model, coordinates client sampling, aggregates client updates, and controls the training process across multiple rounds.
- model#
The global model being trained.
- Type:
torch.nn.Module
- dataset#
Dataset partitioned across clients.
- Type:
- global_round#
Total number of federated learning rounds.
- Type:
int
- num_clients#
Total number of clients in the federation.
- Type:
int
- sample_ratio#
Fraction of clients to sample in each round.
- Type:
float
- device#
Device to run the model on (‘cpu’ or ‘cuda’).
- Type:
str
- client_buffer_cache#
Cache for storing client
- Type:
list[FedAvgUplinkPackage]
- updates before aggregation.
- num_clients_per_round#
Number of clients sampled per round.
- Type:
int
- round#
Current training round.
- Type:
int
- __init__(model_selector: ModelSelector, model_name: str, dataset: PartitionedDataset, global_round: int, num_clients: int, sample_ratio: float, device: str) None [source]#
Initialize the FedAvgServerHandler.
- Parameters:
model_selector (ModelSelector) – Selector for initializing the model.
model_name (str) – Name of the model to be used.
dataset (PartitionedDataset) – Dataset partitioned across clients.
global_round (int) – Total number of federated learning rounds.
num_clients (int) – Total number of clients in the federation.
sample_ratio (float) – Fraction of clients to sample in each round.
device (str) – Device to run the model on (‘cpu’ or ‘cuda’).
Methods
__init__
(model_selector, model_name, ...)Initialize the FedAvgServerHandler.
aggregate
(parameters_list, weights_list)Aggregate model parameters from multiple clients using weighted averaging.
Create a downlink package containing the current global model parameters to send to clients.
global_update
(buffer)Aggregate client updates and update the global model parameters.
if_stop
()Check if the training process should stop.
load
(payload)Load a client's uplink package into the server's buffer and perform a global update if all expected packages for the round are received.
Randomly sample a subset of clients for the current training round.
- static aggregate(parameters_list: list[Tensor], weights_list: list[int]) Tensor [source]#
Aggregate model parameters from multiple clients using weighted averaging.
- Parameters:
parameters_list (list[torch.Tensor]) – List of serialized model parameters
clients. (from)
weights_list (list[int]) – List of data sizes corresponding to each client’s
parameters.
- Returns:
Aggregated model parameters.
- Return type:
torch.Tensor
- downlink_package() FedAvgDownlinkPackage [source]#
Create a downlink package containing the current global model parameters to send to clients.
- Returns:
Downlink package with serialized model parameters.
- Return type:
FedAvgDownlinkPackage
- global_update(buffer: list[FedAvgUplinkPackage]) None [source]#
Aggregate client updates and update the global model parameters.
- Parameters:
buffer (list[FedAvgUplinkPackage]) – List of uplink packages from clients.
- if_stop() bool [source]#
Check if the training process should stop.
- Returns:
True if the current round exceeds or equals the total number of global rounds; False otherwise.
- Return type:
bool
- load(payload: FedAvgUplinkPackage) bool [source]#
Load a client’s uplink package into the server’s buffer and perform a global update if all expected packages for the round are received.
- Parameters:
payload (FedAvgUplinkPackage) – Uplink package from a client.
- Returns:
True if a global update was performed; False otherwise.
- Return type:
bool