본문 바로가기
파이썬/코드 TEST

(AutoEncoder)(Resnet) Resnet을 이용한 학습 코드

by Think_JUNG 2025. 8. 22.

Resnet 자주 쓰는데, 라이브러리 단에서 가져오는 거 말고, Custom 하게 쓰고 싶을 때가 있음.

분류 모델은 FC만 빼면 되는데 AE의 경우 Encoder, Decoder가 있어서 Custom 하게 하기 힘듬

간단하게 작성해보자

(프로젝트는 Github 올리기)

 

import torch
from torch import nn
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader,Dataset

from PIL import Image
import numpy as np
import random
import glob
import cv2
import os

import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F






class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=False):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # 다운샘플 경로
        self.downsample = None
        if downsample or in_channels != out_channels:
            # 채널 크기 또는 공간 크기를 맞추기 위해 Conv2d 사용
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        shortcut = x

        # 메인 경로
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        # Shortcut Connection: 크기 맞추기
        if self.downsample is not None:
            shortcut = self.downsample(x)

        out += shortcut
        out = self.relu(out)

        return out


# ResNet 기반 Encoder
class ResNetEncoder(nn.Module):
    def __init__(self, latent_dim):
        super(ResNetEncoder, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.resblock1 = ResNetBlock(64, 128, stride=2, downsample=True)
        self.resblock2 = ResNetBlock(128, 256, stride=2, downsample=True)
        self.resblock3 = ResNetBlock(256, 512, stride=2, downsample=True)
        self.flatten = nn.Flatten()

        # 잠재 공간 크기를 latent_dim으로 축소
        self.fc = nn.Linear(512 * 7 * 7, latent_dim)

    def forward(self, x):
        x = self.initial(x)
        x = self.resblock1(x)
        x = self.resblock2(x)
        x = self.resblock3(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

class ResNetDecoder(nn.Module):
    def __init__(self, latent_dim):
        super(ResNetDecoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 512 * 7 * 7)  # 잠재 공간 -> 특정 크기
        self.relu = nn.ReLU(inplace=True)
        
        # 디코딩 전용 업샘플링 레이어와 ResNet Block
        self.decode = nn.Sequential(
            # ResNet Blocks
            ResNetBlock(512, 256),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),  # 14x14
            ResNetBlock(256, 128),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),  # 28x28
            ResNetBlock(128, 64),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),  # 56x56
            ResNetBlock(64, 32),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),  # 112x112
        )

        # 최종 출력 레이어
        self.final_layer = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),  # 112x112 -> 224x224
            nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
            nn.Tanh()  # 픽셀 값을 (-1, 1) 범위로 출력
            #nn.Sigmoid()
        )

    def forward(self, x):
        x = self.fc(x)  # 잠재 공간으로부터 풀기
        x = torch.relu(x.view(-1, 512, 7, 7))
        x = self.decode(x)  # Sequential 실행
        x = self.final_layer(x)  # 출력 크기: (1, 224, 224)
        return x


# Autoencoder 통합
class ResNetAutoencoder(nn.Module):
    def __init__(self, latent_dim=128):
        super(ResNetAutoencoder, self).__init__()
        self.encoder = ResNetEncoder(latent_dim)
        self.decoder = ResNetDecoder(latent_dim)

    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return latent, reconstructed

 

 

사용 시에는 단순히

model = ResNetAutoencoder()

 

댓글