Opening AI's Black Box with Prof. David Bau, Koyena Pal, and Eric Todd of Northeastern University
In this episode, we dive deep into the inner workings of large language models with Professor David Bau and grad students Koyena Pal and Eric Todd from Northeastern University.
Watch Episode Here
Read Episode Description
In this episode, we dive deep into the inner workings of large language models with Professor David Bau and grad students Koyena Pal and Eric Todd from Northeastern University. Koyena shares insights from her Future Lens paper, which shows that even mid-sized language models think multiple tokens ahead. Eric discusses the fascinating concept of Function Vectors - complex patterns of activity spread across transformer layers that enable in-context learning. Professor Bau connects the dots between these projects and the lab's broader interpretability research agenda, identifying key abstractions that link low-level computations to higher-level model behaviors.
SPONSORS:
Oracle Cloud Infrastructure (OCI) is a single platform for your infrastructure, database, application development, and AI needs. OCI has four to eight times the bandwidth of other clouds; offers one consistent price, and nobody does data better than Oracle. If you want to do more and spend less, take a free test drive of OCI at https://oracle.com/cognitive
ODF is where top founders get their start. Apply to join the next cohort and go from idea to conviction-fast. ODF has helped over 1000 companies like Traba, Levels and Finch get their start. Is it your turn? Go to http://beondeck.com/revolution to learn more.
Omneky is an omnichannel creative generation platform that lets you launch hundreds of thousands of ad iterations that actually work customized across all platforms, with a click of a button. Omneky combines generative AI and real-time advertising data. Mention "Cog Rev" for 10% off www.omneky.com
The Brave search API can be used to assemble a data set to train your AI models and help with retrieval augmentation at the time of inference. All while remaining affordable with developer first pricing, integrating the Brave search API into your workflow translates to more ethical data sourcing and more human representative data sets. Try the Brave search API for free for up to 2000 queries per month at https://bit.ly/BraveTCR
Plumb is a no-code AI app builder designed for product teams who care about quality and speed. What is taking you weeks to hand-code today can be done confidently in hours. Check out https://bit.ly/PlumbTCR for early access.
Head to Squad to access global engineering without the headache and at a fraction of the cost: head to choosesquad.com and mention “Turpentine” to skip the waitlist.
TIMESTAMPS:
(00:00) Intro
(04:03) Reverse Engineering AI
(10:10) Factual Knowledge Localization
(16:34) Sponsors: Oracle | Omneky | On Deck
(18:27) Future Lens Paper Intro
(23:54) Choosing GPT-J
(26:54) Vocabulary Prediction
(30:36) Sponsors: Brave | Plumb | Squad
(33:13) Fixed Prompt
(35:57) Soft Prompt
(41:32) Future Lens Results Analysis
(51:50) Tooling & Open Source Code
(56:07) Larger Models & Mamba Probes
(1:04:10) Function Vectors Paper
(1:09:39) Extracting & Patching Vectors
(1:13:30) Encoding Task Understanding
(1:15:36) Expert Models Implications
(1:18:09) Conclusion
Music licenses:
Y7AMC6WGEGV0C1NX
UVOQFS7SEPZF7JJQ
-
EXCBARTTNYZA0KPB
8TDCM1XEGMAGOJRV
Full Transcript
Transcript
S0 (0:00) Machine learning worked okay, but didn't really work in profound ways until the last 10 years or so. But now it's really working. It's really working remarkably well. And so we're facing a new type of software that we cannot use traditional computer science tools and traditional computer science methods for dealing with how to ensure that it's correct, that it does the stuff that we want.
S1 (0:22) Logit lens is essentially the tool where it projects these intermediate states into the decoding layer. So essentially, we can just see for that current token prediction, what is the model currently thinking about. We wanted to start with a quote unquote simpler solution, like having some sort of linearity with decoding future tokens. That would have been like a nice final solution, but we realized that no, it's a lot more complex than that, at least at the moment.
S2 (0:47) We tested like 40 plus different tasks, and it seems like they're all mediated by this small set of heads, which is cool that it seems like the model has this sort of path that it communicates this task information. Even though in the prompts, we're never explicitly telling it what the relationship between the demonstration and the label are, it's able to figure that out and communicate it forward.
S3 (1:10) Hello, and welcome to the Cognitive Revolution,
S4 (1:12) where we interview visionary researchers, entrepreneurs,
S3 (1:15) and builders working on the frontier of artificial intelligence. Each week, we'll explore their revolutionary ideas and together we'll build a picture of how AI technology will transform work, life, and society in the coming years. I'm Nathan Labenz joined by my cohost, Eric Todd,
S0 (1:43) and Eric Todd,
S4 (1:44) who who made made a a surprise surprise appearance appearance to to discuss discuss their their recent recent paper paper describing function vectors. Professor Bau left Google to establish this group with the goal of cracking open the black boxes that are modern AI models so as to understand the internal mechanisms that underlie their key capabilities and ultimately gain robust control of increasingly powerful systems. And indeed, starting originally with vision models and now working more on LLMs, The Bau lab has done pioneering work in this area, delivering a remarkable series of mechanistic interpretability papers over the last couple of years. In this conversation, we go deep on a couple of the group's recent publications. First, Koyena walks us through her future lens work, which develops techniques that show that even mid sized language models seem to be thinking multiple tokens ahead. Then Eric Todd joins to tell us about function vectors, a complicated pattern of activity spread across multiple attention heads in different layers of the transformer that seem to enable in context learning by encoding the nature of the task inputs and outputs. This work is fascinating for anyone who wants to develop their intuition for how today's models actually function under the hood. And they're a notable step on the path to overall AI interpretability, identifying new abstractions that translate low level patterns of computation to higher order behaviors. I was a fan of this work coming into the conversation, but I left as a big fan of professor Bau as a research adviser as well. Throughout the conversation, he does a great job of connecting the dots between his students' individual projects, motivating the lab's broader interpretability agenda, and emphasizing key open questions in the field. As always, if you find this work valuable, please do share the episode with your friends. I'd especially suggest this 1 for anyone who's interested in mechanistic interpretability and considering a PhD in machine learning. And please don't hesitate to share your feedback or suggestions via our website, cognitiverevolution.ai, or by messaging me on your favorite social network. Now please enjoy this illuminating discussion about the inner workings of large language models with mechanistic interpretability researchers, Koyena Pal, Eric Todd, and
S3 (3:52) professor David Bau. Koyena Pal and David Bau from Northeastern University, authors of Future Lens. Welcome to the Cognitive Revolution.
S1 (4:01) Thank you.
S0 (4:02) Really happy to be here, Nathan.
S3 (4:03) I'm excited for this conversation. So professor Bau leads the group and Koyena is part of the the research group. And together with your colleagues, you guys have really developed a remarkable interpretability and sort of editability agenda with respect to large language models over the last couple of years. I have covered a couple of your papers previously, although not in the depth that we will get to very shortly with the future lens paper. But it's a really cool agenda that is shedding a lot of light on how these newfangled AIs are actually working, what's going on in there, kinda cracking open the the black box problem, certainly regular listeners will know that that's something that I really appreciate and value. For starters, you guys wanna just give us a little background on the overall agenda of the group and maybe highlight a couple of your favorite papers that you've published previously?
S0 (4:50) It's a pretty simple agenda. So I thing that motivates our group is the fact that machine learning is really a new era for computer scientists. Machine learning has been around for a while, but up until the last decade or so, it's been this dream of having a way of creating self programming computers that learn from data, that people don't really have to program every line of code. Machine learning worked okay, but it didn't really work in profound ways until the last 10 years or so. But now it's really working. It's really working remarkably well. And so we're facing a new type of software that we cannot use traditional computer science tools and traditional computer science methods for dealing with how to ensure that it's correct, that it does the stuff that we want. All of our tools are very statistical. They're very black box. And so what the theme of the lab is, is to ask how do we confront this new type of software? Are there things that we can do to recover our ability to debug and understand the software that we create? Are there ways that we can take more responsibility for the behavior that emerges from the software? Can we edit the code like the way that you would change a line of code in a traditional program? And so this is actually pretty hard. It would be like asking the question, can you totally control the mechanisms of a tomato or some biological thing that emerged out of a process of evolution? And machine learning is similar that way. It emerges out of a training process. And so just like biologists have to reverse engineer how a tomato works if they wanna really understand what they can do to make it better, we're also trying to develop a science of how to understand the internals of these systems.
S3 (6:35) Now is that an agenda that you have changed course on relatively recently? I I assume it has to be in as much as the understanding of advanced AI systems in this way is predicated on their existence. And I would imagine, like, at least probably prior to GPT-two, there was almost nothing powerful enough to do this kind of work?
S0 (6:55) Oh, no. I've been at it for a while. So my personal history is I worked at Google for many years, and I actually left Google to pursue research when it started to become clear around 2015 that the field was going to change. And the speed at which it's changed has surprised me. It surprised almost everybody in the field, but it was clear that there was a sea change happening. The moment for me was also the same moment for a lot of people when AlexNet came out in 2012 and showed that this really simple, really classical machine learning procedure, something from an idea from the 1950s, right, and an idea from the 1980s, training these neural networks, which really the community at large thought that such a simple procedure probably wouldn't be the thing that solves machine learning in the long run. People thought it would be something trickier or something more complex or something like that. But this idea that all you have to do is scale up neural networks and it works really well was a real eye opener. So the things that I studied at first, I studied generative vision models a lot like GANs, which were really surprising when they came out. And I thought, I'll study these GANs for a long time. Of course, after GANs, we have diffusion models. And, of course, we've got these amazing language models and other things going on. So it's been this incredible moving target.
S3 (8:13) It does seem to be a pattern that the earliest to the game started with more vision models. I know that's true for Chris Olof, for example, as well. So, yeah, fascinating. The kind of shared lineage there. You wanna touch briefly on the ROME and MEMIT papers? This is 1 that I've covered a little bit in a previous episode. And the way I remember those papers is as a demonstration of the locality of factual knowledge in a language model and also the ability to then edit that language model in ways that preserve a lot of the key properties that you would want to preserve if you were editing in a knowledge graph, where, for example, if you were to I think the the example in the paper is Michael Jordan played basketball, of course, is the real sentence. But if we wanted to edit that world knowledge to reflect an alternative reality where Michael Jordan played baseball, then we would want that to be not just for that specific sentence, but robust to different ways you might ask the question, different ways you might set up an auto completion, and in a process of kind of systematically working through the layers to figure out where you're corrupting bits as you go through the layer and figuring out where it starts to fail and then using that as the place to edit and then achieving that kind of robust editing. It's been a couple months since I went deep on that 1. Tell me how I did.
S0 (9:32) No. No. No. That's right. And I would back up a little bit, maybe half a step further, which is the fundamental question that we're asking, and it's actually related to future lens too, is is there like a physical reality to the idea of a concept or an idea or a unit of knowledge? We have these things in our head. We think you can go to somebody and you can say, that's something that you know. Like, you know where the Eiffel Tower is, or you've never heard of the Eiffel Tower, you have no idea where that is? No, it's something that you don't know. So we have this idea that people can know things or they can not know things or they can have an idea or a concept or they can not have 1. It just seems so crisp intuitively in our own heads. So the question, is there a physical reality to this? If an artificial neural network seems to know something or doesn't know something, is there a physical reality to that knowledge? Or is it just diffuse and spread out in all the neurons and all the computations through the whole network? And I think that there's definitely a school of people who would generally believe that most things are just unimaginably diffuse, that everything's just spread out. There's no reason that there would be any locality. But there's a bunch of neuroscientists in the biological world and some computer scientists in the artificial neural network world who have taken a look over the years and noticed that a lot of times you see locality. I saw that kind of thing in GANs, but our question was, do you see this kind of organization in really large models? And do you see any that corresponds to the facts? And so what ROME was about is a bunch of experimental methods for narrowing this down and finding out where facts are localized in a large language model. And it sounds like you had another PIE test on it, so maybe you can direct people to listen to that 1. But, yeah, that's the general setting. And of course, there's a bunch of clever methods that you can use to really pin that down.
S3 (11:22) What is the state of that work now? I know that the successor to ROME was the MEMIT paper, right? And that 1 showed that you could edit up to like 10,000 facts at a time.
S0 (11:33) Sure. And then there's another paper that we've written that's actually also coming out this coming year at ICLR. It's called the LRE paper, the linear relational embedding paper. And that paper looks at a different slice of the problem. That paper really asks if you have a relationship, like something is located in a city or somebody plays a certain instrument or somebody plays a sport or somebody holds a certain position or something like that. There's this general relationship. What's the representation of that general relationship? The ROME paper was asking what's the representation of like an individual association, but we really wanted to break that down. We can see that there's an individual association. We can see that the representation of the subject and the object seem pretty clear, but we wanted to understand what's the representation of the relation between the subject and the object. And the LRA paper, explores that and finds it in a bunch of the cases, the relation can actually be understood as a linear mapping. But it's actually sort of an incomplete story because we found that for a lot of the cases, there's something more complex going on. There's something that looks nonlinear. And so it opens the door to future research to ask what's going on in the more complicated cases. These relations are sometimes really non trivial. And I think that there's going to be some interesting, exciting work to do to figure out like, why is it that the models are modeling some relations in a really complex way? What are they getting out of all that complexity? And is it enabling new types of reasoning that we haven't characterized yet?
S3 (13:09) Cool. That's awesome. I look forward to that 1 in more detail in the future as well. Koyena, let's get into the future lens paper where you're the lead author. This 1 really called out to me pretty much immediately because I think you were speaking to it a moment ago as well where you're saying there's a school of thought out there that's like, this is all stochastic paratry, and there's no real understanding, and it's just next token prediction. So I'm always interested in things that explore, like, when near next token prediction can in fact be a bit more than meets the eye. And this paper is a really good example of that. So you wanna start off maybe by just talking about the the specific motivation for this particular investigation within the the broader agenda of the group?
S1 (13:54) Yeah. So actually, the way this problem started was, like a little different from how it was before. Initially, was interested in looking at memorized content in large language models because when you use chat GPT or any other models, some things that I've noticed like license codes, for example, if you just give couple tokens in, you kind of know the next 11, 20 tokens that tend to be exactly the same from training data. So initially, I was actually first exploring that. And while I was exploring that, I like tangentially went to this point where like, okay, well, I can see that it's saying these tokens pretty much like, word to word, but like, how about for smaller phrases like New York City and stuff like that, which, for us, we consider that as 1 whole entity. Yeah, that's how it started off in terms of just looking into seeing whether 1 a single hidden state can we actually decode such information from that. And yeah, we were exploring ways that you have seen in the paper in terms of to what extent can we decode that future information about a particular entity or, like, just anything in general at this point.
S3 (15:00) So the question as it's I'm reading directly from the paper. In this work, we ask, to what extent can we extract information about future that is beyond the currently under consideration token from a single hidden token representation? So let me, again, just pitch you how I understand this work and you can correct any misconceptions that I have. I think the vocabulary, especially for people that are not in the research, can sometimes be a barrier to entry. So if I understand correctly, what you're doing here is looking at activations within the forward pass of a large language model. As the numbers are getting crunched, I I like to use the most basic terms I can. The numbers are getting crunched from beginning to end. And between the layers where the actual intensive computation occurs, there are these kind of middle states, which are known as activations. You are taking an array of numbers basically that sits there at different places, and we can get into more details of the different places that you can look. You kind of look at every given layer, right, through the transformer to to see a different activation. But the question you're always asking is, given this array of numbers, these activations, can I predict not just what the current token is gonna be as I'm working through that forward pass, but can I also see what the future tokens are going to be just from the information that's in that 1 particular vector of numbers? Is that right?
S1 (16:25) That sounds about right. Yeah. Like, we just termed those activations hidden states because those are like hidden and they are part of the intermediate computation and stuff.
S5 (16:34) Hey, we'll continue our interview in a moment after a word from our sponsors.
S1 (16:39) So logit lens is essentially a tool where it projects these intermediate states into the decoding layer. So essentially, we can just see for that current token prediction, what is the model currently thinking about, like based on the distribution and when you decode it to the highest probable word, what is the model currently thinking?
S0 (16:58) I recommend anybody who's listening to the podcast who hasn't seen logit lens before. This is not our work. This is somebody else's work. There's a blogger, a nostalgic brist who created this idea of logit lens. You should Google it and take a look at what logit lens looks like. It's these beautiful grids of words that are overlaid on all the hidden states of the transformer, where every hidden state, every 1 of these vectors that's inside a transformer, all of these intermediate numbers are all translated into a word. And then at the bottom layers, at the very earliest phase of the transformer, you can see what words the transformer is basically internally thinking. And then as you go through more and more layers, the words evolve until you get to the end. And you can see the predictions of words get more sophisticated as you get to the end. It's always guessing what the next word should be that it should say. And those guesses get better and better the deeper you go into the transformer. It's really interesting to look at the logit lens.
S1 (17:57) Yeah. The big disadvantage is you only see what it's currently thinking, which may not be representative of what it might be actually thinking about couple of tokens ahead. It's not as informative in that sense. And so we are trying to like make it more informative and seeing like, okay, it's currently thinking of the word new. What word is it thinking after new? Like is it thinking of New York, New Jersey? These are like some examples of things that could follow you, but that's kinda, like, not possible to see with logit lens for that particular hidden state, at least.
S3 (18:27) Yeah. So there's 2 directions that we can consider. Tell me if this intuition jives with your understanding. In the course of the forward pass, basically, this is what the logic lens looks at. It's like, as we're going layer by layer, we're sort of gradually getting more sophisticated in the way that the model seems to be understanding the input, and it's gradually converging to the prediction that it's actually gonna make. So you can see at the beginning, it, like, has a low level of confidence and maybe wrong if you just took that hidden state and went directly to decode. And that would be basically like, okay. Let's say we do 2 layers of computation, then we skip all the rest, and then we decode whatever came out of the 2 layers instead of all the layers. What would we get? Right? What we get is not very accurate at 2 layers. It gets more and more accurate as you go through the layers as it finally reaches its output, which makes sense. Right? Because that's presumably why we need all the layers is to be working through the information and gradually getting somewhere. And then there's also the token direction. Right? We've got, like, a sequence of tokens, and information can only flow forward because that's the nature of the look back attention mechanism. So it can only flow forward from tokens to future tokens, but that means or at least it stands to reason. Right? And this is what you will then set to prove that as this information is built up from prior tokens to the current token, it probably means that there's more there than just what would be needed to predict the immediate token. There's been, like, lots of interesting intuitions around this with even just articles in English. Right? Like, if you set up a situation where it's very clear that the actual noun is gonna begin with a vowel versus a consonant, then you can see that it predicts the article correctly. And then you're like, that's pretty suggestive that maybe there's more information there. Otherwise, how would it be choosing the article? It seems to be anticipating what's coming beyond this current token as well. So I think of the information processing proceeding, like, vertically and the token information flow proceeding horizontally. And now we're in this moment where it's okay. At each layer through the forward pass, through the vertical dimension, we can check to see, like, would this be enough already to predict the current token? And, again, that's the logit lens kind of based on what you're building. Now we get to, okay, can we also detect that future token information is there? Right?
S1 (20:47) Yep.
S3 (20:47) Okay. Cool. So tell me how you did it. First of all, I noticed that this was done on GPT J. That was kinda surprising, but maybe the I'm I'm just not sure what I don't know about GPT J. It seems like these days this is like Llama would be the default for this sort of thing. So how did you choose the model?
S1 (21:04) Again, since this was like a spin off from the earlier project that I was talking about. So I was using GPT-J and I just continued with it, but it was definitely like 1 of those models where I could run it on like a single
S0 (21:17) GPU. Have a feature lens on Llama now?
S1 (21:19) I think people are building it. If it has the same vector size like 4 0 9 6, then you can technically do the whole transfer. But yes, generally, this method should be workable for other models as well.
S3 (21:31) Yeah. Let's go back to the other models in a little bit because I I have some questions around, like, how you think the different models might behave differently under this sort of examination, but it'll help to establish exactly what it is we're doing. And you said the activation vectors, this is actually 1 of my questions. The vector from which you are predicting, that's a 4,096 size vector?
S1 (21:54) And, like, the number of tokens, but then that would be 1 for us. So yeah.
S3 (21:57) Okay. Cool. So then tell me about the dataset. So this is the first as is so often the case, the first challenge in doing something like this is you've gotta collect a dataset. This 1 is kind of a a subtle 1 to describe.
S1 (22:09) Yeah. So since I knew that GPTJ was trained on pile, I was like, okay, why don't I take the pile dataset and sample couple of texts from it and just get the hidden state from each of the sentence in there. And yeah, I'll run it. So I collected like 100,000 of them. 1 of my initial condition was if it is GPTJ itself, if it is reading this 1 token that it's predicting the next token correctly in that sense. And so that's 1 of the few sample condition that I had. But otherwise, yeah, once I got like 100,000 of them, I stopped and then I was experimenting on that.
S0 (22:45) So it's 100,000 samples of texts with what what's the filter that you used in that?
S1 (22:50) Oh, the filter that I used was that given the last text of the prompt that the next 1, it predicts correctly. The model predicts correctly. Yeah.
S3 (23:00) Okay. So there's a 100,000 prompts identified by the condition that the model must get the next token correct per the actual source material. So what are you actually collecting for each of the 100,000? It is the activation, right, or the or the hidden state for each of the layers throughout the entire transformer?
S1 (23:20) Yeah. That's what eventually happens. But in terms of just storing purposes, I had just the CSV file of these 100,000 sentences. But then later during the training process, I would gather the last hidden state of it at that particular layer for whatever, like, probing or a method that I'm trying.
S3 (23:37) Because all of the analysis downstream of this is going to be at a layer specific level. Right?
S1 (23:44) Yes. Ultimately, the actual input itself would be the 4 0 9 6 vector. But, yeah, when I'm collecting all of them, I'm collecting them based on the strings and stuff.
S0 (23:54) So, yeah.
S3 (23:54) I always like to get very clear in the inputs and outputs. So now we've got to the point where, okay, we got the 100,000 texts. And our goal is to say, for a given layer, can we figure out a way to take that 4,096 length vector and use that as the input and then predict not just the current token, but the next several tokens from that single input vector.
S1 (24:19) Yeah, almost there. So definitely layer, but also the last token. So if my prompt was 10 tokens, if it was layer wise, it'll be 10 by 4,096. But since again, it's just 1 token and a single hidden state will be the last token of the input prompt, that would be so that's 1 x 4 0 9 6. Yeah.
S3 (24:37) Okay. Cool. And then just as validation, there's another thousand that you'll do the the testing on once you develop the methods. Okay. Cool. So we've got 5 distinct methods where the first 1 is the null hypothesis, which is just the bigram, which is, okay, given the total frequency of word pairs in the entire pile dataset, to what degree can I just look at the current word and predict the next word solely from that? And if you're guessing at home what that number is, I'll give you a second, it is 20%. Apparently, 20% of tokens in the pile you can guess just from what the knowing the current word gets you basically to 20% accuracy. That was kind of surprising to me. I wouldn't have guessed it would be that high. There you have it. 20%.
S0 (25:22) Yeah. This language modeling is not so hard after all. We can just get 20% from counting and that's it. This Ngram thing was the state of the art method until not too many years ago.
S3 (25:31) Yeah. But it's not state of the art method anymore. So let's talk about the 4 methods that you developed. And first 1 is direct vocabulary prediction. So here, you take that same 4 96 vector and just specifically try to jump directly to the logits. Right? The the final percentage. It's not exactly percentage, but it's within a very short transformation from percentages of the actual tokens to be predicted for the current and then the the next few. What is the nature of the thing that you're learning? This is because it's a learned method. Right? You're gonna optimize and essentially learn a way to make this prediction across the 100,000 samples. What is the nature though of the vocabulary prediction thing that is being trained?
S1 (26:20) Yeah. So, like, the goal with this direct vocabulary prediction is essentially, can we like literally decode the token from that hidden state? I imagine the linear models would be the most basic model for probing in terms of just being able to grab that token. And so, yeah, that was essentially the goal. We wanted to see whether there was some sort of linearity in, like, being able to just decode what a future token is. Like, is that stored linearly in the hidden state?
S3 (26:47) So that is to be very concrete, a 1 layer transformation?
S1 (26:52) No. 1 layer, 1 token. So in some sense, it's, like, context free. But, yeah, it's 1 hidden state to 1 other hidden state, and each of them is, like, 1 x 4 0 9 6.
S3 (27:01) The thing that is being trained that is doing the transformation from the hidden state to the prediction is just a simple linear transformation.
S1 (27:10) And sorry, just to correct myself earlier since we were talking about the vocabulary prediction, it's actually 1 x 4 0 9 6 to 5 0 4 0 0. So essentially, the way this vocabulary prediction is that the GPTJ, if we were to tokenize the words, the size, the vocabulary space of a GPTJ is about 50,000 tokens. And we are essentially trying to get that vector of like 50,000 tokens and choose whichever is the highest probability as the token that would be predicted. So that's for the linear direct vocabulary prediction. And so yeah, that's why for this particular method, we grab in the hidden state, which is again 4,096 vector, and then that's 2 5 0, like that 50,000 size of like 1 token.
S0 (27:56) You're literally training a decoder, your own decoder head on this problem. Right? Yeah. Essentially, yeah. Training a little linear classifier.
S3 (28:04) So the number of parameters in that linear classifier would be 4 0 9 6 times the 50,000 vocab size. Yep. So, essentially, each number in the hidden state makes a contribution to each output token, and the intensity of its contribution is determined by the parameters, which are learned in the training of this linear
S1 (28:27) That sounds right. Yeah. Mhmm.
S3 (28:28) Okay. Cool. So that's direct vocabulary prediction. So we got a 100,000 of these 4 96 vectors. We know what the actual NEXT tokens are, and we're training this mapping of 4 0 9 6 by the vocab size to do that mapping.
S5 (28:46) Hey. We'll continue our interview in a moment after a word from our sponsors. Hey, all. I'm hearing more and more that founders want to get profitable and do more with less, especially with engineering. Listen. I love your 30 year old ex fang senior software engineer as much as the next guy, but honestly, I can't afford them anymore. Founders everywhere are trying to turn to global talent, but boy, is it a hassle to do at scale, from sourcing to interviewing to on the ground operations and management. That's why I teamed up with Sean Lenehan, who's been building engineering teams in Vietnam at a very high level for over 5 years to help you access global engineering without the headache. Squad, Sean's new company, takes care of sourcing, legal compliance, and local HR for global talent so you don't have to. With teams across Asia and South America, we can cover you no matter which time zone you operate in. Their engineers follow your process and use your tools. They work with React, Next. Js, or your favorite front end frameworks. And on the back end, they're experts at Node, Python, Java, and anything under the sun. Full disclosure, it's going to cost more than the random person you found on Upwork that's doing 2 hours of work per week but billing you for 40. But you'll get premium quality at a fraction of the typical cost. Our engineers are vetted top 1% talent and actually working hard for you every day. Increase your velocity without amping up burn. Head to choosesquad.com and mention turpentine to skip the wait list.
S3 (30:13) Then there's kind of 2 variations of this. Right? The first 1 is the direct vocabulary prediction. The second 1 is the linear model approximation. And if I understand correctly, the difference is whether you're jumping straight to the output tokens is what you're predicting or keeping the existing decoder and instead trying to predict something that then the decoder would properly decode.
S1 (30:35) Exactly. Yeah. 1 is to the vocabulary distribution, and then the other 1 is to the hidden state, which the decoder layer would decode it to that ultimate token. Yeah.
S3 (30:44) And that's pretty interesting. I didn't have too much of an intuition for, like, why 1 might be thought to work better than the other. And then it it it seems like it basically turned out that they roughly work equally well, like, almost uncannily similarly.
S1 (30:56) Yeah. Actually, for me, that was a little surprising because I would have assumed that because for the vocabulary distribution, like it's also trying to learn how to decode, like it's creating its own decoder thing, I thought it will take probably more number of iterations or basically a lot more data for it to come to that stage, but it works similar to how if you just predict the hidden state and have your decoder layer handle that. Yeah. It works similarly. So that was pretty cool to see.
S3 (31:24) Yeah. Is there something that we can understand about the broader system based on the fact that those 2 approaches work the same? Like, it seems you basically have 2 linear classifiers, and you can't like, 2 is not better than 1 is kind of an interesting way to see that. Right? I guess 1 way to interpret that would be that you really are getting toward the max of the information that is there. I guess another thing you might try would be, like, a bigger, more you could train a full transformer to do this sort of thing, and maybe you did. But, I guess, if 2 linear classifiers aren't better than 1, then maybe you'd think we're already maxing out. Is that how you would interpret that?
S1 (32:01) Yeah. That's how I'd interpret it. We wanted to start with a quote unquote simpler solution, like having some sort of linearity with decoding future tokens. That would have been, like, a nice final solution, but we realized that, no, it's a lot more complex than that, at least at the moment. And, yeah, that's what we tried. We tried our best to max out the linearity aspect of our experiments.
S3 (32:21) Okay. Cool. So that's 1 class of thing, training these linear classifiers. And then the other class of thing is basically an approach where you take the activation out of the context in which it was generated and port it over to and kind of stitch it into a broader forward pass context where that other forward pass context is engineered to be as neutral as possible or, I guess, ultimately engineered to allow that activation to shine through and and get its content through to the predicted tokens as effectively as possible. So the first way I thought was, like, super intuitive to do that. You just set up the the prompt with, hello. Could you please tell me more about and then port in the activation from the other context where however much information has been aggregated at that final token is now being ported in. And the theory there is like, well, because hello, could you please tell me more about is so generic, then, like, in theory, whatever you put there should set it up for a kind of nice clean continuation. Unfortunately, that 1 didn't seem to work so well. So I was surprised by that. What do you think is up with that?
S1 (33:29) It's because perhaps it was too generic in that sense. While, yes, we were targeting generic prompts so that it has possibility for whatever the intervened hidden state is to like ultimately have that token shine through. But it was probably still too general. The prompt wasn't really optimized for it to be like, hey, I wanted the feature tokens out from this particular state. And so that's why I feel like it didn't work as well.
S0 (33:56) We tried a few, right? Didn't find any that really worked that well.
S1 (33:59) Yeah. Before I gave a couple sentences, I literally started with start off text, just a dot opening bracket. Yeah. I tried a bunch of them, but they didn't seem to, like, work as well either. Yeah. But it was working a bit better than linear models, of course, but still not really as great.
S3 (34:16) Yeah. That's interesting. Okay. So now what comes next though is perhaps even more surprising in view of the fact that you tried multiple actual natural language prompts and none of them worked. The next approach is saying, okay. Well, you know, where we're going, we don't need language. It's the soft prompt technique, right, where this time you're learning what the embeddings should be from a sort of hypothetical abstracted prompt. Like, what would the ideal prompt be if we didn't have to represent it in tokens, and we could just go straight to the numbers and engineer it at that level so that we could get the tokens out the other end that we want. So in this case, the optimization is tweaking those embeddings across all of the activation to future token prediction pairs. And by optimizing those embeddings, now you can really start to get somewhere.
S1 (35:07) Yeah. That sounds right. That's exactly what we did.
S3 (35:09) That's surprising. What kind jumps out at me, obviously, looking now looking to the results, that method works by far the best. And how do you understand why that works so much better? Is there any ability to say what the ultimately learned soft prompt means? Can we translate that back into something that I can understand?
S1 (35:26) We were actually attempting to do that. We were like, okay. We found a soft prompt that works great. Let's try to like get a discrete 1, but it was pretty gibberish. So it was like, okay, we'll keep it in like the vector space for now.
S0 (35:38) How gibberish was it?
S1 (35:39) It had like a couple of at somewhere and it had some like, they looked like words, then they didn't make sense kind of words. Those were what was there. Yeah, it was pretty gibberish. So we just went on with the vector space soft prompt. But yeah, in terms of the results, it was nice to see that there was a pretty significant increase. And the reason I imagine is because this prompt is trained so that like whatever intervened state we have, it enhances whatever the future context is encoded in there. And so because in the manual versions, we were thinking of what the prompts could be, like, what we think would make sense. But over here, this is what made sense of the computer. Yeah.
S0 (36:19) It's amazing that basically, there's this little capability in the transformer to solve this problem, right, to tell you what the future tokens are. It has some little machine in it that knows how to decode it, and Koyena, with her soft prompt training, basically came up with a prompt that is like a key, that like, is sort of this capability to do it. At least that's the way I see it. But the key didn't seem to correspond to any real piece of text other than this weird gibberish that she saw. Right?
S1 (36:47) Yeah. It
S0 (36:48) wouldn't have been something that we could have predicted.
S3 (36:50) It reminds me very much of the universal jailbreak paper, which we did an episode on, and there there was sort of a similar now they were operating in token space, but still often found these universal jailbreak prompts to be gibberish sequences. And sometimes they did have some that were, like, more human readable. Now this was something that you guys did a lot. Right? Like, all of these techniques are happening at each layer. So for each of these techniques, you had to train like, for the soft prompts, there's what? How many layers in GPT-J? Is it 28? So you had to train 28 different soft prompts, 1 for each layer at which you're gonna patch in the activation.
S1 (37:30) So that was what the method was initially. But like the actual feature lens, the tool, we just use 1 soft prompt. We use the best model, the best layer that worked. We did that for every state when we actually built the tool. So in the tool, it's like 1, but then we were using this method to find which was the best soft prompt to use if you may. Yeah.
S3 (37:49) Gotcha. Okay. So yeah, let's talk a little bit more in detail about the results. I think these are suggestive of many certainly, it's a good graph to just sit and ponder for a little while. Again, folks can refer to the paper to see this. This is figure 4 from the paper. Basically, what you've got here is 4 different graphs. Each 1 corresponds to the token that is being predicted. So n equals 0, the first graph is like predicting the current token. Then it's n equals 1, the next token, 2, the 1 after that, 3, the 1 after that. And then for each of these predictions, you're able to look at, okay, put the layers across the x axis and the success rate on the y axis. And you can see that as you proceed through the layers, the prediction success changes. For the first token, the current token, basically, you recap the logit lens results, right, where you can see that as you move through the layers, the predictions get more and more accurate by the time you're at the end. By definition ish or close to definition ish at the end, you're getting the answer right close to a 100% of the time.
S1 (38:49) Yeah. That was our sanity check for n n equals 0. We were like, okay. Is this method actually working? Let's make sure it's actually decoding the current tokens. Yeah. That was more of a sanity check.
S3 (38:58) So then several things jump out at me about the next 3. 1 is that, again, we've already said that the 2 linear projection methods basically worked almost exactly the same. Like, you can see that those lines for tokens 1, 2, and 3 are in almost lockstep for through every layer. Those did not work that well. And in fact and this is kind of surprising too. On the first token, they beat the 20% baseline a bit. But then on the second and third tokens, they're actually worse than the baseline. I guess maybe the baseline is out of domain there at that point anyway, but I wasn't expecting to see things go lower than the baseline baseline really.
S1 (39:33) Yeah. I assume so. We created like the bigram baseline and so for n equals 2, n equals 3 would be like trigram and like the 4 g models. They were too huge for us to compute at that time. And so, yeah, that's why we didn't include the 20% in there. But yes, it's more to do with the fact that it was trying to predict 2 tokens ahead, 3 tokens ahead.
S3 (39:53) So the baseline is not so relevant past just the the immediate next token. So you probably don't have a paper with those methods. Right? Because they didn't work that great. The next 1 is the fixed prompt. This 1 actually does go below the baseline on the first 1. So, again, I'm like, man, that seems weird. Why would it be lower than the straight baseline? And even to make that more confusing, that method seems to get better for the subsequent tokens. Like, it's worse at predicting n equals 1, and it gets better at predicting n equals 2 and n equals 3. And that really surprised me.
S1 (40:28) So to provide a bit more context, so when we were checking n equals 2, n equals 3, we assume that we give like the correct token for the previous ones. So let's say we are talking about New York City, for example. Let's say n equals 0 is predicting New. 1 is New York. 2 is City. And so let's say because there are so many possibilities after New that it doesn't really predict New York. But if we give New York, if not, if we give New, if we give York, can it actually then predict City? And so that's the idea behind checking for 2 and 3. Because generally, if 1 is wrong, then it's pretty sure that 2 and 3 will be wrong if you check whole phrase wise. So that's why, yeah, we gave 1 extra token context in that sense. But with that, it's essentially the same thing with 1 additional token, but nothing else, like whatever the current token was.
S3 (41:17) Can you help me understand that a little bit better when I am imagining porting an activation over to a different context? How should I understand that, like, mechanistically interacting with the idea that there's another kind of subsequent token included?
S1 (41:33) Yeah. So when you do transplant that hidden state to the fixed context, yes, it's actually completely from a different step at this point. Let's say we did this transplantation at like layer 14. So from layer 14 onwards, it would have the context that is present in that hidden state. But if sometimes that's not enough, as we saw in the fixed prompt result, n equals 1. But let's say if we gave that particular hidden state plus what the next word could have been, so like York, in that sense, maybe the reason why it's performing better, it's okay because it now has, like, a direction of, like, where it could predict next. But even then, it's just that hidden single hidden state and that particular 1 token. So
S3 (42:13) Interesting. So with that method, it's
S5 (42:15) getting
S3 (42:16) more information. Does that also help me understand? Because another thing that I was quite surprised by is that the performance of that fixed prompt method seems to be pretty clearly declining as you go through the layers. And whereas the first 2, the linear transformation ones are, like, flattish, climbing a little bit, the fixed prompt 1 is declining. And I was like, that seems like a very odd reversal from the main logit lens. Do you have an intuition for that, and does that have something to do with the extra tokens that it's given?
S1 (42:49) So, yes, in some context. But usually, the way I imagine just as the first, like, last layer is that from beginning, it's trying to build an understanding of what it's gonna say. And then in the middle, it has these set of ideas that it wants to say and towards the end, it probably wants to say immediately what it thinks it should be next. And so in that sense, I imagine there could be some information lost along those layers when it's ultimately trying to predict literally the next token in that sense. Yeah.
S3 (43:16) Yeah. And what you just described is definitely my intuition as well. The sort of gradual working up of inputs into higher order concepts through at least half. And my general sense is that in many models even goes deeper than that. Because rough picture in my head is more through, like, 80% of layers. And then, yeah, in the final layers, it's like, now we're condensing again. Like, now it's time to actually cash this out. All this sort of high order or high concept work that happened in the middle, now we gotta cash that out to a concrete prediction. And so you see a collapsing toward the end. And that is very much what we see in the soft prompt 1. So a couple of notable facts about the soft prompt results. First of all, they're just a lot better than everything else. And second, they do more follow this pattern that we both each just described of seemingly more and more information buildup that's more useful, rising line from the beginning up until the middle layers, and then declining. And just the sort of peak here is in the middle, and I I've seen later peaks in other models. My guess is that's probably just a function of model size where you'd have to have that end time and that limits how deep into the model you can continue to build. Does that seem right to you?
S1 (44:25) Yeah. That definitely seems right to me. And the whole idea of, like, the middle layer importance, it is something that I've seen in other interpret papers as well. And yeah, there's definitely something along in those middle layers, middle late layers as well.
S0 (44:37) It is interesting how it declines, right? It definitely suggests that the late layers is erasing some information that is using that space for something else. And that sort of explains its decline over time. So I feel like the whole phenomenon of a transformer erasing its own information is an interesting 1, and we see hints of it here. It is another cool thing to study.
S3 (44:58) Yeah. What else do you take away from this set of results that I haven't landed on myself?
S0 (45:04) 1 of the things here is that the information is present, but we weren't able to decode the information directly. Somehow the transformer's own circuits are essential for decoding the information. That's a pretty interesting insight. The mental model I have is like the transformer has its own dictionary in its weights somewhere of how to decode concepts into sequences of words. I love the example that you chose in the final version of your paper. Back to the future, right? Like the movie title, Back to the Future, it was longer sequences of words. And I imagine that it's not actually directly in the hidden state, the instructions of how to decode back to the future, but it's more like a pointer. It's it's just this compact vector that you have to look up in the model weights somehow. And it's that soft prompt that Koyena learned that is triggering some mechanism that causes this lookup to happen that eventually rolls out to the whole phrase back to the future. So I thought this was pretty interesting. If the decoding was simpler, we would have found it through 1 of these direct decoding methods. But it's something that's the encoding is complex enough that the transformer itself is needed.
S3 (46:14) Yeah. So something I said earlier, maybe I'm now thinking is perhaps wrong. Because I had said, okay. 2 linear transformations don't seem better than 1. And I jumped to the conclusion that, like, you could train a whole transformer, but it probably wouldn't work. But now I'm thinking maybe that's exactly what would be needed. What is your intuition say if you were to scale up the direct to go back to those first 2 methods, but instead of using a linear transformation, what if you use the full transformer to try to do the processing? Do you think that would work?
S1 (46:43) Yeah. I think that should work better. Again, with you when introduce transformer, there's like the non linearity thing, was required for the soft prompting, like instances. So ideally it should work better, but that's not something we tried out.
S0 (46:56) I don't feel like we had enough data to try it. Yeah. So, you know, this was a relatively small model training exercise you did, right? Where every layer and every token and all these different settings, you're training lots and lots of models. If you wanted to make a prediction on a large model like this, you'd probably need to train on millions or hundreds of millions of examples. And at that point, it's a little bit different. It's almost like the task that you're trying to learn, you might be best off just memorizing the whole thing. You could just say this vector corresponds to these tokens a few ahead and just memorize the whole problem. The thing that really surprises me about what Koyena found was that she didn't need to do that with a really pretty small amount of data, sort of 100,000 samples. She could unlock the dictionaries for all of these phrases. And and it wasn't like we're really training the transformer to do something new. That's that's why it feels like we're unlocking the knowledge that's already in the transformer. Does that make sense?
S3 (47:53) Yeah. That's really interesting. That's super helpful. And I understand exactly what you're saying about a 100,000 data points would not be enough to retrain a transformer. Maybe it could do the memorizing, but it certainly couldn't possibly relearn all the stuff that the full GPT j 6 b knows because there's no way to get there from such a small sample size. So in theory, it could maybe work if you had the full pile and you ran the full pile this way, but, obviously, that's not feasible on the resources available here. So you can't recreate that magic.
S1 (48:25) When we basically add in another transformer, that kinda ends up becoming like a chicken and egg problem. Like, we're trying to understand this transformer, but then we now have to try to understand the method transformer.
S3 (48:34) Yeah. It's a bit absurdist. No doubt. And it becomes almost anti interpretability at some point as well. Yeah. That's definitely an intuition building exercise more so than a way to create insight into how the the things are working. How big is the soft prompt? I assume that individual tokens are embedded in the same 4 96 space?
S1 (48:54) Yeah. So the soft prompt size was, like, 10 by 40 96, assuming 10 tokens and each of this token having 40 96 vector size. Yeah.
S0 (49:03) How's that compared to your linear models?
S1 (49:05) Oh, linear model, like, the input was, like, 1 x 40 96.
S0 (49:08) Yeah. But those models were, like, 40 96 x 40 96.
S1 (49:12) Like, your soft
S0 (49:13) prompt is way smaller.
S1 (49:14) Yes. That is true. It is much smaller indeed. Yeah.
S0 (49:16) So even though it was the best performing thing, it was, like, the the smallest number of parameters. It'd be, like,
S3 (49:21) 1000 x difference. Right? Because the linear transformation were 4,000 times 50,000 parameters, and this is 4,000 times 10 parameters.
S1 (49:31) That's true. But I guess, like, the trade off is the linear model is trying to understand everything what we technically do have to We
S0 (49:36) have the transformer helping us out, but the weights are frozen.
S1 (49:39) That's true.
S0 (49:39) Yeah. That's right. It's the power of using the transformer to explain itself.
S1 (49:43) Yes.
S3 (49:44) Yeah. Okay. That is really definitely very insightful. Did you try also varying the length of the soft prompt? My mind goes to what if there were no soft prompt at all?
S1 (49:53) Yeah. No. My coauthor, Jude Ng, he definitely tried out a couple variations of the prompts and found out that the size was like good enough in that sense for this. And I'm not sure if we tried 0 prompt, but I assume that would work similar to a fixed prompt size just because it is like a fixed prompt.
S0 (50:11) Yeah, it's something we don't know in detail. We didn't see anecdotal evidence that it made a huge difference to make the prompt different sizes. So we just picked a relatively small prompt and tested it.
S3 (50:20) Can we go back to the fixed prompt again for a second? I'm still a little kind of not content with my understanding of why that would actually perform worse than the bigram. I don't know. I would expect it to do better. I'm struggling. That feels like such a natural idea that it seems like it would work better than just such a simple statistic, and yet it's worse. And all of the ones you tried were worse? Like, there were none. I I assume that you didn't pick the worst of
S0 (50:46) the No. No. This is the best 1. My intuition is this. So we show these various successes. The function vector paper is an example of this, where if you like go and intervene directly in the hidden states or the ROME paper, you go directly intervene in the weights of the model and we show, oh, it does something amazing. But actually, this isn't generally the case. Imagine sticking a random probes into somebody's brain and then saying, Hey, know, I wonder if this'll make you smell the scent of bananas. But, you know, generally it's not possible to have a success with an experiment like this. So despite our best efforts at prompt engineering, some very clever thing for the model to be set up to do this, just jamming in this hidden state at the end of these engineered prompts, it just didn't work. It's just doing some sort of brain damage. It's really not getting it to predict anything very useful. Obviously, like you said, like a natural thing to try.
S3 (51:43) Yeah. Fascinating. Certainly, we can say it's out of domain. Right? You're putting it into a state that is clearly very unnatural where all of a sudden and you still have more computation to go from that particular place of patching. So, yeah, it still feels weird, but it's far enough out of domain that it just can't handle
S0 (52:02) And it's possible that there is some way to make this work that we weren't able to figure out. So we tried our best with a bunch of different setups for this sort of manual prompts, and this is the best example that you can see. But it's certainly possible that there's some structure that we're missing or some clever way of setting it up that would make it work better. In fact, so we could talk about the function vectors paper also later if you want. There are some other tricks that you can find to get these kinds of interventions to actually work pretty well.
S3 (52:29) I was curious if people like me are still feeling like they want to go try another fixed prompt and see if they can't find the magic words to make it work. What does the tooling look like for all this sort of stuff? What does the coding look like?
S1 (52:42) There are no variations of ways to do about this. But when I was working on the fixed prompt stuff initially, I was using actually David's Bau Kit tool, which allows me to look into a model, trace it and grab the hidden state and put it in like another run. But then there is a very recent tool called insight, which also does this. It's basically like the successor of Baukit. And it allows us to do all these like patching interventions and all these just basically few lines of code. Yeah.
S0 (53:13) You have the open source release for this project. Right?
S1 (53:16) Yes. We call it future.baulab.info. And yeah, over there we do have the code as well. And yes, basically in the code, have like notebooks and scripts and yeah, people can refer to that in that sense.
S0 (53:29) So yeah, so definitely check out Koyena's repo for trying to people can try their hand at seeing if they can get a clever prompt that works better or some variation on the method. We'd be delighted to see if there's some clever trick that we missed. There's definitely room for creativity in this direction. Absolutely.
S3 (53:45) So you give a pretty good intuition already for what that would look like. You've got the open source code. You've got a library that's specifically designed to facilitate this sort of activation, extraction, and patching, and the the loops surrounding that. To actually run a test, I should be able to do this probably in just even, like, a Google collab notebook, right, I I would think.
S1 (54:05) So because of the model and everything, you might need a Google collab probe, which I haven't run the training part. But again, the testing part with the feature lens, I've tried it on collab pro and it works. Yeah.
S0 (54:14) Have you used the end if back end?
S1 (54:16) Yes and no. I used it on the server actually, but not necessarily remote equals true because I was like, I'm on the server anyway.
S0 (54:22) You're a little bit of a cheater because you're you're sitting next to the implementer that's using this. But actually, a lot of the students who are playing with future lands and these kinds of things, that library was developed by Jaden. His good friend Jaden. And the library supports a remote backend that we're trying to get funding to provide the scalable free backend for people to do interpretability research. We call it the end if backend. It's like a deep inference fabric. And so if you code your interpretability experiments using this idiom, this little library, then, well, if you have the GPUs to run your model, then you can just run them locally. But if you don't, if you need resources to run a model that is too big for your laptop or for your collab environment or something, then you can just flip a flag. You can just say use a different backend, and then the experiment will be run on a shared remote back end for you. So, yeah, there's we have GPT-J running, we've got different MAMA models.
S1 (55:23) Yeah. Think the 7,000,000,000 model as well as a
S0 (55:25) single Yeah. Like, MAMA 7DB, which is a pain in the neck to run on your own, but we have that running.
S3 (55:30) Cool. So you have a single library that kind of abstracts away the subtle I mean, all these things are transformers, of course, but they have differences in terms of their implementations, vector sizes, etcetera?
S0 (55:41) Oh, no. They're not all transformers. Like, we have folks using it to study Mamba, for example, these state space models.
S3 (55:48) That's definitely of strong interest to me.
S0 (55:50) Yeah, it's pretty fun stuff. Okay, so people who are interested in this library, it's embryonic, but it's pretty cool. And there's a really nice dynamic community around it. So the library is called nsight,nsite.net is the, URL to get to it. We haven't really promoted it, so I don't know if you search for it on Google, if it'll come up even. Right? It's a little bit of a secret alpha stage project right now. But if you go to insight.net, there's a, like, an icon in the corner to link to the Discord channel to join, and there's a bunch of tutorial and documentation on it. But the really valuable thing is to just join the Discord and ask around, and there's a really friendly community that's developing it.
S3 (56:29) That's awesome. Is there anything that we can say about the bigger or more intensively trained models as compared to GPT-J? My hypothesis would be that you would probably find higher success rates with bigger models because you would, in theory, hope that they have more semantically meaningful middle state representations?
S1 (56:52) Yeah. I would go along in the similar direction as well. It's just that with bigger models, you also need bigger dataset for it to potentially train and be more exhaustive in that sense. But like, in terms of takeaway wise, at least till now, for example, the whole middle layer stuff, I think it should still be something that is pretty transferable, like other models as well.
S0 (57:12) As the advisor person here, I'll say, this is a gap in our knowledge and we need to run some more experiments. So you're right. I think the intuition that you have, which is that there might be more structure as you get larger models, and it actually might become more interpretable in this way, is really interesting. It's a little counterintuitive. Traditionally, people would imagine that as models get bigger, they might become less interpretable and harder to understand. But this intuition that you're sharing here that maybe it becomes more structured, more predictable, easier to understand, that might be true, but it would be quite a claim. So more research is is definitely in order to systematically take a look at these kind of questions.
S3 (57:51) Is there anything, interesting that you can share about the probes that have been sent into the MEMBA universe so far? I did, for reference, a 2 and a half hour monologue about the MEMBA paper in December, and I'm counting actually along with another like minded state space model fan. I I think we're up to very close to, if not hitting, 20 papers now published downstream of the MEMBA paper with different remixes and hybrid versions all over the place. So, yeah, have you found anything interesting there?
S0 (58:22) You know, all these attention free architectures, you you could basically ask how are they doing it? You can try to lay things like logic lens on top of them or try to build future lens on top of them. And to some extent, there may be in surprising ways, they show some similarities to transformers. Now, we're sort of halfway through doing some research to try to figure this out. And so I don't want to lead any of your listeners astray by saying something that's not actually true. But yeah, even though they are missing a traditional attention mechanism, they seem to be able to do some of the sophisticated types of computations that transformers use attention for. So exactly how it all works, we're still trying to untangle it.
S3 (59:06) Yeah. Cool. Alright. I'm looking forward to that. If you don't have a striped hyena in that menagerie already, definitely throw 1 of those in there at some point as well. Because I don't know you've seen just from the last few days, there was a usage of the stripe hyena to do a DNA foundation model, which is starting to show some really interesting properties. First of all, super long context. And stripe pain is it's a hybrid, so it has attention and also the state space. And my sense definitely is that some sort of mashup is gonna be the winner for a lot of scenarios in the future. Certainly in this DNA paper that just came out, they showed that the transformer was the least effective and then Mambo was a bit better, but then the hybrid attention and state space was the best performing. And it is kind of uncanny because it's sort of like, if you believe the theory that the language models are learning some sort of world model representation through next token prediction, By analogy, this seems to be learning some sort of, like, life model or maybe better said would be like cell model at the level of just next base pair prediction. It's a whole other world that's got obviously its own tremendous interpretability challenges, but definitely fascinating space. Plus 1 for the striped hyena in your next round of experiments.
S0 (1:00:23) We will be investigating these things carefully. It's a very exciting time as the world potentially moves on to whatever comes after transformers. It'll be very interesting.
S3 (1:00:31) Yeah. I always say it's not the end of history. You want to talk a little bit about the function vectors paper?
S0 (1:00:36) So the the basic question for function vectors was, Eric was very interested in asking how does in context learning work? It's just this amazing thing. I mean, it was the foundation for all the excitement around GPT-three was this ability for a model to seemingly learn how to do a task after seeing a half a dozen examples of the task being done. Then you give it another half worked example and it'll work it the rest of the way. It's just amazing that these models are so good at doing that. And so what Eric was looking for is, is there some localization in the in context learning task? What he found was that there is a bottleneck in it between when you give a set of demonstrations and when you ask a model to generalize. Yep, come join.
S3 (1:01:24) Eric Todd, welcome to the Cognitive Revolution.
S0 (1:01:26) So, yeah, Eric is the author of the function vector paper. So actually, maybe now that you're here, you can just sort of explain Yeah. What it
S2 (1:01:33) So we found that there are like a small set of attention heads that when you process an ICL prompt that mediate the identification of a task. For example, let's say you give up a bunch of examples of like antonyms of big and small, short and tall, and then you give it another query word like a bright. And the model, you can extract the output of a few attention heads from some other context where they saw a bunch of antonym pairs and take this output and stick it into this new context, and it'll give you the antonym of the word that it's processing. So if you take the same attention heads Yeah.
S0 (1:02:13) And you read out what those same attention heads are saying, if you said Paris, France, Moscow, Russia, Washington DC, United States, and then you said Ottawa, blank. Right? But then you read the attention heads at that moment, and then instead you have a totally new context where you have Ottawa or something like that, or you have Madrid or something like that. Even though those attention heads were the ones that seem to cause you to say antonyms in 1 of the experiments, in this experiment, if you read out the attention heads for the country capital task and you stick it in to the new context, it seems to encode
S0 (1:02:59) instead of doing antonyms or something like that. And what other types of tasks did you test? Tested how many?
S2 (1:03:03) Yeah, we tested like 40 plus different tasks, and it seems like they're all mediated by this small set of heads, which is kind of cool that it seems like the model has this sort of path that it communicates this task information, or like this bottleneck, I guess, that it's communicating what the task is that it's doing. Even though in the prompts, we're never explicitly telling it what the relationship between the demonstration and the label are, it's able to figure that out and communicate it forward.
S3 (1:03:32) Yeah. This is cool. I'll just resummarize it, make sure I have it right. So you, first of all, set up a task implicitly with a few shot prompt, which is the original GPTs are few shot learners. Right? The original paper title, few shot learners refers to the ability to under stand or infer and actually do the task just from a few examples. So pretty fundamental, notable behavior without question. So okay. Set up a couple of these examples, thus creating a a few shot prompt. 1 example is antonyms. Another example could be translation. So arrive, depart, small, big, common, rare. And then at that point as the pattern has been established, you go looking inside the transformer for, is there some and I'd I'd love to hear a little bit how you think about the term bottleneck, but you're going and looking for some sort of relatively small dimensional thing that you could extract and patch over to another context. And now when you find that thing, and obviously, that's where all the hard work is, but then the demonstration is that when you have found that thing and you do in fact extract it now in a clean setting without any examples, you can set something up and say, the word fast means and then inject this representation of give me the antonym. And instead of doing what a language model would normally do, which is define what fast means, it now says the word fast means slow because you've forced the antonym behavior with the patching from the few shot context now over to the clean context.
S2 (1:05:07) Yeah. You're right about bottleneck. So attention seems like a natural place to look just because that's how transformers move information between tokens. And we figured in order to gather information about the task from your context, you'd have to do that via attention. Like, if you never explicitly tell what the task is, you need to go back and look at all the other tokens to figure out what you're doing. And attention is it's nice because attention heads act on a small subspace, and so they're natural bottlenecks for, like, sparse activations. So what is it exactly that you are extracting and patching over?
S3 (1:05:44) Is it the output from an attention head?
S2 (1:05:47) Yeah. That's right. So you can take the output of these the small set and just add them all together. And that's the actual thing that we call a function vector is this sum of attention head outputs. And then you can add this vector into some other prompt at a particular layer in the model.
S3 (1:06:03) So it's 1 vector that represents the sum of multiple different attention heads, but you're selecting which attention heads to look at. And obviously, a big part of the challenge is figuring out which ones to look at. Once you have that sort of thing, does the relationship between tokens need to have some sort of length similarity to the original length relationship or not really because all of the information passing between tokens has already happened by the time the attention head is done. Right?
S2 (1:06:32) It doesn't need to. It doesn't seem like that matters that much. Seems like fairly robust to different settings. So we tried it sticking it into, like, natural text settings, different templated settings. It seems like it works just the same for long or short. Yeah.
S0 (1:06:48) Isn't that cool?
S3 (1:06:50) How do you figure out which attention heads to use, and what can we say about the relationship between the different attention heads that together make the vector? It's nice to think of this, like, single vector, but then it's weird to think of these different heads contributing to, like, part of this sort of single conceptual vector.
S2 (1:07:08) Yeah. So we found that a lot of these heads have had a similar attention pattern, which is kind of interesting. Most of them were attending to previous label tokens, which is like you have demonstration label, demonstration label. So they're basically looking at all of the answers of all the pairs, which kind of suggests that maybe, and some other papers have found, like when you do ICL, these label tokens seem to be sort of places where the model stores important information. And so there must be something about the pairs where it's able to figure out the task and it's transporting that task information from those labels onward. The way that we found these heads was we do a process called causal mediation analysis. Also, people sometimes call it activation patching. So basically, the setup is you have 1 sentence with a bunch of pairs of words demonstrating a task, and you have some other sentence where you can set it up where it's maybe demonstrating some other task. We decided to scramble all of the labels so that it wasn't really clear what the task was. And then what you do is you patch the activations from 1 attention head into this other setting where it's not clear what the task is doing. So you take it from the clean setting into the jumbled setting and measure how much it influences the model's output of the word that you would expect the behavior to induce. So for example, if in your jumbled setting, the query word is like big and you paste in an attention head output from an antonym setting, you'd expect that the model's output for small would go up slightly. And like, just doing 1 attention head isn't enough to flip the prediction, so we started trying multiple attention heads, and it seems once you get to around like 10, you can start having the significant causal effect of getting the model to flip its prediction and start understanding what the task is.
S0 (1:09:09) So to me, the big surprise was Eric got this working for about 10 attention heads for like antonyms. And then he looked at something else. I think in the paper we looked at like English to Spanish translation or something. He had a bunch of different tasks and he he went and he started looking at the attention heads for these other tasks. And we noticed that they were largely the same attention heads. And so if you just picked up your attention heads from 1 task and used them on a different task, that whatever information is in those same attention heads would tend to induce the new task too, which was like it was really surprising. Like, I could imagine that there's some antonym attention heads, but it was very interesting to find that there are these attention heads which are just, like, general function.
S3 (1:09:52) Task understanding. Yeah. Yeah. That's really interesting. And it looks like from the, figure 3 in the paper, they're found mostly in the middle layers, but not, like, all at the same layer. They're kind of mixed at different parts of the middle layers.
S2 (1:10:05) Yeah. Most of them in the middle, maybe there was like 1 towards the end, but yeah, it matches our intuition. I sort of see transformers as having 2 processes. 1 is figure out what I want to from the context and then use that to predict the next token. And so like, you could see maybe the first half of the network doing the figuring out, then the second half using that to predict a token. And, like, it's cool that right in the middle, right before it, like, starts to predict a token is where this process is happening.
S3 (1:10:33) Yeah. That's fascinating. Obviously, recently, mixture of experts has been a huge trend, right, with your GPT-four allegedly being a mixture of experts model and your Mixed Roll and, of course, now Gemini, a 1.5 has been declared to be a big mixture of experts model as well. So I've been looking into what is known, what is understood about mixture of experts. And it's like most of the mixture of experts papers are focused on swapping out of the MLP blocks because that's where, like, most of the computation is, at least if the context window isn't huge. But now, obviously, with context windows getting huge, you know, that's starting to even out or maybe even tip the other direction. And so I guess people are also increasingly starting to do a mixture of experts where the attention blocks are also switched out at at runtime. So I wonder how this kind of finding should help me think about that. It certainly, at first blush to me, sounds like, jeez, that's a pretty complicated structure. Yeah. There is like you've got 10 different heads, and you could have had more. Right? If I understand correctly, you kind drew up arbitrary line and said, like, this gets us most of the 80, know, Pareto principle type of thing. So it definitely seems like it's a sort of distributed process. The process of identifying the task is a distributed process. It's kind of, like, mixed throughout layers. Yeah. It is interesting. There's, like, 2 in the first layer where they exist, 2 in the second layer, 1 in the third, 1, 1, 1, 1, and then it goes quiet, at least in your top 10. I wonder maybe 1 way to think about it is maybe there's like 1 to 2 of these per layer. I don't know. It's an interesting just mashup. I wonder if you have any intuition for how this would help us understand mixture of experts.
S0 (1:12:11) Let me answer in my interpretability professor advisor way. I guess the way I look at it is this, is that we don't know yet what is the right level of abstraction that we should be looking at to understand these things. So we know the physical things that we made them out of, that we have dimensions, we've got nonlinearities, we've got modules of different types, we've got these MLPs and attention layers and cool state space recurrences and other neat things that people have suggested. And so in the end, it's all just matrix multiplications. Right? We know what we're building these things out of, but we're not sure yet exactly the right way of looking at all of these computations. And the attention heads, so ROME was a way of looking at the MLPs and saying, hey, there's this interesting association structure in these MLPs. It seems to be pretty informative. Although amendment sort of qualifies that and says, well, you know, it might be distributed on a few layers and it's a little bit more subtle. And all of this great mechanistic interpretability work, I think, has been a way of looking at all these attention heads and saying, oh gosh, these attention heads also reveal a lot of interesting computational structure. And I think that motivated a lot of Eric's work. He says, oh, is there an attention head? Maybe there's a few. And so there's definitely structure in the attention heads. To me, it's not a 100% clear that's the end of the story. That it may be like when Eric finds that a dozen attention heads are working in concert with each other, then it leads to the question, is there some abstraction? Is there some way that we should be looking at these attention heads altogether? And I actually see function vectors as 1 proposal for our way of doing that. You can summarize all of these attention heads in a single vector. And even though we've sort of blurred them together and we've lost a bunch of the fine grained distinction, that vector has pretty strong causal effects. It has some interesting properties that Eric investigated in the paper. To some extent, these vectors can be composed and do interesting things. And so I don't know if function vectors are gonna be 1 of the abstractions in the end. That's like the right way of looking at what these things are doing. But what it hints at is that there's potentially more structure to be found, maybe at a higher level even than what we've been looking at so far. And to me, the big game, the research game, the chase that we're all on in the interpretability world is to figure out the answer to the question, what is the right level of obstruction?
S3 (1:14:47) That's great. That might be the perfect note to end on. And you guys are doing admirable and certainly laudable job of chipping away at this massive problem. I'm a big fan of all the work that I see. Anytime I see a Bau Lab paper, I'm excited to get to the website and you guys do great visualization as well. So definitely encourage people to go to the individual paper websites and see the the visualizations, the animations.
S0 (1:15:09) For visualizations, let me give a shout out to Nikhil who just archived his paper. That's also gonna be at ICLR. We didn't talk about it today, but he has the best circuit visualizations out of anybody. He really has done a really great job at putting those together. So check out his paper. It's called fine tuning enhances existing mechanisms. So it's this whole study of the interplay between circuits and fine tuning. And it's also an interesting paper, but really beautifully presented as well.
S3 (1:15:34) Yeah. We will save that 1 for next time. This has been a great conversation, Eric Todd, Koyena Pal, and professor David Bau, thank you for being part of the cognitive revolution.
S0 (1:15:42) Thanks, Nathan, for having us.
S3 (1:15:44) It is both energizing and enlightening to hear why people listen and learn what they value about the show. So please don't hesitate to reach out via email at tcr@turpentine.co, or you can DM me on the social media platform of your choice.