zip-sa | Data is important

Backpropagation with PyTorch 본문

AI/ML

Backpropagation with PyTorch

zip-sa 2024. 8. 16. 17:36

역전파 (Backpropagation)

1. Backpropagation이란?

Backpropagation(역전파)은 신경망에서 학습을 가능하게 하는 핵심 알고리즘입니다. 이 알고리즘은 모델의 예측 결과와 실제 결과 사이의 차이(오차)를 기반으로 가중치와 편향을 조정하여, 모델이 더 정확한 예측을 할 수 있도록 만듭니다. 이 과정은 네트워크의 마지막 층에서부터 첫 번째 층까지 거슬러 올라가며 진행되기 때문에 "역전파"라고 불립니다.

 

2. 계산 그래프 (Computational Graph)

계산 그래프는 복잡한 계산 과정을 시각적으로 표현한 구조입니다. 신경망에서 각 연산을 노드로 나타내고, 이들 간의 관계를 연결하여 표현합니다. 계산 그래프를 사용하면, 각 연산에 대한 미분(gradient)을 효율적으로 계산할 수 있습니다. 이를 통해 네트워크의 가중치를 업데이트하고, 오차를 줄일 수 있습니다.

 

출처 : https://zhangruochi.com/Computational-Graph/2019/12/06/

 

3. 연쇄 법칙 (Chain Rule)을 활용한 Backpropagation

Backpropagation의 핵심 원리는 연쇄 법칙(Chain Rule)을 사용해 그래디언트를 계산하는 것입니다. 연쇄 법칙은 복합 함수의 미분을 계산할 때, 각 중간 단계에서의 미분을 곱하여 전체 미분을 계산하는 방법입니다. 신경망에서는 이 방법을 통해 출력 층에서 입력 층까지 각 가중치에 대한 그래디언트를 계산하게 됩니다.

 

출처 : https://medium.com/@El_Fares_Anass/a-basic-explanation-how-the-gradient-descent-is-determined-during-back-propagation-864376f8f1a4

4. PyTorch로 구현하는 Backpropagation

아래는 PyTorch를 사용해 간단한 신경망에서 Backpropagation을 구현하는 예제입니다.

import torch
import torch.nn as nn
import torch.optim as optim

# 간단한 데이터셋 생성
x = torch.tensor([[1.0], [2.0], [3.0], [4.0]], requires_grad=True)
y = torch.tensor([[2.0], [4.0], [6.0], [8.0]], requires_grad=True)

# 간단한 신경망 정의
model = nn.Sequential(
    nn.Linear(1, 3),
    nn.ReLU(),
    nn.Linear(3, 1)
)

# 손실 함수와 최적화기 정의
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 모델 학습
for epoch in range(100):
    # Forward pass
    outputs = model(x)
    loss = criterion(outputs, y)

    # Backward pass and optimization
    optimizer.zero_grad()  # 기울기 초기화
    loss.backward()  # Backpropagation
    optimizer.step()  # 가중치 업데이트

    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')

이 코드는 간단한 신경망을 정의하고, 주어진 데이터셋을 사용해 학습하는 예제입니다. loss.backward() 메서드를 통해 Backpropagation이 수행되며, optimizer.step()을 사용해 가중치가 업데이트됩니다.

 

5. 예제로 보는 Backpropagation 과정

계산 그래프 예제를 통해 Backpropagation이 어떻게 작동하는지 단계별로 설명해보겠습니다.

 

예제 함수:
$$ f(x, y, z) = (x + y) \times z $$

주어진 값:
$$(x = -2), (y = 5), (z = -4)$$

  1. Forward Pass:
    • $(q = x + y = 3)$
    • $(f = q \times z = -12)$
  2. Backward Pass:
    • $( \frac{\partial f}{\partial z} = q = 3 )$
    • $( \frac{\partial f}{\partial q} = z = -4 )$
    • $( \frac{\partial q}{\partial x} = 1 )$
    • $( \frac{\partial q}{\partial y} = 1 )$
    • 따라서, $( \frac{\partial f}{\partial x} = \frac{\partial f}{\partial q} \times \frac{\partial q}{\partial x} = -4 \times 1 = -4 )$
    • $( \frac{\partial f}{\partial y} = \frac{\partial f}{\partial q} \times \frac{\partial q}{\partial y} = -4 \times 1 = -4 )$

이처럼 Backpropagation을 통해 각 변수의 그래디언트를 계산하여, 네트워크의 가중치를 업데이트할 수 있습니다.

 

결론

Backpropagation은 신경망의 학습 과정에서 필수적인 역할을 합니다. 이 글에서는 Backpropagation의 원리, 계산 그래프, 연쇄 법칙을 통한 미분 계산, 그리고 PyTorch로 구현하는 방법을 살펴보았습니다. Backpropagation의 이해는 신경망 설계 및 최적화에 매우 중요하며, 이를 바탕으로 더 복잡한 네트워크 구조와 문제를 해결할 수 있습니다.