Skip to content

Commit a18e08e

Browse files
committed
Added scene image sampler 🌉
1 parent ceabef4 commit a18e08e

File tree

6 files changed

+272
-15
lines changed

6 files changed

+272
-15
lines changed

configs/coco_scene_images_transformer.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@ model:
88
params:
99
vocab_size: 8192
1010
block_size: 348 # = 256 + 92 = dim(vqgan_latent_space,16x16) + dim(conditional_builder.embedding_dim)
11-
n_layer: 32
11+
n_layer: 40
1212
n_head: 16
13-
n_embd: 912
13+
n_embd: 1408
14+
embd_pdrop: 0.1
15+
resid_pdrop: 0.1
16+
attn_pdrop: 0.1
1417
first_stage_config:
1518
target: taming.models.vqgan.VQModel
1619
params:
@@ -59,7 +62,7 @@ data:
5962
crop_method: random-1d
6063
random_flip: true
6164
use_group_parameter: true
62-
encode_crop: true
65+
encode_crop: false
6366
validation:
6467
target: taming.data.annotated_objects_coco.AnnotatedObjectsCoco
6568
params:
@@ -71,7 +74,7 @@ data:
7174
min_object_area: 0.00001
7275
min_objects_per_image: 2
7376
max_objects_per_image: 30
74-
crop_method: random-1d
75-
random_flip: true
77+
crop_method: center
78+
random_flip: false
7679
use_group_parameter: true
7780
encode_crop: true

main.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
1212
from pytorch_lightning.utilities.distributed import rank_zero_only
1313

14+
from taming.data.utils import custom_collate
15+
16+
1417
def get_obj_from_str(string, reload=False):
1518
module, cls = string.rsplit(".", 1)
1619
if reload:
@@ -160,16 +163,16 @@ def setup(self, stage=None):
160163

161164
def _train_dataloader(self):
162165
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
163-
num_workers=self.num_workers, shuffle=True)
166+
num_workers=self.num_workers, shuffle=True, collate_fn=custom_collate)
164167

165168
def _val_dataloader(self):
166169
return DataLoader(self.datasets["validation"],
167170
batch_size=self.batch_size,
168-
num_workers=self.num_workers)
171+
num_workers=self.num_workers, collate_fn=custom_collate)
169172

170173
def _test_dataloader(self):
171174
return DataLoader(self.datasets["test"], batch_size=self.batch_size,
172-
num_workers=self.num_workers)
175+
num_workers=self.num_workers, collate_fn=custom_collate)
173176

174177

175178
class SetupCallback(Callback):

scripts/make_scene_samples.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import glob
2+
import os
3+
import sys
4+
from itertools import product
5+
from pathlib import Path
6+
from typing import Literal, List, Optional, Tuple
7+
8+
import numpy as np
9+
import torch
10+
from omegaconf import OmegaConf
11+
from pytorch_lightning import seed_everything
12+
from torch import Tensor
13+
from torchvision.utils import save_image
14+
from tqdm import tqdm
15+
16+
from scripts.make_samples import get_parser, load_model_and_dset
17+
from taming.data.conditional_builder.object_center_points_builder import CoordinatesCenterPointsConditionalBuilder
18+
from taming.data.helper_types import BoundingBox, Annotation
19+
from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
20+
from taming.models.cond_transformer import Net2NetTransformer
21+
22+
seed_everything(42424242)
23+
device: Literal['cuda', 'cpu'] = 'cuda'
24+
first_stage_factor = 16
25+
trained_on_res = 256
26+
27+
28+
def _helper(coord: int, coord_max: int, coord_window: int) -> (int, int):
29+
assert 0 <= coord < coord_max
30+
coord_desired_center = (coord_window - 1) // 2
31+
return np.clip(coord - coord_desired_center, 0, coord_max - coord_window)
32+
33+
34+
def get_crop_coordinates(x: int, y: int) -> BoundingBox:
35+
WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
36+
x0 = _helper(x, WIDTH, first_stage_factor) / WIDTH
37+
y0 = _helper(y, HEIGHT, first_stage_factor) / HEIGHT
38+
w = first_stage_factor / WIDTH
39+
h = first_stage_factor / HEIGHT
40+
return x0, y0, w, h
41+
42+
43+
def get_z_indices_crop_out(z_indices: Tensor, predict_x: int, predict_y: int) -> Tensor:
44+
WIDTH, HEIGHT = desired_z_shape[1], desired_z_shape[0]
45+
x0 = _helper(predict_x, WIDTH, first_stage_factor)
46+
y0 = _helper(predict_y, HEIGHT, first_stage_factor)
47+
no_images = z_indices.shape[0]
48+
cut_out_1 = z_indices[:, y0:predict_y, x0:x0+first_stage_factor].reshape((no_images, -1))
49+
cut_out_2 = z_indices[:, predict_y, x0:predict_x]
50+
return torch.cat((cut_out_1, cut_out_2), dim=1)
51+
52+
53+
@torch.no_grad()
54+
def sample(model: Net2NetTransformer, annotations: List[Annotation], dataset: AnnotatedObjectsDataset,
55+
conditional_builder: CoordinatesCenterPointsConditionalBuilder, no_samples: int,
56+
temperature: float, top_k: int) -> Tensor:
57+
x_max, y_max = desired_z_shape[1], desired_z_shape[0]
58+
59+
annotations = [a._replace(category_no=dataset.get_category_number(a.category_id)) for a in annotations]
60+
61+
recompute_conditional = any((desired_resolution[0] > trained_on_res, desired_resolution[1] > trained_on_res))
62+
if not recompute_conditional:
63+
crop_coordinates = get_crop_coordinates(0, 0)
64+
conditional_indices = conditional_builder.build(annotations, crop_coordinates)
65+
c_indices = conditional_indices.to(device).repeat(no_samples, 1)
66+
z_indices = torch.zeros((no_samples, 0), device=device).long()
67+
output_indices = model.sample(z_indices, c_indices, steps=x_max*y_max, temperature=temperature,
68+
sample=True, top_k=top_k)
69+
else:
70+
output_indices = torch.zeros((no_samples, y_max, x_max), device=device).long()
71+
for predict_y, predict_x in tqdm(product(range(y_max), range(x_max)), desc='sampling_image', total=x_max*y_max):
72+
crop_coordinates = get_crop_coordinates(predict_x, predict_y)
73+
z_indices = get_z_indices_crop_out(output_indices, predict_x, predict_y)
74+
conditional_indices = conditional_builder.build(annotations, crop_coordinates)
75+
c_indices = conditional_indices.to(device).repeat(no_samples, 1)
76+
new_index = model.sample(z_indices, c_indices, steps=1, temperature=temperature, sample=True, top_k=top_k)
77+
output_indices[:, predict_y, predict_x] = new_index[:, -1]
78+
z_shape = (
79+
no_samples,
80+
model.first_stage_model.quantize.e_dim, # codebook embed_dim
81+
desired_z_shape[0], # z_height
82+
desired_z_shape[1] # z_width
83+
)
84+
x_sample = model.decode_to_img(output_indices, z_shape) * 0.5 + 0.5
85+
x_sample = x_sample.to('cpu')
86+
87+
plotter = conditional_builder.plot
88+
figure_size = (x_sample.shape[2], x_sample.shape[3])
89+
scene_graph = conditional_builder.build(annotations, (0., 0., 1., 1.))
90+
plot = plotter(scene_graph, dataset.get_textual_label_for_category_no, figure_size)
91+
return torch.cat((x_sample, plot.unsqueeze(0)))
92+
93+
94+
def get_resolution(resolution_str: str) -> (Tuple[int, int], Tuple[int, int]):
95+
if not resolution_str.count(',') == 1:
96+
raise ValueError("Give resolution as in 'height,width'")
97+
res_h, res_w = resolution_str.split(',')
98+
res_h = max(int(res_h), trained_on_res)
99+
res_w = max(int(res_w), trained_on_res)
100+
z_h = int(round(res_h/first_stage_factor))
101+
z_w = int(round(res_w/first_stage_factor))
102+
return (z_h, z_w), (z_h*first_stage_factor, z_w*first_stage_factor)
103+
104+
105+
def add_arg_to_parser(parser):
106+
parser.add_argument(
107+
"-R",
108+
"--resolution",
109+
type=str,
110+
default='256,256',
111+
help=f"give resolution in multiples of {first_stage_factor}, default is '256,256'",
112+
)
113+
parser.add_argument(
114+
"-C",
115+
"--conditional",
116+
type=str,
117+
default='objects_bbox',
118+
help=f"objects_bbox or objects_center_points",
119+
)
120+
parser.add_argument(
121+
"-N",
122+
"--n_samples_per_layout",
123+
type=int,
124+
default=4,
125+
help=f"how many samples to generate per layout",
126+
)
127+
return parser
128+
129+
130+
if __name__ == "__main__":
131+
sys.path.append(os.getcwd())
132+
133+
parser = get_parser()
134+
parser = add_arg_to_parser(parser)
135+
136+
opt, unknown = parser.parse_known_args()
137+
138+
ckpt = None
139+
if opt.resume:
140+
if not os.path.exists(opt.resume):
141+
raise ValueError("Cannot find {}".format(opt.resume))
142+
if os.path.isfile(opt.resume):
143+
paths = opt.resume.split("/")
144+
try:
145+
idx = len(paths)-paths[::-1].index("logs")+1
146+
except ValueError:
147+
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
148+
logdir = "/".join(paths[:idx])
149+
ckpt = opt.resume
150+
else:
151+
assert os.path.isdir(opt.resume), opt.resume
152+
logdir = opt.resume.rstrip("/")
153+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
154+
print(f"logdir:{logdir}")
155+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
156+
opt.base = base_configs+opt.base
157+
158+
if opt.config:
159+
if type(opt.config) == str:
160+
opt.base = [opt.config]
161+
else:
162+
opt.base = [opt.base[-1]]
163+
164+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
165+
cli = OmegaConf.from_dotlist(unknown)
166+
if opt.ignore_base_data:
167+
for config in configs:
168+
if hasattr(config, "data"):
169+
del config["data"]
170+
config = OmegaConf.merge(*configs, cli)
171+
desired_z_shape, desired_resolution = get_resolution(opt.resolution)
172+
conditional = opt.conditional
173+
174+
print(ckpt)
175+
gpu = True
176+
eval_mode = True
177+
show_config = False
178+
if show_config:
179+
print(OmegaConf.to_container(config))
180+
181+
dsets, model, global_step = load_model_and_dset(config, ckpt, gpu, eval_mode)
182+
print(f"Global step: {global_step}")
183+
184+
data_loader = dsets.val_dataloader()
185+
print(dsets.datasets["validation"].conditional_builders)
186+
conditional_builder = dsets.datasets["validation"].conditional_builders[conditional]
187+
188+
outdir = Path(opt.outdir).joinpath(f"{global_step:06}_{opt.top_k}_{opt.temperature}")
189+
outdir.mkdir(exist_ok=True, parents=True)
190+
print("Writing samples to ", outdir)
191+
192+
p_bar_1 = tqdm(enumerate(iter(data_loader)), desc='batch', total=len(data_loader))
193+
for batch_no, batch in p_bar_1:
194+
save_img: Optional[Tensor] = None
195+
for i, annotations in tqdm(enumerate(batch['annotations']), desc='within_batch', total=data_loader.batch_size):
196+
imgs = sample(model, annotations, dsets.datasets["validation"], conditional_builder,
197+
opt.n_samples_per_layout, opt.temperature, opt.top_k)
198+
save_image(imgs, outdir.joinpath(f'{batch_no:04}_{i:02}.png'), n_row=opt.n_samples_per_layout+1)

taming/data/annotated_objects_coco.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@
1212
COCO_PATH_STRUCTURE = {
1313
'train': {
1414
'top_level': '',
15-
'person_annotations': 'annotations/person_keypoints_train2017.json',
1615
'instances_annotations': 'annotations/instances_train2017.json',
1716
'stuff_annotations': 'annotations/stuff_train2017.json',
1817
'files': 'train2017'
1918
},
2019
'validation': {
2120
'top_level': '',
22-
'person_annotations': 'annotations/person_keypoints_val2017.json',
2321
'instances_annotations': 'annotations/instances_val2017.json',
2422
'stuff_annotations': 'annotations/stuff_val2017.json',
2523
'files': 'val2017'

taming/data/annotated_objects_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def no_classes(self) -> int:
9999
return len(self.categories)
100100

101101
@property
102-
def conditional_builders(self):
102+
def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
103103
# cannot set this up in init because no_classes is only known after loading data in init of superclass
104104
if self._conditional_builders is None:
105105
self._conditional_builders = {
@@ -109,15 +109,15 @@ def conditional_builders(self):
109109
self.no_tokens,
110110
self.encode_crop,
111111
self.use_group_parameter,
112-
getattr(self, 'self.use_additional_parameters', False)
112+
getattr(self, 'use_additional_parameters', False)
113113
),
114114
'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
115115
self.no_classes,
116116
self.max_objects_per_image,
117117
self.no_tokens,
118118
self.encode_crop,
119119
self.use_group_parameter,
120-
getattr(self, 'self.use_additional_parameters', False)
120+
getattr(self, 'use_additional_parameters', False)
121121
)
122122
}
123123
return self._conditional_builders

taming/data/utils.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
import collections
12
import os
2-
import numpy as np
3+
import tarfile
34
import urllib
4-
import tarfile, zipfile
5+
import zipfile
56
from pathlib import Path
7+
8+
import numpy as np
9+
import torch
10+
from taming.data.helper_types import Annotation
11+
from torch._six import string_classes
12+
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
613
from tqdm import tqdm
714

815

@@ -112,3 +119,51 @@ def quadratic_crop(x, bbox, alpha=1.0):
112119
xmin = int(center[0] - l / 2)
113120
ymin = int(center[1] - l / 2)
114121
return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
122+
123+
124+
def custom_collate(batch):
125+
r"""source: pytorch 1.9.0, only one modification to original code """
126+
127+
elem = batch[0]
128+
elem_type = type(elem)
129+
if isinstance(elem, torch.Tensor):
130+
out = None
131+
if torch.utils.data.get_worker_info() is not None:
132+
# If we're in a background process, concatenate directly into a
133+
# shared memory tensor to avoid an extra copy
134+
numel = sum([x.numel() for x in batch])
135+
storage = elem.storage()._new_shared(numel)
136+
out = elem.new(storage)
137+
return torch.stack(batch, 0, out=out)
138+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
139+
and elem_type.__name__ != 'string_':
140+
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
141+
# array of string classes and object
142+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
143+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
144+
145+
return custom_collate([torch.as_tensor(b) for b in batch])
146+
elif elem.shape == (): # scalars
147+
return torch.as_tensor(batch)
148+
elif isinstance(elem, float):
149+
return torch.tensor(batch, dtype=torch.float64)
150+
elif isinstance(elem, int):
151+
return torch.tensor(batch)
152+
elif isinstance(elem, string_classes):
153+
return batch
154+
elif isinstance(elem, collections.abc.Mapping):
155+
return {key: custom_collate([d[key] for d in batch]) for key in elem}
156+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
157+
return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
158+
if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
159+
return batch # added
160+
elif isinstance(elem, collections.abc.Sequence):
161+
# check to make sure that the elements in batch have consistent size
162+
it = iter(batch)
163+
elem_size = len(next(it))
164+
if not all(len(elem) == elem_size for elem in it):
165+
raise RuntimeError('each element in list of batch should be of equal size')
166+
transposed = zip(*batch)
167+
return [custom_collate(samples) for samples in transposed]
168+
169+
raise TypeError(default_collate_err_msg_format.format(elem_type))

0 commit comments

Comments
 (0)