Drift detection on image classifiers

In this example, we take a simple classification model for image inputs and implement drift detection. As you can see below, the drift detection itself is very straightforward, perhaps even more so than the finetuning of the model.

Load data

In order to experiment, we need a dataset. We grab an ants versus bees dataset from the PyTorch tutorials.

Because we will want to apply transformations to simulate input drift, we split the PyTorch data loading chain a bit differently than usually done: Instead of normalizing the channels to zero mean / unit variance (on imagenet) in the dataset augmentation, we do this in an auxiliary function in the model just before feeding a batch.

We implement a LightningDataModule but with an important additional default_dataloader method that allows us to liberally choose the number of samples and the batch size. We will later use the additional transform to generate a drifted dataset on the fly.

[2]:
torchvision.datasets.utils.download_and_extract_archive('https://download.pytorch.org/tutorial/hymenoptera_data.zip', 'data/')


# these are the standard transforms without the normalization (which we move into the model.step/predict before the forward)
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333)),
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.ToTensor()])
val_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size=256),
    torchvision.transforms.CenterCrop(size=(224, 224)),
    torchvision.transforms.ToTensor()])


class OurDataModule(pl.LightningDataModule):
    def __init__(self, parent: Optional['OurDataModule']=None, additional_transform=None):
        if parent is None:
            self.train_dataset = torchvision.datasets.ImageFolder('./data/hymenoptera_data/train/',
                                                                  transform=train_transform)
            self.val_dataset = torchvision.datasets.ImageFolder('./data/hymenoptera_data/val/',
                                                                  transform=val_transform)
            self.test_dataset = torchvision.datasets.ImageFolder('./data/hymenoptera_data/test/',
                                                                  transform=val_transform)
            self.train_batch_size = 4
            self.val_batch_size = 128
            self.additional_transform = None
        else:
            self.train_dataset = parent.train_dataset
            self.val_dataset = parent.val_dataset
            self.test_dataset = parent.test_dataset
            self.train_batch_size = parent.train_batch_size
            self.val_batch_size = parent.val_batch_size
            self.additional_transform = additional_transform
        if additional_transform is not None:
            self.additional_transform = additional_transform

        self.prepare_data()
        self.setup('fit')
        self.setup('test')

    def setup(self, typ):
        pass

    def collate_fn(self, batch):
        batch = torch.utils.data._utils.collate.default_collate(batch)
        if self.additional_transform:
            batch = (self.additional_transform(batch[0]), *batch[1:])
        return batch

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.train_batch_size,
                                           num_workers=4, shuffle=True, collate_fn=self.collate_fn)
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.val_batch_size,
                                           num_workers=4, shuffle=False, collate_fn=self.collate_fn)
    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.val_batch_size,
                                           num_workers=4, shuffle=False, collate_fn=self.collate_fn)

    def default_dataloader(self, batch_size=None, num_samples=None, shuffle=True):
        dataset = self.val_dataset
        if batch_size is None:
            batch_size = self.val_batch_size
        replacement = num_samples is not None
        if shuffle:
            sampler = torch.utils.data.RandomSampler(dataset, replacement=replacement, num_samples=num_samples)
        else:
            sampler = None
        return torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=sampler,
                                           collate_fn=self.collate_fn)


datamodule = OurDataModule()

Using downloaded and verified file: data/hymenoptera_data.zip
Extracting data/hymenoptera_data.zip to data/

Build a model

We use the TorchVision ResNet18 as our starting point.

Note that while we train a model here, in the end we use only the feature extractor for the drift detection, but we want to create the realistic situation where you have a classification model. Nothing here is specific to drift detection except.

[3]:
feature_extractor = torchvision.models.resnet18(pretrained=True)
feature_extractor.fc = torch.nn.Identity()
[4]:
class Classifier(pl.LightningModule):
    def __init__(self, base_classifier):
        super().__init__()
        self.backbone = base_classifier
        self.backbone.eval()
        for p in self.backbone.parameters():
            p.requires_grad_(False)
        self.classifier = torch.nn.Linear(512, 2)

    def normalize(self, x: torch.Tensor):
        # We pull the normalization, usually done in the dataset into the model forward
        x = torchvision.transforms.functional.normalize(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        return x

    def forward(self, x: torch.Tensor):
        x = self.normalize(x)
        y = self.backbone(x)
        return self.classifier(y)

    def training_step(self, batch: torch.Tensor, batch_idx: int):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        acc = (y_hat.max(1).indices == y).float().mean()
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        return loss

    def validation_step(self, batch: torch.Tensor, batch_idx: int):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        acc = (y_hat.max(1).indices == y).float().mean()
        self.log('val_loss', loss)
        self.log('val_acc', acc)
        return loss

    def test_step(self, batch: torch.Tensor, batch_idx: int):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        acc = (y_hat.max(1).indices == y).float().mean()
        self.log('test_loss', loss)
        self.log('test_acc', acc)
        return loss

    def predict(self, batch: Any, batch_idx: Optional[int]=None, dataloader_idx: Optional[int] = None):
        return self(batch)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

model = Classifier(feature_extractor)

We finetune our network a bit (we set the requires grad to False for the resnet):

[5]:
trainer = pl.Trainer(max_epochs=3, gpus=1, checkpoint_callback=False, logger=False)
trainer.fit(model, datamodule)
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type   | Params
--------------------------------------
0 | backbone   | ResNet | 11.2 M
1 | classifier | Linear | 1.0 K
--------------------------------------
1.0 K     Trainable params
11.2 M    Non-trainable params
11.2 M    Total params
44.710    Total estimated model params size (MB)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
[5]:
1

Let us check that things are OK. You should get 85%-95% accuracy.

[6]:

trainer.test(model, datamodule=datamodule)
Widget Javascript not detected.  It may not be installed or enabled properly.
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.9607843160629272, 'test_loss': 0.20751166343688965}
--------------------------------------------------------------------------------
[6]:
[{'test_loss': 0.20751166343688965, 'test_acc': 0.9607843160629272}]

Simulating drifted data

For systematic experiments, we want to compare the output of the drift detector on benign (non-drifted) and and drifted, here (partially) out of distribution samples. We simulate out of distribution data by applying a gaussian blur. In reality you might have effects like the camera lense losing focus or dirt impeding the picture quality.

Note that we do not use the drifted data for “training” the drift detector, but just for evaluation!

On the technical side, we take our datamodule as the in-distribution datamodule as is and use a derived datamodule which applies the gaussian blur in addition to the usual transforms as the out of distribution datamodule.

[7]:
def corruption_function(x: torch.Tensor):
    return torchdrift.data.functional.gaussian_blur(x, severity=2)

ind_datamodule = datamodule
ood_datamodule = OurDataModule(parent=datamodule, additional_transform=corruption_function)

Let us grab a few inputs and show them. We also let our model predict labels. Sometimes we see that labels are flipped for the blurred images.

[8]:
inputs, _ = next(iter(datamodule.default_dataloader(shuffle=True)))
inputs_ood = corruption_function(inputs)

N = 6
model.eval()
inps = torch.cat([inputs[:N], inputs_ood[:N]])
model.cpu()
predictions = model.predict(inps).max(1).indices

predicted_labels = [["ant","bee"][p] for p in predictions]
pyplot.figure(figsize=(15, 5))
for i in range(2 * N):
    pyplot.subplot(2, N, i + 1)
    pyplot.title(predicted_labels[i])
    pyplot.imshow(inps[i].permute(1, 2, 0))
    pyplot.xticks([])
    pyplot.yticks([])
../_images/notebooks_drift_detection_on_images_14_0.png

For drift detection, we need a feature extractor (of course, we had one above, too, but let’s play along and pretend we got the model from our colleague).

[9]:
feature_extractor = copy.deepcopy(model)
feature_extractor.classifier = torch.nn.Identity()

The drift detector

We instantiate the drift detector. In this example we use the Kernel MMD drift detector on the features extracted by the pretrained ResNet. We believe this is a good choice for your initial experiments, too.

[10]:
drift_detector = torchdrift.detectors.KernelMMDDriftDetector()

We fit the drift detector using benign training data only. If you have a more sizeable dataset, you likely want to limit the number of samples you use here, you can do this by specifying the num_batches parameter to the fit function.

[11]:
torchdrift.utils.fit(datamodule.train_dataloader(), feature_extractor, drift_detector)
100%|██████████| 61/61 [00:01<00:00, 36.39it/s]

We can then put the drift detector to action. TorchDrift uses vanilla torch.nn.Modules so you can just build a Sequential of the feature extractor and drift detector (or, in other use-cases an additional dimension reduction step in between).

[12]:
drift_detection_model = torch.nn.Sequential(
    feature_extractor,
    drift_detector
)

First we test the output on benign input. By our setup (see the overview for details), we expect the computed p-value to be uniformely distributed on the unit interval \([0, 1]\), so with 95% probability, it is above 0.05. Note that for the p-value lower means “more likely that things have drifted”, while the same leads to higher scores.

[13]:
features = feature_extractor(inputs)
score = drift_detector(features)
p_val = drift_detector.compute_p_value(features)
score, p_val
[13]:
(tensor(0.0129), tensor(0.0520))

We can also visualize the two distribution to detemine whether we “intuitively believe” that the two samples are from the same distribution.

You may wonder why we use Isomap here instead of the more well-known (in certain circles) t-SNE. The reason is that Isomap actually gives us a transformation (from the training data) that we can use on the inputs.

[14]:
N_base = drift_detector.base_outputs.size(0)
mapper = sklearn.manifold.Isomap(n_components=2)
base_embedded = mapper.fit_transform(drift_detector.base_outputs)
features_embedded = mapper.transform(features)
pyplot.scatter(base_embedded[:, 0], base_embedded[:, 1], s=2, c='r')
pyplot.scatter(features_embedded[:, 0], features_embedded[:, 1], s=4)
pyplot.title(f'score {score:.2f} p-value {p_val:.2f}');
../_images/notebooks_drift_detection_on_images_26_0.png

We can also do the same with out of distribution data. Typically the p-value here is very low - that it is actually 0 is an artifact from the bootstrapping we use for computing the p-value.

Can you spot a difference in the plot?

[15]:
features = feature_extractor(inputs_ood)
score = drift_detector(features)
p_val = drift_detector.compute_p_value(features)

features_embedded = mapper.transform(features)
pyplot.scatter(base_embedded[:, 0], base_embedded[:, 1], s=2, c='r')
pyplot.scatter(features_embedded[:, 0], features_embedded[:, 1], s=4)
pyplot.title(f'score {score:.2f} p-value {p_val:.2f}');
../_images/notebooks_drift_detection_on_images_28_0.png

This concludes our overview.

Here, we saw how to get scores and p-values for the drift detection, taking the feature extractor from our classification model (our you might use an imagenet-trained feature extractor regardless of your model).

We show some more systematic experimentation in Comparing Drift Detectors. Also check our our example how to use the drift detector in monitoring a model.

View this document as a notebook: https://github.com/torchdrift/torchdrift/blob/master/notebooks/drift_detection_on_images.ipynb