Training CLIP Model from Scratch for an Image Retrieval App (2024)

Contrastive Language Image Pretraining (CLIP) by OpenAI is a model that connects text and images, allowing it to recognize and categorize images without needing specific training for each category. CLIP learns to understand how text and visual features relate, similar to how humans process information through multiple senses.

Training CLIP Model from Scratch for an Image Retrieval App (1)

The article primarily discusses:

  • To explore the Zero Shot Classification capabilities of pretrained CLIP.
  • How to Implement Vision and Text Encoder of CLIP in PyTorch.
  • How to train a CLIP-like model on a Fashion Images Dataset.
  • How to create anImage Retrieval app for apparel search using Gradio.

This article is designed for readers who have an intermediate understanding of PyTorch and Transformers. You’ll gain insights into how CLIP works internally and learn to build a simple version of the model yourself.

To try out the code and use it with your dataset, click the “Download Code” banner.

  • Why Do We Need CLIP?
  • Capability of Contrastive Language Image Pretraining
  • Understanding the CLIP Architecture
  • Code Walkthrough
  • E-Commerce Image Retrieval App with Gradio
  • Key Takeaways
  • Conclusion
  • References

Why Do We Need CLIP?

You may wonder why you need advanced architectures for classification when we have task-specific models like ResNet? These models, although highly accurate on ImageNet1k or ImageNet21k classes, struggle to generalize effectively for out of distribution classes. CLIP can overcome this limitation by creating a flexible set of image embeddings to detect any class from the input image by providing a set of labels.

Having trained on vast text-image pairs with contrastive (or symmetric) loss, it performs exceptionally well even across unseen categories. CLIP is also primarily utilized in text-guided diffusion model pipelines, where it helps refine the image generation process based on context-rich textual prompts.

Capability of Contrastive Language Image Pretraining

First, let us understand what CLIP is capable of, through some zero shot classification.

You can find the code to perform inference using the pre-trained CLIP in HuggingFace model card.

Zero Shot Example 1: Facial Emotion Recognition

Zero Shot Example 2: Animals and Birds

Zero Shot Example 3: Human Activities

We see that CLIP performs well in most cases with few failure case exceptions (shown at end ) from the above testing.

As we prepare to build an apparel search app with a 0.5M model, we want to inform you that an open-source FashionCLIP model, pretrained on a large fashion dataset, is available on Hugging Face. FashionCLIP, a CLIP-based model developed to produce general product representations for fashion concepts, performs better across fashion retrieval tasks.

Training CLIP Model from Scratch for an Image Retrieval App (6)

Understanding the CLIP Architecture

Contrastive Language Image Pretraining (CLIP) is a combination of Image Encoder and Text Encoder. Let’s understand the components at a high-level.

Image or Vision Encoder: This is either a CNN model like ResNet or Vision Transformers (ViT) which process the images or patches and convert them into its equivalent embeddings as dense vectors, say, Training CLIP Model from Scratch for an Image Retrieval App (8).

Text Encoder: This is often a transformer based model like GPT or BERT. It can even be a traditional embedding model (UTF-8 based) or an algorithm like TF-IDF. The text encoder creates meaningful embeddings from caption or label descriptions, say, Training CLIP Model from Scratch for an Image Retrieval App (9).

Then, the final layers of text and the image encoders should have the same dimensionality so that their embeddings can be compared directly in the shared feature space (or latent space). For this, a linear projection layer is typically added to one of the encoders or both, to maximize the cosine similarity between similar text-image feature pairs while minimizing for unrelated pairs.

While similar architectures like ALIGN existed earlier, what sets CLIP apart is its intuitive approach employed during training. CLIP uses contrastive loss during its training phase, which stands for Contrastive Language Image Pre-training. During contrastive learning, CLIP was trained on 400M image-text pairs where it learned to maximize similarity of the diagonal elements (or positive pairs, i.e. less cosine distance) and treating non-diagonal elements as less relevant or negative pairs (maximize their cosine distance).

Observe the above image where we perform zero shot classification on an image of a hedgehog. First, we define a set of class labels, let’s say labels = [“porcupine”,”rat”,”mole”,”hedgehog”]. These classes are then encoded with the text encoder of the CLIP model and we process the image to be inferred using image encoder. Then it calculates the distance between the textual label embeddings and image patch embeddings. After computing similarity, the class with the highest similarity score is determined as the predicted class (hedgehog) of the image.

Once the image and text is encoded into embedding space, the directionality is one-way which means it is not possible to reverse the process and reconstruct the original image or caption exactly from those embeddings.

Now that we conceptually understand how CLIP works, next we will code individual components of CLIP such as Vision Encoder, Text Encoder and Tokenizer from the ground up and train the model on an e-commerce fashion images dataset.

Finally, we will build an image retrieval app with Gradio whose schema pipeline looks as follows:

Code Walkthrough

We will start with importing the necessary dependencies.

!pip install datasets -q!pip install kaggle -q!pip install gradio==3.50 -q
import torchimport torch.nn as nnimport torch.optim as optimimport torchvisionimport torchvision.transforms as Tfrom torch.utils.data import Dataset, DataLoaderfrom datasets import load_datasetfrom sklearn.model_selection import train_test_splitfrom torch.utils.data import Subsetimport matplotlib.pyplot as pltimport numpy as npimport pandas as pdfrom PIL import Image, ImageDraw, ImageOpsimport pdbimport osfrom tqdm import tqdmimport warningswarnings.filterwarnings('ignore')

TRANSFORMER ENCODER

Usually a decoder only model requires masking for autoregressive text generation. Even though CLIP is an encoder only model it requires an additional masking to be created in the Attention Head of the transformer encoder to account for padded features in both texts and images. So we will build a custom MHA block of transformer encoder instead of using PyTorch nn.MHA.

class TransformerEncoder(nn.Module): def __init__(self, d_model, n_heads, mlp_ratio =4): super().__init__() self.d_model = d_model self.n_heads = n_heads self.ln1 = nn.LayerNorm(d_model) self.mha = MultiheadAttention(d_model, n_heads) self.ln2 = nn.LayerNorm(d_model) self.mlp = nn.Sequential( nn.Linear(d_model, d_model*mlp_ratio), nn.GELU(), nn.Linear(d_model * mlp_ratio, d_model) )#For clip even though its a encoder model it requires mask ->to account for padded for max seq_length def forward(self, x, mask = None): x_n = self.mha(self.ln1(x), mask = mask) x = x + self.mlp(self.ln2(x_n)) return x # x.shape --> [B,max_seq_len,d_model]

Here,

  • d_model: Hidden embedding dimension of transformer encoder or hidden size of the embedding vector.
    n_heads: Number of attention heads.
  • mlp_ratio: ratio multiplied to hidden_size. Set to 4 to balance computation cost and efficiency.
  • The Transformer encoder consists of few layer norm layers followed by multiple self attention layers and finally a feed forward network.
  • Like vanilla Transformer Encoder from the “Attention is All You Need” paper, we use LayerNorm to stabilize training and ensure feature scaling across the layers instead of normalizing across batches (BatchNorm).
  • The output of the transformer encoder is of shape [Batch_size, max_seq_len, d_model].

Even though the input and output shape of a Transformer encoder are same, the output has learned contextual embeddings (i.e. Attention)

Now, we have seen how the Transformer encoder is composed of multiple head attention blocks.

ATTENTION BLOCK

Next we will implement the attention block where every input [x.shape = B, max_seq_len, d_model] is processed through linear layers as query(Q), key(K), and value(V) matrices of shape [B, max_seq_len, head_size]. The attention is calculated by matrix multiplication of Q and K^T.

By the above formula ( Attention(Q,K,V) ), the attention scores are normalized (or scaled) by square root of head_size and a softmax layer is applied to create attention weights. Finally the output of the softmax is multiplied with the weights of the Value matrix. The final output of each Attention Head is a weighted sum of value vectors which emphasizes the interaction between each token w.r.t to themselves and all other tokens in the sequence.

class AttentionHead(nn.Module): def __init__(self, d_model, qkv_dim): super().__init__() self.qkv_dim = qkv_dim self.query = nn.Linear(d_model, qkv_dim) self.key = nn.Linear(d_model, qkv_dim) self.value = nn.Linear(d_model, qkv_dim) def forward(self, x, mask = None): # x.shape --> [B,max_seq_len,d_model] Q = self.query(x) #[B,max_seq_len,vit_heads] K = self.key(x) V = self.value(x) attention = Q @ K.transpose(-2,-1) #eg: -2 -second last dim and -1 last dim --> [B,max_seq_len,max_seq_len] #Scaling attention = attention / self.qkv_dim ** 0.5 # [B,max_seq_len,max_seq_len] #Apply attention mask for padded sequence if mask is not None: mask = attention.masked_fill(mask == 0, float("-inf")) # torch.tensor.masked_fill # Apply softmax to obtain attention weights [Wij] attention = torch.softmax(attention, dim = -1) #along last dim # (softmax(Q_K^T)/sqrt(d_k)).V --> [B,max_seq_len,max_seq_len] attention = attention @ V # [B,max_seq_len,max_seq_len] return attention #Y_i

The output of AttentionHead is of shape [B,max_seq_len, head_size].

MULTI-HEAD ATTENTION

class MultiheadAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() # d_model --> embed dimension # n_heads --> number of heads self.qkv_dim = d_model // n_heads #or self.head_size self.W_o = nn.Linear(d_model,d_model) #Dense layer self.multi_head = nn.ModuleList([AttentionHead(d_model, self.qkv_dim) for _ in range(n_heads)]) def forward(self,x,mask = None): #x.shape --> [B,max_seq_len,d_model] #Concatenates the outputs from all attention heads along the last dimension (dim=-1) out = torch.cat([head(x, mask=mask) for head in self.multi_head], dim = -1) # [B,max_seq_len,d_model] # Apply the linear transformation out = self.W_o(out) # (Concat --> Dense) --> [B,max_seq_len,d_model] return out

The outputs from all the attention heads are concatenated and a linear transformation is applied to mix these learned embeddings back to the original dimension[B,max_seq_len,d_model].

VISION ENCODER

We will use a ViT as a vision encoder where an image is divided into patches resulting in img_height/patch_size * img_width / patch_size non-overlapping patches. These patches are flattened linearly, and their dimensions are [B, Num_Patches, Patch Dimension].

As we know in a transformer based architecture, to learn the spatial features properly we will need to add positional encodings to the image patches to retain information about a patch’s position in the input image. The following image illustrates the importance of adding positional encoding to preserve the order of patches of a input image.

The sine and cosine functions are applied to alternate tokens and help determine the unique position of the patch in the sequence. As a result we save the positional encodingas a part of the model using the register buffer and it is not a learnable parameter. There are approaches where you can make ‘pe' a learnable param as well.

class PositionalEmbedding(nn.Module): def __init__(self, d_model, max_seq_length): super().__init__() self.d_model = d_model self.max_seq_length = max_seq_length pe = torch.zeros(max_seq_length, d_model) position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): seq_len = x.size(1) return x + self.pe[:, :seq_len]

Along with this, a [CLS] token is prepended (i.e. at zeroth position) which is a classification token that acts as a summary representation for the entire image. The input to the transformer encoder is of shape [B, max_seq_len, d_model]

class VisionEncoder(nn.Module): def __init__(self, d_model,img_size,patch_size, n_channels, n_heads,n_layers, emb_dim): super().__init__() assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] ==0, "image dimensions should be divisible by patch dim" assert d_model % n_heads == 0, "d_model should be divisible by n_heads" self.num_patches = (img_size[0] * img_size[1] ) // (patch_size[0] * patch_size[1]) # max_seq_length self.max_seq_length = self.num_patches +1 self.linear_proj = nn.Conv2d(in_channels = n_channels,out_channels = d_model, kernel_size = patch_size[0], stride = patch_size[0]) self.cls_token = nn.Parameter(torch.randn(1,1,d_model), requires_grad = True) self.positional_embedding = PositionalEmbedding(d_model, self.max_seq_length) self.transformer_encoder = nn.ModuleList([TransformerEncoder(d_model, n_heads) for _ in range(n_layers)]) self.projection = nn.Parameter(torch.randn(d_model, emb_dim)) def forward(self,x, mask = None): x = self.linear_proj(x) # (B, C, H, W) -> (B, d_model, Patch_col_d_model, Patch_row_height) x = x.flatten(2).transpose(-2, -1) # (B, d_model, Patch_col_d_model, Patch_row_height) --> Flatten (B, d_model, Patch) --> .transpose(-2,-1) (B, Patch, d_model) # The input to the transformer we need to pass a sequence of patches or tokens so we need num_patches to be before hidden dim x = torch.cat((self.cls_token.expand(x.shape[0], -1,-1), x), dim = 1) #add cls token at the beginning of patch_sequence --> [B,max_seq_len,d_model] x = self.positional_embedding(x) # [B,max_seq_len,d_model]

After passing through the MLP layer of our transformer encoder, which was defined earlier, we will get class tokens. Finally the output of Vision Encoder has to be projected to k-dimensional space (also known as a shared embedding space or joint embedding space) of both image and text encoders (emb_space_dim).

class VisionEncoder(nn.Module): ... def forward(self,x, mask = None): for encoder_layer in self.transformer_encoder: x = encoder_layer(x, mask) # [B, d_model] # Get learned class tokens x = x[:, 0, :] # Project to shared embedding space if self.projection is not None: x = x @ self.projection # [B, emb_dim] x = x / torch.norm(x , dim = -1 , keepdim = True) return x

TOKENIZATION

def tokenizer(text, encode=True, mask=None, max_seq_length=32): if encode: # Adding SOT and EOT tokens out = chr(2) + text + chr(3) # Truncate if length exceeds max_seq_length if len(out) > max_seq_length: out = out[:max_seq_length] # Add padding if needed out = out + "".join([chr(0) for _ in range(max_seq_length - len(out))]) # Encode the text out = torch.IntTensor(list(out.encode("utf-8"))) # Create the mask mask = torch.ones(len(out.nonzero())) # Pad the mask to max_seq_length if len(mask) < max_seq_length: mask = torch.cat((mask, torch.zeros(max_seq_length - len(mask)))).type(torch.IntTensor) else: mask = mask.type(torch.IntTensor) else: # Decode the text out = [chr(x) for x in text[1:len(mask.nonzero()) - 1]] out = "".join(out) mask = None return out, mask

In the text encoder section, we will use a UTF-8-based encoder for tokenization, where each word’s start and end have special tokens. The words are then encoded with additional padding tokens to match the max sequence length. A binary mask is created where padded tokens hold a ‘0‘ value and non-padded tokens hold a value of ‘1’, which helps the model to focus on meaningful tokens.

TEXT ENCODER

Here, the text encoder takes tokenized text sequences and converts them into text embeddings in the initial step. The subsequent steps include adding positional embeddings similar to the vision encoder and the combined embeddings are passed through the n layers of the transformer encoder. The output of the Text Encoder is the same as the shape of the Vision Encoder in order to accomodate further operations that we are going to perform with the CLIP model.

class TextEncoder(nn.Module): def __init__(self, vocab_size, d_model, max_seq_length, n_layers,n_heads, emb_dim): super().__init__() self.max_seq_length = max_seq_length self.embed = nn.Embedding(vocab_size, d_model) self.positional_embedding = PositionalEmbedding(d_model, max_seq_length) self.transformer_encoder = nn.ModuleList([TransformerEncoder(d_model, n_heads) for _ in range(n_layers)]) self.projection = nn.Parameter(torch.randn(d_model, emb_dim)) # For training def forward(self, text, mask = None): x = self.embed(text) x = self.positional_embedding(x) for encoder_layer in self.transformer_encoder: x = encoder_layer(x, mask=mask) #The output of the encoder layers is the text features. We are going to be using the features from the EOT embedding. x = x[torch.arange(text.shape[0]), torch.sub(torch.sum(mask[:,0],dim=1),1)] if self.projection is not None: x = x @ self.projection x = x / torch.norm(x, dim=-1, keepdim = True) return x

Configuration of Contrastive Language Image Pretraining (CLIP)

# Visionemb_dim = 128 vit_d_model = 32 # vit_heads * vit_layers = vit_d_modelimg_size = (80,80)patch_size = (5,5) n_channels = 3vit_layers = 8vit_heads = 4 # Textvocab_size = 256text_d_model = 64 # --> text_heads * text_layers = text_d_modelmax_seq_length = 128text_heads = 8text_layers = 8lr = 1e-3epochs = 50batch_size = 128
  • emb_dim: Shared embedding space dimension for both Vision and Text Encoder.
  • vit_d_model: Hidden size of vision encoder.
  • vit_heads: Number of attention heads in the Vision Encoder.
  • vit_layers: Number of Transformer Encoder layers in Vision Encoder.
  • vocab_size: As the tokenizer is a UTF-8, a vocab size of 256 indicates the 256 possible unique characters in the encoding.
  • max_seq_length: Maximum length of the input sequence; shorter sequences will be padded.
  • text_heads: Number of attention heads in the Text Encoder.

Our training configuration is of a batch size of 128 with a learning rate of 0.001 for 50 epochs

CLIP MODEL

Now it’s time to integrate all the previous components into our CLIP model. The image and text encodings are matrix multiplied to obtain the output logits or similarity. Additionally, a learnable temperature parameter then scales the logits before calculating the loss.

For ground truth labels, we generate a 1D tensor with values ranging from 0 to batch_size -1. The CLIP loss uses a symmetric contrastive loss to align text and image embeddings to maximize the similarity of relevant text-image pairs and minimize for irrelevant or dissimilar ones.

class CLIP(nn.Module): def __init__(self, emb_dim, vit_layers, vit_d_model, img_size, patch_size, n_channels, vit_heads, vocab_size, max_seq_length, text_heads, text_layers, text_d_model, retrieval = False): super().__init__() self.vision_encoder = VisionEncoder(vit_d_model, img_size, patch_size, n_channels, vit_heads, vit_layers, emb_dim) # print(retrieval) if retrieval: self.text_encoder = TextEncoder_Retrieval(vocab_size, text_d_model, max_seq_length, text_layers, text_heads, emb_dim) else: self.text_encoder = TextEncoder(vocab_size, text_d_model, max_seq_length, text_layers, text_heads, emb_dim) self.temperature = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def CLIPLoss(self, logits, device = "cuda"): #Symmetric or Contrastive loss # arange generates a list between 0 and n-1 labels = torch.arange(logits.shape[0]).to(device) # For row 1 we want 1,1 to be max, and row n-1 we want (n-1,n-1) text pairs to be max --> time 15.43 umar loss_v = nn.functional.cross_entropy(logits.transpose(-2,-1), labels) loss_t = nn.functional.cross_entropy(logits, labels) loss = (loss_v + loss_t) / 2 return loss
class CLIP(nn.Module):... def forward(self, image, text, mask=None): V_e = self.vision_encoder(image) # Vision encoder output [B, emb_dim] T_e = self.text_encoder(text, mask) # Text encoder output [B, emb_dim] # print(f"V_e shape: {V_e.shape}, T_e shape: {T_e.shape}") logits = (V_e @ T_e.transpose(-2, -1)) * torch.exp(self.temperature) loss = self.CLIPLoss(logits, self.device) return loss

For instance, Before training the image (Messi’s Jersey) and its relevant text labels (Sports Dress) might be far apart in the latent space. However, after training due to CLIP’s contrastive loss (CLIPLoss), they pull these encodings to come very close to each other.

Now we are all set with the CLIP model components. For sanity check let’s initialize our model with our configurations.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")model = CLIP(emb_dim, vit_layers, vit_d_model, img_size,patch_size,n_channels, vit_heads, vocab_size, max_seq_length, text_heads, text_layers,text_d_model, retrieval = False).to(device)optimizer = optim.AdamW(model.parameters(), lr=lr)total_params = 0total_params = sum([ param.numel() for param in model.parameters() if param.requires_grad])print(f"Total number of trainable parameters: {total_params}; i.e., {total_params/1000000:.2f} M")#Total number of trainable parameters: 532641; i.e., 0.53 M

The total parameter size of our CLIP like model is around 0.53 M including both vision and text encoder. And we see everything is perfect. Now let’s move on to the dataset preparation for the Fashion Images Dataset from Kaggle.

DATASET PREPARATION

Our dataset is a Fashion E-commerce subset, containing 44 classes and approximately 44k image samples.

!kaggle datasets download -d paramaggarwal/fashion-product-images-small -q

However, it is a heavily imbalanced dataset with many classes under represented. Original image dimensions are 80x60x3.

Classes: ['Accessories' 'Apparel Set' 'Bags' 'Bath and Body' 'Beauty Accessories' 'Belts' 'Bottomwear' 'Cufflinks' 'Dress' 'Eyes' 'Eyewear' 'Flip Flops' 'Fragrance' 'Free Gifts' 'Gloves' 'Hair' 'Headwear' 'Home Furnishing' 'Innerwear' 'Jewellery' 'Lips' 'Loungewear and Nightwear' 'Makeup' 'Mufflers' 'Nails' 'Perfumes' 'Sandal' 'Saree' 'Scarves' 'Shoe Accessories' 'Shoes' 'Skin' 'Skin Care' 'Socks' 'Sports Accessories' 'Sports Equipment' 'Stoles' 'Ties' 'Topwear' 'Umbrellas' 'Vouchers' 'Wallets' 'Watches' 'Water Bottle' 'Wristbands']

For this training, we just focus on two key columns: the image id column and the subCategory which contains the product labels. Finally a dictionary of captions is created with idx and the class names are mapped.

# Load the datasetdf = pd.read_csv('fashion/myntradataset/styles.csv', usecols=['id', 'subCategory'])unique, counts = np.unique(df["subCategory"].tolist(), return_counts = True)print(f"Classes: {unique}: {counts}")# Split the dataset into training and validation setstrain_df, val_df = train_test_split(df, test_size=0.10, random_state=42)# Print the sizes of the datasetsprint(f"Train size: {len(train_df)}, Validation size: {len(val_df)}")class_names = df['subCategory'].unique()class_names = [str(name).lower() for name in class_names]# Replace in-placefor i, name in enumerate(class_names): if name == "lips": class_names[i] = "lipstick" elif name == "eyes": class_names[i] = "eyelash" elif name == "nails": class_names[i] = "nail polish"captions = {idx: class_name for idx, class_name in enumerate(class_names)}for idx, caption in captions.items(): print(f"{idx}: {caption}\n") 

Let’s prepare our custom Dataset class.

class MyntraDataset(Dataset): def __init__(self, data_frame, captions, target_size=28): self.data_frame = data_frame[data_frame['subCategory'].str.lower() != 'innerwear'] self.target_size = target_size # Desired size for the square image self.transform = T.Compose([ T.ToTensor() # Convert image to tensor ]) self.captions = captions def __len__(self): return len(self.data_frame)

The __getitem__( ) method retrieves and processes the image and its corresponding captions. The caption is then tokenized using our tokenizer which will be the input to the text encoder.

class MyntraDataset(Dataset):... def __getitem__(self, idx): while True: sample = self.data_frame.iloc[idx] img_path = os.path.join("fashion/myntradataset/images", f"{sample['id']}.jpg") try: # Attempt to open the image image = Image.open(img_path).convert('RGB') except (FileNotFoundError, IOError): # If the image is not found, skip this sample by incrementing the index idx = (idx + 1) % len(self.data_frame) # Loop back to the start if we reach the end continue # Retry with the next index # Resize the image to maintain aspect ratio image = self.resize_and_pad(image, self.target_size) # Apply transformations (convert to tensor) image = self.transform(image) # Retrieve the subCategory label and its corresponding caption label = sample['subCategory'].lower() label = {"lips": "lipstick", "eyes": "eyelash", "nails": "nail polish"}.get(label, label) label_idx = next(idx for idx, class_name in self.captions.items() if class_name == label) # # print(label_idx) # # print(self.captions[label_idx]) # # Tokenize the caption using the tokenizer function cap, mask = tokenizer(self.captions[label_idx]) # Make sure the mask is a tensor mask = torch.tensor(mask) # If the mask is a single dimension, make sure it is expanded correctly if len(mask.size()) == 1: mask = mask.unsqueeze(0) return {"image": image, "caption": cap, "mask": mask,"id": img_path} 

To maintain the aspect ratio with training img_size as80×80, we will resize and pad the dimension whose side is less than the other dimension, with black pixels.

class MyntraDataset(Dataset):... def resize_and_pad(self, image, target_size): original_width, original_height = image.size aspect_ratio = original_width / original_height if aspect_ratio > 1: new_width = target_size new_height = int(target_size / aspect_ratio) else: new_height = target_size new_width = int(target_size * aspect_ratio) image = image.resize((new_width, new_height)) pad_width = (target_size - new_width) // 2 pad_height = (target_size - new_height) // 2 padding = (pad_width, pad_height, target_size - new_width - pad_width, target_size - new_height - pad_height) image = ImageOps.expand(image, padding, fill=(0, 0, 0)) return image

The training and validation data loader is prepared with a batch size of 128.

train_dataset = MyntraDataset(data_frame=train_df ,captions = captions, target_size =80)val_dataset = MyntraDataset(data_frame=val_df ,captions = captions ,target_size =80)test_dataset = MyntraDataset(data_frame=val_df, captions = captions, target_size = 224)print("Number of Samples in Train Dataset:", len(train_dataset))print("Number of Samples in Validation Dataset:", len(val_dataset))#Number of Samples in Train Dataset: 38360#Number of Samples in Validation Dataset: 4278train_loader = DataLoader(train_dataset, shuffle = True, batch_size = batch_size,num_workers = 5)val_loader = DataLoader(val_dataset, shuffle = False, batch_size = batch_size,num_workers = 5)test_loader = DataLoader(test_dataset, shuffle = False, batch_size = batch_size, num_workers = 5)#Sanity check of dataloader initializationlen(next(iter(train_loader))) #(img_tensor,label_tensor)#4

TRAINING

The following training loop computes the loss for each batch and the best model is saved according to the best average loss trained for 50 epochs.

best_loss = np.inffor epoch in range(epochs): epoch_loss = 0.0 # To accumulate the loss over the epoch with tqdm(enumerate(train_loader, 0), total=len(train_loader), desc=f"Epoch [{epoch+1}/{epochs}]") as tepoch: for i, data in tepoch: img, cap, mask = data["image"].to(device), data["caption"].to(device), data["mask"].to(device) optimizer.zero_grad() loss = model(img, cap, mask) loss.backward() optimizer.step() # Update the progress bar with the current loss tepoch.set_postfix(loss=loss.item()) epoch_loss += loss.item() avg_loss = epoch_loss / len(train_loader) print(f"Epoch [{epoch+1}/{epochs}], Average Loss: {avg_loss:.3f}") # Save model if it performed better than the previous best if avg_loss <= best_loss: best_loss = avg_loss torch.save(model.state_dict(), "clip.pt") print("Model Saved.")

At the end of training, we evaluate on the val_loader by loading the best model weight. The image and text features are forward passed to the CLIP model’s vision and text encoder separately. Then the image and text features are normalized with L2 Norm along the last dimension. From this, whichever image and text pairs have maximum similarity, its indices are filtered and used for making predictions.

# Loading Best Modelmodel = CLIP(emb_dim, vit_layers, vit_d_model, img_size,patch_size,n_channels, vit_heads, vocab_size, max_seq_length, text_heads, text_layers,text_d_model,retrieval = False).to(device)model.load_state_dict(torch.load("clip.pt", map_location=device))# print([x for x in val_dataset.captions.values()])# Getting dataset captions to compare images totext = torch.stack([tokenizer(x)[0] for x in val_dataset.captions.values()]).to(device)# print(text)mask = torch.stack([tokenizer(x)[1] for x in val_dataset.captions.values()])mask = mask.repeat(1,len(mask[0])).reshape(len(mask),len(mask[0]),len(mask[0])).to(device)correct, total = 0,0with torch.no_grad(): for data in val_loader: images, labels = data["image"].to(device), data["caption"].to(device) image_features = model.vision_encoder(images) text_features = model.text_encoder(text, mask=mask) image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) _, indices = torch.max(similarity,1) pred = torch.stack([tokenizer(val_dataset.captions[int(i)])[0] for i in indices]).to(device) correct += int(sum(torch.sum((pred==labels),dim=1)//len(pred[0]))) # print(pred.shape) total += len(labels)print(f'\nModel Accuracy: {100 * correct // total} %')#The tokenized ground truth caption labels and predicted labels were compared, and we obtained 85% accuracy on the validation dataset.

The tokenized ground truth caption labels and predicted labels were compared, and we obtained 85% accuracy on the validation dataset.

E-Commerce Image Retrieval App with Gradio

Now, we will move on to the application part.Till the similarity calculation, it’s the same logic that we saw in the last block. By using topk we obtain indices of 30 samples with the highest similarity scores.

# Load the model and tokenizerretrieval_model = CLIP(emb_dim, vit_layers, vit_d_model, img_size, patch_size, n_channels, vit_heads, vocab_size, max_seq_length, text_heads, text_layers, text_d_model, retrieval=True).to(device)retrieval_model.load_state_dict(torch.load("clip.pt", map_location=device))# Function to process the query and return the top 20 imagesdef retrieve_images(query): query_text, query_mask = tokenizer(query) query_text = query_text.unsqueeze(0).to(device) # Add batch dimension query_mask = query_mask.unsqueeze(0).to(device) with torch.no_grad(): query_features = retrieval_model.text_encoder(query_text, mask=query_mask) query_features /= query_features.norm(dim=-1, keepdim=True) # Step 2: Encode all images in the dataset and store features image_features_list = [] image_paths = [] val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=5) with torch.no_grad(): for batch in val_loader: images = batch["image"].to(device) features = retrieval_model.vision_encoder(images) features /= features.norm(dim=-1, keepdim=True) image_features_list.append(features) image_paths.extend(batch["id"]) # Assuming batch contains image paths or IDs # Concatenate all image features image_features = torch.cat(image_features_list, dim=0) # Step 3: Compute similarity using the CLIP model's logic similarities = (query_features @ image_features.T) * torch.exp(retrieval_model.temperature_clip) similarities = similarities.softmax(dim=-1) # Retrieve top 20 matches top_values, top_indices = similarities.topk(20) # Step 4: Retrieve and display top N images images_to_display = [] for value, index in zip(top_values[0], top_indices[0]): img_path = image_paths[index] img = Image.open(img_path).convert("RGB") images_to_display.append(np.array(img)) return images_to_display

Then we import gradio and create a simple app with the gradio_app() function which accepts a query from the input box and displays the retrieved images. The gradio interface has several blocks like input query, search bar and wide column (gallery) to show the retrieved images.

# Define the Gradio interfacedef gradio_app(query): images = retrieve_images(query) return images# Create Gradio Interfacewith gr.Blocks() as interface: # Centered title gr.Markdown("<h1 style='text-align: center;'> 👒 Image Retrieval with CLIP - 👔👖 E-commerce Fashion 👚🥻</h1>") with gr.Row(): # Textbox for query input query_input = gr.Textbox(placeholder="Enter your search query...", show_label=False) # Small submit button submit_btn = gr.Button("Search 🔍", elem_id="small-submit-btn") # Gallery output for displaying images gallery_output = gr.Gallery(label="Top 20 Matches").style(grid=[4], container=True) # Link the submit button to the function submit_btn.click(fn=gradio_app, inputs=query_input, outputs=gallery_output) # Custom CSS to make the submit button small gr.HTML(""" <style> #small-submit-btn { padding: 0.5rem 1rem; font-size: 0.8rem; } </style> """)# Launch the appinterface.launch()

Key Takeaways

  • Our results were satisfactory with decent retrieval for certain classes which are well represented. However, we found that classes with fewer num_samples weren’t learned effectively. Another observation is that querying with capital v/s lower characters had an obvious difference in the quality of retrieval because of tokenization.
  • As our focus of this article is just to introduce training a naive CLIP like architecture from scratch, improving the search results and model accuracy is subjected to further experiments on different datasets. We can also make use of pretrained weights of Vision Encoders like ViTor ResNet and for Text Encoders we can adapt BERT for text encoding, which might improve the accuracy and performance of this orchestration to a huge extent.

TRIVIA

Failure Cases like Digit Recognition

Pre-trained CLIPstruggles to correctly differentiate between MNIST digits, which is relatively simple for an MLP. This might be due to CLIP’s broader focus on natural image features rather than simple fine-grained features specific to digit recognition.

labels = [“one”, “two”, “three”, “four”, “five”,”six”,”seven”,”eight”,”nine”]

Conclusion

The motivation behind training CLIP model architecture from scratch is to appreciate the beauty in mathematical intuition and the learnings we had from papers of CLIP and SigCLIP. Pretrained CLIP or SigCLIP can be combined with decoders (LLM) and can be extended to perform Vision Language Model tasks. Apart from classification CLIP can shine well in Zero shot Segmentation (ZegCLIP) as well.

If you are intrigued by the exceptional zero-shot image understanding capabilities of VLMs like LLaVA, Florence-2, and PaliGemma, starting with applications of CLIP or SigLIP can be a great choice. If you are planning to build anything interesting with VLM, we would love to hear them in the comments.

References

  1. Special thanks to Matt Nyugen for his gem of article on CLIP.
  2. Learning Transferable Visual Models From Natural Language Supervision
  3. Attention is All You Need
  4. SigLip
  5. Umar Jamil – Coding PaliGemma From Scratch
  6. RASA – Transformers

    Subscribe & Download Code

    If you liked this article and would like to download code (C++ and Python) and example images used in this post, please click here. Alternately, sign up to receive a free Computer Vision Resource Guide. In our newsletter, we share OpenCV tutorials and examples written in C++/Python, and Computer Vision and Machine Learning algorithms and news.

    Download Example Code

Training CLIP Model from Scratch for an Image Retrieval App (2024)
Top Articles
Latest Posts
Recommended Articles
Article information

Author: Terence Hammes MD

Last Updated:

Views: 5363

Rating: 4.9 / 5 (69 voted)

Reviews: 84% of readers found this page helpful

Author information

Name: Terence Hammes MD

Birthday: 1992-04-11

Address: Suite 408 9446 Mercy Mews, West Roxie, CT 04904

Phone: +50312511349175

Job: Product Consulting Liaison

Hobby: Jogging, Motor sports, Nordic skating, Jigsaw puzzles, Bird watching, Nordic skating, Sculpting

Introduction: My name is Terence Hammes MD, I am a inexpensive, energetic, jolly, faithful, cheerful, proud, rich person who loves writing and wants to share my knowledge and understanding with you.