Docs
API
norm = nn.InstanceNorm(3) t = Tensor.rand(2, 3, 4, 4) * 2 + 1 print(t.mean().item(), t.std().item())
t = norm(t) print(t.mean().item(), t.std().item())