Efficient Householder Transformation in PyTorch

Efficient Householder Transformation in PyTorch

torch-householder

Overview

This landing page provides an overview of the Householder transformation algorithm for calculating orthogonal matrices and Stiefel frames. The algorithm is implemented as a Python package with differentiable bindings to PyTorch. In particular, the package provides an enhanced drop-in replacement for the torch.orgqr function. For the detailed analysis of functionality, please visit the project GitHub .

APIs for orthogonal transformations have been around since LAPACK; however, their support in the deep learning frameworks is lacking. Recently, orthogonal constraints have become popular in deep learning as a way to regularize models and improve training dynamics, and hence the need to backpropagate through orthogonal transformations arised.

PyTorch 1.7 implements matrix exponential function torch.matrix_exp, which can be repurposed to performing the orthogonal transformation when the input matrix is skew-symmetric. This is the baseline we use in Speed and Precision evaluation.

Compared to torch.matrix_exp, the Householder transformation implemented in this package has the following advantages:

  • Orders of magnitude lower memory footprint
  • Ability to transform non-square matrices (Stiefel frames)
  • A significant speed-up for non-square matrices
  • Better numerical precision for all matrix and batch sizes

Source code

Check out the project GitHub page

Download

pip install torch-householder

Citation

@misc{obukhov2021torchhouseholder,
  author={Anton Obukhov},
  year=2021,
  title={Efficient Householder transformation in PyTorch},
  url={https://github.com/toshas/torch-householder}
}