Implementation of models
If you want to implement new models you can use following template:
class ModelWrapper(nn.Module):
def __init__(self):
self.model = Model()
self.criterion = LossFn()
def preprocess(self, x:Tensor, target:List[Dict[str,Tensor]]=None):
annotations = None if target is None else FnIn(target)
return x, annotations
def forward(self, *x):
imgs, targets = x if len(x)==2 else (x[0], None)
imgs, targets = self.preprocess(imgs, targets)
preds = self.model(imgs)
if targets is None:
return self.postprocess(preds)
else:
return self.criterion(preds, targets)
def postprocess(self, preds) -> List[Dict[str,Tensor]]:
return FnOut(preds)
preprocess
receives targets from ObjDetAdapter
as a list of dicts with the keys boxes
, labels
and optionally masks
with tensors as values. The tensors have following shape (n: number of objects):
labels: [n]
boxes: [n, 4]
masks: [n, h, w]
Then it converts this input to the annotations required by the model (FnIn
).
If there are targets present forward
returns the loss (Dict[str,Tensor]
), if not it calls postprocess
, which converts the predictions (FnOut
) to a list of dicts with the keys boxes
, labels
, scores
and optionally masks
with tensors as values, which are getting passed to the metrics.