Deploying Drift Detection

TorchDrift provides the tools you need to detect drift. But how do you actually get your model to monitor drift?

This short tutorial shows how to use model hooks on your feature extractor to capture data to feed into the drift detector.

First we need to set up a model and drift detector. Let us import some packages.

[1]:
import sys
sys.path.insert(0, '../')

import torch
import torchvision
import torchdrift
import copy
%matplotlib inline
from matplotlib import pyplot

device = "cuda" if torch.cuda.is_available else "cpu"

We use a very simple ResNet as our example model. As we often do, we move the normalization out of the dataset transforms. We do this because we want to post-process the images to “fake” drifted inputs, so you would not need to do this for your own data and models (but I would advocate that moving the normalization into the models is indeed uncommon but best practice). We also split out the fully connected layer from the ResNet.

[2]:
resnet = torchvision.models.resnet18(pretrained=True)
model = torch.nn.Sequential(
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    resnet,
    resnet.fc
    )
resnet.fc = torch.nn.Identity()
model.eval().to(device)
for p in model.parameters():
    p.requires_grad_(False)

And we set up a dataset.

[3]:
val_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size=256),
    torchvision.transforms.CenterCrop(size=(224, 224)),
    torchvision.transforms.ToTensor()])


ds_train = torchvision.datasets.ImageFolder('./data/hymenoptera_data/train/',
                                            transform=val_transform)
ds_val = torchvision.datasets.ImageFolder('./data/hymenoptera_data/val/',
                                          transform=val_transform)
dl_val = torch.utils.data.DataLoader(ds_val, batch_size=64, shuffle=True)

We fit the detector. We use the p-value here for demonstration. Note that this is currently computationally more expensive than the score (but we’ll work on pre-computing the score distribution under the null hypothesis).

[4]:
def fit_detector(N_train):
    detector = torchdrift.detectors.KernelMMDDriftDetector(return_p_value=True)
    dl_train = torch.utils.data.DataLoader(ds_train, batch_size=N_train, shuffle=True)
    feature_extractor = model[:2]  # without the fc layer
    torchdrift.utils.fit(dl_train, feature_extractor, detector, num_batches=1)
    return detector

detector = fit_detector(N_train = 100)
100%|██████████| 1/1 [00:00<00:00,  1.17it/s]

We build a model monitor: When it hooks into the model to capture the output of feature_layer. I will cache the last N captured model features in a ring buffer.

If we provide a callback, it will call the drift detector every callback_interval after it has seen enough samples.

Just to show off, I also throw in a little plot function.

[5]:
class ModelMonitor:
    def __init__(self, drift_detector, feature_layer, N = 20, callback = None, callback_interval = 1):
        self.N = N
        base_outputs = drift_detector.base_outputs
        self.drift_detector = drift_detector
        assert base_outputs is not None, "fit drift detector first"
        feature_dim = base_outputs.size(1)
        self.feature_rb = torch.zeros(N, feature_dim, device=base_outputs.device, dtype=base_outputs.dtype)
        self.have_full_round = False
        self.next_idx = 0
        self.hook = feature_layer.register_forward_hook(self.collect_hook)
        self.counter = 0
        self.callback = callback
        self.callback_interval = callback_interval

    def unhook(self):
        self.hook.remove()

    def collect_hook(self, module, input, output):
        self.counter += 1
        bs = output.size(0)
        if bs > self.N:
            output = output[-self.N:]
            bs = self.N
        output = output.reshape(bs, -1)
        first_part = min(self.N - self.next_idx, bs)
        self.feature_rb[self.next_idx: self.next_idx + first_part] = output[:first_part]
        if first_part < bs:
            self.feature_rb[: bs - first_part] = self.output[first_part:]
        if not self.have_full_round and self.next_idx + bs >= self.N:
            self.have_full_round = True
        self.next_idx = (self.next_idx + bs) % self.N
        if self.callback and self.have_full_round and self.counter % self.callback_interval == 0:
            p_val = self.drift_detector(self.feature_rb)
            self.callback(p_val)

    def plot(self):
        import sklearn.manifold
        from matplotlib import pyplot

        mapping = sklearn.manifold.Isomap()
        ref = mapping.fit_transform(self.drift_detector.base_outputs.to("cpu").numpy())

        test = mapping.transform(self.feature_rb.to("cpu").numpy())
        pyplot.scatter(ref[:, 0], ref[:, 1])
        pyplot.scatter(test[:, 0], test[:, 1])

To instantiate our monitor, we need an alarm function. I just raise an exception, but you could also text the AI facility management or so.

[6]:
def alarm(p_value):
    assert p_value > 0.01, f"Drift alarm! p-value: {p_value*100:.03f}%"

mm = ModelMonitor(detector, model[1], callback=alarm)

We grab a batch each of benign and drifted samples.

Fun fact: For this dataset, shuffling in the dataloader is important here. Otherwise the class balance of the test batch will be off enough to cause the alarm to be set off.

[7]:
it = iter(dl_val)
batch = next(it)[0].to(device)
batch_drifted = torchdrift.data.functional.gaussian_blur(next(it)[0].to(device), 5)

Now we run our model. Imagenet class 309 is bee and 310 is ant. Do not believe the model if it says aircraft carrier (it did this during testing). Note that we might be unlucky and get an exception here. This is at least in part a sampling artifact from computing the p-value.

[8]:
res = model(batch).argmax(1)
res
[8]:
tensor([310, 310, 309, 494, 310, 301, 309, 309, 310, 310, 114, 309, 310, 309,
        310, 309, 310, 310, 309, 309, 310, 310, 309, 309, 310, 309, 310, 310,
        310, 309, 309, 319, 310, 403, 410, 309, 310, 310, 410, 323, 310, 309,
        310, 947, 309, 309, 310, 309, 310, 114, 309, 309, 410, 309, 309, 309,
        310, 310, 309, 310, 310, 309,  79, 309], device='cuda:0')
[9]:
detector.compute_p_value(mm.feature_rb)
[9]:
tensor(0.6550, device='cuda:0')

We can also look at the latents to form an opinion if we think they might be from the same distribution. If you happen to have a heavily class-imbalanced sample (e.g. you disabled the shuffle in the dataloader - for testing, not because you forgot!) you might spot that imbalance here on the projected features.

[10]:
mm.plot()
../_images/notebooks_deployment_monitoring_example_18_0.png

When we call the model with drifted inputs, we are relatively sure to set off the alarm.

[11]:
# call it with drifted inputs...
model(batch_drifted)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-11-bf4c9a25f81a> in <module>
      1 # call it with drifted inputs...
----> 2 model(batch_drifted)

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/container.py in forward(self, input)
    116     def forward(self, input):
    117         for module in self:
--> 118             input = module(input)
    119         return input
    120

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    891                 _global_forward_hooks.values(),
    892                 self._forward_hooks.values()):
--> 893             hook_result = hook(self, input, result)
    894             if hook_result is not None:
    895                 result = hook_result

<ipython-input-5-a161160f0ec9> in collect_hook(self, module, input, output)
     33         if self.callback and self.have_full_round and self.counter % self.callback_interval == 0:
     34             p_val = self.drift_detector(self.feature_rb)
---> 35             self.callback(p_val)
     36
     37     def plot(self):

<ipython-input-6-8fe6769bae9a> in alarm(p_value)
      1 def alarm(p_value):
----> 2     assert p_value > 0.01, f"Drift alarm! p-value: {p_value*100:.03f}%"
      3
      4 mm = ModelMonitor(detector, model[1], callback=alarm)

AssertionError: Drift alarm! p-value: 0.000%

With any luck, you can also see the drift in the datapoints.

[12]:
mm.plot()
../_images/notebooks_deployment_monitoring_example_22_0.png

So in this notebook we saw how to use model hooks with the drift detector to automatically set of the alarm when something bad happens. Just remember that if you set the p-value to \(x\%\) you expect to get a false alarm every \(100\%/x\%\) batches to not spam your emergency contact.

View this document as a notebook: https://github.com/torchdrift/torchdrift/blob/master/notebooks/deployment_monitoring_example.ipynb