Unable to understand the encoder-decoder dimensions

Hello all,

out = infer(net, 'INDIA', 30)

Encoder input torch.Size([6, 1, 27])
Encoder output torch.Size([6, 1, 256])
Encoder hidden torch.Size([1, 1, 256])
Decoder state torch.Size([1, 1, 256])
Decoder input torch.Size([1, 1, 129])
Decoder intermediate output torch.Size([1, 1, 256])
Decoder output torch.Size([1, 1, 129])

Here, I am not able to understand the dimension sizes.Specially the last axis as value 256. Other dimensions are clear to me.
Any intuitive explanation regarding this?

The 3D shape of all the tensors follow the pattern: (sequence\_length, batch\_size, num\_neurons)

In the example logs that you posted, we can notice 3 things easily first:

  • The batch\_size is 1.
  • The sequence\_length of encoder is 6 (time steps)
    • Why? Because “INDIA” with a padding of 1 start\_token, is 6 characters long
  • The sequence\_length of decoder is 1
    • That is, we decode ourselves manually one-by-one timestep.

That being said, let’s now look into the last dimension of all the shapes you have posted.

Encoder input torch.Size([6, 1, 27])

It means that the input is an one-hot encoding over 27 alphabets (A-Z + start_token)

Encoder output torch.Size([6, 1, 256])

The output of the encoder produces a vector of dimension 256 for each input timestep
(256 is nothing but the size of the hidden unit in encoder RNN.)

Encoder hidden torch.Size([1, 1, 256])

This is the shape of the output of last time step alone.


Decoder state torch.Size([1, 1, 256])

Note that the output size of encoder RNN is same as output size of decoder RNN.
Hence, the above tensor is responsible for holding the hidden vector that we need to pass for each timestep to the decoder RNN.

Decoder intermediate output torch.Size([1, 1, 256])

Output of each timestep of decoder RNN.
(Note that this output just becomes the hidden vector for the next timestep, which is the above tensor)

Decoder output torch.Size([1, 1, 129])

After the RNN decoder, probably there is an FNN with Softmax to classify which output class is more likely.
In the example above, 129 means that there are 127 Devanagari tokens (+2 stop & pad tokens probably)

This tensor helps you find out the output character at each decoder timestep.

Decoder input torch.Size([1, 1, 129])

Note that the input to decoder RNN is nothing but the output we got from previous timestep.
Hence the 129.

2 Likes

Everything is clear now. Thank you so much for explaining all the dimensions in such detailed manner. :slight_smile:

1 Like