Source code for blazefl.core.client_trainer
import multiprocessing as mp
from abc import ABC, abstractmethod
from multiprocessing.pool import ApplyResult
from pathlib import Path
from typing import Generic, TypeVar
import torch
from tqdm import tqdm
UplinkPackage = TypeVar("UplinkPackage")
DownlinkPackage = TypeVar("DownlinkPackage")
[docs]
class SerialClientTrainer(ABC, Generic[UplinkPackage, DownlinkPackage]):
"""
Abstract base class for serial client training in federated learning.
This class defines the interface for training clients in a serial manner,
where each client is processed one after the other.
Raises:
NotImplementedError: If the methods are not implemented in a subclass.
"""
[docs]
@abstractmethod
def uplink_package(self) -> list[UplinkPackage]:
"""
Prepare the data package to be sent from the client to the server.
Returns:
list[UplinkPackage]: A list of data packages prepared for uplink
transmission.
"""
...
[docs]
@abstractmethod
def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None:
"""
Process the downlink payload from the server for a list of client IDs.
Args:
payload (DownlinkPackage): The data package received from the server.
cid_list (list[int]): A list of client IDs to process.
Returns:
None
"""
...
DiskSharedData = TypeVar("DiskSharedData")
[docs]
class ParallelClientTrainer(
SerialClientTrainer[UplinkPackage, DownlinkPackage],
Generic[UplinkPackage, DownlinkPackage, DiskSharedData],
):
"""
Abstract base class for parallel client training in federated learning.
This class extends SerialClientTrainer to enable parallel processing of clients,
allowing multiple clients to be trained concurrently.
Attributes:
num_parallels (int): Number of parallel processes to use for client training.
share_dir (Path): Directory path for sharing data between processes.
cache (list[UplinkPackage]): Cache to store uplink packages from clients.
Raises:
NotImplementedError: If the abstract methods are not implemented in a subclass.
"""
[docs]
def __init__(self, num_parallels: int, share_dir: Path) -> None:
"""
Initialize the ParallelClientTrainer with parallelism settings.
Args:
num_parallels (int): Number of parallel processes to use.
share_dir (Path): Directory path for sharing data between processes.
"""
self.num_parallels = num_parallels
self.share_dir = share_dir
self.share_dir.mkdir(parents=True, exist_ok=True)
self.cache: list[UplinkPackage] = []
[docs]
@abstractmethod
def get_shared_data(self, cid: int, payload: DownlinkPackage) -> DiskSharedData:
"""
Retrieve shared data for a given client ID and payload.
Args:
cid (int): Client ID.
payload (DownlinkPackage): The data package received from the server.
Returns:
DiskSharedData: The shared data associated with the client ID and payload.
"""
...
[docs]
@staticmethod
@abstractmethod
def process_client(path: Path) -> Path:
"""
Process a single client based on the provided path.
Args:
path (Path): Path to the client's data file.
Returns:
Path: Path to the processed client's data file.
"""
...
[docs]
def local_process(self, payload: DownlinkPackage, cid_list: list[int]) -> None:
"""
Manage the parallel processing of clients.
This method distributes the processing of multiple clients across
parallel processes, handling data saving, loading, and caching.
Args:
payload (DownlinkPackage): The data package received from the server.
cid_list (list[int]): A list of client IDs to process.
Returns:
None
"""
pool = mp.Pool(processes=self.num_parallels)
jobs: list[ApplyResult] = []
for cid in cid_list:
path = self.share_dir.joinpath(f"{cid}.pkl")
data = self.get_shared_data(cid, payload)
torch.save(data, path)
jobs.append(pool.apply_async(self.process_client, (path,)))
for job in tqdm(jobs, desc="Client", leave=False):
path = job.get()
assert isinstance(path, Path)
package = torch.load(path, weights_only=False)
self.cache.append(package)
pool.close()
pool.join()