blazefl.contrib.FedAvgSerialClientTrainer#
- class blazefl.contrib.FedAvgSerialClientTrainer(model_selector: ModelSelector, model_name: str, dataset: PartitionedDataset, device: str, num_clients: int, epochs: int, batch_size: int, lr: float)[source]#
Bases:
SerialClientTrainer
[FedAvgUplinkPackage
,FedAvgDownlinkPackage
]Serial client trainer for the Federated Averaging (FedAvg) algorithm.
This trainer processes clients sequentially, training and evaluating a local model for each client based on the server-provided model parameters.
- model#
The client’s local model.
- Type:
torch.nn.Module
- dataset#
Dataset partitioned across clients.
- Type:
- device#
Device to run the model on (‘cpu’ or ‘cuda’).
- Type:
str
- num_clients#
Total number of clients in the federation.
- Type:
int
- epochs#
Number of local training epochs per client.
- Type:
int
- batch_size#
Batch size for local training.
- Type:
int
- lr#
Learning rate for the optimizer.
- Type:
float
- cache#
Cache to store uplink packages for the
- Type:
list[FedAvgUplinkPackage]
- server.
- __init__(model_selector: ModelSelector, model_name: str, dataset: PartitionedDataset, device: str, num_clients: int, epochs: int, batch_size: int, lr: float) None [source]#
Initialize the FedAvgSerialClientTrainer.
- Parameters:
model_selector (ModelSelector) – Selector for initializing the local model.
model_name (str) – Name of the model to be used.
dataset (PartitionedDataset) – Dataset partitioned across clients.
device (str) – Device to run the model on (‘cpu’ or ‘cuda’).
num_clients (int) – Total number of clients in the federation.
epochs (int) – Number of local training epochs per client.
batch_size (int) – Batch size for local training.
lr (float) – Learning rate for the optimizer.
Methods
__init__
(model_selector, model_name, ...)Initialize the FedAvgSerialClientTrainer.
evaluate
(test_loader)Evaluate the local model on the given test data loader.
local_process
(payload, cid_list)Train and evaluate the model for each client in the given list.
train
(model_parameters, train_loader)Train the local model on the given training data loader.
Retrieve the uplink packages for transmission to the server.
- evaluate(test_loader: DataLoader) tuple[float, float] [source]#
Evaluate the local model on the given test data loader.
- Parameters:
test_loader (DataLoader) – DataLoader for the evaluation data.
- Returns:
A tuple containing the average loss and accuracy.
- Return type:
tuple[float, float]
- local_process(payload: FedAvgDownlinkPackage, cid_list: list[int]) None [source]#
Train and evaluate the model for each client in the given list.
- Parameters:
payload (FedAvgDownlinkPackage) – Downlink package with global model
parameters.
cid_list (list[int]) – List of client IDs to process.
- Returns:
None
- train(model_parameters: Tensor, train_loader: DataLoader) FedAvgUplinkPackage [source]#
Train the local model on the given training data loader.
- Parameters:
model_parameters (torch.Tensor) – Global model parameters to initialize the
model. (local)
train_loader (DataLoader) – DataLoader for the training data.
- Returns:
Uplink package containing updated model parameters and data size.
- Return type:
FedAvgUplinkPackage