Implicit Quantile Network (IQN)ΒΆ

Implicit Quantile Networks are a distributional RL method that model the distribution of returns using quantile regression. They were introduced in the paper [arxiv:1806.06923] and replaced the fixed parametrization of the quantile q-function of Quantile-Regression DQN [arxiv:1710.10044] with uniformly sampled quantile fractions.

For the generation of equally spaced quantile fractions as in QR-DQN in coax have a look at coax.utils.quantiles. For uniformly distributed quantile fractions as in IQN there is the coax.utils.quantiles_uniform function.


iqn.py

Open in Google Colab
import gymnasium
import coax
import optax
import haiku as hk
import jax
import jax.numpy as jnp


# pick environment
env = gymnasium.make(...)
env = coax.wrappers.TrainMonitor(env)

# choose iqn hyperparameters
quantile_embedding_dim = 32
num_quantiles = 32


def func_type1(S, A, is_training):
    # custom haiku function: s,a -> q(s,a)
    net = hk.Sequential([...])
    X = jax.vmap(jnp.kron)(S, A)  # or jnp.concatenate((S, A), axis=-1) or whatever you like
    quantile_values, quantile_fractions = net(X)
    return {'values': quantile_values,  # output shape: (batch_size, num_quantiles)
            'quantile_fractions': quantile_fractions}


def func_type2(S, is_training):
    # custom haiku function: s -> q(s,.)
    quantile_values, quantile_fractions = hk.Sequential([...])
    return {'values': quantile_values,  # output shape: (batch_size, num_actions, num_quantiles)
            'quantile_fractions': quantile_fractions}


# function approximator
func = ...  # func_type1 or func_type2

# quantile value function and its derived policy
q = coax.StochasticQ(func, env, num_bins=num_quantiles, value_range=None)
pi = coax.BoltzmannPolicy(q)

# target network
q_targ = q.copy()

# specify how to trace the transitions
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=100000)

# specify how to update q-function
qlearning = coax.td_learning.QLearning(q, q_targ=q_targ, optimizer=optax.adam(0.001))


for ep in range(1000):
    s, info = env.reset()

    for t in range(env.spec.max_episode_steps):
        a = pi(s)
        s_next, r, done, truncated, info = env.step(a)

        # trace rewards and add transition to replay buffer
        tracer.add(s, a, r, done)
        while tracer:
            buffer.add(tracer.pop())

        # learn
        if len(buffer) >= 100:
            transition_batch = buffer.sample(batch_size=32)
            metrics = qlearning.update(transition_batch)
            env.record_metrics(metrics)

        # sync target network
        q_targ.soft_update(q, tau=0.01)

        if done or truncated:
            break

        s = s_next

    # early stopping
    if env.avg_G > env.spec.reward_threshold:
        break