Very concise summary of the procedure described in this paper:
1. Run the model once across a dataset to estimate loss curvature per MLP weight matrix via K-FAC (activation/gradient covariances).
2. Decompose each weight matrix into curvature-ordered components; low-curvature directions correspond most to verbatim memorization, higher curvature to shared/general mechanisms.
3. Edit by dropping the low-curvature subspace and keep only the top directions.
vessenes•2h ago
Thank you for this huge time saver.
Now, about the paper-that’s super interesting. I imagine the dream here is to distil down into a “reasoning” core. Or maybe reclaim space for more generalization. Lots of interesting use cases.
getnormality•2h ago
Thank you!
I think you may have accidentally switched low and high in #2, no? The abstract speaks of high curvature as associated with memorization:
> curvature for memorized training points is much sharper than non memorized
radarsat1•1h ago
This sounds more correct to me. I've read previously somewhere that better generalization is usually associated with wider, smoother minima, and this is why regularization is important, because it has a smoothing function on the loss landscape.
getnormality•1h ago
Yes. This is also not hard to see intuitively from scratch.
Say you have a smooth but highly flexible model y = f(x) and some data points you are fitting with a machine learning algorithm. For whatever reason, the algorithm decides it wants to reduce training error by interpolating some specific point, (x0,y0), without negatively affecting training error on nearby points. The direct, guaranteed successful way to do this is to adjust the model to y0 = f(x0) exactly on x0 by adding a Dirac delta there, leaving the rest of f exactly as-is. But this cannot be done on a differentiable model, as it would create a discontinuity. The next best thing that such a model can actually do is replace the Dirac delta with a smooth but very narrow bump (e.g. Gaussian). But this narrow bump will inevitably have extremely high curvature at x0, since the bump is flat at x0 and it has to merge with the neighborhood around x0 in a very short distance.
Think of driving: if you have to change lanes in a very short distance, you're going to have to steer hard. Steering is curvature.
woadwarrior01•1h ago
That's very reminiscent of the idea behind the SAM (Sharpness Aware Minimization) family of optimizers.
andy12_•56m ago
Actually, no! Look at this in the paper
> In extending from studying per-example to bulk memorization, we propose a novel inversion of the
previous interpretation of loss curvature: while individual memorized points are associated with high curvature, the direction of curvature varies across examples, meaning that, averaged across multiple
examples, memorization directions are actually flatter than generalizing directions, which maintain a consistent moderate curvature across points
getnormality•19m ago
Ah! I figured I should be very circumspect in the question since I hadn't read in full and there could be some crazy reason it's actually the opposite.
kingstnap•1h ago
A very similar idea is presented here in the first 5 minutes of this recent talk. But more from observing a kink in loss curves.
andy12_•3h ago
1. Run the model once across a dataset to estimate loss curvature per MLP weight matrix via K-FAC (activation/gradient covariances).
2. Decompose each weight matrix into curvature-ordered components; low-curvature directions correspond most to verbatim memorization, higher curvature to shared/general mechanisms.
3. Edit by dropping the low-curvature subspace and keep only the top directions.
vessenes•2h ago
Now, about the paper-that’s super interesting. I imagine the dream here is to distil down into a “reasoning” core. Or maybe reclaim space for more generalization. Lots of interesting use cases.
getnormality•2h ago
I think you may have accidentally switched low and high in #2, no? The abstract speaks of high curvature as associated with memorization:
> curvature for memorized training points is much sharper than non memorized
radarsat1•1h ago
getnormality•1h ago
Say you have a smooth but highly flexible model y = f(x) and some data points you are fitting with a machine learning algorithm. For whatever reason, the algorithm decides it wants to reduce training error by interpolating some specific point, (x0,y0), without negatively affecting training error on nearby points. The direct, guaranteed successful way to do this is to adjust the model to y0 = f(x0) exactly on x0 by adding a Dirac delta there, leaving the rest of f exactly as-is. But this cannot be done on a differentiable model, as it would create a discontinuity. The next best thing that such a model can actually do is replace the Dirac delta with a smooth but very narrow bump (e.g. Gaussian). But this narrow bump will inevitably have extremely high curvature at x0, since the bump is flat at x0 and it has to merge with the neighborhood around x0 in a very short distance.
Think of driving: if you have to change lanes in a very short distance, you're going to have to steer hard. Steering is curvature.
woadwarrior01•1h ago
andy12_•56m ago
> In extending from studying per-example to bulk memorization, we propose a novel inversion of the previous interpretation of loss curvature: while individual memorized points are associated with high curvature, the direction of curvature varies across examples, meaning that, averaged across multiple examples, memorization directions are actually flatter than generalizing directions, which maintain a consistent moderate curvature across points
getnormality•19m ago