PyTorch
PyTorch - AutoEncoder
장비 정
2021. 6. 1. 14:01
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torchvision import datasets, transforms
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'current device is {device}')
train_data = datasets.FashionMNIST(
root = 'data',
download = True,
train = False,
transform = transforms.ToTensor(),
)
test_data = datasets.FashionMNIST(
root = 'data',
download = True,
train = False,
transform = transforms.ToTensor(),
)
train_loader = torch.utils.data.DataLoader(
train_data, batch_size = 32, shuffle = True
)
test_loader = torch.utils.data.DataLoader(
test_data, batch_size = 32, shuffle = False
)
for (x_train, y_train) in train_loader:
print(f'x_train : {x_train.size()}, x_train_type : {x_train.type()}')
print(f'x_train : {x_train.size()}, x_train_type : {x_train.type()}')
break
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 512)
nn.ReLu()
nn.Linear(512, 256)
nn.ReLu()
nn.Linear(256, 32)
nn.ReLu()
)
self.decoder = nn.Sequnetial(
nn.Linear(32, 256)
nn.ReLu()
nn.Linear(256, 512)
nn.ReLu()
nn.Linear(512, 28 * 28)
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return encoded, decoded
def init(m):
if isinstance(m, nn.Linear):
init.kaiming_uniform_(m.weight.data)
model = Net().to(device)
model.apply(init)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
model.parameters(),
lr = 1e-3
)
def train(model, train_loader, optimizer, log_interval):
model.train()
for batch_idx(image, _) in enumerate(train_loader):
image = image.view(-1, 28 * 28).to(device)
target = image.view(-1, 28 * 28).to(device)
optimizer.zero_grad()
encoded, decoded = model(image)
loss = criterion(decoded, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print(f'Epochs : {epoch} [{batch_idx * len(image)}/{len(train_loader.dataset)} ({100 * batch_idx / len(train_loader):.0f})%] Train Loss : {loss.item():.4f}')
def test(model, test_loader):
model.eval()
test_loss = 0
real_image = list()
gen_image = list()
with torch.no_grad():
for image, _ in test_loader:
image = image.view(-1, 28 * 28).to(device)
target = image.view(-1, 28 * 28).to(device)
encoded, decoded = model(image)
test_loss += criterion(decoded, target)
real_image.append(target.to(device))
gen_image.append(decoded.to(device))
test_loss /= len(test_laoder.dataset)
return test_loss, real_image, gen_image
for epoch in range(1, epochs + 1):
train(model, train_loader, optimizer, log_interval=200)
test_loss, real_image, gen_image = test(model, test_loader)
print(f'Epochs : {epoch}, test_loss : {test_loss}')