SGAN(ssgan)
ssgan是半監督學習生成對抗網絡。
初衷是利用GAN生成器生成的樣本來改進和提高圖像分類任務的性能。
SGAN的主要思想在鑒别器的設計。
相比普通的GAN的鑒别器輸出0和1(真和假),SGAN通過使鑒别器網絡輸出label+1類别,将其轉換為半監督上下文。
我們希望設計的鑒别器既扮演執行圖像分類任務的分類器的角色,又能區分有生成器生成的生成樣本和真實資料。
在SGAN中,就是把這個二分類(sigmoid)轉化為多分類(softmax),類型數量為C+1,指代C個标簽的資料和“一個假資料”,表示為[C_1, C_2,…,C_n,Fake]
SGAN在資料集上訓練生成模型G和D(對C + 1類别執行分類)
在訓練時,D預測輸入屬于C+1lei中的哪一類,其中添加了一個額外的類對應生成圖檔。
該方法可以用于建立資料效率更高的分類器。
![](https://img.laitimes.com/img/_0nNw4CM6IyYiwiM6ICdiwiIyVGduV2YfNWawNyZuBnLiRWZygTOxQWY2Q2MklzN4QGMlRjYyMWN3QDOlFjZjNzLc52YucWbp5GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.png)
資料集處理
通過對mnist資料集進行分割
原訓練資料集60000,分割為有标簽資料集(labeled_set)1000,無标簽資料集(unlabeled_set)59000
transforms = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=0.5, std=0.5)])
dataset = torchvision.datasets.MNIST("mnist", train=True, download=True,
transform=transforms)
labeled_set_len = 1000
unlabeled_set_len = len(dataset) - labeled_set_len
labeled_set, unlabeled_set = random_split(dataset, [labeled_set_len, unlabeled_set_len])
生成器
生成器部分和DCGAN無無别
先将100大小噪聲進行線性擴大,再通過三層反卷積得到[1, 28, 28]像素矩陣
class Generate(nn.Module):
def __init__(self) -> None:
super().__init__()
self.genModel1 = nn.Sequential(
nn.Linear(100, 256 * 7 * 7 ),
nn.ReLU(),
nn.BatchNorm1d(256 * 7 * 7),
)
self.genModel2 = nn.Sequential(
# [128, 28, 28]
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.BatchNorm2d(128),
# [64, 14, 14]
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
# [1, 28, 28]
nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, x):
x = self.genModel1(x)
x = torch.reshape(x, (-1, 256, 7, 7))
x = self.genModel2(x)
return x
辨識器
由于通過辨識器有兩種判斷種類,一種屬于10分類結果,另一種為普通GAN分類結果(真\假),是以在輸出時分為兩種:[-1, 10] 和 [-1, 1]
總體思路還是通過卷積層增大channel,減小h、w
最後通過線性層1輸出10分類結果,通過線性層2輸出2分類結果
注:線上性之後并沒有使用激活函數,在損失函數方面做了處理。
class Discriminator(nn.Module):
def __init__(self) -> None:
super().__init__()
self.modu1 = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Dropout2d(),
nn.Conv2d(64, 128, kernel_size=3, stride=2),
nn.LeakyReLU(),
nn.Dropout2d()
)
self.fc1 = nn.Linear(128*6*6, 10) #分類輸出,十個類别
self.fc2 = nn.Linear(128*6*6, 1) #判斷真假
self.bn = nn.BatchNorm2d(128)
def forward(self, x, if_bn = True):
x = self.modu1(x) # x.shape([1, 128, 6, 6])
if if_bn:
x = self.bn(x)
x = torch.reshape(x, (-1, 128*6*6)) # 四維 -> 二維
class_out = self.fc1(x)
real_fake_out = self.fc2(x)
# 這裡未做激活,在損失函數中處理
return class_out, real_fake_out
CrossEntropyLoss相當于loss+softmax
BCEWithLogitsLoss對應二分類未激活輸出(集合了sigmoid更穩定
# 分類損失函數
loss_classfication_fn = nn.CrossEntropyLoss()
# 二分類損失函數
loss_sigmoid_fn = nn.BCEWithLogitsLoss()
訓練模型
辨識器訓練部分,和普通GAN不同的是,普通GAN将真實圖檔損失和fake損失相加,這裡則是真實損失、fake損失、标注資料辨識損失
dis_optim.zero_grad()
# 計算未标注資料損失
# 首先将未标注資料放入鑒别器當中(隻判斷真假) 将10分類傳回值不接收
_, real_sg_ou = dis(unla_img)
dis_real_sg_loss = loss_sigmoid_fn(real_sg_ou, torch.ones_like(real_sg_ou, device=device))
dis_real_sg_loss.backward()
# 計算生成資料損失
gen_img = gen(random_noise)
_, fake_sg_ou = dis(gen_img.detach())
dis_fake_sg_loss = loss_sigmoid_fn(fake_sg_ou, torch.zeros_like(fake_sg_ou, device=device))
dis_fake_sg_loss.backward()
# 計算标注資料損失
d_real_sfm_out, _ = dis(lb_img)
d_real_sfm_loss = loss_classfication_fn(d_real_sfm_out, label)
d_real_sfm_loss.backward()
dis_loss = dis_real_sg_loss + dis_fake_sg_loss + d_real_sfm_loss
dis_optim.step()
生成器訓練部分,和普通GAN沒有差別
gen_optim.zero_grad()
_, fake_ou = dis(gen_img)
gen_fake_sg_loss = loss_sigmoid_fn(fake_ou, torch.ones_like(fake_ou, device=device))
gen_fake_sg_loss.backward()
gen_optim.step()
測試模型
epoch 0:
epoch 50:
epoch 100:
epoch 150:
epoch 200:
epoch 250:
生成器和判别器損失: