Source code for blazefl.utils.dataset
from collections.abc import Callable
from torch.utils.data import Dataset
[docs]
class FilteredDataset(Dataset):
"""
A dataset wrapper that filters and transforms a subset of the original dataset.
This class allows selecting specific data points by their indices and
applying optional transformations to the data and targets.
Attributes:
data (list): The filtered subset of the original dataset.
targets (list | None): The filtered subset of targets, if provided.
transform (Callable | None): A function to apply transformations to the data.
target_transform (Callable | None): A function to apply
transformations to the targets.
"""
[docs]
def __init__(
self,
indices: list[int],
original_data: list,
original_targets: list | None = None,
transform: Callable | None = None,
target_transform: Callable | None = None,
) -> None:
"""
Initialize the FilteredDataset.
Args:
indices (list[int]): Indices of the data points to include in the dataset.
original_data (list): The original dataset.
original_targets (list | None): The original targets, if available.
transform (Callable | None): Transformation function for the data.
target_transform (Callable | None): Transformation function for the targets.
"""
self.data = [original_data[i] for i in indices]
if original_targets is not None:
assert len(original_data) == len(original_targets)
self.targets = [original_targets[i] for i in indices]
self.transform = transform
self.target_transform = target_transform
def __len__(self) -> int:
"""
Return the length of the filtered dataset.
Returns:
int: The number of data points in the dataset.
"""
return len(self.data)
def __getitem__(self, index: int) -> tuple:
"""
Retrieve a data item (and optionally its target) at a specific index.
Args:
index (int): The index of the data point to retrieve.
Returns:
tuple: A tuple containing the transformed data item and its target
(if available).
"""
img = self.data[index]
if self.transform is not None:
img = self.transform(img)
if hasattr(self, "targets"):
target = self.targets[index]
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
return img