Docs
API
norm = nn.GroupNorm(2, 12) t = Tensor.rand(2, 12, 4, 4) * 2 + 1 print(t.mean().item(), t.std().item())
t = norm(t) print(t.mean().item(), t.std().item())