Detach() function in encoder-decoder seq modeling

I understood that Detach() function won’t allow backpropagation to happen through a specific connection. But why? Isn’t it needed for training our model?

All the functions and operatios we do are being noted by pytorch and whenever we do training step, backpropagation happens through all these steps and some times we do not want backpropagation to happen through all the functions and sometimes we do not want backpropagation to happen through some funcitonal representations.
So for instance here the output of letter say invocation 1, is going as input to invocation 2, Now if we start having back propagation through this as well it can lead to very complicated training schedules. Alreaddy the backpropagation is happening through all the hidden states and so on, we do not want further backpropogation to happen through this chaining so when we say.detach() we essentilaly say dont pass gradients to this particular tensor.


Can you please point out to the part of the code where the backpropagation is happening through the hidden states inside the Transliteration_EncoderDecoder class?

In this screen shot you have shared which is a Transliteration_EncoderDecoder class inside of which we have defined two functions init(initialising different parameters and their dimensions) and forward(defining our encoder-decoder model with attention mechanism). Whereas in other cell we have defined train function - where we compute the loss and update the parmaters with backpropagation when we use opt.step() after computing the loss. This is where backpropagation happens for the forward function we defined inside Transliteration_EncoderDecoder class and since inside the forward function we also defined decoder_input = one_hot.detach() so that backpropagation does not happen through this.

1 Like