blazefl.contrib.FedAvgServerHandler#

class blazefl.contrib.FedAvgServerHandler(model_selector: ModelSelector, model_name: str, dataset: PartitionedDataset, global_round: int, num_clients: int, sample_ratio: float, device: str)[source]#

Bases: ServerHandler

Server-side handler for the Federated Averaging (FedAvg) algorithm.

Manages the global model, coordinates client sampling, aggregates client updates, and controls the training process across multiple rounds.

model#

The global model being trained.

Type:

torch.nn.Module

dataset#

Dataset partitioned across clients.

Type:

PartitionedDataset

global_round#

Total number of federated learning rounds.

Type:

int

num_clients#

Total number of clients in the federation.

Type:

int

sample_ratio#

Fraction of clients to sample in each round.

Type:

float

device#

Device to run the model on (‘cpu’ or ‘cuda’).

Type:

str

client_buffer_cache#

Cache for storing client

Type:

list[FedAvgUplinkPackage]

updates before aggregation.
num_clients_per_round#

Number of clients sampled per round.

Type:

int

round#

Current training round.

Type:

int

__init__(model_selector: ModelSelector, model_name: str, dataset: PartitionedDataset, global_round: int, num_clients: int, sample_ratio: float, device: str) None[source]#

Initialize the FedAvgServerHandler.

Parameters:
  • model_selector (ModelSelector) – Selector for initializing the model.

  • model_name (str) – Name of the model to be used.

  • dataset (PartitionedDataset) – Dataset partitioned across clients.

  • global_round (int) – Total number of federated learning rounds.

  • num_clients (int) – Total number of clients in the federation.

  • sample_ratio (float) – Fraction of clients to sample in each round.

  • device (str) – Device to run the model on (‘cpu’ or ‘cuda’).

Methods

__init__(model_selector, model_name, ...)

Initialize the FedAvgServerHandler.

aggregate(parameters_list, weights_list)

Aggregate model parameters from multiple clients using weighted averaging.

downlink_package()

Create a downlink package containing the current global model parameters to send to clients.

global_update(buffer)

Aggregate client updates and update the global model parameters.

if_stop()

Check if the training process should stop.

load(payload)

Load a client's uplink package into the server's buffer and perform a global update if all expected packages for the round are received.

sample_clients()

Randomly sample a subset of clients for the current training round.

static aggregate(parameters_list: list[Tensor], weights_list: list[int]) Tensor[source]#

Aggregate model parameters from multiple clients using weighted averaging.

Parameters:
  • parameters_list (list[torch.Tensor]) – List of serialized model parameters

  • clients. (from)

  • weights_list (list[int]) – List of data sizes corresponding to each client’s

  • parameters.

Returns:

Aggregated model parameters.

Return type:

torch.Tensor

Create a downlink package containing the current global model parameters to send to clients.

Returns:

Downlink package with serialized model parameters.

Return type:

FedAvgDownlinkPackage

global_update(buffer: list[FedAvgUplinkPackage]) None[source]#

Aggregate client updates and update the global model parameters.

Parameters:

buffer (list[FedAvgUplinkPackage]) – List of uplink packages from clients.

if_stop() bool[source]#

Check if the training process should stop.

Returns:

True if the current round exceeds or equals the total number of global rounds; False otherwise.

Return type:

bool

load(payload: FedAvgUplinkPackage) bool[source]#

Load a client’s uplink package into the server’s buffer and perform a global update if all expected packages for the round are received.

Parameters:

payload (FedAvgUplinkPackage) – Uplink package from a client.

Returns:

True if a global update was performed; False otherwise.

Return type:

bool

sample_clients() list[int][source]#

Randomly sample a subset of clients for the current training round.

Returns:

Sorted list of sampled client IDs.

Return type:

list[int]