Scaling Back-propagation by Parallel Scan Algorithm
In an era when the performance of a single compute device plateaus, software must be designed to scale on a massively parallel system for better runtime performance. However, the commonly used back-propagation (BP) algorithm imposes a strong sequential dependency in the process of gradient computation. Under model parallelism, BP has a theoretical step complexity of Θ (n) which hinders its scalability in a parallel computing environment, where n represents the number of compute devices into which a model is partitioned. In this work, we restructure such dependency and reformulate BP into a scan operation which is scaled by our modified version of the Blelloch scan algorithm. Our algorithm is able to achieve a theoretical step complexity of Θ ( n). We perform an in-depth performance analysis and identify the challenges of deploying our algorithm in a practical setting, along with a variety of approaches to tackle such challenges. We demonstrate the scalability benefits of our algorithm in the use case of retraining pruned networks.
READ FULL TEXT