PyTorch - Chargement des données

PyTorch comprend un package appelé torchvision qui est utilisé pour charger et préparer l'ensemble de données. Il comprend deux fonctions de base à savoir Dataset et DataLoader qui aide à la transformation et au chargement de l'ensemble de données.

Base de données

L'ensemble de données est utilisé pour lire et transformer un point de données à partir de l'ensemble de données donné. La syntaxe de base à implémenter est mentionnée ci-dessous -

trainset = torchvision.datasets.CIFAR10(root = './data', train = True,
   download = True, transform = transform)

DataLoader est utilisé pour mélanger et regrouper les données. Il peut être utilisé pour charger les données en parallèle avec des nœuds de calcul multitraitement.

trainloader = torch.utils.data.DataLoader(trainset, batch_size = 4,
   shuffle = True, num_workers = 2)

Exemple: chargement d'un fichier CSV

Nous utilisons le package Python Panda pour charger le fichier csv. Le fichier d'origine a le format suivant: (nom de l'image, 68 points de repère - chaque point de repère a des coordonnées ax, y).

landmarks_frame = pd.read_csv('faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)