Reducing your labeled data requirements (2–5x) for Deep Learning: Deep Mind’s new “Contrastive Predictive Coding 2.0”
Current Deep Learning for vision, audio, etc. requires vast amounts of human labeled data, with many examples of each category, to properly train a classifier to acceptable accuracy.
By contrast, humans only need to see a few examples of a class to begin properly and accurately recognizing and classifying future examples of that class.
The difference is that humans are able to quickly generate accurate mental ‘representations’ of things, and then use those representations to flexibly account for future variations. After seeing a few images of bluejays for example, we can make a mental model or representation of a bluejay, and then spot and accurately id bluejays in new images even when the birds are facing different angles, different perspectives, etc. Deep learning struggles to build representations in the same manner, and thus needs to train on lots and lots and lots of labeled instances (with images ‘augmented’ to show different angles, perspectives) to reliably handle future data and successfully generalize.
That gap in representation ability, which drives the requirement for large amounts of labeled data, however, may now be rapidly shrinking thanks to new improvements, by Deep Mind, upon Deep Mind’s earlier breakthrough “CPC” or Contrastive Predictive Coding.
Their work, CPC 2.0, is presented in a new paper called “Data Efficient Image Recognition with Contrastive Predictive Coding”.
Utilizing CPC 2.0, Image classification and recognition NN’s are able to better build representations that allow strong generalization after training on only small amounts of data…getting closer to how humans are able to perform.
Performance Results: Some comparisons drive home the significance — image classifiers trained with CPC 2 and only 1% of ImageNet data achieved 78% top-5 accuracy, outperforming supervised (regular labeled training) trained on 5x more data.
Continuing with training on all the available images (100%), CPC2 ResNet outperformed fully supervised systems, also trained on the full dataset, by 3.2% (Top-1 accuracy). Note that with only half the dataset (50%), the CPC ResNet matched the accuracy of fully supervised NN’s trained on 100% of the data.
Finally, to show the generality of CPC representations— by taking the CPC2 ResNet and using transfer learning for object detection (PASCAL-VOC 2007 dataset), it achieves new State of the Art performance with 76.6% mAP, surpassing the previous record by 2%.
Why does CPC work? The core concept here, per the author’s hypothesis, is that CPC may allow better spatial representations, thus moving artificial systems closer to biological representations and helping close the gap between the two.
Which ties in with a general principle they note below:
A new principle for deep learning — “good representations
should make the spatio-temporal variability in natural signals more predictable.” (quote from paper)
How does CPC work?
In simplified form, CPC2 works by taking the following steps:
1- Divide an image into overlapping squares, or ‘patches’.
2 — Run each patch through a feature extractor (blue stack in the image above), culminating in a mean pooling layer and thus a final representation vector (thin spikes above).
3 — Combine local feature vectors into groups, and create masked ConvNets, based on locality (red blocks in image above) and a dividing threshold (i.e. top vs bottom). In this example above, context networks above the center are being created.
4 — Use context networks (masked ConvNet) to predict/recognize feature vectors opposite a dividing line of the feature vectors. In this case, predicting which feature vectors are below the center (see the arrows pointing down from the red blocks) from amongst negative samples of feature vectors from other images.
The quality of the feature vector predictions is measured using a contrastive loss and hence the name, Contrastive Predictive Coding.
The full loss function is termed InfoNCE, inspired by Noise Contrastive Estimation. NCE has been shown to maximize the amount of mutual information between the given context vector and the target feature vector.
5 — Discard the masked ConvNet and replace with a standard linear classifier and begin training/evaluation with labeled data.
By doing the above, the NN is forced to build better representations and thus reduce the need for more labeled data and be able to generalize better.
From CPC to CPC 2.0:
CPC v1 was born with the original paper by Deep Mind: Representation Learning with Contrastive Predictive Coding
The authors built on this to create CPC v2 by stacking a whole series of improvements resulting in a large jump in final accuracy as shown above. Details are in the paper but I’ll highlight two of the biggest jumps above.
Data augmentation (PA above) — the authors found that by dropping 2 out of 3 color channels on the images served as an excellent augmentation (+3%). They then improved that further by adding more augmentations such as shearing, rotation, color transforms for another 4.5% gain.
Larger model (MC above) — by moving from a ResNet-101 to ResNet-161 (customized) the added a 5% jump. In addition, by integrating larger patches, they were able to boost results another +2%.
Other changes were having the predictions run in all directions and just top vs bottom. i.e. left vs right, bottom vs top. In sum this added another +4.5% (see chart above).
Conclusion: CPC 2.0 sets a new benchmark for ‘unsupervised learning’ and shows a new way to help NN’s build better representations and thus learn in a manner closer to human/biological methods. The key concept shown is that representations that reduce spatio-temporal variability are the future path for AI in general as that will reduce the data requirements for learning.
The authors note that CPC is task agnostic — while the paper focuses on vision, language, audio, etc. are all candidates for CPC style training, and may be especially useful for robotics where multi-modal inputs are used but data is scarce.
You’ll be able to use CPC 2.0 in your own projects shortly…b/c:
Open Source coming soon! To quote the authors: “ We will open-source our implementation and pre-trained models to make these techniques widely accessible.”