Batch norm

Project Source Code

Get the project source code below, and follow along with the lesson material.

Download Project Source Code

To set up the project on your local machine, please follow the directions provided in the README.md file. If you run into any issues with running the project source code, then feel free to reach out to the author in the course's Discord channel.

This lesson preview is part of the Fundamentals of transformers - Live Workshop course and can be unlocked immediately with a single-time purchase. Already have access to this course? Log in here.

This video is available to students only
Unlock This Course

Get unlimited access to Fundamentals of transformers - Live Workshop with a single-time purchase.

Thumbnail for the \newline course Fundamentals of transformers - Live Workshop
  • [00:00 - 00:05] I'm going to implement batch norm in a second. Actually, I will do that now.

    [00:06 - 00:09] So let's go into here. Let's insert here.

    [00:10 - 00:14] Okay, so you'll notice that these demos are getting shorter. It will have a longer one later.

    [00:15 - 00:24] So here we'll have demo, demo batch one. Okay, all right, so what does a batch norm look like?

    [00:25 - 00:30] Let's say we have a tensor. That looks like the following.

    [00:31 - 00:37] So batch size, let's say is equal to three. Okay, so what is batch size here?

    [00:38 - 00:52] Batch size means how many prompts I'm passing in to the LLM all at once. So let's say that maybe Ken sends a prompt, Julius sends a prompt, I send a prompt, that's three, right?

    [00:53 - 01:06] So each of us sends a prompt to chat GBT. And then chat GBT says, I'm going to stack all those, I'm going to take those prompts, the text, I'm going to convert them into vectors, and then I'm going to stack all those vectors together.

    [01:07 - 01:14] Right, so those end up being, so that means we have a tensor with a batch size of three. So question from Maya is, these are all distinct layers in the model.

    [01:15 - 01:19] Yes, exactly. So if I come back to here, these are all distinct layers.

    [01:20 - 01:22] And you'll be able to see them in the code actually. You know what?

    [01:23 - 01:26] Actually, I can just go right now to the code. Lama, GitHub.

    [01:27 - 01:32] I can show you, this is going to be a lot of codes, you don't have to read all of them. I just want to show you that these all have a presence somewhere in the code.

    [01:33 - 01:35] The first one that we have is the tension. The second we have is MLP.

    [01:36 - 01:40] The third one is positional encoding. positional encoding in Lama is called rope.

    [01:41 - 01:44] Okay, so I guess it's somewhere here. But let me come back for that.

    [01:45 - 01:50] Actually, let me go to attention first. So here you have the attention module.

    [01:51 - 01:55] And let's see where this is actually called or defined. Okay, so you create the attention module, right?

    [01:56 - 02:01] Inside of your transformer, right? So what we expect, you also create your feed forward, right?

    [02:02 - 02:04] Inside of your transformer. Again, what we expect.

    [02:05 - 02:16] So now in this attention, what we call this attention down here, right? So this right here, the forward function for any PyTorch layer tells you what is actually executing during inference.

    [02:17 - 02:21] So here we're executing the attention. And remember this X plus, that's our skip connection.

    [02:22 - 02:25] This H plus, that's also our skip connection. And then here we got our feed forward.

    [02:26 - 02:32] We also see there's this FFN norm and it's attention norm. These are the norms that we're just about to talk about, but let's ignore that for now.

    [02:33 - 02:37] Okay, so we've got skip connections. We've got the attention, we've got skip connection, we've got the MLP.

    [02:38 - 02:47] Now let's look for that positional encoding. So inside of our attention, we should have, we should have our positional encoding.

    [02:48 - 03:03] So this right here, a PyTorch array embedding, is there a version of the positional encoding? It's gonna look slightly weird, but basically, type as, yeah, okay.

    [03:04 - 03:24] So I guess this isn't the most readable unfortunately, because La Mona's a clever thing where it pre-computers some stuff and that pre-computation is where the bulk of the logic is. Okay, so anyways, you'll just have to take my word unfortunately that apply, rotary embedding is the positional encoding.

    [03:25 - 03:35] And we can go back to that once we've talked about all the different parts and then we'll walk through the La Mona code to see like where all these parts are. Okay, all right, so back to here.

    [03:36 - 03:47] Let's talk about this norm, because this norm is very, very important and ensures that all of our values are close to zero, therefore making the model faster to converge or easier to train. Okay, so the first is our batch size.

    [03:48 - 03:55] As you mentioned before, batch size is the prompts. It's the number of prompts that we're feeding into the model all at once.

    [03:56 - 04:09] And the second dimension that we care about here is going to be the number of tokens. All right, so the number of tokens will have to be the maximum number of tokens across all three of our requests.

    [04:10 - 04:26] So if my question has a hundred words and I think Julius and Ken were my example from before, if Julius example has 50 words and Ken's input has 200 words, 200 is the maximum. So number of tokens here would have to be 200.

    [04:27 - 04:35] And then finally, we have the dimensionality. So from before, we know that our token dimension is actually 300, right, so dimensionality.

    [04:36 - 04:52] And so this is called the model dimension, usually in code. So actually if you go into hugging face and you look at a transformer, so if you go to any one of these, you can actually see, oh, actually, I need this is not what I want.

    [04:53 - 05:00] Let's go to lama again, just as long as the most popular. Okay, I actually want the lama model of in face.

    [05:01 - 05:11] Okay, perfect. So we have model lama here and then you'll have something in here called model 10.

    [05:12 - 05:15] Oh, you know what? Maybe it's no longer called model 10.

    [05:16 - 05:18] Oh, okay, it's called hidden size. Okay, that's fantastic name.

    [05:19 - 05:23] This is way better. So actually, instead of model 10, it's gonna be called hidden size.

    [05:24 - 05:26] So let's follow the same convention. It's gonna be called hidden size.

    [05:27 - 05:32] All right, okay. So now let's create that tensor, that size, number of tokens, hidden size.

    [05:33 - 05:40] All right, and this is our input X. Now, how do we compute a batch on?

    [05:41 - 05:51] Well, a compute batch on we're actually going to subtract the mean and then we're gonna divide by the standard deviation. So in short, we're centering all of our input and then we're scaling it down.

    [05:52 - 05:57] So that it's very, so that it's unit norm. All right, so we're gonna unshift and then scale down.

    [05:58 - 06:07] Now, here's what I could do. I could just do something like this, go to an X minus the mean and then divide it by the standard deviation.

    [06:08 - 06:14] Right. Now, that front though, operates in a very specific way.

    [06:15 - 06:28] It actually takes this mean and standard deviation along the batch dimension. So what that means is for every single token and for every single value within that token.

    [06:29 - 06:47] So for every position, we take the average across, let's say my first token, Ken's first token and Julius's first token, across those first few tokens, we take the average. So take the average of three tokens, now we have one average token and that's what you subtract.

    [06:48 - 06:54] So that's what it means to take the average across the batch dimension. So here I write, dim equals to zero.

    [06:55 - 07:00] Dim equals to zero. Actually, let me double check that out with this code correctly.

    [07:01 - 07:03] Dimings is zero dot shape. Oopsie.

    [07:04 - 07:12] Oh, I see because this, I don't know what I want to switch that random. Okay. All right.

    [07:13 - 07:19] So this is correct. So once I take the average across the batch dimension, that batch dimension should disappear.

    [07:20 - 07:33] So think of it this way. If I took the average of three vectors, so if I took the average of three arrays, where every array has length 10, then in the very end, I should have one array of like 10.

    [07:34 - 07:37] Right. So that's what I mean by that batch dimension disappears.

    [07:38 - 07:41] Instead of three, I now just have one. And in this case, instead of three, I just have nothing.

    [07:42 - 07:45] Right. Okay. So this is correct.

    [07:46 - 07:48] So this is actually how batch run operates. That a very simple level.

    [07:49 - 08:01] I've omitted some details, but they don't really matter for the sake of this conversation. Right. We've now got a vector that has scale, uniform scale, and that is all centered.

    [08:02 - 08:05] Okay. These are all near zero.

    [08:06 - 08:15] Right. So this negative right here, I'm sure you understand the sign that the notation, right. So these are all close to zero, basic and zero.

    [08:16 - 08:28] Now, if I do, let's do standard deviation, when it comes to zero, then you'll see all these are equal to one. So basically, we've got a scaled down and unshifted version of the elements.

    [08:29 - 08:33] And that's our batch one. Okay. So now here's a problem though.

    [08:34 - 08:44] These elements are so big that I don't actually have the memory required to pass in multiple prompts during training. I can't fit in.

    [08:45 - 08:50] So in reality, during training, the number of tokens would be really, really large. There should be something like 2000, 4000, 8000.

    [08:51 - 08:55] So I don't have room for multiple prompts during training. So I have batch size equal to one.

    [08:56 - 09:07] Now, what happens if my batch size equal to one? Batch storm is now taking the average of one, one vector, which is just itself.

    [09:08 - 09:11] Batch storm is no longer doing anything useful. Yeah. Oh, okay.

    [09:12 - 09:19] So the question is, I don't understand the batch operation. I assume you mean this x.mean, where dimension equals to zero.

    [09:20 - 09:21] Does that what you mean? The batch operation?

    [09:22 - 09:26] Oh, okay. So the concept of a batch.

    [09:27 - 09:34] Okay, that's a good question. Maybe what I can write here instead is a number of prompts.

    [09:35 - 09:38] Right. So number of prompts.

    [09:39 - 09:44] And so let me show you where that comes up. So in the very, very beginning, let's go back to the very, very top here.

    [09:45 - 09:52] Remember this example where we actually ran one prompt through the LLM. It turns out that the LLM can actually take in multiple prompts all at once.

    [09:53 - 10:06] So here I can actually take in, I can just copy this all the way down, the very bottom. Okay. So here, instead of just outputting running one of these, I can actually pass in multiple prompts all at once.

    [10:07 - 10:25] So a list of colors red, and I could say the capital of the US is, and I can say my favorite food is, oh, why is it so upset that we, and I have to have batch and see what this is. Oh, I see, I see.

    [10:26 - 10:31] Okay. So basically I need to do, adding is true.

    [10:32 - 10:44] Okay. All right. So here I've got three different outputs, right?

    [10:45 - 10:48] And I'm actually gonna, okay. So here's what we've got.

    [10:49 - 10:54] We've got three different prompts, right? And then let me do this.

    [10:55 - 11:01] I got three different prompts, a list of colors, and then we got a list of colors red, comma. So the model is repeat itself.

    [11:02 - 11:08] And then we've got the capital of US is, and then the LLM gave us New York City. That's funny.

    [11:09 - 11:18] And then my favorite food is, and it just gave us a blank. So maybe if I forget food is, let's try like, turn me to, and let's see what it gives us.

    [11:19 - 11:20] Stream me to add iced tea. Okay. Fantastic.

    [11:21 - 11:31] So Facebook, OPT 125 million really likes turn me to an iced tea. So you can see here basically that I can pass in multiple prompts to my LLM.

    [11:32 - 11:35] And so that's what I meant by bat size. In this case, I've got three different prompts.

    [11:36 - 11:40] So my bat size is three. And so to help hopefully make that clear, I'll just say the number of prompts here, right?

    [11:41 - 11:49] So the first dimension of my tensor is number of prompts. And then I can take the average across those prompts.

    [11:50 - 11:56] So actually, instead of using this as a dimension, what I could do is just following. Let's just say I have the following.

    [11:57 - 12:01] I have three inputs. I have X one, this is a tour shop, Rand, Num tokens, hidden size.

    [12:02 - 12:09] And I can just duplicate this. And then batch norm is just taking the mean of these three.

    [12:10 - 12:26] X one, X two, X three. And then it is dividing by, okay. So actually, Y one would be equal to X one minus this, minus the mean, and then divided by standard deviation.

    [12:27 - 12:34] So mean is this, standard deviation is this. Yeah.

    [12:35 - 12:39] Okay. So you could equivalently write something like this.

    [12:40 - 12:46] Why? X three and minus the mean.

    [12:47 - 12:48] The action. Okay, cool.

    [12:49 - 12:51] And then this is basically what Bachelorette does. Bachelorette produces these values.

    [12:52 - 13:04] Oh, we do this. And so now you can imagine, if I only have one value, if I only have one prompt, I'm taking the mean of just one item and I'm taking the standard deviation of just one item, both of which are very, very uninteresting.

    [13:05 - 13:08] Yeah. Okay. Does that answer the question?

    [13:09 - 13:19] Can anybody else? Cool.

    [13:20 - 13:26] I'll assume yes, but let me know if anybody's still confused. Just posting the chat and I'll go over it again.

    [13:27 - 13:33] Okay. So now we've seen, yeah, pretty much, it's a batch of prompts.

    [13:34 - 13:39] And that's what I mean by batch. Yeah. It's just like a, it's just a set of prompts, a group of prompts.

    [13:40 - 13:45] Yeah. And so in this case, this was the group of prompts, or the batch of prompts that I gave to LLM.

    [13:46 - 13:53] Yeah. And this particular case I've got, these are all prompts, but converted into vectors.

    [13:54 - 14:00] Yeah. How does that sound?

    [14:01 - 14:13] Let me know if that's still confusing. Okay. All right.

    [14:14 - 14:19] So I'll keep going, but definitely impose more questions if you have any for anybody in the room. That was batch.

    [14:20 - 14:31] Where we actually normalize using, where we actually divide by the standard deviation, subtract the mean, and the mean and standard deviation are computed across different prompts that we've added.