Source code for coax.value_transforms._log_transform
import jax.numpy as jnp
from ._base import ValueTransform
[docs]class LogTransform(ValueTransform):
r"""
A simple invertible log-transform.
.. math::
x\ \mapsto\ y\ =\ \lambda\,\text{sign}(x)\,
\log\left(1+\frac{|x|}{\lambda}\right)
with inverse:
.. math::
y\ \mapsto\ x\ =\ \lambda\,\text{sign}(y)\,
\left(\text{e}^{|y|/\lambda} - 1\right)
This transform logarithmically supresses large values :math:`|x|\gg1` and smoothly interpolates
to the identity transform for small values :math:`|x|\sim1` (see figure below).
.. image:: /_static/img/log_transform.svg
:alt: Invertible log-transform
:width: 640px
Parameters
----------
scale : positive float, optional
The scale :math:`\lambda>0` of the linear-to-log cross-over. Smaller
values for :math:`\lambda` translate into earlier onset of the
cross-over.
"""
__slots__ = ValueTransform.__slots__ + ('scale',)
def __init__(self, scale=1.0):
assert scale > 0
self.scale = scale
def transform_func(x):
return jnp.sign(x) * scale * jnp.log(1 + jnp.abs(x) / scale)
def inverse_func(x):
return jnp.sign(x) * scale * (jnp.exp(jnp.abs(x) / scale) - 1)
self._transform_func = transform_func
self._inverse_func = inverse_func