diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py index a65767d084b6..2e1c01fba138 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py @@ -881,7 +881,7 @@ def collate_fn(examples): elif args.loss_type == "hinge": loss = torch.relu(1 - args.beta_dpo * logits).mean() elif args.loss_type == "ipo": - losses = (logits - 1 / (2 * args.beta)) ** 2 + losses = (logits - 1 / (2 * args.beta_dpo)) ** 2 loss = losses.mean() else: raise ValueError(f"Unknown loss type {args.loss_type}")