Computer Vision News - February 2021
Using the function grad one can differentiate to any order of gradient, e.g. 17 JAX Hessian matrices can be computed, essentially by using the hessian function, which supports nested Python containers as inputs and outputs. Compiling expressions To compile JAX expressions XLA can be used (end-to-end compilation) with jit . There are two ways to use jit : either as a decorator ( @jit ) or as a higher-order function. Debugging One great feature of JAX, which is very often missing in the field of deep learning (or at least it’s not equally represented in relation to the broader software- engineering perspective) is “debugging”. from jax import grad import jax.numpy as jnp def tanh(x): # Define a function y = jnp.exp(-2.0 * x) return (1.0 - y) / (1.0 + y) grad_tanh = grad(tanh) # Obtain its gradient function print (grad_tanh(1.0)) # Evaluate it at x = 1.0 # prints 0.4199743 print (grad(grad(grad(tanh)))(1.0)) # prints 0.62162673 >>> import jax.numpy as jnp >>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])} >>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.})) {'c': {'a': {'a': DeviceArray([[[ 2., 0.], [ 0., 0.]], [[ 0., 0.], [ 0., 12.]]], dtype=float32), 'b': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)}, 'b': {'a': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32), 'b': DeviceArray([[[0. , 0. ], [0. , 0. ]], [[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}} import jax.numpy as jnp from jax import jit def slow_f(x): # Element-wise ops see a large benefit from fusion return x * x + x * 2.0 x = jnp.ones((5000, 5000)) fast_f = jit(slow_f) %timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X %timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
Made with FlippingBook
RkJQdWJsaXNoZXIy NTc3NzU=