Total Internal Reflection

Technology and Art



Code

Mojo LSP
Cartographer
PICK Basic Tree-Sitter LSP
Cobol REKT
Tape/Z
Plenoxels
Transformer
Basis-Processing
Cataract
COMRADE
Duck-Angular
Exo
IRIS
MuchHeap
Snail-MapReduce
Underline
Lambda-Queuer
jQuery-Jenkins Radiator

Contact

Github
Twitter
LinkedIn

Site Feed

Transformers using PyTorch : Worklog Part 2

Avishek Sen Gupta on 14 January 2023

We continue looking at the Transformer architecture from where we left from Part 1. When we’d stopped, we’d set up the Encoder stack, but had stopped short of adding positional encoding, and starting work on the Decoder stack. In this post, we will focus on setting up the training cycle.

Specifically, we will cover:

We will also lay out the dimensional analysis a little more clearly, and add necessary unit tests to verify intended functionality. The code is available here.

Positional Encoding

You can see the code for visualising the positional encoding here. Both images below show the encoding map at different levels of zoom.

Position Encoding zoomed out

Position Encoding zoomed in

The code in the main Transformer implementation which implements the positional embedding is shown below.


    # The encoder output is injected directly into the sublayer of every Decoder. To build up the chain of Decoders
    # in PyTorch, so that we can put the full stack inside a Sequential block, we simply inject the encoder output
    # to the root Decoder, and have it output the encoder output (together with the actual Decoder output) as part of
    # the Decoder's actual output to make it easy for the next Decoder in the stack to consume the Encoder and Decoder
    # outputs
    def forward(self, input):
        encoder_output, previous_stage_output = input
        masked_mh_output = self.masked_multiheaded_attention_layer(
            self.masked_qkv_source.forward(previous_stage_output))
        input_qkv = self.unmasked_qkv_source.forward((encoder_output, masked_mh_output))
        mh_output = self.multiheaded_attention_layer(input_qkv)
        # Adds the residual connection to the output of the attention layer
        layer_normed_multihead_output = self.layer_norm(mh_output + previous_stage_output)
        ffnn_outputs = torch.stack(
            list(map(lambda attention_vector: self.feedforward_layer(attention_vector), layer_normed_multihead_output)))
        layer_normed_ffnn_output = self.layer_norm(ffnn_outputs + layer_normed_multihead_output)
        return (encoder_output, layer_normed_ffnn_output)



Data Flow

The diagram below (you’ll need to zoom in) shows the data flow for a single Encoder/Decoder, with 8 attention blocks per multihead attention layer. \(n\) represents the number of words passed into the Encoder. \(m\) represents the number of words passed into the Decoder. \(V\) represents the length of the full vocabulary.

The dimensions of the data at each stage are depicted to facilitate understanding.

graph LR;
    subgraph Encoder
        encoder_src[Source Text]--nx512-->pos_encoding[Positional Encoding];
        pos_encoding--nx512-->qkv_encoder[QKV Layer]
        qkv_encoder--Q=nx64-->multihead_attn_1[Attention 1]
        qkv_encoder--K=nx64-->multihead_attn_1
        qkv_encoder--V=nx64-->multihead_attn_1
        qkv_encoder--Q=nx64-->multihead_attn_2[Attention 2]
        qkv_encoder--K=nx64-->multihead_attn_2
        qkv_encoder--V=nx64-->multihead_attn_2
        qkv_encoder--Q=nx64-->multihead_attn_3[Attention 3]
        qkv_encoder--K=nx64-->multihead_attn_3
        qkv_encoder--V=nx64-->multihead_attn_3
        qkv_encoder--Q=nx64-->multihead_attn_4[Attention 4]
        qkv_encoder--K=nx64-->multihead_attn_4
        qkv_encoder--V=nx64-->multihead_attn_4
        qkv_encoder--Q=nx64-->multihead_attn_5[Attention 5]
        qkv_encoder--K=nx64-->multihead_attn_5
        qkv_encoder--V=nx64-->multihead_attn_5
        qkv_encoder--Q=nx64-->multihead_attn_6[Attention 6]
        qkv_encoder--K=nx64-->multihead_attn_6
        qkv_encoder--V=nx64-->multihead_attn_6
        qkv_encoder--Q=nx64-->multihead_attn_7[Attention 7]
        qkv_encoder--K=nx64-->multihead_attn_7
        qkv_encoder--V=nx64-->multihead_attn_7
        qkv_encoder--Q=nx64-->multihead_attn_8[Attention 8]
        qkv_encoder--K=nx64-->multihead_attn_8
        qkv_encoder--V=nx64-->multihead_attn_8
        subgraph EncoderMultiheadAttention[Encoder Multihead Attention]
            multihead_attn_1--nx64-->concat((Concatenate))
            multihead_attn_2--nx64-->concat
            multihead_attn_3--nx64-->concat
            multihead_attn_4--nx64-->concat
            multihead_attn_5--nx64-->concat
            multihead_attn_6--nx64-->concat
            multihead_attn_7--nx64-->concat
            multihead_attn_8--nx64-->concat
        end
        concat--nx512-->linear_reproject[Linear Reprojection]
        linear_reproject--1x512-->ffnn_encoder_1[FFNN 1]
        linear_reproject--1x512-->ffnn_encoder_2[FFNN 2]
        linear_reproject--1x512-->ffnn_encoder_t[FFNN x]
        linear_reproject--1x512-->ffnn_encoder_n[FFNN n]
        subgraph FfnnEncoder[Feed Forward Neural Network]
            ffnn_encoder_1--1x512-->stack_encoder((Stack))
            ffnn_encoder_2--1x512-->stack_encoder
            ffnn_encoder_t--1x512-->stack_encoder
            ffnn_encoder_n--1x512-->stack_encoder
        end
        stack_encoder--nx512-->encoder_output[Encoder Output]
    end
    subgraph Decoder
        decoder_target[Decoder Target]--mx512-->pos_encoding_2[Positional Encoding]
        pos_encoding_2--mx512-->qkv_decoder_1[QKV Layer]
        qkv_decoder_1--Q=mx64-->multihead_attn_masked_1[Attention 1]
        qkv_decoder_1--K=mx64-->multihead_attn_masked_1
        qkv_decoder_1--V=mx64-->multihead_attn_masked_1
        qkv_decoder_1--Q=mx64-->multihead_attn_masked_2[Attention 2]
        qkv_decoder_1--K=mx64-->multihead_attn_masked_2
        qkv_decoder_1--V=mx64-->multihead_attn_masked_2
        qkv_decoder_1--Q=mx64-->multihead_attn_masked_3[Attention 3]
        qkv_decoder_1--K=mx64-->multihead_attn_masked_3
        qkv_decoder_1--V=mx64-->multihead_attn_masked_3
        qkv_decoder_1--Q=mx64-->multihead_attn_masked_4[Attention 4]
        qkv_decoder_1--K=mx64-->multihead_attn_masked_4
        qkv_decoder_1--V=mx64-->multihead_attn_masked_4
        qkv_decoder_1--Q=mx64-->multihead_attn_masked_5[Attention 5]
        qkv_decoder_1--K=mx64-->multihead_attn_masked_5
        qkv_decoder_1--V=mx64-->multihead_attn_masked_5
        qkv_decoder_1--Q=mx64-->multihead_attn_masked_6[Attention 6]
        qkv_decoder_1--K=mx64-->multihead_attn_masked_6
        qkv_decoder_1--V=mx64-->multihead_attn_masked_6
        qkv_decoder_1--Q=mx64-->multihead_attn_masked_7[Attention 7]
        qkv_decoder_1--K=mx64-->multihead_attn_masked_7
        qkv_decoder_1--V=mx64-->multihead_attn_masked_7
        qkv_decoder_1--Q=mx64-->multihead_attn_masked_8[Attention 8]
        qkv_decoder_1--K=mx64-->multihead_attn_masked_8
        qkv_decoder_1--V=mx64-->multihead_attn_masked_8
        subgraph DecoderMaskedMultiheadAttention[Decoder Masked Multihead Attention]
            multihead_attn_masked_1--mx64-->concat_masked((Concatenate))
            multihead_attn_masked_2--mx64-->concat_masked
            multihead_attn_masked_3--mx64-->concat_masked
            multihead_attn_masked_4--mx64-->concat_masked
            multihead_attn_masked_5--mx64-->concat_masked
            multihead_attn_masked_6--mx64-->concat_masked
            multihead_attn_masked_7--mx64-->concat_masked
            multihead_attn_masked_8--mx64-->concat_masked
        end
        concat_masked--mx512-->linear_reproject_masked[Linear Reprojection]
        linear_reproject_masked--1x512-->ffnn_encoder_1_masked[FFNN 1]
        linear_reproject_masked--1x512-->ffnn_encoder_2_masked[FFNN 2]
        linear_reproject_masked--1x512-->ffnn_encoder_t_masked[FFNN x]
        linear_reproject_masked--1x512-->ffnn_encoder_n_masked[FFNN n]
        subgraph FfnnEncoderMasked[Feed Forward Neural Network]
            ffnn_encoder_1_masked--1x512-->stack_decoder_masked((Stack))
            ffnn_encoder_2_masked--1x512-->stack_decoder_masked
            ffnn_encoder_t_masked--1x512-->stack_decoder_masked
            ffnn_encoder_n_masked--1x512-->stack_decoder_masked
        end
        stack_decoder_masked--mx512-->query_project[Query Projection]
        encoder_output--nx512-->kv_project_decoder[Key-Value Projection]
        query_project--Q=mx64-->multihead_attn_unmasked_1[Attention 1]
        kv_project_decoder--K=nx64-->multihead_attn_unmasked_1
        kv_project_decoder--V=nx64-->multihead_attn_unmasked_1
        query_project--Q=mx64-->multihead_attn_unmasked_2[Attention 2]
        kv_project_decoder--K=nx64-->multihead_attn_unmasked_2
        kv_project_decoder--V=nx64-->multihead_attn_unmasked_2
        query_project--Q=mx64-->multihead_attn_unmasked_3[Attention 3]
        kv_project_decoder--K=nx64-->multihead_attn_unmasked_3
        kv_project_decoder--V=nx64-->multihead_attn_unmasked_3
        query_project--Q=mx64-->multihead_attn_unmasked_4[Attention 4]
        kv_project_decoder--K=nx64-->multihead_attn_unmasked_4
        kv_project_decoder--V=nx64-->multihead_attn_unmasked_4
        query_project--Q=mx64-->multihead_attn_unmasked_5[Attention 5]
        kv_project_decoder--K=nx64-->multihead_attn_unmasked_5
        kv_project_decoder--V=nx64-->multihead_attn_unmasked_5
        query_project--Q=mx64-->multihead_attn_unmasked_6[Attention 6]
        kv_project_decoder--K=nx64-->multihead_attn_unmasked_6
        kv_project_decoder--V=nx64-->multihead_attn_unmasked_6
        query_project--Q=mx64-->multihead_attn_unmasked_7[Attention 7]
        kv_project_decoder--K=nx64-->multihead_attn_unmasked_7
        kv_project_decoder--V=nx64-->multihead_attn_unmasked_7
        query_project--Q=mx64-->multihead_attn_unmasked_8[Multihead Attention 8]
        kv_project_decoder--K=nx64-->multihead_attn_unmasked_8
        kv_project_decoder--V=nx64-->multihead_attn_unmasked_8
        subgraph DecoderUnmaskedMultiheadAttention[Decoder Unmasked Multihead Attention]
            multihead_attn_unmasked_1--mx64-->concat_unmasked((Concatenate))
            multihead_attn_unmasked_2--mx64-->concat_unmasked
            multihead_attn_unmasked_3--mx64-->concat_unmasked
            multihead_attn_unmasked_4--mx64-->concat_unmasked
            multihead_attn_unmasked_5--mx64-->concat_unmasked
            multihead_attn_unmasked_6--mx64-->concat_unmasked
            multihead_attn_unmasked_7--mx64-->concat_unmasked
            multihead_attn_unmasked_8--mx64-->concat_unmasked
        end
        concat_unmasked--mx512-->linear_reproject_unmasked[Linear Reprojection]
        linear_reproject_unmasked--1x512-->ffnn_decoder_1_unmasked[FFNN 1]
        linear_reproject_unmasked--1x512-->ffnn_decoder_2_unmasked[FFNN 2]
        linear_reproject_unmasked--1x512-->ffnn_decoder_t_unmasked[FFNN x]
        linear_reproject_unmasked--1x512-->ffnn_decoder_n_unmasked[FFNN n]
        subgraph FfnnDecoderUnmasked[Feed Forward Neural Networks]
            ffnn_decoder_1_unmasked--1x512-->stack_decoder_unmasked((Stack))
            ffnn_decoder_2_unmasked--1x512-->stack_decoder_unmasked
            ffnn_decoder_t_unmasked--1x512-->stack_decoder_unmasked
            ffnn_decoder_n_unmasked--1x512-->stack_decoder_unmasked
        end
    end
    stack_decoder_unmasked--mx512-->linear[Linear=512xV]
    subgraph OutputLayer[Output Layer]
        linear--mxV-->softmax[Softmax]
        softmax--mxV-->select_max_probabilities[Select Maximum Probability Token for each Position]
        select_max_probabilities--1xm-->transformer_output[Transformer Output]
    end

Notes on the Code

Transformer Implementation Class Diagram

Conclusion

We have built and tested the basic Transformer architecture. However, we still need to do the following:

All of the above, we will work on in the sequel to this post.

References


tags: Machine Learning - PyTorch - Programming - Deep Learning - Transformers