laitimes

Start from scratch and train a small Wensheng video model with NVIDIA T4 and A10 in a few hours

author:Institute of Computer Vision

Click on the blue word

Follow us

Follow and star

Never get lost

Institute of Computer Vision

Start from scratch and train a small Wensheng video model with NVIDIA T4 and A10 in a few hours
Start from scratch and train a small Wensheng video model with NVIDIA T4 and A10 in a few hours

Computer Vision Research Institute

Scan the QR code on the homepage to get how to join

Special column of the Institute of Computer Vision

Column of Computer Vision Institute

A very informative tutorial. OpenAI's Sora, Stability AI's Stable Video Diffusion, and many other text-generated video models that have been released or will emerge in the future are among the most popular AI trends in 2024 after large language models (LLMs).

In this blog, the authors will show how to build a small-scale text-generated video model from scratch, covering everything from understanding theoretical concepts, to writing the entire architecture, to generating the final result.

Since the authors didn't have a GPU with a lot of computing power, they only wrote small-scale architectures. Here's a comparison of the time it takes to train a model on different processors.

Start from scratch and train a small Wensheng video model with NVIDIA T4 and A10 in a few hours

According to the authors, running on the CPU obviously takes longer to train the model. If you need to quickly test changes in your code and see the results, CPU isn't the best choice. Therefore, it is recommended to use Colab or Kaggle's T4 GPU for more efficient and faster training.

Build your goals

We take a similar approach to traditional machine learning or deep learning models, where we train on a dataset and then test on unseen data. In the context of text-to-video, suppose you have a training dataset of 100,000 videos of dogs picking up balls and cats chasing mice, and then train the model to generate videos of cats picking up balls or dogs chasing mice.

Start from scratch and train a small Wensheng video model with NVIDIA T4 and A10 in a few hours

Source: iStock, GettyImages

While such training datasets are readily available on the internet, the computing power required is extremely high. Therefore, we will use a moving object video dataset generated by Python code. The GAN (Generative Adversarial Network) architecture is also used to create the model instead of the diffusion model used by OpenAI Sora.

We also tried using the diffusion model, but the memory requirements were beyond our capabilities. GANs, on the other hand, make it easier and faster to train and test.

Preparation conditions

We will be using OOP (Object-Oriented Programming), so it is essential to have a basic understanding of it as well as neural networks. In addition, knowledge of GANs (Generative Adversarial Networks) is not required, as their architecture is briefly described here.

  • OOP:https://www.youtube.com/watch?v=q2SGW2VgwAM
  • Neural Network Theory: https://www.youtube.com/watch?v=Jy4wM2X21u0
  • BY 架构:https://www.youtube.com/watch?v=TpMIssRdhco
  • Python Basics: https://www.youtube.com/watch?v=eWRfhZUzrAc

Learn about GAN architectures

What is GAN?

A generative adversarial network is a deep learning model in which two neural networks compete against each other: one creates new data (such as images or music) from a given dataset, and the other determines whether the data is real or fake. This process continues until the resulting data is indistinguishable from the original.

Real-world applications

  • Generate images: The GAN creates a realistic image based on a text prompt or modifies an existing image, such as upscaling the resolution or adding color to a black-and-white photo.
  • Data augmentation: The GAN generates synthetic data to train other machine learning models, such as creating fraudulent transaction data for fraud detection systems.
  • Fill in missing information: GANs can fill in missing data, such as generating subsurface imagery from topographic maps for use in energy applications.
  • Generate 3D models: GANs convert 2D images into 3D models, which are useful in areas such as healthcare and can be used to create photorealistic organ images for surgical planning.

How GANs work

A GAN consists of two deep neural networks: a generator and a discriminator. The two networks are trained together in an adversarial setting, with one generating new data and the other evaluating whether the data is true or false.

Start from scratch and train a small Wensheng video model with NVIDIA T4 and A10 in a few hours

GAN training example

Let's take an example of image-to-image conversion to explain the GAN model, with a focus on modifying faces.

1. Input Image: The input image is a real face image. 2. Attribute modification: The generator will modify the attributes of the face, such as adding sunglasses to the eyes. 3. Generate Images: The generator creates a set of images with sunglasses added. 4. Discriminator's Task: The Discriminator receives a mix of the real image (the person with the sunglasses) and the generated image (the face with the sunglasses added). 5. Evaluation: The discriminator tries to distinguish between a real image and a generated image. 6. Feedback loop: If the discriminator correctly recognizes a fake image, the generator adjusts its parameters to produce a more realistic image. If the generator successfully spoofs the discriminator, the discriminator updates its parameters to improve detection capabilities.

Through this adversarial process, both networks are constantly improving. Generators are getting better at generating realistic images, while discriminators are getting better at identifying fake images until equilibrium is reached, and discriminators can no longer distinguish between real and generated images. At this point, the GAN has successfully learned to produce photorealistic modified images.

Set the background

We'll be using a series of Python libraries, let's import them.

# Operating System module for interacting with the operating system              import os                  # Module for generating random numbers              import random                  # Module for numerical operations              import numpy as np                  # OpenCV library for image processing              import cv2                  # Python Imaging Library for image processing              from PIL import Image, ImageDraw, ImageFont                  # PyTorch library for deep learning              import torch                  # Dataset class for creating custom datasets in PyTorch              from torch.utils.data import Dataset                  # Module for image transformations              import torchvision.transforms as transforms                  # Neural network module in PyTorch              import torch.nn as nn                  # Optimization algorithms in PyTorch              import torch.optim as optim                  # Function for padding sequences in PyTorch              from torch.nn.utils.rnn import pad_sequence                  # Function for saving images in PyTorch              from torchvision.utils import save_image                  # Module for plotting graphs and images              import matplotlib.pyplot as plt                  # Module for displaying rich content in IPython environments              from IPython.display import clear_output, display, HTML                  # Module for encoding and decoding binary data to text              import base64           

Now that we've imported all the libraries, the next step is to define our training data for training the GAN architecture.

Encode the training data

We need at least 10,000 videos as training data. Why? Because I tested a smaller number of videos and the results were very bad and hardly any effect existed. The next important question is: what are these videos about? Our training video dataset includes a video of circles moving in different directions and in different modes of motion. Let's write the code and generate 10,000 videos to see how it works.

# Create a directory named 'training_dataset'              os.makedirs('training_dataset', exist_ok=True)                  # Define the number of videos to generate for the dataset              num_videos = 10000                  # Define the number of frames per video (1 Second Video)              frames_per_video = 10                  # Define the size of each image in the dataset              img_size = (64, 64)                  # Define the size of the shapes (Circle)              shape_size = 10            

After setting some basic parameters, the next step is to define the text prompt of the training dataset and generate a training video based on it.

# Define text prompts and corresponding movements for circles              prompts_and_movements = [              ("circle moving down", "circle", "down"), # Move circle downward              ("circle moving left", "circle", "left"), # Move circle leftward              ("circle moving right", "circle", "right"), # Move circle rightward              ("circle moving diagonally up-right", "circle", "diagonal_up_right"), # Move circle diagonally up-right              ("circle moving diagonally down-left", "circle", "diagonal_down_left"), # Move circle diagonally down-left              ("circle moving diagonally up-left", "circle", "diagonal_up_left"), # Move circle diagonally up-left              ("circle moving diagonally down-right", "circle", "diagonal_down_right"), # Move circle diagonally down-right              ("circle rotating clockwise", "circle", "rotate_clockwise"), # Rotate circle clockwise              ("circle rotating counter-clockwise", "circle", "rotate_counter_clockwise"), # Rotate circle counter-clockwise              ("circle shrinking", "circle", "shrink"), # Shrink circle              ("circle expanding", "circle", "expand"), # Expand circle              ("circle bouncing vertically", "circle", "bounce_vertical"), # Bounce circle vertically              ("circle bouncing horizontally", "circle", "bounce_horizontal"), # Bounce circle horizontally              ("circle zigzagging vertically", "circle", "zigzag_vertical"), # Zigzag circle vertically              ("circle zigzagging horizontally", "circle", "zigzag_horizontal"), # Zigzag circle horizontally              ("circle moving up-left", "circle", "up_left"), # Move circle up-left              ("circle moving down-right", "circle", "down_right"), # Move circle down-right              ("circle moving down-left", "circle", "down_left"), # Move circle down-left              ]           

We've used these prompts to define several trajectories of the circle. Now, we need to write some mathematical formulas to move the circle according to the prompt.

# Define function with parameters              def create_image_with_moving_shape(size, frame_num, shape, direction):                  # Create a new RGB image with specified size and white background              img = Image.new('RGB', size, color=(255, 255, 255))                   # Create a drawing context for the image              draw = ImageDraw.Draw(img)                   # Calculate the center coordinates of the image              center_x, center_y = size[0] // 2, size[1] // 2                   # Initialize position with center for all movements              position = (center_x, center_y)                   # Define a dictionary mapping directions to their respective position adjustments or image transformations              direction_map = {               # Adjust position downwards based on frame number              "down": (0, frame_num * 5 % size[1]),               # Adjust position to the left based on frame number              "left": (-frame_num * 5 % size[0], 0),               # Adjust position to the right based on frame number              "right": (frame_num * 5 % size[0], 0),               # Adjust position diagonally up and to the right              "diagonal_up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),               # Adjust position diagonally down and to the left              "diagonal_down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1]),               # Adjust position diagonally up and to the left              "diagonal_up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),               # Adjust position diagonally down and to the right              "diagonal_down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),               # Rotate the image clockwise based on frame number              "rotate_clockwise": img.rotate(frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),               # Rotate the image counter-clockwise based on frame number              "rotate_counter_clockwise": img.rotate(-frame_num * 10 % 360, center=(center_x, center_y), fillcolor=(255, 255, 255)),               # Adjust position for a bouncing effect vertically              "bounce_vertical": (0, center_y - abs(frame_num * 5 % size[1] - center_y)),               # Adjust position for a bouncing effect horizontally              "bounce_horizontal": (center_x - abs(frame_num * 5 % size[0] - center_x), 0),               # Adjust position for a zigzag effect vertically              "zigzag_vertical": (0, center_y - frame_num * 5 % size[1]) if frame_num % 2 == 0 else (0, center_y + frame_num * 5 % size[1]),               # Adjust position for a zigzag effect horizontally              "zigzag_horizontal": (center_x - frame_num * 5 % size[0], center_y) if frame_num % 2 == 0 else (center_x + frame_num * 5 % size[0], center_y),               # Adjust position upwards and to the right based on frame number              "up_right": (frame_num * 5 % size[0], -frame_num * 5 % size[1]),               # Adjust position upwards and to the left based on frame number              "up_left": (-frame_num * 5 % size[0], -frame_num * 5 % size[1]),               # Adjust position downwards and to the right based on frame number              "down_right": (frame_num * 5 % size[0], frame_num * 5 % size[1]),               # Adjust position downwards and to the left based on frame number              "down_left": (-frame_num * 5 % size[0], frame_num * 5 % size[1])               }                  # Check if direction is in the direction map              if direction in direction_map:               # Check if the direction maps to a position adjustment              if isinstance(direction_map[direction], tuple):               # Update position based on the adjustment              position = tuple(np.add(position, direction_map[direction]))               else: # If the direction maps to an image transformation              # Update the image based on the transformation              img = direction_map[direction]                   # Return the image as a numpy array              return np.array(img)           

The above function is used to move our circle in each frame according to the selected orientation. We just have to run a loop on it up to the number of times all the videos are generated.

# Iterate over the number of videos to generate              for i in range(num_videos):              # Randomly choose a prompt and movement from the predefined list              prompt, shape, direction = random.choice(prompts_and_movements)                  # Create a directory for the current video              video_dir = f'training_dataset/video_{i}'              os.makedirs(video_dir, exist_ok=True)                  # Write the chosen prompt to a text file in the video directory              with open(f'{video_dir}/prompt.txt', 'w') as f:              f.write(prompt)                  # Generate frames for the current video              for frame_num in range(frames_per_video):              # Create an image with a moving shape based on the current frame number, shape, and direction              img = create_image_with_moving_shape(img_size, frame_num, shape, direction)                  # Save the generated image as a PNG file in the video directory              cv2.imwrite(f'{video_dir}/frame_{frame_num}.png', img)           

After running the above code, the entire training dataset is generated. The following is the structure of the training dataset file.

Start from scratch and train a small Wensheng video model with NVIDIA T4 and A10 in a few hours

Each training video folder contains its frames and the corresponding text prompt. Let's take a look at a sample of our training dataset.

In our training dataset, we didn't include motions where the circle moved up and then to the right. We'll use this as a test prompt to evaluate how the model we're training performs on unseen data.

Start from scratch and train a small Wensheng video model with NVIDIA T4 and A10 in a few hours

Another important point to note is that our training data contains samples where many objects move out of the scene or partially appear in front of the camera, similar to what we observed in the OpenAI Sora demo video.

Start from scratch and train a small Wensheng video model with NVIDIA T4 and A10 in a few hours

The reason for including such samples in our training data is to test whether the model is able to maintain consistency without breaking its shape when circles enter the scene from corners.

Now that our training data has been generated, we need to convert the training video to tensors, which is the main type of data used in deep learning frameworks like PyTorch. In addition, performing transformations such as normalization can help improve the convergence and stability of the training architecture by scaling the data to a smaller range.

Preprocess training data

We had to write a dataset class for the text-to-video task that could read the video frame and its corresponding text prompt from the training dataset directory so that it could be used in PyTorch.

# Define a dataset class inheriting from torch.utils.data.Dataset              class TextToVideoDataset(Dataset):              def __init__(self, root_dir, transform=None):              # Initialize the dataset with root directory and optional transform              self.root_dir = root_dir              self.transform = transform              # List all subdirectories in the root directory              self.video_dirs = [os.path.join(root_dir, d) for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]              # Initialize lists to store frame paths and corresponding prompts              self.frame_paths = []              self.prompts = []                  # Loop through each video directory              for video_dir in self.video_dirs:              # List all PNG files in the video directory and store their paths              frames = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith('.png')]              self.frame_paths.extend(frames)              # Read the prompt text file in the video directory and store its content              with open(os.path.join(video_dir, 'prompt.txt'), 'r') as f:              prompt = f.read().strip()              # Repeat the prompt for each frame in the video and store in prompts list              self.prompts.extend([prompt] * len(frames))                  # Return the total number of samples in the dataset              def __len__(self):              return len(self.frame_paths)                  # Retrieve a sample from the dataset given an index              def __getitem__(self, idx):              # Get the path of the frame corresponding to the given index              frame_path = self.frame_paths[idx]              # Open the image using PIL (Python Imaging Library)              image = Image.open(frame_path)              # Get the prompt corresponding to the given index              prompt = self.prompts[idx]                  # Apply transformation if specified              if self.transform:              image = self.transform(image)                  # Return the transformed image and the prompt              return image, prompt           

Before we can continue to write schema code, we need to normalize the training data. We use a batch size of 16 and shuffle the data to introduce more randomness.

Implement a text embedding layer

As you may have seen, in a Transformer architecture, the starting point is to convert text input into embeddings for further processing in multi-head attention. Similarly, here we have to write a text embedding layer. Based on this layer, the GAN architecture is trained on our embedded data and image tensors.

# Define a class for text embedding              class TextEmbedding(nn.Module):              # Constructor method with vocab_size and embed_size parameters              def __init__(self, vocab_size, embed_size):              # Call the superclass constructor              super(TextEmbedding, self).__init__()              # Initialize embedding layer              self.embedding = nn.Embedding(vocab_size, embed_size)                  # Define the forward pass method              def forward(self, x):              # Return embedded representation of input              return self.embedding(x)            

The vocabulary will be calculated at a later date based on our training data. The embed size will be 10. If you're working with a larger dataset, you can also use the embedding model that is already on Hugging Face.

Implement the generator layer

Now that we know what the generator does in the GAN, let's code this layer and then understand its contents.

class Generator(nn.Module):              def __init__(self, text_embed_size):              super(Generator, self).__init__()                  # Fully connected layer that takes noise and text embedding as input              self.fc1 = nn.Linear(100 + text_embed_size, 256 * 8 * 8)                  # Transposed convolutional layers to upsample the input              self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)              self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1)              self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2, 1) # Output has 3 channels for RGB images                  # Activation functions              self.relu = nn.ReLU(True) # ReLU activation function              self.tanh = nn.Tanh() # Tanh activation function for final output                  def forward(self, noise, text_embed):              # Concatenate noise and text embedding along the channel dimension              x = torch.cat((noise, text_embed), dim=1)                  # Fully connected layer followed by reshaping to 4D tensor              x = self.fc1(x).view(-1, 256, 8, 8)                  # Upsampling through transposed convolution layers with ReLU activation              x = self.relu(self.deconv1(x))              x = self.relu(self.deconv2(x))                  # Final layer with Tanh activation to ensure output values are between -1 and 1 (for images)              x = self.tanh(self.deconv3(x))                  return x           

This Generator class is responsible for creating video frames based on a combination of random noise and text embeddings, aiming to generate realistic video frames based on a given text description. The network starts from the fully connected layer (nn. Linear) to combine noise vectors and text embeddings into a single feature vector. The vector is then reshaped and passed through a series of transposed convolutional layers (nn. ConvTranspose2d), these layers step by step upsample the feature map to the desired video frame size.

These layers are activated using ReLU (nn. ReLU) to achieve nonlinearity, and the last layer uses Tanh activation (nn. Tanh) scales the output to a range of [-1, 1]. As a result, the generator converts abstract high-dimensional input into coherent video frames that visually represent the input text.

Implement the discriminator layer

After writing the generator layer, we need to implement the other half, which is the discriminator part.

class Discriminator(nn.Module):              def __init__(self):              super(Discriminator, self).__init__()                  # Convolutional layers to process input images              self.conv1 = nn.Conv2d(3, 64, 4, 2, 1) # 3 input channels (RGB), 64 output channels, kernel size 4x4, stride 2, padding 1              self.conv2 = nn.Conv2d(64, 128, 4, 2, 1) # 64 input channels, 128 output channels, kernel size 4x4, stride 2, padding 1              self.conv3 = nn.Conv2d(128, 256, 4, 2, 1) # 128 input channels, 256 output channels, kernel size 4x4, stride 2, padding 1                  # Fully connected layer for classification              self.fc1 = nn.Linear(256 * 8 * 8, 1) # Input size 256x8x8 (output size of last convolution), output size 1 (binary classification)                  # Activation functions              self.leaky_relu = nn.LeakyReLU(0.2, inplace=True) # Leaky ReLU activation with negative slope 0.2              self.sigmoid = nn.Sigmoid() # Sigmoid activation for final output (probability)                  def forward(self, input):              # Pass input through convolutional layers with LeakyReLU activation              x = self.leaky_relu(self.conv1(input))              x = self.leaky_relu(self.conv2(x))              x = self.leaky_relu(self.conv3(x))                  # Flatten the output of convolutional layers              x = x.view(-1, 256 * 8 * 8)                  # Pass through fully connected layer with Sigmoid activation for binary classification              x = self.sigmoid(self.fc1(x))                  return x           

The discriminator class is used as a binary classifier to distinguish between real video frames and generated video frames. The aim is to evaluate the realism of the video frames, which guides the generator to produce a more realistic output. The network is composed of convolutional layers (nn. Conv2d), these convolutional layers extract layered features from the input video frames, Leaky ReLU activation (nn. LeakyReLU) increases nonlinearity while allowing for small gradients of negative values.

The feature map is then flattened and passed through a fully connected layer (nn. Linear), which is eventually activated in a S-shape (nn. Sigmoid) output indicates whether the frame is real or false as a probability score.

By training the discriminator to accurately classify the frames, the generator is simultaneously trained to create more convincing video frames, fooling the discriminator.

Write training parameters

We have to set up the underlying components that are used to train the GAN, such as loss functions, optimizers, and so on.

# Check for GPU              device = torch.device("cuda" if torch.cuda.is_available() else "cpu")                  # Create a simple vocabulary for text prompts              all_prompts = [prompt for prompt, _, _ in prompts_and_movements] # Extract all prompts from prompts_and_movements list              vocab = {word: idx for idx, word in enumerate(set(" ".join(all_prompts).split()))} # Create a vocabulary dictionary where each unique word is assigned an index              vocab_size = len(vocab) # Size of the vocabulary              embed_size = 10 # Size of the text embedding vector                  def encode_text(prompt):              # Encode a given prompt into a tensor of indices using the vocabulary              return torch.tensor([vocab[word] for word in prompt.split()])                  # Initialize models, loss function, and optimizers              text_embedding = TextEmbedding(vocab_size, embed_size).to(device) # Initialize TextEmbedding model with vocab_size and embed_size              netG = Generator(embed_size).to(device) # Initialize Generator model with embed_size              netD = Discriminator().to(device) # Initialize Discriminator model              criterion = nn.BCELoss().to(device) # Binary Cross Entropy loss function              optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999)) # Adam optimizer for Discriminator              optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999)) # Adam optimizer for Generator           

This is the part where we have to convert the code to run on the GPU, if available. We've written code to find vocab_size, and we're using the ADAM optimizer for generators and discriminators. You can choose your own optimizer. Here, we set the learning rate to a smaller value of 0.0002 and an embedding size of 10, which is much smaller than other Hugging Face models available to the public.

Write a training loop

Just like any other neural network, we'll be coding GAN architecture training in a similar way.

# Number of epochs              num_epochs = 13                  # Iterate over each epoch              for epoch in range(num_epochs):              # Iterate over each batch of data              for i, (data, prompts) in enumerate(dataloader):              # Move real data to device              real_data = data.to(device)                  # Convert prompts to list              prompts = [prompt for prompt in prompts]                  # Update Discriminator              netD.zero_grad() # Zero the gradients of the Discriminator              batch_size = real_data.size(0) # Get the batch size              labels = torch.ones(batch_size, 1).to(device) # Create labels for real data (ones)              output = netD(real_data) # Forward pass real data through Discriminator              lossD_real = criterion(output, labels) # Calculate loss on real data              lossD_real.backward() # Backward pass to calculate gradients                  # Generate fake data              noise = torch.randn(batch_size, 100).to(device) # Generate random noise              text_embeds = torch.stack([text_embedding(encode_text(prompt).to(device)).mean(dim=0) for prompt in prompts]) # Encode prompts into text embeddings              fake_data = netG(noise, text_embeds) # Generate fake data from noise and text embeddings              labels = torch.zeros(batch_size, 1).to(device) # Create labels for fake data (zeros)              output = netD(fake_data.detach()) # Forward pass fake data through Discriminator (detach to avoid gradients flowing back to Generator)              lossD_fake = criterion(output, labels) # Calculate loss on fake data              lossD_fake.backward() # Backward pass to calculate gradients              optimizerD.step() # Update Discriminator parameters                  # Update Generator              netG.zero_grad() # Zero the gradients of the Generator              labels = torch.ones(batch_size, 1).to(device) # Create labels for fake data (ones) to fool Discriminator              output = netD(fake_data) # Forward pass fake data (now updated) through Discriminator              lossG = criterion(output, labels) # Calculate loss for Generator based on Discriminator's response              lossG.backward() # Backward pass to calculate gradients              optimizerG.step() # Update Generator parameters                  # Print epoch information              print(f"Epoch [{epoch + 1}/{num_epochs}] Loss D: {lossD_real + lossD_fake}, Loss G: {lossG}")           

With backpropagation, our losses will be adjusted for generators and discriminators. We used 13 epochs in the training loop. We tested different values, but if the epoch is higher than that, the results don't make much difference. In addition, the risk of overfitting is high. If our dataset were more diverse, with more actions and shapes, we could consider using a higher epoch, but this is not the case here.

When we run this code, it starts training, and after each epoch the loss of the print generator and discriminator.

## OUTPUT ##                  Epoch [1/13] Loss D: 0.8798642754554749, Loss G: 1.300612449645996              Epoch [2/13] Loss D: 0.8235711455345154, Loss G: 1.3729925155639648              Epoch [3/13] Loss D: 0.6098687052726746, Loss G: 1.3266581296920776                  ...           

Save the trained model

Once the training is complete, we need to save the discriminator and generator of the trained GAN architecture, which can be achieved with just two lines of code.

# Save the Generator model's state dictionary to a file named 'generator.pth'              torch.save(netG.state_dict(), 'generator.pth')                  # Save the Discriminator model's state dictionary to a file named 'discriminator.pth'              torch.save(netD.state_dict(), 'discriminator.pth')           

Generate AI videos

As we've discussed, our approach to testing the model on unseen data is similar to the examples in our training data that involve dogs fetching balls and cats chasing mice. As a result, our test prompt may involve scenarios such as a cat fetching a ball or a dog chasing a mouse.

In our particular case, the motion of the circle moving up and then to the right is not present in the training data, so the model is not familiar with this particular motion. However, the model has already been trained on other actions. We can use this action as a prompt to test our trained model and observe its performance.

# Inference function to generate a video based on a given text promptdef generate_video(text_prompt, num_frames=10): # Create a directory for the generated video frames based on the text prompt os.makedirs(f'generated_video_{text_prompt.replace(" ", "_")}', exist_ok=True) # Encode the text prompt into a text embedding tensor text_embed = text_embedding(encode_text(text_prompt).to(device)).mean(dim=0).unsqueeze(0) # Generate frames for the video for frame_num in range(num_frames): # Generate random noise noise = torch.randn(1, 100).to(device) # Generate a fake frame using the Generator network with torch.no_grad(): fake_frame = netG(noise, text_embed) # Save the generated fake frame as an image file save_image(fake_frame, f'generated_video_{text_prompt.replace(" ", "_")}/frame_{frame_num}.png')# usage of the generate_video function with a specific text promptgenerate_video('circle moving up-right')           

When we run the aforementioned code, it will generate a directory with all the frames of the video we are generating. We need to use some code to merge all these frames into one short video.

# Define the path to your folder containing the PNG frames              folder_path = 'generated_video_circle_moving_up-right'                      # Get the list of all PNG files in the folder              image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]                  # Sort the images by name (assuming they are numbered sequentially)              image_files.sort()                  # Create a list to store the frames              frames = []                  # Read each image and append it to the frames list              for image_file in image_files:              image_path = os.path.join(folder_path, image_file)              frame = cv2.imread(image_path)              frames.append(frame)                  # Convert the frames list to a numpy array for easier processing              frames = np.array(frames)                  # Define the frame rate (frames per second)              fps = 10                  # Create a video writer object              fourcc = cv2.VideoWriter_fourcc(*'XVID')              out = cv2.VideoWriter('generated_video.avi', fourcc, fps, (frames[0].shape[1], frames[0].shape[0]))                  # Write each frame to the video              for frame in frames:              out.write(frame)                  # Release the video writer              out.release()           

Make sure the folder path points to the location of your newly generated video. After you run this code, you'll successfully create an AI video. Let's see what it looks like.

Start from scratch and train a small Wensheng video model with NVIDIA T4 and A10 in a few hours

We did multiple training sessions and did the same number of trainings. In both cases, the circle starts at the bottom and appears halfway. The good news is that our model tries to perform an upright motion in both cases.

For example, in Attempt 1, the circle moves diagonally up and then performs an upward motion, while in Attempt 2, the circle moves diagonally while the size decreases. In both cases, the circle didn't move to the left or disappear completely, which is a good sign.

Finally, the authors say that they have tested various aspects of the architecture and found that the training data is key. By including more actions and shapes in your dataset, you can increase variability and improve the performance of your model. Since the data is generated through code, it doesn't take much time to generate more diverse data; Instead, you can focus on perfecting the logic.

In addition, the GAN architecture discussed in the article is relatively simple. You can make it more complex by integrating advanced techniques or using language model embeddings (LLMs) instead of basic neural network embeddings. In addition, adjusting parameters such as embedding size can significantly affect the effectiveness of the model.

Original link: https://levelup.gitconnected.com/building-an-ai-text-to-video-model-from-scratch-using-python-35b4eb4002de

END

Start from scratch and train a small Wensheng video model with NVIDIA T4 and A10 in a few hours

Please contact this official account for authorization for reprinting

Start from scratch and train a small Wensheng video model with NVIDIA T4 and A10 in a few hours

The Computer Vision Research Institute Learning Group is waiting for you to join!

ABOUT

Institute of Computer Vision

The Institute of Computer Vision is mainly involved in the field of deep learning, mainly focusing on object detection, object tracking, image segmentation, OCR, model quantization, model deployment and other research directions. The institute shares the latest paper algorithm and new framework every day, provides one-click download of papers, and shares practical projects. The institute mainly focuses on "technical research" and "practical implementation". The institute will share the practice process for different fields, so that everyone can truly experience the real scene of getting rid of theory, and cultivate the habit of loving hands-on programming and thinking with their brains!

🔗

Read on