233. A Hessian-Aware Stochastic Differential Equation for Modelling SGD
Invited abstract in session WC-5: Recent Advances in Stochastic Optimization, stream Optimization for machine learning.
Wednesday, 14:00-16:00Room: B100/4013
Authors (first author is the speaker)
| 1. | Xiang Li
|
| Department of Computer Science, ETH Zurich |
Abstract
Understanding how Stochastic Gradient Descent (SGD) escapes from stationary points is essential for advancing optimization in machine learning, particularly in non-convex settings. While continuous-time approximations using stochastic differential equations (SDEs) have been instrumental in analyzing such behavior, existing models fall short in accurately capturing SGD dynamics—even for simple loss landscapes. In this work, we introduce a novel SDE model derived through a stochastic backward error analysis framework. This new formulation incorporates second-order information from the objective function into both its drift and diffusion components, leading to a more faithful representation of SGD’s behavior.
Our new SDE improves the theoretical weak approximation error among existing models, reducing the dependence on the smoothness parameter. Importantly, we demonstrate that, for quadratic objectives, our model is the first to exactly replicate the distributions of SGD iterates.
Empirical evaluations on neural network loss surfaces further validate the practical advantages of our SDE. Additionally, the improved approximation allows for a better analysis of the escape time of SGD near stationary points.
Keywords
- Stochastic optimization
Status: accepted
Back to the list of papers