Directional derivatives and JAX

PUBLISHED ON FEB 8, 2022 / 4 MIN READ — AUTODIFF, CALCULUS, JAX, MML, PYTHON

Open In Colab


Directional derivatives are the conceptual tool to measure the effect on a function by changing the input in any direction within the input space. It’s possible to compute the directional derivatives using the jacobian-vector product, implemented by the automatic differentiation JAX library.

Partial derivatives \(\partial f/ \partial x_i\) give us the rate of change if we slightly modify the ith element of the input vector \(\bf{x}\) by h, letting the rest constant.


\[ \frac{\partial f}{\partial x_1} = lim_{h \to 0} \frac{f(x_1 + h, x_2, \dots, x_n) - f({\bf x})}{h} \\ \vdots \\ \frac{\partial f}{\partial x_n} = lim_{h \to 0} \frac{f(x_1, x_2, \dots, x_n + h) - f({\bf x})}{h} \]


The above definition can be more compactly using vector notation.


\[ \frac{\partial f}{\partial x_i}({\bf x_0}) = lim_{h \to 0} \frac{f({\bf x_0} + h{\bf e_i}) - f({\bf x_0})}{h} \]

The \(e_i\) vector represents a unit vector in the direction of \(i\) with the same number of dimensions that \({\bf x_0}\). The only element of \(e_i\) different from 0 is the ith-element with a value of 1.

As you can see in the initial diagram, in a 2D input space, there are two partial derivatives:

  • \(\partial f / \partial x\): computing parallel to the x-axis (\(e_1\) typical known as \(\hat{i}\))
  • \(\partial f / \partial y\): computing parallel to the y-axis (\(e_2\) typical known as \(\hat{j}\)).

Computing derivatives using unit vectors such as \(e_i\) give us the change of \(f\) on the direction on \(i\), or parallel to the i-axis. How can we compute the derivative of \(f\) given a slight nudge of the inputs in any arbitrary direction?

Directional derivatives is a way to compute the rate of change on \(f\) in the direction of \({\bf v}\).


\[ \nabla_{{\bf v}}f({\bf x_0}) = lim_{h \to 0} \frac{f({\bf x_0} + h{\bf v}) - f({\bf x_0})}{h} \]

Think as \({\bf v}\) as a weighted vector of the n-directions of the input space. We aren’t limited to the changes on \(f\) in parallel directions in the input space.

We can compute directional derivatives using the dot product between the jacobian vector (\(\nabla f\)) and the vector \({\bf v}\). For instance, for a two-dimensional input space, \({\bf v}= (v_1, v_2)\), and any arbitrary point \(p\):

\[\nabla_{\bf v} f(p) = \nabla f(p) \cdot {\bf v} = \frac{\partial f}{\partial x_1}(p) v_1 + \frac{\partial f}{\partial x_2}(p)v_2\]

More general:

\[\nabla_{\bf v} f(p) = \nabla f(p) \cdot {\bf v} = \sum^{n}_{i=1} \frac{\partial f}{\partial x_i}(p) v_i\]

Let’s focus on computing the above using the function jax.jvp, which jvp stands for the jacobian-vector product.

The function jax.jvp computes the directional derivative and whose arguments are:

  1. A differentiable function \(f\) to compute the jacobian \(\nabla f\)
  2. A primal vector \({\bf p}\) to evaluate the jacobian \(\nabla f(p)\)
  3. A tangent vector \({\bf v}\) which represent the direction in which we want to calculate the derivative.

jax.jvp returns a tuple with \((f(p), \nabla f_{v}(p))\)

Example

We compute the directional derivative of \(f(x, y)=x^2y\) hand-coding all the necessary elements and then checking the results given by jax.jvp.

def fun(x, y): return x**2 * y
def fun_dx(x, y): return 2*x*y
def fun_dy(x, y): return x**2

We define the primal vector \({\bf p}\) and the tangent vector \({\bf v}\) in which we want to compute the directional derivative.

p = [1., 1.]
v = [1., 2.]

Evaluate \(f(p)\):

# *n-list/n-tuple unpack the element e0, e1, ..., en
fun(*p)
> 1.0

Compute the directional derivative using the fun_dx and fun_dy.

fun_dx(*p) * v[0] + fun_dy(*p) * v[1]
> 4.0

Now using jax.jvp we obtain the same results: \(f({\bf p})\) and \(\nabla_{\bf v}f({\bf p})\).

jax.jvp(fun, p, v)
> (DeviceArray(1., dtype=float32, weak_type=True),
   DeviceArray(4., dtype=float32, weak_type=True))

A surface plot will show the output space, and a contour plot the input space of \(f(x,y)=x^2y\). We will compute the directional derivatives for three points and their respective directional vectors.

Look the directional vectors in the plot, or tangent vectors as JAX refers to them, there are of different lengths. It’s important to remark that if we want the “slope definition” for directional derivatives we need to transform \({\bf v}\) in a unit length vector (divide the directional derivative definition by \(||v||\)). Remember that partial derivatives are computed using unit vectors (\(e_i\)).

\[ \nabla_{{\bf v}}f = \frac{\partial f}{\partial {\bf x}} = lim_{h \to 0} \frac{f({\bf x} + h{\bf v}) - f({\bf x})}{h||{\bf v}||} \]

primal_a = jnp.array([-5., 3.2])
primal_b = jnp.array([5., -3.2])
primal_c = jnp.array([0., 0.])
va = jnp.array([-7.5, 5.7])
vb = jnp.array([7.5, -5.7])
vc = jnp.array([-1.0, -0.7])
unit_va = va/va.dot(va)**.5
unit_vb = vb/vb.dot(vb)**.5
unit_vc = vc/vc.dot(vc)**.5
# Computing making the directional vectors unit length
_, slope_a = jax.jvp(fun, primal_a.tolist(), unit_va.tolist())
_, slope_b = jax.jvp(fun, primal_b.tolist(), unit_vb.tolist())
_, slope_c = jax.jvp(fun, primal_c.tolist(), unit_vc.tolist())
slope_a, slope_b, slope_c
> (DeviceArray(40.60427, dtype=float32, weak_type=True),
   DeviceArray(-40.60427, dtype=float32, weak_type=True),
   DeviceArray(-0., dtype=float32, weak_type=True))

We can see some observations from the points and their directional derivatives.

  • Point A: the directional derivative is 40.6, makes sense with the contour lines in front of A. The surface start to rise in the direction of \(\overrightarrow{v}_a\).
  • Point B: the function 𝑓 decreases in the direction pointing out the vector \(\overrightarrow{v}_b\), like the directional derivative, 𝑓 changes −40.6 regarding the slight variations in the input across the directional vector. Notice that it has the same magnitude as the slope of point A but goes in the opposite direction; the surface plot shows how the function increases/decreases in the same proportion across its diagonals.
  • Point C: the surface is practically flat around the point (0,0). Notice that the directional derivative at \(\overrightarrow{v}_c\) is 0.
TAGS: CALCULUS, JAX, MML
comments powered by Disqus