forked from FlintyLemming/ComfyUI-AnyText
49 lines
1.2 KiB
Python
49 lines
1.2 KiB
Python
from torch import nn
|
|
|
|
|
|
class CTCHead(nn.Module):
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels=6625,
|
|
fc_decay=0.0004,
|
|
mid_channels=None,
|
|
return_feats=False,
|
|
**kwargs):
|
|
super(CTCHead, self).__init__()
|
|
if mid_channels is None:
|
|
self.fc = nn.Linear(
|
|
in_channels,
|
|
out_channels,
|
|
bias=True,)
|
|
else:
|
|
self.fc1 = nn.Linear(
|
|
in_channels,
|
|
mid_channels,
|
|
bias=True,
|
|
)
|
|
self.fc2 = nn.Linear(
|
|
mid_channels,
|
|
out_channels,
|
|
bias=True,
|
|
)
|
|
|
|
self.out_channels = out_channels
|
|
self.mid_channels = mid_channels
|
|
self.return_feats = return_feats
|
|
|
|
def forward(self, x, labels=None):
|
|
if self.mid_channels is None:
|
|
predicts = self.fc(x)
|
|
else:
|
|
x = self.fc1(x)
|
|
predicts = self.fc2(x)
|
|
|
|
if self.return_feats:
|
|
result = dict()
|
|
result['ctc'] = predicts
|
|
result['ctc_neck'] = x
|
|
else:
|
|
result = predicts
|
|
|
|
return result
|