This commit is contained in:
FlintyLemming
2024-09-25 15:18:31 +08:00
commit e1911954ed
99 changed files with 38062 additions and 0 deletions

View File

@@ -0,0 +1,627 @@
import einops
import torch
import torch as th
import torch.nn as nn
import copy
from easydict import EasyDict as edict
from ..ldm.modules.diffusionmodules.util import (
conv_nd,
linear,
zero_module,
timestep_embedding,
)
from einops import rearrange, repeat
from torchvision.utils import make_grid
from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
from ..ldm.models.diffusion.ddpm import LatentDiffusion
from ..ldm.util import log_txt_as_img, exists, instantiate_from_config
# from ldm.models.diffusion.ddim import DDIMSampler
from .ddim_hacked import DDIMSampler
from ..ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from .recognizer import TextRecognizer, create_predictor
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class ControlledUnetModel(UNetModel):
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
hs = []
with torch.no_grad():
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
if self.use_fp16:
t_emb = t_emb.half()
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
if control is not None:
h += control.pop()
for i, module in enumerate(self.output_blocks):
if only_mid_control or control is None:
h = torch.cat([h, hs.pop()], dim=1)
else:
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
return self.out(h)
class ControlNet(nn.Module):
def __init__(
self,
image_size,
in_channels,
model_channels,
glyph_channels,
position_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.dims = dims
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
self.use_fp16 = use_fp16
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
self.glyph_block = TimestepEmbedSequential(
conv_nd(dims, glyph_channels, 8, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 8, 8, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 8, 16, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 96, 96, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
nn.SiLU(),
)
self.position_block = TimestepEmbedSequential(
conv_nd(dims, position_channels, 8, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 8, 8, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 8, 16, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1),
nn.SiLU(),
conv_nd(dims, 32, 64, 3, padding=1, stride=2),
nn.SiLU(),
)
self.fuse_block = zero_module(conv_nd(dims, 256+64+4, model_channels, 3, padding=1))
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch))
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self.middle_block_out = self.make_zero_conv(ch)
self._feature_size += ch
def make_zero_conv(self, channels):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, text_info, timesteps, context, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
if self.use_fp16:
t_emb = t_emb.half()
emb = self.time_embed(t_emb)
# guided_hint from text_info
B, C, H, W = x.shape
glyphs = torch.cat(text_info['glyphs'], dim=1).sum(dim=1, keepdim=True)
positions = torch.cat(text_info['positions'], dim=1).sum(dim=1, keepdim=True)
enc_glyph = self.glyph_block(glyphs, emb, context)
enc_pos = self.position_block(positions, emb, context)
guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info['masked_x']], dim=1))
outs = []
h = x.type(self.dtype)
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
h = module(h, emb, context)
h += guided_hint
guided_hint = None
else:
h = module(h, emb, context)
outs.append(zero_conv(h, emb, context))
h = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h, emb, context))
return outs
class ControlLDM(LatentDiffusion):
def __init__(self, control_stage_config, control_key, glyph_key, position_key, only_mid_control, loss_alpha=0, loss_beta=0, with_step_weight=False, use_vae_upsample=False, latin_weight=1.0, embedding_manager_config=None, *args, **kwargs):
self.use_fp16 = kwargs.pop('use_fp16', False)
super().__init__(*args, **kwargs)
self.control_model = instantiate_from_config(control_stage_config)
self.control_key = control_key
self.glyph_key = glyph_key
self.position_key = position_key
self.only_mid_control = only_mid_control
self.control_scales = [1.0] * 13
self.loss_alpha = loss_alpha
self.loss_beta = loss_beta
self.with_step_weight = with_step_weight
self.use_vae_upsample = use_vae_upsample
self.latin_weight = latin_weight
if embedding_manager_config is not None and embedding_manager_config.params.valid:
self.embedding_manager = self.instantiate_embedding_manager(embedding_manager_config, self.cond_stage_model)
for param in self.embedding_manager.embedding_parameters():
param.requires_grad = True
else:
self.embedding_manager = None
if self.loss_alpha > 0 or self.loss_beta > 0 or self.embedding_manager:
if embedding_manager_config.params.emb_type == 'ocr':
self.text_predictor = create_predictor().eval()
args = edict()
args.rec_image_shape = "3, 48, 320"
args.rec_batch_num = 6
args.rec_char_dict_path = './ocr_recog/ppocr_keys_v1.txt'
args.use_fp16 = self.use_fp16
self.cn_recognizer = TextRecognizer(args, self.text_predictor)
for param in self.text_predictor.parameters():
param.requires_grad = False
if self.embedding_manager:
self.embedding_manager.recog = self.cn_recognizer
@torch.no_grad()
def get_input(self, batch, k, bs=None, *args, **kwargs):
if self.embedding_manager is None: # fill in full caption
self.fill_caption(batch)
x, c, mx = super().get_input(batch, self.first_stage_key, mask_k='masked_img', *args, **kwargs)
control = batch[self.control_key] # for log_images and loss_alpha, not real control
if bs is not None:
control = control[:bs]
control = control.to(self.device)
control = einops.rearrange(control, 'b h w c -> b c h w')
control = control.to(memory_format=torch.contiguous_format).float()
inv_mask = batch['inv_mask']
if bs is not None:
inv_mask = inv_mask[:bs]
inv_mask = inv_mask.to(self.device)
inv_mask = einops.rearrange(inv_mask, 'b h w c -> b c h w')
inv_mask = inv_mask.to(memory_format=torch.contiguous_format).float()
glyphs = batch[self.glyph_key]
gly_line = batch['gly_line']
positions = batch[self.position_key]
n_lines = batch['n_lines']
language = batch['language']
texts = batch['texts']
assert len(glyphs) == len(positions)
for i in range(len(glyphs)):
if bs is not None:
glyphs[i] = glyphs[i][:bs]
gly_line[i] = gly_line[i][:bs]
positions[i] = positions[i][:bs]
n_lines = n_lines[:bs]
glyphs[i] = glyphs[i].to(self.device)
gly_line[i] = gly_line[i].to(self.device)
positions[i] = positions[i].to(self.device)
glyphs[i] = einops.rearrange(glyphs[i], 'b h w c -> b c h w')
gly_line[i] = einops.rearrange(gly_line[i], 'b h w c -> b c h w')
positions[i] = einops.rearrange(positions[i], 'b h w c -> b c h w')
glyphs[i] = glyphs[i].to(memory_format=torch.contiguous_format).float()
gly_line[i] = gly_line[i].to(memory_format=torch.contiguous_format).float()
positions[i] = positions[i].to(memory_format=torch.contiguous_format).float()
info = {}
info['glyphs'] = glyphs
info['positions'] = positions
info['n_lines'] = n_lines
info['language'] = language
info['texts'] = texts
info['img'] = batch['img'] # nhwc, (-1,1)
info['masked_x'] = mx
info['gly_line'] = gly_line
info['inv_mask'] = inv_mask
return x, dict(c_crossattn=[c], c_concat=[control], text_info=info)
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
assert isinstance(cond, dict)
diffusion_model = self.model.diffusion_model
_cond = torch.cat(cond['c_crossattn'], 1)
_hint = torch.cat(cond['c_concat'], 1)
if self.use_fp16:
x_noisy = x_noisy.half()
control = self.control_model(x=x_noisy, timesteps=t, context=_cond, hint=_hint, text_info=cond['text_info'])
control = [c * scale for c, scale in zip(control, self.control_scales)]
eps = diffusion_model(x=x_noisy, timesteps=t, context=_cond, control=control, only_mid_control=self.only_mid_control)
return eps
def instantiate_embedding_manager(self, config, embedder):
model = instantiate_from_config(config, embedder=embedder)
return model
@torch.no_grad()
def get_unconditional_conditioning(self, N):
return self.get_learned_conditioning(dict(c_crossattn=[[""] * N], text_info=None))
def get_learned_conditioning(self, c):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
if self.embedding_manager is not None and c['text_info'] is not None:
self.embedding_manager.encode_text(c['text_info'])
if isinstance(c, dict):
cond_txt = c['c_crossattn'][0]
else:
cond_txt = c
if self.embedding_manager is not None:
cond_txt = self.cond_stage_model.encode(cond_txt, embedding_manager=self.embedding_manager)
else:
cond_txt = self.cond_stage_model.encode(cond_txt)
if isinstance(c, dict):
c['c_crossattn'][0] = cond_txt
else:
c = cond_txt
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
def fill_caption(self, batch, place_holder='*'):
bs = len(batch['n_lines'])
cond_list = copy.deepcopy(batch[self.cond_stage_key])
for i in range(bs):
n_lines = batch['n_lines'][i]
if n_lines == 0:
continue
cur_cap = cond_list[i]
for j in range(n_lines):
r_txt = batch['texts'][j][i]
cur_cap = cur_cap.replace(place_holder, f'"{r_txt}"', 1)
cond_list[i] = cur_cap
batch[self.cond_stage_key] = cond_list
@torch.no_grad()
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
use_ddim = ddim_steps is not None
log = dict()
z, c = self.get_input(batch, self.first_stage_key, bs=N)
if self.cond_stage_trainable:
with torch.no_grad():
c = self.get_learned_conditioning(c)
c_crossattn = c["c_crossattn"][0][:N]
c_cat = c["c_concat"][0][:N]
text_info = c["text_info"]
text_info['glyphs'] = [i[:N] for i in text_info['glyphs']]
text_info['gly_line'] = [i[:N] for i in text_info['gly_line']]
text_info['positions'] = [i[:N] for i in text_info['positions']]
text_info['n_lines'] = text_info['n_lines'][:N]
text_info['masked_x'] = text_info['masked_x'][:N]
text_info['img'] = text_info['img'][:N]
N = min(z.shape[0], N)
n_row = min(z.shape[0], n_row)
log["reconstruction"] = self.decode_first_stage(z)
log["masked_image"] = self.decode_first_stage(text_info['masked_x'])
log["control"] = c_cat * 2.0 - 1.0
log["img"] = text_info['img'].permute(0, 3, 1, 2) # log source image if needed
# get glyph
glyph_bs = torch.stack(text_info['glyphs'])
glyph_bs = torch.sum(glyph_bs, dim=0) * 2.0 - 1.0
log["glyph"] = torch.nn.functional.interpolate(glyph_bs, size=(512, 512), mode='bilinear', align_corners=True,)
# fill caption
if not self.embedding_manager:
self.fill_caption(batch)
captions = batch[self.cond_stage_key]
log["conditioning"] = log_txt_as_img((512, 512), captions, size=16)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c], "text_info": text_info},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_cross = self.get_unconditional_conditioning(N)
uc_cat = c_cat # torch.zeros_like(c_cat)
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross['c_crossattn'][0]], "text_info": text_info}
samples_cfg, tmps = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c_crossattn], "text_info": text_info},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_full,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
pred_x0 = False # wether log pred_x0
if pred_x0:
for idx in range(len(tmps['pred_x0'])):
pred_x0 = self.decode_first_stage(tmps['pred_x0'][idx])
log[f"pred_x0_{tmps['index'][idx]}"] = pred_x0
return log
@torch.no_grad()
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
ddim_sampler = DDIMSampler(self)
b, c, h, w = cond["c_concat"][0].shape
shape = (self.channels, h // 8, w // 8)
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, log_every_t=5, **kwargs)
return samples, intermediates
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.control_model.parameters())
if self.embedding_manager:
params += list(self.embedding_manager.embedding_parameters())
if not self.sd_locked:
# params += list(self.model.diffusion_model.input_blocks.parameters())
# params += list(self.model.diffusion_model.middle_block.parameters())
params += list(self.model.diffusion_model.output_blocks.parameters())
params += list(self.model.diffusion_model.out.parameters())
if self.unlockKV:
nCount = 0
for name, param in self.model.diffusion_model.named_parameters():
if 'attn2.to_k' in name or 'attn2.to_v' in name:
params += [param]
nCount += 1
print(f'Cross attention is unlocked, and {nCount} Wk or Wv are added to potimizers!!!')
opt = torch.optim.AdamW(params, lr=lr)
return opt
def low_vram_shift(self, is_diffusing):
if is_diffusing:
self.model = self.model.cuda()
self.control_model = self.control_model.cuda()
self.first_stage_model = self.first_stage_model.cpu()
self.cond_stage_model = self.cond_stage_model.cpu()
else:
self.model = self.model.cpu()
self.control_model = self.control_model.cpu()
self.first_stage_model = self.first_stage_model.cuda()
self.cond_stage_model = self.cond_stage_model.cuda()

View File

@@ -0,0 +1,317 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from ..ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
model_t = self.model.apply_model(x, t, c)
model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
num_reference_steps = timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec

View File

@@ -0,0 +1,168 @@
'''
Copyright (c) Alibaba, Inc. and its affiliates.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from ..ldm.modules.diffusionmodules.util import conv_nd, linear, zero_module
def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"]
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
return tokens[0, 1]
def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string)
assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
token = token[0, 1]
return token
def get_clip_vision_emb(encoder, processor, img):
_img = img.repeat(1, 3, 1, 1)*255
inputs = processor(images=_img, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(img.device)
outputs = encoder(**inputs)
emb = outputs.image_embeds
return emb
def get_recog_emb(encoder, img_list):
_img_list = [(img.repeat(1, 3, 1, 1)*255)[0] for img in img_list]
encoder.predictor.eval()
_, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
return preds_neck
def pad_H(x):
_, _, H, W = x.shape
p_top = (W - H) // 2
p_bot = W - H - p_top
return F.pad(x, (0, 0, p_top, p_bot))
class EncodeNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(EncodeNet, self).__init__()
chan = 16
n_layer = 4 # downsample
self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
self.conv_list = nn.ModuleList([])
_c = chan
for i in range(n_layer):
self.conv_list.append(conv_nd(2, _c, _c*2, 3, padding=1, stride=2))
_c *= 2
self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.act = nn.SiLU()
def forward(self, x):
x = self.act(self.conv1(x))
for layer in self.conv_list:
x = self.act(layer(x))
x = self.act(self.conv2(x))
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
class EmbeddingManager(nn.Module):
def __init__(
self,
embedder,
valid=True,
glyph_channels=20,
position_channels=1,
placeholder_string='*',
add_pos=False,
emb_type='ocr',
**kwargs
):
super().__init__()
if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
token_dim = 768
if hasattr(embedder, 'vit'):
assert emb_type == 'vit'
self.get_vision_emb = partial(get_clip_vision_emb, embedder.vit, embedder.processor)
self.get_recog_emb = None
else: # using LDM's BERT encoder
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
token_dim = 1280
self.token_dim = token_dim
self.emb_type = emb_type
self.add_pos = add_pos
if add_pos:
self.position_encoder = EncodeNet(position_channels, token_dim)
if emb_type == 'ocr':
self.proj = nn.Sequential(
zero_module(linear(40*64, token_dim)),
nn.LayerNorm(token_dim)
)
if emb_type == 'conv':
self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
self.placeholder_token = get_token_for_string(placeholder_string)
def encode_text(self, text_info):
if self.get_recog_emb is None and self.emb_type == 'ocr':
self.get_recog_emb = partial(get_recog_emb, self.recog)
gline_list = []
pos_list = []
for i in range(len(text_info['n_lines'])): # sample index in a batch
n_lines = text_info['n_lines'][i]
for j in range(n_lines): # line
gline_list += [text_info['gly_line'][j][i:i+1]]
if self.add_pos:
pos_list += [text_info['positions'][j][i:i+1]]
if len(gline_list) > 0:
if self.emb_type == 'ocr':
recog_emb = self.get_recog_emb(gline_list)
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
elif self.emb_type == 'vit':
enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
elif self.emb_type == 'conv':
enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
if self.add_pos:
enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
enc_glyph = enc_glyph+enc_pos
self.text_embs_all = []
n_idx = 0
for i in range(len(text_info['n_lines'])): # sample index in a batch
n_lines = text_info['n_lines'][i]
text_embs = []
for j in range(n_lines): # line
text_embs += [enc_glyph[n_idx:n_idx+1]]
n_idx += 1
self.text_embs_all += [text_embs]
def forward(
self,
tokenized_text,
embedded_text,
):
b, device = tokenized_text.shape[0], tokenized_text.device
for i in range(b):
idx = tokenized_text[i] == self.placeholder_token.to(device)
if sum(idx) > 0:
if i >= len(self.text_embs_all):
print('truncation for log images...')
break
text_emb = torch.cat(self.text_embs_all[i], dim=0)
if sum(idx) != len(text_emb):
print('truncation for long caption...')
embedded_text[i][idx] = text_emb[:sum(idx)]
return embedded_text
def embedding_parameters(self):
return self.parameters()

View File

@@ -0,0 +1,111 @@
import torch
import einops
from ..ldm.modules.encoders import modules
from ..ldm.modules import attention
from transformers import logging
from ..ldm.modules.attention import default
def disable_verbosity():
logging.set_verbosity_error()
print('logging improved.')
return
def enable_sliced_attention():
attention.CrossAttention.forward = _hacked_sliced_attentin_forward
print('Enabled sliced_attention.')
return
def hack_everything(clip_skip=0):
disable_verbosity()
modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
modules.FrozenCLIPEmbedder.clip_skip = clip_skip
print('Enabled clip hacks.')
return
# Written by Lvmin
def _hacked_clip_forward(self, text):
PAD = self.tokenizer.pad_token_id
EOS = self.tokenizer.eos_token_id
BOS = self.tokenizer.bos_token_id
def tokenize(t):
return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
def transformer_encode(t):
if self.clip_skip > 1:
rt = self.transformer(input_ids=t, output_hidden_states=True)
return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
else:
return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
def split(x):
return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
def pad(x, p, i):
return x[:i] if len(x) >= i else x + [p] * (i - len(x))
raw_tokens_list = tokenize(text)
tokens_list = []
for raw_tokens in raw_tokens_list:
raw_tokens_123 = split(raw_tokens)
raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
tokens_list.append(raw_tokens_123)
tokens_list = torch.IntTensor(tokens_list).to(self.device)
feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
y = transformer_encode(feed)
z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
return z
# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
limit = k.shape[0]
att_step = 1
q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
q_chunks.reverse()
k_chunks.reverse()
v_chunks.reverse()
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
del k, q, v
for i in range(0, limit, att_step):
q_buffer = q_chunks.pop()
k_buffer = k_chunks.pop()
v_buffer = v_chunks.pop()
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
del k_buffer, q_buffer
# attention, what we cannot get enough of, by chunks
sim_buffer = sim_buffer.softmax(dim=-1)
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
del v_buffer
sim[i:i + att_step, :, :] = sim_buffer
del sim_buffer
sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
return self.to_out(sim)

View File

@@ -0,0 +1,76 @@
import os
import numpy as np
import torch
import torchvision
from PIL import Image
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.rank_zero import rank_zero_only
class ImageLogger(Callback):
def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
log_images_kwargs=None):
super().__init__()
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
self.disabled = disabled
self.log_on_batch_idx = log_on_batch_idx
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step
@rank_zero_only
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "image_log", split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
if self.rescale:
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split="train"):
check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
hasattr(pl_module, "log_images") and
callable(pl_module.log_images) and
self.max_images > 0):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
for k in images:
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
if self.clamp:
images[k] = torch.clamp(images[k], -1., 1.)
self.log_local(pl_module.logger.save_dir, split, images,
pl_module.global_step, pl_module.current_epoch, batch_idx)
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
return check_idx % self.batch_freq == 0
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if not self.disabled:
self.log_img(pl_module, batch, batch_idx, split="train")

View File

@@ -0,0 +1,34 @@
import os
import torch
from omegaconf import OmegaConf
from ..ldm.util import instantiate_from_config
def get_state_dict(d):
return d.get('state_dict', d)
def load_state_dict(ckpt_path, location='cpu'):
_, extension = os.path.splitext(ckpt_path)
if extension.lower() == ".safetensors":
import safetensors.torch
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
else:
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
state_dict = get_state_dict(state_dict)
print(f'Loaded state_dict from [{ckpt_path}]')
return state_dict
def create_model(config_path, cond_stage_path=None, use_fp16=False):
config = OmegaConf.load(config_path)
if cond_stage_path:
config.model.params.cond_stage_config.params.version = cond_stage_path # use pre-downloaded ckpts, in case blocked
if use_fp16:
config.model.params.use_fp16 = True
config.model.params.control_stage_config.params.use_fp16 = True
config.model.params.unet_config.params.use_fp16 = True
model = instantiate_from_config(config.model).cpu()
print(f'Loaded model config from [{config_path}]')
return model

View File

@@ -0,0 +1,210 @@
from torch import nn
import torch
from .RecSVTR import Block
class Swish(nn.Module):
def __int__(self):
super(Swish, self).__int__()
def forward(self,x):
return x*torch.sigmoid(x)
class Im2Im(nn.Module):
def __init__(self, in_channels, **kwargs):
super().__init__()
self.out_channels = in_channels
def forward(self, x):
return x
class Im2Seq(nn.Module):
def __init__(self, in_channels, **kwargs):
super().__init__()
self.out_channels = in_channels
def forward(self, x):
B, C, H, W = x.shape
# assert H == 1
x = x.reshape(B, C, H * W)
x = x.permute((0, 2, 1))
return x
class EncoderWithRNN(nn.Module):
def __init__(self, in_channels,**kwargs):
super(EncoderWithRNN, self).__init__()
hidden_size = kwargs.get('hidden_size', 256)
self.out_channels = hidden_size * 2
self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2,batch_first=True)
def forward(self, x):
self.lstm.flatten_parameters()
x, _ = self.lstm(x)
return x
class SequenceEncoder(nn.Module):
def __init__(self, in_channels, encoder_type='rnn', **kwargs):
super(SequenceEncoder, self).__init__()
self.encoder_reshape = Im2Seq(in_channels)
self.out_channels = self.encoder_reshape.out_channels
self.encoder_type = encoder_type
if encoder_type == 'reshape':
self.only_reshape = True
else:
support_encoder_dict = {
'reshape': Im2Seq,
'rnn': EncoderWithRNN,
'svtr': EncoderWithSVTR
}
assert encoder_type in support_encoder_dict, '{} must in {}'.format(
encoder_type, support_encoder_dict.keys())
self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels,**kwargs)
self.out_channels = self.encoder.out_channels
self.only_reshape = False
def forward(self, x):
if self.encoder_type != 'svtr':
x = self.encoder_reshape(x)
if not self.only_reshape:
x = self.encoder(x)
return x
else:
x = self.encoder(x)
x = self.encoder_reshape(x)
return x
class ConvBNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=0,
bias_attr=False,
groups=1,
act=nn.GELU):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
bias=bias_attr)
self.norm = nn.BatchNorm2d(out_channels)
self.act = Swish()
def forward(self, inputs):
out = self.conv(inputs)
out = self.norm(out)
out = self.act(out)
return out
class EncoderWithSVTR(nn.Module):
def __init__(
self,
in_channels,
dims=64, # XS
depth=2,
hidden_dims=120,
use_guide=False,
num_heads=8,
qkv_bias=True,
mlp_ratio=2.0,
drop_rate=0.1,
attn_drop_rate=0.1,
drop_path=0.,
qk_scale=None):
super(EncoderWithSVTR, self).__init__()
self.depth = depth
self.use_guide = use_guide
self.conv1 = ConvBNLayer(
in_channels, in_channels // 8, padding=1, act='swish')
self.conv2 = ConvBNLayer(
in_channels // 8, hidden_dims, kernel_size=1, act='swish')
self.svtr_block = nn.ModuleList([
Block(
dim=hidden_dims,
num_heads=num_heads,
mixer='Global',
HW=None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer='swish',
attn_drop=attn_drop_rate,
drop_path=drop_path,
norm_layer='nn.LayerNorm',
epsilon=1e-05,
prenorm=False) for i in range(depth)
])
self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
self.conv3 = ConvBNLayer(
hidden_dims, in_channels, kernel_size=1, act='swish')
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
self.conv4 = ConvBNLayer(
2 * in_channels, in_channels // 8, padding=1, act='swish')
self.conv1x1 = ConvBNLayer(
in_channels // 8, dims, kernel_size=1, act='swish')
self.out_channels = dims
self.apply(self._init_weights)
def _init_weights(self, m):
# weight initialization
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
# for use guide
if self.use_guide:
z = x.clone()
z.stop_gradient = True
else:
z = x
# for short cut
h = z
# reduce dim
z = self.conv1(z)
z = self.conv2(z)
# SVTR global block
B, C, H, W = z.shape
z = z.flatten(2).permute(0, 2, 1)
for blk in self.svtr_block:
z = blk(z)
z = self.norm(z)
# last stage
z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
z = self.conv3(z)
z = torch.cat((h, z), dim=1)
z = self.conv1x1(self.conv4(z))
return z
if __name__=="__main__":
svtrRNN = EncoderWithSVTR(56)
print(svtrRNN)

View File

@@ -0,0 +1,48 @@
from torch import nn
class CTCHead(nn.Module):
def __init__(self,
in_channels,
out_channels=6625,
fc_decay=0.0004,
mid_channels=None,
return_feats=False,
**kwargs):
super(CTCHead, self).__init__()
if mid_channels is None:
self.fc = nn.Linear(
in_channels,
out_channels,
bias=True,)
else:
self.fc1 = nn.Linear(
in_channels,
mid_channels,
bias=True,
)
self.fc2 = nn.Linear(
mid_channels,
out_channels,
bias=True,
)
self.out_channels = out_channels
self.mid_channels = mid_channels
self.return_feats = return_feats
def forward(self, x, labels=None):
if self.mid_channels is None:
predicts = self.fc(x)
else:
x = self.fc1(x)
predicts = self.fc2(x)
if self.return_feats:
result = dict()
result['ctc'] = predicts
result['ctc_neck'] = x
else:
result = predicts
return result

View File

@@ -0,0 +1,45 @@
from torch import nn
from .RNN import SequenceEncoder, Im2Seq, Im2Im
from .RecMv1_enhance import MobileNetV1Enhance
from .RecCTCHead import CTCHead
backbone_dict = {"MobileNetV1Enhance":MobileNetV1Enhance}
neck_dict = {'SequenceEncoder': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im}
head_dict = {'CTCHead':CTCHead}
class RecModel(nn.Module):
def __init__(self, config):
super().__init__()
assert 'in_channels' in config, 'in_channels must in model config'
backbone_type = config.backbone.pop('type')
assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}'
self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)
neck_type = config.neck.pop('type')
assert neck_type in neck_dict, f'neck.type must in {neck_dict}'
self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)
head_type = config.head.pop('type')
assert head_type in head_dict, f'head.type must in {head_dict}'
self.head = head_dict[head_type](self.neck.out_channels, **config.head)
self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}'
def load_3rd_state_dict(self, _3rd_name, _state):
self.backbone.load_3rd_state_dict(_3rd_name, _state)
self.neck.load_3rd_state_dict(_3rd_name, _state)
self.head.load_3rd_state_dict(_3rd_name, _state)
def forward(self, x):
x = self.backbone(x)
x = self.neck(x)
x = self.head(x)
return x
def encode(self, x):
x = self.backbone(x)
x = self.neck(x)
x = self.head.ctc_encoder(x)
return x

View File

@@ -0,0 +1,233 @@
import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from .common import Activation
class ConvBNLayer(nn.Module):
def __init__(self,
num_channels,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='hard_swish'):
super(ConvBNLayer, self).__init__()
self.act = act
self._conv = nn.Conv2d(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
bias=False)
self._batch_norm = nn.BatchNorm2d(
num_filters,
)
if self.act is not None:
self._act = Activation(act_type=act, inplace=True)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
if self.act is not None:
y = self._act(y)
return y
class DepthwiseSeparable(nn.Module):
def __init__(self,
num_channels,
num_filters1,
num_filters2,
num_groups,
stride,
scale,
dw_size=3,
padding=1,
use_se=False):
super(DepthwiseSeparable, self).__init__()
self.use_se = use_se
self._depthwise_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=int(num_filters1 * scale),
filter_size=dw_size,
stride=stride,
padding=padding,
num_groups=int(num_groups * scale))
if use_se:
self._se = SEModule(int(num_filters1 * scale))
self._pointwise_conv = ConvBNLayer(
num_channels=int(num_filters1 * scale),
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0)
def forward(self, inputs):
y = self._depthwise_conv(inputs)
if self.use_se:
y = self._se(y)
y = self._pointwise_conv(y)
return y
class MobileNetV1Enhance(nn.Module):
def __init__(self,
in_channels=3,
scale=0.5,
last_conv_stride=1,
last_pool_type='max',
**kwargs):
super().__init__()
self.scale = scale
self.block_list = []
self.conv1 = ConvBNLayer(
num_channels=in_channels,
filter_size=3,
channels=3,
num_filters=int(32 * scale),
stride=2,
padding=1)
conv2_1 = DepthwiseSeparable(
num_channels=int(32 * scale),
num_filters1=32,
num_filters2=64,
num_groups=32,
stride=1,
scale=scale)
self.block_list.append(conv2_1)
conv2_2 = DepthwiseSeparable(
num_channels=int(64 * scale),
num_filters1=64,
num_filters2=128,
num_groups=64,
stride=1,
scale=scale)
self.block_list.append(conv2_2)
conv3_1 = DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=128,
num_groups=128,
stride=1,
scale=scale)
self.block_list.append(conv3_1)
conv3_2 = DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=(2, 1),
scale=scale)
self.block_list.append(conv3_2)
conv4_1 = DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=256,
num_groups=256,
stride=1,
scale=scale)
self.block_list.append(conv4_1)
conv4_2 = DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=(2, 1),
scale=scale)
self.block_list.append(conv4_2)
for _ in range(5):
conv5 = DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
dw_size=5,
padding=2,
scale=scale,
use_se=False)
self.block_list.append(conv5)
conv5_6 = DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=(2, 1),
dw_size=5,
padding=2,
scale=scale,
use_se=True)
self.block_list.append(conv5_6)
conv6 = DepthwiseSeparable(
num_channels=int(1024 * scale),
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=last_conv_stride,
dw_size=5,
padding=2,
use_se=True,
scale=scale)
self.block_list.append(conv6)
self.block_list = nn.Sequential(*self.block_list)
if last_pool_type == 'avg':
self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.out_channels = int(1024 * scale)
def forward(self, inputs):
y = self.conv1(inputs)
y = self.block_list(y)
y = self.pool(y)
return y
def hardsigmoid(x):
return F.relu6(x + 3., inplace=True) / 6.
class SEModule(nn.Module):
def __init__(self, channel, reduction=4):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0,
bias=True)
self.conv2 = nn.Conv2d(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0,
bias=True)
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = hardsigmoid(outputs)
x = torch.mul(inputs, outputs)
return x

View File

@@ -0,0 +1,591 @@
import torch
import torch.nn as nn
import numpy as np
from torch.nn.init import trunc_normal_, zeros_, ones_
from torch.nn import functional
def drop_path(x, drop_prob=0., training=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
if drop_prob == 0. or not training:
return x
keep_prob = torch.tensor(1 - drop_prob)
shape = (x.size()[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
random_tensor = torch.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
return output
class Swish(nn.Module):
def __int__(self):
super(Swish, self).__int__()
def forward(self,x):
return x*torch.sigmoid(x)
class ConvBNLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=0,
bias_attr=False,
groups=1,
act=nn.GELU):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
bias=bias_attr)
self.norm = nn.BatchNorm2d(out_channels)
self.act = act()
def forward(self, inputs):
out = self.conv(inputs)
out = self.norm(out)
out = self.act(out)
return out
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, input):
return input
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
if isinstance(act_layer, str):
self.act = Swish()
else:
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ConvMixer(nn.Module):
def __init__(
self,
dim,
num_heads=8,
HW=(8, 25),
local_k=(3, 3), ):
super().__init__()
self.HW = HW
self.dim = dim
self.local_mixer = nn.Conv2d(
dim,
dim,
local_k,
1, (local_k[0] // 2, local_k[1] // 2),
groups=num_heads,
# weight_attr=ParamAttr(initializer=KaimingNormal())
)
def forward(self, x):
h = self.HW[0]
w = self.HW[1]
x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
x = self.local_mixer(x)
x = x.flatten(2).transpose([0, 2, 1])
return x
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
mixer='Global',
HW=(8, 25),
local_k=(7, 11),
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.HW = HW
if HW is not None:
H = HW[0]
W = HW[1]
self.N = H * W
self.C = dim
if mixer == 'Local' and HW is not None:
hk = local_k[0]
wk = local_k[1]
mask = torch.ones([H * W, H + hk - 1, W + wk - 1])
for h in range(0, H):
for w in range(0, W):
mask[h * W + w, h:h + hk, w:w + wk] = 0.
mask_paddle = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk //
2].flatten(1)
mask_inf = torch.full([H * W, H * W],fill_value=float('-inf'))
mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
self.mask = mask[None,None,:]
# self.mask = mask.unsqueeze([0, 1])
self.mixer = mixer
def forward(self, x):
if self.HW is not None:
N = self.N
C = self.C
else:
_, N, C = x.shape
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //self.num_heads)).permute((2, 0, 3, 1, 4))
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = (q.matmul(k.permute((0, 1, 3, 2))))
if self.mixer == 'Local':
attn += self.mask
attn = functional.softmax(attn, dim=-1)
attn = self.attn_drop(attn)
x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mixer='Global',
local_mixer=(7, 11),
HW=(8, 25),
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer='nn.LayerNorm',
epsilon=1e-6,
prenorm=True):
super().__init__()
if isinstance(norm_layer, str):
self.norm1 = eval(norm_layer)(dim, eps=epsilon)
else:
self.norm1 = norm_layer(dim)
if mixer == 'Global' or mixer == 'Local':
self.mixer = Attention(
dim,
num_heads=num_heads,
mixer=mixer,
HW=HW,
local_k=local_mixer,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
elif mixer == 'Conv':
self.mixer = ConvMixer(
dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
else:
raise TypeError("The mixer must be one of [Global, Local, Conv]")
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
if isinstance(norm_layer, str):
self.norm2 = eval(norm_layer)(dim, eps=epsilon)
else:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
self.prenorm = prenorm
def forward(self, x):
if self.prenorm:
x = self.norm1(x + self.drop_path(self.mixer(x)))
x = self.norm2(x + self.drop_path(self.mlp(x)))
else:
x = x + self.drop_path(self.mixer(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self,
img_size=(32, 100),
in_channels=3,
embed_dim=768,
sub_num=2):
super().__init__()
num_patches = (img_size[1] // (2 ** sub_num)) * \
(img_size[0] // (2 ** sub_num))
self.img_size = img_size
self.num_patches = num_patches
self.embed_dim = embed_dim
self.norm = None
if sub_num == 2:
self.proj = nn.Sequential(
ConvBNLayer(
in_channels=in_channels,
out_channels=embed_dim // 2,
kernel_size=3,
stride=2,
padding=1,
act=nn.GELU,
bias_attr=False),
ConvBNLayer(
in_channels=embed_dim // 2,
out_channels=embed_dim,
kernel_size=3,
stride=2,
padding=1,
act=nn.GELU,
bias_attr=False))
if sub_num == 3:
self.proj = nn.Sequential(
ConvBNLayer(
in_channels=in_channels,
out_channels=embed_dim // 4,
kernel_size=3,
stride=2,
padding=1,
act=nn.GELU,
bias_attr=False),
ConvBNLayer(
in_channels=embed_dim // 4,
out_channels=embed_dim // 2,
kernel_size=3,
stride=2,
padding=1,
act=nn.GELU,
bias_attr=False),
ConvBNLayer(
in_channels=embed_dim // 2,
out_channels=embed_dim,
kernel_size=3,
stride=2,
padding=1,
act=nn.GELU,
bias_attr=False))
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).permute(0, 2, 1)
return x
class SubSample(nn.Module):
def __init__(self,
in_channels,
out_channels,
types='Pool',
stride=(2, 1),
sub_norm='nn.LayerNorm',
act=None):
super().__init__()
self.types = types
if types == 'Pool':
self.avgpool = nn.AvgPool2d(
kernel_size=(3, 5), stride=stride, padding=(1, 2))
self.maxpool = nn.MaxPool2d(
kernel_size=(3, 5), stride=stride, padding=(1, 2))
self.proj = nn.Linear(in_channels, out_channels)
else:
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
# weight_attr=ParamAttr(initializer=KaimingNormal())
)
self.norm = eval(sub_norm)(out_channels)
if act is not None:
self.act = act()
else:
self.act = None
def forward(self, x):
if self.types == 'Pool':
x1 = self.avgpool(x)
x2 = self.maxpool(x)
x = (x1 + x2) * 0.5
out = self.proj(x.flatten(2).permute((0, 2, 1)))
else:
x = self.conv(x)
out = x.flatten(2).permute((0, 2, 1))
out = self.norm(out)
if self.act is not None:
out = self.act(out)
return out
class SVTRNet(nn.Module):
def __init__(
self,
img_size=[48, 100],
in_channels=3,
embed_dim=[64, 128, 256],
depth=[3, 6, 3],
num_heads=[2, 4, 8],
mixer=['Local'] * 6 + ['Global'] *
6, # Local atten, Global atten, Conv
local_mixer=[[7, 11], [7, 11], [7, 11]],
patch_merging='Conv', # Conv, Pool, None
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
last_drop=0.1,
attn_drop_rate=0.,
drop_path_rate=0.1,
norm_layer='nn.LayerNorm',
sub_norm='nn.LayerNorm',
epsilon=1e-6,
out_channels=192,
out_char_num=25,
block_unit='Block',
act='nn.GELU',
last_stage=True,
sub_num=2,
prenorm=True,
use_lenhead=False,
**kwargs):
super().__init__()
self.img_size = img_size
self.embed_dim = embed_dim
self.out_channels = out_channels
self.prenorm = prenorm
patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging
self.patch_embed = PatchEmbed(
img_size=img_size,
in_channels=in_channels,
embed_dim=embed_dim[0],
sub_num=sub_num)
num_patches = self.patch_embed.num_patches
self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
# self.pos_embed = self.create_parameter(
# shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
# self.add_parameter("pos_embed", self.pos_embed)
self.pos_drop = nn.Dropout(p=drop_rate)
Block_unit = eval(block_unit)
dpr = np.linspace(0, drop_path_rate, sum(depth))
self.blocks1 = nn.ModuleList(
[
Block_unit(
dim=embed_dim[0],
num_heads=num_heads[0],
mixer=mixer[0:depth[0]][i],
HW=self.HW,
local_mixer=local_mixer[0],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=eval(act),
attn_drop=attn_drop_rate,
drop_path=dpr[0:depth[0]][i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm) for i in range(depth[0])
]
)
if patch_merging is not None:
self.sub_sample1 = SubSample(
embed_dim[0],
embed_dim[1],
sub_norm=sub_norm,
stride=[2, 1],
types=patch_merging)
HW = [self.HW[0] // 2, self.HW[1]]
else:
HW = self.HW
self.patch_merging = patch_merging
self.blocks2 = nn.ModuleList([
Block_unit(
dim=embed_dim[1],
num_heads=num_heads[1],
mixer=mixer[depth[0]:depth[0] + depth[1]][i],
HW=HW,
local_mixer=local_mixer[1],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=eval(act),
attn_drop=attn_drop_rate,
drop_path=dpr[depth[0]:depth[0] + depth[1]][i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm) for i in range(depth[1])
])
if patch_merging is not None:
self.sub_sample2 = SubSample(
embed_dim[1],
embed_dim[2],
sub_norm=sub_norm,
stride=[2, 1],
types=patch_merging)
HW = [self.HW[0] // 4, self.HW[1]]
else:
HW = self.HW
self.blocks3 = nn.ModuleList([
Block_unit(
dim=embed_dim[2],
num_heads=num_heads[2],
mixer=mixer[depth[0] + depth[1]:][i],
HW=HW,
local_mixer=local_mixer[2],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=eval(act),
attn_drop=attn_drop_rate,
drop_path=dpr[depth[0] + depth[1]:][i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm) for i in range(depth[2])
])
self.last_stage = last_stage
if last_stage:
self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num))
self.last_conv = nn.Conv2d(
in_channels=embed_dim[2],
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False)
self.hardswish = nn.Hardswish()
self.dropout = nn.Dropout(p=last_drop)
if not prenorm:
self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
self.use_lenhead = use_lenhead
if use_lenhead:
self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
self.hardswish_len = nn.Hardswish()
self.dropout_len = nn.Dropout(
p=last_drop)
trunc_normal_(self.pos_embed,std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight,std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
def forward_features(self, x):
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks1:
x = blk(x)
if self.patch_merging is not None:
x = self.sub_sample1(
x.permute([0, 2, 1]).reshape(
[-1, self.embed_dim[0], self.HW[0], self.HW[1]]))
for blk in self.blocks2:
x = blk(x)
if self.patch_merging is not None:
x = self.sub_sample2(
x.permute([0, 2, 1]).reshape(
[-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))
for blk in self.blocks3:
x = blk(x)
if not self.prenorm:
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
if self.use_lenhead:
len_x = self.len_conv(x.mean(1))
len_x = self.dropout_len(self.hardswish_len(len_x))
if self.last_stage:
if self.patch_merging is not None:
h = self.HW[0] // 4
else:
h = self.HW[0]
x = self.avg_pool(
x.permute([0, 2, 1]).reshape(
[-1, self.embed_dim[2], h, self.HW[1]]))
x = self.last_conv(x)
x = self.hardswish(x)
x = self.dropout(x)
if self.use_lenhead:
return x, len_x
return x
if __name__=="__main__":
a = torch.rand(1,3,48,100)
svtr = SVTRNet()
out = svtr(a)
print(svtr)
print(out.size())

View File

@@ -0,0 +1,74 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Hswish(nn.Module):
def __init__(self, inplace=True):
super(Hswish, self).__init__()
self.inplace = inplace
def forward(self, x):
return x * F.relu6(x + 3., inplace=self.inplace) / 6.
# out = max(0, min(1, slop*x+offset))
# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
class Hsigmoid(nn.Module):
def __init__(self, inplace=True):
super(Hsigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
# torch: F.relu6(x + 3., inplace=self.inplace) / 6.
# paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
return F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
class GELU(nn.Module):
def __init__(self, inplace=True):
super(GELU, self).__init__()
self.inplace = inplace
def forward(self, x):
return torch.nn.functional.gelu(x)
class Swish(nn.Module):
def __init__(self, inplace=True):
super(Swish, self).__init__()
self.inplace = inplace
def forward(self, x):
if self.inplace:
x.mul_(torch.sigmoid(x))
return x
else:
return x*torch.sigmoid(x)
class Activation(nn.Module):
def __init__(self, act_type, inplace=True):
super(Activation, self).__init__()
act_type = act_type.lower()
if act_type == 'relu':
self.act = nn.ReLU(inplace=inplace)
elif act_type == 'relu6':
self.act = nn.ReLU6(inplace=inplace)
elif act_type == 'sigmoid':
raise NotImplementedError
elif act_type == 'hard_sigmoid':
self.act = Hsigmoid(inplace)
elif act_type == 'hard_swish':
self.act = Hswish(inplace=inplace)
elif act_type == 'leakyrelu':
self.act = nn.LeakyReLU(inplace=inplace)
elif act_type == 'gelu':
self.act = GELU(inplace=inplace)
elif act_type == 'swish':
self.act = Swish(inplace=inplace)
else:
raise NotImplementedError
def forward(self, inputs):
return self.act(inputs)

View File

@@ -0,0 +1,95 @@
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
\
]
^
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
{
|
}
~
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,310 @@
'''
Copyright (c) Alibaba, Inc. and its affiliates.
'''
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import cv2
import numpy as np
import math
import traceback
from easydict import EasyDict as edict
import time
from .ocr_recog.RecModel import RecModel
import torch
import torch.nn.functional as F
from skimage.transform._geometric import _umeyama as get_sym_mat
current_directory = os.path.dirname(os.path.abspath(__file__))
ocr_txt_path = os.path.join(os.path.dirname(os.path.dirname(current_directory)), "ocr_weights", "ppocr_keys_v1.txt")
ocr_model_path = os.path.join(os.path.dirname(os.path.dirname(current_directory)), "ocr_weights", "ppv3_rec.pth")
def min_bounding_rect(img):
ret, thresh = cv2.threshold(img, 127, 255, 0)
contours, hierarchy = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if len(contours) == 0:
print('Bad contours, using fake bbox...')
return np.array([[0, 0], [100, 0], [100, 100], [0, 100]])
max_contour = max(contours, key=cv2.contourArea)
rect = cv2.minAreaRect(max_contour)
box = cv2.boxPoints(rect)
box = np.int0(box)
# sort
x_sorted = sorted(box, key=lambda x: x[0])
left = x_sorted[:2]
right = x_sorted[2:]
left = sorted(left, key=lambda x: x[1])
(tl, bl) = left
right = sorted(right, key=lambda x: x[1])
(tr, br) = right
if tl[1] > bl[1]:
(tl, bl) = (bl, tl)
if tr[1] > br[1]:
(tr, br) = (br, tr)
return np.array([tl, tr, br, bl])
def adjust_image(box, img):
pts1 = np.float32([box[0], box[1], box[2], box[3]])
width = max(np.linalg.norm(pts1[0]-pts1[1]), np.linalg.norm(pts1[2]-pts1[3]))
height = max(np.linalg.norm(pts1[0]-pts1[3]), np.linalg.norm(pts1[1]-pts1[2]))
pts2 = np.float32([[0, 0], [width, 0], [width, height], [0, height]])
# get transform matrix
M = get_sym_mat(pts1, pts2, estimate_scale=True)
C, H, W = img.shape
T = np.array([[2 / W, 0, -1], [0, 2 / H, -1], [0, 0, 1]])
theta = np.linalg.inv(T @ M @ np.linalg.inv(T))
theta = torch.from_numpy(theta[:2, :]).unsqueeze(0).type(torch.float32).to(img.device)
grid = F.affine_grid(theta, torch.Size([1, C, H, W]), align_corners=True)
result = F.grid_sample(img.unsqueeze(0), grid, align_corners=True)
result = torch.clamp(result.squeeze(0), 0, 255)
# crop
result = result[:, :int(height), :int(width)]
return result
'''
mask: numpy.ndarray, mask of textual, HWC
src_img: torch.Tensor, source image, CHW
'''
def crop_image(src_img, mask):
box = min_bounding_rect(mask)
result = adjust_image(box, src_img)
if len(result.shape) == 2:
result = torch.stack([result]*3, axis=-1)
return result
def create_predictor(model_dir=None, model_lang='ch', is_onnx=False):
model_file_path = model_dir
if model_file_path is not None and not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(model_file_path))
if is_onnx:
import onnxruntime as ort
sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider']) # 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
return sess
else:
if model_lang == 'ch':
n_class = 6625
elif model_lang == 'en':
n_class = 97
else:
raise ValueError(f"Unsupported OCR recog model_lang: {model_lang}")
rec_config = edict(
in_channels=3,
backbone=edict(type='MobileNetV1Enhance', scale=0.5, last_conv_stride=[1, 2], last_pool_type='avg'),
neck=edict(type='SequenceEncoder', encoder_type="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True),
head=edict(type='CTCHead', fc_decay=0.00001, out_channels=n_class, return_feats=True)
)
rec_model = RecModel(rec_config)
if model_file_path is not None:
rec_model.load_state_dict(torch.load(model_file_path, map_location="cpu"))
rec_model.eval()
return rec_model.eval()
def _check_image_file(path):
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff'}
return any([path.lower().endswith(e) for e in img_end])
def get_image_file_list(img_file):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
if os.path.isfile(img_file) and _check_image_file(img_file):
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
file_path = os.path.join(img_file, single_file)
if os.path.isfile(file_path) and _check_image_file(file_path):
imgs_lists.append(file_path)
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
imgs_lists = sorted(imgs_lists)
return imgs_lists
class TextRecognizer(object):
def __init__(self, args, predictor):
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
self.rec_batch_num = args.rec_batch_num
self.predictor = predictor
self.chars = self.get_char_dict(ocr_txt_path)
self.char2id = {x: i for i, x in enumerate(self.chars)}
self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
self.use_fp16 = args.use_fp16
# img: CHW
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
assert imgC == img.shape[0]
imgW = int((imgH * max_wh_ratio))
h, w = img.shape[1:]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = torch.nn.functional.interpolate(
img.unsqueeze(0),
size=(imgH, resized_w),
mode='bilinear',
align_corners=True,
)
resized_image /= 255.0
resized_image -= 0.5
resized_image /= 0.5
padding_im = torch.zeros((imgC, imgH, imgW), dtype=torch.float32).to(img.device)
padding_im[:, :, 0:resized_w] = resized_image[0]
return padding_im
# img_list: list of tensors with shape chw 0-255
def pred_imglist(self, img_list, show_debug=False):
img_num = len(img_list)
assert img_num > 0
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[2] / float(img.shape[1]))
# Sorting can speed up the recognition process
indices = torch.from_numpy(np.argsort(np.array(width_list)))
batch_num = self.rec_batch_num
preds_all = [None] * img_num
preds_neck_all = [None] * img_num
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[1:]
if h > w * 1.2:
img = img_list[indices[ino]]
img = torch.transpose(img, 1, 2).flip(dims=[1])
img_list[indices[ino]] = img
h, w = img.shape[1:]
# wh_ratio = w * 1.0 / h
# max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
if self.use_fp16:
norm_img = norm_img.half()
norm_img = norm_img.unsqueeze(0)
norm_img_batch.append(norm_img)
norm_img_batch = torch.cat(norm_img_batch, dim=0)
if show_debug:
for i in range(len(norm_img_batch)):
_img = norm_img_batch[i].permute(1, 2, 0).detach().cpu().numpy()
_img = (_img + 0.5)*255
_img = _img[:, :, ::-1]
file_name = f'{indices[beg_img_no + i]}'
if os.path.exists(file_name + '.jpg'):
file_name += '_2' # ori image
cv2.imwrite(file_name + '.jpg', _img)
if self.is_onnx:
input_dict = {}
input_dict[self.predictor.get_inputs()[0].name] = norm_img_batch.detach().cpu().numpy()
outputs = self.predictor.run(None, input_dict)
preds = {}
preds['ctc'] = torch.from_numpy(outputs[0])
preds['ctc_neck'] = [torch.zeros(1)] * img_num
else:
preds = self.predictor(norm_img_batch)
for rno in range(preds['ctc'].shape[0]):
preds_all[indices[beg_img_no + rno]] = preds['ctc'][rno]
preds_neck_all[indices[beg_img_no + rno]] = preds['ctc_neck'][rno]
return torch.stack(preds_all, dim=0), torch.stack(preds_neck_all, dim=0)
def get_char_dict(self, character_dict_path):
character_str = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
character_str.append(line)
dict_character = list(character_str)
dict_character = ['sos'] + dict_character + [' '] # eos is space
return dict_character
def get_text(self, order):
char_list = [self.chars[text_id] for text_id in order]
return ''.join(char_list)
def decode(self, mat):
text_index = mat.detach().cpu().numpy().argmax(axis=1)
ignored_tokens = [0]
selection = np.ones(len(text_index), dtype=bool)
selection[1:] = text_index[1:] != text_index[:-1]
for ignored_token in ignored_tokens:
selection &= text_index != ignored_token
return text_index[selection], np.where(selection)[0]
def get_ctcloss(self, preds, gt_text, weight):
if not isinstance(weight, torch.Tensor):
weight = torch.tensor(weight).to(preds.device)
ctc_loss = torch.nn.CTCLoss(reduction='none')
log_probs = preds.log_softmax(dim=2).permute(1, 0, 2) # NTC-->TNC
targets = []
target_lengths = []
for t in gt_text:
targets += [self.char2id.get(i, len(self.chars)-1) for i in t]
target_lengths += [len(t)]
targets = torch.tensor(targets).to(preds.device)
target_lengths = torch.tensor(target_lengths).to(preds.device)
input_lengths = torch.tensor([log_probs.shape[0]]*(log_probs.shape[1])).to(preds.device)
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
loss = loss / input_lengths * weight
return loss
def main():
rec_model_dir = ocr_model_path
predictor = create_predictor(rec_model_dir)
args = edict()
args.rec_image_shape = "3, 48, 320"
args.rec_char_dict_path = ocr_txt_path
args.rec_batch_num = 6
text_recognizer = TextRecognizer(args, predictor)
image_dir = './test_imgs_cn'
gt_text = ['韩国小馆']*14
image_file_list = get_image_file_list(image_dir)
valid_image_file_list = []
img_list = []
for image_file in image_file_list:
img = cv2.imread(image_file)
if img is None:
print("error in loading image:{}".format(image_file))
continue
valid_image_file_list.append(image_file)
img_list.append(torch.from_numpy(img).permute(2, 0, 1).float())
try:
tic = time.time()
times = []
for i in range(10):
preds, _ = text_recognizer.pred_imglist(img_list) # get text
preds_all = preds.softmax(dim=2)
times += [(time.time()-tic)*1000.]
tic = time.time()
print(times)
print(np.mean(times[1:]) / len(preds_all))
weight = np.ones(len(gt_text))
loss = text_recognizer.get_ctcloss(preds, gt_text, weight)
for i in range(len(valid_image_file_list)):
pred = preds_all[i]
order, idx = text_recognizer.decode(pred)
text = text_recognizer.get_text(order)
print(f'{valid_image_file_list[i]}: pred/gt="{text}"/"{gt_text[i]}", loss={loss[i]:.2f}')
except Exception as E:
print(traceback.format_exc(), E)
if __name__ == "__main__":
main()