Implementation of models
If you want to implement new models you can use following template:
class ModelWrapper(nn.Module):
def __init__(self):
self.model = Model()
self.criterion = LossFn()
def preprocess(self, x:Tensor, target:List[Dict[str,Tensor]]=None):
annotations = None if target is None else FnIn(target)
return x, annotations
def forward(self, *x):
imgs, targets = x if len(x)==2 else (x[0], None)
imgs, targets = self.preprocess(imgs, targets)
preds = self.model(imgs)
if targets is None:
return self.postprocess(preds)
else:
return self.criterion(preds, targets)
def postprocess(self, preds) -> List[Dict[str,Tensor]]:
return FnOut(preds)
preprocess receives targets from ObjDetAdapter as a list of dicts with the keys boxes, labels and optionally masks with tensors as values. The tensors have following shape (n: number of objects):
labels: [n]boxes: [n, 4]masks: [n, h, w]
Then it converts this input to the annotations required by the model (FnIn).
If there are targets present forward returns the loss (Dict[str,Tensor]), if not it calls postprocess, which converts the predictions (FnOut) to a list of dicts with the keys boxes, labels, scores and optionally masks with tensors as values, which are getting passed to the metrics.