Skip to yearly menu bar Skip to main content


Poster
in
Workshop: MATH-AI: The 4th Workshop on Mathematical Reasoning and AI

A Hessian View of Grokking in Mathematical Reasoning

Zhenshuo Zhang · Jerry Liu · Christopher RĂ© · Hongyang Zhang

Keywords: [ Hessian ] [ grokking ]


Abstract: Mathematical reasoning is a central problem in developing more intelligent language models. An intriguing phenomenon observed in mathematical arithmetics is grokking, where the training loss of a transformer model stays near zero for an extended period until the validation loss finally reduces to near zero. In this work, we approach this phenomenon through a view of the Hessian of the loss surface. The Hessian relates to the generalization properties of neural networks as it can capture geometric properties of the loss surface, such as the sharpness of local minima. We begin by noting in our experiments that high weight decay is essential for grokking to occur in several arithmetic tasks (trained with a GPT-2 style transformer model). However, we also find that the training loss is highly unstable and exhibits strong oscillations. To address this issue, we consider adding regularization to the Hessian by injecting isotropic Gaussian noise to the weights of the transformer network, and find that this combination of high weight decay and Hessian regularization can smooth out the training loss during grokking. We also find that this approach can accelerate the grokking stage compared to existing methods by at least $50\%$ measured on seven arithmetic tasks. Finally, to understand the precise cause of grokking, we consider a Hessian-based measurement for multi-layer networks and find that this measure yields non-vacuous estimates of the generalization errors observed in practice. We hope these empirical findings can facilitate future research towards understanding grokking (and generalization) in mathematical reasoning.

Chat is not available.