๊ตฌํํ ์ ์กฐ๋ ๊ฐ์ ๋ชจ๋ธ Python์ฝ๋
๋
ผ๋ฌธ์ ๋์จ๋๋ก ๊ตฌํํ์๊ณ ํ์ต์ ๊ฑธ์ด๋์์ต๋๋ค.
๊ฒฐ๊ณผ๊ฐ ์ ๋์์ผ๋ฉด ์ข๊ฒ ๊ตฐ์..
์ด์ ํฌ์คํธ์์ ์ธ๊ธํ๋ฏ์ด ์ด ๋ชจ๋ธ์ ํ์ต์์ผ์ ๋ฐฐํฌ๋ ONNX๋ก ์งํ์ ํ ๊ฒ์
๋๋ค.
๊ทธ ONNX๋ชจ๋ธ์ C++ ONNX๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ด์ฉํ์ฌ ๋ก๋ํด์ ๋์๊ฐ๋๋ก ๋ง๋ค ๊ฒ์ด๊ณ .
๊ทธ ์ดํ์ C# ๋ํ ๊ณผ์ ์ ๊ฑฐ์ณ UI์ ์๊ฐ์ ์ผ๋ก ํ์ํ ์์ ์
๋๋ค.
๋ฌผ๋ก ๋๊ฒ์ ๋ฐฐํฌ๋์ด์๋ ์ ๋ช
ํ ONNX C#๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ด์ฉํด์ ์ถ๋ก ํ๋๊ฒ๋ ๊ฐ๋ฅํฉ๋๋ค.
์ ๊ทธ๋ฆฌ๊ณ ์ ๊ฐ ์ฌ์ฉํ ์ ์กฐ๋ ๊ฐ์ ๋ฐ์ดํฐ์
์ ์๋ ๋งํฌ์์ ๋ฐ์ผ์ค ์ ์์ต๋๋ค.
Separable Convolution Block
import torch
class DepthwiseSeparableConv(torch.nn.Module):
def __init__(self, in_ch, out_ch):
super(DepthwiseSeparableConv, self).__init__()
self.depth_conv = torch.nn.Conv2d(in_ch, in_ch, kernel_size=3, padding=1, groups=in_ch)
self.point_conv = torch.nn.Conv2d(in_ch, out_ch, kernel_size=1)
def forward(self, x):
return self.point_conv(self.depth_conv(x))
ZeroDCE Model
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.block import DepthwiseSeparableConv
class ZeroDCEPP(nn.Module):
def __init__(self, scale_factor=4, num_features=32):
super(ZeroDCEPP, self).__init__()
self.relu = nn.ReLU(inplace=True)
self.scale_factor = scale_factor
self.upsample = nn.UpsamplingBilinear2d(scale_factor=self.scale_factor)
self.e_conv1 = DepthwiseSeparableConv(3, num_features)
self.e_conv2 = DepthwiseSeparableConv(num_features, num_features)
self.e_conv3 = DepthwiseSeparableConv(num_features, num_features)
self.e_conv4 = DepthwiseSeparableConv(num_features, num_features)
self.e_conv5 = DepthwiseSeparableConv(num_features * 2, num_features) # x3 + x4
self.e_conv6 = DepthwiseSeparableConv(num_features * 2, num_features) # x2 + x5
self.e_conv7 = DepthwiseSeparableConv(num_features * 2, 3) # x1 + x6
def apply_curve(self, x, x_r):
for _ in range(8):
x = x + x_r * (torch.pow(x, 2) - x)
return x
def forward(self, x):
# ๋ค์ด ์ค์ผ์ผ
if self.scale_factor == 1:
x_down = x
else:
x_down = F.interpolate(x, scale_factor=1/self.scale_factor, mode='bilinear', align_corners=False)
# ํน์ง ์ถ์ถ ๋ ์ด์ด ์์
layer1 = self.relu(self.e_conv1(x_down))
layer2 = self.relu(self.e_conv2(layer1))
layer3 = self.relu(self.e_conv3(layer2))
layer4 = self.relu(self.e_conv4(layer3))
# Feature Concat
layer5 = self.relu(self.e_conv5(torch.cat([layer3, layer4], 1)))
layer6 = self.relu(self.e_conv6(torch.cat([layer2, layer5], 1)))
layer7 = self.e_conv7(torch.cat([layer1, layer6], 1))
x_r = torch.tanh(layer7)
# ์๋ณธ ์ฌ์ด์ฆ๋ก ๋ค์ upsampling
if self.scale_factor != 1:
x_r = self.upsample(x_r)
# ์๋ณธ ์ด๋ฏธ์ง์ ๋ผ์ดํธ๋งต์ ๊ณฑํจ.
enhance_image = self.apply_curve(x, x_r)
return enhance_image, x_r
ONNX Model ๋ฐฐํฌ๋ฅผ ์ํ ๋ชจ๋ธ ์ฝ๋
import torch
class OnnxModel(torch.nn.Module):
def __init__(self, backbone:torch.nn.Module):
super(OnnxModel, self).__init__()
self.backbone = backbone
def forward(self, x):
## ํ์ตํ ๋ชจ๋ธ์ ๋ด๋ณด๋ด๊ธฐ ํ ๋ curve parameter๋ inference์ ํ์๊ฐ ์๊ธฐ ๋๋ฌธ์.
## OnnxModel์ฉ ํด๋์ค๋ฅผ ๋ง๋ค์ด์ curve๋ฅผ ์ ์ธํ ๊ฒฐ๊ณผ๋ง ๋ฆฌํดํ๋๋ก ์์
x, curve = self.backbone(x)
return x
Training Code
import torch
import torch.onnx
import numpy as np
import cv2
import os
from model.zerodcepp import ZeroDCEPP
from loss.zerodce_loss import ZeroDCETotalLoss
from dataset.lle_dataset import get_lle_loader
from model.onnx_model import OnnxModel
# GPU ์ฌ์ฉ ๊ฐ๋ฅ ์ฌ๋ถ ํ์ธ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"ํ์ฌ ์ฌ์ฉ ์ค์ธ ๋๋ฐ์ด์ค: {device}")
# Hyper Parameter
epochs = 50
learning_rate = 0.0001
weight_decay = 0.0001
batch_size = 8
scale_factor = 4
num_features = 32
image_width = 1024
image_height = 1024
image_channel = 3
dataset_path = "C://github//dataset//lol_dataset//our485//all"
weight_path = "C://github//Dot4Seminar//working//python//results//weights.pth"
onnx_model_path = "C://github//Dot4Seminar//working//python//results//model.onnx"
dummy_input = torch.randn(size=(1, image_channel, image_height, image_width)).to(device)
model = ZeroDCEPP(scale_factor=scale_factor, num_features=num_features)
model = model.to(device)
if os.path.exists(weight_path):
state_dict = torch.load(weight_path, map_location=device)
model.load_state_dict(state_dict)
dataloader = get_lle_loader(dataset_path, batch_size, resize_shape=(image_height, image_width))
total_batches = len(dataloader)
loss = ZeroDCETotalLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
temp_loss = 1000000
for epoch in range(epochs):
avg_loss = 0
model.train()
for i, x_image in enumerate(dataloader):
gpu_x_image = x_image.to(device)
enhanced_img, curve_params = model(gpu_x_image)
current_loss = loss(gpu_x_image, enhanced_img, curve_params)
optimizer.zero_grad()
current_loss.backward()
optimizer.step()
avg_loss += current_loss.item() / total_batches
tensor_input_check = gpu_x_image[0]
tensor_output_check = enhanced_img[0]
# 1. CPU ์ด๋ ๋ฐ Numpy ๋ณํ
input = tensor_input_check.detach().cpu().numpy()
input = np.transpose(input, (1, 2, 0))
input = (input * 255).astype(np.uint8)
input = cv2.cvtColor(input, cv2.COLOR_RGB2BGR)
# 1. CPU ์ด๋ ๋ฐ Numpy ๋ณํ
output = tensor_output_check.detach().cpu().numpy()
output = np.transpose(output, (1, 2, 0))
output = (output * 255).astype(np.uint8)
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
# 5. ์๊ฐํ
cv2.imshow("original", input)
cv2.imshow("output", output)
cv2.waitKey(1) # 1ms ๋๊ธฐ (ํ์ต ๋ฃจํ ๋ฉ์ถค ๋ฐฉ์ง)
if temp_loss > avg_loss:
temp_loss = avg_loss
model.eval()
onnx_model = OnnxModel(backbone=model)
torch.save(model.state_dict(), weight_path)
torch.onnx.export(
onnx_model, # ์คํํ ๋ชจ๋ธ
dummy_input, # ๋ชจ๋ธ ์
๋ ฅ ์์
onnx_model_path, # ์ ์ฅ ํ์ผ๋ช
export_params=True, # ๋ชจ๋ธ ํ์ผ ์์ ํ์ต๋ ํ๋ผ๋ฏธํฐ ์ ์ฅ
opset_version=11, # Bilinear ์ฐ์ฐ์ ์์ ์ ์ผ๋ก ์ง์ํ๋ ๋ฒ์
do_constant_folding=True, # ์์ ํด๋ฉ ์ต์ ํ (์๋ ํฅ์)
input_names=['input'], # ์
๋ ฅ ๋
ธ๋ ์ด๋ฆ (C++์์ ํธ์ถ ์ ์ฌ์ฉ)
output_names=['output'], # ์ถ๋ ฅ ๋
ธ๋ ์ด๋ฆ
)
print('current avg loss = ', avg_loss)