Skip to content

Commit

Permalink
ENH: Support fish speech 1.4 (#2295)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Sep 13, 2024
1 parent 4274507 commit 8f73b05
Show file tree
Hide file tree
Showing 43 changed files with 544 additions and 1,748 deletions.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ all =
loguru # For Fish Speech
natsort # For Fish Speech
loralib # For Fish Speech
ormsgpack # For Fish Speech
qwen-vl-utils # For qwen2-vl
datamodel_code_generator # for minicpm-4B
jsonschema # for minicpm-4B
Expand Down Expand Up @@ -198,6 +199,7 @@ audio =
loguru # For Fish Speech
natsort # For Fish Speech
loralib # For Fish Speech
ormsgpack # For Fish Speech
doc =
ipython>=6.5.0
sphinx>=3.0.0
Expand Down
1 change: 1 addition & 0 deletions xinference/deploy/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jj-pytorchvideo # For CogVLM2-video
loguru # For Fish Speech
natsort # For Fish Speech
loralib # For Fish Speech
ormsgpack # For Fish Speech
qwen-vl-utils # For qwen2-vl
datamodel_code_generator # for minicpm-4B
jsonschema # for minicpm-4B
Expand Down
1 change: 1 addition & 0 deletions xinference/deploy/docker/requirements_cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ jj-pytorchvideo # For CogVLM2-video
loguru # For Fish Speech
natsort # For Fish Speech
loralib # For Fish Speech
ormsgpack # For Fish Speech
qwen-vl-utils # For qwen2-vl
datamodel_code_generator # for minicpm-4B
jsonschema # for minicpm-4B
14 changes: 7 additions & 7 deletions xinference/model/audio/fish_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def load(self):

checkpoint_path = os.path.join(
self._model_path,
"firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
"firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
)
self._model = load_decoder_model(
config_name="firefly_gan_vq",
Expand Down Expand Up @@ -213,12 +213,12 @@ def speech(
text=input,
enable_reference_audio=False,
reference_audio=None,
reference_text="",
max_new_tokens=0,
chunk_length=100,
top_p=0.7,
repetition_penalty=1.2,
temperature=0.7,
reference_text=kwargs.get("reference_text", ""),
max_new_tokens=kwargs.get("max_new_tokens", 1024),
chunk_length=kwargs.get("chunk_length", 200),
top_p=kwargs.get("top_p", 0.7),
repetition_penalty=kwargs.get("repetition_penalty", 1.2),
temperature=kwargs.get("temperature", 0.7),
)
)
sample_rate, audio = result[0][1]
Expand Down
6 changes: 3 additions & 3 deletions xinference/model/audio/model_spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@
"multilingual": true
},
{
"model_name": "FishSpeech-1.2-SFT",
"model_name": "FishSpeech-1.4",
"model_family": "FishAudio",
"model_id": "fishaudio/fish-speech-1.2-sft",
"model_revision": "180288e21ec5c50cfc564023a22f789e4b88a0e0",
"model_id": "fishaudio/fish-speech-1.4",
"model_revision": "3c49651b8e583b6b13f55e375432e0d57e1aa84d",
"model_ability": "text-to-audio",
"multilingual": true
}
Expand Down
2 changes: 1 addition & 1 deletion xinference/model/audio/tests/test_fish_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_fish_speech(setup):
client = Client(endpoint)

model_uid = client.launch_model(
model_name="FishSpeech-1.2-SFT",
model_name="FishSpeech-1.4",
model_type="audio",
)
model = client.get_model(model_uid)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@ head:
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
num_mels: 512
upsample_initial_channel: 512
use_template: false
pre_conv_kernel_size: 13
post_conv_kernel_size: 13
quantizer:
_target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
input_dim: 512
n_groups: 4
n_groups: 8
n_codebooks: 1
levels: [8, 5, 5, 5]
downsample_factor: [2]
downsample_factor: [2, 2]
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defaults:

project: text2semantic_finetune_dual_ar
max_length: 4096
pretrained_ckpt_path: checkpoints/fish-speech-1.2-sft
pretrained_ckpt_path: checkpoints/fish-speech-1.4

# Lightning Trainer
trainer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
"Put your text here.": "Put your text here.",
"Reference Audio": "Reference Audio",
"Reference Text": "Reference Text",
"Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.": "Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.",
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
"Remove Selected Data": "Remove Selected Data",
"Removed path successfully!": "Removed path successfully!",
"Repetition Penalty": "Repetition Penalty",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
"Put your text here.": "Ponga su texto aquí.",
"Reference Audio": "Audio de Referencia",
"Reference Text": "Texto de Referencia",
"Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado se publica bajo la Licencia BSD-3-Clause, y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
"Remove Selected Data": "Eliminar Datos Seleccionados",
"Removed path successfully!": "¡Ruta eliminada exitosamente!",
"Repetition Penalty": "Penalización por Repetición",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
"Put your text here.": "ここにテキストを入力してください。",
"Reference Audio": "リファレンスオーディオ",
"Reference Text": "リファレンステキスト",
"Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.": "関連コードはBSD-3-Clauseライセンスの下でリリースされ、重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
"Remove Selected Data": "選択したデータを削除",
"Removed path successfully!": "パスの削除に成功しました!",
"Repetition Penalty": "反復ペナルティ",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"Reference Text": "Texto de Referência",
"warning": "Aviso",
"Pre-processing begins...": "O pré-processamento começou!",
"Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado é licenciado sob a Licença BSD-3-Clause, e os pesos sob a Licença CC BY-NC-SA 4.0.",
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
"Remove Selected Data": "Remover Dados Selecionados",
"Removed path successfully!": "Caminho removido com sucesso!",
"Repetition Penalty": "Penalidade de Repetição",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
"Put your text here.": "在此处输入文本.",
"Reference Audio": "参考音频",
"Reference Text": "参考文本",
"Related code are released under BSD-3-Clause License, and weights are released under CC BY-NC-SA 4.0 License.": "相关代码使用 BSD-3-Clause 许可证发布,权重使用 CC BY-NC-SA 4.0 许可证发布.",
"Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
"Remove Selected Data": "移除选中数据",
"Removed path successfully!": "移除路径成功!",
"Repetition Penalty": "重复惩罚",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def from_pretrained(

if "int8" in str(Path(path)):
logger.info("Using int8 weight-only quantization!")
from ...tools.llama.quantize import WeightOnlyInt8QuantHandler
from tools.llama.quantize import WeightOnlyInt8QuantHandler

simple_quantizer = WeightOnlyInt8QuantHandler(model)
model = simple_quantizer.convert_for_runtime()
Expand All @@ -363,7 +363,7 @@ def from_pretrained(
path_comps = path.name.split("-")
assert path_comps[-2].startswith("g")
groupsize = int(path_comps[-2][1:])
from ...tools.llama.quantize import WeightOnlyInt4QuantHandler
from tools.llama.quantize import WeightOnlyInt4QuantHandler

simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
model = simple_quantizer.convert_for_runtime()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .lit_module import VQGAN

__all__ = ["VQGAN"]
Loading

0 comments on commit 8f73b05

Please sign in to comment.