Mixed Precision and Other Data Types

author: Jacob Schreiber contact: jmschreiber91@gmail.com

Because pomegranate models are all instances of torch.nn.Module, you can do anything with them that you could do with other PyTorch models. In the first tutorial, we saw how this means that one can use GPUs in exactly the same way that one would with their other PyTorch models. However, this also means that all the great things built-in for during half precision, quantization, automatic mixed precision (AMP), etc., can also be used in pomegranate.

[1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

%pylab inline
import seaborn; seaborn.set_style('whitegrid')

import torch

numpy.random.seed(0)
numpy.set_printoptions(suppress=True)

%load_ext watermark
%watermark -m -n -p numpy,scipy,torch,pomegranate
Populating the interactive namespace from numpy and matplotlib
numpy      : 1.23.4
scipy      : 1.9.3
torch      : 1.13.0
pomegranate: 1.0.0

Compiler    : GCC 11.2.0
OS          : Linux
Release     : 4.15.0-208-generic
Machine     : x86_64
Processor   : x86_64
CPU cores   : 8
Architecture: 64bit

float16 (half) Precision

Doing operations at half precision is just the same as it would by in PyTorch: you can use the .half() method or the .to(torch.float16) method. However, very few operations seem to be supported for half precision, including the log and sqrt methods not being supported for some reason? So, until more operations are supported, you will probably be using other methods.

bfloat16

More operations seem to be supported for bfloat16 though!

[2]:
from pomegranate.distributions import Normal

X = torch.randn(1000, 5)

d = Normal(covariance_type='diag').fit(X)
d.means, d.covs
[2]:
(Parameter containing:
 tensor([-0.0076,  0.0164,  0.0164, -0.0033,  0.0105]),
 Parameter containing:
 tensor([1.0291, 1.0491, 0.9661, 0.8947, 1.0778]))
[3]:
X = X.to(torch.bfloat16)

d = Normal(covariance_type='diag').to(torch.bfloat16).fit(X)
d.means, d.covs
[3]:
(Parameter containing:
 tensor([-0.0075,  0.0165,  0.0165, -0.0034,  0.0103], dtype=torch.bfloat16),
 Parameter containing:
 tensor([1.0312, 1.0469, 0.9688, 0.8945, 1.0781], dtype=torch.bfloat16))

However, not all operations are supported for torch.bfloat16 either, including the cholesky decomposition used in full covariance normal distributions.

Although not all operations support all data types, all models and methods (inference and training) support them to the extent that the underlying operations allow. For instance, we can just as easily use a mixture model with bfloat16 data types as with full floats.

[4]:
from pomegranate.gmm import GeneralMixtureModel

model = GeneralMixtureModel([Normal(covariance_type='diag'), Normal(covariance_type='diag')], verbose=True)
model = model.to(torch.bfloat16)
model.fit(X)
[1] Improvement: 32.0, Time: 0.001619s
[2] Improvement: 32.0, Time: 0.001082s
[3] Improvement: 0.0, Time: 0.00106s
[4]:
GeneralMixtureModel(
  (distributions): ModuleList(
    (0-1): 2 x Normal()
  )
)

And we can use the resulting trained model to make predictions at whatever resolution we’d like.

[5]:
y_hat = model.predict_proba(X)
y_hat, y_hat.dtype
[5]:
(tensor([[0.8281, 0.1680],
         [0.2451, 0.7539],
         [0.0874, 0.9375],
         ...,
         [0.6445, 0.3574],
         [0.3145, 0.6875],
         [0.2695, 0.7305]], dtype=torch.bfloat16),
 torch.bfloat16)
[6]:
model = model.to(torch.float32)

y_hat = model.predict_proba(X)
y_hat, y_hat.dtype
[6]:
(tensor([[0.8299, 0.1701],
         [0.2463, 0.7537],
         [0.0793, 0.9207],
         ...,
         [0.6496, 0.3504],
         [0.3148, 0.6852],
         [0.2788, 0.7212]]),
 torch.float32)

Automatic Mixed Precision

An automatic way to get around some of these issues is to use AMP so that operations which can work at lower precision are cast and others are not. Keeping up with the theme, doing this is exactly the same as using AMP with your other PyTorch models.

[7]:
X = torch.randn(1000, 50).cuda()

model = GeneralMixtureModel([Normal(), Normal()]).cuda()

with torch.autocast('cuda', dtype=torch.float16):
    model.fit(X)

This would have crashed if you tried to run model.fit alone because of the unsupported Cholesky decomposition.

Speedups

Unfortunately, because pomegranate uses a wide range of operations to implement the underlying models, the speedups from using mixed precision are inconsistent. It may be worth trying out in your application but the speedups observed in training neural networks are not guaranteed here because not all operations are supported – and if they are supported, they may not be optimized to be faster. Basically, because AMP will fall back on normal precision for many operations, the entire method may end up not being significantly faster in practice.

[8]:
X = torch.randn(10000, 500).cuda()

model = GeneralMixtureModel([Normal(covariance_type='diag') for i in range(10)], max_iter=5, verbose=True).cuda()
with torch.autocast('cuda', dtype=torch.bfloat16):
    model.fit(X)

print()

model = GeneralMixtureModel([Normal(covariance_type='diag') for i in range(10)], max_iter=5, verbose=True).cuda()
model.fit(X)
[1] Improvement: 1015.0, Time: 0.1042s
[2] Improvement: 417.5, Time: 0.1044s
[3] Improvement: 310.0, Time: 0.1041s
[4] Improvement: 243.5, Time: 0.1041s

[1] Improvement: 1232.5, Time: 0.1039s
[2] Improvement: 613.5, Time: 0.1042s
[3] Improvement: 446.5, Time: 0.1043s
[4] Improvement: 344.0, Time: 0.1043s
[8]:
GeneralMixtureModel(
  (distributions): ModuleList(
    (0-9): 10 x Normal()
  )
)