大致思路如下:
模型A 预测输入 x 的标签 y_pred
模型B 根据输入的 y_pred 和真实标签y输出一个数值 z 作为模型A 的损失
根据z计算模型A的梯度,并更新模型A
更新的A重新预测 x 的标签记为 y_pred_new
此时计算 y_pred 和 y_pred_new 的交叉熵损失loss,更新模型B,但是在loss对模型B的求梯度时,梯度全为none
代码如下
公式过程如下图
推测问题在于③位置上的对θ更新时求导,因为grads=tape.gradient()求出来是tensor,相当于βx变成了tensor,不是variable了,导致⑤位置求导的时候无法对β求导,我该如何解决这个问题?
问个题外话, 梯度为 none 就是算法里常说的 《梯度消失》吗
不,这里是因为断流导致的,并不是常说的梯度消失
@h19615j: thx 加油