Debugging output helpers (varipeps.utils.debug_print)

varipeps.utils.debug_print.debug_print(fmt: str, *args, ordered: bool = False, **kwargs) None[source]

Prints values and works in staged out JAX functions.

Function adapted from jax.debug.print to work with tqdm. See there for original authors and function.

Parameters:
  • fmt (str) – A format string, e.g. "hello {x}", that will be used to format input arguments.

  • *args – A list of positional arguments to be formatted.

Keyword Arguments:
  • ordered (bool, optional) –

    : A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this debug_print w.r.t. other ordered debug_print calls.

    Default: False

  • **kwargs – Additional keyword arguments to be formatted.