5.4. Contextual Word Embeddings and Text Embeddings#

  • Last update: 11.11.2024

  • Author: Johannes Maucher

Word Embeddings has been described in section Representations. These Word Embeddings (), e.g. CBOW, Skipgram, Glove or the FastText subword embeddings are not contextual. This means that a word is mapped to a unique vector, independend of the context in which the word appears. This is a drawback because the meaning (semantic) of a word depends on the context (the surrounding words), in which the word appears. That’s where contextual Word Embeddings come into play. In a contextual Word Embeddings the vector representation of a single word or token, varies with the word’s context. For example the word play will be mapped to another vector in the context of the sentence

  • the members of the house of parliament play a crucial role in this debate

than in the context of the sentence

  • girls like to play with dolls

Contextual word embeddings can be learned by transformers, as will be described in section Transformers. A common approach is to apply the pre-trained encoder-only transformer BERT. This approach will be demonstrated below.

5.4.1. Load BERT Model and corresponding Tokenizer#

In this section we will load a pre-trained BERT model and the corresponding tokenizer from HuggingFace. BERT itself will be described in section Transformers.

Import required Python modules:

import random
import torch
from transformers import BertTokenizer, BertModel
from sklearn.metrics.pairwise import cosine_similarity
random_seed = 1234
random.seed(random_seed)

# Set a random seed for PyTorch (for GPU as well)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)

Load pre-trained BERT Model and associated Tokenizer:

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

5.4.2. Define Input Input Sentence and Tokenize it#

We start from an arbitrary sentence, which is assigned to the variable text:

text = "The lesson on artificial intelligence will be held on Monday"

Next, we tokenize the example sentence. For this we apply the batch_encode_plus()-method, which can also be applied for a bunch of texts, each represented as a variable of type string. The method automatically integrates the CLS- and the SEP-special token.

encoding = tokenizer.batch_encode_plus( [text],# List of input texts
    padding=True,              # Pad to the maximum sequence length
    truncation=True,           # Truncate to the maximum sequence length if necessary
    return_tensors='pt',      # Return PyTorch tensors
    add_special_tokens=True    # Add special tokens CLS and SEP
)

The tokenizer-method returns a dictionary (encoding in the cell below), which keeps the keys input_ids and attention_mask. The values of these keys are printed below:

input_ids = encoding['input_ids']  # Token IDs
# print input IDs
print(f"Input ID: {input_ids}")
attention_mask = encoding['attention_mask']  # Attention mask
# print attention mask
print(f"Attention mask: {attention_mask}")
Input ID: tensor([[  101,  1996, 10800,  2006,  7976,  4454,  2097,  2022,  2218,  2006,
          6928,   102]])
Attention mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

The attention mask defines the attention, which is drawn to each single token. Since in our example all elements of the attention_mask are the same, each token is considered equally in the calculation of the contextual embedding.

5.4.3. Generate Word-Embeddings#

# Generate embeddings using BERT model
with torch.no_grad():
    outputs = model(input_ids, attention_mask=attention_mask)
    word_embeddings = outputs.last_hidden_state  # This contains the embeddings

# Output the shape of word embeddings
print(f"Shape of Word Embeddings: {word_embeddings.shape}")
Shape of Word Embeddings: torch.Size([1, 12, 768])

As can be seen in the output above, BERT outputs for each of the 12 words in the input a contextual word embedding of length 768.

Below we just demonstrate the tokenizer’s encode() and decode()-method:

# Decode the token IDs back to text
decoded_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
#print decoded text
print(f"Decoded Text: {decoded_text}")
# Tokenize the text again for reference
tokenized_text = tokenizer.tokenize(decoded_text)
#print tokenized text
print(f"tokenized Text: {tokenized_text}")
# Encode the text
encoded_text = tokenizer.encode(text, return_tensors='pt')  # Returns a tensor
# Print encoded text
print(f"Encoded Text: {encoded_text}")
Decoded Text: the lesson on artificial intelligence will be held on monday
tokenized Text: ['the', 'lesson', 'on', 'artificial', 'intelligence', 'will', 'be', 'held', 'on', 'monday']
Encoded Text: tensor([[  101,  1996, 10800,  2006,  7976,  4454,  2097,  2022,  2218,  2006,
          6928,   102]])

Run the following code-cell only if you like to see the 12 embeddings, each of length 768. We don’t execute this cell, because of the very long output:

#for token, embedding in zip(tokenized_text, word_embeddings[0]):
#    #print(f"Token: {token}")
#    print(f"Embedding: {embedding}")
#    print("\n")

Next, we calculate the text-embedding from the contextual word embeddings at the BERT-encoder’s output. Note that are different methods to calculate a text embedding from the list of word embeddings. The most popular options are:

  1. The encoder’s output at the position of the CLS-token (first position in the input sequence) is a representation of the entire text and therefore can be applied as text-embedding.

  2. Just calculate the mean of all contextual word embeddings at the output of the transformer-encoder.

Below we implement the second option:

sentence_embedding = word_embeddings.mean(dim=1)
print(f"Shape of Sentence Embedding: {sentence_embedding.shape}")
# Print the sentence embedding
print("Sentence Embedding:")
print(sentence_embedding)
Shape of Sentence Embedding: torch.Size([1, 768])
Sentence Embedding:
tensor([[-3.5001e-01, -3.3019e-01,  2.6465e-01,  7.5814e-02,  2.6252e-01,
         -4.2781e-01,  3.2828e-01,  6.8058e-01, -1.1294e-01, -1.3759e-01,
         -2.1079e-01, -2.7575e-01,  1.1764e-01,  1.5474e-01, -1.3306e-01,
         -6.5089e-03,  1.0366e-01, -2.5773e-02,  6.5012e-02,  4.5614e-02,
         -8.1534e-02, -9.1516e-02, -9.6652e-02,  6.5698e-01,  3.4971e-01,
          1.5745e-01, -2.1174e-02,  4.4196e-01, -1.7211e-01, -1.0172e-01,
         -2.7239e-02,  8.8936e-02, -7.1396e-02,  1.2314e-01,  1.8995e-01,
          1.0141e-01,  1.3849e-01, -2.6853e-01, -3.7551e-01,  5.1383e-01,
         -4.5088e-01, -8.5676e-02,  3.5888e-01, -8.5598e-02, -3.4887e-01,
         -3.2465e-01,  2.8293e-01, -2.6047e-01,  9.4687e-02, -3.1299e-01,
         -6.5420e-01,  3.6622e-01,  2.0525e-01,  1.6598e-01, -1.8717e-02,
          3.6724e-01,  4.8422e-02, -1.5004e-01, -2.3705e-01, -1.4570e-01,
          3.3358e-02,  1.8378e-01, -3.5432e-01, -3.2449e-01,  4.1561e-01,
          2.9656e-01,  2.2482e-02,  6.1011e-02, -5.8971e-01, -5.3985e-02,
         -7.3305e-01,  2.3496e-02,  3.8414e-02,  1.3239e-01, -2.0427e-01,
         -1.1042e-01,  1.4433e-01,  2.5816e-01,  1.5394e-01, -2.2436e-01,
         -2.3106e-02,  4.3833e-01,  3.8140e-01,  2.4284e-01,  3.2330e-01,
          4.8820e-01, -3.6262e-01, -4.5515e-02, -4.6194e-01,  5.3348e-01,
         -1.6715e-01, -1.2979e-01, -1.6823e-02,  2.2146e-01,  2.5272e-01,
         -4.0833e-03, -8.4667e-02,  2.3374e-02, -1.8031e-01,  5.2358e-01,
          3.6706e-01, -3.6619e-01, -1.0291e-01,  3.6938e-01,  3.5335e-01,
         -1.5034e-01, -3.5996e-01, -6.9851e-01,  1.5057e-03,  8.2513e-02,
         -1.0837e-01, -1.4382e-01,  2.8829e-02, -4.0238e-01, -3.9316e-01,
          1.7526e-01, -2.8134e-02, -7.4224e-02, -2.6436e-01,  8.5258e-02,
         -2.2381e-01, -3.5146e-02,  2.9077e-01,  8.7016e-01, -3.0988e-01,
          2.5399e-01,  6.1589e-02,  3.6281e-01,  2.3399e-01, -4.6509e-01,
          3.7167e-01,  2.6715e-01,  6.8925e-01, -4.8475e-01, -1.4224e-02,
          2.1432e-01,  1.4330e-01, -2.2511e-01, -3.2580e-01,  4.6141e-01,
         -2.1279e-01, -1.2618e-01,  2.7136e-01, -5.1448e-01,  3.4573e-01,
          1.5227e-03, -2.3293e-01, -3.1338e-01, -1.7379e-01,  7.6039e-02,
         -5.6773e-05,  2.7782e-01, -6.1375e-01, -2.3112e-01, -2.6738e-01,
         -1.6650e-01, -1.7668e-01,  3.2612e-01,  1.8533e-02,  2.2091e-01,
          2.4410e-01, -3.2489e-01, -4.4623e-02,  2.8956e-01, -2.4651e-01,
          4.0018e-01,  1.2866e-01,  8.7871e-02,  2.9744e-01,  1.9692e-01,
         -3.2523e-01, -1.4092e-01,  5.1488e-01, -8.8644e-02, -1.4185e-01,
         -1.8024e-01, -3.3525e-01, -1.0060e-01,  5.8080e-02,  4.4633e-02,
         -9.4789e-01,  3.4185e-02,  1.8970e-01, -3.8252e-01,  7.1670e-02,
         -4.9928e-01,  1.6674e-01, -4.2020e-01, -2.9474e-01,  4.1578e-01,
         -3.0630e-01,  1.6033e-01, -4.2800e-01, -3.7009e-01,  5.4866e-01,
         -3.3470e-01,  1.0239e-01,  1.5431e-01, -4.5977e-01,  3.8354e-01,
          3.6915e-01, -7.4490e-02,  5.9416e-02,  3.2153e-01, -3.1410e-01,
          6.1120e-02, -8.8887e-02, -4.9073e-01, -1.9812e-01,  6.4406e-02,
         -2.9210e-01,  4.5945e-01,  7.2102e-02, -2.6303e-02,  2.4272e-01,
         -2.4792e-01,  4.9993e-02,  2.1967e-01, -3.7672e-01, -1.7933e-01,
          2.7616e-01,  2.6674e-01, -4.0560e-01,  4.4782e-01,  3.0446e-02,
          3.3901e-01, -1.0932e-02,  9.4912e-02,  2.4630e-01,  5.3611e-02,
         -2.0549e-01, -2.9059e-01,  2.4120e-01,  3.3106e-02, -1.3610e-02,
         -2.5216e-01, -7.8993e-01, -2.2043e-01,  1.0860e-01, -4.3867e-01,
         -1.3193e-01,  1.8721e-01,  1.1624e-01,  1.2369e-01,  3.4737e-01,
         -1.7266e-01,  2.1957e-01,  4.3804e-01, -9.4424e-02, -3.5881e-01,
         -4.2728e-01, -9.7560e-01, -5.6313e-02, -4.3410e-01, -6.1033e-01,
          4.9483e-02, -4.7547e-01, -3.3615e-02,  2.7971e-01,  4.4012e-01,
          2.6933e-01,  2.4647e-01,  3.2145e-01,  1.9147e-01, -1.8743e-01,
         -5.1153e-01,  4.6308e-02, -1.8047e-01,  2.4499e-01,  4.6406e-02,
          2.6046e-02, -6.1334e-01,  4.6881e-01,  1.5311e-01, -4.2407e-01,
         -5.3723e-01,  3.9752e-01,  3.2234e-02, -2.2177e-01,  3.1721e-03,
          1.5712e-01,  4.1237e-01, -3.2868e-01,  4.7417e-01,  6.7273e-02,
         -3.8105e-01,  1.6562e-01, -2.2771e-01, -2.1024e-01, -2.7373e-01,
         -7.0715e-02,  5.3727e-01, -1.6749e-01, -2.6956e-01,  9.9304e-02,
          1.2754e-01,  1.5457e-01, -1.9934e-01,  2.5971e-01, -2.5183e-01,
          3.2693e-02,  4.1218e-02, -3.1145e-02,  1.1443e-01, -1.2888e-01,
         -3.2929e-03, -1.3154e-01, -2.1232e-01, -4.0269e+00, -6.0261e-02,
         -1.2884e-01,  1.3271e-01,  2.1156e-01,  6.5302e-02,  7.5073e-03,
         -8.8058e-02, -5.0960e-01, -2.6623e-01, -4.8730e-01, -2.1142e-01,
          3.3794e-01,  1.7242e-01,  2.9090e-01,  1.3112e-01,  5.8265e-01,
         -2.3833e-01,  1.0395e-01,  5.4858e-01, -2.4176e-01, -2.4934e-01,
         -2.0685e-02, -7.9416e-02,  4.1731e-01,  1.0205e-01, -5.3768e-01,
          1.9558e-01, -3.5325e-01, -7.3941e-03,  1.4926e-01, -5.9997e-02,
          3.4697e-01,  2.0086e-01,  1.8412e-01, -1.5491e-02, -1.0581e-01,
         -1.1161e-03,  1.9717e-01,  1.1820e-02, -4.6035e-02,  9.1656e-02,
         -2.7942e-01, -3.2009e-03,  9.4838e-01, -1.4958e-01,  4.1843e-02,
         -2.7940e-01, -6.7275e-02,  6.5053e-02, -1.0996e-01,  1.7743e-01,
         -1.5813e-01, -7.0556e-02, -2.1256e-01, -1.4250e-01,  4.0637e-01,
          1.9027e-01, -1.8110e-01, -3.8689e-01,  4.0752e-01, -1.9994e-01,
         -3.3271e-01,  2.1929e-02, -1.0427e-01, -2.4982e-01, -6.1833e-01,
         -3.3414e-01,  9.3490e-02,  1.8034e-01,  2.6657e-01, -3.7683e-02,
         -5.9592e-01, -9.9460e-01, -3.1050e-01,  4.9618e-02, -1.8509e-02,
         -1.9436e-01,  3.6091e-01,  1.1406e-01,  9.9842e-02, -6.3950e-01,
         -5.0970e-04,  1.7271e-01,  6.2532e-02, -2.7805e-01,  9.7467e-02,
         -1.0300e-01, -1.7447e-01, -4.6657e-01,  6.9973e-02,  4.7557e-01,
          2.8477e-02, -5.7240e-02,  1.1543e-02,  9.7917e-02,  4.2295e-01,
         -2.7138e-01,  1.5050e-01,  2.2147e-01, -1.7118e-01, -6.0872e-01,
          2.8294e-01,  9.7952e-02,  3.4805e-01,  5.8905e-02, -3.4898e-01,
          2.9768e-01,  3.2487e-02,  9.4686e-02,  5.3735e-01, -4.1341e-01,
          3.9576e-01, -1.5443e-01,  5.1577e-03,  1.7772e-01, -2.7315e-02,
          5.1828e-01, -7.3625e-02, -8.1733e-02, -9.8084e-02,  2.2278e-01,
         -8.4016e-03,  1.1859e-01, -2.8934e-01, -1.2390e-01, -6.1686e-02,
         -2.9135e-01,  1.3129e-01, -1.8670e-01,  3.4517e-02,  8.4539e-02,
         -1.2491e-01,  3.3284e-01,  2.3134e-01, -1.2112e-01, -9.6196e-02,
         -2.2862e-01,  1.7466e-01,  1.1974e-01,  1.5807e-01,  3.9687e-01,
          2.7971e-01, -2.4582e-02,  8.7744e-02,  2.1067e-01, -1.4803e-01,
         -1.3783e-02, -2.0617e-01,  1.8067e-01, -2.8827e-01, -2.3446e-01,
         -2.0790e-01, -3.0983e-01, -1.0412e-01,  1.5453e-01,  2.1305e-01,
         -1.9187e-01,  5.0736e-02, -1.3197e-01, -1.9838e-01, -1.4006e-01,
         -1.2290e-01,  3.0371e-01, -4.0152e-02,  7.6067e-01, -2.7431e-01,
         -6.8675e-02, -3.8117e-01,  7.5752e-02, -1.5208e-01, -1.4443e-01,
         -2.9547e-02, -3.5503e-01,  1.8192e-01, -2.5108e-02,  1.9621e-02,
         -5.7128e-02, -1.7207e-01,  3.7231e-01, -6.8502e-02, -2.5235e-01,
         -4.1659e-01,  1.0478e-01,  8.4711e-02,  3.4261e-01,  8.1984e-02,
          6.0531e-02,  1.1377e-01,  2.9045e-02, -7.8550e-02,  2.8527e-01,
         -3.6355e-01, -1.4366e-01, -5.3405e-01, -3.4555e-01,  5.4226e-01,
          6.1720e-02,  2.8459e-01,  6.7661e-02,  2.7893e-01,  6.7947e-02,
         -2.9389e-01,  6.1849e-01,  3.4890e-01, -4.4523e-02,  1.2058e-02,
          1.8805e-01, -1.7503e-01,  1.1556e-01, -3.1747e-01, -1.3132e-01,
         -3.1923e-01, -2.9955e-01,  1.1621e-01, -4.7510e-01, -2.4903e-01,
          5.0876e-02, -1.7723e-01,  4.4682e-01, -3.9996e-01,  1.3958e-01,
          3.1837e-01,  4.3305e-02, -2.4848e-01, -2.6041e-01,  1.5642e-02,
         -1.1489e-01, -1.9163e-01, -4.1618e-02,  1.1694e-02, -1.2724e-01,
          2.3057e-01,  2.1819e-01, -1.8967e-01, -1.9907e-02, -5.0891e-01,
         -3.5877e-01,  7.2083e-02, -1.8683e-01, -1.2663e-01, -2.9630e-01,
         -3.1224e-02, -2.2234e-01, -2.2673e-01, -6.4638e-02, -1.9064e-01,
          1.9712e-01,  9.9471e-04, -2.8300e-02, -1.4606e-01,  1.4804e-01,
         -2.5294e-01, -6.8988e-02, -4.0995e-01, -6.0582e-01,  2.8313e-02,
         -3.8813e-02, -7.6255e-02, -3.3953e-02, -5.8762e-02, -1.8847e-01,
         -2.8234e-01,  3.4244e-01,  2.2865e-01,  4.3288e-01,  5.5052e-02,
         -1.6130e-01,  2.2415e-01,  4.4876e-01, -1.9497e-01, -2.1223e-01,
         -1.0491e-01, -1.6519e-01, -1.6083e-01,  1.3944e-01,  6.6475e-02,
         -5.8687e-01, -2.9872e-01, -1.7971e-01,  1.1832e-01,  2.1972e-01,
         -2.9425e-01, -1.0049e-01,  3.6542e-01, -4.7722e-01, -3.1582e-01,
          5.3376e-01, -1.1965e-01, -1.2218e-01, -2.5652e-02,  1.4879e-01,
         -6.7233e-02,  9.2874e-02,  1.2487e-01,  4.9830e-01,  2.5855e-01,
         -1.4293e-01, -1.1309e-01,  1.8774e-01, -1.4081e-01,  4.3225e-01,
          7.1667e-02,  1.8105e-01, -4.7223e-01,  3.2808e-01,  1.0556e-01,
         -2.2698e-01,  3.4445e-01,  3.5342e-02, -1.3785e-01,  3.2490e-01,
          5.4810e-01,  7.1389e-01, -4.8402e-01, -3.2673e-02,  6.6809e-02,
         -5.0849e-01, -3.4669e-01,  4.0453e-01, -8.7078e-02, -6.8191e-02,
          4.1731e-01, -5.4170e-02, -4.6859e-02,  4.7115e-01, -1.7083e-01,
          9.9105e-02, -1.2825e-01,  2.7454e-01, -2.0900e-01,  7.4443e-01,
         -1.7870e-01,  5.1283e-01, -2.5385e-01, -5.6341e-01,  7.2013e-02,
         -1.8784e-01,  2.7014e-01,  6.1444e-02,  4.0886e-01,  4.9157e-01,
          3.8884e-01,  2.7868e-01,  1.3078e-01,  3.0113e-02,  9.5825e-02,
          3.1781e-01,  4.1155e-01,  5.6720e-01,  2.3901e-01,  5.9801e-02,
          1.2964e-01, -2.9834e-01,  5.2350e-01,  1.1932e-01, -1.5596e-01,
          6.4928e-01,  3.5083e-01,  4.2034e-01,  5.1993e-01,  3.2745e-01,
          5.0950e-01, -5.3826e-01, -2.9767e-01,  4.2256e-01,  7.3183e-02,
          1.7045e-01, -3.5584e-01,  2.6905e-01,  2.8558e-01,  2.5299e-01,
          4.5645e-02, -3.0377e-01,  6.7047e-02,  2.7107e-01, -1.6785e-01,
         -3.8109e-01, -3.7910e-02, -1.0643e-01,  4.1419e-01,  9.4466e-02,
         -4.4282e-01, -4.5154e-02, -2.8278e-02, -4.7644e-01, -5.1879e-01,
          2.0594e-01,  7.6130e-03, -1.2532e-01, -5.3660e-02,  2.2753e-01,
          4.1270e-01, -5.2486e-01, -2.3312e-02, -1.2330e-01,  1.8143e-01,
         -4.6706e-02,  4.2303e-01,  2.0783e-01,  3.4429e-01, -3.5113e-01,
         -1.3195e-01, -2.3233e-02,  6.8567e-02, -1.8425e-01, -1.0333e-01,
          1.8997e-01, -4.5926e-01, -2.3890e-01,  3.7245e-02,  2.4184e-01,
         -7.7843e-01, -6.9345e-02,  2.4815e-01,  2.4974e-01, -1.9161e-01,
         -1.1773e-01, -1.5767e-01, -2.4952e-01, -2.2291e-01, -2.2567e-01,
          1.7225e-01,  3.7873e-01, -2.4091e-01,  3.8792e-01, -9.1646e-03,
          7.4300e-02,  8.9568e-02, -3.7235e-01,  2.3798e-01,  7.5409e-02,
          3.5144e-01, -1.3033e-01,  1.0079e-01, -6.0263e-02, -3.6118e-01,
          3.2302e-01, -8.2100e-02, -1.8806e-01,  4.9064e-02, -1.2287e-01,
          2.8884e-01, -9.0585e-02, -2.6741e-01, -3.8880e-01, -1.4117e-01,
         -9.0293e-02, -2.0516e-01, -4.1247e-01,  1.4714e-01,  6.7941e-02,
          4.5147e-01, -3.0173e-01, -1.4309e-01,  1.5541e-02, -3.4537e-02,
         -9.3550e-02,  3.1032e-02,  2.7587e-02]])

5.4.4. Calculate Similarity between texts#

Up to now, we generated a text-embedding of our sample text text:

text
'The lesson on artificial intelligence will be held on Monday'

Now, we like to compare this text with other texts. For this, we first calculate the text-embeddings of the new texts. Then for each of these new texts the text-embedding is determined and the cosine similarity between our original text and all new texts is calculated. We expect, that we obtain the largest cosine similarity score for the text, whose semantics is closest to the sematics of the original text.

# Example sentence for similarity comparison
example_sentence0 = "Cuttlery and plates are piled up on the table"
example_sentence1 = "European cities suffer from severe air pollution"
example_sentence2 = "The machine learning module is in the master's program"
example_sentence3 = "Artificial intelligence is the future of technology"
example_sentence4 = "Waste of energy is a major concern for the environment"
example_sentences=[example_sentence0,example_sentence1,example_sentence2,example_sentence3,example_sentence4]
# Tokenize and encode the example sentence
example_encodings = tokenizer.batch_encode_plus(
    example_sentences,
    padding=True,
    truncation=True,
    return_tensors='pt',
    add_special_tokens=True
)
example_input_ids = example_encodings['input_ids']
#example_input_ids = example_encodings['input_ids'][1]
example_input_ids
tensor([[  101,  3013, 25091,  2100,  1998,  7766,  2024, 17835,  2039,  2006,
          1996,  2795,   102],
        [  101,  2647,  3655,  9015,  2013,  5729,  2250, 10796,   102,     0,
             0,     0,     0],
        [  101,  1996,  3698,  4083, 11336,  2003,  1999,  1996,  3040,  1005,
          1055,  2565,   102],
        [  101,  7976,  4454,  2003,  1996,  2925,  1997,  2974,   102,     0,
             0,     0,     0],
        [  101,  5949,  1997,  2943,  2003,  1037,  2350,  5142,  2005,  1996,
          4044,   102,     0]])
example_attention_mask = example_encodings['attention_mask']
# Generate embeddings for the example sentence
with torch.no_grad():
    example_outputs = model(example_input_ids, attention_mask=example_attention_mask)
    example_sentence_embedding = example_outputs.last_hidden_state.mean(dim=1)
    #example_sentence_embedding = example_outputs.last_hidden_state[0,0]
# Compute cosine similarity between the original sentence embedding and the example sentence embedding
similarity_score = cosine_similarity(sentence_embedding.reshape(1, -1), example_sentence_embedding)

# Print the similarity score
print(f"Similarities of {len(example_sentences)} example sentences and the original text \n\t",text)
for idx,sample in enumerate(example_sentences):
    print(f"\t\t{idx} : {sample} : Similarity Score = {similarity_score[0][idx]}")
Similarities of 5 example sentences and the original text 
	 The lesson on artificial intelligence will be held on Monday
		0 : Cuttlery and plates are piled up on the table : Similarity Score = 0.560899555683136
		1 : European cities suffer from severe air pollution : Similarity Score = 0.552908182144165
		2 : The machine learning module is in the master's program : Similarity Score = 0.724198043346405
		3 : Artificial intelligence is the future of technology : Similarity Score = 0.6622300744056702
		4 : Waste of energy is a major concern for the environment : Similarity Score = 0.5618582367897034

As can be seen in the output of the previous cell, we actually obtain the largest similarity-scores for the sentences, which are semantically closest to the original text.

5.4.5. Calculate Similarity between different contextual embeddings of the same word#

BERT outputs contextual word embeddings, i.e. a given word is not mapped to a unique vector. Instead the vector representations of a single word vary with the context of the word.

In the experiment below we generate 6 different sentences. In each of these sentences the word show is at the 2nd position. We like to investigate and compare the 6 different vectors generated at the BERT output for the word show.

s0 = "I show you the easiest way to solve this problem."
s1 = "They show remarkable resilience in difficult situations."
s2 = "We show our appreciation by sending a thank-you note."
s3 = "You show kindness to everyone you meet."
s4 = "Teachers show their students how to find reliable information."
s5 = "Friends show support when times get tough."
sents=[s0,s1,s2,s3,s4,s5]

Encode the 6 sentences:

# Tokenize and encode the example sentence
example_encodings = tokenizer.batch_encode_plus(
    sents,
    padding=True,
    truncation=True,
    return_tensors='pt',
    add_special_tokens=True
)

By checking the output of the following cell, we can see, that in each of the 6 sentences the ID at position 3 (index = 2) is the same. This is because the second word in each sentence is the word show. Note that the second word is at the third position, because the special token CLS has been attached in the tokenisation method above.

example_input_ids = example_encodings['input_ids']
#example_input_ids = example_encodings['input_ids'][1]
example_input_ids
tensor([[  101,  1045,  2265,  2017,  1996, 25551,  2126,  2000,  9611,  2023,
          3291,  1012,   102,     0],
        [  101,  2027,  2265,  9487, 24501, 18622, 10127,  1999,  3697,  8146,
          1012,   102,     0,     0],
        [  101,  2057,  2265,  2256, 12284,  2011,  6016,  1037,  4067,  1011,
          2017,  3602,  1012,   102],
        [  101,  2017,  2265, 16056,  2000,  3071,  2017,  3113,  1012,   102,
             0,     0,     0,     0],
        [  101,  5089,  2265,  2037,  2493,  2129,  2000,  2424, 10539,  2592,
          1012,   102,     0,     0],
        [  101,  2814,  2265,  2490,  2043,  2335,  2131,  7823,  1012,   102,
             0,     0,     0,     0]])
example_attention_mask = example_encodings['attention_mask']

Next, we generate for each of the 6 sentences the corresponding BERT output. As before, for each sentence the i.th output vector of is the contextual embedding of the i.th input word. Again, we have to take into account, that the special token CLS has been added at the first position of each sentence.

In contrast to the previous experiments, we now do not calculate the mean over all contextual word embeddings of a sentence. Instead we just pick out the contextual word embedding at postion with index 2. Because at this position we have in each of the 6 sentences the same word show.

# Generate embeddings for the second word (IDX=2) of the example sentence
IDX=2
with torch.no_grad():
    example_outputs = model(example_input_ids, attention_mask=example_attention_mask)
    #example_sentence_embedding = example_outputs.last_hidden_state.mean(dim=1)
    example_sentence_embedding = example_outputs.last_hidden_state[:,2]

The output example_sentence_embedding is a tensor of shape (6, 768), where 6 is the number of example sentences and 768 is the dimension of the embeddings.

example_sentence_embedding.shape
torch.Size([6, 768])

Next, we calculate the pairwise cosine similarity between all 6 sentences:

similarity_score = cosine_similarity(example_sentence_embedding)

We obtain a \(6 \times 6\) array, whose value at row \(i\), column \(j\) is the cosine-similarity between sentence \(i\) and sentence \(j\).

similarity_score.shape
(6, 6)
import numpy as np
np.set_printoptions(precision=3)
similarity_score
array([[1.   , 0.39 , 0.487, 0.508, 0.685, 0.472],
       [0.39 , 1.   , 0.495, 0.655, 0.483, 0.6  ],
       [0.487, 0.495, 1.   , 0.538, 0.543, 0.533],
       [0.508, 0.655, 0.538, 1.   , 0.536, 0.709],
       [0.685, 0.483, 0.543, 0.536, 1.   , 0.59 ],
       [0.472, 0.6  , 0.533, 0.709, 0.59 , 1.   ]], dtype=float32)

Finally we visualise the pairwise similarities in a heatmap:

import seaborn as sns
sns.heatmap(similarity_score, annot=True, cmap='viridis', xticklabels=sents, yticklabels=sents)
<Axes: >
../_images/52df5ce31e5b96159b9da94ac3f46d98c49b7fab6a8a00e710f8c0f39e1a9411.png

The heatmap shows, that the cosine-similarity between the embeddings of our word show in sentence \(i\) and the same word in the context of the same sentence \(j=i\) is 1. However, for \(j \neq i\) the cosine similarities are \(<1\), which proves that our word show has distinct embeddings in the distinct contexts. The embeddings in sentence 0 and sentence 1 show the smallest similarity score. This is plausible, because in sentence 1 the word show has a significantly different meaning than in sentence 0. The highest value for different sentences is obtained for sentence 3 and sentence 5. This is also plausible, because in these two sentences the meaning of the word show is similar.