import torch import torch.nn as nn class GlobalClassifier(nn.Module): """Add a global average pooling before the classifier""" def __init__( self, classifier: nn.Module ): super(GlobalClassifier, self).__init__() self.classifier = classifier def forward(self, x: torch.Tensor): x = torch.mean(x, dim=1) return self.classifier(x) class LastClassifier(nn.Module): """Select last frame to do the classification""" def __init__( self, classifier: nn.Module ): super(LastClassifier, self).__init__() self.classifier = classifier def forward(self, x: torch.Tensor): x = x[:, -1, :] return self.classifier(x)