Seeing is Believing with MIT’s Ziming Liu

Ziming Liu discusses Brain-Inspired Modular Training (BIMT) for neural networks, enhancing modularity and interpretability in AI and physics research.

1970-01-01T01:17:04.000Z

Watch Episode Here

Video Description

Ziming Liu is a Physics PhD student at MIT and IAIFI, advised by Prof. Max Tegmark. Ziming’s research is at the intersection of AI and physics. Today’s discussion goes in-depth on Liu’s paper “Seeing is Believing” where he presents Brain-Inspired Modular Training (BIMT), a method for making neural networks more modular and interpretable. The ability to see modules visually can complement current mechanistic interpretability strategies.

The Cognitive Revolution is a part of the Turpentine podcast network. To learn more: Turpentine.co

Links:
Seeing is Believing, authors: Ziming Liu, Eric Gan, and Max Tegmark https://arxiv.org/pdf/2305.08746.pdf

Ziming Lu: https://kindxiaoming.github.io/

TIMESTAMPS:
(00:00) Episode preview
(01:16) Nathan introduces Ziming Liu and his research
(04:50) What motivated Ziming's research into neural networks and modularity
(08:06) Key Visuals of "Seeing is Believing"
(08:44) Biology analogy & paper overview
(15:07) Sponsor: Omneky
(26:08) Research precedents
(30:00) Locality vs modularity
(37:40) Shoutout Tiny Stories - https://www.youtube.com/watch?v=mv3SIgDP_y4&t=115s
(39:00) Quantization Model of Neuro Scaling.
(1:00:00) How this work can impact mechanistic interpretability strategies
(1:27:00) Training
(1:46:30) Downscaling large scale models

SPONSOR:
Thank you Omneky (www.omneky.com) for sponsoring The Cognitive Revolution. Omneky is an omnichannel creative generation platform that lets you launch hundreds of thousands of ad iterations that actually work, customized across all platforms, with a click of a button. Omneky combines generative AI and real-time advertising data. Mention "Cog Rev" for 10% off.

MUSIC CREDIT:
MusicLM


Full Transcript

Transcript

Ziming Liu: 0:00 If we want neural networks to be interpretable, we should in training explicitly encourage them to be interpretable. That's when we posed this research question: what kind of training techniques we need to induce modularity in otherwise non-modular neural networks? Actually, all the results I got, I am just using my Mac M1. There are no GPUs. But the module addition takes around 20 minutes. For symbolic formulas, just less than 1 minute. The existential risk is a common threat to all human beings, and no one can survive if we don't collaborate. So with this, actually, this can bring together US and China collaboration on this AI safety thing.

Nathan Labenz: 0:50 Hello, and welcome to the Cognitive Revolution, where we interview visionary researchers, entrepreneurs, and builders working on the frontier of artificial intelligence. Each week, we'll explore their revolutionary ideas, and together, we'll build a picture of how AI technology will transform work, life, and society in the coming years. I'm Nathan Labenz, joined by my cohost, Erik Torenberg. Hello, and welcome back to the Cognitive Revolution. Today, my guest is Ziming Liu, grad student and physicist turned AI safety researcher in Max Tegmark's group at MIT, and first author of the recent paper, "Seeing is Believing: Brain-Inspired Modular Training for Mechanistic Interpretability." This paper immediately stood out to me for three big reasons. First, any work that makes deep learning neural networks easier to reverse engineer and interpret is, in my view, worth celebrating. The infamous black box problem, which in its simplest form simply means that we don't really know why AIs do what they do, is one of the biggest reasons to worry about what more powerful AI systems might do in the future. And personally, I see mechanistic interpretability work as one of the most promising paths to AI safety. Second, the technique in this paper is conceptually simple, but nevertheless profound. Taking inspiration from biology and the practical physical constraints that favor local connections and modular structures within animal brains, Ziming devised a straightforward modification to the loss function that encourages the development of similarly modular structures in digital neural networks. I suspect that this technique, perhaps more than any other I've seen so far in 2023, left many in the AI research community asking themselves, why didn't I think of that? And finally, as the title "Seeing is Believing" suggests, this paper includes phenomenal visual aids, which very quickly and intuitively communicate both the core ideas and the results while also inviting deeper study. Given how much I understood at a glance, I was really pleased with how much more I learned by digging in and exploring some of the important design decisions and trade-offs that Ziming made along the way. Because this work is so visual, we've taken the extra step of showing the figures in the YouTube version. So if it's convenient, I would encourage you to watch this one on YouTube. If not, you can definitely still get the bulk of the benefit by following the links in the show notes to the GitHub page, which has all of the important animations, to Max Tegmark's announcement tweet, and, of course, to the paper itself. Just a couple minutes spent looking at the animations and the images will do wonders for your understanding of this discussion. As always, if you're finding value in the show, we'd appreciate it if you take a moment to share it with friends or leave us a review. We'd also welcome your guest suggestions either by email at tcr@turpentine.co or via Twitter DM where I am at Labenz. Now, I hope you enjoy this illuminating conversation about neural networks made sparse, modular, and interpretable by design with Ziming Liu from MIT. Ziming Liu, welcome to the Cognitive Revolution.

Ziming Liu: 4:01 Thank you for the invitation. It's my pleasure.

Nathan Labenz: 4:04 Very excited to have you. Pleasure's all mine. You guys have published this really interesting work, "Seeing is Believing: Brain-Inspired Modular Training for Mechanistic Interpretability." And you had me at mechanistic interpretability, but especially when I saw the visual that you guys posted on Twitter, it's so rare to see in AI a mechanistic work that you can absorb and really start to get an intuition for in just a few seconds. Immediately, I was like, I need to read this paper in its entirety. And obviously, now here we are to dive deep into it. So, great job. Very exciting work. You started off with this question: What training techniques can induce modularity in otherwise non-modular networks? Maybe we could just start there and you can tell us about the motivation and how you came into this line of research.

Ziming Liu: 4:56 We all know that deep neural networks are successful but hard to understand. So to interpret neural networks, we kind of hope that neural networks can be decomposed into different modules and we divide and conquer. Once we decompose that, we divide networks into smaller parts like modules, then they are more amenable to interpretability. But we rely on the analogy to human brains, where in brains, different parts of our brains are responsible for different functions. But what's different for biological versus artificial neural networks is that artificial neural networks do not have incentives to become modular. By contrast, biological neural networks have evolutionary reasons to become modular because modular brains are more energy efficient than non-modular brains. That's why modular brains have a selective advantage in evolution and we human beings today have modular brains. So the philosophy is if we want to make artificial neural networks modular, we should in training explicitly make them modular or to even generalize beyond modularity, if we want neural networks to be interpretable, we should in training explicitly encourage them to be interpretable. That's when we posed this research question: what kind of training techniques we need to induce modularity in otherwise non-modular neural networks? And also borrowing the lesson from biology, we asked why are our human brains remarkably modular? It's because this is closely tied to locality. Because our brains, of course, we all live in the three-dimensional space where you can define distances and to create neurons in your brains, you need energy to create neurons to maintain them and also transmit signals. Local neuron connections tend to be more energy efficient than non-local ones. So the lesson we learned from our human brains is that locality gives rise to modularity and that is the key trick we used. That's a key strategy we used in the BIMT paper where we explicitly embed the network into a geometric space where you can define distances, and we have a penalty proportional to the length of the neural connections. We penalize non-local or long neural connections more than short or local connections to mimic these evolutionary effects to some extent. That's the basic idea of it.

Nathan Labenz: 7:58 You definitely want to look at some of the artifacts from this paper. As the title suggests, seeing is believing. So the audio-only form here almost can't possibly do this work full justice. But the good news is you can probably spend five minutes looking at the key figures, and they're simple enough and striking enough that you can absorb that information and keep that image in mind, and then listen to the rest in audio form if that's what you want to do. So definitely do that. But okay, so just to restate the key ideas. Obviously, we as people and as biological beings exist in space, we have this three-dimensional problem to solve. How do we get the most bang for our buck out of a certain sized skull? And that naturally means there's going to be a cost imposed on having super long-range connections across the brain. If nothing else, it's going to just clog up space to do that. But also, the cell has to get bigger. It seems like it might be more susceptible to damage or supply chain issues, so to speak. And so it makes sense intuitively that most of the brain activity would be local. And then we also see that the brain activity is modular in that there are these particular parts which seem to be fairly clearly—the boundaries are fairly clearly defined between them in many cases where this part does something and this other part does something else. In contrast, though, neural networks, as they've generally existed, just have this concept of layers, but there's not any sort of spatial meaning that is given to those layers. So whether two neurons or activations or positions in the network have the same index or are very far from each other in a layer, that has never really been something that has been taken into account in the process of training. So whatever connections happen to form at the beginning, those are the ones that continue to be built around and refined, and that's why we get spaghetti-looking networks.

Ziming Liu: 10:24 Yes. Normally the artificial neural networks we used before, we treat them as topological objects. So sorry for being a little bit mathematical, but for topological objects, there's no numeric measure to evaluate how close two points are. We only care about whether two things are connected, whether two nodes are connected by an edge or not. However, if we go to geometric space, for example, the 3D Euclidean space that we live in, we have the privilege to define distances. So actually the switching from normal neural networks to BIMT is just switching from a topological space to a geometric space, so to speak.

Nathan Labenz: 11:11 Yeah. I thought this is a really interesting concept in that it's ultimately a pretty simple and elegant one where you describe it so simply there. It's moving from this topological conception to a more literal spatial conception. And then if I understand correctly, the modifications that you're making to the loss function are also pretty simple. In terms of the code, we're probably talking about just not even that many characters of additional code to ultimately define that there is a length, and now it's going to be counted against you.

Ziming Liu: 11:54 Yes. So, basically, we want to sparsify the network so there is this L1 regularization. So naturally, even if we just do L1 regularization without this locality trick, you would require one hyperparameter, which is the penalty strength. So on top of that, we just have one additional hyperparameter, which is how much you want to encourage locality. So it's simply just L1 regularization, but it's distance dependent. And the only extra hyperparameter is the locality strength. When you put that locality strength to zero, you recover the vanilla L1 regularization. When you put it to be large, you encourage more locality. That's really simple implementation code.

Nathan Labenz: 12:44 L1 regularization is, again, motivated by the same thing. We want to make the network sparse. We don't want to have wasteful, needless extra connections. So let's just penalize, in this case, the weight itself. We want to optimize a loss function that includes both prediction accuracy, but now we're adding on to it. And we also want to minimize the total weights in the network overall. So all of the weights themselves get summed up into that loss function. And, therefore, when we take the derivative or we find the gradient, they're all naturally incentivized to go down. All the weights naturally tend down at every step?

Ziming Liu: 13:26 Yes. So the loss landscape would look like, for each weight, it looks like an absolute value for the L1 penalty. So it has this incentive to go down to zero magnitude, except that the prediction loss can also drag you somewhere else, but the penalty terms drag you down to zero.

Nathan Labenz: 13:44 Right. So only over rounds of gradient descent, only those weights that actually contribute to prediction accuracy enough to offset the weight that they—no pun intended—the weight that they are carrying against the loss function actually survive, and everything else just gets gradually pruned back to zero. So that's traditional L1. And now you're adding on this additional parameter on top of that that is—and you're doing both. Right? So you're penalizing the weight itself. Is it multiplied by the length? I was—I'm not the best notation guy, so as I was reading the paper, I wasn't 100% sure of all that.

Ziming Liu: 14:25 It's multiplied by the distance. Say for layer i and layer i plus 1, the closest two neurons, the closest connection are still nonzero. So every weight has a nonzero distance, so it gets pruned even if it's the most local connection.

Nathan Labenz: 14:47 Gotcha. So the minimum would be one if it's the same index in the layer from one to the next, but the more distance there is across the layer. And, again, go look at the figures in the paper. I guarantee it will really help make this much clearer. Hey. We'll continue our interview in a

Ads: 15:07 moment after a word from our sponsors.

Nathan Labenz: 15:09 So, basically, the penalty can get stronger proportional to the length going up.

Ziming Liu: 15:15 Yeah. So if you go straight, for example, if you go straight from this neuron to the nearest neuron in the next layer, the penalty is like one. But if you go if you have some certain angle, then probably the penalty is two. And then if you further tilt it away, maybe it becomes three. So the network has the incentive to make only the most local connections nonzero.

Nathan Labenz: 15:43 And then just for completeness in this regularization, there's also L2 regularization, and that is just a little bit less aggressive version because instead of adding the weight itself, you're adding the square of the weight, and that would tend to make a smoother curve around zero so you're not pulled as hard to zero, and so you get maybe a little less sparsity with L2 than with L1.

Ziming Liu: 16:16 Yeah. My experience with L2 is that it will shrink the overall weight norm of all the parameters, but not necessarily push down one single parameter to really close to zero so that you can really ignore it, prune it away. So they are—well, I think they are engineering tricks that if you really care about pruning, you should try both. They have their own advantages and limitations. Like L1 sometimes has optimization issues because when you really get down to zero, you bounce back and forth. But with L2, as long as your step size is smaller than some threshold, you converge steadily towards zero. But also for the reason I mentioned before, L2 does not encourage one single parameter to really drive it to zero, but encourages all the parameters overall to have a smaller weight norm. So there are these kinds of trade-offs you want to try when you do engineering.

Nathan Labenz: 17:18 So you run the process with everything that we've described so far. So there's this balance between all the weights tending to zero. The weights that have a long distance on the connection are particularly trending to zero. But, of course, the prediction accuracy itself is maintaining some weights up because you need to make good predictions. And so you run this, and you do indeed find and you show these images of sparse networks. But then you add this additional concept of neuron swapping, and that takes the outcome from sparse to sparse but also much more symmetrical, much more aesthetic, much more elegant. How did you come up with this idea of swapping? And then we can get more into understanding it as well. Nathan Labenz: 17:18 So you run the process with everything that we've described so far. There's this balance between all the weights tending to zero. The weights that have a long distance on the connection are particularly trending to zero. But of course, the prediction accuracy itself is maintaining some weights up because you need to make good predictions. And so you run this, and you do indeed find, and you show these images of sparse networks. But then you add this additional concept of neuron swapping, and that kind of takes the outcome from sparse to sparse but also much more symmetrical, much more aesthetic, much more elegant. How did you come up with this idea of swapping? And then we can get more into understanding it as well.

Ziming Liu: 18:16 Sure. So there's an anecdote behind it. When I first tried modular addition without swapping, the network looked a little bit like 3D DNA, the double helix, but projected to 2D. So it looks something like this. It just swaps and then swaps back. So I was thinking if I make these neurons movable, make their embeddings trainable, I should be able to avoid this topological problem. But then I asked my advisor Max Tegmark, and he asked me to try something simpler, as simple as possible. And so he suggested that I try the swapping thing. That sounds like the simplest thing we could try that can solve this double helix problem. So it turned out working pretty well, so that's why we ended up using swapping. But potentially one could also imagine making those embeddings trainable. And also in hindsight, we think that sometimes your inputs and outputs do not have a natural, you don't have a preferred ordering of your inputs and outputs. You want an automatic way, want your neural network to determine which kind of order you want. Perhaps you don't want to prespecify the ordering. So this is a really contrived example, but if your task is just swapping two numbers, so your input is x and y, x on the left, y on the right, and then your outputs are y and x, y on the left, x on the right, then your neural network would have two really long connections which also cross each other. This is okay because you can understand what's happening, but it's not aesthetic. We want everything to become local and modular, so naturally we want to swap these two things, either the input or the output to make the connections as short as possible. Just go straight to where it should go. That's the philosophy I think behind swapping. But for some tasks we find that swapping does not help that much. It basically just reorganizes things a little bit so that the pictures look better. But for other cases, it indeed helps and even helps neural network training. It helps us get lower training loss because you can imagine if without swapping, your neural network gets stuck at those weird double helix structures. And so you have a large penalty. A large penalty means that you need to sacrifice the prediction loss. So if you can swap these neurons cleverly to get a more locally connected configuration, your penalty drops down, which means given the same penalty strength, you can get a better prediction loss. So there's this hyperparameter like how frequently you do the swapping. I think I just set it to 200 or set it to infinity, but you could imagine that maybe there's a phase diagram, phase transitions like in physics, when we talk about ice, water, and gas, right? When we tune these hyperparameters, maybe there are different phases that can emerge. But right now, I just set them to be something simple, but more things to be done in the future.

Nathan Labenz: 22:13 And again, just to be super explicit, the process of the swapping is basically every so often, every however many steps in our training process, we add an additional step, which just goes down the layers and takes two different neuron positions and essentially says, would loss go down if I just flipped the position of these, where loss is already reflecting the length penalty? So you could just do a bunch of pairwise comparisons and find those enhancements. And essentially, it sounds like it's more of an aesthetic, as you said. It's also kind of helping you not get stuck in local minima because, and I guess how those local minima would potentially form would be you reach some kind of stable place where the prediction is pretty good and the length of whatever the structure has kind of settled into, the length could be shortened, but there's not really a mechanism other than the swapping to shorten the length if the other weight isn't helping with prediction just yet. Right? Like, a weight that's not active can basically not really be turned back on in this scenario. Right?

Ziming Liu: 23:49 We only care about important weights, in the sense that have a large value. So actually, we attach each neuron a significant score, an importance score. And we only care about swapping the top 10, say, important neurons with other neurons to see if such swapping can reduce the locality loss. And also we constrain, we only swap two neurons in the same layer such that the prediction loss remains the same, but the locality loss may change. And we want the swapping that would reduce the locality loss.

Nathan Labenz: 24:32 And the prediction loss would stay the same by definition, or am I missing something there?

Ziming Liu: 24:38 The prediction loss remains the same by definition. When I say I swap two neurons, I'm actually also swapping the incoming and outgoing weights. So basically the computational graph, topologically speaking, remains the same. So the function, the whole neural network as a function remains the same, but it's just geometrically speaking, it changes.

Nathan Labenz: 25:02 So that's true at that step, but within the broader optimization process, it can still help achieve better end performance because you are applying less of it. You're helping it kind of apply as little length penalty as it can along the way essentially.

Ziming Liu: 25:23 Yes, so at the single point when we do the swapping, the prediction loss remains the same, but in the long term, having connections in the right place encourages the network to decrease prediction loss. Otherwise, if the network got stuck at some bad local minima configurations, this would have a negative effect on prediction loss.

Nathan Labenz: 25:50 This feels like something, when I saw it, I was like, why did nobody do that years ago? Do you also feel that way? Like, why did nobody else think of this sooner? Or was there some unlock that you think made this kind of an idea whose time has come now?

Ziming Liu: 26:09 Actually, when we wrote the paper, we were not aware of those related papers. So like 10 years ago, Jeff Kuhn is an expert in this, and their papers basically focus on the biological side and they evolve their tiny networks with evolutionary algorithms just to mimic what's really happening in biological evolution. They have some follow-up which extended to gradient based methods, but still in different setups. So first of all, the previous works did not properly leverage, or at that time, the machine learning or gradient based methods are not as popular as they are now. So we're integrating the current machine learning tools. Secondly, we have different focus. Our focus is for doing mechanistic interpretability to understand how neural networks work internally. And again, as I said, they care more about how modularity emerges from biological systems. So I would not shamelessly say that this is a totally novel idea. We always stand on the shoulders of people before, and there are clever minds we always learn lessons from, but still we're in this different setup and we also integrate these new tools available only recently in the machine learning community. As long as it's a useful tool, I mean, there's no harm. We rediscover things, right, because we find valuable things. Only really valuable things can be rediscovered over and over again. That's it. I wouldn't say that it's totally novel, but it still has some contribution, I guess.

Nathan Labenz: 28:12 Everything's a remix, but this had a feeling of, I bet a lot of people saw it and thought, why didn't I think of that? Because it just feels so, once you especially when you see that animation, it's like, man, I definitely should have thought of that. So maybe it's more that I'm reacting to how compellingly the final result has kind of been packaged up and communicated in terms of the density of that realization than anything else, but it certainly creates that feeling.

Ziming Liu: 28:40 I see. Yeah. I myself was really shocked when I got the results. When I actually saw those connectivity graphs, I was shocked. How possibly can a neural network do this complicated thing with such sparse and modular networks and everything? Actually, because we are using this locality trick, and I'm a little bit worried that locality is not exactly modularity, but it works. So there's still a lot of things unexpected here. Even when we first came up with the idea, we're still not sure whether it can work on these examples, what we will end up with. And when I first saw the really sparse and modular graphs for those symbolic regression tasks, I'm like, we are really onto something and we should really dig deep into it and try a lot of examples that people already have tested in mech interp and make it a real thing. So yeah, a lot of surprises along the way.

Nathan Labenz: 29:46 So a couple little follow-ups here. In particular, I definitely want to dig in a little bit on the distinction between locality and modularity. And then I definitely want to get into also the things that you actually did and kind of understanding, you know, what can we interpret out of all this? Just one more little follow-up on the swapping. Does it seem likely to you that the swapping is more valuable at the beginning of the training process? Would you expect that front loading those swaps would be more useful than, like, late in training swaps?

Ziming Liu: 30:24 That's exactly the opposite of what I'm hoping for. Yeah, so right now I'm thinking of applying BIMT to pretrained large language models. So what I hope is that I can just take pretrained large language models, like at the end of the pretraining and apply, just do a little bit fine-tuning, because maybe we don't have the compute to do that. So hopefully the swapping would be most valuable at the end of the training. Since you brought up the point, I got a little bit concerned. Maybe I tend to agree with you that maybe swapping is most valuable at the beginning of the training when, you know, at the beginning of the training, there are like many lottery tickets, directions you can go, and swapping is the time that the neural network needs to decide which way you need to branch, which lottery ticket you want to branch into. And after you branch into that basin of attraction or lottery ticket, so to speak, swapping becomes no longer important because you're already in that basin. You only need to roll down to that bottom. You don't need to select which basin. So yeah, but that's just my conjecture. We'll see how it works on language models.

Nathan Labenz: 31:58 I'm sure you've looked at the Git Rebasin paper which had a similar concept. I think they called it, kind of confusingly, aligning the model. And there was some sort of swapping mechanism that was meant to take two same-shape networks trained on different things and somehow kind of align them to each other so they could be merged?

Ziming Liu: 32:27 Yeah. For the Git Rebasin paper, I think there's a follow-up paper which contradicts the idea of Git Rebasin called Mechanistic Mode Connectivity. So if a network has multiple mechanisms to do the same thing, for example, modular addition, you have many group representations that can achieve the same goal, then actually you cannot by just swapping neurons make them equivalent. And that's also observed in my code. This may not be something relevant, but it's interesting to note that when I run the modular addition example, to make my code reproducible, I should really set the precision to be double precision numbers. Even for float precision numbers, the truncation errors can dominate. This kind of noise is large enough to make you go to another basin. So there are exponentially many equivalent basins there. This permutation symmetry makes the neural loss landscape very complicated. Just a note, if you want to make the codes reproducible, you need to use double floats instead of, double precision numbers.

Nathan Labenz: 34:01 And is that 64? How many digits is in that?

Ziming Liu: 34:05 64, yeah.

Nathan Labenz: 34:06 Yeah. So if you were to round down, you're saying even at 32 digits, it still has this problem?

Ziming Liu: 34:13 Yeah. That's problem specific. Only for modular addition and permutation symmetry, for the algorithmic datasets, you need to set it to 64 digits because in addition to neuron permutation, you also have this representation equivalence. But for other tasks like image classification or symbolic regression, there is no such problem because we don't have this representation problem. But unfortunately for language models, you need to learn the token embeddings and they have such a problem. I'm not sure if there is any bug in my code, but this is what I found. For trainable embedding tasks like modular addition or in the future for language models, you need to use 64 bits to ensure reproducibility. That's a little bit unfortunate, but...

Nathan Labenz: 35:16 Modularity and locality then. I was thinking about that. Locality obviously doesn't necessarily imply clear boundaries. I guess that's the simple way I was framing the question. So if you wanted to be more encouraging of boundaries or sort of separation of concerns, is there a loss that you could conjure up to do that too? Could you sort of count the number of nonzero weights coming into a position as part of the loss? I'm sure you've thought about this.

Ziming Liu: 35:58 Yeah. So there are a bunch of off-the-shelf modularity measures in the literature, especially in graph theory. So the idea is basically you want to detect communities where you want to separate a whole graph into different subgraphs, where between subgraphs, the connections are quite sparse, but inside the subgraph, the connections can be dense. But across different parts of the subgraphs, the connections should be sparse, something like that. We already have such measures for modularity, but they're just defined on graphs, which are topological objects. They're good, but it's hard to visualize them. This is just some preliminary thought. What I'm thinking is maybe we can combine the BIMT penalty loss with some modularity loss people use in graph theory, something like that. But I'm a little skeptical of this because it's more like an engineering trick. There is not, maybe we can have something more elegant than this, but it's the simplest thing that comes to mind.

Nathan Labenz: 37:19 As cool as this is, we're not at the end of this line of research. That's for sure. Maybe worth checking out the Tiny Stories project.

Ziming Liu: 37:28 I noticed that project. That's super cool. Ziming Liu: 37:28 I noticed that project. That's super cool.

Nathan Labenz: 37:32 You mentioned the compute budget is obviously a meaningful constraint when you are trying to run this research, probably in any setting these days, but certainly in an academic setting it can be a challenge. Their explicit goal was to create a dataset that would capture the full complexity, let's say, or almost the full complexity of natural language while also being reduced vocabulary, relatively simple concepts a 3 year old can understand. And then they were able to get a lot more of the reasoning behavior that we see in larger language models. They were able to get it at, as far as I know, the smallest scale models that I've seen because essentially they narrowed the universe of what had to be learned while still preserving the value of learning reasoning in that context. So there could be something perhaps interesting there.

Ziming Liu: 38:36 This is actually related to one paper by our group led by Eric Michel called the quantization model of neural scaling. So we have this quantization hypothesis where we conjecture that a prediction task, say language modeling, can actually be decomposed into many subtasks called quanta, as in physics, we call them energy quanta. So those quanta, they have different importance. They appear with different frequency in natural language. So presumably we can order these quanta from the most significant to least significant. So we have a quanta sequence and we conjecture that neural network, based on its capacity, would first learn the most significant quanta and then in sequence with this decreasing order of importance. So I would imagine that to speak English coherently as a 3 year child would only require, say, the first 10 quanta. But to role play a physicist or say something more fancy than just what a 3 year old can, you need less frequent quanta that appear in the tails. So that's why you need a very large language model, simply because you need to memorize, so to speak, that kind of facts. But to speak English coherently, you don't need that large language model. So I just want to mention that the tiny stories also has a connection back to our quantum model.

Nathan Labenz: 40:34 That's cool. I could definitely imagine a reconvergence of these ideas, and it would be really interesting to see if you could actually start to get a visual on the different, you know, I've been using this term reasoning microskills to kind of emphasize that they're, at least as I'm imagining them, very discrete, probably very small little modules that do very specific tasks and ladder up to more general purpose reasoning. But it seems like what they see, and I guess this makes sense given what we understand about the mechanisms of training, is that super specific things like negation seem to come online as a discrete skill, a discrete skill in superposition, obviously, with a bunch of just memorized stuff and baseline priors. But it seems to kind of grok these very specific little skills in a discrete way. So, yeah, fascinating. Let's get to then some of the results. So just to recap, see if I can do this briefly. You say, hey, we've built neural networks, digital neural networks for the last 10 years on a very loose biological inspiration. But one of the biological realities that we've not really taken inspiration from is the fact that the larger the distance between 2 neurons in a human brain, the harder it's going to be for those cells to connect for very just practical physical reasons. What if we take that same concept and bring it back to the machine learning context? That turns into an additional penalty that corresponds to the length of a connection between neurons. It gets thrown into loss function. You reoptimize around that. You find that you're getting much better sparsity. But then you also realize, hey, I can sometimes get stuck in these local minima, so I'll add this swapping concept to come along and untangle and help me skip out of some of those local minima, and therefore, we can both get great performance and aesthetic organized, almost crystal like structures that we can visualize at the end. You start off, there's a series of experiments. The first one, and again, make sure you look at these things. It will help tremendously to see them. The first things you're doing are these symbolic formulas where basically you're taking a few variables coming in, and then it's the job of the network to predict function output values based on those input values. And it's presented as symbolic functions, but I was a little confused by that because to the best of my understanding, at the end of the process, the network itself is just numbers in, numbers out. Is that right?

Ziming Liu: 43:39 Right. So we synthesize the dataset with symbolic functions, we actually input x and y numeric values and output s also in numeric values, but f is pre computed as a symbolic formula of x and y. So, yeah, numbers in, numbers out, everything numeric.

Nathan Labenz: 44:01 Again, look at the graphs, but there's in each case, you're demonstrating this conceptual congruence, I guess, between the nature of the functions that the network is learning and then the structure of the network that you can look at. So just to be super concrete, there is a network that's trained on x 1, 2, 3, and 4, so 4 numeric input values. And then it's trained to predict 2 functional output values. And I was surprised that they're not super simple. One is x 2 squared plus sine of pi x 4. That's one prediction. And then the other prediction is x 1 plus x 3 squared. So in other words, take 2 numbers, add them, and square them, and then take the other 2 numbers, square one, take the sine of the other, and add that. Okay. That's kind of random. Is there any, should I understand that in any deeper way other than you picked 2 random functions and that's kind of that?

Ziming Liu: 45:09 Yeah. In fairness, this is the main thing. We want to see that the network can split into 2 parts, 2 independent parts. And for the functions, we just randomly take it on the top of my head, squared functions, cubic functions, sine, square root, something like that, as simple as possible.

Nathan Labenz: 45:31 Okay. So just I feel like I didn't maybe make that as clear as I could have. The independence concept is that because the 2 output functions that you're training the network to learn and predict numerically are such that one of those only depends on 2 of the inputs and the other one depends only on the other 2 inputs, then if everything is working according to our initial theory, then we should see that there's just 2 parallel paths through the network that don't interact at all because they don't need to, because information doesn't need to cross in that way. And, indeed, that is what you see. So you have, presumably due to the swapping, as it's shown in the paper, x 4 and x 2 end up together on one side, and they feed directly up to their output function. And then x 1 and x 3 feed directly up to their output function. I guess you probably chose it that way so that they were forced to be crossing initially, and then you see that they uncross as you go.

Ziming Liu: 46:33 Yes. We deliberately put x 1 to x 4, but the outputs, the first output depend on x 1 and x 3. The second output depend on x 2 and 4. We deliberately want to test if the swapping is effective enough to swap the input such that they group into the correct group. We are happy to see that it indeed works.

Nathan Labenz: 46:56 And they're extremely sparse. So these are literal graphs. Right? There are just 2 layers in the network that is trained to do this task.

Ziming Liu: 47:08 Yes. And I did do the test. So one might suspect that those connections you cannot see can actually contribute a lot, but that's not the case. I actually literally plot every weight, but the thickness of that connection is proportional to the weight magnitude. So if you don't see it, this means that the weight is too small for the naked eye to spot. And if I deliberately set those small weights to 0, the output of the network is not affected at all. This means that this is not a visualization effect. This is not a visualization stuff that makes you feel like they're small, they're still contributing. It's not that case. They are not contributing and they're indeed 0, close to 0.

Nathan Labenz: 48:10 That's fascinating. And I'm left feeling like this is a eureka breakthrough in the sense that, oh my god, look how simple the structure is. And in this particular case of x 2 squared plus sine of pi x 4, it is really just a couple of neurons that are doing the job. There's the 2 inputs, there's 3 active neurons in the first layer, there's 2 active neurons in the second layer, and there's only 2 layers. So essentially, there's only 5 total active neurons that are needed to translate these 2 inputs to that functional form output. So I'm like, again, that seems like eureka. But then, if I could be vulnerable with you for a second, I still don't really get how it's doing it. I'm looking at the graph and I'm like, what's not clicking into my brain is like, oh yeah, now certainly, I don't feel like I could diagram a 5 node thing and know how to predict that function. So how do you interpret that now that we have that super lean, super sparse thing? Is that very meaningful to you?

Ziming Liu: 49:32 So I can make some explanation what the neural network did. I can actually write down the symbolic formulas and trying to figure out what the neural network are trying to figure out. So my take, I was really shocked when I see the results too. But in hindsight, I think it's understandable in the sense that, well, just like physicists are always fascinated by the unreasonable effectiveness of mathematics, here I would frame our surprise as the unreasonable effectiveness of smooth activation functions. So I'm using SiLU or SELU as activations. They're both smooth functions. And in applied math, we know that if we want to approximate a smooth function with Fourier basis, the approximation error drops exponentially as we add more and more higher frequency modes. And the statement can also generalize to other smooth basis functions, not just Fourier modes, so SiLU or SELU in our case. That's, well, I understand that's not a very satisfying explanation, but my take is that, wow, these smooth activation functions are remarkably unreasonably effective.

Nathan Labenz: 50:59 Are there basically 2 nonlinearities? Like, you have a nonlinearity at each layer?

Ziming Liu: 51:05 Yes. It has 2 hidden layers. At each hidden layer, it has a nonlinearity.

Nathan Labenz: 51:10 If people know any activation function, they might know the ReLU function, which is 0 if the value is negative and then just y equals x if the value is positive. So it looks like straight line and then a sharp corner at 0 and then a straight line going up. And the SiLU function is essentially a curvy approximation of that, which remains differentiable at 0 for one thing.

Ziming Liu: 51:37 Right. So basically you have a ReLU function, but you drag around 0 to make a well there. So SiLU now becomes nonmonotonic. But asymptotically when your input is very small, negative, or very large, positive, the asymptotics are the same as ReLU, but it's just that around 0, it's differentiable.

Nathan Labenz: 52:07 So if you were to use ReLU, did you, I assume you tried ReLU and it doesn't work? Is that a safe assumption?

Ziming Liu: 52:14 That's a good question. Actually, somehow I never tried ReLU. It's my unreasonable craze for SiLU or SELU.

Nathan Labenz: 52:25 When you said it's the unreasonable effectiveness of the smooth activations, you're confident enough that this is the better activation function, you don't even need to try the old one.

Ziming Liu: 52:37 Right. And there are papers proving that with SiLU or SELU, this kind of smooth functions, can construct quadratic functions or multiply 2 numbers with just 2 or 3 neurons. So for example, for the squared function, the SiLU or SELU function is non monotonic, so there's a bottom. If you Taylor expand around the bottom, you got a parabola, the quadratic function. So this may sound like you only need one neuron to approximate the quadratic function, which is actually true in construction. But in practice, what really happens, because the chance of being initialized around the bottom is very low, so what really happens in practice is that neural network have 2 neurons. Both neurons have their own first order terms, but somehow in the later layers they weighted them such that their first order terms cancel but the second order terms survive in Taylor expansion. So that's where the squared function comes from. By leveraging this sneaky Taylor expansion trick that I myself did not think about it, neural network can be this sneaky, but the network just discovered this themselves. So it's a little bit shocking to me. In some sense, neural networks are more clever than myself.

Nathan Labenz: 54:18 Okay. I'm looking at this trained network. It takes these 4 inputs. It predicts these 2 functions. There's this remarkably sparse structure that is able to achieve this pretty nontrivial functional form with just a few neurons. And then on that, you report the loss. And the loss is on this particular thing, I guess it's actually this, it's the joint task of the 2 predictions. But nevertheless, the loss is 7.4 e minus 3 or 7.4 times 10 to the minus 3. I guess we can jump down to the bottom. Right? There you guys have some graphs in the appendix that showed this. I was just trying to figure out what does that loss mean in terms of is it really tightly fitting the functional form, or is it kind of loose around it? How close does it actually come to learning the functional form?

Ziming Liu: 55:24 Yeah. So I have the same concern when I just look at loss function. That's why in appendix, we also plot the scatter plots to see how well the predicted results aligned with the ground truth results. You see basically they lie on the line and the R squared is like 0.999 or something. So that's pretty good. But that's still, including this penalty can also still degrade the performance. Because as I said, to approximate the quadratic function, well, you can approximate a quadratic function acceptably well with just 2 neurons, but if you include more neurons, the approximation could be better and better. So there's this trade off between accuracy and sparsity. Presumably there's a Pareto frontier, something like that. There's no best solution. There's a Pareto frontier of the solution trading off between accuracy and sparsity. And choosing lambda, I mean choosing the penalty strength actually may make us move along that Pareto frontier. Maybe I set lambda to be small. I didn't try it. Maybe I should. If I set lambda to be small, maybe you'll see that there are 3 or 4 neurons, but the prediction loss could be better.

Nathan Labenz: 56:44 If you turn the prediction loss up to infinity, then presumably everything just goes to 0 and your predictions are all terrible.

Ziming Liu: 56:51 Right. Yeah.

Nathan Labenz: 56:52 And then if you turn it down to no penalty, then you just are back to the beginning where you have no incentive to sparsify in the first place.

Ziming Liu: 57:02 Yes. Yes.

Ziming Liu: 57:02 Yes.

Nathan Labenz: 57:03 Yeah. This is figure 10 in the paper, and the scatterplot is tight to the line. I can say it's basically indistinguishable from the line. And a couple of these others, you can see a little bit of wobble around the line, but it's still very close. Do you have any sense for how this works when it comes to all of these inputs are on the interval minus 1 to 1, right? Did you look at what would happen if you just started to go a little outside of domain on the input? Would that ruin everything?

Ziming Liu: 57:47 Yeah. Good question. Honestly, I didn't try that. I would imagine it will fail. I don't believe in systematic generalization. It's not like we have no information about the outside. So the network at most can just do interpolation in the range we trained on. So that's my guess.

Nathan Labenz: 58:12 That's maybe a good bridge to start to move into... Because we're still kind of climbing the ladder of complexity in these experiments. Right? The first one is purely a synthetic dataset derived by these symbolic functional forms, but the network itself is just taking numbers from some narrow interval predicting essentially the curve. We see that, amazingly, it can learn to do that with just very few active neurons, and we attribute that to the unreasonable effectiveness of the smooth activation curve. But then we could still ask, okay, we're seeing these kind of conceptual notions that we designed into the problem reflected back to us. When we set one up so that, in theory, we could have 2 completely distinct subgraphs, indeed, that's what we get. And when we set one up with feature sharing, where one of the outputs only depends on one of the inputs, and the next one depends on 2, and the next one depends on 3, again, we see the right information flow that aligns to our expectations. And then you've got the third one, which is compositionality, which starts to look a little bit even more tangled, clearly has some parts that I see as kind of interpretable, where it appears to me that a square root function, much like a square function consists of one neuron going to 2 and then back to 1. So we see these little motifs that kind of pop up. And that's all super interesting. But then you could still ask, has this grokked anything that's kind of more conceptual than learning the explicit shape of a curve? And your expectation is no. But maybe in some of the later experiments, that starts to become more relevant, right?

Ziming Liu: 1:00:10 Even if you cannot generalize outside the training data, it's still interesting to understand. I guess what physicists believe is that all the theories we have are effective theories. It's only valuable, it's only valid within certain energy scales, so to speak. Even if we cannot generalize, there's still a lot of interesting things we can say about at least for the task.

Nathan Labenz: 1:00:37 Cool. So the next one, I think, is a really nice little visual, the 2 moon classification problem. You've got a bunch of points in a 2D space. They're color coded for the viewers' help in interpreting them. The model itself doesn't get that. It only knows the position, right, and then has to use the position ultimately to make a prediction of which class the data point belongs to, and it has to learn this kind of curve that separates these 2, almost overlapping but not quite overlapping shapes. And so, indeed, it does that. And, again, you see a pretty sparse result. I was a little bit just struggling to interpret the last graph there where there's 2 numbers in, right, the x y coordinate, and then there's the out. But it seems like all the nodes are going to class 1 and class 2 has been sort of left.

Ziming Liu: 1:01:31 Yeah, that's a very interesting observation. I was surprised by how clever neural networks can be when I saw the connectivity graph. So for binary classification, what really matters is the difference of the logic of the 2 classes. That is the relative magnitude but not the absolute magnitude. So an efficient strategy for neural classifiers is to simply set class a to always have 0 logits while only learning the logit function for another class b. And a positive logit for class b means that the classification is b, and a negative logic for class b means that the classification result is class a. So if we also look at the evolution of the whole thing, the evolution dynamics is actually very interesting. There is this intermediate phase where you see that both classes have output lodges. They have network connections to the outputs, and they're almost symmetric. But as training goes on and the pruning and more and more weights that got pruned, you see that the network learns to transit from this symmetric phase to the asymmetric phase, because the asymmetric phase is more energy efficient requiring fewer neuron connections. So it's very interesting to see that there's actually this phase transition. At first it's messy, everything fully connected and then in the middle state, they're like this sparse network and also for the 2 classes. And finally, the network realizes that it only needs to predict the logic for 1 class while pruning away totally for the other class. So it's a very clever strategy that I learned from my neural networks.

Nathan Labenz: 1:03:31 I'm studying that final visual, and that definitely jumps out. It's basically reduced the dimensionality of the problem on its own to just having now to make 1 prediction instead of 2 effectively. I don't have any other immediate intuitions for the shape of what I'm looking at. Is there anything else you could say about that?

Ziming Liu: 1:03:53 They're sparse. What's good about sparsity is that now you know that there are just 7, if I remember correctly, weights in the graph, so you can just intervene any important weights or neurons to see what it did to the prediction results. I'm guessing because the problem is too simple, it does not have any meaningful structure to emerge. So maybe the whole thing itself is a module. So that's why we don't have very good explanation for it because itself, the whole thing is a module and we don't expect to have a good explanation for internal activations inside modules.

Nathan Labenz: 1:04:34 Yeah, especially for something as kind of arbitrary as learn to separate 2.

Ziming Liu: 1:04:40 Right. But in appendix, did intervention experiments. You can see what each weights and neurons are doing.

Nathan Labenz: 1:04:50 Yeah. It's amazing. There's only 6 active neurons across the 2 layers, 4 in the first, 2 in the second. Oh, you're actually eliminating connections in this case. Is that right?

Ziming Liu: 1:05:01 Yeah. I'm neutralizing individual weight, but yeah. In principle, can neutralizing the neurons. And also one thing to notice that, well, if you change random seeds, the graph would look completely different. That's another sign that there is no consistent modularity. But if we move on to the modular addition case in the next example, you would see that no matter what random seed you have, you almost always have 3 parallel modules emergence. So there, this is a more consistent thing you can say about the task. Yeah, random seed also plays a huge role here but my take is that if there's a consistent structure in the task, no matter what your random seed is, you should, for the most time, you should be able to find it. And in the most difficult case where only some of the random seed can find the structure, that's also fine. You can just select the most interpretable, can just run a 100 models with different random seeds in parallel and then select the one that you think you feel most interpretable to you, and you go from there to do mech interp.

Nathan Labenz: 1:06:24 It seems like we probably could graph, maybe even in just closed form. Again, I'm not amazing with all the linear algebra notation, but either in closed form or at worst, with just a mesh of points, you could kind of graph the value of class 1 as kind of a z axis over the 2 inputs and kind of get a visualization for essentially, you're learning a sort of elevation landscape.

Ziming Liu: 1:06:56 Yeah, that's correct. So in appendix, we write down the symbolic formulas explicitly. Basically, just to extract those weights and biases and write them down into a symbolic formula. So, yeah, as you said, in principle, you can plot the 3D surface plot like that. That may be more intuitive to see, but this can all, in principle, one can do that.

Nathan Labenz: 1:07:25 So, yeah, there's no... It makes sense then that there's not any super interpretable... There's not a 2 sentence summary of this because it's an arbitrary shape. It's almost like you just kinda sprinkled some stuff out there, and it had to kind of learn this particular shape. But that's not a super principled problem in the first place.

Ziming Liu: 1:07:51 Yeah. Or using the language of the quantization model I mentioned before, the task itself has only 1 subtask, so there's no need to modularize itself. If we're dealing with some compositional tasks like language modeling that actually involves many subtasks, then there you have the incentive to grow those modules for different abilities. But this 2 moon classification or later, the transformer example on linear regression, they're just one single task. At least I myself cannot imagine their subtasks that's underneath this whole task. So that's why the graph look a little bit less interpretable, look less modular, but that's because the whole network is presumably modules. It's probably a module and they are like this degree of freedom kind of thing that make the graph looks messy.

Nathan Labenz: 1:08:50 If you were to add a third moon to this, or a third region, would you expect that that would then create some sort of subtask that you could see reflected in modularity?

Ziming Liu: 1:09:08 Yeah. Maybe we can look at the MNIST figures. There, I'm not sure if I can say confidently I see something modular, but there's something like this pattern mismatching thing emerging that's also interesting. But I can't say for sure that there are meaningful modularity there. But that's a reasonable conjecture. Maybe classifying between 1 and 2 is a subtask and then 2 and 3 is another subtask, and then you can do the classifying 3 things by combining this binary classification task. I think that's a reasonable conjecture, but it still needs to be tested.

Nathan Labenz: 1:09:52 So next, modular addition, one of my favorite problems in the world due to all of the mechanistic interpretability focus on it. The big questions I had on this one were, is this doing the same thing that we've seen kind of prior interpretability work demonstrate that a grokked network is doing? Is it solving the problem in the same way?

Ziming Liu: 1:10:20 That's a very good question. Actually, we have a follow-up work that we're still working on. So the short answer is yes and no. By yes, I mean the circles, embeddings of those numbers are still Fourier basis. That's the same thing as in previous mech interp works. But by no, I mean that the model and the internal computations are different. We actually find a new, a different algorithm than what Neel Nanda and collaborators' paper described. So we call Neel's algorithm the clock algorithm, and while we discovered a new algorithm called the pizza algorithm, which algorithm the neural network ends up learning depends on architectures and hyperparameters and this can be very subtle. And there could be phase transitions from clock to pizza and also pizza to clock. So this is really funny to say out of context, but this is an ongoing work with my MIT colleagues, Jacob Andreas, and Max Tegmark. We'll post the preprint to arXiv soon. We discovered clock and pizza algorithm and we passed the algorithm, the network we got with BIMT, and we find that BIMT actually ends up getting pizza algorithms, the new algorithm we discovered, not what Neel Nanda and collaborators described as the clock algorithm. But no one is wrong, it's simply that mech interp can be much more subtle than we could imagine. It's very subtle depending on the architectures and also hyperparameters, all kinds of stuff.

Nathan Labenz: 1:12:06 So the lesson here is there's a lot of possible mechanisms to interpret, and which one you end up needing to interpret is kind of cast in the upstream decisions of exactly how you lay out your network and training process, and in this case, localization incentive as well.

Ziming Liu: 1:12:27 Right. Yes. In short, the clock algorithm is more accurate, but it requires more resources. By contrast, the pizza algorithm is less accurate, but it takes fewer resources. So that's why we see in our BIMT paper, there are 3 shapes, maybe I can explain more later, because each shape, each parallel module is an imperfect pizza algorithm. So the network needs to come up with some kind of error correction to make each imperfect algorithm to become a perfect one, to aggregate the results from each imperfect algorithm to aggregate them cleverly such that the final outputs are perfect. So that's the takeaway. We hope to post a preprint to arXiv soon.

Nathan Labenz: 1:13:24 Is there a simple... I mean, clock algorithm is you're doing modular addition, you take advantage of the cyclicality of the modular math, and so you sort of say, okay, I'm gonna rotate some and then rotate some more. And then if I get past the origin, it doesn't matter how many times I went around, it's just kind of the final position, right, that I need to look at.

Ziming Liu: 1:13:48 Right. But by contrast, the pizza algorithm adds 2 numbers so that the final prediction is more like slicing pizza and determining which pizza slice the outcome lies on. So the frequency of the pizza algorithm is doubling the frequency of the clock algorithm or the other way around. I can't remember exactly, but there are just this tiny detail there that actually distinguish between clock and pizza. If you don't look carefully enough, for example, if we use the metric Neel Nanda and collaborators used in their paper, we cannot distinguish clock and pizza algorithm. We came up with other metrics that can distinguish between these 2 algorithms. So there are still like this kind of metrics, each metric is coarse grained or drops some information of the whole system. So even if 2 things are different, but you coarse grain too much, you end up getting the same thing, you cannot distinguish 2 things. So my takeaway from this is there's still a lot to be done in mech interp, at least we needed to come up with more metrics to really distinguish those algorithms.

Nathan Labenz: 1:15:11 Yeah, I'd say we are just scratching the surface. So that's really interesting. When you say it is less accurate, is that a matter of getting problems wrong, or is it a matter of being less confident with the right answer so your loss is higher, but you're still getting all the answers right?

Ziming Liu: 1:15:38 So when I say clock algorithm or pizza algorithm, both algorithms have their perfect versions and imperfect versions. So what we end up getting in experiments are their imperfect versions. And the imperfect versions can make wrong answers on some samples, but it can be supplemented by and complemented by another imperfect algorithm which makes wrong answers on another subset of samples. If you have enough number, sufficient number of imperfect copies, at least 1 copy would give the right answer on every sample. So by aggregating the results, finally the final aggregated result would make the correct prediction on every sample, but for 1 head or 1 parallel module, it can only make prediction correct on some subset of all the samples.

Nathan Labenz: 1:16:44 So you essentially train n copies of this tiny little network. And then you find that, I guess, due to random seeds, they sort of cohere in different ways. Each one represents an imperfect approximation of the ideal algorithm, and then in aggregating them, you can kind of get good performance even if they have their individual weaknesses.

Nathan Labenz: 1:16:44 So you essentially train n copies of this tiny little network, and then you find that, I guess, due to random seeds, they sort of cohere in different ways. Each one represents an imperfect approximation of the ideal algorithm, and then in aggregating them, you can kind of get good performance even if they have their individual weaknesses.

Ziming Liu: 1:17:09 I mean, we don't explicitly have these n copies. We just have a whole network, but we can somehow disentangle and find that automatically their n copies emerge from training, which is really fascinating. And this is again an example of where I think, oh, maybe neural network is more clever than myself. It leverages kind of a better correction.

Nathan Labenz: 1:17:35 Is that what is meant by voting?

Ziming Liu: 1:17:37 Yes, exactly.

Nathan Labenz: 1:17:38 Okay. So as we're looking at this graph and, again, it's not that big, right? We're talking about two hidden layers. So that is to say inputs, two hidden layers. Again, I assume each with a nonlinearity as part of that layer, and then outputs. There's not really that much room to do the work here. I've thrown every graph in this paper at least me feeling the same way about that. And then each of these, as you look at it, it's just clear that there are like three distinct modules that go from the inputs up to layer two. And then at layer two, it's basically all kind of fully connected, it looks like, again, to the output. But those three distinct sections, you're saying, now I'm understanding these shapes more deeply. So each of those is an approximation, and I guess you can see that because you can do the ablation, you're still like...

Ziming Liu: 1:18:35 Yes. So each copy is an algorithm that tries to perfectly do the modular addition but fails. So somehow the neural network, at some point, the neural network figures out that it's more expensive to form one perfect copy of a single organism, so instead I form like three copies of imperfect algorithms and somehow aggregate...

Nathan Labenz: 1:19:00 Or some sort of majority voting to get the end results. And how literally do you understand that concept of voting? I always push really hard on analogies because I find I confuse myself as often as I clarify things for myself.

Ziming Liu: 1:19:17 Yeah. So the terminology voting actually comes from, so we borrow the term from error correction in information theory. So let's say you want to communicate a classical bit, zero or one, over some channel. The channel is imperfect. It can randomly flip the bit with some probability. So to achieve better accuracy, there's something called the repetition codes, which basically you repeat the bit for three times or five times or even more. So then the receiver can infer the bit by doing a majority voting of the three bits. The probability of making an error reduces exponentially as the repetition times grows larger. So this analogy of repetition codes also applies to modular addition here. Each shape or each module is an imperfect algorithm of modular addition. One module alone can make mistakes on some examples, but by aggregating these three modules together, or when I say vote, I actually mean aggregating just the last linear layer aggregate their results, but just by aggregating or voting their results, they can end up with perfect classification results.

Nathan Labenz: 1:20:45 In the kind of bit level example from the information theory, you basically repeat the thing, you repeat the signal three times, and you trust that it's unlikely that two out of three get degraded, so I can trust two out of three. That's the basic intuition.

Ziming Liu: 1:21:05 Yes, and also, there is this independent voter. Like, the flipping things are like independent.

Nathan Labenz: 1:21:14 So I'm trying to map that concept onto the data that I see here, and I would expect, and I'm being very literal. So my approach in this is trying to be very literal and see if like I get confused. And I am getting a little confused because, I guess, if I was thinking of it in like a strict analogy and I said, let's imagine that I'm using this voting approach, but then I just eliminate one of my inputs. Now I've got two inputs. Now, like, if they agree, great. If they disagree, I'm at like 50-50 chance. But what it seems here is like you knock out one and your performance is like much more degraded than 50-50.

Ziming Liu: 1:22:00 Yeah. So I think the majority voting thing is more like a metaphor. What really happens is the linear aggregation of the outputs, and it's more like cooperation. It's more like, so in majority voting, we assume the voters to be independent to each other. But here, we sort of, like, since we're adding this kind of penalty to the network, the three modules really need to cooperate with each other. So they rely on each other. Somehow they talk to each other in training. So, for example, the first module says I take sample A, and then the second module says then I will take sample B, something like that. But even more complicated because it's a linear combination of things. We cannot partition things that clearly. In principle, we can understand, we can try to understand the structure more clearly, but I haven't done that. But I think it's something to investigate in detail.

Nathan Labenz: 1:23:06 So the major basis, I guess, then for the kind of notion of like a voting metaphor, or the reason that you're saying that you interpret these three modules as kind of each an independent approximation, is more anchored to those shapes, I guess? Like, you show kind of the spatial representation, and that definitely looks like something.

Ziming Liu: 1:23:37 Yes, yes. I think it's more anchored on the shapes. And it's surprising to me, at least, to see that three modules emerge. Finally, I understand that, oh, there is this error correction mechanism that we didn't think about. Yeah, it's amazing. We learned something from neural networks.

Nathan Labenz: 1:23:58 When I looked at the next problem around permutations, anything you want to say on that that kind of jumps out as the most important takeaway for you?

Ziming Liu: 1:24:08 So the best part I love about the permutation example is that Neel and the collaborators found that the network can automatically learn the group representations. And for S4 permutation group, it has a nine-dimensional representation, and they still need to probe and find where to look at. But this nine-dimensional representation just naturally emerges on privileged spaces, like aligned with the neurons after training. And there are exactly nine active neurons, and there's exactly one sine neuron, the 22nd magic neuron corresponding to the sine of the representation. And all the other neurons we can also explain them based on Cayley graphs, and it tells us that the neural network indeed leverages the structure of the group datasets. So I think the main takeaway is that with BIMT, you can clearly know where you should be looking at by just looking at the graph. But with other methods like probing or intervention, at first, you have no idea where to look at, or you need to really understand your problem so that you know where to look at. But BIMT, you don't need to know anything about the task. You just look at the graph and you can say something useful about it.

Nathan Labenz: 1:25:45 Yeah. I have this visualization sometimes in my mind of like shrink-wrapping conceptual reality. And I don't know, there's probably a lot of, that's kind of an analogy. So my analogy detector goes off, but it does feel like you've created another way to create additional kind of negative pressure to like suck all of the extraneous activity out of the network and then it kind of, boom, snaps or coheres right into the appropriate dimension.

Ziming Liu: 1:26:21 So in the paper, we only have those static images, but if you look at the videos I posted on Twitter and also GitHub, this is exactly what you described. You seem to have some kind of pressure that pushes inwards and shrink-wraps the whole thing. Yeah, really amazing. I encourage everyone to look at the videos, to watch the videos.

Nathan Labenz: 1:26:43 Okay. So we're getting into the homestretch. So the next thing you do is then extend this to the transformers. Everything we've talked about so far, I wanted to ask you, like, how long does it take to train these really simple networks? Like, if you were to set up one run to train the modular addition network, and it goes through how many steps?

Ziming Liu: 1:27:09 20,000, I guess.

Nathan Labenz: 1:27:11 So how long does that actually take just on a fairly standard issue computer?

Ziming Liu: 1:27:15 Yeah. So, actually, all the results I got, I'm just using my Mac M1. There's no GPU, so for the modular addition, it takes around 20 minutes. For symbolic formulas, just less than one minute. The transformer example is the most time-consuming one, almost ranging from one hour to two hours at most. So if you want to play with BIMT, I encourage you just start playing with the symbolic regression example. You can even make modification to the code. It's very fast iteration because to get the results, it only takes less than one minute to get the results. So it's a good point to start.

Nathan Labenz: 1:28:07 So it basically can feel with, I guess, with these later projects, not quite, but with the simplest cases, it's almost like kind of real-time. I presume you could even just kind of, although, yeah, I guess if you don't go all the steps, you don't see all the sparsity, so you do kind of have to run it to some sort of completion. You mentioned the transformer portion. It wasn't entirely clear to me. Are you modifying the attention mechanism as well as the fully connected portion of the transformer?

Ziming Liu: 1:28:42 Yeah, yeah, yeah. That's a very good question. So think about the attention layers. If we put aside the softmax part of the attention layer, the attention layer is just three matrix multiplications, the key, query, and value. So these linear matrices, we can understand them as a linear layer with zero bias. So we apply the same trick we used in MLPs, just treated as a linear layer. And yeah. So attention layer, the softmax has no parameters, so we don't have to worry about that. We just need to deal with these linear matrices as in MLPs.

Nathan Labenz: 1:29:28 So the representation of this then for kind of purposes of the loss function is in the like fully expanded form? Like, you're applying, you have this sort of three multiplications, but you're treating those as essentially one giant...

Ziming Liu: 1:29:46 That's a good question. So there are three matrices. I overlapped them. I stacked them along a third dimension I actually did not plot. And the shift along the third dimension is very small, just to visualize it. But like, because query, key, and value, they're actually, in some sense, they're equivalent, so there's no sense to separate them in space. So they are overlapping each other. But visually, you can think of them separated by small amounts along the third dimension.

Nathan Labenz: 1:30:20 Would this have, for kind of scaling up purposes, I'm not great with this sort of optimization of this, but I definitely followed Neel Nanda's work quite a bit. And what I understand from, I think, taking away from some of his YouTube tutorials is that there's a divergence between the representation of the attention mechanism that is most compute efficient and, therefore, actually like gets used in code from a more purely analytical representation that's like easier to interpret. And if I understand correctly, you're using, you're applying the penalty in a way that works on that more analytically sound, like not separated, but maybe like less compute efficient representation. Do I have that right, or am I going off track somewhere?

Ziming Liu: 1:31:20 Sorry. What do you mean by low compute representation here?

Nathan Labenz: 1:31:24 So you had said a second ago that the sort of multiple matrix operations that constitute the transformer mechanism, in analytical terms, are not really separable, but they are in practice. They are sort of coded separately, as I understand it, for efficiency reasons.

Ziming Liu: 1:31:45 I see, I see. So I treat them independently.

Nathan Labenz: 1:31:49 So the penalty is added in and calculated in such a way that the process of calculating that penalty, the whole process would still scale just like a normal transformer scales?

Ziming Liu: 1:32:03 Yeah, yeah, exactly. There's no additional overhead in this.

Nathan Labenz: 1:32:07 Can you maybe just help develop a little bit more of the intuition of how exactly should I understand the locality penalty in the context of attention? In the above, like, simple graphs, it makes total sense. It's just, okay, 2D space. Boom. I got it. Is there anything more that I should kind of be intuiting for the attention version?

Ziming Liu: 1:32:31 Yeah. So that's a very good question. Actually, it takes a lot of modifications to make BIMT work on transformers. I don't have an intuitive picture for that, but like the modifications are you need to consider the head, the multi-head attention, and that affects swapping. We cannot just swap in two neurons in the same layer. We can only swap the two neurons in the same layer and also in the same head. But we can also swap two heads as a whole. We can only swap two heads. And the residual stream is another pain in the ass because with the residual stream, the neurons get aligned, so the permutation along the residual stream are in some sense locked, so their permutations are not independent. The residual stream in different layers, they share the same permutation. That may also be one thing that makes the network less interpretable or less local.

Nathan Labenz: 1:33:55 And that's kind of presumably related to like superposition or not really, I guess. You could basically, the idea is that you would have to untangle if you're going to swap, you'd have to swap at every layer of the residual stream?

Ziming Liu: 1:34:10 Yes. If you swap, you need to swap every layer of the residual stream. So that sounds to me like a constraint. It's undesirable because if you have this constraint, you have to consider all across all the layers, so you cannot untangle some non-local connections in some certain layer because you need to consider everything globally. Maybe there are better ways to do this. So again, I think BIMT is still, there's still a lot of things to be done to improve BIMT. And also another pain in the ass is the layer norm. So I think Anthropic published their results also arguing that layer norm did something sneaky. So in our case, we also dropped, so they didn't drop layer norm, they just say, oh, layer norm is pain in the ass. But I dropped layer norm entirely because it normalizes things. Sometimes even if it has a small input, it still normalizes to zero to one range. And that's something we don't want with sparse neural networks. Yeah, so again, so there are a lot, there are still engineering tricks that we need to incorporate to make BIMT work on transformers. This is just a first step that I want to show that while in principle, BIMT is a high-level idea that you can apply to any architectures as long as we value architecture in a geometric space. And this is just a prototype example, but a lot of things can be done in the future. Ziming Liu: 1:34:10 Yes. If you swap, you need to swap every layer of the residual stream. So that sounds to me like a constraint. I think it's undesirable because if you have this constraint, you have to consider all across all the layers, so you cannot untangle some non-local connections in some certain layer because you need to consider everything globally. Maybe there are better ways to do this. So again, I think BIMT is still, there's still a lot of things to be done to improve BIMT. And also another pain to ask is the layer norm. So I think Anthropic published their results also arguing that layer norm did something sneaky. So in our case, we also dropped... so they didn't drop layer norm, they just say, oh, layer norm is pain in the ass. But I dropped layer norm entirely because it normalizes things. Sometimes even if it has a small input, it still normalizes to 0 to 1 range. And that's something we don't want with sparse neural networks. Yeah, so again, there are still engineering tricks that we need to incorporate to make BIMT work on transformers. This is just a first step that I want to show that in principle, BIMT is a high level idea that you can apply to any architectures as long as we value architecture in a geometric space. And this is just a prototype example, but a lot of things can be done in the future.

Nathan Labenz: 1:35:54 That brings us, I think, to the last experiment, the last visualization, which is applying this to 3D space instead of 2D space. And here, you're using the classic MNIST handwritten digit dataset. And this, of all the visualizations, I don't know, there's some good ones, but this one might be the best because you see at the beginning just a massive 3D tangle of intersecting weights. And then over the course of all the steps, it gets just much more, first of all, kind of cropped in to the actual space where the images actually are, and then just obviously a lot more sparse. It seems like in this one, again, what is it really doing is sort of limited, so we still need the kind of, we still need the NeilNAND to breakdown, but you've kind of run the first leg of the relay in that you have, by sparsifying things, made it pretty clear where to look, as you say. And then next step is for somebody to kind of figure out to what degree can we kind of assign any meaning to this. And maybe none, I guess. Right? It could just be that this is like a super tight module that does this really weird task. And like, the shape of the numerals themselves, as played out across a ton of handwritten versions, is just weird. There's not necessarily more to say about it?

Ziming Liu: 1:37:32 Yeah. That's a reasonable conjecture. It could be that the task could be just too simple. The whole thing viewed by the network is just a whole task. There's no subtask in it. Like, even linear classifiers can get above 90% accuracy on MNIST. It could also be that vision tasks, at least for image classification, demonstrate less modularity than language tasks. As I imagine, the language task presumably has many, has many subtasks. And also the chain of thought reasoning thing in language models might be related to modularity. You need each step to have a concrete reasoning step. So I personally think that maybe BIMT's ability to make language models more modular is more than vision models. That's what I'm imagining. So about next step, I think I will try things on language models. But also, I think it's important lately, AI for science tries all kinds of scientific problems, where interpretability people also care a lot about interpretability in scientific problems. So I think that's where BIMT can also play a huge role. So my next step would be language models and scientific problems.

Nathan Labenz: 1:39:03 Cool. Do you want to comment just for a second on the pattern anti-matching followed by pattern matching? And so it's kind of like, there's the input layer, and then in kind of the, there's, again, just 2 hidden layers. Right? Between the first and the second hidden layers, it looks like basically all the weights are negative. And then from the second layer to the output, all the weights are positive. And you're interpreting that as pattern anti-matching. I guess my answer is probably going to be the nonlinearity, because I was initially thinking if we just change the sign of all the weights, would it still just work? But the answer, I guess, is no because of the nonlinearity.

Ziming Liu: 1:39:49 Yeah. I think nonlinearity also plays a role here. So in the appendix, we also tried 1 hidden layer and 3 hidden layers. We find this pattern mismatching thing is consistent no matter what. Yeah, no matter how deep your neural networks are, at least for 1 or 2 or 3 hidden layers, they're consistent. So I'm guessing there's something deep or fundamental about it, like this kind of strategy is more biologically or energy efficient than what we would imagine. Like how we classify, how we do image recognition. We would previously imagine that we want to do pattern matching. We have a template of things. Say, we have a dog feature, and we scan over the whole image and try max pooling to see if there is a dog on the image. But it seems like by having this BIMT strategy, you end up doing something opposite. Like you want to tell, oh, this image does not have a dog, so it's more likely a cat, something like that. So I can't say that I have fully understood this. There's still things to be understood.

Nathan Labenz: 1:41:19 Anything that you would add just to kind of wrap things up on the paper front before just asking a few more general questions?

Ziming Liu: 1:41:26 Yeah. So one thing I still want to note is about the next direction. You mentioned tiny stories. So that's one direction I'm looking to lately. One thing I'm hoping that BIMT can achieve on tiny stories, which is already achieved for module addition, is that the learnable token embeddings have privileged spaces or better aligned with coordinated basis. For example, there's presumably a direction in the embedding space corresponding to color, say. With unprivileged basis, you need to try really hard to search for that color direction and then project token embeddings to that direction. But there are actually infinitely many possible directions. But with privileged spaces or with BIMT, very likely there's just 1 dimension of the token embeddings corresponds to the color dimension. So you can just obtain that direction by enumerating the finite number of the embedding dimensions. So yeah, one thing I'm thinking about is whether it can scale, whether BIMT can scale to large networks. So my only concern so far is that it might be harder to visualize the whole network. This is an apparent difficulty, but other than that, the way I'm seeing it, I don't see any deal breaker here. I don't necessarily see any factor that's blocking it from scaling.

Nathan Labenz: 1:43:01 I'm actually doing another interview with one of the authors of the MegaByte paper that has kind of the hierarchical approach. This seems like it could sort of play nicely with that as well, perhaps, although I don't quite know how. But it has these sort of a global model and then a bunch of these local models. And I don't know, have you thought about any sort of trying to apply this to any sort of more hierarchical structure, such that modules could sort of naturally cohere in different places?

Ziming Liu: 1:43:37 Yeah, that's a good question. Actually, like I mentioned, an outstanding researcher, Jeff Kuhn, on connection cost thing like even 10 years ago, but in biology, right? They discovered that by imposing this connection cost thing, you not only induce modularity, you also induce hierarchy, and also allows you to do lifelong learning or continual learning. So a lot of promising directions. But the regularity thing, as you said, like, is it possible to have one module that can be copied all over the network? For example, there's a basic unit that you need to record it multiple times. Unfortunately, BIMT or this kind of connection cost techniques itself cannot induce this kind of repetition or regularity. And you need some... so this is another thing I'm thinking about lately. I want to draw inspiration from biology and some kind of stuff. And there's some literature like trying hyper networks or some compositional producing networks, something like that, to make sure the regularity or the repetition of the module to reuse the modules. But that's something orthogonal to modularity and hierarchy. I think there are many traits I'm all very interested in, like modularity, hierarchy, locality, module reusability. Maybe they're connected somehow, but at least BIMT cannot solve the last one, solve module reusability. That's one thing I'm thinking maybe I can improve, integrate some kind of hyper network technique into BIMT to make it also leverage this kind of module reusability structure. Also this is for better efficiency. If we can decompose a network into different motifs, we only need to restore those motifs and how these motifs connected to each other. We don't need to store the motifs, store the same motifs for thousands of times. So that's also for more efficient storage and more efficient inference for especially for large scale models. If we can downsize large scale models in this way, decomposing them into motifs, that would be awesome.

Nathan Labenz: 1:46:21 And the inference time efficiency gains, at least on these like, toy models are pretty huge, right? I mean, if you run the fully connected version versus the, like, say you take the sparse version and actually prune it, and you just only have the fewer weights at the end, presumably, it's like under 10%, the compute cost at inference time.

Ziming Liu: 1:46:43 So if we explicitly sparsify the network after training, the inference time would be very much reduced. Like in the module addition example, like there's 3 parallel modules emerge from training. We don't know a priori that there are 3 parallel modules, but they just emerged. With that, we can make them parallelizable. For example, we can run these 3 modules in parallel and then use 1 machine, in principle, and then use 1 machine to aggregate the whole results. So if this kind of parallelization can also generalize to language models, that would be fascinating because then the inference time can be saved. You can leverage all sorts of multi-threading or parallelization techniques we have at hand to make sure that the inference time and the inference memory can be saved by a large margin.

Nathan Labenz: 1:47:50 That seems extremely compelling, actually. So Max Tegmark is your adviser. He has not, for most of his career, been in AI. Did you... did he already like, make the switch into AI and then you joined the group after that switch? Or tell me the story of like, working with somebody who's obviously got some special ability but is new to AI.

Ziming Liu: 1:48:13 I cannot recall the exact year Max switched to AI, but I joined the group after Max switched to AI. And when I joined the group, we are still focusing on AI for physics. Because Max used to be a cosmologist, has strong background in physics, he has position in department of physics. So, and also I got admitted to the department of physics at MIT. But I did research on AI for physics with Max for 2 or 3 years. My, yeah, I mean, my third year right now. And then Max sort of decided to switch entirely the focus to AI safety, mechinterp specifically, because like trained as physicists, we think that we need to first fully understand something so that we can fully control it. We can fully keep AI systems under control. It did no harm to our human beings. But the first step is basically to understand it.

Nathan Labenz: 1:49:22 Yeah. I've been an admirer of his work for a number of years, and it's been cool to see that he's been so flexible and adaptable and, obviously, has been able to attract some bright minds to the group to do some exciting work. I'm struck by the fact that, from what I understand just on your website and briefly looking into your background, you grew up in China, went to university in China, and then came here after university. Is that right?

Ziming Liu: 1:49:49 Yeah. So I went to Tsinghua University in the physics department, school of physics, and then Max hired me as his PhD student. So, yeah, I've been here in the US for like 2 years.

Nathan Labenz: 1:50:06 So does the general, like broader context of sort of US-China tension and the increasingly, like, center stage that AI has in that broader debate, is that like, relevant to you personally? Do you just try to stay out of it, or how do you engage with that subject, if at all?

Ziming Liu: 1:50:25 I think I agree with Max on this point. Like, the extinction risk is a common threat to all human beings. So if we see this as kind of extinction risk, no one can survive if we don't collaborate. So with this, actually this can bring together US and China collaboration on this AI safety thing. I think collaboration is the right way, not opposing each other.

Nathan Labenz: 1:50:57 Yeah, totally. I try to highlight any instance of positive US-China relationship and collaboration that I can, so your time here and work on this project is a great example of that. And I'm really appreciative that you've spent so much time walking through it with us. Keep up the great work, whether it's here or whether it's one day back in China, or hopefully, you can continue to span the two indefinitely. We certainly need a lot more of this ability to understand because, as I think you guys put it very well, understanding is usually a precondition for control. And everybody has a shared interest in figuring out how to control these systems so we can get the best from them and avoid the worst.

Ziming Liu: 1:51:51 Yeah. Sure. Thank you so much for the invitation. Yeah, it was great speaking to you. And have fun with your family.

Nathan Labenz: 1:52:00 Thank you very much. Ziming Liu, thank you for being part of the Cognitive Revolution.

Great! You’ve successfully signed up.

Welcome back! You've successfully signed in.

You've successfully subscribed to The Cognitive Revolution.

Success! Check your email for magic link to sign-in.

Success! Your billing info has been updated.

Your billing was not updated.