Introduction

Getting started with word embeddings was a little confusing for me at first so I have decided to put together a notebook that provides simple coded examples. At the time of introducing myself to embeddings, quite frankly, the official documentation provided by PyTorch was a little sparse and I wanted something a little more dense, such as the word embeddings I set out to build. Vector jokes aside, here follows an introduction to word embeddings in PyTorch, namely, how to represent words as dense vectors and how to leverage state-of-art pre-trained embeddings.

First and foremost, arguably the most important factor when tackling a natural language processing task is the way you choose to represent words. The ability to not only capture a word’s occurrence but also its contextual and semantic meaning is key to building state-of-the-art NLP models. Furthermore, capturing this information in a dense vector, of some N-dimensions, allows us to project words into a somewhat meaningful space. This representation allows us to capture similarity of words in an N-dimensional space based on different semantic qualities (e.g. tense, plurality or gender). In contrast, the most crude representation of words is via one-hot encodings. This is known as a ‘localist’ representation where words are represented as sparse vectors and the number of dimensions is equal to the size of the vocabulary, $R^{|V|\times 1}$ . This representation has all sorts of drawbacks, mainly we are unable to calculate similarity between any pair of words due to orthogonality. For example:

$$(w^{table})^T{w}^{dog}= (w^{table})^Tw^{cat}=0$$

So dense vectors are great for representing words. They allow you to pack a heap of semantic information into a more manageable dimensional space. And given the incredible breakthroughs over the last 10 years, deep learning has been leveraged to represent words based on the context in which they are appear. And the best part is that PyTorch provides a plethora of options when it comes to building both our own embeddings or leveraging pre-trained state-of-the-art embeddings.

# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchtext

torch.manual_seed(1)
<torch._C.Generator at 0x182be38a1b0>

Let’s get started with the simplest example. First up, we need to define a dictionary with an index for each unique word in our text data set. This will act as a lookup table for when we want to dig up an embedding for a specific word. Next, we need to define the size of our embeddings and in this case we have a vocabulary size of 5 words and 10-dimensional embeddings. The embeddings are now setup and ready to generate our dense vectors.

# Setup our vocabulary
word_to_ix = {"the": 0, "mighty": 1, "big": 2, "birds": 3}

# Intitiate embedding object with random weights
embeds = nn.Embedding(4, 5)  # 2 words in vocab, 5 dimensional embeddings

# Dig up a word and generate the embedding
lookup_tensor = torch.tensor([word_to_ix["mighty"]], dtype=torch.long)

mighty_embed = embeds(lookup_tensor)
print(hello_embed)
tensor([[ 0.2673, -0.4212, -0.5107, -1.5727, -0.1232]],
       grad_fn=<EmbeddingBackward>)
# Show the initialized weights
embeds.weight
Parameter containing:
tensor([[-0.3030, -1.7618,  0.6348, -0.8044,  0.7575],
        [-0.4068, -0.1277,  0.2804,  0.0375, -0.6378],
        [-0.8148, -0.6895,  0.7705, -1.0739, -0.2015],
        [-0.5603,  0.6817, -0.5170,  1.7902,  0.5877]], requires_grad=True)

Here the nn.Embedding object initializes our vocabulary into dense vectors made up of random weights. We can then look up a specific vector for a given word. This sets the platform for learning meaningful weights using approaches such as Word2vec. As I said in the introduction, researchers have worked tirelessly to build architectures that optimize the weights for us. Next, we will leverage some of the pre-trained word embeddings offered by PyTorch.

# Let's download our pre-trained GloVe embeddings
glove_e = torchtext.vocab.GloVe(name='6B', dim=300)
.vector_cache\glove.6B.zip: 862MB [34:54, 412kB/s]                                                                                          
100%|████████████████████████████████████████████████████████████████████████████████████████████▉| 399596/400000 [02:10<00:00, 4846.87it/s]
glove_e.get_vecs_by_tokens('mighty')
tensor([ 3.4493e-01, -8.2858e-02,  2.2480e-01, -5.7307e-01, -4.9530e-01,
         7.8402e-01,  3.0057e-01,  7.1671e-01,  7.4948e-02,  2.0687e-01,
         1.8449e-01, -3.6602e-01,  5.3008e-02, -3.2370e-02, -2.3264e-01,
         5.0531e-01,  2.6115e-01,  1.0184e-01,  2.0003e-01,  8.8639e-01,
         8.1781e-02,  1.3622e-01,  2.8007e-01,  4.9476e-01, -2.7150e-01,
         1.3703e-01,  3.4939e-01, -7.1551e-01, -2.2424e-01, -1.8928e-02,
        -3.9568e-01, -3.6419e-01, -4.2634e-01, -1.1286e-01, -5.0333e-01,
         1.4272e-01, -4.4973e-01,  1.5146e-01,  1.4907e-01,  1.9503e-01,
        -3.2335e-03,  4.1720e-02, -4.8672e-01,  1.7150e-01,  1.5776e-01,
         1.0764e-01,  3.0741e-01, -4.9108e-01,  2.8661e-01, -2.3774e-01,
        -4.8432e-02,  3.4752e-01, -2.2934e-01, -1.5170e-01, -1.6065e-02,
        -2.7918e-01,  1.6342e-01,  1.6535e-01,  5.2585e-01,  3.0065e-01,
         3.7926e-02,  2.3617e-02, -3.4103e-02, -3.2240e-01, -1.7905e-01,
         1.3291e-01, -2.6915e-01,  2.6770e-01, -2.2907e-01,  1.1490e-01,
        -1.4468e-01,  1.0187e+00, -3.9975e-01,  4.5466e-01,  5.9253e-02,
         1.2598e-01, -4.2799e-01, -2.5723e-01,  6.4704e-01, -9.5196e-02,
        -8.6474e-01,  4.6565e-03, -4.2461e-01, -7.2993e-01, -3.6044e-01,
         4.8774e-01, -3.1956e-01,  6.5800e-01, -3.1860e-01, -7.0973e-01,
         5.2862e-01,  2.1545e-01,  3.7057e-01, -6.1895e-01, -1.7001e-01,
         1.9096e-01, -9.7453e-03, -3.5662e-01,  8.3948e-02,  3.3018e-01,
         2.7374e-01,  5.4790e-02,  1.7859e-01, -2.7695e-01,  2.9216e-01,
        -4.7656e-01, -2.3649e-02, -3.7655e-01,  5.8836e-02,  1.5329e-03,
        -2.1578e-01,  1.9514e-01, -4.5238e-01,  4.3061e-02,  8.1670e-02,
        -1.5773e-01,  1.8213e-01, -5.2056e-02,  2.4244e-01, -3.1294e-01,
         4.3311e-02, -6.3707e-01, -1.3066e-01,  3.6283e-01,  1.2065e-01,
         2.2420e-01, -3.9730e-01,  4.3084e-01, -8.3899e-02, -7.2651e-02,
         3.2916e-01,  5.2869e-02,  4.5304e-02, -2.8657e-01, -2.2814e-01,
        -2.1079e-01,  5.4317e-01, -2.8364e-01, -5.8576e-02,  3.3784e-01,
        -1.2021e-01,  4.3256e-01, -1.7689e-01,  3.9046e-01,  1.8407e-01,
        -8.8278e-02,  2.2861e-01, -9.3254e-02, -2.7015e-01,  8.2049e-02,
        -8.6143e-02, -4.7037e-01,  7.3181e-02,  1.7888e-01,  7.9018e-01,
        -2.5279e-01,  5.1713e-01, -2.1898e-01, -4.7730e-01, -4.7788e-02,
         1.8743e-01,  5.2972e-01,  6.8950e-01, -4.0064e-02, -3.0558e-01,
         3.4918e-02, -2.4100e-01, -5.8816e-02, -3.3122e-01,  2.0327e-01,
        -4.0185e-01, -2.5401e-01, -5.6088e-01,  5.8727e-02,  5.6775e-02,
        -4.5034e-01, -1.6924e-01,  2.9315e-01, -1.1529e-01,  7.0037e-01,
         5.9482e-03,  7.0533e-02,  8.9786e-02, -2.5654e-01, -8.9956e-02,
        -6.0867e-01, -1.7586e-01, -2.1448e-01,  3.3010e-01,  6.9301e-02,
         1.2401e-01,  1.2466e-01,  2.4337e-01,  4.2527e-01,  5.4302e-01,
        -1.6637e-01,  2.2132e-02, -1.9205e-01,  1.2017e-01,  2.6494e-01,
         1.8716e+00,  5.3612e-01,  1.7644e-01, -2.8449e-01, -6.9979e-01,
        -4.1062e-01,  1.9707e-01,  6.7704e-01, -1.1671e-01, -5.9469e-02,
        -4.1428e-01,  7.3477e-01, -3.4258e-01, -2.1703e-01,  7.0706e-02,
        -2.1172e-01,  2.8692e-01,  1.5552e-01, -2.0091e-01, -1.1618e-01,
        -2.6394e-01,  7.8578e-02, -1.6133e-01, -4.6151e-01, -3.8410e-01,
         4.7084e-02, -1.8008e-01,  4.0676e-01, -2.4275e-01, -2.6171e-01,
        -6.3435e-01,  4.1012e-01, -4.1747e-01, -1.0495e-01,  1.1516e-01,
         2.7490e-01,  1.1669e-01,  1.5617e-01,  4.2554e-01,  7.7032e-03,
        -3.6714e-01,  3.8705e-01, -3.4780e-01, -1.0861e-01, -3.8036e-01,
         4.6560e-03, -1.9371e-01,  4.0009e-01, -1.3126e-01,  1.1151e-01,
        -9.3251e-02,  1.7237e-01,  3.5482e-01,  1.0572e-01,  2.7734e-01,
        -1.9776e-01, -1.2454e-01, -3.7321e-01, -1.6871e-01, -9.4672e-02,
         3.9411e-01, -3.5628e-01, -8.3707e-02, -2.7465e-02,  2.6049e-02,
         4.8365e-01, -4.0487e-02, -1.9930e-01, -1.1348e-01,  3.3046e-01,
        -1.6522e-01, -5.5299e-01, -3.8460e-01, -8.1189e-02, -1.5261e-01,
         8.2685e-01, -9.7262e-02,  1.3995e-01, -1.9765e-01,  4.2032e-01,
         1.9911e-01,  8.7246e-01,  2.1387e-01, -2.4937e-01,  1.4783e-02,
        -2.5855e-03, -2.5508e-01, -1.2610e-01, -3.1158e-01, -3.0251e-01,
        -2.7762e-01,  4.8606e-01, -2.7652e-01, -8.8533e-02, -2.0091e-01,
         6.4036e-02, -7.9180e-02, -4.4887e-01, -4.6391e-02,  4.0895e-01])
# Let's look up the first 10 weights for "mighty"
glove_e.get_vecs_by_tokens("")[:10]
tensor([ 0.3449, -0.0829,  0.2248, -0.5731, -0.4953,  0.7840,  0.3006,  0.7167,
         0.0749,  0.2069])
def print_closest_words(vec, n=5):
    dists = torch.norm(glove_e.vectors - vec, dim=1)
    # compute distances to all words
    lst = sorted(enumerate(dists.numpy()), key=lambda x: x[1]) # sort by distance
    for idx, difference in lst[1:n+1]:  # take the top n
        print(glove_e.itos[idx], difference)

print_closest_words(glove_e["birds"], n=10)