天天看点

实验记录

1、basic model 训练采用X4的数据,noise的范围为【0,15】;

2、subnetwork训练的数据应该是采用noise范围为【0,66】;不采用tvloss,同样采用LR进行训练,

3、finutune两个模型。基于BN层。X4与【0,66】的数据;

python train.py -opt options/train/train_sr.json

nosie level estimation 网络的setting

{
  "name": "subnetwork_X1_0_66" //  please remove "debug_" during training
  , "tb_logger_dir": "sft666"
  , "use_tb_logger": true
  , "model":"sr"
  , "scale": 1
  , "crop_scale": 0
  , "gpu_ids": [3,4]
//  , "init_type": "kaiming"
//
//  , "finetune_type": "sft"
//  , "init_norm_type": "zero"

  , "datasets": {
    "train": {
      "name": "DIV2K"
      , "mode": "LRHR"
      , "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4"
      , "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4_noise_66"
      , "subset_file": null
      , "use_shuffle": true
      , "n_workers": 8
      , "batch_size": 24 // 16
      , "HR_size": 96 // 128 | 192 | 96
      , "noise_gt": true//residual for the noise
      , "use_flip": true
      , "use_rot": true
    }

  , "val": {
      "name": "val_set5_c03s08_LR_mod4"
      , "mode": "LRHR"
      , "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4"
      , "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4_noise_66"
      , "noise_gt": true//residual for the noise
    }

  }

  , "path": {
    "root": "/home/guanwp/Blind_Restoration-master/sft666"
    , "pretrain_model_G": null
  }


//
  , "network_G": {
    "which_model_G": "denoise_resnet" // RRDB_net | sr_resnet | modulate_denoise_resnet |noise_subnet
    //    , "norm_type": "adaptive_conv_res"
    //, "norm_type": "batch"
    ,
    "mode": "CNA",
    "nf": 64,
    "nb": 16,
    "in_nc": 3,
    "out_nc": 3
    //    , "gc": 32
    ,
    "group": 1
      ,"down_scale": 2//denoise srresnet
    //    , "gate_conv_bias": true
    //    , "ada_ksize": 1
    //    , "num_classes": 2
  }


//    , "network_G": {
//    "which_model_G": "srcnn" // RRDB_net | sr_resnet
    , "norm_type": null
//    , "norm_type": "adaptive_conv_res"
//    , "mode": "CNA"
//    , "nf": 64
//    , "in_nc": 1
//    , "out_nc": 1
//    , "ada_ksize": 5
//  }


  , "train": {
//    "lr_G": 1e-3
    "lr_G": 1e-4
    , "lr_scheme": "MultiStepLR"
//    , "lr_steps": [200000, 400000, 600000, 800000]
    , "lr_steps": [500000]
    , "lr_gamma": 0.1
//    , "lr_gamma": 0.5

    , "pixel_criterion": "l2"

    , "pixel_criterion_reg": "tv"

    , "pixel_weight": 1.0
    , "val_freq": 2e3

    , "manual_seed": 0
    , "niter": 1e6
//    , "niter": 6e5
  }

  , "logger": {
    "print_freq": 200
    , "save_checkpoint_freq": 2e3
  }
}
           

basic model的setting

{
  "name": "subnetwork_X1_0_66" //  please remove "debug_" during training
  , "tb_logger_dir": "sft666"
  , "use_tb_logger": true
  , "model":"sr"
  , "scale": 4
  , "crop_scale": 0
  , "gpu_ids": [3,4]
//  , "init_type": "kaiming"
//
//  , "finetune_type": "sft"
//  , "init_norm_type": "zero"

  , "datasets": {
    "train": {
      "name": "DIV2K"
      , "mode": "LRHR"
      , "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub"
      , "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4_noise_15"
      , "subset_file": null
      , "use_shuffle": true
      , "n_workers": 8
      , "batch_size": 24 // 16
      , "HR_size": 96 // 128 | 192 | 96
      //, "noise_gt": true//residual for the noise
      , "use_flip": true
      , "use_rot": true
    }

  , "val": {
      "name": "val_set5_c03s08_LR_mod4"
      , "mode": "LRHR"
      , "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5"
      , "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4_noise_15"
      //, "noise_gt": true//residual for the noise
    }

  }

  , "path": {
    "root": "/home/guanwp/Blind_Restoration-master/sft666"
    , "pretrain_model_G": null
  }


//
  , "network_G": {
    "which_model_G": "sr_resnet" // RRDB_net | sr_resnet | modulate_denoise_resnet |noise_subnet
    //    , "norm_type": "adaptive_conv_res"
    , "norm_type": null//"batch"
    ,
    "mode": "CNA",
    "nf": 64,
    "nb": 16,
    "in_nc": 3,
    "out_nc": 3
    //    , "gc": 32
    ,
    "group": 1
      //,"down_scale": 2//denoise srresnet
    //    , "gate_conv_bias": true
    //    , "ada_ksize": 1
    //    , "num_classes": 2
  }


//    , "network_G": {
//    "which_model_G": "srcnn" // RRDB_net | sr_resnet
    , "norm_type": null
//    , "norm_type": "adaptive_conv_res"
//    , "mode": "CNA"
//    , "nf": 64
//    , "in_nc": 1
//    , "out_nc": 1
//    , "ada_ksize": 5
//  }


  , "train": {
//    "lr_G": 1e-3
    "lr_G": 1e-4
    , "lr_scheme": "MultiStepLR"
//    , "lr_steps": [200000, 400000, 600000, 800000]
    , "lr_steps": [500000]
    , "lr_gamma": 0.1
//    , "lr_gamma": 0.5

    , "pixel_criterion": "l2"

    , "pixel_criterion_reg": "tv"

    , "pixel_weight": 1.0
    , "val_freq": 2e3

    , "manual_seed": 0
    , "niter": 1e6
//    , "niter": 6e5
  }

  , "logger": {
    "print_freq": 200
    , "save_checkpoint_freq": 2e3
  }
}
           

模型训练好的结果如下:

实验记录

好,接下来训练SFT-layer。由于在训练SFT-layer的时候,原来的basicmodel是没有BN层的,当然也没有SFT-layer层,就是网络的结构是发生了变化的,故此需要用transfer_params_degration.py文件进行模型的转换

import torch
from torch.nn import init
from collections import OrderedDict

# pretrained_net = torch.load('../../baselines/experiments/bicx3_nonorm_denoise_resnet_DIV2K/models/794000_G.pth')
# pretrained_net = torch.load('../../baselines_jpeg/experiments/JPEG80_gray_nonorm_denoise_resnet_DIV2K/models/964000_G.pth')
# pretrained_net = torch.load('../../noise_from15to75/experiments/gaussian_from15to75_resnet_denoise_DIV2K/models/986000_G.pth')
# pretrained_net = torch.load('/home/jwhe/workspace/BasicSR_v3/experiments/pretrained_models/noise_c16s06/bicx4_nonorm_denoise_resnet_DIV2K/992000_G.pth')
# pretrained_net = torch.load('/home/jwhe/workspace/BasicSR_v3/sr_c16s06/experiments/LR_srx4_c16s06_resnet_denoise_DIV2K/models/704000_G.pth')
pretrained_net = torch.load('/home/jwhe/workspace/BasicSR_v3/sr/experiments/LR_srx4_resnet_denoise_DIV2K/models/516000_G.pth')
# should run train debug mode first to get an initial model
# pretrained_net_degradation = torch.load('../../noise_subnet/experiments/noise75_subnet/models/84000_G.pth')
# pretrained_net_degradation = torch.load('../../noise_subnet/experiments/noise15_subnet/models/34000_G.pth')

# reference_net = torch.load('../../noise_estimate/experiments/finetune_75_sft_64_nores_noise_estimate_15_denoise_resnet_DIV2K/models/20000_G.pth')
adaptive_norm_net = OrderedDict()
# initialize the norm with value 0
# for k, v in adaptive_norm_net.items():
#     if 'gamma' in k:
#         print(k, 'gamma')
#         v.fill_(0)
#     elif 'beta' in k:
#         print(k, 'beta')
#         v.fill_(0)

# for k, v in adaptive_norm_net.items():
#     if 'gamma' in k:
#         print(k, 'gamma')
#         v.fill_(0)
#     elif 'beta' in k:
#         print(k, 'beta')
#         v.fill_(0)

adaptive_norm_net['fea_conv.0.weight'] = pretrained_net['model.0.weight']
adaptive_norm_net['fea_conv.0.bias'] = pretrained_net['model.0.bias']
# residual blocks
for i in range(16):
    adaptive_norm_net['norm_branch.{:d}.conv_block0.0.weight'.format(i)] = pretrained_net['model.1.sub.{:d}.res.0.weight'.format(i)]
    adaptive_norm_net['norm_branch.{:d}.conv_block0.0.bias'.format(i)] = pretrained_net['model.1.sub.{:d}.res.0.bias'.format(i)]
    adaptive_norm_net['norm_branch.{:d}.conv_block1.0.weight'.format(i)] = pretrained_net['model.1.sub.{:d}.res.2.weight'.format(i)]
    adaptive_norm_net['norm_branch.{:d}.conv_block1.0.bias'.format(i)] = pretrained_net['model.1.sub.{:d}.res.2.bias'.format(i)]

adaptive_norm_net['LR_conv.0.weight'] = pretrained_net['model.1.sub.16.weight']
adaptive_norm_net['LR_conv.0.bias'] = pretrained_net['model.1.sub.16.bias']

# HR
# adaptive_norm_net['HR_branch.0.weight'] = pretrained_net['model.2.weight']
# adaptive_norm_net['HR_branch.0.bias'] = pretrained_net['model.2.bias']
# adaptive_norm_net['HR_branch.3.weight'] = pretrained_net['model.5.weight']
# adaptive_norm_net['HR_branch.3.bias'] = pretrained_net['model.5.bias']
# adaptive_norm_net['HR_branch.5.weight'] = pretrained_net['model.7.weight']
# adaptive_norm_net['HR_branch.5.bias'] = pretrained_net['model.7.bias']

adaptive_norm_net['HR_branch.0.weight'] = pretrained_net['model.2.weight']
adaptive_norm_net['HR_branch.0.bias'] = pretrained_net['model.2.bias']
adaptive_norm_net['HR_branch.3.weight'] = pretrained_net['model.5.weight']
adaptive_norm_net['HR_branch.3.bias'] = pretrained_net['model.5.bias']
adaptive_norm_net['HR_branch.6.weight'] = pretrained_net['model.8.weight']
adaptive_norm_net['HR_branch.6.bias'] = pretrained_net['model.8.bias']
adaptive_norm_net['HR_branch.8.weight'] = pretrained_net['model.10.weight']
adaptive_norm_net['HR_branch.8.bias'] = pretrained_net['model.10.bias']

# for k, v in pretrained_net_degradation.items():
#     adaptive_norm_net[k] = v

print('OK. \n Saving model...')
# torch.save(adaptive_norm_net, '../../experiments/pretrained_models/jpeg_estimation/JPEG80_gray_nonorm_denoise_resnet_DIV2K/jpeg80_964000.pth')
# torch.save(adaptive_norm_net, '../../experiments/pretrained_models/50to15_models/gaussian75_nonorm_denoise_resnet_DIV2K/noise_CNA_adaptive_988000.pth')
# torch.save(adaptive_norm_net, '../../experiments/pretrained_models/basic_model/gaussian_from15to75_resnet_denoise_DIV2K/from15to75_basicmodel_986000.pth')
# torch.save(adaptive_norm_net, '../../experiments/pretrained_models/sr_c16s06/LR_srx4_c16s06_resnet_denoise_DIV2K/c16s06_basicmodel_704000.pth')
torch.save(adaptive_norm_net, '../../experiments/pretrained_models/sr/LR_srx4_resnet_denoise_DIV2K/basicmodel_516000.pth')
           

python train.py -opt options/train/train_sub_sr.json

setting(注意不要noise_groundth)

{
  "name": "finetune_noiseestimation_sftlayer_15to66" //  please remove "debug_" during training
  , "tb_logger_dir": "sft666"
  , "use_tb_logger": true
  , "model":"sr_sub"
  , "scale": 4
  , "crop_scale": 4
  , "gpu_ids": [4,5]
//  , "init_type": "kaiming"
//
//  , "finetune_type": "sft_basic" //sft | basic
//  , "init_norm_type": "zero"

  , "datasets": {
    "train": {
      "name": "DIV2K"
      , "mode": "LRHR"
      , "dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub"
      , "dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4_noise_66"
      , "subset_file": null
      , "use_shuffle": true
      , "n_workers": 8
      , "batch_size": 24 // 16
      , "HR_size": 96 // 128 | 192 | 96
      //, "noise_gt": true
      , "use_flip": true
      , "use_rot": true
    }
//
  , "val": {
      "name": "val_set5_x4_c03s08_mod4",
      "mode": "LRHR",
      "dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5",
      "dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_sub_bicLRx4_noise_66"
     // , "noise_gt": true
    }
//
//    , "val": {
//      "name": "val_set5_x3_gray_mod6"
//      , "mode": "LRHR"
//      , "dataroot_HR": "/media/sdc/jwhe/BasicSR_v2/data/val/Set5_val/mod6/Set5_gray_mod6"
//      , "dataroot_LR": "/media/sdc/jwhe/BasicSR_v2/data/val/Set5_val/mod6/Set5_gray_bicx3"
//    }
  }

  , "path": {
    "root": "/home/guanwp/Blind_Restoration-master/sft666"
  , "pretrain_model_G": "/home/guanwp/Blind_Restoration-master/sft/686000_G.pth"
  , "pretrain_model_sub": "/home/guanwp/Blind_Restoration-master/sft666/experiments/subnetwork_X1_0_66/models/528000_G.pth"
  }

  , "network_G": {
    "which_model_G": "modulate_sr_resnet" // RRDB_net | sr_resnet | modulate_sr_resnet
//    , "norm_type": "adaptive_conv_res"
    , "norm_type": "sft"
//    , "norm_type": null
    , "mode": "CNA"
    , "nf": 64
    , "nb": 16
    , "in_nc": 3
    , "out_nc": 3
//    , "gc": 32
    , "group": 1
    , "gate_conv_bias": true
  }

  , "network_sub": {
    "which_model_sub": "denoise_resnet" // RRDB_net | sr_resnet | modulate_denoise_resnet |noise_subnet
    //    , "norm_type": "adaptive_conv_res"
    ,
    //"norm_type": "batch",
    "mode": "CNA",
    "nf": 64
        , "nb": 16
    ,
    "in_nc": 3,
    "out_nc": 3,
    "group": 1,
    "down_scale": 2
  }


  , "train": {
//    "lr_G": 1e-3
    "lr_G": 1e-4
    , "lr_scheme": "MultiStepLR"
//    , "lr_steps": [200000, 400000, 600000, 800000]
    , "lr_steps": [500000]
    , "lr_gamma": 0.1
//    , "lr_gamma": 0.5

    , "pixel_criterion_basic": "l2"
    , "pixel_criterion_noise": "l2"
    , "pixel_criterion_reg_noise": "tv"
    , "pixel_weight_basic": 1.0
    , "pixel_weight_noise": 1.0
    , "val_freq": 2e3

    , "manual_seed": 0
    , "niter": 1e6
//    , "niter": 6e5
  }

  , "logger": {
    "print_freq": 200
    , "save_checkpoint_freq": 2e3
  }
}
           

首先给出sub_srmodel,关键部分应该是在optimize_parameters,可以看到数据的流向。与级联网络不一样的地方是,这里不需要concat

import os
from collections import OrderedDict

import torch
import torch.nn as nn
from torch.optim import lr_scheduler

import models.networks as networks
from .base_model import BaseModel
from .modules.loss import TVLoss


class SRModel(BaseModel):
    def __init__(self, opt):
        super(SRModel, self).__init__(opt)
        train_opt = opt['train']
        finetune_type = opt['finetune_type']

        # define network and load pretrained models
        self.netG = networks.define_G(opt).to(self.device)
        self.subnet = networks.define_sub(opt).to(self.device)
        self.load()

        if self.is_train:
            self.netG.train()
            self.subnet.train()
            # self.subnet.eval()

            # loss
            loss_type_noise = train_opt['pixel_criterion_noise']
            if loss_type_noise == 'l1':
                self.cri_pix_noise = nn.L1Loss().to(self.device)
            elif loss_type_noise == 'l2':
                self.cri_pix_noise = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type_noise))
            self.l_pix_noise_w = train_opt['pixel_weight_noise']

            loss_reg_noise = train_opt['pixel_criterion_reg_noise']
            if loss_reg_noise == 'tv':
                self.cri_pix_reg_noise = TVLoss(0.00001).to(self.device)

            loss_type_basic = train_opt['pixel_criterion_basic']
            if loss_type_basic == 'l1':
                self.cri_pix_basic = nn.L1Loss().to(self.device)
            elif loss_type_basic == 'l2':
                self.cri_pix_basic = nn.MSELoss().to(self.device)
            else:
                raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type_basic))
            self.l_pix_basic_w = train_opt['pixel_weight_basic']

            # optimizers
            wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0

            self.optim_params = self.__define_grad_params(finetune_type)

            self.optimizer_G = torch.optim.Adam(
                self.optim_params, lr=train_opt['lr_G'], weight_decay=wd_G)
            self.optimizers.append(self.optimizer_G)

            # schedulers
            if train_opt['lr_scheme'] == 'MultiStepLR':
                for optimizer in self.optimizers:
                    self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
                        train_opt['lr_steps'], train_opt['lr_gamma']))
            else:
                raise NotImplementedError('MultiStepLR learning rate scheme is enough.')

            self.log_dict = OrderedDict()

        print('---------- Model initialized ------------------')
        self.print_network()
        print('-----------------------------------------------')

    def feed_data(self, data, need_HR=True):
        self.var_L = data['LR'].to(self.device)  # LR
        #self.real_H = data['HR'].to(self.device)  # HR
        #self.mid_L = data['MR'].to(self.device)  # MR

    def __define_grad_params(self, finetune_type=None):

        optim_params = []

        if finetune_type == 'sft':
            for k, v in self.netG.named_parameters():
                v.requires_grad = False
                if k.find('Gate') >= 0:
                    v.requires_grad = True
                    optim_params.append(v)
                    print('we only optimize params: {}'.format(k))
        elif finetune_type == 'sub_sft':
            for k, v in self.netG.named_parameters():
                v.requires_grad = False
                if k.find('Gate') >= 0:
                    v.requires_grad = True
                    optim_params.append(v)
                    print('we only optimize params: {}'.format(k))
            for k, v in self.subnet.named_parameters():  # can optimize for a part of the model
                v.requires_grad = False
                if k.find('degration') >= 0:
                    v.requires_grad = True
                    optim_params.append(v)
                    print('we only optimize params: {}'.format(k))
        elif finetune_type == 'basic' or finetune_type == 'sft_basic':
            for k, v in self.netG.named_parameters():
                v.requires_grad = True
                optim_params.append(v)
                print('we only optimize params: {}'.format(k))
            for k, v in self.subnet.named_parameters():
                v.requires_grad = False
        else:
            for k, v in self.netG.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                    print('params [{:s}] will optimize.'.format(k))
                else:
                    print('WARNING: params [{:s}] will not optimize.'.format(k))
            for k, v in self.subnet.named_parameters():  # can optimize for a part of the model
                if v.requires_grad:
                    optim_params.append(v)
                    print('params [{:s}] will optimize.'.format(k))
                else:
                    print('WARNING: params [{:s}] will not optimize.'.format(k))
        return optim_params

    def optimize_parameters(self, step):

        self.optimizer_G.zero_grad()

        self.fake_noise = self.subnet(self.var_L)
        l_pix_noise = self.l_pix_noise_w * self.cri_pix_noise(self.fake_noise, self.mid_L)
        l_pix_noise = l_pix_noise + self.cri_pix_reg_noise(self.fake_noise)
        # self.fake_H = self.netG(torch.cat((self.var_L, self.fake_noise), 1))
        self.fake_H = self.netG((self.var_L, self.fake_noise))
        l_pix_basic = self.l_pix_basic_w * self.cri_pix_basic(self.fake_H, self.real_H)
        l_pix = l_pix_noise + l_pix_basic
        l_pix.backward()

        # self.fake_noise = self.subnet(self.var_L)
        # # self.fake_H = self.netG(torch.cat((self.var_L, self.fake_noise), 1))
        # self.fake_H = self.netG((self.var_L, self.fake_noise))
        # l_pix = self.l_pix_basic_w * self.cri_pix_basic(self.fake_H, self.real_H)
        # l_pix.backward()

        self.optimizer_G.step()

        self.log_dict['l_pix'] = l_pix.item()

    def test(self):
        self.netG.eval()
        self.subnet.eval()
        if self.is_train:
            for v in self.optim_params:
                v.requires_grad = False
        else:
            for k, v in self.netG.named_parameters():
                v.requires_grad = False
            for k, v in self.subnet.named_parameters():
                v.requires_grad = False
        self.fake_noise = self.subnet(self.var_L)
        self.fake_H = self.netG(torch.cat((self.var_L, self.fake_noise*2.6), 1))
        #self.fake_H = self.netG((self.var_L, self.fake_noise))
        if self.is_train:
            for v in self.optim_params:
                v.requires_grad = True
        else:
            for k, v in self.netG.named_parameters():
                v.requires_grad = True
            for k, v in self.subnet.named_parameters():
                v.requires_grad = True
        self.netG.train()
        self.subnet.train()
        # self.subnet.eval()

    # def test(self):
    #     self.netG.eval()
    #     for k, v in self.netG.named_parameters():
    #         v.requires_grad = False
    #     self.fake_H = self.netG(self.var_L)
    #     for k, v in self.netG.named_parameters():
    #         v.requires_grad = True
    #     self.netG.train()

    def get_current_log(self):
        return self.log_dict

    def get_current_visuals(self, need_HR=True):
        out_dict = OrderedDict()
        out_dict['LR'] = self.var_L.detach()[0].float().cpu()
        out_dict['MR'] = self.fake_noise.detach()[0].float().cpu()
        out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
        if need_HR:
            out_dict['HR'] = self.real_H.detach()[0].float().cpu()
        return out_dict

    def print_network(self):
        # G
        s, n = self.get_network_description(self.netG)
        print('Number of parameters in G: {:,d}'.format(n))
        if self.is_train:
            message = '-------------- Generator --------------\n' + s + '\n'
            network_path = os.path.join(self.save_dir, '../', 'network.txt')
            with open(network_path, 'w') as f:
                f.write(message)

            # subnet
            s, n = self.get_network_description(self.subnet)
            print('Number of parameters in subnet: {:,d}'.format(n))
            message = '\n\n\n-------------- subnet --------------\n' + s + '\n'
            with open(network_path, 'a') as f:
                f.write(message)

    def load(self):
        load_path_G = self.opt['path']['pretrain_model_G']
        load_path_sub = self.opt['path']['pretrain_model_sub']
        if load_path_G is not None:
            print('loading model for G [{:s}] ...'.format(load_path_G))
            self.load_network(load_path_G, self.netG)
        if load_path_sub is not None:
            print('loading model for subnet [{:s}] ...'.format(load_path_sub))
            self.load_network(load_path_sub, self.subnet)

    def save(self, iter_label):
        self.save_network(self.save_dir, self.netG, 'G', iter_label)
        self.save_network(self.save_dir, self.subnet, 'sub', iter_label)
           

关于SFT-layer部分可以参考博文《 学习笔记之——基于pytorch的SFTGAN(xintao代码学习,及数据处理部分的学习)》

实验记录

通过SFT-layer使得basic model从【0,15】到【0,66】

结果:

实验记录
实验记录

继续阅读