Computer Vision News - February 2021

18 Computer Vision Tool It is often very difficult to debug functions or even simple event compilations. JAX provides a very simple API to trace where NaN’s are occurring in the functions or gradients. A documentation example is the following: The generated NaN was caught as an invalid value above. What if we need a debugger for the functions under @jit ? The solution is to run %debug as shown below. In [1]: import jax.numpy as jnp In [2]: jnp.divide(0., 0.) --------------------------------------------------------------------------- FloatingPointError Traceback (most recent call last) <ipython-input-2-f2e2c413b437> in <module>() ----> 1 jnp.divide(0., 0.) .../jax/jax/numpy/lax_numpy.pyc in divide(x1, x2) 343 return floor_divide(x1, x2) 344 else: --> 345 return true_divide(x1, x2) 346 347 .../jax/jax/numpy/lax_numpy.pyc in true_divide(x1, x2) 332 x1, x2 = _promote_shapes(x1, x2) 333 return lax.div(lax.convert_element_type(x1, result_dtype), --> 334 lax.convert_element_type(x2, result_dtype)) 335 336 .../jax/jax/lax.pyc in div(x, y) 244 def div(x, y): 245 r"""Elementwise division: :math:`x \over y`.""" --> 246 return div_p.bind(x, y) 247 248 def rem(x, y): ... stack trace ... .../jax/jax/interpreters/xla.pyc in handle_result(device_buffer) 103 py_val = device_buffer.to_py() 104 if np.any(np.isnan(py_val)): --> 105 raise FloatingPointError("invalid value") 106 else: 107 return DeviceArray(device_buffer, *result_shape) FloatingPointError: invalid value

RkJQdWJsaXNoZXIy NTc3NzU=