TinyASR - Part 1

Back to Writing

2024-01-25

Introduction1

I've wanted to scratch an itch for a while and see if I could build a decent speech recognition system from the ground up without too much "cheating" (i.e. looking up in detail how other people have done it). Coming from working in the text-to-speech space, I've used ASR models a lot (wav2vec 2.0, HuBERT, Whisper etc) and so have a higher than average exposure to the audio/speech modality, but never actually needed to build or train them myself. Admittedly it's a little bit contrived doing this since these powerful models already exist and do a great job, but I always learn the most by being hands on.

Whisper

The ASR model that I've used the most is Whisper, and I have to admit that I've read the paper and looked at the code pretty closely before. Like a lot of people at the time, I was surprised at how good Whisper was considering it was just a vanilla encoder-decoder transformer trained on a lot (680k hours) of data2. But, I think it took a lot of insight to reduce the problem to that, and this approach has resonated with me in general since.

Now, I don't have access to nearly that much data and even if I did, I wouldn't have enough compute to get through it in any reasonable amount of time, so I'll have to stick to being Tiny in model size and data. It's a little bit of an unknown to me how much data I should have, I do know for example models like TinyStores have many multiple more tokens than I do in the datasets I'll be using, but this will be interesting to see.

Architecture

I'll stick with transformers, since I know them relatively well and I like how simple they are once you get used to them. Unlike Whisper I'll be going with a decoder-only approach, which seems fairly standard now in the multi-modal space (see Fuyu for example). I'll also use rotary positional embeddings because everyone seems to nowadays.

A stripped down version of the forward pass looks like this


# waveforms and mels are padded to some fixed size with silence so we always
# have the same prefix length offset
mels = self.encoder(mels)
prefix_len = mels.size(1)
tokens = self.embedding(token_ids)
# concat mels and tokens along the sequence axis
x = torch.cat((mels, tokens), dim=1)
out = self.decoder(x)
logits = seld.lm_head(out[:, prefix_len:])
# take the cross-entropy loss between logits and target_ids

There are a couple of design decisions in here:

  • encoder here is a bit of a misnomer - it's a small convolutional net to do some light feature extraction, compression, and normalization of the incoming mel spectrograms, nothing more. Mel spectrograms themselves are already a very standard (and useful) representation, so I think it's justifiable that not much more processing is done here.
  • I'm using a causal mask across the entire embedded sequence. This is the easiest thing out of the box with current Flash Attention implementations. The major alternative in design here would be to use a "prefix-LM" mask, which allows for bidirectional attention on the "prefix" (the encoded mels for us) and then causal on the output sequence (the text tokens we'd like to decode). This seems closer in principle to an encoder-decoder setup. The UL2 paper explores a lot of these choices in detail.
  • I'll stick with the standard language modeling cross-entropy loss (I know there's a lot of models trained with more specialised losses such as CTC, but I'll go with the simplest thing first). Not that we're just taking the loss between the logits after the encoded mel prefix.

Here's the full model code if you'd like to take a closer look.

Text tokenizer

Part of me dislikes the use of subword tokenizers in general - they seem like a hack that should go away at some point - so for a first attempt with TinyASR I'll use raw utf-83 bytes instead. This is the same approach that I know was used in recent models like ByT5, although I can also think back to that fantastic Karpathy RNN blog post from 2015 using them.

I do really like that tokenization essentially simplifies to this one-liner


token_ids = torch.tensor(list(text.encode("utf-8")))

and you don't have to "train" your own subword tokenizer and carry that around.

Data

For the time being I'll limit myself to a recent English split of the Common Voice dataset.


from datasets import load_dataset
ds = load_dataset("mozilla-foundation/common_voice_16_1", "en", split="train")

This is a relatively diverse dataset across accents and audio quality, so I don't think it's a bad choice to appromixate real-world usage, although volunteers are manually recording speech so it's not always the most natural sounding. Luckily the audio segments are on the shorter side too (average duration comes in below 10s), which means I can train faster.

I won't mess around with dynamic sequence lengths in batches or anything - I'll follow Whisper and just pad audio to a fixed max duration (for me 15s) and then right-pad the text tokens.

Training and experiments

I tend to copy and paste across the same training loop and default optimization settings between projects/experiments. It's almost always AdamW with these settings along with gradient clipping


lr: 1e-4 # maybe 3e-4, peak value - warmup to this and then cosine decay
weight_decay: 0.1
betas: [0.9, 0.95]
max_grad_norm: 1.0

I'll use 64 as the batch size, which easily fits on the 24GB L4 I'm renting on GCP.

False start

I forgot to account for padding tokens in my cross-entropy loss and the decoded transcript were awful. Always use ignore_index if you are padding!


loss = F.cross_entropy(logits.view(-1, logits.size(-1)),target_ids.view(-1), ignore_index=pad_token_id)

First training runs

I ran two model sizes that I still count to be in the "tiny" size category

Model (W&B run)n_layersd_modeln_headsn_parameters
Base6512820M
Small127681288M

Loss curves

First training runs

These looked relatively healthy, the only thing that stood were these periodic decreases in losses - they're most visible in the Small model. This had the signs of overfitting due to repeated epochs to me, and sure enough and probably because I'm not using any dropout (I'm used to bigger datasets and always setting dropout = 0.0 as the default) it looks like these drops coincide with the number of steps per epoch (~17k).

Decoding/sampling

I went with greedy sampling at temperature = 1.0. Here's a very un-scientific overview of the first 5 eval lines

targetBaseSmall
"Well, why don't you go to Mecca now?" asked the boy.Well, while on you go to Mancanal, asked the boy.While wind down Yogauto Lankanao, asked the boy.
The men were terrified at his sorcery.The man were that he fight at his source and early.The man were televised at his soursery.
He spoke the words "Columbia", Houston.He spoke to Works, Colombia, Houston.He spoke the words "Columbia" Fuston.
The renovated train station also has a small cafe with hot food.The renovated Train Station also has a small castard foot.The renovated train station also has a small curfee without foot.
The focal depth can be calculated from measurements based on seismic wave phenomena.The focal background beat Cartilate from National Histon says "mid-way fennow".The local DepercanBouk calculated from measurements based on seismic waves and omino.

I'm happy enough there aren't any serious bugs in the implementation and this is at least learning and sort of transcribing things (at least seems to be getting the sounds correct even when getting the text wrong). You can see the full evaluation results in the linked W&B runs above.

To be continued...

That's enough for a Part 1. It looks like it's sort of working, even if (unsurprisingly) it's not great or nearly as good as parameter matched Whisper models. At least now we've got some sort of a baseline to try and beat.

Some obvious things to investigate in the future:

  • work out if we are indeed overfitting, add a validation loss, and probably add some dropout since we are repeating data almost 6 times in 100k steps

  • is using a subword tokenizer (shortening the sequence) going to be better than decoding to bytes? I know Whisper uses their GPT-2 tokenizer.

  • add word error rate (WER) as a metric instead of just going off vibes that the transcriptions look good (or not)

P.S. Here's the mandatory cute mascot for the ML project.

Customary cute ML project mascot. Thanks Stable Diffusion XL

Footnotes

  1. Posts like this are a semi-structured (and edited) log of my thought process and what I did. They're semi-didactic in nature just to spell things out for myself as much as the reader, but you'll probably have to fill in some gaps yourself. I'll try to include some background and references (that I know of), but this will be nowhere near the standards of academic papers - apologies in advance!

  2. I'm still not completely sure anyone knows precisely the source of it though

  3. Which itself is still a human crafted tokenizer of sorts with its own biases.