In the previous post I described the Batch Normalization layer. I implemented it using the computational graph instead of using calculus to get a more efficient implementation. In this post I will derive the batch normalization gradient expression and implement it in python.
The Batch Norm layer normalizes its input to have zero mean and unit variance, just as we usually do to the input data:
Where is the per-feature mean of the input batch and is the variance. As this operation could limit the representational power of the network, the authors wanted to make sure that the layer could learn the identity function:
The layer is initialized with and in order to effectively normalize its input. But, when and , the output of the layer is the same as its input. As such, the layer can learn what’s best: if it should normalize the data to have 0 mean and unit variance, if it should output the same values as its input or if it should scale and translate the input with some other values.
There are two strategies to derive the gradient expression needed to update the network’s weights:
- Break down the mathematical expressions into atomic operations in order to build a computational graph and make it easier to use the chain rule.
- Derive the gradient analytically.
The first strategy offers us a simple abstraction to compute the gradient but it is usually not the most efficient implementation. By deriving the gradient expression, it is usually possible to simplify it and remove unnecessary terms. It turns out that the second strategy yields, indeed, a more efficient gradient expression for the Batch Norm layer.
Let’s start with the simpler parameters: and . In these cases, the computational graph already gives us the best possible expression, but this will help us remember the basics. Again, we will make use of the chain rule to derive the partial derivative of the loss with respect to :
Why do we need to use the chain rule? We have the value of the , as it is provided to us by the next layer, and we can actually derive the . That way, we do not need to know anything about the loss function that was used nor what are the next layers. The gradient expression becomes self-contained, provided that we are given .
Why the summation? is a dimensional vector that is multiplied by the vectors of dimension and so its contributions must be summed. As I am a very visual person, I better understood this by thinking in terms of a computational graph:
When we are backpropagating through the computational graph, the gradient flowing from each node arrives at the same node and, as such, all the gradients must be summed.
The gradient expressions for the partial derivatives of the loss with respect to and become:
What we need to compute next is the partial derivative of the loss with respect to the inputs , so the previous layers can compute their gradients and update their parameters. We need to gather all the expressions where is used that has influence in the result. Do not forget that:
We can conclude that is used to compute , and and therefore:
Let’s compute and simplify each of these terms individually and then we come back to this expression.
The first term is the easiest one to derive. Using the same chain rule process as before we arrive to the following expressions:
So far so good. The next expressions are a bit longer, but once you get the process of the chain rule they are just as simple.
As what happened with the gradients of and , to compute the gradient of we need to sum over the contributions of all elements from the batch. The same happens to the gradient of , as it is also a dimensional vector:
In case you are wondering why I deleted the second term of the expression, it turns out that is equal to . We are translating all the points so their mean is equal to . By looking at the mean’s formula, we can see that the only way for the points to have mean is when their sum is also equal to :
Now we can easily compute the rest of the terms:
EDITED ON 17/12/2016: Thanks to Ishijima Seiichiro for pointing me out that this equation could be further simplified. Let’s focus on this term of the equation:
Notice that and so:
Merging everything together:
Finally we have a single mathematical expression for the partial derivative of the loss with respect to each input. This is a simpler expression than what we get by deriving the expression using the computational graph.
Translating this to python, we end up with a much more compact method:
def batchnorm_backward_alt(dout, cache): gamma, xhat, istd = cache N, _ = dout.shape dbeta = np.sum(dout, axis=0) dgamma = np.sum(xhat * dout, axis=0) dx = (gamma*istd/N) * (N*dout - xhat*dgamma - dbeta) return dx, dgamma, dbeta
This method is 2 to 4 times faster than the one presented on the previous post. It might not seem much, but when we are talking about deep neural networks that may take weeks to train, every little improvement in the end makes a huge difference.