{ "cells": [ { "cell_type": "markdown", "id": "indirect-worker", "metadata": {}, "source": [ "# Deploying Drift Detection\n", "\n", "TorchDrift provides the tools you need to detect drift. But how do you actually get your model to monitor drift?\n", "\n", "This short tutorial shows how to use model hooks on your feature extractor to capture data to feed into the drift detector.\n", "\n", "First we need to set up a model and drift detector. Let us import some packages." ] }, { "cell_type": "code", "execution_count": 1, "id": "secure-potential", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0, '../')\n", "\n", "import torch\n", "import torchvision\n", "import torchdrift\n", "import copy\n", "%matplotlib inline\n", "from matplotlib import pyplot\n", "\n", "device = \"cuda\" if torch.cuda.is_available else \"cpu\"" ] }, { "cell_type": "markdown", "id": "liberal-generic", "metadata": {}, "source": [ "We use a very simple ResNet as our example model. As we often do, we move the normalization out of the dataset transforms. We do this because we want to post-process the images to \"fake\" drifted inputs, so you would not need to do this for your own data and models (but I would advocate that moving the normalization into the models is indeed uncommon but best practice). We also split out the fully connected layer from the ResNet." ] }, { "cell_type": "code", "execution_count": 2, "id": "trained-mobility", "metadata": {}, "outputs": [], "source": [ "resnet = torchvision.models.resnet18(pretrained=True)\n", "model = torch.nn.Sequential(\n", " torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " resnet,\n", " resnet.fc\n", " )\n", "resnet.fc = torch.nn.Identity()\n", "model.eval().to(device)\n", "for p in model.parameters():\n", " p.requires_grad_(False)\n" ] }, { "cell_type": "markdown", "id": "vocal-official", "metadata": {}, "source": [ "And we set up a dataset. " ] }, { "cell_type": "code", "execution_count": 3, "id": "excellent-white", "metadata": {}, "outputs": [], "source": [ "val_transform = torchvision.transforms.Compose([\n", " torchvision.transforms.Resize(size=256),\n", " torchvision.transforms.CenterCrop(size=(224, 224)),\n", " torchvision.transforms.ToTensor()])\n", "\n", "\n", "ds_train = torchvision.datasets.ImageFolder('./data/hymenoptera_data/train/',\n", " transform=val_transform)\n", "ds_val = torchvision.datasets.ImageFolder('./data/hymenoptera_data/val/',\n", " transform=val_transform)\n", "dl_val = torch.utils.data.DataLoader(ds_val, batch_size=64, shuffle=True)" ] }, { "cell_type": "markdown", "id": "aware-attack", "metadata": {}, "source": [ "We fit the detector. We use the p-value here for demonstration. Note that this is currently computationally more expensive than the score (but we'll work on pre-computing the score distribution under the null hypothesis)." ] }, { "cell_type": "code", "execution_count": 4, "id": "split-madison", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1/1 [00:00<00:00, 1.17it/s]\n" ] } ], "source": [ "def fit_detector(N_train):\n", " detector = torchdrift.detectors.KernelMMDDriftDetector(return_p_value=True)\n", " dl_train = torch.utils.data.DataLoader(ds_train, batch_size=N_train, shuffle=True)\n", " feature_extractor = model[:2] # without the fc layer \n", " torchdrift.utils.fit(dl_train, feature_extractor, detector, num_batches=1)\n", " return detector\n", "\n", "detector = fit_detector(N_train = 100)" ] }, { "cell_type": "markdown", "id": "herbal-georgia", "metadata": {}, "source": [ "We build a model monitor: When it hooks into the model to capture the output of `feature_layer`. I will cache the last `N` captured model features in a ring buffer.\n", "\n", "If we provide a `callback`, it will call the drift detector every `callback_interval` after it has seen enough samples.\n", "\n", "Just to show off, I also throw in a little plot function." ] }, { "cell_type": "code", "execution_count": 5, "id": "after-security", "metadata": {}, "outputs": [], "source": [ "class ModelMonitor:\n", " def __init__(self, drift_detector, feature_layer, N = 20, callback = None, callback_interval = 1):\n", " self.N = N\n", " base_outputs = drift_detector.base_outputs\n", " self.drift_detector = drift_detector\n", " assert base_outputs is not None, \"fit drift detector first\"\n", " feature_dim = base_outputs.size(1)\n", " self.feature_rb = torch.zeros(N, feature_dim, device=base_outputs.device, dtype=base_outputs.dtype)\n", " self.have_full_round = False\n", " self.next_idx = 0\n", " self.hook = feature_layer.register_forward_hook(self.collect_hook)\n", " self.counter = 0\n", " self.callback = callback\n", " self.callback_interval = callback_interval\n", "\n", " def unhook(self):\n", " self.hook.remove()\n", "\n", " def collect_hook(self, module, input, output):\n", " self.counter += 1\n", " bs = output.size(0)\n", " if bs > self.N:\n", " output = output[-self.N:]\n", " bs = self.N\n", " output = output.reshape(bs, -1)\n", " first_part = min(self.N - self.next_idx, bs)\n", " self.feature_rb[self.next_idx: self.next_idx + first_part] = output[:first_part]\n", " if first_part < bs:\n", " self.feature_rb[: bs - first_part] = self.output[first_part:]\n", " if not self.have_full_round and self.next_idx + bs >= self.N:\n", " self.have_full_round = True\n", " self.next_idx = (self.next_idx + bs) % self.N\n", " if self.callback and self.have_full_round and self.counter % self.callback_interval == 0:\n", " p_val = self.drift_detector(self.feature_rb)\n", " self.callback(p_val)\n", "\n", " def plot(self):\n", " import sklearn.manifold\n", " from matplotlib import pyplot\n", " \n", " mapping = sklearn.manifold.Isomap()\n", " ref = mapping.fit_transform(self.drift_detector.base_outputs.to(\"cpu\").numpy())\n", "\n", " test = mapping.transform(self.feature_rb.to(\"cpu\").numpy())\n", " pyplot.scatter(ref[:, 0], ref[:, 1])\n", " pyplot.scatter(test[:, 0], test[:, 1])" ] }, { "cell_type": "markdown", "id": "unique-arabic", "metadata": {}, "source": [ "To instantiate our monitor, we need an alarm function.\n", "I just raise an exception, but you could also text the AI facility management or so.\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "indie-province", "metadata": {}, "outputs": [], "source": [ "def alarm(p_value):\n", " assert p_value > 0.01, f\"Drift alarm! p-value: {p_value*100:.03f}%\"\n", " \n", "mm = ModelMonitor(detector, model[1], callback=alarm)" ] }, { "cell_type": "markdown", "id": "suspected-cream", "metadata": {}, "source": [ "We grab a batch each of benign and drifted samples.\n", "\n", "Fun fact: For this dataset, shuffling in the dataloader is important here. Otherwise the class balance of the test batch will be off enough to cause the alarm to be set off.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "stable-borough", "metadata": {}, "outputs": [], "source": [ "it = iter(dl_val)\n", "batch = next(it)[0].to(device)\n", "batch_drifted = torchdrift.data.functional.gaussian_blur(next(it)[0].to(device), 5)" ] }, { "cell_type": "markdown", "id": "pretty-disorder", "metadata": {}, "source": [ "Now we run our model. Imagenet class 309 is _bee_ and 310 is _ant_. Do not believe the model if it says aircraft carrier (it did this during testing). Note that we might be unlucky and get an exception here. This is at least in part a sampling artifact from computing the p-value." ] }, { "cell_type": "code", "execution_count": 8, "id": "olympic-feedback", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([310, 310, 309, 494, 310, 301, 309, 309, 310, 310, 114, 309, 310, 309,\n", " 310, 309, 310, 310, 309, 309, 310, 310, 309, 309, 310, 309, 310, 310,\n", " 310, 309, 309, 319, 310, 403, 410, 309, 310, 310, 410, 323, 310, 309,\n", " 310, 947, 309, 309, 310, 309, 310, 114, 309, 309, 410, 309, 309, 309,\n", " 310, 310, 309, 310, 310, 309, 79, 309], device='cuda:0')" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res = model(batch).argmax(1) \n", "res" ] }, { "cell_type": "code", "execution_count": 9, "id": "white-pierce", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.6550, device='cuda:0')" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "detector.compute_p_value(mm.feature_rb)" ] }, { "cell_type": "markdown", "id": "prescribed-senior", "metadata": {}, "source": [ "We can also look at the latents to form an opinion if we think they might be from the same distribution. If you happen to have a heavily class-imbalanced sample (e.g. you disabled the shuffle in the dataloader - for testing, not because you forgot!) you might spot that imbalance here on the projected features." ] }, { "cell_type": "code", "execution_count": 10, "id": "furnished-repair", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "mm.plot()" ] }, { "cell_type": "markdown", "id": "rural-passage", "metadata": {}, "source": [ "When we call the model with drifted inputs, we are relatively sure to set off the alarm." ] }, { "cell_type": "code", "execution_count": 11, "id": "appointed-corps", "metadata": { "tags": [ "raises-exception" ] }, "outputs": [ { "ename": "AssertionError", "evalue": "Drift alarm! p-value: 0.000%", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# call it with drifted inputs...\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_drifted\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 888\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 889\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 890\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 118\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 119\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 891\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 892\u001b[0m self._forward_hooks.values()):\n\u001b[0;32m--> 893\u001b[0;31m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 894\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 895\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook_result\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mcollect_hook\u001b[0;34m(self, module, input, output)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhave_full_round\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcounter\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_interval\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0mp_val\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdrift_detector\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeature_rb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp_val\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36malarm\u001b[0;34m(p_value)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0malarm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp_value\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mp_value\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0.01\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34mf\"Drift alarm! p-value: {p_value*100:.03f}%\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mmm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mModelMonitor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdetector\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malarm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mAssertionError\u001b[0m: Drift alarm! p-value: 0.000%" ] } ], "source": [ "# call it with drifted inputs...\n", "model(batch_drifted)" ] }, { "cell_type": "markdown", "id": "rolled-basics", "metadata": {}, "source": [ "With any luck, you can also see the drift in the datapoints." ] }, { "cell_type": "code", "execution_count": 12, "id": "mental-substance", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "mm.plot()" ] }, { "cell_type": "markdown", "id": "accompanied-moscow", "metadata": {}, "source": [ "So in this notebook we saw how to use model hooks with the drift detector to automatically set of the alarm when something bad happens. Just remember that if you set the p-value to $x\\%$ you expect to get a false alarm every $100\\%/x\\%$ batches to not spam your emergency contact." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.1+" } }, "nbformat": 4, "nbformat_minor": 5 }