Directional derivatives are the conceptual tool to measure the effect on a function by changing the input in any direction within the input space. It’s possible to compute the directional derivatives using the jacobian-vector product, implemented by the automatic differentiation JAX library.
Partial derivatives \(\partial f/ \partial x_i\) give us the rate of change if we slightly modify the ith element of the input vector \(\bf{x}\) by h, letting the rest constant.
\[ \frac{\partial f}{\partial x_1} = lim_{h \to 0} \frac{f(x_1 + h, x_2, \dots, x_n) - f({\bf x})}{h} \\ \vdots \\ \frac{\partial f}{\partial x_n} = lim_{h \to 0} \frac{f(x_1, x_2, \dots, x_n + h) - f({\bf x})}{h} \]
The above definition can be more compactly using vector notation.
\[
\frac{\partial f}{\partial x_i}({\bf x_0}) = lim_{h \to 0} \frac{f({\bf x_0} + h{\bf e_i}) - f({\bf x_0})}{h}
\]
The \(e_i\) vector represents a unit vector in the direction of \(i\) with the same number of dimensions that \({\bf x_0}\). The only element of \(e_i\) different from 0 is the ith-element with a value of 1.
As you can see in the initial diagram, in a 2D input space, there are two partial derivatives:
Computing derivatives using unit vectors such as \(e_i\) give us the change of \(f\) on the direction on \(i\), or parallel to the i-axis. How can we compute the derivative of \(f\) given a slight nudge of the inputs in any arbitrary direction?
Directional derivatives is a way to compute the rate of change on \(f\) in the direction of \({\bf v}\).
\[
\nabla_{{\bf v}}f({\bf x_0}) = lim_{h \to 0} \frac{f({\bf x_0} + h{\bf v}) - f({\bf x_0})}{h}
\]
Think as \({\bf v}\) as a weighted vector of the n-directions of the input space. We aren’t limited to the changes on \(f\) in parallel directions in the input space.
We can compute directional derivatives using the dot product between the jacobian vector (\(\nabla f\)) and the vector \({\bf v}\). For instance, for a two-dimensional input space, \({\bf v}= (v_1, v_2)\), and any arbitrary point \(p\):
\[\nabla_{\bf v} f(p) = \nabla f(p) \cdot {\bf v} = \frac{\partial f}{\partial x_1}(p) v_1 + \frac{\partial f}{\partial x_2}(p)v_2\]
More general:
\[\nabla_{\bf v} f(p) = \nabla f(p) \cdot {\bf v} = \sum^{n}_{i=1} \frac{\partial f}{\partial x_i}(p) v_i\]
Let’s focus on computing the above using the function jax.jvp
, which jvp
stands for the jacobian-vector product.
The function jax.jvp
computes the directional derivative and whose arguments are:
jax.jvp
returns a tuple with \((f(p), \nabla f_{v}(p))\)
We compute the directional derivative of \(f(x, y)=x^2y\) hand-coding all the
necessary elements and then checking the results given by jax.jvp
.
def fun(x, y): return x**2 * y
def fun_dx(x, y): return 2*x*y
def fun_dy(x, y): return x**2
We define the primal vector \({\bf p}\) and the tangent vector \({\bf v}\) in which we want to compute the directional derivative.
p = [1., 1.]
v = [1., 2.]
Evaluate \(f(p)\):
# *n-list/n-tuple unpack the element e0, e1, ..., en
fun(*p)
> 1.0
Compute the directional derivative using the fun_dx
and fun_dy
.
fun_dx(*p) * v[0] + fun_dy(*p) * v[1]
> 4.0
Now using jax.jvp
we obtain the same results: \(f({\bf p})\) and \(\nabla_{\bf v}f({\bf p})\).
jax.jvp(fun, p, v)
> (DeviceArray(1., dtype=float32, weak_type=True),
DeviceArray(4., dtype=float32, weak_type=True))
A surface plot will show the output space, and a contour plot the input space of \(f(x,y)=x^2y\). We will compute the directional derivatives for three points and their respective directional vectors.
Look the directional vectors in the plot, or tangent vectors as JAX refers to them, there are of different lengths. It’s important to remark that if we want the “slope definition” for directional derivatives we need to transform \({\bf v}\) in a unit length vector (divide the directional derivative definition by \(||v||\)). Remember that partial derivatives are computed using unit vectors (\(e_i\)).
\[ \nabla_{{\bf v}}f = \frac{\partial f}{\partial {\bf x}} = lim_{h \to 0} \frac{f({\bf x} + h{\bf v}) - f({\bf x})}{h||{\bf v}||} \]
primal_a = jnp.array([-5., 3.2])
primal_b = jnp.array([5., -3.2])
primal_c = jnp.array([0., 0.])
va = jnp.array([-7.5, 5.7])
vb = jnp.array([7.5, -5.7])
vc = jnp.array([-1.0, -0.7])
unit_va = va/va.dot(va)**.5
unit_vb = vb/vb.dot(vb)**.5
unit_vc = vc/vc.dot(vc)**.5
# Computing making the directional vectors unit length
_, slope_a = jax.jvp(fun, primal_a.tolist(), unit_va.tolist())
_, slope_b = jax.jvp(fun, primal_b.tolist(), unit_vb.tolist())
_, slope_c = jax.jvp(fun, primal_c.tolist(), unit_vc.tolist())
slope_a, slope_b, slope_c
> (DeviceArray(40.60427, dtype=float32, weak_type=True),
DeviceArray(-40.60427, dtype=float32, weak_type=True),
DeviceArray(-0., dtype=float32, weak_type=True))
We can see some observations from the points and their directional derivatives.