Recycling GPT2 to generate audio
Language models — like ChatGPT, Llama and Mistral — are a hot topic nowadays and everybody seem to be trying them out; but, what about language models for audio generation? In this article, I’ll show you some of the experiments I did to generate audio using GPT2, EnCodec and EnCodecMAE. But first, let’s talk about language models (LMs).
Language models estimate probabilities of sequences of words given a corpus. So, what’s the probability of a sequence? Given a sequence $S$ of length $T$:
\[[x_1,x_2,...,x_T]\]its probability is calculated as the joint probability of every element in the sequence:
\[P(S) = P(x_1,x_2,...,x_T)\]Using the chain rule of probability, and assuming an ordering of the elements, it can be expressed as:
\[P(S) = P(x_1)P(x_2|x_1)P(x_3|x_1,x_2),...,P(x_T|x_1,...,x_{T-1})\]Let’s explain this with a concrete example. Let’s say we have the following sentence:
"I play football"
The first step is to turn this sentence into a sequence. There are many possible ways to do it:
"[I, play, football]"
or
"[I, play, foot#, #ball]"
or
"[I, ,p,l,a,y, ,f,o,o,t,b,a,l,l]"
Tokenization is the process of turning a corpus into sequences of symbols. The symbols could be words like in the first example; or allow subwords like decomposing football in foot and ball; or could be chars like in the last example.
Let’s stick with the first example of tokenization and calculate the probability of the sentence. According to the rule of chain we have
\[P(S) = P(x_1)P(x_2|x_1)P(x_3|x_1,x_2),...,P(x_T|x_1,...,x_{T-1})\] \[P(S) = P(I)P(play|I)P(football|I,play)\]Each term in the equation tells us the probability of a word given the previous words. If we train a neural network to predict the next word given the previous ones, and we use cross-entropy as loss function, then
the outputs will correspond to each of the terms in the equation. The important detail is that our models have to be unable to ‘see’ into the future; if we peek into the future, then the probabilities will no longer be only conditioned in the past elements of the sequence.
For recurrent neural networks this restriction is inherent to the model as tokens are processed in the sequence order. In the case of transformers, attention is computed between each query and only the past key-values by using attention masks.
Pre-deep learning methods, like n-grams, would assume that the current output only depends on the previous n-1 outputs. Also causal convolutional neural networks, like Wavenet
The idea explained in the above section can be extended to model any sequence, not only text. Let’s try to use it for audio!
The first step is to turn the audio into a sequence $S$, as we did with the sentence.
Audio itself is a sequence of discrete values, where the sampling rate defines how many values there are in 1 second.
For speech, a common value is 16000 Hz, and for music 44100 Hz. This means that if we
have 1 second of music and want to predict the next sample, the model will have to take the 44100 previous values into account.
That’s really hard and computationally expensive, specially for recurrent neural networks (RNN) and transformers,
which are the 2 leading models for sequence modelling in NLP.
One workaround is to learn a more compact representation of audio. We will take a look into EnCodec
As it can be seen in the figure, EnCodec has an encoder and a decoder and its objective is to reconstruct the input audio (it’s an autoencoder). The loss consists of a weighted sum of other losses: a reconstruction loss in the waveform domain $l_t$; in the spectrogram domain $l_s$ and an adversarial loss to reduce artifacts. The important bit is that the autoencoder has a very restricted bottleneck: the encoder downsamples the waveform from 24kHz to 75Hz (320x). This might seem like a lot of compression but the caveat is that those 75 elements per second are 128-D vectors. So at this point, the actual reduction is of just 2.5X.
To further reduce the size of the waveform, EnCodec quantizes the bottleneck. The idea is that a quantization layer will map each of these 128-D vectors into one of 1024 possible integers. This is done by learning a codebook, which is just a lookup table with 1024 128D vectors (codes). Then, during inference, the output of the encoder is compared against each of the codes and the index corresponding to the closest one is returned. 1024 possible values can be represented using 10 bits. This means that 1 second of audio could be represented as a sequence of 75 10 bits elements, which gives a bitrate of 750 bits per second (bps). Our original audio had a sampling rate of 24000 Hz with 16 bit depth. That means 24000*16 bits per second, which is 384 Kbps. With this quantization we would be achieving a reduction of 512 in size! However there is a problem: the quality of the compressed audio will be very bad.
To solve this problem, EnCodec uses a residual vector quantizer (RVQ), which means that the residual between the encoder output and the closest code of the first quantizer is then quantized by the second quantizer. This allows to use multiple quantizers that refine the outputs of the previous quantizers. If we use 8 quantizers, we can get a decent audio quality and would be representing 1 second of audio with 8 sequences of 75 10 bits elements. That is 6 kbps and a reduction of 64 in size. Nice!
Now, let’s write some code in Python to tokenize audio using EnCodec:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from encodec import EncodecModel
import librosa
import torch
def tokenize(filename, model, sr=24000, num_quantizers=8):
with torch.no_grad():
x, fs = librosa.core.load(filename, sr=sr)
codes = model.encode(torch.from_numpy(x)[None,None,:])[0][0]
return codes
def detokenize(codes, model):
with torch.no_grad():
decoded_x = model.decode([(codes,None)])
return decoded_x.detach().cpu().numpy()[0,0]
NUM_QUANTIZERS = 8
model = EncodecModel.encodec_model_24khz()
codes = tokenize('/mnt/hdd6T/jamendo/00/565800.mp3', model)
reconstruction_8q = detokenize(codes[:,:NUM_QUANTIZERS], model)
These are some audio samples of encoding/decoding the same audio with a different number of quantizers:
So it seems that now we have a way to tokenize audio efficiently, we just encode it using EnCodec and instead of having a very long sequence of 24000 elements per second, we have only 75.
But there is a problem still, related to the RVQ, as instead of having a single sequence to represent an audio, we have as many sequences as quantizers. In our experiments we are going to use 8 quantizers, so it would be 8 sequences.
If the sequences were independent, they could be predicted in parallel. In the same way that the model predicts the next word given the previous ones, we could predict the 8 tokens in parallel, given the previous ones.
The problem is that the output of the second quantizer will depend on the output of the first one, as it is modelling its residual. So assuming independence between the sequences might not be ideal. There are several approaches to overcome this problem, and they are discussed in
This pattern allows the model to predict the next element based in the previous quantizers outputs from the same timestep and all the outputs from previous timesteps. This way, the dependence between quantizers can be modelled. The drawback is that now the resulting sequence is $Q$ times longer, making it computationally more expensive and harder to model.
Parallel pattern: this pattern is kind of opposite to the flattening one; the quantizers are assumed to be independent so in this pattern all the sequences are predicted in parallel. This approach doesn’t increase the sequence length but it might lead to worse results because of the independence assumption.
Vall-E pattern: this is the pattern used in Vall-E
Delay pattern: this is the pattern used by MusicGen
Now, let’s write the Python code to apply the delay pattern to the EnCodec codes and also to remove the delay:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def roll_along(arr, shifts, dim):
#From: https://stackoverflow.com/a/76920720
assert arr.ndim - 1 == shifts.ndim
dim %= arr.ndim
shape = (1,) * dim + (-1,) + (1,) * (arr.ndim - dim - 1)
dim_indices = torch.arange(arr.shape[dim], device=arr.device).reshape(shape)
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
return torch.gather(arr, dim, indices)
def apply_delay_pattern(codes):
codes = torch.nn.functional.pad(codes+1,(0,codes.shape[1]-1))
codes = roll_along(codes, torch.arange(0,codes.shape[1], device=codes.device)[None,:].tile((codes.shape[0],1)), 2)
return codes
def unapply_delay_pattern(codes):
codes = roll_along(codes, -torch.arange(0,codes.shape[1], device=codes.device)[None,:].tile((codes.shape[0],1)), 2)
codes = codes[:,:,:-codes.shape[1]]
return codes
It can be seen that this approach adds only $Q-1$ elements to the original sequence, which in this case is only 7 extra elements.
GPT2
1
2
3
from transformers import GPT2Model
gpt = GPT2Model.from_pretrained('gpt2')
We will have to make some modifications to this model to make it work with our inputs:
Also, we want to have some control over the generated audio at inference time. One way to achieve this is by prepending a prompt to the input tokens.
The prompt could be a sequence of words describing the audio, an image, another audio, etc…
I’ve been working in ways to represent audios recently, and proposed EnCodecMAE
Now, we can make a Pytorch Lightning module to encapsulate everything in a class:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import pytorch_lightning as pl
class EnCodecGPT(pl.LightningModule):
def __init__(self, num_q=8, optimizer=None, lr_scheduler=None):
super().__init__()
self.encodec_model = EncodecModel.encodec_model_24khz()
self.num_q = num_q
self.num_codes = 1024
self.vocab_size = self.num_q * (self.num_codes + 1)
self.gpt = GPT2Model.from_pretrained('gpt2')
self.vocab_embed = torch.nn.Embedding(self.vocab_size, self.gpt.embed_dim)
self.classification_head = torch.nn.Linear(self.gpt.embed_dim, self.vocab_size)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
def forward(self, x, prompt=None):
#Encodec tokens:
with torch.no_grad():
codes = self.encodec_model.encode(x)[0][0][:,:self.num_q]
#Delay pattern:
codes = apply_delay_pattern(codes)
#Offset each quantizer by 1025 and pass through LUT:
input_vocab_embed = torch.arange(self.num_q, device=x.device)[None,:,None]*(self.num_codes + 1) + codes
gpt_in = self.vocab_embed(input_vocab_embed).sum(axis=1)
#Prepend prompt:
if prompt is not None:
gpt_in = torch.cat([prompt[:,None,:],gpt_in],axis=1)
#Pass through GPT:
gpt_out = self.gpt(inputs_embeds = gpt_in)
#Make classification:
preds = self.classification_head(gpt_out['last_hidden_state'])
preds = preds.view(preds.shape[0],preds.shape[1],self.num_q,self.num_codes+1)
return preds, codes
def training_step(self,x, batch_idx):
wav = x['wav'].unsqueeze(1)
preds, targets = self(wav, x['prompt'])
preds = preds.transpose(1,3)[:,:,:,:-1]
loss = torch.nn.functional.cross_entropy(preds, targets)
self.log('train_loss', loss)
return loss
def validation_step(self,x, batch_idx):
wav = x['wav'].unsqueeze(1)
preds, targets = self(wav, x['prompt'])
preds = preds.transpose(1,3)[:,:,:,:-1]
loss = torch.nn.functional.cross_entropy(preds, targets)
self.log('val_loss', loss)
def configure_optimizers(self):
opt = self.optimizer(self.trainer.model.parameters())
if self.lr_scheduler is not None:
if self.lr_scheduler.__name__ == 'SequentialLR':
binds = gin.get_bindings('torch.optim.lr_scheduler.SequentialLR')
lr_scheduler = self.lr_scheduler(opt, schedulers=[s(opt) for s in binds['schedulers']])
else:
lr_scheduler = self.lr_scheduler(opt) if self.lr_scheduler is not None else None
else:
lr_scheduler = None
del self.optimizer
del self.lr_scheduler
opt_config = {'optimizer': opt}
if lr_scheduler is not None:
opt_config['lr_scheduler'] = {'scheduler': lr_scheduler,
'interval': 'step',
'frequency': 1}
return opt_config
The full code for training can be found here.
Once the model is trained, we want to generate audios with it. To sample from language models we have to do what is called autoregressive sampling.
WARNING
For this experiment we didn’t use an end of audio token as all the samples are 4 seconds long, so it is expected that after 4 seconds the model will return silence.
A nice thing about HuggingFace GPT2 implementation is that it allows key-value caching. This means that we don’t need to generate all the intermediate outputs and compute attention with all the queries all the time during inference, as intermediate results are cached at each autoregressive sampling iteration. This saves a lot of computations reducing the time required to generate the sequences. This is not a minor detail as autoregressive sampling is expensive because it cannot be parallelized.
One very important detail to discuss is how to pass from probability outputs to an actual output generated by the language model. There are many approaches, and those are discussed in more depth here. Some options are:
In the following Python snippet we will use temperature sampling:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def generate(prompt_filename, encodecmae, lm_model, temperature=1.0, generation_steps=300):
prompt = encodecmae.extract_features_from_file(prompt_filename)
prompt = prompt.mean(axis=0)
with torch.no_grad():
prompt = torch.from_numpy(prompt).to(lm_model.device)
gpt_in = prompt.unsqueeze(0).unsqueeze(0)
past_keys=None
generation = []
for i in tqdm(range(generation_steps)):
outs = lm_model.gpt(inputs_embeds=gpt_in,past_key_values=past_keys)
past_keys = outs['past_key_values']
preds = lm_model.classification_head(outs['last_hidden_state'][0])
preds = preds.view(preds.shape[0],lm_model.num_q,lm_model.num_codes+1)
sampled_idxs = torch.cat([torch.multinomial(torch.nn.functional.softmax(preds[0,q,:]/temperature),1) for q in range(lm_model.num_q)])
generation.append(sampled_idxs)
in_idxs = torch.arange(lm_model.num_q, device=lm_model.device)*(lm_model.num_codes + 1) + sampled_idxs
gpt_in = lm_model.vocab_embed(in_idxs).sum(axis=0).unsqueeze(0).unsqueeze(0)
generation = torch.stack(generation)
generation = roll_along(generation,-torch.arange(0,8,device=generation.device),0)
audio = lm_model.encodec_model.decode([(torch.maximum(generation-1, torch.tensor(0, device=lm_model.device))[:-lm_model.num_q].T.unsqueeze(0),None)])
audio = audio[0].cpu().detach().numpy()
return audio
For the initial experiments, I trained the language model on NSynth, which is a dataset with many synth samples spanning different music instruments and notes. All the samples are 4 seconds long and quite standarized. This makes it an ideal dataset for first toy experiments.
Some examples of what happens when we move the temperature. This example is not in the training set of the model.
We can hear that the generated samples resemble the prompt. However, when the temperature is too low (0.01 and 0.1), artifacts resulting from outputs looping between tokens can be heard. This signals us that greedy search might be a bad idea. Increasing temperature leads to more organic results, however more noise is also added. When the temperature is too high (>1.0), the generated samples start to sound random and very different from the prompt.
Then, I did some experiments morphing between 2 sounds (let’s call them A and B). The most straightforward way to do it is to extract the prompt from A and B, and generate new prompts that are linear combinations of A and B. Then, these prompts are used to generate new audios. Let’s listen some examples with 15 prompts between A and B concatenated.
What if we want a continuous interpolation between 2 prompts. Is it possible? Well, yes it is, although it’s a bit more complicated because our model was trained with 4 seconds audios only. However, it is still possible to do it by creating a buffer with a length a bit shorter than 4 seconds (to avoid the silence at the end of NSynth samples). Let’s listen some examples of this type of morphing:
What happens if we use as prompt something that is not a synth sound? Let’s find out:
In the next experiment I wanted to deal with a bit more complex signal: speech. For that I used a small subset of around 200 hours of LibriLight. The prompt is still the same, and I hope it can give information about speaker identity, speaking style and maybe content. In spite of the extra complexity of the dataset, the training and validation losses look very smooth:
It sounds like the identity is being informed by the prompt, and also the style is mantained. The model seems to be able to generate sometimes some words or ‘word-like’ sounds. The prosody and rhythm of speaking sounds quite coherent too. However the audio quality is not the best as some artifacts can be heard.
Next, I explored the effect of temperature but instead of sampling from a set of temperatures, I continuously increased it from 0.01 to 2, so we can listen how speech is heating up.