Implicit Q Learning

A core challenge when applying dynamic programming based approaches to offline reinforcement learning is the bootstrapping error. This pill presents a paper proposing an algorithm called implicit Q learning that mitigates bootstrapping issues by modifying the argmax in the Bellman equation.

This paper pill covers implicit Q learning(IQL) [Kos21O], a state-of-the-art offline reinforcement learning (RL) method that can also be used for subsequent online fine-tuning.

Background

Let us consider sequential decision-making in a Markov decision process $ (S, A, P, r, H)$ with state space $S$, action space $A$, transition dynamics $P$, reward function $r$ and horizon $H$. We will be concerned with the offline RL setting, where the agent has no access to the environment but instead learns from a fixed buffer of transitions $D = \{s_i, a_i, r_i, s’_{i}\}_{i=1}^N $ of size $N$. These transitions were collected with a behaviour policy $\pi_{\beta}$, which might be a mixture of several suboptimal policies.

Applying classic Q-learning to the offline data $D$ amounts to training the critic parameters $\theta$ via minimizing the mean squared Bellman error obtained with dynamic programming: \begin{equation} L_{TD}(\theta) = \mathbb{E}_{(s,a,r,s') \sim D} \left[\left(r + \gamma \max_{a' \in A} Q_{\hat {\theta}}(s',a') - Q_{\theta}(s,a) \right)^2\right] \tag1\end{equation} and extracting a policy $\pi_\phi$ by selecting the best actions for given states

\begin{equation} L(\phi) = \mathbb{E}_{s \sim D} \left[Q_\theta(s,\pi_\phi(s)) \right]. \tag2\end{equation}

Here TD stands for the temporal-difference error, and $r(s,a) + \gamma \max_{a’ \in A} Q_{\hat{\theta}}(s’,a’)$ is called the TD-target.

The $\max$ operation in the bootstrapping of the TD-target as well as in the policy extraction step is susceptible to extrapolation errors. Such errors occur when the critic is not truthful for actions far from the data in the buffer $D$. While in online RL the agent can collect additional experience and correct itself (optimism is even probably efficient), this is not an option in the offline RL setting. Thus, a core challenge is offline RL is avoiding the occurrence or mitigating the impact of extrapolation errors. Bootstrapping, coupled with the maximization over the action space, carries a high risk of selecting actions that are not within the support of the dataset. The SARSA approach gives an alternative objective that avoids querying the critic for actions outside the offline dataset’s support:

\begin{equation} L_{TD}(\theta) = \mathbb{E}_{(s,a,r,s',a') \sim D} \left[\left(r + \gamma Q_{\hat {\theta}}(s',a') - Q_{\theta}(s,a) \right)^2\right]. \tag3\end{equation}

Note that in this case the buffer contains transitions of the eponymous $(s,a, r,s’,a’)$ type, whereas the offline Q-learning update requires only $(s,a,r,s’)$.

Expectile Regression

Let $\tau \in (0, 1)$ and $L_2^\tau (u) = |\tau − 1(u < 0)|u^2$. The $\tau$-expectile of a random variable $X$ is defined as a solution to the asymmetric least squares problem: \begin{equation} \underset{m_\tau}{\text{argmin}} \ \mathbb {E}_{x \sim X} \left[ L_2^\tau (x-m_\tau)\right]. \tag4\end{equation} For $\tau=0.5$ the symmetric least squares problem is recovered. For $\tau \in (0.5,1)$ contributions of points $x > m_\tau$ are upweighted, while negative differences are weighted lower. The asymmetric weighting is flipped for $\tau \in (0,0.5)$.

Expectile regression can also be applied to conditional, predictive distributions. The $\tau$-expectile of a conditional distribution $(x \mid y)$ for $(x,y) \sim D$ is given by the solution of:

\begin{equation} \arg \min_{m_τ(x)} \mathbb{E}_{(x,y)\sim D} [L_2^\tau (y − m_τ(x))]. \tag5\end{equation} Note that the maximum operator over in-distribution values of $y$ is approximated by $\tau \approx 1$.

Figure 1 from [Kos21O]. Left: The asymmetric squared loss used for expectile regression. $ \tau = 0.5$ corresponds to the standard mean squared error loss, while $ \tau = 0.9$ gives more weight to positives differences. Center: Expectiles of a normal distribution. Right: an example of estimating state conditional expectiles of a two-dimensional random variable. Each $x$ corresponds to a distribution over $y$. We can approximate a maximum of this random variable with expectile regression: $ \tau = 0.5$ correspond to the conditional mean statistics of the distribution, while $\tau \approx 1$ approximates the maximum operator over in-support values of $y$.

The IQL Algorithm

The technical core contribution of the paper is to use expectile regression within the policy evaluation step.

Predicting an upper expectile of the TD-target approximates the maximum of $r(s, a) + \gamma Q_\theta(s’, a’)$ over actions $a’$ constrained to the dataset actions. The resulting training objective for the critic is:

\begin{equation} L_{TD}(\theta) = \mathbb{E}_{(s,a,r,s') \sim D} \left[\left(r + \gamma \max_{a' \in A(s,\pi_\beta)} Q_{\hat{\theta}}(s',a') - Q_{\theta}(s,a) \right)^2\right], \tag6\end{equation} where $A(s,\pi_\beta) = \{a \in A | \pi_\beta(a|s)>0\}$ is the supported set of action of policy $\pi_\beta$ for state $s$ via dataset $D$. While this objective mitigates the problems with the $\max$, there is still an issue in the expectiles being estimated with respect to both, states and actions. Even so only actions in the support of the data are considered, the objective also incorporates stochasticity that comes from the environment dynamics $s^′ \sim P (·|s, a)$. Therefore, a large target value might not necessarily reflect the existence of a single action that achieves that value, but rather a “fortunate” sample transitioning into a good state. The proposed solution is to introduce a separate value function $V_\psi$ approximating an expectile purely with respect to the action distribution. This leads to the final training objective for the value function:

\begin{equation} L_V(\psi) = \mathbb{E}_{(s,a) \sim D} [L_τ^2 (Q_{\hat{θ}}(s, a) − V_\psi (s))]. \tag7\end{equation} This value function is also used to train the Q-functions with MSE loss, averaging over the stochasticity from the transitions and avoiding the aforementioned “lucky” sample issue: \begin{equation} L_Q(\theta)=\mathbb{E}_{(s,a,s′) \sim D}[(r(s,a)+γV_ψ(s′)−Q_\theta(s,a))^2]. \tag8\end{equation}

Note that the critic can be independently trained from the actor. Thus, it suffices to extract an actor policy once training the critic is completed. Extracting the actor is done via advantage weighted regression (AWR) [Pen19A] (see also our previous pill). The idea is to select actions that are good under the carefully trained critic while avoiding extrapolation error by satisfying a KL-divergence constraint to the behavior policy,

\begin{equation} L(\phi) = \mathbb{E}_{s, a \sim D} \left[ \exp(\beta(Q_ \theta(s, a) - V_\psi(s))) \log\pi_\phi(a|s) \right], \tag9\end{equation}

where $\beta$ denotes an inverse temperature.

Experimental Results

In a toy experiment using mazes, the authors collect trajectories that are in part heavily suboptimal. Out of all tested algorithms, IQL is the only one displaying successful “stitching behaviour” of the optimal parts.

Figure 2 from [Kos21O]. Evaluation of our algorithm on a toy u-maze environment (a). When the static dataset is heavily corrupted by suboptimal actions, one-step policy evaluation results in a value function that degrades to zero far from the rewarding states too quickly (c). Our algorithm aims to learn a near-optimal value function, combining the best properties of SARSA-style evaluation with the ability to perform multistep dynamic programming, leading to value functions that are much closer to optimality (shown in (b)) and producing a much better policy (d).

The choice of temperature $\tau$ to be sufficiently large is crucial to the emergence of “stitching behaviour”, see Figure 3. Figure 3 from [Kos21O].Estimating a larger expectile $τ$ is crucial for antmaze tasks that require dynamical programming (’stitching’).

Additional experiments in the purely offline and the hybrid of offline pretraining and online fine-tuning setting on the D4RL offline RL benchmark (in the meantime migrated to Minari) demonstrate state-of-the-art performance.

Discussion

Offline RL is concerned with finding a policy that improves over the behavior policies that collected the data. To counterbalance the effect of extrapolation error, most works add a constraint to stay close(in some metric) to the (potential mix of) suche behavior policies while improving over it (aka maximise some critic). Using the statistical tool of expectile regression, IQL is one more instantiation of this common paradigm. Note that although expectile regression elegantly mitigates extrapolation error, the policy extraction via AWR still implements a “stay close to the data collection policy” constraint. Personally, I see some room for improvement in the latter, as the “staying close to the average data collection policy” term might be harmful in the case of very poor data.

Software-wise, the authors provide an official implementation based on JAXRL. Plenty of other projects like pytorchRL contain IQL as well.

References

In this series