ComfyUI-AnyText/AnyText/nodes.py

281 lines
13 KiB
Python
Raw Normal View History

2024-09-25 15:18:31 +08:00
import os
import folder_paths
import re
import cv2
import numpy as np
from .utils import is_module_imported, pil2tensor, get_device_by_name, comfy_tensor_Image2np_Image
comfy_temp_dir = folder_paths.get_temp_directory()
Random_Gen_Mask_path = os.path.join(comfy_temp_dir, "AnyText_random_mask_pos_img.png")
tmp_pose_img_path = os.path.join(comfy_temp_dir, "AnyText_manual_mask_pos_img.png")
tmp_ori_img_path = os.path.join(comfy_temp_dir, "AnyText_ori_img.png")
class AnyText:
def __init__(self):
self.model = None
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"AnyText_Loader": ("AnyText_Loader", {"forceInput": True}),
"prompt": ("STRING", {"default": "A raccoon stands in front of the blackboard with the words \"你好呀~Hello!\" written on it.", "multiline": True}),
"a_prompt": ("STRING", {"default": "best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks", "multiline": True}),
"n_prompt": ("STRING", {"default": "low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture", "multiline": True}),
"mode": (['text-generation', 'text-editing'],{"default": 'text-generation'}),
"sort_radio": (["", ""],{"default": ""}),
"revise_pos": ("BOOLEAN", {"default": False}),
"img_count": ("INT", {"default": 1, "min": 1, "max": 10}),
"ddim_steps": ("INT", {"default": 20, "min": 2, "max": 100}),
"seed": ("INT", {"default": 9999, "min": -1, "max": 99999999}),
"nonEdit_random_gen_width": ("INT", {"default": 512, "min": 128, "max": 1920, "step": 64}),
"nonEdit_random_gen_height": ("INT", {"default": 512, "min": 128, "max": 1920, "step": 64}),
# "width": ("INT", {"forceInput": True}),
# "height": ("INT", {"forceInput": True}),
"Random_Gen": ("BOOLEAN", {"default": False}),
"strength": ("FLOAT", {
"default": 1.00,
"min": -999999,
"max": 9999999,
"step": 0.01
}),
"cfg_scale": ("FLOAT", {
"default": 9,
"min": 1,
"max": 99,
"step": 0.1
}),
"eta": ("FLOAT", {
"default": 0,
"min": 0,
"max": 1,
"step": 0.1
}),
"device": (["auto", "cuda", "cpu", "mps", "xpu"],{"default": "auto"}),
"fp16": ("BOOLEAN", {"default": True}),
"cpu_offload": ("BOOLEAN", {"default": False, "label_on": "model_to_cpu", "label_off": "unload_model"}),
"all_to_device": ("BOOLEAN", {"default": False}),
},
"optional": {
"ori_image": ("IMAGE", {"forceInput": True}),
"pos_image": ("IMAGE", {"forceInput": True}),
# "show_debug": ("BOOLEAN", {"default": False}),
},
}
RETURN_TYPES = ("IMAGE",)
CATEGORY = "ExtraModels/AnyText"
FUNCTION = "anytext_process"
TITLE = "AnyText Geneation"
def anytext_process(self,
mode,
AnyText_Loader,
ori_image,
pos_image,
sort_radio,
revise_pos,
Random_Gen,
prompt,
cpu_offload,
# show_debug,
img_count,
fp16,
device,
all_to_device,
ddim_steps=20,
strength=1,
cfg_scale=9,
seed="",
eta=0.0,
a_prompt="",
n_prompt="",
nonEdit_random_gen_width=512,
nonEdit_random_gen_height=512,
):
def prompt_replace(prompt):
prompt = prompt.replace('', '"')
prompt = prompt.replace('', '"')
p = '"(.*?)"'
strs = re.findall(p, prompt)
if len(strs) == 0:
strs = [' ']
else:
for s in strs:
prompt = prompt.replace(f'"{s}"', f' * ', 1)
return prompt
def check_overlap_polygon(rect_pts1, rect_pts2):
poly1 = cv2.convexHull(rect_pts1)
poly2 = cv2.convexHull(rect_pts2)
rect1 = cv2.boundingRect(poly1)
rect2 = cv2.boundingRect(poly2)
if rect1[0] + rect1[2] >= rect2[0] and rect2[0] + rect2[2] >= rect1[0] and rect1[1] + rect1[3] >= rect2[1] and rect2[1] + rect2[3] >= rect1[1]:
return True
return False
def count_lines(prompt):
prompt = prompt.replace('', '"')
prompt = prompt.replace('', '"')
p = '"(.*?)"'
strs = re.findall(p, prompt)
if len(strs) == 0:
strs = [' ']
return len(strs)
def generate_rectangles(w, h, n, max_trys=200):
img = np.zeros((h, w, 1), dtype=np.uint8)
rectangles = []
attempts = 0
n_pass = 0
low_edge = int(max(w, h)*0.3 if n <= 3 else max(w, h)*0.2) # ~150, ~100
while attempts < max_trys:
rect_w = min(np.random.randint(max((w*0.5)//n, low_edge), w), int(w*0.8))
ratio = np.random.uniform(4, 10)
rect_h = max(low_edge, int(rect_w/ratio))
rect_h = min(rect_h, int(h*0.8))
# gen rotate angle
rotation_angle = 0
rand_value = np.random.rand()
if rand_value < 0.7:
pass
elif rand_value < 0.8:
rotation_angle = np.random.randint(0, 40)
elif rand_value < 0.9:
rotation_angle = np.random.randint(140, 180)
else:
rotation_angle = np.random.randint(85, 95)
# rand position
x = np.random.randint(0, w - rect_w)
y = np.random.randint(0, h - rect_h)
# get vertex
rect_pts = cv2.boxPoints(((rect_w/2, rect_h/2), (rect_w, rect_h), rotation_angle))
rect_pts = np.int32(rect_pts)
# move
rect_pts += (x, y)
# check boarder
if np.any(rect_pts < 0) or np.any(rect_pts[:, 0] >= w) or np.any(rect_pts[:, 1] >= h):
attempts += 1
continue
# check overlap
if any(check_overlap_polygon(rect_pts, rp) for rp in rectangles): # type: ignore
attempts += 1
continue
n_pass += 1
img = cv2.fillPoly(img, [rect_pts], 255)
cv2.imwrite(Random_Gen_Mask_path, 255-img[..., ::-1])
rectangles.append(rect_pts)
if n_pass == n:
break
print("attempts:", attempts)
if len(rectangles) != n:
raise Exception(f'Failed in auto generate positions after {attempts} attempts, try again!')
return img
if not is_module_imported('AnyText_Pipeline'):
from .AnyText_scripts.AnyText_pipeline import AnyText_Pipeline
#check if prompt is chinese to decide whether to load translator检测是否为中文提示词否则不适用翻译。
prompt_modify = prompt_replace(prompt)
bool_is_chinese = AnyText_Pipeline.is_chinese(self, prompt_modify)
device = get_device_by_name(device)
loader_out = AnyText_Loader.split("|")
if bool_is_chinese == False:
use_translator = False
else:
use_translator = True
if 'damo/nlp_csanmt_translation_zh2en' in loader_out[3]:
if not os.access(os.path.join(folder_paths.models_dir, "prompt_generator", "nlp_csanmt_translation_zh2en", "tf_ckpts", "ckpt-0.data-00000-of-00001"), os.F_OK):
if not is_module_imported('snapshot_download'):
from modelscope.hub.snapshot_download import snapshot_download
snapshot_download('damo/nlp_csanmt_translation_zh2en')
else:
if not os.access(os.path.join(folder_paths.models_dir, "prompt_generator", "models--utrobinmv--t5_translate_en_ru_zh_small_1024", "model.safetensors"), os.F_OK):
if not is_module_imported('hg_snapshot_download'):
from huggingface_hub import snapshot_download as hg_snapshot_download
hg_snapshot_download(repo_id="utrobinmv/t5_translate_en_ru_zh_small_1024")
pipe = AnyText_Pipeline(ckpt_path=loader_out[1], clip_path=loader_out[2], translator_path=loader_out[3], cfg_path=loader_out[4], use_translator=use_translator, device=device, use_fp16=fp16, all_to_device=all_to_device, loaded_model_tensor=self.model)
# tensor图片转换为numpy图片
pos_image = comfy_tensor_Image2np_Image(self, pos_image)
ori_image = comfy_tensor_Image2np_Image(self, ori_image)
# 保存转换后的numpy图片到ComfyUI临时文件夹
pos_image.save(tmp_pose_img_path)
ori_image.save(tmp_ori_img_path)
ori = tmp_ori_img_path
pos = tmp_pose_img_path
if mode == "text-generation":
ori_image = None
revise_pos = revise_pos
else:
revise_pos = False
ori_image = ori
n_lines = count_lines(prompt)
if Random_Gen == True:
generate_rectangles(nonEdit_random_gen_width, nonEdit_random_gen_height, n_lines, max_trys=500)
pos_img = Random_Gen_Mask_path
else:
pos_img = pos
# lora_path = r"D:\AI\ComfyUI_windows_portable\ComfyUI\models\loras\ys艺术\sd15_mw_bpch_扁平风格插画v1d1.safetensors"
# lora_ratio = 1
# lora_path_ratio = str(lora_path)+ " " + str(lora_ratio)
# print("\033[93m", lora_path_ratio, "\033[0m")
params = {
"mode": mode,
"use_fp16": fp16,
"Random_Gen": Random_Gen,
"sort_priority": sort_radio,
"revise_pos": revise_pos,
# "show_debug": show_debug,
"image_count": img_count,
"ddim_steps": ddim_steps - 1,
"image_width": nonEdit_random_gen_width,
"image_height": nonEdit_random_gen_height,
"strength": strength,
"cfg_scale": cfg_scale,
"eta": eta,
"a_prompt": a_prompt,
"n_prompt": n_prompt,
# "lora_path_ratio": lora_path_ratio,
}
input_data = {
"prompt": prompt,
"seed": seed,
"draw_pos": pos_img,
"ori_image": ori_image,
}
# if show_debug ==True:
# print(f'\033[93mloader from .util(从.util输入的loader): {AnyText_Loader}, \033[0m\n \
# \033[93mloader_out split form loader(分割loader得到4个参数): {loader_out}, \033[0m\n \
# \033[93mFont(字体)--loader_out[0]: {loader_out[0]}, \033[0m\n \
# \033[93mAnyText Model(AnyText模型)--loader_out[1]: {loader_out[1]}, \033[0m\n \
# \033[93mclip model(clip模型)--loader_out[2]: {loader_out[2]}, \033[0m\n \
# \033[93mTranslator(翻译模型)--loader_out[3]: {loader_out[3]}, \033[0m\n \
# \033[93myaml_file(yaml配置文件): {loader_out[4]}, \033[0m\n) \
# \033[93mIs Chinese Input(是否中文输入): {use_translator}, \033[0m\n \
# \033[93mNumber of text-content to generate(需要生成的文本数量): {n_lines}, \033[0m\n \
# \033[93mpos_image location(遮罩图位置): {pos}, \033[0m\n \
# \033[93mori_image location(原图位置): {ori}, \033[0m\n \
# \033[93mSort Position(文本生成位置排序): {sort_radio}, \033[0m\n \
# \033[93mEnable revise_pos(启用位置修正): {revise_pos}, \033[0m')
x_samples, results, rtn_code, rtn_warning, debug_info, self.model = pipe(input_data, font_path=loader_out[0], cpu_offload=cpu_offload, **params)
if rtn_code < 0:
raise Exception(f"Error in AnyText pipeline: {rtn_warning}")
output = pil2tensor(x_samples)
print("\n", debug_info)
return(output)
# Node class and display name mappings
NODE_CLASS_MAPPINGS = {
"AnyText": AnyText,
}