Trace Trick
Equation
Suppose a scalar function $f: \mathbb{R}^{n \times n} \to \mathbb{R}$ is defined over matrix space $M(\mathbb{R}^{n \times n})$. Let $\mathrm{d}f$ be the total differential of $f$. The following holds.
$$ \mathrm{d}f = \Tr (\mathrm{d}f) = \mathrm{d}\Tr (f) \tag{1} $$
$\Tr$ is the trace.
Proof
The function value of $f$ is a scalar, so taking the trace yields the same result. Therefore, we obtain:
$$ f = \Tr(f) \implies \mathrm{d}f = \mathrm{d}\Tr(f) $$
Since $\mathrm{d}f$ is also a scalar and taking the trace yields the same result, we obtain:
$$ \mathrm{d}f = \Tr (\mathrm{d}f) $$
From the above two results, we derive the following expression.
$$ \Tr (\mathrm{d}f) = \mathrm{d}f = \mathrm{d}\Tr(f) $$
■
Trace Trick
The trace trick is a computational method that allows for easy computation of the gradient matrix $\nabla_{\mathbf{X}}f$. Compared to the derivative $\dfrac{d f}{d x}$ of a univariate function or the gradient $\nabla_{\mathbf{x}}f = \begin{bmatrix} \dfrac{\partial f}{\partial x_{1}} & \cdots & \dfrac{\partial f}{\partial x_{n}} \end{bmatrix}^{\mathsf{T}}$, the gradient matrix has many components that make calculations cumbersome. Especially when matrix products and transpositions are involved, taking indices into account can be confusing. The trace trick allows for easy computation without the need to consider indices. The summary of the computational process is as follows:
- Begin with $\mathrm{d}f = \Tr( \mathrm{d}f )$ or $\mathrm{d}f = \mathrm{d}\Tr(f)$.
- Convert the expression into the form of $\mathrm{d}f = \Tr \left( A^{\mathsf{T}} \mathrm{d}\mathbf{X} \right)$.
- The total differential of $f$ is $\mathrm{d}f = \Tr \left( \left( \nabla_{\mathbf{X}}f \right)^{\mathsf{T}} \mathrm{d}\mathbf{X} \right)$, thus $A = \nabla_{\mathbf{X}}f$.
In this process, using $(1)$ or properties from trace theory, and properties of matrix calculus makes the calculation easier. As will be illustrated in examples, there is hardly any “calculation” involved. The main task is rearranging the order of matrix multiplications using properties of trace. Most importantly, the differentiation is not performed actively. (Of course, computing the matrix differential elements is derived by the product differentiation rule.)
For derivative formulas computed using the trace trick, refer to the matrix differentiation table for scalar functions.
Example
The following rules are used in the calculations.
Properties of Matrix Differential Elements
Given a variable matrix $\mathbf{X}, \mathbf{Y} \in \mathbb{R}^{n \times n}$, scalar $\alpha \in \mathbb{R}$, and constant matrix $\mathbf{A} \in \mathbb{R}^{n \times n}$, the following holds:
- $\mathrm{d}(\alpha \mathbf{X}) = \alpha \mathrm{d}\mathbf{X}$
- $\mathrm{d}(\mathbf{X}^{\mathsf{T}}) = (\mathrm{d}\mathbf{X})^{\mathsf{T}}$
- $\mathrm{d}(\mathbf{A}\mathbf{X}) = \mathbf{A} \mathrm{d}\mathbf{X}$ and $\mathrm{d}(\mathbf{X}\mathbf{A}) = (\mathrm{d}\mathbf{X}) \mathbf{A}$
- $\mathrm{d}(\mathbf{X} + \mathbf{Y}) = \mathrm{d}\mathbf{X} + \mathrm{d}\mathbf{Y}$
- $\mathrm{d}(\mathbf{X}\mathbf{Y}) = (\mathrm{d}\mathbf{X})\mathbf{Y} + \mathbf{X} \mathrm{d}\mathbf{Y}$
- $\Tr (\alpha \mathbf{X}) = \alpha \Tr (\mathbf{X})$
- Linearity: $\Tr (\mathbf{X} + \mathbf{Y}) = \Tr (\mathbf{X}) + \Tr (\mathbf{Y})$
- Cyclicity: $\Tr (\mathbf{X}\mathbf{Y}\mathbf{Z}) = \Tr (\mathbf{Y}\mathbf{Z}\mathbf{X}) = \Tr (\mathbf{Z}\mathbf{X}\mathbf{Y})$
- Invariance of Transposition: $\Tr (\mathbf{X}^{\mathsf{T}}) = \Tr (\mathbf{X})$
$f(\mathbf{X}) = \Tr(\mathbf{A}\mathbf{X})$
Calculated according to the sequence described above, it proceeds as follows.
$$ \begin{align*} \mathrm{d}f &= \mathrm{d}\Tr \left( \mathbf{A}\mathbf{X} \right) \\ &= \Tr \left( \mathrm{d}\left( \mathbf{A}\mathbf{X} \right) \right) & \text{by } (1) \\ &= \Tr \left( \mathbf{A} \mathrm{d}\mathbf{X} \right) \\ &= \Tr \left( \left( \mathbf{A}^{\mathsf{T}} \right)^{\mathsf{T}} \mathrm{d}\mathbf{X} \right)\\ &= \Tr \left( (\nabla_{\mathbf{X}}f)^{\mathsf{T}} \mathbf{X} \right) \end{align*} $$
Thus, we arrive at:
$$ \nabla_{\mathbf{X}}f = \mathbf{A}^{\mathsf{T}} $$
This is much simpler than direct computation.
■
$f(\mathbf{X}) = \mathbf{a}^{\mathsf{T}}\mathbf{X}\mathbf{b}$
Calculated similarly to the example above. From Properties of Trace, $\Tr (\mathbf{a}^{\mathsf{T}}\mathbf{X}\mathbf{b}) = \Tr (\mathbf{b}\mathbf{a}^{\mathsf{T}}\mathbf{X})$, we can compute as follows:
$$ \begin{align*} \mathrm{d}f &= \mathrm{d}\Tr (f) &\text{by } (1) \\ &= \mathrm{d}\Tr \left( \mathbf{a}^{\mathsf{T}}\mathbf{X}\mathbf{b} \right) \\ &= \mathrm{d}\Tr \left( (\mathbf{b}\mathbf{a}^{\mathsf{T}})\mathbf{X} \right) \\ &= \mathrm{d}\Tr \left( (\mathbf{a}\mathbf{b}^{\mathsf{T}})^{\mathsf{T}}\mathbf{X} \right) \\ \end{align*} $$
By the result of the first example,
$$ \nabla_{\mathbf{X}}f = \mathbf{a}\mathbf{b}^{\mathsf{T}} $$
■
$f(\mathbf{X}) = \mathbf{a}^{\mathsf{T}} \mathbf{X}^{\mathsf{T}}\mathbf{X} \mathbf{b}$
Compared to the second example, only an additional term $\mathbf{X}^{\mathsf{T}}$ appears, yet the inconvenience and difficulty of a direct hand calculation are remarkably different. Using the trace trick, it can be computed as follows. First, since $\Tr(\mathbf{a}^{\mathsf{T}} \mathbf{X}^{\mathsf{T}} \mathbf{X} \mathbf{b}) = \Tr(\mathbf{b} \mathbf{a}^{\mathsf{T}} \mathbf{X}^{\mathsf{T}} \mathbf{X})$,
$$ \begin{align*} \mathrm{d}f &= \mathrm{d}\Tr (f) &\text{by } (1) \\ &= \mathrm{d}\Tr(\mathbf{a}^{\mathsf{T}} \mathbf{X}^{\mathsf{T}} \mathbf{X} \mathbf{b}) \\ &= \mathrm{d}\Tr(\mathbf{b} \mathbf{a}^{\mathsf{T}} \mathbf{X}^{\mathsf{T}} \mathbf{X}) \\ &= \Tr\left( \mathrm{d}(\mathbf{b} \mathbf{a}^{\mathsf{T}} \mathbf{X}^{\mathsf{T}} \mathbf{X}) \right) & \text{by } (1) \\ &= \Tr\left( \mathbf{b} \mathbf{a}^{\mathsf{T}} \mathrm{d}(\mathbf{X}^{\mathsf{T}}) \mathbf{X} + \mathbf{b} \mathbf{a}^{\mathsf{T}} \mathbf{X}^{\mathsf{T}} \mathrm{d}\mathbf{X} \right) \\ &\overset{6\text{th}}{=} \Tr\left( \mathrm{d}(\mathbf{X}^{\mathsf{T}})\mathbf{X}\mathbf{b} \mathbf{a}^{\mathsf{T}} + \mathbf{b} \mathbf{a}^{\mathsf{T}} \mathbf{X}^{\mathsf{T}} \mathrm{d}\mathbf{X} \right) \\ &= \Tr\left( (\mathrm{d}\mathbf{X}^{\mathsf{T}})(\mathbf{X}\mathbf{b} \mathbf{a}^{\mathsf{T}}) + \mathbf{b} \mathbf{a}^{\mathsf{T}} \mathbf{X}^{\mathsf{T}} \mathrm{d}\mathbf{X} \right) \\ &\overset{8\text{th}}{=} \Tr\left( (\mathbf{X}\mathbf{b} \mathbf{a}^{\mathsf{T}})^{\mathsf{T}}\mathrm{d}\mathbf{X} + (\mathbf{X} \mathbf{a} \mathbf{b}^{\mathsf{T}})^{\mathsf{T}} \mathrm{d}\mathbf{X} \right) \\ &= \Tr\left( \left[ \mathbf{X}\mathbf{b} \mathbf{a}^{\mathsf{T}} + \mathbf{X} \mathbf{a} \mathbf{b}^{\mathsf{T}} \right]^{\mathsf{T}} \mathrm{d}\mathbf{X} \right) \\ &= \Tr \left( (\nabla_{\mathbf{X}}f)^{\mathsf{T}} \mathbf{X} \right) \end{align*} $$
The sixth equality holds due to the linearity and cyclicity properties of the trace. The eighth equality holds due to the linearity and invariance of transposition properties of the trace.
$$ \nabla_{\mathbf{X}}f = \mathbf{X}\mathbf{b} \mathbf{a}^{\mathsf{T}} + \mathbf{X} \mathbf{a} \mathbf{b}^{\mathsf{T}} = \mathbf{X} \left( \mathbf{a} \mathbf{b}^{\mathsf{T}} + \mathbf{b} \mathbf{a}^{\mathsf{T}} \right) $$
■