Source code for blazefl.utils.seed

import os
import random
from dataclasses import dataclass

import numpy as np
import torch


[docs] def seed_everything(seed: int, device: str) -> None: """ Seed random number generators for reproducibility. This function sets seeds for Python's random module, NumPy, and PyTorch to ensure deterministic behavior in experiments. Args: seed (int): The seed value to set. device (str): The device type ('cpu' or 'cuda'). Returns: None """ random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) if device.startswith("cuda"): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
@dataclass class CUDARandomState: """ A dataclass representing the random state for CUDA. Attributes: manual_seed (int): The manual seed for CUDA. cudnn_deterministic (bool): The deterministic setting for cuDNN. cudnn_benchmark (bool): The benchmark setting for cuDNN. cuda_rng_state (torch.Tensor): The RNG state for CUDA. """ manual_seed: int cudnn_deterministic: bool cudnn_benchmark: bool cuda_rng_state: torch.Tensor
[docs] @dataclass class RandomState: """ A dataclass representing the random state for Python, NumPy, and PyTorch. Attributes: random (tuple): The state of Python's random module. environ (str): The PYTHONHASHSEED environment variable. numpy (dict): The state of NumPy's RNG. torch_seed (int): The initial seed for PyTorch. torch_rng_state (torch.Tensor): The RNG state for PyTorch. cuda (CUDARandomState | None): The CUDA-specific random state, if available. """ random: tuple environ: str numpy: dict torch_seed: int torch_rng_state: torch.Tensor cuda: CUDARandomState | None
[docs] @classmethod def get_random_state(cls, device: str) -> "RandomState": """ Capture the current random state. Args: device (str): The device type ('cpu' or 'cuda'). Returns: RandomState: The captured random state. """ if device.startswith("cuda"): return cls( random.getstate(), os.environ["PYTHONHASHSEED"], np.random.get_state(), torch.initial_seed(), torch.cuda.get_rng_state(), CUDARandomState( torch.cuda.initial_seed(), torch.backends.cudnn.deterministic, torch.backends.cudnn.benchmark, torch.cuda.get_rng_state(), ), ) return cls( random.getstate(), os.environ["PYTHONHASHSEED"], np.random.get_state(), torch.initial_seed(), torch.get_rng_state(), None, )
[docs] @staticmethod def set_random_state(random_state: "RandomState") -> None: """ Restore the random state from a RandomState object. Args: random_state (RandomState): The random state to restore. Returns: None """ random.setstate(random_state.random) os.environ["PYTHONHASHSEED"] = random_state.environ np.random.set_state(random_state.numpy) torch.manual_seed(random_state.torch_seed) if random_state.cuda is not None: torch.cuda.manual_seed(random_state.cuda.manual_seed) torch.backends.cudnn.deterministic = random_state.cuda.cudnn_deterministic torch.backends.cudnn.benchmark = random_state.cuda.cudnn_benchmark torch.cuda.set_rng_state(random_state.cuda.cuda_rng_state) else: torch.set_rng_state(random_state.torch_rng_state)