How to implement attention with batching in PyTorch

I am unable to figure out how to implement attention mechanism with batching especially the step in which we have to take a linear combination of the attention weights and encoder states. I am unable to find a substitute for this part of the code.

attn_applied = torch.bmm(attn_weights.unsqueeze(0),encoder_outputs.unsqueeze(0))

Please share if it is implemented by anyone

The above snippet is taken from the following article I believe:
https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html#attention-decoder

For implementing attention with batching, please follow this tutorial notebook:

1 Like