Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 42 additions & 24 deletions classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ def run_training(args):
model = models.__dict__[args.arch](args.pretrained)

model = torch.nn.DataParallel(model).cuda()
# ===== NEWEdit: adjust classifier to 2 classes (real/fake) =====
NUM_CLASSES = 2
base_model = model.module # unwrap DataParallel
# DRN uses a Conv2d "fc" with out_dim -> num_classes, then avgpoo
in_channels = base_model.out_dim # this is channels[-1]

base_model.fc = nn.Conv2d(in_channels, NUM_CLASSES, kernel_size=1, stride=1, padding=0, bias=True)
nn.init.kaiming_normal_(base_model.fc.weight, mode='fan_out', nonlinearity='relu')
if base_model.fc.bias is not None:
nn.init.constant_(base_model.fc.bias, 0.)
# ===========================================================


best_prec1 = 0

Expand All @@ -100,14 +112,14 @@ def run_training(args):
cudnn.benchmark = True

# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
traindir = os.path.join(args.data, 'training')
valdir = os.path.join(args.data, 'validation')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

train_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(traindir, transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
Expand All @@ -117,7 +129,7 @@ def run_training(args):

val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Scale(256),
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
Expand Down Expand Up @@ -162,6 +174,17 @@ def test_model(args):
model = models.__dict__[args.arch](args.pretrained)

model = torch.nn.DataParallel(model).cuda()
# ===== NEW: adjust classifier to 2 classes (real/fake) =====
NUM_CLASSES = 2
base_model = model.module # unwrap DataParallel
# DRN uses a Conv2d "fc" with out_dim -> num_classes, then avgpool
in_channels = base_model.out_dim # this is channels[-1]
base_model.fc = nn.Conv2d(in_channels, NUM_CLASSES, kernel_size=1, stride=1, padding=0, bias=True)
nn.init.kaiming_normal_(base_model.fc.weight, mode='fan_out', nonlinearity='relu')
if base_model.fc.bias is not None:
nn.init.constant_(base_model.fc.bias, 0.)
# ===========================================================


if args.resume:
if os.path.isfile(args.resume):
Expand All @@ -178,12 +201,12 @@ def test_model(args):
cudnn.benchmark = True

# Data loading code
valdir = os.path.join(args.data, 'val')
valdir = os.path.join(args.data, 'validation')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

t = transforms.Compose([
transforms.Scale(args.scale_size),
transforms.Resize(256),,
transforms.CenterCrop(args.crop_size),
transforms.ToTensor(),
normalize])
Expand Down Expand Up @@ -212,17 +235,14 @@ def train(args, train_loader, model, criterion, optimizer, epoch):
# measure data loading time
data_time.update(time.time() - end)

target = target.cuda(async=True)
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
input = input.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)

# compute output
output = model(input_var)
loss = criterion(output, target_var)
output = model(input)
loss = criterion(output, target)

losses.update(loss.item(), input.size(0))

# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))

Expand Down Expand Up @@ -257,17 +277,15 @@ def validate(args, val_loader, model, criterion):

end = time.time()
for i, (input, target) in enumerate(val_loader):
target = target.cuda(async=True)
input_var = torch.autograd.Variable(input, volatile=True)
target_var = torch.autograd.Variable(target, volatile=True)
input = input.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)

with torch.no_grad():
output = model(input)
loss = criterion(output, target)

# compute output
output = model(input_var)
loss = criterion(output, target_var)
losses.update(loss.item(), input.size(0))

# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))

Expand Down