天天看點

sGAN網絡的基本實作(mnist資料集)

SGAN(ssgan)

ssgan是半監督學習生成對抗網絡。

初衷是利用GAN生成器生成的樣本來改進和提高圖像分類任務的性能。

SGAN的主要思想在鑒别器的設計。

相比普通的GAN的鑒别器輸出0和1(真和假),SGAN通過使鑒别器網絡輸出label+1類别,将其轉換為半監督上下文。

我們希望設計的鑒别器既扮演執行圖像分類任務的分類器的角色,又能區分有生成器生成的生成樣本和真實資料。

在SGAN中,就是把這個二分類(sigmoid)轉化為多分類(softmax),類型數量為C+1,指代C個标簽的資料和“一個假資料”,表示為[C_1, C_2,…,C_n,Fake]

SGAN在資料集上訓練生成模型G和D(對C + 1類别執行分類)

在訓練時,D預測輸入屬于C+1lei中的哪一類,其中添加了一個額外的類對應生成圖檔。

該方法可以用于建立資料效率更高的分類器。

sGAN網絡的基本實作(mnist資料集)

資料集處理

通過對mnist資料集進行分割

原訓練資料集60000,分割為有标簽資料集(labeled_set)1000,無标簽資料集(unlabeled_set)59000

transforms = transforms.Compose([transforms.ToTensor(),
    transforms.Normalize(mean=0.5, std=0.5)])
dataset = torchvision.datasets.MNIST("mnist", train=True, download=True,
    transform=transforms)
labeled_set_len = 1000
unlabeled_set_len = len(dataset) - labeled_set_len
labeled_set, unlabeled_set = random_split(dataset, [labeled_set_len, unlabeled_set_len])
           

生成器

生成器部分和DCGAN無無别

先将100大小噪聲進行線性擴大,再通過三層反卷積得到[1, 28, 28]像素矩陣

class Generate(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.genModel1 = nn.Sequential(
            nn.Linear(100,  256 * 7 * 7 ),
            nn.ReLU(),
            nn.BatchNorm1d(256 * 7 * 7),
        )
        self.genModel2 = nn.Sequential(
            # [128, 28, 28]
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            # [64, 14, 14] 
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            # [1, 28, 28]
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.genModel1(x)
        x = torch.reshape(x, (-1, 256, 7, 7))
        x = self.genModel2(x)
        return x
        
           

辨識器

由于通過辨識器有兩種判斷種類,一種屬于10分類結果,另一種為普通GAN分類結果(真\假),是以在輸出時分為兩種:[-1, 10] 和 [-1, 1]

總體思路還是通過卷積層增大channel,減小h、w

最後通過線性層1輸出10分類結果,通過線性層2輸出2分類結果

注:線上性之後并沒有使用激活函數,在損失函數方面做了處理。

class Discriminator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.modu1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=2),
            nn.LeakyReLU(),
            nn.Dropout2d(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2),
            nn.LeakyReLU(),
            nn.Dropout2d()
        )
        self.fc1 = nn.Linear(128*6*6, 10) #分類輸出,十個類别
        self.fc2 = nn.Linear(128*6*6, 1) #判斷真假
        self.bn = nn.BatchNorm2d(128)
    def forward(self, x, if_bn = True):
        x = self.modu1(x) # x.shape([1, 128, 6, 6])
        if if_bn:
            x = self.bn(x)
        x = torch.reshape(x, (-1, 128*6*6)) # 四維 -> 二維
        class_out = self.fc1(x)
        real_fake_out = self.fc2(x)
        # 這裡未做激活,在損失函數中處理
        return class_out, real_fake_out
           

CrossEntropyLoss相當于loss+softmax

BCEWithLogitsLoss對應二分類未激活輸出(集合了sigmoid更穩定

# 分類損失函數
loss_classfication_fn = nn.CrossEntropyLoss()
# 二分類損失函數
loss_sigmoid_fn = nn.BCEWithLogitsLoss()
           

訓練模型

辨識器訓練部分,和普通GAN不同的是,普通GAN将真實圖檔損失和fake損失相加,這裡則是真實損失、fake損失、标注資料辨識損失

dis_optim.zero_grad()
        # 計算未标注資料損失
        # 首先将未标注資料放入鑒别器當中(隻判斷真假) 将10分類傳回值不接收
        _, real_sg_ou = dis(unla_img)
        dis_real_sg_loss = loss_sigmoid_fn(real_sg_ou, torch.ones_like(real_sg_ou, device=device))
        dis_real_sg_loss.backward()
        # 計算生成資料損失
        gen_img = gen(random_noise)
        _, fake_sg_ou = dis(gen_img.detach())
        dis_fake_sg_loss = loss_sigmoid_fn(fake_sg_ou, torch.zeros_like(fake_sg_ou, device=device))
        dis_fake_sg_loss.backward()
        # 計算标注資料損失
        d_real_sfm_out, _ = dis(lb_img)
        d_real_sfm_loss = loss_classfication_fn(d_real_sfm_out, label)
        d_real_sfm_loss.backward()
        dis_loss = dis_real_sg_loss + dis_fake_sg_loss + d_real_sfm_loss
        dis_optim.step()
           

生成器訓練部分,和普通GAN沒有差別

gen_optim.zero_grad()
        _, fake_ou = dis(gen_img)
        gen_fake_sg_loss = loss_sigmoid_fn(fake_ou, torch.ones_like(fake_ou, device=device))
        gen_fake_sg_loss.backward()
        gen_optim.step()
           

測試模型

epoch 0:

sGAN網絡的基本實作(mnist資料集)

epoch 50:

sGAN網絡的基本實作(mnist資料集)

epoch 100:

sGAN網絡的基本實作(mnist資料集)

epoch 150:

sGAN網絡的基本實作(mnist資料集)

epoch 200:

sGAN網絡的基本實作(mnist資料集)

epoch 250:

sGAN網絡的基本實作(mnist資料集)

生成器和判别器損失:

sGAN網絡的基本實作(mnist資料集)

繼續閱讀