utility.py

import os
import math
import time
import datetime
from functools import reduce

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import numpy as np
import scipy.misc as misc
from skimage.restoration import denoise_bilateral

import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lrs

class timer():
    def __init__(self):
        self.acc = 0
        self.tic()
        #print ("2-1-1-checkpoint")

    def tic(self):
        self.t0 = time.time()
        #print ("2-1-2-checkpoint")

    def toc(self):
        return time.time() - self.t0
        #print ("2-1-3-checkpoint")

    def hold(self):
        self.acc += self.toc()
        #print ("2-1-4-checkpoint")

    def release(self):
        ret = self.acc
        self.acc = 0
        #print ("2-1-5-checkpoint")

        return ret

    def reset(self):
        self.acc = 0
        #print ("2-1-6-checkpoint")

class checkpoint():
    def __init__(self, args):
        self.args = args
        self.ok = True
        self.log = torch.Tensor()
        now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')

        if args.load == '.':
            if args.save == '.': args.save = now
            self.dir = '../experiment/' + args.save
        else:
            self.dir = '../experiment/' + args.load
            if not os.path.exists(self.dir):
                args.load = '.'
            else:
                self.log = torch.load(self.dir + '/psnr_log.pt')
                print('Continue from epoch {}...'.format(len(self.log)))

        if args.reset:
            os.system('rm -rf ' + self.dir)
            args.load = '.'

        def _make_dir(path):
            if not os.path.exists(path): os.makedirs(path)

        _make_dir(self.dir)
        _make_dir(self.dir + '/model')
        _make_dir(self.dir + '/results')
        _make_dir(self.dir + '/residuals')
        _make_dir(self.dir + '/branches')

        open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'
        self.log_file = open(self.dir + '/log.txt', open_type)
        with open(self.dir + '/config.txt', open_type) as f:
            f.write(now + '\n\n')
            for arg in vars(args):
                f.write('{}: {}\n'.format(arg, getattr(args, arg)))
            f.write('\n')
        #print ("2-2-1-checkpoint")

    def save(self, trainer, epoch, is_best=False):
        trainer.model.save(self.dir, epoch, is_best=is_best)
        trainer.loss.save(self.dir)
        trainer.loss.plot_loss(self.dir, epoch)

        self.plot_psnr(epoch)
        torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt'))
        torch.save(
            trainer.optimizer.state_dict(),
            os.path.join(self.dir, 'optimizer.pt')
        )
        #print ("2-2-2-checkpoint")

    def add_log(self, log):
        self.log = torch.cat([self.log, log])
        #print ("2-2-3-checkpoint")

    def write_log(self, log, refresh=False):
        #print(log)
        self.log_file.write(log + '\n')
        if refresh:
            self.log_file.close()
            self.log_file = open(self.dir + '/log.txt', 'a')
        #print ("2-2-4-checkpoint")

    def done(self):
        self.log_file.close()
        #print ("2-2-5-checkpoint")

    def plot_psnr(self, epoch):
        axis = np.linspace(1, epoch, epoch)
        label = 'SR on {}'.format(self.args.data_test)
        fig = plt.figure()
        plt.title(label)
        for idx_scale, scale in enumerate(self.args.scale):
            plt.plot(
                axis,
                self.log[:, idx_scale].numpy(),
                label='Scale {}'.format(scale)
            )
        plt.legend()
        plt.xlabel('Epochs')
        plt.ylabel('PSNR')
        plt.grid(True)
        plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test))
        plt.close(fig)
        #print ("2-2-5-checkpoint")

    def save_results(self, filename, save_list, scale):
        filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale)
        postfix = ('SR', 'LR', 'HR')
        for v, p in zip(save_list, postfix):
            normalized = v[0].data.mul(255 / self.args.rgb_range)
            ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
            
            if ndarr.shape[-1] == 1: 
                    ndarr = ndarr[:,:,0] 
                    
            misc.imsave('{}{}.png'.format(filename, p), ndarr)
        #print ("2-2-6-checkpoint")

    def save_residuals(self, filename, save_list, scale): 
        filename = '{}/residuals/{}_x{}'.format(self.dir, filename, scale)
        sr, hr = save_list[0], save_list[-1]

        def _prepare(x):
            normalized = x[0].data.mul(1. / self.args.rgb_range)
            out = normalized.permute(1,2,0).cpu().numpy()
            
            if out.shape[-1] == 1: 
                out = out[:,:,0]

            return out 

        ndarr_sr, ndarr_hr = _prepare(sr), _prepare(hr)
        out = np.abs(ndarr_hr - ndarr_sr)
        misc.imsave('{}.png'.format(filename), out)
        #print ("2-2-7-checkpoint")

    def save_branches(self, filename, save_list, scale): 
        filename = '{}/branches/{}_x{}'.format(self.dir, filename, scale)
        
        def _prepare(x, residual):
            normalized = x[0].data.mul(1. / self.args.rgb_range)
            if not residual: 
                out = normalized.permute(1,2,0).cpu().numpy()
            else: 
                out = np.abs(normalized.permute(1,2,0).cpu().numpy())

            if out.shape[-1] == 1: 
                out = out[:,:,0]
            return out 

        for i, branch_output in enumerate(save_list): 
            ndarr = _prepare(branch_output, not (i==0))
            misc.imsave('{}{}.png'.format(filename, '_branch{}'.format(i)), ndarr)
        #print ("2-2-8-checkpoint")
        return 

def get_bilateral(tensor, rgb_range): 
    tensor = tensor.numpy().transpose(0,2,3,1) / rgb_range
    out = np.zeros_like(tensor)

    for i, t in enumerate(tensor): 
        out[i] = denoise_bilateral(t)

    #print ("2-3-checkpoint")
    return torch.Tensor(out.transpose(0,3,1,2)) * rgb_range

def quantize(img, rgb_range):
    pixel_range = 255 / rgb_range
    return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
    #print ("2-4-checkpoint")

def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
    diff = (sr - hr).data.div(rgb_range)
    if benchmark:
        shave = scale
        if diff.size(1) > 1:
            convert = diff.new(1, 3, 1, 1)
            convert[0, 0, 0, 0] = 65.738
            convert[0, 1, 0, 0] = 129.057
            convert[0, 2, 0, 0] = 25.064
            diff.mul_(convert).div_(256)
            diff = diff.sum(dim=1, keepdim=True)
    else:
        shave = scale + 6

    valid = diff[:, :, shave:-shave, shave:-shave]
    mse = valid.pow(2).mean()
    #print ("2-5-checkpoint")
    return -10 * math.log10(mse)

def make_optimizer(args, my_model):
    trainable = filter(lambda x: x.requires_grad, my_model.parameters())

    if args.optimizer == 'SGD':
        optimizer_function = optim.SGD
        kwargs = {'momentum': args.momentum}
    elif args.optimizer == 'ADAM':
        optimizer_function = optim.Adam
        kwargs = {
            'betas': (args.beta1, args.beta2),
            'eps': args.epsilon
        }
    elif args.optimizer == 'RMSprop':
        optimizer_function = optim.RMSprop
        kwargs = {'eps': args.epsilon}

    kwargs['lr'] = args.lr
    kwargs['weight_decay'] = args.weight_decay
    #print ("2-6-checkpoint")
    return optimizer_function(trainable, **kwargs)

def make_scheduler(args, my_optimizer):
    if args.decay_type == 'step':
        scheduler = lrs.StepLR(
            my_optimizer,
            step_size=args.lr_decay,
            gamma=args.gamma
        )
    elif args.decay_type.find('step') >= 0:
        milestones = args.decay_type.split('_')
        milestones.pop(0)
        milestones = list(map(lambda x: int(x), milestones))
        scheduler = lrs.MultiStepLR(
            my_optimizer,
            milestones=milestones,
            gamma=args.gamma
        )
    #print ("2-7-checkpoint")
    return scheduler
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時請結(jié)合常識與多方信息審慎甄別。
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

友情鏈接更多精彩內(nèi)容