def get_parnum(m): total = sum([param.nelement() for param in m.parameters()]) print("Number of parameter: %.2fM" % (total / 1e6))