DriftDetectionExperiment(drift_detector, feature_extractor, ood_ratio=1.0, sample_size=1)¶
An experimental setup to explore the ROC of drift detection setups
This tests a setup based on a drift detector and a feature extractor (the latter including reducers).
The detector is fitted with post_training.
Then given datamodules for in-distribution and out-of-distribution, non-drifted and drifted batches are constructed. The test batches have the sample_size given as a constructor parameter and a fraction ood_ratio samples (rounded up) are from the out-of-distribution datamodule.
The datamodules are expected to provide a default_dataloader method taking batch_size and num_samples arguments (see the examples for details).
evaluate(ind_datamodule, ood_datamodule, num_runs=50)¶
runs the experiment (num_runs inputs)
- Returns: auc, (fp, tp)
auc: Area-under-Curve score
fp, tp: False positive and true positive rates to plot the ROC curve.
Called after training the main model, fits the drift detector.
tests check and raises RuntimeError with message if false
fit(dl: torch.utils.data.dataloader.DataLoader, feature_extractor: torch.nn.modules.module.Module, reducers_detectors: Union[torchdrift.reducers.reducer.Reducer, torchdrift.detectors.detector.Detector, List[Union[torchdrift.reducers.reducer.Reducer, torchdrift.detectors.detector.Detector]]], *, num_batches: Optional[int] = None, device: Optional[Union[torch.device, str]] = None)¶
Train drift detector on reference distribution.
The dataloader dl should provide the reference distribution. Optionally you can limit the number of batches sampled from the dataloader with num_batches.
The feature extractor can be any module be anything that does not need to be fit.
The reducers and detectors should be passed (in the order they should be applied, one takes the output from the previous) as a list. A single detector or reducer can also be passed.
If you provide a device, data is moved there before running through the feature extractor, otherwise the functions try to infer the device from the feature_extractor.