Implementation of models

If you want to implement new models you can use following template:

class ModelWrapper(nn.Module):

    def __init__(self):
        self.model = Model()
        self.criterion = LossFn()

    def preprocess(self, x:Tensor, target:List[Dict[str,Tensor]]=None):
        annotations = None if target is None else FnIn(target)
        return x, annotations

    def forward(self, *x):
        imgs, targets = x if len(x)==2 else (x[0], None)
        imgs, targets = self.preprocess(imgs, targets)
        preds = self.model(imgs)
        if targets is None:
            return self.postprocess(preds)
        else:
            return self.criterion(preds, targets)

    def postprocess(self, preds) -> List[Dict[str,Tensor]]:
        return FnOut(preds)

preprocess receives targets from ObjDetAdapter as a list of dicts with the keys boxes, labels and optionally masks with tensors as values. The tensors have following shape (n: number of objects):

  • labels: [n]
  • boxes: [n, 4]
  • masks: [n, h, w]

Then it converts this input to the annotations required by the model (FnIn).

If there are targets present forward returns the loss (Dict[str,Tensor]), if not it calls postprocess, which converts the predictions (FnOut) to a list of dicts with the keys boxes, labels, scores and optionally masks with tensors as values, which are getting passed to the metrics.