diff --git a/utils/loss.py b/utils/loss.py index 2b1d968f8fe..47cc0df8175 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -1007,7 +1007,7 @@ def build_targets(self, p, targets, imgs): all_gj.append(gj) all_gi.append(gi) all_anch.append(anch[i][idx]) - from_which_layer.append(torch.ones(size=(len(b),)) * i) + from_which_layer.append((torch.ones(size=(len(b),), device=targets.device) * i)) fg_pred = pi[b, a, gj, gi] p_obj.append(fg_pred[:, obj_idx:(obj_idx+1)]) @@ -1327,7 +1327,7 @@ def build_targets(self, p, targets, imgs): all_gj.append(gj) all_gi.append(gi) all_anch.append(anch[i][idx]) - from_which_layer.append(torch.ones(size=(len(b),)) * i) + from_which_layer.append((torch.ones(size=(len(b),), device=targets.device) * i)) fg_pred = pi[b, a, gj, gi] p_obj.append(fg_pred[:, 4:5]) @@ -1480,7 +1480,7 @@ def build_targets2(self, p, targets, imgs): all_gj.append(gj) all_gi.append(gi) all_anch.append(anch[i][idx]) - from_which_layer.append(torch.ones(size=(len(b),)) * i) + from_which_layer.append((torch.ones(size=(len(b),), device=targets.device) * i)) fg_pred = pi[b, a, gj, gi] p_obj.append(fg_pred[:, 4:5])