Source code for blazefl.utils.serialize
import torch
[docs]
def serialize_model(model: torch.nn.Module, cpu: bool = True) -> torch.Tensor:
"""
Serialize a PyTorch model's parameters into a flat tensor.
Args:
model (torch.nn.Module): The PyTorch model to serialize.
cpu (bool): Whether to move the serialized parameters to the CPU.
Returns:
torch.Tensor: A flat tensor containing the serialized parameters.
"""
parameters = [param.data.view(-1) for param in model.state_dict().values()]
serialized_parameters = torch.cat(parameters)
if cpu:
serialized_parameters = serialized_parameters.cpu()
return serialized_parameters
[docs]
def deserialize_model(
model: torch.nn.Module, serialized_parameters: torch.Tensor
) -> None:
"""
Deserialize a flat tensor back into a PyTorch model's parameters.
Args:
model (torch.nn.Module): The PyTorch model to update.
serialized_parameters (torch.Tensor): The tensor containing the parameters.
Returns:
None
"""
current_index = 0
for param in model.state_dict().values():
numel = param.numel()
size = param.size()
param.copy_(
serialized_parameters[current_index : current_index + numel].view(size)
)
current_index += numel