Demo - Manual LLM inference
Get the project source code below, and follow along with the lesson material.
Download Project Source CodeTo set up the project on your local machine, please follow the directions provided in the README.md
file. If you run into any issues with running the project source code, then feel free to reach out to the author in the course's Discord channel.
This lesson preview is part of the Fundamentals of transformers - Live Workshop course and can be unlocked immediately with a single-time purchase. Already have access to this course? Log in here.
Get unlimited access to Fundamentals of transformers - Live Workshop with a single-time purchase.
[00:00 - 00:12] OK, all right, so now we can look at a demo for autoregressive decoding. OK, so now let's go to here.
[00:13 - 00:26] OK, so previously we ran this LLM.generate function that's hooking face provides for us. But now we want to do is actually run inference manually so you can see how autoregressive decoding works.
[00:27 - 00:31] So let's do that. Let's actually create a new cell here.
[00:32 - 00:41] This is going to be autoregressive decoding. OK, so we're going to start off by pulling the input IDs.
[00:42 - 00:50] So this is just some data structure that I can face come up with and it contains the list of IDs that we talked about. So let me actually show you what that looks like.
[00:51 - 01:00] We printed these out earlier and it's just the list numbers like this. We're actually going to feed that list of numbers into our LLM.
[01:01 - 01:10] So we're now run our LLM on those input IDs and we get a bunch of outputs. Now, the outputs though are not yet token IDs, right?
[01:11 - 01:19] Instead, there's something that we call logits. So logits, we're going to talk about in more detail later.
[01:20 - 01:35] They're basically a list of vectors and I'll explain what vectors are and why we are talking about them. For now, you can ignore what this line is doing functionally and just understand that at a high level, this line converts logits into IDs for us.
[01:36 - 01:44] Or into a single ID for us. OK, so this is the token ID of the output of text.
[01:45 - 02:03] Now, let's actually see what that text is using our tokenizer. So to recap, our tokenizer converts from text into IDs and then from IDs back into text, right? So right now we have an ID and we need to convert it back into text output ID.
[02:04 - 02:10] Now, let's see what this token looks like. Ha ha, it's a comma and that agrees with what we saw previously.
[02:11 - 02:21] So if you scroll back up here, we saw that our text from LLM was a list of colors colon red that was a prompt. And the first character I predicted was a colon. Sorry, was a comma.
[02:22 - 02:27] Right. So down here, that's exactly what we saw here. So we know that we wrote this code correctly.
[02:28 - 02:36] Now, someone else I think can and asked earlier, what is the token ID for punct uation or how do we output punctuation? So let me show you this is what the token ID is.
[02:37 - 02:41] It's just it's six. Yeah.
[02:42 - 02:43] Cool. All right.
[02:44 - 02:53] So we've done one round of prediction. Now, what I said earlier was you need to take this output, you to add it back to the input and then you need to run infants to the model one more time.
[02:54 - 02:58] So let's do that. Here I'm going to write the following.
[02:59 - 03:10] I'm going to convert this into a list. So this is currently a Itorch tensor.
[03:11 - 03:16] That's a tensor class. And I need to convert the list using two lists.
[03:17 - 03:22] Right. OK. So I'm going to run prediction for five steps.
[03:23 - 03:27] So normally you would run prediction until you see an end of sequence token. I'm not going to do that.
[03:28 - 03:34] I'm just going to run for five steps here. So the first thing I'm going to do is unfortunately some what you call this.
[03:35 - 03:40] It's just logistics. I need to basically wrap this list in a torch tensor object.
[03:41 - 03:43] That's not really important. There's no interesting insights there.
[03:44 - 03:53] But this is the first line where I do some real work. This is where I actually pass the inputs into the LOI right using the list of inputs that I've collected so far.
[03:54 - 03:55] Right. All right.
[03:56 - 04:01] Now I take the output just like I did before. This is the exact same line as up here.
[04:02 - 04:03] Right. OK.
[04:04 - 04:13] So now for my last step, I'm going to do what I promised you earlier. I'm going to add the newly outputted ID back to the input.
[04:14 - 04:21] And why is there red here? OK, I'm going to.
[04:22 - 04:25] Oh, I see it's just upset. I didn't import the torch.
[04:26 - 04:30] OK, cool. So now that I've done this, I've added the input.
[04:31 - 04:36] I've run the input and I execute the LOM five times. Now let's see what that output looks like.
[04:37 - 04:42] Right. So here we have decode output IDs, special type of things.
[04:43 - 04:48] Right. So while this is executing, we can actually, oh, great, perfect.
[04:49 - 04:53] So now we get red, comma, green, comma, blue, comma. Let's compare that to what we saw before.
[04:54 - 04:56] And it looks exactly the same. Right.
[04:57 - 05:05] So now we know we've actually implemented hugging faces LOM dot generate function pretty faithfully. This is exactly what their function does.
[05:06 - 05:09] So now you've implemented all our aggressive decoding from scratch. OK.
[05:10 - 05:11] OK. All right.
[05:12 - 05:21] So now I want to show you something weird. Let's say that I took my inputs and I fed it into my model, just like we did here.
[05:22 - 05:27] But here what I've been doing is I've been taking the last token. I've only been printing out the last token.
[05:28 - 05:31] What happens if I print out all the tokens? Right.
[05:32 - 05:37] What does the LOM predict? We'd mentioned before that the LOM outputs and tokens for every end inputs.
[05:38 - 05:40] Why can't I just read the outputs directly? Right.
[05:41 - 05:43] Why do this auto-aggressive decoding? Let me show you.
[05:44 - 05:53] Here I can take the outputs of the LOM. That's decode outputs, logits, art max.
[05:54 - 06:08] So again, don't worry too much about what the code is doing functionally. Just at a high level, I'm passing in the inputs to my model and looking at all of the outputs, not just the last word I outputted.
[06:09 - 06:12] OK, this is total gibberish. I lot of the and.
[06:13 - 06:17] OK. Why did my model produce a bunch of random gibberish?
[06:18 - 06:22] If I don't just use the last word. Let's look at the slides for that.
[06:23 - 06:26] I need to explain what's going on here. OK, so here's what we got.
[06:27 - 06:33] I inputted start of sequence a list of colors. And the model outputted I lot of the and.
[06:34 - 06:40] Now what if I told you though, this output actually makes perfect sense. And it's just a matter of how we read it.
[06:41 - 06:50] So let me explain. We need to read each line of this output as a continuation of all lines of input up to that point.
[06:51 - 06:54] So here we would read I as start of sequence. I.
[06:55 - 06:56] We would read lot as a sequence. I.
[06:57 - 07:07] I. We would read lot as start of sequence a lot. I read we would read of as start of sequence a list of.
[07:08 - 07:16] A list of the list of colors and. Right, so you can see now the outputs actually make sense if we read in this manner, right?
[07:17 - 07:26] And this is because the of how the L M predicts, right? The L M always predicts in this matter where each of these outputs is conditioned on only the words that come before it.
[07:27 - 07:34] And this will be important later on as well. This is what we call prompt processing, right?
[07:35 - 07:42] We take in an entire input and then we get a series of inputs. And then we ignore most of those and we just pick the last one.
[07:43 - 07:51] Right, so unlike previously we've fed in one word at a time here you feed in the entire input and you look at the entire output. Right, so this is important.
[07:52 - 07:58] There are two phases of inference for L. You have prompt processing and then you have autoregressive decoding.
[07:59 - 08:08] Prompt processing is where you process your massive prompt. Maybe it's the question that you gave chat to be deep and autoregressive decoding is where you then predict step by step word by word.
[08:09 - 08:23] All right, so I want to pause here to see first if folks have questions for me. Or how are people doing a thumbs up or a thumbs down, suffices?
[08:24 - 08:26] This is confusing. Are people following?
[08:27 - 08:29] Does it make sense? Okay, great.
[08:30 - 08:31] Perfect. I'm seeing some thumbs up.
[08:32 - 08:35] Great. All right, so let's keep going then.
[08:36 - 08:42] These are the, okay, so how language models predict? So far we've talked about the general process of autoregressive decoding.
[08:43 - 08:47] And then so a question here, a product wouldn't output the last comma, right? That's a good question.
[08:48 - 09:06] So a product would most likely keep running inference until it reached an end of sequence token, in which case the comma wouldn't be the last thing at outputs, but it would output the comma at some point. So what it would give you is some, it would actually stream the output from servers back to your interface.
[09:07 - 09:20] Right, so for this would be over some sort of RTC connection or something or WebSocket connection where a stream, these inputs token by token and then on your side, you see chat TBT like typing , right? Yeah.
[09:21 - 09:26] Cool. All right.