Advantage-Induced Policy Alignment

Building on the classic results on reward weighted regression and its more recent adaptation to deep learning, a new algorithm called advantage-induced policy alignment (APA) is proposed for aligning language models (LM) to human preferences. APA is claimed to significantly outperform state-of-the-art methods such as PPO. This pill also contains a short overview of the family of algorithms from which APA emerged.

Prelude

This paper pill covers recent work by Zhu et al. [Zhu23F] and puts it into a short historical/algorithmic context. The advantage-induced policy alignment (APA) algorithm proposed therein is the latest installment in a long line of extensions of simple methods in reinforcement learning (RL), which was started by the reward weighted regression algorithm (derived from an expectation-maximization framework by Peters and Schaal [Pet07R]). This pill contains more references than usual, because I believe the context and ideas surrounding the algorithms mentioned here may be of general interest.

APA not only holds the promise of improving the state-of-the-art for aligning language models (LM) to human preferences, but could also be of general use in RL. What I find particularly fascinating about the works mentioned in this pill (and about the modern history of RL in general), is how powerful algorithms arise from small modifications to previously published (and often simple) methods that proved to work well in restricted settings. It leaves me wondering what other gems are still hidden in the literature, waiting to be rediscovered as state-of-the-art algorithms.

As a motivation for practitioners, we proceed with the experimental results of APA [Zhu23F], before diving deeper into the theoretical background and derivations.

Experimental Results

The APA algorithm follows the same training routines as typical reinforcement learning from human feedback (RLHF) methods (see e.g. Ziegler et al. [Zie20F]), but with a different loss. Thus, the authors can make a direct comparison of APA to RLHF with PPO on an already trained and fixed reward model. They also include a comparison to another loss derived from advantage weighted regression (AWR), which to my knowledge is the first time this method is applied to language modeling. Since one of the main novelties of the original AWR paper was the extension of the AWR loss to off-policy learning, the poor performance of AWR in the on-policy setting is maybe not too surprising. Thus, the main practically relevant comparison is between APA and PPO.

Figure 1 from [Zhu23F]. Comparison of the performance of three methods on the Helpfulness and Harmlessnes dataset. Top: The x-axis represents the total steps, which are proportional to the amount of data used in the training procedure. The y-axis is the reward evaluated by the same reward model. Bottom: The x-axis represents the total steps. The y-axis is the KL divergence between the trained model and the initial model.

We can see that APA outperforms PPO in terms of reward and has a comparable control of the KL divergence to the non-finetuned model $\pi_{\text{init}}$. However, I believe it is premature to conclude that APA is the new state-of-the-art for training language models from these results alone. Improving on a learned reward metric does not necessarily mean that the resulting model is better at the task we care about. In fact, we know that models can achieve very high reward with gibberish output, 1 1 Ziegler et al. [Zie20F] report high-reward completions of the type “! These These These sound flowed instantly easily easily easily easily!” for training without the KL-penalty. which is precisely the reason behind including the KL divergence to the initial policy in the optimization problem.

For the largest of the tested models (which is still small in comparison to LLMs in use today), APA was able to achieve higher reward at the cost of a higher KL divergence. Moreover, looking at the graphs, there seems to be a tendency to deviate more strongly from the initial policy as the model size increases. On the other hand, for smaller models, APA is unambiguously better than PPO.

The practical conclusion from these experiments is unclear to me at this moment. One would at least need to perform an evaluation on the standard downstream tasks to see if the performance of the model has actually improved. Unfortunately, the authors provide no such evaluation. This does not mean that APA is not a promising method, and I hope a future evaluation of with larger models and a variety of downstream tasks will shed more light on it.

From Weighted Regression to APA

Notation

We use standard notation for RL, where $\pi$ denotes a policy, $s$ a state, $a$ an action, $d^\pi$ a probability measure (of states or state-action pairs) induced by $\pi$, $A^\pi$ an (estimation of) the advantage function for the policy $\pi$, and $\mathcal{D}$ a dataset of trajectories sampled from an environment with some policy. We will not go into the details of what each subscript means, as it should be clear from the context. For a more complete overview we recommend to read through the cited references.

Core Ideas of Weighted Regression

RL algorithms from the weighted regression family are generally based on the following idea: given some policy $\pi_{\text{old}}$ and a dataset of trajectories $\mathcal{D}$ sampled from it, one wants to find an improved (in terms of expected reward) parameterized policy $\pi_\theta$ that is close to $\pi_{\text{old}}$ in some sense. 2 2 Note that this is not quite the same problem typically addressed in policy gradient (PG) or value iteration settings, where the goal is to find a policy that maximizes the expected reward without regard to $\pi_{\text{old}}$. Despite that, the implementations of weighted regression may look very similar to implementations of other RL methods, often just differing by the loss used in gradient descent.. E.g., in an on-policy weighted regression setting, one typically performs alternating rounds of incremental updates to a parameterized policy (which takes on the role of $\pi_ {\text{old}}$), and sampling of new data with the updated $\pi_\theta$, just like in standard iterative schemes. On the other hand in practical PG methods, one often uses some form of constraint to not deviate too strongly from the sampling policy in order to stabilize learning, like in weighted regression. The resulting optimization problems often admit closed form, non-parametric solutions - let’s call them $\pi^*$. The policy-improvement algorithms typically involve a projection of a parameterized policy $\pi_\theta$ onto $\pi^*$. These incremental improvements to $\pi_\theta$ can then be combined with sampling of new data and repeated until convergence. The analytic solutions $\pi^*$ generally take the form of $\pi_{\text{old}}$ weighted by some factors, and the projection step is performed with regression (i.e. supervised learning), hence the name.

A particularly clean manifestation of this scheme can be found in the AWR paper [Pen19A]. It goes roughly as follows:

  1. Start with the first-order approximation of the expected improvement of a policy $\pi$ over the sampling policy $\pi_{\text{old}}$:

    \begin{equation} J(\pi) - J(\pi_{\text{old}}) \approx \mathbb{E}_{s \sim d^{ \pi_{ \text{old} } } } \ \mathbb{E}_{a \sim \pi(a\mid s)} \left[ A^{\pi_{\text{old}}}(s,a) \right], \tag1\end{equation}

    where $J$ is the expectation of discounted sum of future rewards (i.e. the objective one wants to maximize), and $A^{\pi_{\text{old}}}(s,a)$ is the advantage function of $\pi_{\text{old}}$.3 3 This expression is the starting point of several policy gradient algorithms, in particular trust region policy optimization and proximal policy optimization. A derivation can be found e.g. in our blog post on this topic.

  2. Restrict the solution space to policies that are close to $\pi_{\text{old}}$ in terms of KL divergence, resulting in the Lagrangian

    \begin{equation} \mathcal{L} = \mathbb{E}_{s \sim d^{ \pi_{ \text{old} } } } \ \mathbb{E}_{a \sim \pi(a\mid s)} \left[ A^{\pi_{\text{old}}}(s,a) \right] + \beta \left( \epsilon - \text{KL}(\pi_{\text{old}} \mid \mid \pi) \right), \tag2\end{equation}

    where $\beta$ is the Lagrange multiplier and $\epsilon$ some small constant.

  3. The analytic solution of the above problem is:

    \begin{equation} \pi^*(a \mid s) = \frac{1}{Z(s)} \pi_{\text{old}}(a \mid s) \exp \left( \frac{A^{\pi_{\text{old}}}(s,a)}{\beta} \right), \tag3\end{equation}

    where $Z(s)$ is a state-dependent normalization factor (the partition function). As noted above, this is a weighted version of $\pi_{\text{old}}$.

  4. Find an improved set of parameters $\theta^*$ by minimizing the difference between $\pi_\theta$ and $\pi^*$ on the dataset $\mathcal{D}$ according to some distance function $\text{dist}$ for probability distributions:


    \begin{equation} \theta^* = \arg \min_\theta \mathbb{E}_{s \sim \mathcal{D} } \left[ \text{dist}(\pi_\theta(\cdot \mid s), \pi^*(\cdot \mid s)) \right]. \tag4\end{equation}

Weighted Regression in Practice

There are multiple variations of this basic scheme. In the original AWR paper, the authors use the KL-divergence as the distance function and make the assumption that the partition function $Z(s)$ can be set to one for practical purposes.4 4 There is a small mistake in the derivation of the AWR algorithm in the original paper, where $Z(s)$ just silently disappears from the equations (the authors have acknowledged this in a private communication). The role of this factor has been discussed by Wang et al. [Wan18E] in the context of imitation learning and by Nair et al. for offline RL [Nai21A]. Moreover, they introduce a strategy for reusing samples from previous iterations, thereby turning advantage-weighted regression into an off policy algorithm. Other works used instantiations of (4) for imitation learning [Wan18E], or as an ingredient in an offline-RL algorithm [Nai21A, Kos21O].

Early works by Peters, Schaal, and Neumann [Pet07R, Neu08F] arrived at similar schemes by using the expectation-maximization (EM) framework on a lower bound on the expected reward (see also Strupl et. al. [Str22R] for a nice theoretical analysis and extension of the original reward weighted regression). The well-known REPS algorithm [Pet10R] and follow-up work is making use of the analytic solution to an objective similar to Equation 2. The widely used trust region and proximal policy optimization algorithms [Sch15T, Sch17P] directly aim at optimizing Equation 2 or a variation thereof, but without using the analytic solution. Instead, they focus on policy gradient methods and practical implementation details (see e.g. our blog post Natural, Trust Region and Proximal Policy Optimization for a detailed discussion).

The APA Algorithm

The APA algorithm is the newest addition to the family of weighted regression algorithms, and the first to my knowledge to tackle language modeling. It makes the following modifications to the basic scheme outlined above:

  1. For language modelling, one wants the fine-tuned model to be close to original pre-trained model $\pi_{\text{init}}$ in terms of KL-divergence, i.e. $\pi_{\text{old}}$ is replaced with $\pi_{\text{init}}$. Therefore, the constraint in Equation 2 is adjusted accordingly, resulting in the optimal policy

    \begin{equation} \pi^*(a \mid s) = \frac{1}{Z(s)} \pi_{\text{init}}(a \mid s) \exp \left( \frac{A^{\pi_{\text{old}}}(s,a) }{\beta} \right). \end{equation}
  2. The authors argue that, for LMs, one can safely assume $Z(s) \approx 1 $.5 5 The authors' intuition behind this is that to first order in $\frac{A^{\pi_{\text{old}}}(s,a) }{\beta}$, the partition function is $1 + \frac{1}{\beta}\mathbb{E}_{a \sim \pi_{\text{init}}(a \mid s)} [A^{\pi_{\text{old}}}(s,a) ]$. Assuming that the tuned policy is close enough to the initial policy for the advantages to be similar, one gets $Z(s) \approx 1 + \frac{1}{\beta}\mathbb{E}_{a \sim \pi_{\text{init}}(a \mid s)} [A^{\pi_{\text{init}}}(s,a) ] = 1$. During the training they notice that the loss decreases by very little, suggesting that the tuned policy is indeed very close to $\pi_{\text{init}}$. On a personal note, while I don’t necessarily see the validity of these intermediate approximations (a value of $\beta=0.1$ is used in the experiments, which is not particularly large), the final justification is that the resulting algorithm works well. As the distance function $\text{dist}$ for the projection step Equation 4, they use the squared error between logits weighted by previously sampled data, i.e.


    \begin{equation} \text{dist}_{\text{APA}}(\pi_\theta(\cdot \mid s), \pi^*(\cdot \mid s)) \mathrel{\mathop:}= \mathbb{E}_{a \sim \pi_{\text{old}}(a \mid s)} \left[ \left( \log \pi_\theta(a \mid s) - \log \pi^*(a \mid s) \right)^2 \right]. \end{equation}

These design choices result in the APA loss:

\begin{equation} \mathcal{L}_{\text{APA}} \mathrel{\mathop:}= \mathbb{E}_{(s, a) \sim \mathcal{D}_{\text{old}} } \left[ \log^2 \left( \frac{\pi_\theta(a \mid s)}{\pi_{\text{init}}(a \mid s)} \exp \left(- \frac{A^{\pi_{\text{old}}}(s,a)}{\beta} \right) \right) \right]. \tag5\end{equation}

When applying APA for training LMs, Zhu et al. use the weighted regression scheme in its on-policy version, i.e. they do not reuse samples from previous iterations. Thus, the method can be directly compared to other on-policy algorithms, differing from them only in the loss function Equation 5.

Discussion

From the theoretical point of view, the new advantage-induced policy alignment algorithm resulted from a small modification of advantage-weighted regression, and as such is not a major breakthrough.

On the practical side, however, this modification seems to result in a significant improvement in performance. Therefore, this work may be important for practical purposes, as well as for rejuvenating interest in the family of weighted regression algorithms.

I found it interesting to follow the discussions leading to the rejections of the AWR paper as well as some popular follow-up works like AWAC from the conferences they applied to; see the discussions on openreview here and here. The main reasons for rejection were that the new algorithms were neither sufficiently different from existing weighted regression approaches, nor more performant than state-of-the-art algorithms. We see, however, that simple and efficient ideas, while perhaps not immediately beating the state-of-the-art, can still be highly influential and eventually lead to significant improvements in performance. Many now famous papers in machine learning are based on such small theoretical modifications, so I believe it is fair to say that the authors of AWR were somewhat unlucky in their timing for the submission of their results.

Software-wise, APA is built on top of the modern trlx library for combining transformers with RL, and as such it could be a useful starting point for custom RLHF algorithms that follow good software engineering practices. It is refreshing to see research code that is going in the direction of being reusable and extensible. The code is freely available on GitHub under the MIT license.

References

In this series