-
[DPG] Distributional Reinforcement Learning for Energy-Based Sequential ModelsResearch/... 2024. 12. 12. 00:12
https://arxiv.org/pdf/1912.08517
Abstract
Global Autoregressive Models (GAMs) are a recent proposal [15] for exploiting global properties of sequences for data-efficient learning of seq2seq models. In the first phase of training, an Energy-Based model (EBM) [10] over sequences is derived. This EBM has high representational power, but is unnormalized and cannot be directly exploited for sampling. To address this issue [15] proposes a distillation technique, which can only be applied under limited conditions. By relating this problem to Policy Gradient techniques in RL, but in a distributional rather than optimization perspective, we propose a general approach applicable to any sequential EBM. Its effectiveness is illustrated on GAM-based experiments.
1. Introduction
The mainstream autoregressive sequence models [ 6 , 22 , 5 , 24]) form a subclass of sequential energy-based models (sequential EBMs) [10]. While the former are locally normalized and easy to train and sample from, the latter allow global constraints, greater expressivity, and potentially better sample efficiency, but lead to unnormalized distributions and are more difficult to use for inference and evaluation. We exploit a recently introducaked class of energy-based models, Global Autoregressive Models (GAMs) [15], which combine a locally normalized component (that is, a first, standard, autoregressive model, denoted r) with a global component and use these to explore some core research questions about sequential EBMs, focussing our experiments on synthetic data for which we can directly control experimental conditions. We dissociate the (relatively easy) task of learning from the available data an energy-based representation (Training-1), from the more challenging task of exploiting that representation to produce samples or evaluations (Training-2).
In this paper, we provide a short self-contained introduction to GAMs and to their two-stage training procedure. However our main focus is about Training-2. For that task [15] proposed a Distillation technique to project the Energy-Based representation (denoted by Pλ) obtained at the end of Training-1 into a final autoregressive model (denoted πθ), with better test perplexity than the initial r, but this technique was limited to cases where it was possible to sample from Pλ at training time. One key observation of the current submission is that Training-2, considered as the general problem of deriving an autoregressive model from an energy-based model (not necessarily obtained through Training-1) has strong similarities with the training of policies in Reinforcement Learning (RL), but in a distributional rather than in an optimization perspective as in standard RL. We then propose a distributional variant of the Policy Gradient technique (Distributional Policy Gradient: DPG) which has wider applicability than distillation. We conduct GAM-based experiments to compare this technique with distillation, in synthetic data conditions where distillation is feasible, and show that DPG works as well as distillation. In both cases, in small data conditions, the policies (aka autoregressive) models πθ obtained at the end of the process are very similar and show strong perplexity reduction over the standard autoregressive models.
Section 2 provides an overview of GAMs. Section 3 explains the training procedure, with focus on EBMs and relations to RL. Section 4 presents experiments and results. For space reasons we use the Supplementary Material (Sup. Mat.) to provide some details and to discuss related work.
2. Model
2.1. Background
2.2. GAMs
3. Training
We assume that we are given a training data set D (resp. a validation set V , a test set T) of sequences x, and a finite collection of real-valued feature functions φ1, . . . , φk. The GAM training procedure then is performed in two stages (see Fig. 1).
3.1. Training-1: from data to energy-based representation
3.2. Training-2: from energy-based representation to distributional policy
The output of the previous stage is an unnormalized EBM, which allows us to compute the potential P(x) = P_λ(x) of any given x, but not directly to compute the partition function Z = summation_x P(x) nor the normalized distribution p(x) = 1/Z P(x) = p_λ(x) or to sample from it. In RL terms, the score P(x) can be seen as a reward. The standard RL-as-optimization view would lead us to search for a way to maximize the expectation of this reward, in other words for a policy πθ∗ with θ∗ = argmax_θ E x∼πθ(·) P(x), which would tend to concentrate all its mass on a few sequences.
By contrast, our RL-as-sampling (distributional) view consists in trying to find a policy πθ∗ that approximates the distribution p as closely as possible, in terms of cross-entropy CE. We are thus trying to solve θ∗ = argminθ CE(p, πθ),
We can apply (4) for SGD optimization, using different approaches.
The simplest approach, Distillation, can be employed in situations where we are able to draw, in reasonable training time, a large number of samples x1, . . . , xK from p. We can then exploit (4) directly to update θ, which is in fact equivalent to performing a standard supervised log-likelihood SGD training on the set {x1, . . . , xK}. This is the approach to Training-2 taken in [15], using rejection sampling at training time for obtaining the samples, and then training θ on these samples to obtain a final AM πθ which can be used for efficient sampling at test time and for evaluation. The advantage of this approach is that supervised training of this sort is very succesful for standard autoregressive models, with good stability and convergence properties, and an efficient use of the training data through epoch iteration.3 However, the big disadvantage is its limited applicability, due to restrictive conditions for rejection sampling, as explained earlier.
A central contribution of the present paper is to propose another class of approaches, which does not involve sampling from p, and which relates to standard techniques in RL. We can rewrite the last formula of (4) as:
Here the sampling policy q is different from the policy being learnt, and the formula (6) represents a form of Importance Sampling, with q the proposal, typically chosen to be an approximation to p.
We did some initial experiments with DPG_on, but found that the method had difficulty converging, probably due in part to the instability induced by the constant change of sampling distribution (namely πθ). A similar phenomenon is well documented in the case of the vanilla Policy Gradient in standard RL, and techniques such as TRPO [20] or PPO [21] have been developed to control the rate of change of the sampling distribution. In order to avoid such instability, we decided to focus on DPG_off, based on Algorithm 1 below.
In this algorithm, we suppose that we have as input a potential function P, and an initial proposal distribution q; in the case of GAMs, we take P = Pλ and a good πθ0 is provided by r. We then iterate the collection of episodes x sampled with the same q (line 4), and perform SGD updates (line 5) according to (6) (α (θ) is the learning rate). We do update the proposal q at certain times (line 7), but only based on the condition that the current πθ is superior to q in terms of perplexity measured on the validation set V , thus ensuring a certain stability of the proposal.
This algorithm worked much better than the DPG_on version, and we retained it as our implementation of DPG in all our experiments.
'Research > ...' 카테고리의 다른 글
(2/3) GAN, F-Divergence, IPM (0) 2024.12.21 (1/3) GAN, F-Divergence, IPM (0) 2024.12.20 DPG (0) 2024.12.17 On Reinforcement Learning and Distribution Matching for Fine-Tuning Language Models with no Catastrophic Forgetting (0) 2024.12.17 A Distributional Approach to Controlled Text Generation (0) 2024.12.09