Demo - skip connections
Get the project source code below, and follow along with the lesson material.
Download Project Source CodeTo set up the project on your local machine, please follow the directions provided in the README.md
file. If you run into any issues with running the project source code, then feel free to reach out to the author in the course's Discord channel.
This lesson preview is part of the Fundamentals of transformers - Live Workshop course and can be unlocked immediately with a single-time purchase. Already have access to this course? Log in here.
Get unlimited access to Fundamentals of transformers - Live Workshop with a single-time purchase.
[00:00 - 00:05] OK, so here we've got the next concept. So here's an additional way that we can use position.
[00:06 - 00:13] So as you mentioned before, weighted sums can reorder tokens. Now, we have one more way to actually combat this.
[00:14 - 00:24] One way, as you mentioned before, was to actually add the positional encoding. But one other thing that we can do is actually add what's called a skip connection.
[00:25 - 00:28] So this is a diagram we had before. We added a positional encoding.
[00:29 - 00:36] We added the attention, and then we added the MLP. Now what we're going to do is add something called a skip connection.
[00:37 - 00:42] So I'll explain what this looks like in code in just a second. But in short, a skip connection has two main purposes.
[00:43 - 00:51] Number one, it ensures that reordering doesn't occur. And second, more importantly, it actually improves conversions of the model.
[00:52 - 01:02] It makes it possible to train much faster. Now, the reason why it does that and why skip connection makes training converge much faster is a little bit more complex and hard to explain.
[01:03 - 01:12] Mostly because the original author is that proposed it, maybe five or six years ago, didn't really know the reason either. This is just what the empirically observed.
[01:13 - 01:19] OK, so let me actually show you what this looks like in code. What does a skip connection mean?
[01:20 - 01:39] So let's go back to our-- let's actually now demo a manual skip connection. OK, so let's grab the code for this, because this one is fairly simple.
[01:40 - 01:48] Or it's not simple, but it's simpler than attention at least. OK, so here we've got the MLP, and we've got the input right over here.
[01:49 - 01:53] This is attention through. The skip connection is pretty simple.
[01:54 - 01:59] I'm just going to add attention through. And I know it's a little unwelling, but that's pretty much it for a skip connection.
[02:00 - 02:06] So we go back to my diagram. I have a skip connection that takes the input of my MLP, and it adds it to the output of the MLP.
[02:07 - 02:12] The input here is attention through. And so I just added attention through back to the output.
[02:13 - 02:14] And that's it. I just added this.
[02:15 - 02:23] So this seems very simple, but this is actually very critical for the model to converge much faster. This is what we call a skip connection.
[02:24 - 02:30] And so the skip connection is this input being added to the output. And this is what we call the residual branch.
[02:31 - 02:41] The residual branch is the original set of outputs and set of outputs and set of vectors that are being passed through. Oh, yeah, of course.
[02:42 - 02:48] So let me re-share this. Oh, yeah, there we go.
[02:49 - 02:53] OK, cool. So this is-- OK, so this is skip connection.
[02:54 - 02:56] I'll let me go back to here. OK, so that's what these diagrams mean.
[02:57 - 03:12] Anytime I draw something along with this, it just means it's that we're taking the input and we're adding it to that. And the reason why I'm talking about all these details, by the way, is that maybe they don't make a whole lot of intuitive sense, but in a second, I'm going to walk through the original diagram from the original transform paper that everybody copy and pastes everywhere.
[03:13 - 03:21] And then you'll see that we actually talked about all those components. Yeah, there are also a lot of sources online that just copy and paste those diagrams, but I think our missing critical parts of the explanation.
[03:22 - 03:28] So I'm hoping that you can get the actual explanation here. Definitely let me know if these diagrams end up being confusing.
[03:29 - 03:31] Yeah. OK, so here's one observation.
[03:32 - 03:35] Let's say that I did that. So in my code, I added attention 3 to the output.
[03:36 - 03:46] So every time I do that, though, I'm basically doubling the magnitude of the output. So let me explain what that means, but I'm doubling the magnitude.
[03:47 - 03:57] So let's say that attention 3 here is something like-- folks, actually, let me rename this. This is MLP3 plus attention 3.
[03:58 - 04:05] That'll make it a little clearer. So here, let's say that attention 3 has the values 1, 2, 3.
[04:06 - 04:15] And then maybe MLP3 has the values 3, 1, 2. Previously, my output was 3, 1, 2, and the largest value was 3.
[04:16 - 04:26] But now if I add attention 3 to it, the largest value becomes 6. In essence, if I add the skip connection, it doubles my magnitude.
[04:27 - 04:36] It doubles the value of my largest value-- sorry, the magnitude of my largest value. So we can fix that somehow.
[04:37 - 04:46] And the reason why we need to fix that is because neural networks in general, including large language models, all learn faster when the activations are near 0. OK, so what are activations?
[04:47 - 04:49] Activations are features. They are embeddings.
[04:50 - 04:55] They are the vectors you've been talking about. So activations, just one other name for the exact same thing.
[04:56 - 05:18] So when those values-- so here, when these values are all close to 0, the neural network learns way faster and usually converges to a better result. So knowing that then, the fact that we're doubling the magnitude after the skip connection is a problem.
[05:19 - 05:23] So here's how to fix that. We can use something called normalization.
[05:24 - 05:31] So this is our diagram from before. We can use something called batch norm to actually normalize these samples.