본문 바로가기
옵티마이저

[옵티마이저 설명] Adam: A Method for Stochastic Optimization (2015) Part 03 PyTorch 코드 구현

by 인공지능과 함께 2023. 11. 5.

들어가기에 앞서

이 포스트는 Part 03에 해당되며, Adam의 수식을 기반으로 실제 사용 가능한 PyTorch 기반의 Adam optimizer를 구현하는 것을 목표로 하고 있습니다. 이 시리즈에 대한정보는 다음과 같습니다.

 

Python 코드부터 일단 구현해 보자

PyTorch용 코드를 구현하기 전에 python으로 대강 코드를 작성해 보겠습니다. 이걸 우린 유사코드라고 부르죠? 실제 PyTorch 코드와 유사하게 작성해 봤습니다. 우선 모델의 파라미터는 layer마다 다르기 때문에 for문으로 모든 layer를 꺼내옵니다. 그런 다음 각각의 layer에 대해 Adam 수식을 적용하는 것입니다. Part 01에서 본 코드와 거의 유사한 코드입니다.

m = []
v = []
def optimizer(params: list, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, is_first=True):
	for layer_index, param in enumerate(params):
		grad = param.grad
		if is_first == True:
			m[layer_index] = 0	#	m 초기화
			v[layer_index] = 0	#	v 초기화
			step = 1
		m[layer_index] = betas[0] * m[layer_index] + (1 - betas[0]) * grad
		v[layer_index] = betas[1] * v[layer_index] + (1 - betas[1]) * (grad ** 2)

		m_hat = m[layer_index] / (1 - betas[0] ** step)
		v_hat = v[layer_index] / (1 - betas[1] ** step)
		
		update = lr * m_hat / sqrt(v_hat + eps)
		param -= update
		step += 1

 

 

 

PyTorch 코드 작성

import torch.optim as optim

class Adam(optim.Optimizer):
	def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
		defaults = {
			'lr': lr,
			'betas': betas,
			'eps': eps
		}	#	이게 group으로 들어감

		super(Adam, self).__init__(params, defaults)

	@torch.no_grad()
	def step(self, closure=None):
		for group in self.param_groups:
			for layer_index, param in enumerate(group["params"]):
				grad = param.grad
				state = self.state[param]
				if len(state) == 0:	#	is_first == True
					state["m"] = torch.zeros(param.shape)
					state["v"] = torch.zeros(param.shape)
					state["step"] = 1
				state["m"] = group["betas"][0] * (1 - group["betas"][0]) * grad
				state["v"] = group["betas"][1] * (1 - group["betas"][1]) * (grad**2)
				
				m_hat = state["m"] / (1 - group["betas"][0] ** step)
				v_hat = state["v"] / (1 - group["betas"][1] ** step)
				
				update = group["lr"] * m_hat / torch.sqrt(v_hat + group["eps"])
				param -= update
				state["step"] += 1

차근차근 한 줄씩 보면 Python 코드와 동일하다는 것을 알 수 있으며, 나머지 부분은 PyTorch framework만의 요소들입니다.