[torch] Upsampling


서론

이 장은 이미지 크기를 키우기 위해 transposed convolution을 사용하지 않고 pytorch의 다른 library인 Upsampling으로 이미지 크기를 키우는데 사용하는 방법, 작동 원리, 예시등을 위주로 설명한다.

목차

  1. Dependency
  2. Upsampling
  3. Example
  4. Result

Dependency

위 library를 사용하기 위해서는import torch.nn as nn 시켜줘야 한다.

Upsampling

nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 과 같이 사용된다.

  • scale_factor: input이미지의 height, width에 scale_factor만큼 곱해진다.
    • 예를들어 input image가 (batch, channel, 224, 224)이고 scale_factor=2라면 input image의 height, width에 2가 곱해져 (batch, channel, 448, 448) 형태로 이미지의 크기가 늘어난다.
  • mode: Interpolation 방법을 선택하는 Parameter이다. 위 코드는 bilinear알고리즘으로 이미지를 확장하는데 사용한다.
  • align_corners: 이미지를 확장한 후 edge부분을 유지할지 말지에 대한 여부이다.
    • 만약 align_corner=True 이라면 이미지의 edge부분이 원본 픽셀과 동일하게 유지된다.
    • 1-dimension의 이미지가 [0, 1]로 구성되고 scale_factor=2, mode=bilinear, align_corners=False 로 Upsampling을 하면 2배로 확장해야 하기 때문에 [-0.25, 0.25, 0.75, 1.25]로 나오게 된다.
    • 반면에, align_corners=False인 경우는 양 끝의 픽셀은 유지하기 때문에 [0, 1/3, 2/3, 1]의 결과가 된다.

Example

아래 코드는 이미지하나를 불러와 224, 224로 크기를 조절한 다음 Upsampling을 적용해 이미지 크기를 2배로 키우고 키운 이미지와 키우기 전 원본 이미지를 출력하는 코드이다.

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import cv2
import numpy as np

# PIL.Image로 이미지를 로드합니다.
image = Image.open("./image-15.webp")
image = image.resize((224, 224))
# 이미지를 텐서로 변환하고 [0, 1] 범위로 정규화합니다.
transform = transforms.Compose([
    transforms.ToTensor()
])
image = transform(image) # 이미지의 크기는 다음과 같다: (3, 224, 224)
# 배치 차원을 추가한다.
image = image.unsqueeze(0) # (1, 3, 224, 224)
# nn.Upsample 객체를 생성한다.
upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

# 이미지를 업샘플링한다.
upsampled_image = upsample(image)

# PyTorch 텐서를 numpy 배열로 변환하고 [0, 255] 범위로 조정한다.
# cv2는 [0, 255] 범위의 uint8 를 받는다..
original_image_np = image[0].permute(1, 2, 0).numpy() * 255
original_image_np = original_image_np.astype(np.uint8)

# print(image.shape)
# print(image[0].shape)

upsampled_image_np = upsampled_image[0].detach().permute(1, 2, 0).numpy() * 255
upsampled_image_np = upsampled_image_np.astype(np.uint8)

# BGR to RGB 변환
original_image_np = cv2.cvtColor(original_image_np, cv2.COLOR_BGR2RGB)
upsampled_image_np = cv2.cvtColor(upsampled_image_np, cv2.COLOR_BGR2RGB)

# 이미지의 크기를 표시한다.
print(f"Original Image:{original_image_np.shape}\nUpsampleing Image:{upsampled_image_np.shape}")

# 이미지를 표시한다.
cv2.imshow('Original Image', original_image_np)
cv2.imshow('Upsampled Image', upsampled_image_np)

cv2.waitKey(0)
cv2.destroyAllWindows()

Result