Source code for full_dia.dataloader
from torch.utils.data.dataset import Dataset
try:
_ = profile
except NameError:
[docs]
def profile(func):
return func
[docs]
class MapDataset(Dataset):
def __init__(self, maps, valid_ion_nums, labels):
self.maps = maps
self.valid_ion_nums = valid_ion_nums
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
maps = self.maps[idx] # [ion_num, 13, 50]
y = self.labels[idx]
valid_ion_num = self.valid_ion_nums[idx]
return (maps, valid_ion_num, y)
[docs]
class MallDataset(Dataset):
def __init__(self, malls, valid_ion_nums, labels):
self.malls = malls
self.valid_ion_nums = valid_ion_nums
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
mall = self.malls[idx]
y = self.labels[idx]
valid_ion_num = self.valid_ion_nums[idx]
return (mall, valid_ion_num, y)