TridentNet attempts to tackle the problem of multi-scale objects in 2D images through dilated convolutions. The changes are applied on Faster-RCNN, hence one must have at least a basic understanding of two-stage object detectors (e.g. Faster-RCNN) first to understand TridentNet.
Let’s say we would like to detect the Giraffes in this image:
Let’s take a classroom scenario. Assume that we can only use a neural network of depth 1, single channel of 2×2 Dilated Conv2D. We would like to have this layer to learn the appearance of a Giraffe.
In the picture below, each red rectangle represents the receptive field of our 2×2 Dilated Conv2D filter at three different locations during the convolution process. Each filled-green rectangle represent receptive field of one single weight (float32) of Dilated Conv2d kernel (in reality, each weight gets multiplied with one single pixel value, but for the sake of explanation, just go with it please).
While sliding our 2×2 dilated convolution kernel over the image, we find it hard to find a common dilation rate that would fit the size of all Giraffes. In the picture below, the Conv2D has a dilation rate that is failing to interact with enough pixels of the smallest Giraffe and hence cannot detect it. If we use a smaller dilation rate, it would succeed in capturing the smallest Giraffe but would fail to capture the largest one.
FYI: This is an example of why objects with varying scales are always a problem in Computer Vision.
The proposed solution: Have three parallel Dilated Conv2D layers (the depth is still 1, channel count is still 1, kernel size is still 2×2) with 3 different dilation rates. Also, share the weights, i.e. still have only 4 weights (for 2×2 kernel) overall.
As evident now, a single depth neural network can fit all three Giraffes by just using 3 dilation rates. Also, notice that the 4 quadrants of all Giraffes look the same, hence the weight sharing makes sense.
It’s a fairly simple idea. This is similar to using 3 neural networks, except that here we would be sharing all the weights across the 3 branches. The image below describes how the authors modified a ResNet block to a Trident block. The dilation rates are used only for the 3×3 convolutions and the 1×1 convolutions stay the same.
- So, does that mean we would be doing 3x computation (3 times slow)?
Yes. The training is definitely doing 3 times the computation, however, during the inference phase we could just use the middle branch and the accuracy does not drop “that much” (authors call this version TridentNet Fast).
- Why 3 branches?
Because 2 or 4 branches didn’t perform any better than 3.
- Any difference in model parameters count?
No. The model weights/layers etc. stay the same since the parameters are shared across branches.
TridentNet is a modification of Faster-RCNN. This is the only two-shot detector that the authors used in their work. Both stages (RPN and R-CNN) of Faster-RCNN were converted into 3 branch mode. Some more info:
- Not all existing convolution blocks of ResNet were converted to 3-branches. Authors recommend only converting a few Conv sub-blocks inside the conv4 block of ResNet. More details in sections below.
- During training, each branch is only learning from the ground truth boxes for specific scales. This is called scale-specific training. Small, Medium, Large branches learn their weights from boxes of scale (0 to 90), (30 to 160) and (90 to ∞) respectively. More details in sections below.
Input images’ shorter side is scaled to 800 pixels before feeding it to the network.
This scale-aware training scheme could be applied to both RPN and R-CNN. For RPN, we select ground truth boxes that are valid for each branch according to Eq. 1 during anchor label assignment. Similarly, we remove all invalid proposals for each branch during the training of R-CNNAuthors describing modifications in Faster-RCNN
During the RPN stage in Faster-RCNN, the entire image passes through the RPN and generates region proposals that are likely to contain objects. In TridentNet, the image passes through all 3 branches. While computing loss, for each branch we only take the ground truth boxes whose scale(area) falls with the branch’s responsibility:
- Small branch: Ground truth boxes of scale within range 0 to 90
- Medium branch: Ground truth boxes of scale within range 30 to 160
- Large branch: Ground truth boxes of scale within range 90 to Infinity
Here, the scale of a ground truth box is defined as (this is the Eq. 1 as mentioned in the author’s quote above)
scale = sqrt(box_width X box_height)
Each branch produces 12,000 proposals. These are then filtered by NMS (hard or soft) which gives out 500 proposals. Out of these 500 proposals, 128 are sampled (I don’t what sampling strategy is used here) and sent to the second stage (R-CNN) of object detector.
During the R-CNN stage, for each branch, we use the ROIs which has scales within the correct range. ROIs outside the branch’s responsibility are ignored.
Two inference modes are available: Default and Fast.
In the default mode, the image is passed through all three branches of the network. An NMS is performed on the predictions and that’s it.
Below is the snapshot of their performance results from the paper. The “deformable” backbone indicates the experiment where dilated convolutions (mediocre results) were replaced by deformable convolutions (better results).
In the fast mode, we realize that the default mode is using 3 times the computation power. We don’t like that. Hence, once we train our model with all 3 branches, then we delete the Small and Large branches in inference mode and derived our predictions from the Medium branch only. NMS is performed as usual.
This brings the compute-hunger back to the same as Faster-RCNN. Below table lists the performance change vs. branches:
|Small branch only||31.5|
|Large branch only||31.9|
|All branches enabled||40.6|
Note: these results are on minival set with a slightly different network. Hence, don’t compare them with numbers on test-dev above.
As listed, the performance does take a minor hit with this “optimization”. This mode of inference is termed TridentNet Fast.
Thank you for reading this post. I haven’t looked into their code yet, so some details may be incomplete. Please leave a comment or feel free to reach out to me if you find any issue with this post.