오답노트

PyTorch - torch 기본 (autograd) 본문

PyTorch

PyTorch - torch 기본 (autograd)

장비 정 2021. 5. 26. 22:35

딥러닝 모델이 훈련을 진행함에 있어 back propagation 을 통해 gradient 를 업데이트 시키는데

PyTorch 에서는 이 부분을 어떻게 사용하는지에 대해 코드로 작성시켰다.

 

import torch

if torch.cuda.is_available():
    device = torch.device('cuda')
else :
    device = torch.device('cpu')

Batch_size = 64     # 파라미터를 업데이트 할 때 계산 되는 데이터 수
Input_size = 1000   # 입력층의 노드 수 (64, 1000)
Hidden_size = 100   # 은닉층의 노드 수 (1000, 100)
Output_size = 10    # 출력층의 노드 수 (100, 10)

x = torch.randn(
    Batch_size,
    Input_size,
    device=device,
    dtype=torch.float,
    requires_grad=False,
) # 입력층으로 들어가는 데이터
  # randn : 평균이 0, 표준편차가 1 인 정규분포에서 샘플링함
  # parameter 값을 업데이트 하기 위해 gradient 를 하는 것이지 input data 를 gradient 하는 것이 아님

y = torch.randn(
    Batch_size,
    Output_size,
    device=device,
    dtype=torch.float,
    requires_grad=False,
) # 출력층을 통해 최종적으로 나가는 데이터

w1 = torch.randn(
    Input_size,
    Hidden_size,
    device=device,
    dtype=torch.float,
    requires_grad=True
) # 입력층을 통해 들어온 데이터가 연산 될 은닉층으로 들어가는 데이터

w2 = torch.randn(
    Hidden_size,
    Output_size,
    device=device,
    dtype=torch.float,
    requires_grad=True
) # 은닉층에서 연산이 끝나고 출력층으로 들어가는 데이터

learning_rate = 1e-6
for t in range(1, 501):
    y_pred = x.mm(w1).clamp(min = 0).mm(w2)
    # mm = 행렬 곱, clamp = torch 에서 제공하는 비선형 함수(relu)

    loss = (y_pred - y).pow(2).sum() # mse
    if t % 100 == 0:
        print('Iteration : ', t, '\t', 'Loss : ', loss.item())
    loss.backward() # 각 파라미터에 gradient 를 계산하고 back propagation 을 진행

    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad
        # w 값을 위에서 설정한 learning_rate 를 곱한 뒤 gradient 를 계산한다
        # 음수값을 취한 이유는, 가장 최소값의 gradient 를 찾기 위해 반대방향으로 계산한다

        w1.grad.zero_()
        w2.grad.zero_()
        # 최종적으로 gradient 가 계산이 되었다면, 0으로 초기화하여 다시 처음부터 반복문을 돌린다

'PyTorch' 카테고리의 다른 글

PyTorch - Dropout  (0) 2021.05.30
PyTorch - 기본 이미지 분류 모델  (0) 2021.05.30
PyTorch - torch 기본 (tensor)  (0) 2021.05.26
PyTorch - torch 기본 (matrix)  (0) 2021.05.26
PyTorch - torch 기본 (vector)  (0) 2021.05.26