ChatGPT解决这个技术问题 Extra ChatGPT

model.eval() 在 pytorch 中做了什么?

我应该什么时候使用 .eval()?我知道它应该允许我“评估我的模型”。如何将其关闭以进行培训?

使用 .eval() 训练示例 code

这回答了你的问题了吗? What does model.train() do in pytorch?
是否有一个标志来检测模型是否处于评估模式?例如mdl.is_eval()
我建议对任何具有良好文档的工具有任何疑问,请查看文档:pytorch.org/docs/stable/generated/torch.nn.Module.html。如果文档不清楚 - 只需对一些在训练/优化变量中以不同模式工作的计算块进行小注释,并将其用于进行预测。一个示例是该模型:arxiv.org/abs/1502.03167
它通过执行 self.train(False) 简单地通过 self.training = training 递归地更改所有模块的 self.training。事实上,这就是 self.train 所做的,将所有模块的标志递归地更改为 true。见代码:github.com/pytorch/pytorch/blob/…

t
trsvchn

model.eval() 是模型中某些特定层/部分的一种开关,它们在训练和推理(评估)期间表现不同。例如,Dropouts Layers、BatchNorm Layers 等。您需要在模型评估期间关闭它们,.eval() 会为您完成。此外,评估/验证的常见做法是使用 torch.no_grad()model.eval() 来关闭梯度计算:

# evaluate model:
model.eval()

with torch.no_grad():
    ...
    out_data = model(data)
    ...

但是,不要忘记在 eval 步骤后返回 training 模式:

# training step
...
model.train()
...

torch.no_grad() 是一个上下文管理器,因此您应该以 with torch.no_grad(): 的形式使用它,以保证离开 with ... 块模型时会自动打开梯度计算
因此,model.train()model.eval() 仅对图层有效,而不对渐变有效,默认情况下渐变补偿已打开,但在评估期间使用上下文管理器 torch.no_grad() 可以让您轻松关闭然后自动打开渐变补偿结尾
为什么我们需要在 Eval 上关闭 grad comp?
@shtse8 我们在评估期间不计算或使用梯度,因此关闭 autograd 将加快执行速度并减少内存使用
@NagabhushanSN 是的!它们以递归方式工作,.eval() 看起来像这样:for module in self.children(): module.train(False).train()for module in self.children(): module.train(True)
i
iacob

model.train() model.eval() 将模型设置为训练模式: • 归一化层 1 使用每批统计 • 激活 Dropout 层 2 将模型设置为评估(推理)模式: • 归一化层使用运行统计 • 停用 Dropout 层 等效到model.train(假)。

您可以通过运行 model.train() 关闭评估模式。您应该在将模型作为推理引擎运行时使用它——即在测试、验证和预测时(尽管实际上,如果您的模型不包含任何 differently behaving layers,它不会有任何区别)。

例如 BatchNorm、InstanceNorm 这包括 RNN 模块的子模块等。


i
iacob

model.evaltorch.nn.Module 的方法:

eval() 将模块设置为评估模式。这仅对某些模块有任何影响。请参阅特定模块的文档以了解它们在训练/评估模式下的行为的详细信息,如果它们受到影响,例如 Dropout、BatchNorm 等。这相当于 self.train(False)。

Umang Gupta 很好地解释了相反的方法 model.train


是否有一个标志来检测模型是否处于评估模式?例如mdl.is_eval()
是的,@CharlieParker self.training 标志说明 stackoverflow.com/a/56828547/5884955
G
Gulzar

上述答案的额外补充:

我最近开始使用 Pytorch-lightning,它包含了训练-验证-测试管道中的大部分样板。

除其他外,它通过允许包装 evaltraintrain_stepvalidation_step 回调使 model.eval()model.train() 几乎是多余的,因此您永远不会忘记。


抱歉,我看到的是 delete 而不是 elaborate。有点不喜欢这个问题:Lightning 为你处理训练/测试循环,你只需要定义 train_stepval_step 等等。 model.eval()model.train() 在后台完成,您不必担心它们。我建议您观看他们的一些视频,这是一项非常值得 30 分钟的投资。