ComfyUI-AnyText/AnyText/nodes.py
FlintyLemming e1911954ed Init
2024-09-25 15:18:31 +08:00

281 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import folder_paths
import re
import cv2
import numpy as np
from .utils import is_module_imported, pil2tensor, get_device_by_name, comfy_tensor_Image2np_Image
comfy_temp_dir = folder_paths.get_temp_directory()
Random_Gen_Mask_path = os.path.join(comfy_temp_dir, "AnyText_random_mask_pos_img.png")
tmp_pose_img_path = os.path.join(comfy_temp_dir, "AnyText_manual_mask_pos_img.png")
tmp_ori_img_path = os.path.join(comfy_temp_dir, "AnyText_ori_img.png")
class AnyText:
def __init__(self):
self.model = None
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"AnyText_Loader": ("AnyText_Loader", {"forceInput": True}),
"prompt": ("STRING", {"default": "A raccoon stands in front of the blackboard with the words \"你好呀~Hello!\" written on it.", "multiline": True}),
"a_prompt": ("STRING", {"default": "best quality, extremely detailed,4k, HD, supper legible text, clear text edges, clear strokes, neat writing, no watermarks", "multiline": True}),
"n_prompt": ("STRING", {"default": "low-res, bad anatomy, extra digit, fewer digits, cropped, worst quality, low quality, watermark, unreadable text, messy words, distorted text, disorganized writing, advertising picture", "multiline": True}),
"mode": (['text-generation', 'text-editing'],{"default": 'text-generation'}),
"sort_radio": (["", ""],{"default": ""}),
"revise_pos": ("BOOLEAN", {"default": False}),
"img_count": ("INT", {"default": 1, "min": 1, "max": 10}),
"ddim_steps": ("INT", {"default": 20, "min": 2, "max": 100}),
"seed": ("INT", {"default": 9999, "min": -1, "max": 99999999}),
"nonEdit_random_gen_width": ("INT", {"default": 512, "min": 128, "max": 1920, "step": 64}),
"nonEdit_random_gen_height": ("INT", {"default": 512, "min": 128, "max": 1920, "step": 64}),
# "width": ("INT", {"forceInput": True}),
# "height": ("INT", {"forceInput": True}),
"Random_Gen": ("BOOLEAN", {"default": False}),
"strength": ("FLOAT", {
"default": 1.00,
"min": -999999,
"max": 9999999,
"step": 0.01
}),
"cfg_scale": ("FLOAT", {
"default": 9,
"min": 1,
"max": 99,
"step": 0.1
}),
"eta": ("FLOAT", {
"default": 0,
"min": 0,
"max": 1,
"step": 0.1
}),
"device": (["auto", "cuda", "cpu", "mps", "xpu"],{"default": "auto"}),
"fp16": ("BOOLEAN", {"default": True}),
"cpu_offload": ("BOOLEAN", {"default": False, "label_on": "model_to_cpu", "label_off": "unload_model"}),
"all_to_device": ("BOOLEAN", {"default": False}),
},
"optional": {
"ori_image": ("IMAGE", {"forceInput": True}),
"pos_image": ("IMAGE", {"forceInput": True}),
# "show_debug": ("BOOLEAN", {"default": False}),
},
}
RETURN_TYPES = ("IMAGE",)
CATEGORY = "ExtraModels/AnyText"
FUNCTION = "anytext_process"
TITLE = "AnyText Geneation"
def anytext_process(self,
mode,
AnyText_Loader,
ori_image,
pos_image,
sort_radio,
revise_pos,
Random_Gen,
prompt,
cpu_offload,
# show_debug,
img_count,
fp16,
device,
all_to_device,
ddim_steps=20,
strength=1,
cfg_scale=9,
seed="",
eta=0.0,
a_prompt="",
n_prompt="",
nonEdit_random_gen_width=512,
nonEdit_random_gen_height=512,
):
def prompt_replace(prompt):
prompt = prompt.replace('', '"')
prompt = prompt.replace('', '"')
p = '"(.*?)"'
strs = re.findall(p, prompt)
if len(strs) == 0:
strs = [' ']
else:
for s in strs:
prompt = prompt.replace(f'"{s}"', f' * ', 1)
return prompt
def check_overlap_polygon(rect_pts1, rect_pts2):
poly1 = cv2.convexHull(rect_pts1)
poly2 = cv2.convexHull(rect_pts2)
rect1 = cv2.boundingRect(poly1)
rect2 = cv2.boundingRect(poly2)
if rect1[0] + rect1[2] >= rect2[0] and rect2[0] + rect2[2] >= rect1[0] and rect1[1] + rect1[3] >= rect2[1] and rect2[1] + rect2[3] >= rect1[1]:
return True
return False
def count_lines(prompt):
prompt = prompt.replace('', '"')
prompt = prompt.replace('', '"')
p = '"(.*?)"'
strs = re.findall(p, prompt)
if len(strs) == 0:
strs = [' ']
return len(strs)
def generate_rectangles(w, h, n, max_trys=200):
img = np.zeros((h, w, 1), dtype=np.uint8)
rectangles = []
attempts = 0
n_pass = 0
low_edge = int(max(w, h)*0.3 if n <= 3 else max(w, h)*0.2) # ~150, ~100
while attempts < max_trys:
rect_w = min(np.random.randint(max((w*0.5)//n, low_edge), w), int(w*0.8))
ratio = np.random.uniform(4, 10)
rect_h = max(low_edge, int(rect_w/ratio))
rect_h = min(rect_h, int(h*0.8))
# gen rotate angle
rotation_angle = 0
rand_value = np.random.rand()
if rand_value < 0.7:
pass
elif rand_value < 0.8:
rotation_angle = np.random.randint(0, 40)
elif rand_value < 0.9:
rotation_angle = np.random.randint(140, 180)
else:
rotation_angle = np.random.randint(85, 95)
# rand position
x = np.random.randint(0, w - rect_w)
y = np.random.randint(0, h - rect_h)
# get vertex
rect_pts = cv2.boxPoints(((rect_w/2, rect_h/2), (rect_w, rect_h), rotation_angle))
rect_pts = np.int32(rect_pts)
# move
rect_pts += (x, y)
# check boarder
if np.any(rect_pts < 0) or np.any(rect_pts[:, 0] >= w) or np.any(rect_pts[:, 1] >= h):
attempts += 1
continue
# check overlap
if any(check_overlap_polygon(rect_pts, rp) for rp in rectangles): # type: ignore
attempts += 1
continue
n_pass += 1
img = cv2.fillPoly(img, [rect_pts], 255)
cv2.imwrite(Random_Gen_Mask_path, 255-img[..., ::-1])
rectangles.append(rect_pts)
if n_pass == n:
break
print("attempts:", attempts)
if len(rectangles) != n:
raise Exception(f'Failed in auto generate positions after {attempts} attempts, try again!')
return img
if not is_module_imported('AnyText_Pipeline'):
from .AnyText_scripts.AnyText_pipeline import AnyText_Pipeline
#check if prompt is chinese to decide whether to load translator检测是否为中文提示词否则不适用翻译。
prompt_modify = prompt_replace(prompt)
bool_is_chinese = AnyText_Pipeline.is_chinese(self, prompt_modify)
device = get_device_by_name(device)
loader_out = AnyText_Loader.split("|")
if bool_is_chinese == False:
use_translator = False
else:
use_translator = True
if 'damo/nlp_csanmt_translation_zh2en' in loader_out[3]:
if not os.access(os.path.join(folder_paths.models_dir, "prompt_generator", "nlp_csanmt_translation_zh2en", "tf_ckpts", "ckpt-0.data-00000-of-00001"), os.F_OK):
if not is_module_imported('snapshot_download'):
from modelscope.hub.snapshot_download import snapshot_download
snapshot_download('damo/nlp_csanmt_translation_zh2en')
else:
if not os.access(os.path.join(folder_paths.models_dir, "prompt_generator", "models--utrobinmv--t5_translate_en_ru_zh_small_1024", "model.safetensors"), os.F_OK):
if not is_module_imported('hg_snapshot_download'):
from huggingface_hub import snapshot_download as hg_snapshot_download
hg_snapshot_download(repo_id="utrobinmv/t5_translate_en_ru_zh_small_1024")
pipe = AnyText_Pipeline(ckpt_path=loader_out[1], clip_path=loader_out[2], translator_path=loader_out[3], cfg_path=loader_out[4], use_translator=use_translator, device=device, use_fp16=fp16, all_to_device=all_to_device, loaded_model_tensor=self.model)
# tensor图片转换为numpy图片
pos_image = comfy_tensor_Image2np_Image(self, pos_image)
ori_image = comfy_tensor_Image2np_Image(self, ori_image)
# 保存转换后的numpy图片到ComfyUI临时文件夹
pos_image.save(tmp_pose_img_path)
ori_image.save(tmp_ori_img_path)
ori = tmp_ori_img_path
pos = tmp_pose_img_path
if mode == "text-generation":
ori_image = None
revise_pos = revise_pos
else:
revise_pos = False
ori_image = ori
n_lines = count_lines(prompt)
if Random_Gen == True:
generate_rectangles(nonEdit_random_gen_width, nonEdit_random_gen_height, n_lines, max_trys=500)
pos_img = Random_Gen_Mask_path
else:
pos_img = pos
# lora_path = r"D:\AI\ComfyUI_windows_portable\ComfyUI\models\loras\ys艺术\sd15_mw_bpch_扁平风格插画v1d1.safetensors"
# lora_ratio = 1
# lora_path_ratio = str(lora_path)+ " " + str(lora_ratio)
# print("\033[93m", lora_path_ratio, "\033[0m")
params = {
"mode": mode,
"use_fp16": fp16,
"Random_Gen": Random_Gen,
"sort_priority": sort_radio,
"revise_pos": revise_pos,
# "show_debug": show_debug,
"image_count": img_count,
"ddim_steps": ddim_steps - 1,
"image_width": nonEdit_random_gen_width,
"image_height": nonEdit_random_gen_height,
"strength": strength,
"cfg_scale": cfg_scale,
"eta": eta,
"a_prompt": a_prompt,
"n_prompt": n_prompt,
# "lora_path_ratio": lora_path_ratio,
}
input_data = {
"prompt": prompt,
"seed": seed,
"draw_pos": pos_img,
"ori_image": ori_image,
}
# if show_debug ==True:
# print(f'\033[93mloader from .util(从.util输入的loader): {AnyText_Loader}, \033[0m\n \
# \033[93mloader_out split form loader(分割loader得到4个参数): {loader_out}, \033[0m\n \
# \033[93mFont(字体)--loader_out[0]: {loader_out[0]}, \033[0m\n \
# \033[93mAnyText Model(AnyText模型)--loader_out[1]: {loader_out[1]}, \033[0m\n \
# \033[93mclip model(clip模型)--loader_out[2]: {loader_out[2]}, \033[0m\n \
# \033[93mTranslator(翻译模型)--loader_out[3]: {loader_out[3]}, \033[0m\n \
# \033[93myaml_file(yaml配置文件): {loader_out[4]}, \033[0m\n) \
# \033[93mIs Chinese Input(是否中文输入): {use_translator}, \033[0m\n \
# \033[93mNumber of text-content to generate(需要生成的文本数量): {n_lines}, \033[0m\n \
# \033[93mpos_image location(遮罩图位置): {pos}, \033[0m\n \
# \033[93mori_image location(原图位置): {ori}, \033[0m\n \
# \033[93mSort Position(文本生成位置排序): {sort_radio}, \033[0m\n \
# \033[93mEnable revise_pos(启用位置修正): {revise_pos}, \033[0m')
x_samples, results, rtn_code, rtn_warning, debug_info, self.model = pipe(input_data, font_path=loader_out[0], cpu_offload=cpu_offload, **params)
if rtn_code < 0:
raise Exception(f"Error in AnyText pipeline: {rtn_warning}")
output = pil2tensor(x_samples)
print("\n", debug_info)
return(output)
# Node class and display name mappings
NODE_CLASS_MAPPINGS = {
"AnyText": AnyText,
}