PyTorch

PyTorch - AutoEncoder

장비 정 2021. 6. 1. 14:01
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torchvision import datasets, transforms

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'current device is {device}')

train_data = datasets.FashionMNIST(
	root = 'data',
    download = True,
    train = False,
    transform = transforms.ToTensor(),
)

test_data = datasets.FashionMNIST(
	root = 'data',
    download = True,
    train = False,
    transform = transforms.ToTensor(),
)

train_loader = torch.utils.data.DataLoader(
	train_data, batch_size = 32, shuffle = True
)

test_loader = torch.utils.data.DataLoader(
	test_data, batch_size = 32, shuffle = False
)

for (x_train, y_train) in train_loader:
	print(f'x_train : {x_train.size()}, x_train_type : {x_train.type()}')
    print(f'x_train : {x_train.size()}, x_train_type : {x_train.type()}')
	break
    
class Net(nn.Module):
	def __init__(self):
    	super(Net, self).__init__
        
		self.encoder = nn.Sequential(
        	nn.Linear(28 * 28, 512)
            nn.ReLu()
            nn.Linear(512, 256)
            nn.ReLu()
            nn.Linear(256, 32)
            nn.ReLu()
        )
        
        self.decoder = nn.Sequnetial(
        	nn.Linear(32, 256)
            nn.ReLu()
            nn.Linear(256, 512)
            nn.ReLu()
            nn.Linear(512, 28 * 28)
        )
        
    def forward(self, x):
    	encoded = self.encoder(x)
        decoded = self.decoder(encoded)
    
    return encoded, decoded
    
def init(m):
	if isinstance(m, nn.Linear):
    	init.kaiming_uniform_(m.weight.data)
        
model = Net().to(device)
model.apply(init)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
	model.parameters(),
    lr = 1e-3
)

def train(model, train_loader, optimizer, log_interval):
	model.train()
    
    for batch_idx(image, _) in enumerate(train_loader):
    	image = image.view(-1, 28 * 28).to(device)
        target = image.view(-1, 28 * 28).to(device)
        
        optimizer.zero_grad()
        encoded, decoded = model(image)
        loss = criterion(decoded, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % log_interval == 0:
        	print(f'Epochs : {epoch} [{batch_idx * len(image)}/{len(train_loader.dataset)} ({100 * batch_idx / len(train_loader):.0f})%] Train Loss : {loss.item():.4f}')
            

def test(model, test_loader):
	model.eval()
    test_loss = 0
    real_image = list()
    gen_image = list()
    
    with torch.no_grad():
    	for image, _ in test_loader:
        	image = image.view(-1, 28 * 28).to(device)
            target = image.view(-1, 28 * 28).to(device)
            
            encoded, decoded = model(image)
            test_loss += criterion(decoded, target)
            
            real_image.append(target.to(device))
            gen_image.append(decoded.to(device))
            
        test_loss /= len(test_laoder.dataset)
        
    return test_loss, real_image, gen_image
    
for epoch in range(1, epochs + 1):
	train(model, train_loader, optimizer, log_interval=200)
    test_loss, real_image, gen_image = test(model, test_loader)
    print(f'Epochs : {epoch}, test_loss : {test_loss}')