Git Re-Basin: Merging Models modulo Permutation Symmetries

Two neural networks $A$ and $B$ of the same architecture and trained by SGD are compared after finding a suitable permutation $\pi$ that minimizes a certain distance notion. Experimentally, weights linearly interpolated between $A$ and $\pi(B)$ yield the same loss. In the context of federated learning, merging $A$ and $\pi(B)$ performs much better than merging $A$ and $B$.

Figure 1. Linear mode connectivity during training. Loss barriers as a function of training time for an MLP trained on MNIST. Loss interpolation plots inlaid to highlight results in initial and later epochs. LMC manifests gradually as models are trained. Pick two neural networks $A$ and $B$ of the same architecture that were trained using SGD on the same (or similar) data but with different initializations and training parameters. How similar are they? Often, the training of NNs is surprisingly stable, leading to the same performance. However, the weight matrices of two independently trained networks are usually very different. In fact, one should expect them to be different since there is an enormous number of permutation symmetries that change the weights but leave the predictions and hence the loss unaltered. A full comparison between weights of $A$ and $B$ is only possible when such symmetries have been accounted for. [Ain22G] provides ideas and algorithms that go in this direction.

The goal is to find a permutation $\pi$ of the weights that minimizes some sort of distance between $A$ and $\pi(B)$ and then to compare $A$ to $\pi(B)$. This is easier said than done since even for a small NN there are far more prediction-preserving permutations than there are atoms in the observable universe. An exhaustive search is thus impossible. In the paper mentioned above, several strategies and algorithms for finding a suitable $\pi$ are defined, some making use of the training set and some only using the weights. The main experimental results in comparing $A$ to $\pi(B)$ are the following:

  1. In many situations, the loss is constant on the entire line between $A$ to $\pi(B)$, i.e. for all $\lambda \in [0, 1]$ the weights for $\lambda A + (1-\lambda) \pi(B)$ are equally performant to those of $A$ or $B$. This property is called linear mode connectivity and is part of several conjectures about the loss landscape of NNs, see Figure 1. Figure 2. Models trained on disjoint datasets can be merged with positive-sum results. Two ResNet models trained on disjoint, biased subsets of CIFAR-100 can be merged in weight space.
  2. This find-B-closest-to-A permutation gives a natural strategy for merging neural networks in the context of federated learning: given $A$ and $B$ that were trained on non-intersecting subsets of the training set, the combinations $\lambda A + (1-\lambda) \pi(B)$ outperform using only $A$ or $B$ individually. Just merging $A$ and $B$ directly like $\lambda A + ( 1-\lambda) B$ leads to highly suboptimal results. However, ensembling $A$ and $B$ outperforms any merging, with the downside that it needs twice as many resources for inference, see Figure 2.

The paper provides an interesting experimental analysis of the loss landscape of NNs, and the algorithms for finding $\pi$ given $A$ and $B$ might be useful in practice, especially for federated learning. I am personally a bit skeptical about the applicability of these ideas though. Training on a full data set is still the best thing to do, and it is not clear whether sharing network weights in federated learning is any easier than sharing data (from the legal point of view), since various forms of data leakage through trained weights pose a serious and, to my knowledge, not yet fully resolved problem.

References