diff --git a/classify.py b/classify.py index dd123d0..0abf8cc 100644 --- a/classify.py +++ b/classify.py @@ -81,6 +81,18 @@ def run_training(args): model = models.__dict__[args.arch](args.pretrained) model = torch.nn.DataParallel(model).cuda() + # ===== NEWEdit: adjust classifier to 2 classes (real/fake) ===== + NUM_CLASSES = 2 + base_model = model.module # unwrap DataParallel + # DRN uses a Conv2d "fc" with out_dim -> num_classes, then avgpoo + in_channels = base_model.out_dim # this is channels[-1] + + base_model.fc = nn.Conv2d(in_channels, NUM_CLASSES, kernel_size=1, stride=1, padding=0, bias=True) + nn.init.kaiming_normal_(base_model.fc.weight, mode='fan_out', nonlinearity='relu') + if base_model.fc.bias is not None: + nn.init.constant_(base_model.fc.bias, 0.) + # =========================================================== + best_prec1 = 0 @@ -100,14 +112,14 @@ def run_training(args): cudnn.benchmark = True # Data loading code - traindir = os.path.join(args.data, 'train') - valdir = os.path.join(args.data, 'val') + traindir = os.path.join(args.data, 'training') + valdir = os.path.join(args.data, 'validation') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_loader = torch.utils.data.DataLoader( datasets.ImageFolder(traindir, transforms.Compose([ - transforms.RandomSizedCrop(224), + transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, @@ -117,7 +129,7 @@ def run_training(args): val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(valdir, transforms.Compose([ - transforms.Scale(256), + transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, @@ -162,6 +174,17 @@ def test_model(args): model = models.__dict__[args.arch](args.pretrained) model = torch.nn.DataParallel(model).cuda() + # ===== NEW: adjust classifier to 2 classes (real/fake) ===== + NUM_CLASSES = 2 + base_model = model.module # unwrap DataParallel + # DRN uses a Conv2d "fc" with out_dim -> num_classes, then avgpool + in_channels = base_model.out_dim # this is channels[-1] + base_model.fc = nn.Conv2d(in_channels, NUM_CLASSES, kernel_size=1, stride=1, padding=0, bias=True) + nn.init.kaiming_normal_(base_model.fc.weight, mode='fan_out', nonlinearity='relu') + if base_model.fc.bias is not None: + nn.init.constant_(base_model.fc.bias, 0.) + # =========================================================== + if args.resume: if os.path.isfile(args.resume): @@ -178,12 +201,12 @@ def test_model(args): cudnn.benchmark = True # Data loading code - valdir = os.path.join(args.data, 'val') + valdir = os.path.join(args.data, 'validation') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) t = transforms.Compose([ - transforms.Scale(args.scale_size), + transforms.Resize(256),, transforms.CenterCrop(args.crop_size), transforms.ToTensor(), normalize]) @@ -212,17 +235,14 @@ def train(args, train_loader, model, criterion, optimizer, epoch): # measure data loading time data_time.update(time.time() - end) - target = target.cuda(async=True) - input_var = torch.autograd.Variable(input) - target_var = torch.autograd.Variable(target) + input = input.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) - # compute output - output = model(input_var) - loss = criterion(output, target_var) + output = model(input) + loss = criterion(output, target) + + losses.update(loss.item(), input.size(0)) - # measure accuracy and record loss - prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) - losses.update(loss.data[0], input.size(0)) top1.update(prec1[0], input.size(0)) top5.update(prec5[0], input.size(0)) @@ -257,17 +277,15 @@ def validate(args, val_loader, model, criterion): end = time.time() for i, (input, target) in enumerate(val_loader): - target = target.cuda(async=True) - input_var = torch.autograd.Variable(input, volatile=True) - target_var = torch.autograd.Variable(target, volatile=True) + input = input.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + with torch.no_grad(): + output = model(input) + loss = criterion(output, target) - # compute output - output = model(input_var) - loss = criterion(output, target_var) + losses.update(loss.item(), input.size(0)) - # measure accuracy and record loss - prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) - losses.update(loss.data[0], input.size(0)) top1.update(prec1[0], input.size(0)) top5.update(prec5[0], input.size(0))