天天看點

從零開始,用英偉達T4、A10訓練小型文生視訊模型,幾小時搞定

作者:計算機視覺研究院

點選藍字

關注我們

關注并星标

從此不迷路

計算機視覺研究院

從零開始,用英偉達T4、A10訓練小型文生視訊模型,幾小時搞定
從零開始,用英偉達T4、A10訓練小型文生視訊模型,幾小時搞定

公衆号ID|計算機視覺研究院

學習群|掃碼在首頁擷取加入方式

計算機視覺研究院專欄

Column of Computer Vision Institute

很翔實的一篇教程。OpenAI 的 Sora、Stability AI 的 Stable Video Diffusion 以及許多其他已經釋出或未來将出現的文本生成視訊模型,是繼大語言模型 (LLM) 之後 2024 年最流行的 AI 趨勢之一。

在這篇部落格中,作者将展示如何将從頭開始建構一個小規模的文本生成視訊模型,涵蓋了從了解理論概念、到編寫整個架構再到生成最終結果的所有内容。

由于作者沒有大算力的 GPU,是以僅編寫了小規模架構。以下是在不同處理器上訓練模型所需時間的比較。

從零開始,用英偉達T4、A10訓練小型文生視訊模型,幾小時搞定

作者表示,在 CPU 上運作顯然需要更長的時間來訓練模型。如果你需要快速測試代碼中的更改并檢視結果,CPU 不是最佳選擇。是以建議使用 Colab 或 Kaggle 的 T4 GPU 進行更高效、更快速的訓練。

建構目标

我們采用了與傳統機器學習或深度學習模型類似的方法,即在資料集上進行訓練,然後在未見過資料上進行測試。在文本轉視訊的背景下,假設有一個包含 10 萬個狗撿球和貓追老鼠視訊的訓練資料集,然後訓練模型來生成貓撿球或狗追老鼠的視訊。

從零開始,用英偉達T4、A10訓練小型文生視訊模型,幾小時搞定

圖源:iStock, GettyImages

雖然此類訓練資料集在網際網路上很容易獲得,但所需的算力極高。是以,我們将使用由 Python 代碼生成的移動對象視訊資料集。同時使用 GAN(生成對抗網絡)架構來建立模型,而不是 OpenAI Sora 使用的擴散模型。

我們也嘗試使用擴散模型,但記憶體要求超出了自己的能力。另一方面,GAN 可以更容易、更快地進行訓練和測試。

準備條件

我們将使用 OOP(面向對象程式設計),是以必須對它以及神經網絡有基本的了解。此外 GAN(生成對抗網絡)的知識不是必需的,因為這裡簡單介紹它們的架構。

  • OOP:https://www.youtube.com/watch?v=q2SGW2VgwAM
  • 神經網絡理論:https://www.youtube.com/watch?v=Jy4wM2X21u0
  • GAN 架構:https://www.youtube.com/watch?v=TpMIssRdhco
  • Python 基礎:https://www.youtube.com/watch?v=eWRfhZUzrAc

了解 GAN 架構

什麼是 GAN?

生成對抗網絡是一種深度學習模型,其中兩個神經網絡互相競争:一個從給定的資料集建立新資料(如圖像或音樂),另一個則判斷資料是真實的還是虛假的。這個過程一直持續到生成的資料與原始資料無法區分。

真實世界應用

  • 生成圖像:GAN 根據文本 prompt 建立逼真的圖像或修改現有圖像,例如增強分辨率或為黑白照片添加顔色。
  • 資料增強:GAN 生成合成資料來訓練其他機器學習模型,例如為欺詐檢測系統建立欺詐交易資料。
  • 補充缺失資訊:GAN 可以填充缺失資料,例如根據地形圖生成地下圖像以用于能源應用。
  • 生成 3D 模型:GAN 将 2D 圖像轉換為 3D 模型,在醫療保健等領域非常有用,可用于為手術規劃建立逼真的器官圖像。

GAN 工作原理

GAN 由兩個深度神經網絡組成:生成器和判别器。這兩個網絡在對抗設定中一起訓練,其中一個網絡生成新資料,另一個網絡評估資料是真是假。

從零開始,用英偉達T4、A10訓練小型文生視訊模型,幾小時搞定

GAN 訓練示例

讓我們以圖像到圖像的轉換為例,解釋一下 GAN 模型,重點是修改人臉。

1. 輸入圖像:輸入圖像是一張真實的人臉圖像。2. 屬性修改:生成器會修改人臉的屬性,比如給眼睛加上墨鏡。3. 生成圖像:生成器會建立一組添加了太陽鏡的圖像。4. 判别器的任務:判别器接收到混合的真實圖像(帶有太陽鏡的人)和生成的圖像(添加了太陽鏡的人臉)。 5. 評估:判别器嘗試區分真實圖像和生成圖像。 6. 回報回路:如果判别器正确識别出假圖像,生成器會調整其參數以生成更逼真的圖像。如果生成器成功欺騙了判别器,判别器會更新其參數以提高檢測能力。

通過這一對抗過程,兩個網絡都在不斷改進。生成器越來越善于生成逼真的圖像,而判别器則越來越善于識别假圖像,直到達到平衡,判别器再也無法區分真實圖像和生成的圖像。此時,GAN 已成功學會生成逼真的修改圖像。

設定背景

我們将使用一系列 Python 庫,讓我們導入它們。

# Operating System module for interacting with the operating system              import os                  # Module for generating random numbers              import random                  # Module for numerical operations              import numpy as np                  # OpenCV library for image processing              import cv2                  # Python Imaging Library for image processing              from PIL import Image, ImageDraw, ImageFont                  # PyTorch library for deep learning              import torch                  # Dataset class for creating custom datasets in PyTorch              from torch.utils.data import Dataset                  # Module for image transformations              import torchvision.transforms as transforms                  # Neural network module in PyTorch              import torch.nn as nn                  # Optimization algorithms in PyTorch              import torch.optim as optim                  # Function for padding sequences in PyTorch              from torch.nn.utils.rnn import pad_sequence                  # Function for saving images in PyTorch              from torchvision.utils import save_image                  # Module for plotting graphs and images              import matplotlib.pyplot as plt                  # Module for displaying rich content in IPython environments              from IPython.display import clear_output, display, HTML                  # Module for encoding and decoding binary data to text              import base64           

現在我們已經導入了所有的庫,下一步就是定義我們的訓練資料,用于訓練 GAN 架構。

對訓練資料進行編碼

我們需要至少 10000 個視訊作為訓練資料。為什麼呢?因為我測試了較小數量的視訊,結果非常糟糕,幾乎沒有任何效果。下一個重要問題是:這些視訊内容是什麼? 我們的訓練視訊資料集包括一個圓圈以不同方向和不同運動方式移動的視訊。讓我們來編寫代碼并生成 10,000 個視訊,看看它的效果如何。

# Create a directory named 'training_dataset'              os.makedirs('training_dataset', exist_ok=True)                  # Define the number of videos to generate for the dataset              num_videos = 10000                  # Define the number of frames per video (1 Second Video)              frames_per_video = 10                  # Define the size of each image in the dataset              img_size = (64, 64)                  # Define the size of the shapes (Circle)              shape_size = 10            

設定一些基本參數後,接下來我們需要定義訓練資料集的文本 prompt,并據此生成訓練視訊。

# Define text prompts and corresponding movements for circles              prompts_and_movements = [              ("circle moving down", "circle", "down"), # Move circle downward              ("circle moving left", "circle", "left"), # Move circle leftward              ("circle moving right", "circle", "right"), # Move circle rightward              ("circle moving diagonally up-right", "circle", "diagonal_up_right"), # Move circle diagonally up-right              ("circle moving diagonally down-left", "circle", "diagonal_down_left"), # Move circle diagonally down-left              ("circle moving diagonally up-left", "circle", "diagonal_up_left"), # Move circle diagonally up-left              ("circle moving diagonally down-right", "circle", "diagonal_down_right"), # Move circle diagonally down-right              ("circle rotating clockwise", "circle", "rotate_clockwise"), # Rotate circle clockwise              ("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"), # Rotate circle counter-clockwise              ("circle shrinking", "circle", "shrink"), # Shrink circle              ("circle expanding", "circle", "expand"), # Expand circle              ("circle bouncing vertically", "circle", "bounce_vertical"), # Bounce circle vertically              ("circle bouncing horizontally", "circle", "bounce_horizontal"), # Bounce circle horizontally              ("circle zigzagging vertically", "circle", "zigzag_vertical"), # Zigzag circle vertically              ("circle zigzagging horizontally", "circle", "zigzag_horizontal"), # Zigzag circle horizontally              ("circle moving up-left", "circle", "up_left"), # Move circle up-left              ("circle moving down-right", "circle", "down_right"), # Move circle down-right              ("circle moving down-left", "circle", "down_left"), # Move circle down-left              ]           

我們已經利用這些 prompt 定義了圓的幾個運動軌迹。現在,我們需要編寫一些數學公式,以便根據 prompt 移動圓。

# Define function with parameters              def create_image_with_moving_shape(size, frame_num, shape, direction):                  # Create a new RGB image with specified size and white background              img = Image.new('RGB', size, color=(255, 255, 255))                   # Create a drawing context for the image              draw = ImageDraw.Draw(img)                   # Calculate the center coordinates of the image              center_x, center_y = size[0] // 2, size[1] // 2                   # Initialize position with center for all movements              position = (center_x, center_y)                   # Define a dictionary mapping directions to their respective position adjustments or image transformations              direction_map = {               # Adjust position downwards based on frame number              "down": (0, frame_num * 5 % size[1]),               # Adjust position to the left based on frame number              "left": (-frame_num * 5 % size[0], 0),               # Adjust position to the right based on frame number              "right": (frame_num * 5 % size[0], 0),               # Adjust position diagonally up and to the right              "diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),               # Adjust position diagonally down and to the left              "diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]),               # Adjust position diagonally up and to the left              "diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),               # Adjust position diagonally down and to the right              "diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),               # Rotate the image clockwise based on frame number              "rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),               # Rotate the image counter-clockwise based on frame number              "rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),               # Adjust position for a bouncing effect vertically              "bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)),               # Adjust position for a bouncing effect horizontally              "bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0),               # Adjust position for a zigzag effect vertically              "zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]),               # Adjust position for a zigzag effect horizontally              "zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y),               # Adjust position upwards and to the right based on frame number              "up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),               # Adjust position upwards and to the left based on frame number              "up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),               # Adjust position downwards and to the right based on frame number              "down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),               # Adjust position downwards and to the left based on frame number              "down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1])               }                  # Check if direction is in the direction map              if direction in direction_map:               # Check if the direction maps to a position adjustment              if isinstance(direction_map[direction], tuple):               # Update position based on the adjustment              position = tuple(np.add(position, direction_map[direction]))               else: # If the direction maps to an image transformation              # Update the image based on the transformation              img = direction_map[direction]                   # Return the image as a numpy array              return np.array(img)           

上述函數用于根據所選方向在每一幀中移動我們的圓。我們隻需在其上運作一個循環,直至生成所有視訊的次數。

# Iterate over the number of videos to generate              for i in range(num_videos):              # Randomly choose a prompt and movement from the predefined list              prompt, shape, direction = random.choice(prompts_and_movements)                  # Create a directory for the current video              video_dir = f'training_dataset/video_{i}'              os.makedirs(video_dir, exist_ok=True)                  # Write the chosen prompt to a text file in the video directory              with open(f'{video_dir}/prompt.txt', 'w') as f:              f.write(prompt)                  # Generate frames for the current video              for frame_num in range(frames_per_video):              # Create an image with a moving shape based on the current frame number, shape, and direction              img = create_image_with_moving_shape(img_size, frame_num, shape, direction)                  # Save the generated image as a PNG file in the video directory              cv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)           

運作上述代碼後,就會生成整個訓練資料集。以下是訓練資料集檔案的結構。

從零開始,用英偉達T4、A10訓練小型文生視訊模型,幾小時搞定

每個訓練視訊檔案夾包含其幀以及對應的文本 prompt。讓我們看一下我們的訓練資料集樣本。

在我們的訓練資料集中,我們沒有包含圓圈先向上移動然後向右移動的運動。我們将使用這個作為測試 prompt,來評估我們訓練的模型在未見過的資料上的表現。

從零開始,用英偉達T4、A10訓練小型文生視訊模型,幾小時搞定

還有一個重要的要點需要注意,我們的訓練資料包含許多物體從場景中移出或部分出現在錄影機前方的樣本,類似于我們在 OpenAI Sora 示範視訊中觀察到的情況。

從零開始,用英偉達T4、A10訓練小型文生視訊模型,幾小時搞定

在我們的訓練資料中包含此類樣本的原因是為了測試當圓圈從角落進入場景時,模型是否能夠保持一緻性而不會破壞其形狀。

現在我們的訓練資料已經生成,需要将訓練視訊轉換為張量,這是 PyTorch 等深度學習架構中使用的主要資料類型。此外,通過将資料縮放到較小的範圍,執行歸一化等轉換有助于提高訓練架構的收斂性和穩定性。

預處理訓練資料

我們必須為文本轉視訊任務編寫一個資料集類,它可以從訓練資料集目錄中讀取視訊幀及其相應的文本 prompt,使其可以在 PyTorch 中使用。

# Define a dataset class inheriting from torch.utils.data.Dataset              class TextToVideoDataset(Dataset):              def __init__(self, root_dir, transform=None):              # Initialize the dataset with root directory and optional transform              self.root_dir = root_dir              self.transform = transform              # List all subdirectories in the root directory              self.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]              # Initialize lists to store frame paths and corresponding prompts              self.frame_paths = []              self.prompts = []                  # Loop through each video directory              for video_dir in self.video_dirs:              # List all PNG files in the video directory and store their paths              frames = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png')]              self.frame_paths.extend(frames)              # Read the prompt text file in the video directory and store its content              with open(os.path.join(video_dir, 'prompt.txt'), 'r') as f:              prompt = f.read().strip()              # Repeat the prompt for each frame in the video and store in prompts list              self.prompts.extend([prompt] * len(frames))                  # Return the total number of samples in the dataset              def __len__(self):              return len(self.frame_paths)                  # Retrieve a sample from the dataset given an index              def __getitem__(self, idx):              # Get the path of the frame corresponding to the given index              frame_path = self.frame_paths[idx]              # Open the image using PIL (Python Imaging Library)              image = Image.open(frame_path)              # Get the prompt corresponding to the given index              prompt = self.prompts[idx]                  # Apply transformation if specified              if self.transform:              image = self.transform(image)                  # Return the transformed image and the prompt              return image, prompt           

在繼續編寫架構代碼之前,我們需要對訓練資料進行歸一化處理。我們使用 16 的 batch 大小并對資料進行混洗以引入更多随機性。

實作文本嵌入層

你可能已經看到,在 Transformer 架構中,起點是将文本輸入轉換為嵌入,進而在多頭注意力中進行進一步處理。類似地,我們在這裡必須編寫一個文本嵌入層。基于該層,GAN 架構訓練在我們的嵌入資料和圖像張量上進行。

# Define a class for text embedding              class TextEmbedding(nn.Module):              # Constructor method with vocab_size and embed_size parameters              def __init__(self, vocab_size, embed_size):              # Call the superclass constructor              super(TextEmbedding, self).__init__()              # Initialize embedding layer              self.embedding = nn.Embedding(vocab_size, embed_size)                  # Define the forward pass method              def forward(self, x):              # Return embedded representation of input              return self.embedding(x)            

詞彙量将基于我們的訓練資料,在稍後進行計算。嵌入大小将為 10。如果使用更大的資料集,你還可以使用 Hugging Face 上已有的嵌入模型。

實作生成器層

現在我們已經知道生成器在 GAN 中的作用,接下來讓我們對這一層進行編碼,然後了解其内容。

class Generator(nn.Module):              def __init__(self, text_embed_size):              super(Generator, self).__init__()                  # Fully connected layer that takes noise and text embedding as input              self.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8)                  # Transposed convolutional layers to upsample the input              self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)              self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)              self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1) # Output has 3 channels for RGB images                  # Activation functions              self.relu = nn.ReLU(True) # ReLU activation function              self.tanh = nn.Tanh() # Tanh activation function for final output                  def forward(self, noise, text_embed):              # Concatenate noise and text embedding along the channel dimension              x = torch.cat((noise, text_embed), dim=1)                  # Fully connected layer followed by reshaping to 4D tensor              x = self.fc1(x).view(-1, 256, 8, 8)                  # Upsampling through transposed convolution layers with ReLU activation              x = self.relu(self.deconv1(x))              x = self.relu(self.deconv2(x))                  # Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)              x = self.tanh(self.deconv3(x))                  return x           

該 Generator 類負責根據随機噪聲和文本嵌入的組合建立視訊幀,旨在根據給定的文本描述生成逼真的視訊幀。該網絡從完全連接配接層 (nn.Linear) 開始,将噪聲向量和文本嵌入組合成單個特征向量。然後,該向量被重新整形并經過一系列的轉置卷積層 (nn.ConvTranspose2d),這些層将特征圖逐漸上采樣到所需的視訊幀大小。

這些層使用 ReLU 激活 (nn.ReLU) 實作非線性,最後一層使用 Tanh 激活 (nn.Tanh) 将輸出縮放到 [-1, 1] 的範圍。是以,生成器将抽象的高維輸入轉換為以視覺方式表示輸入文本的連貫視訊幀。

實作判别器層

在編寫完生成器層之後,我們需要實作另一半,即判别器部分。

class Discriminator(nn.Module):              def __init__(self):              super(Discriminator, self).__init__()                  # Convolutional layers to process input images              self.conv1 = nn.Conv2d(3, 64, 4, 2, 1) # 3 input channels (RGB), 64 output channels, kernel size 4x4, stride 2, padding 1              self.conv2 = nn.Conv2d(64, 128, 4, 2, 1) # 64 input channels, 128 output channels, kernel size 4x4, stride 2, padding 1              self.conv3 = nn.Conv2d(128, 256, 4, 2, 1) # 128 input channels, 256 output channels, kernel size 4x4, stride 2, padding 1                  # Fully connected layer for classification              self.fc1 = nn.Linear(256 * 8 * 8, 1) # Input size 256x8x8 (output size of last convolution), output size 1 (binary classification)                  # Activation functions              self.leaky_relu = nn.LeakyReLU(0.2, inplace=True) # Leaky ReLU activation with negative slope 0.2              self.sigmoid = nn.Sigmoid() # Sigmoid activation for final output (probability)                  def forward(self, input):              # Pass input through convolutional layers with LeakyReLU activation              x = self.leaky_relu(self.conv1(input))              x = self.leaky_relu(self.conv2(x))              x = self.leaky_relu(self.conv3(x))                  # Flatten the output of convolutional layers              x = x.view(-1, 256 * 8 * 8)                  # Pass through fully connected layer with Sigmoid activation for binary classification              x = self.sigmoid(self.fc1(x))                  return x           

判别器類用作二進制分類器,區分真實視訊幀和生成的視訊幀。目的是評估視訊幀的真實性,進而指導生成器産生更真實的輸出。該網絡由卷積層 (nn.Conv2d) 組成,這些卷積層從輸入視訊幀中提取分層特征, Leaky ReLU 激活 (nn.LeakyReLU) 增加非線性,同時允許負值的小梯度。

然後,特征圖被展平并通過完全連接配接層 (nn.Linear),最終以 S 形激活 (nn.Sigmoid) 輸出訓示幀是真實還是假的機率分數。

通過訓練判别器準确地對幀進行分類,生成器同時接受訓練以建立更令人信服的視訊幀,進而騙過判别器。

編寫訓練參數

我們必須設定用于訓練 GAN 的基礎元件,例如損失函數、優化器等。

# Check for GPU              device = torch.device("cuda" if torch.cuda.is_available() else "cpu")                  # Create a simple vocabulary for text prompts              all_prompts = [prompt for prompt, _, _ in prompts_and_movements] # Extract all prompts from prompts_and_movements list              vocab = {word: idx for idx, word in enumerate(set(" ".join(all_prompts).split()))} # Create a vocabulary dictionary where each unique word is assigned an index              vocab_size = len(vocab) # Size of the vocabulary              embed_size = 10 # Size of the text embedding vector                  def encode_text(prompt):              # Encode a given prompt into a tensor of indices using the vocabulary              return torch.tensor([vocab[word] for word in prompt.split()])                  # Initialize models, loss function, and optimizers              text_embedding = TextEmbedding(vocab_size, embed_size).to(device) # Initialize TextEmbedding model with vocab_size and embed_size              netG = Generator(embed_size).to(device) # Initialize Generator model with embed_size              netD = Discriminator().to(device) # Initialize Discriminator model              criterion = nn.BCELoss().to(device) # Binary Cross Entropy loss function              optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999)) # Adam optimizer for Discriminator              optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) # Adam optimizer for Generator           

這是我們必須轉換代碼以在 GPU 上運作的部分(如果可用)。我們已經編寫了代碼來查找 vocab_size,并且我們正在為生成器和判别器使用 ADAM 優化器。你可以選擇自己的優化器。在這裡,我們将學習率設定為較小的值 0.0002,嵌入大小為 10,這比其他可供公衆使用的 Hugging Face 模型要小得多。

編寫訓練 loop

就像其他神經網絡一樣,我們将以類似的方式對 GAN 架構訓練進行編碼。

# Number of epochs              num_epochs = 13                  # Iterate over each epoch              for epoch in range(num_epochs):              # Iterate over each batch of data              for i, (data, prompts) in enumerate(dataloader):              # Move real data to device              real_data = data.to(device)                  # Convert prompts to list              prompts = [prompt for prompt in prompts]                  # Update Discriminator              netD.zero_grad() # Zero the gradients of the Discriminator              batch_size = real_data.size(0) # Get the batch size              labels = torch.ones(batch_size, 1).to(device) # Create labels for real data (ones)              output = netD(real_data) # Forward pass real data through Discriminator              lossD_real = criterion(output, labels) # Calculate loss on real data              lossD_real.backward() # Backward pass to calculate gradients                  # Generate fake data              noise = torch.randn(batch_size, 100).to(device) # Generate random noise              text_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0) for prompt in prompts]) # Encode prompts into text embeddings              fake_data = netG(noise, text_embeds) # Generate fake data from noise and text embeddings              labels = torch.zeros(batch_size, 1).to(device) # Create labels for fake data (zeros)              output = netD(fake_data.detach()) # Forward pass fake data through Discriminator (detach to avoid gradients flowing back to Generator)              lossD_fake = criterion(output, labels) # Calculate loss on fake data              lossD_fake.backward() # Backward pass to calculate gradients              optimizerD.step() # Update Discriminator parameters                  # Update Generator              netG.zero_grad() # Zero the gradients of the Generator              labels = torch.ones(batch_size, 1).to(device) # Create labels for fake data (ones) to fool Discriminator              output = netD(fake_data) # Forward pass fake data (now updated) through Discriminator              lossG = criterion(output, labels) # Calculate loss for Generator based on Discriminator's response              lossG.backward() # Backward pass to calculate gradients              optimizerG.step() # Update Generator parameters                  # Print epoch information              print(f"Epoch [{epoch + 1}/{num_epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}")           

通過反向傳播,我們的損失将針對生成器和判别器進行調整。我們在訓練 loop 中使用了 13 個 epoch。我們測試了不同的值,但如果 epoch 高于這個值,結果并沒有太大差異。此外,過度拟合的風險很高。如果我們的資料集更加多樣化,包含更多動作和形狀,則可以考慮使用更高的 epoch,但在這裡沒有這樣做。

當我們運作此代碼時,它會開始訓練,并在每個 epoch 之後 print 生成器和判别器的損失。

## OUTPUT ##                  Epoch [1/13] Loss D: 0.8798642754554749, Loss G: 1.300612449645996              Epoch [2/13] Loss D: 0.8235711455345154, Loss G: 1.3729925155639648              Epoch [3/13] Loss D: 0.6098687052726746, Loss G: 1.3266581296920776                  ...           

儲存訓練的模型

訓練完成後,我們需要儲存訓練好的 GAN 架構的判别器和生成器,這隻需兩行代碼即可實作。

# Save the Generator model's state dictionary to a file named 'generator.pth'              torch.save(netG.state_dict(), 'generator.pth')                  # Save the Discriminator model's state dictionary to a file named 'discriminator.pth'              torch.save(netD.state_dict(), 'discriminator.pth')           

生成 AI 視訊

正如我們所讨論的,我們在未見過的資料上測試模型的方法與我們訓練資料中涉及狗取球和貓追老鼠的示例類似。是以,我們的測試 prompt 可能涉及貓取球或狗追老鼠等場景。

在我們的特定情況下,圓圈向上移動然後向右移動的運動在訓練資料中不存在,是以模型不熟悉這種特定運動。但是,模型已經在其他動作上進行了訓練。我們可以使用此動作作為 prompt 來測試我們訓練過的模型并觀察其性能。

# Inference function to generate a video based on a given text promptdef generate_video(text_prompt, num_frames=10): # Create a directory for the generated video frames based on the text prompt os.makedirs(f'generated_video_{text_prompt.replace(" ", "_")}', exist_ok=True) # Encode the text prompt into a text embedding tensor text_embed = text_embedding(encode_text(text_prompt).to(device)).mean(dim=0).unsqueeze(0) # Generate frames for the video for frame_num in range(num_frames): # Generate random noise noise = torch.randn(1, 100).to(device) # Generate a fake frame using the Generator network with torch.no_grad(): fake_frame = netG(noise, text_embed) # Save the generated fake frame as an image file save_image(fake_frame, f'generated_video_{text_prompt.replace(" ", "_")}/frame_{frame_num}.png')# usage of the generate_video function with a specific text promptgenerate_video('circle moving up-right')           

當我們運作上述代碼時,它将生成一個目錄,其中包含我們生成視訊的所有幀。我們需要使用一些代碼将所有這些幀合并為一個短視訊。

# Define the path to your folder containing the PNG frames              folder_path = 'generated_video_circle_moving_up-right'                      # Get the list of all PNG files in the folder              image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]                  # Sort the images by name (assuming they are numbered sequentially)              image_files.sort()                  # Create a list to store the frames              frames = []                  # Read each image and append it to the frames list              for image_file in image_files:              image_path = os.path.join(folder_path, image_file)              frame = cv2.imread(image_path)              frames.append(frame)                  # Convert the frames list to a numpy array for easier processing              frames = np.array(frames)                  # Define the frame rate (frames per second)              fps = 10                  # Create a video writer object              fourcc = cv2.VideoWriter_fourcc(*'XVID')              out = cv2.VideoWriter('generated_video.avi', fourcc, fps, (frames[0].shape[1], frames[0].shape[0]))                  # Write each frame to the video              for frame in frames:              out.write(frame)                  # Release the video writer              out.release()           

確定檔案夾路徑指向你新生成的視訊所在的位置。運作此代碼後,你将成功建立 AI 視訊。讓我們看看它是什麼樣子。

從零開始,用英偉達T4、A10訓練小型文生視訊模型,幾小時搞定

我們進行了多次訓練,訓練次數相同。在兩種情況下,圓圈都是從底部開始,出現一半。好消息是,我們的模型在兩種情況下都嘗試執行直立運動。

例如,在嘗試 1 中,圓圈沿對角線向上移動,然後執行向上運動,而在嘗試 2 中,圓圈沿對角線移動,同時尺寸縮小。在兩種情況下,圓圈都沒有向左移動或完全消失,這是一個好兆頭。

最後,作者表示已經測試了該架構的各個方面,發現訓練資料是關鍵。通過在資料集中包含更多動作和形狀,你可以增加可變性并提高模型的性能。由于資料是通過代碼生成的,是以生成更多樣的資料不會花費太多時間;相反,你可以專注于完善邏輯。

此外,文章中讨論的 GAN 架構相對簡單。你可以通過內建進階技術或使用語言模型嵌入 (LLM) 而不是基本神經網絡嵌入來使其更複雜。此外,調整嵌入大小等參數會顯著影響模型的有效性。

原文連結:https://levelup.gitconnected.com/building-an-ai-text-to-video-model-from-scratch-using-python-35b4eb4002de

END

從零開始,用英偉達T4、A10訓練小型文生視訊模型,幾小時搞定

轉載請聯系本公衆号獲得授權

從零開始,用英偉達T4、A10訓練小型文生視訊模型,幾小時搞定

計算機視覺研究院學習群等你加入!

ABOUT

計算機視覺研究院

計算機視覺研究院主要涉及深度學習領域,主要緻力于目标檢測、目标跟蹤、圖像分割、OCR、模型量化、模型部署等研究方向。研究院每日分享最新的論文算法新架構,提供論文一鍵下載下傳,并分享實戰項目。研究院主要着重”技術研究“和“實踐落地”。研究院會針對不同領域分享實踐過程,讓大家真正體會擺脫理論的真實場景,培養愛動手程式設計愛動腦思考的習慣!

🔗

繼續閱讀