Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6021db52aa | |||
| a2375e6daf | |||
| 8504126d4a | |||
| 4dc4ffa58c | |||
| 26b38811cc | |||
| 1d63461e35 | |||
| 5f1e2e67a4 | |||
| c23b1d42bf | |||
| b2866d073c | |||
| a889c68e40 |
@@ -1,10 +0,0 @@
|
|||||||
VA_ALIAS='("джарвис",)'
|
|
||||||
VA_TBR='("скажи", "покажи", "ответь", "произнеси", "расскажиv, "сколько", "слушай")'
|
|
||||||
VOSK_MODEL_NAME='vosk-model-small-ru-0.22' # vosk-model-ru-0.42
|
|
||||||
MICROPHONE_INDEX=-1
|
|
||||||
PICOVOICE_TOKEN='token'
|
|
||||||
|
|
||||||
|
|
||||||
# home assistant
|
|
||||||
HOME_ASSISTANT_URL='http://localhost:8123/api'
|
|
||||||
HOME_ASSISTANT_TOKEN=''
|
|
||||||
@@ -8,9 +8,6 @@ __pycache__/
|
|||||||
# Custom
|
# Custom
|
||||||
data/model_small/
|
data/model_small/
|
||||||
data/model_large/
|
data/model_large/
|
||||||
data/v4_ru.pt
|
|
||||||
MyTTSDataset/
|
|
||||||
vocal.wav
|
|
||||||
|
|
||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|||||||
@@ -72,10 +72,3 @@ weather:
|
|||||||
- возможен дождь сегодня?
|
- возможен дождь сегодня?
|
||||||
- прогноз погоды на сегодня
|
- прогноз погоды на сегодня
|
||||||
- погода
|
- погода
|
||||||
home_assistant_execute:
|
|
||||||
- включи телевизор
|
|
||||||
- выключи телевизор
|
|
||||||
- начни уборку
|
|
||||||
- убрать мою комнату
|
|
||||||
home_assistant_get:
|
|
||||||
- тест
|
|
||||||
+5
-15
@@ -1,15 +1,5 @@
|
|||||||
import environs
|
VA_ALIAS = ('джарвис',)
|
||||||
|
VA_TBR = ('скажи', 'покажи', 'ответь', 'произнеси', 'расскажи', 'сколько', 'слушай')
|
||||||
env = environs.Env()
|
MODEL_NAME = "vosk-model-small-ru-0.22" # vosk-model-ru-0.42
|
||||||
env.read_env()
|
MICROPHONE_INDEX = -1
|
||||||
|
PICOVOICE_TOKEN = "4xbwaZwZmSHeTiowFl5Rgqsc8CR4FKGV8YueJUlR4Zt2e1kB64IDcA=="
|
||||||
|
|
||||||
VA_ALIAS = env.str("VA_ALIAS")
|
|
||||||
VA_TBR = env.str("VA_TBR")
|
|
||||||
VOSK_MODEL_NAME = env.str("VOSK_MODEL_NAME")
|
|
||||||
MICROPHONE_INDEX = env.int("MICROPHONE_INDEX")
|
|
||||||
PICOVOICE_TOKEN = env.str("PICOVOICE_TOKEN")
|
|
||||||
|
|
||||||
# home assistant
|
|
||||||
HOME_ASSISTANT_URL = env.str("HOME_ASSISTANT_URL")
|
|
||||||
HOME_ASSISTANT_TOKEN = env.str("HOME_ASSISTANT_TOKEN")
|
|
||||||
|
|||||||
@@ -1,4 +0,0 @@
|
|||||||
пылесос:
|
|
||||||
- entity_id:vacuum.roborock_vacuum_m1s
|
|
||||||
- state:находится в
|
|
||||||
- attributes.battery_level:а его уровень зарядки
|
|
||||||
@@ -1,97 +0,0 @@
|
|||||||
import requests
|
|
||||||
import yaml
|
|
||||||
from fuzzywuzzy import process
|
|
||||||
from requests import Response
|
|
||||||
|
|
||||||
from data import config
|
|
||||||
|
|
||||||
|
|
||||||
class HomeAssistant:
|
|
||||||
"""
|
|
||||||
Модуль home assistant для работы с его api
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
self.url = "http://192.168.0.112:9999/api"
|
|
||||||
self.token = config.HOME_ASSISTANT_TOKEN
|
|
||||||
self.HA_CMD_LIST = yaml.safe_load(open('data/home_assistant_entities.yaml', encoding='utf8'))
|
|
||||||
|
|
||||||
def get_info(self, state: str) -> Response:
|
|
||||||
"""
|
|
||||||
Функция для получения информации о заданном entity
|
|
||||||
|
|
||||||
:param state: str - объект в home assistant информацию о котором надо узнать
|
|
||||||
:return: Response - ответ от сервера api
|
|
||||||
"""
|
|
||||||
response = requests.get(
|
|
||||||
url=f"{self.url}/states",
|
|
||||||
headers={
|
|
||||||
"Authorization": "Bearer " + self.token
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for entity in response.json():
|
|
||||||
if entity["entity_id"] == state:
|
|
||||||
return entity
|
|
||||||
return response
|
|
||||||
|
|
||||||
def send_process(self, command: str = "выключи телевизор") -> bool:
|
|
||||||
"""
|
|
||||||
Функция для отправки запроса о выполнении команды к api
|
|
||||||
|
|
||||||
:param command: str - команда в виде строки
|
|
||||||
:return: bool - удачная ли отправка запроса к api
|
|
||||||
"""
|
|
||||||
response = requests.post(
|
|
||||||
url=f"{self.url}/services/conversation/process",
|
|
||||||
json={"text": command},
|
|
||||||
headers={
|
|
||||||
"Authorization": "Bearer " + self.token,
|
|
||||||
"content-type": "application/json"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if response.status_code == 200:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def voice_to_name(self, voice: str) -> str:
|
|
||||||
"""
|
|
||||||
Функция для неточного сравнивания входной строки голоса
|
|
||||||
и списка устройств дял которых можно узнать информацию
|
|
||||||
|
|
||||||
:param voice: str - распознанная фраза без проверки по списку
|
|
||||||
:return: str - найденный объект для получения информации
|
|
||||||
"""
|
|
||||||
words = voice.lower().split()
|
|
||||||
best_match = None
|
|
||||||
highest_score = 0
|
|
||||||
for word in words:
|
|
||||||
result, score = process.extractOne(word, self.HA_CMD_LIST.keys())
|
|
||||||
if score > highest_score:
|
|
||||||
highest_score = score
|
|
||||||
best_match = result
|
|
||||||
return best_match
|
|
||||||
|
|
||||||
def validate_info(self, name: str) -> str:
|
|
||||||
"""
|
|
||||||
Функция для получения готовой строки информации entity по его имени.
|
|
||||||
Эта строка готова для произношения
|
|
||||||
|
|
||||||
:param name: str - имя entity для нахождения информации о нём
|
|
||||||
:return: str - готовая строка для найденного по имени объекта для её произношения
|
|
||||||
"""
|
|
||||||
answer = name
|
|
||||||
entity_config = self.HA_CMD_LIST.get(name)
|
|
||||||
if entity_config:
|
|
||||||
# Создание словаря, разделяя каждый элемент конфигурации на ключ и значение
|
|
||||||
entity_details = {item.split(':')[0]: item.split(':')[1] for item in entity_config}
|
|
||||||
entity_id = entity_details.pop("entity_id", "robot")
|
|
||||||
if entity_id:
|
|
||||||
responses = self.get_info(entity_id)
|
|
||||||
for attribute_path, label in entity_details.items():
|
|
||||||
response = responses
|
|
||||||
try:
|
|
||||||
for attribute in attribute_path.split("."):
|
|
||||||
response = response[attribute]
|
|
||||||
answer += f" {label} {response}"
|
|
||||||
except KeyError:
|
|
||||||
continue
|
|
||||||
return answer
|
|
||||||
+3
-23
@@ -11,20 +11,15 @@ from fuzzywuzzy import fuzz
|
|||||||
from pvrecorder import PvRecorder
|
from pvrecorder import PvRecorder
|
||||||
|
|
||||||
from data import config
|
from data import config
|
||||||
from modules import HomeAssistant
|
|
||||||
from utils import download_models, execute_cmd, play
|
from utils import download_models, execute_cmd, play
|
||||||
|
|
||||||
|
|
||||||
class Jarvis:
|
class Jarvis:
|
||||||
"""
|
|
||||||
Это основной модуль голосового ассистента
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
download_models.install_vosk_model()
|
download_models.install_vosk_model()
|
||||||
self.recorder = None
|
self.recorder = None
|
||||||
self.CDIR = os.getcwd()
|
self.CDIR = os.getcwd()
|
||||||
self.VA_CMD_LIST = yaml.safe_load(open('data/commands.yaml', encoding='utf8'))
|
self.VA_CMD_LIST = yaml.safe_load(open('data/commands.yaml', encoding='utf8'))
|
||||||
self.home_assistant = HomeAssistant.HomeAssistant()
|
|
||||||
self.porcupine = pvporcupine.create(
|
self.porcupine = pvporcupine.create(
|
||||||
access_key=config.PICOVOICE_TOKEN,
|
access_key=config.PICOVOICE_TOKEN,
|
||||||
keywords=['jarvis'],
|
keywords=['jarvis'],
|
||||||
@@ -65,13 +60,7 @@ class Jarvis:
|
|||||||
print(f"Unexpected {err=}, {type(err)=}")
|
print(f"Unexpected {err=}, {type(err)=}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def va_respond(self, voice: str) -> bool:
|
def va_respond(self, voice: str):
|
||||||
"""
|
|
||||||
Функция предсказывает команду
|
|
||||||
|
|
||||||
:param voice: str - распознанная строка
|
|
||||||
:return: bool - распознана или нет команда
|
|
||||||
"""
|
|
||||||
print(f"Распознано: {voice}")
|
print(f"Распознано: {voice}")
|
||||||
for x in config.VA_ALIAS + config.VA_TBR:
|
for x in config.VA_ALIAS + config.VA_TBR:
|
||||||
voice = voice.replace(x, "").strip()
|
voice = voice.replace(x, "").strip()
|
||||||
@@ -82,7 +71,6 @@ class Jarvis:
|
|||||||
if vrt > rc['percent']:
|
if vrt > rc['percent']:
|
||||||
rc['cmd'] = c
|
rc['cmd'] = c
|
||||||
rc['percent'] = vrt
|
rc['percent'] = vrt
|
||||||
rc['recognized_phrase'] = x
|
|
||||||
if len(rc['cmd'].strip()) <= 0:
|
if len(rc['cmd'].strip()) <= 0:
|
||||||
return False
|
return False
|
||||||
elif rc['percent'] < 70 or rc['cmd'] not in self.VA_CMD_LIST.keys():
|
elif rc['percent'] < 70 or rc['cmd'] not in self.VA_CMD_LIST.keys():
|
||||||
@@ -90,16 +78,8 @@ class Jarvis:
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
execute_cmd.execute_cmd(self, rc['cmd'], rc['recognized_phrase'], voice)
|
execute_cmd.execute_cmd(self, rc['cmd'])
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def play(self, phrase: str, wait_done: bool = True):
|
def play(self, phrase, wait_done=True):
|
||||||
"""
|
|
||||||
Функция для запуска голосовой команды
|
|
||||||
|
|
||||||
:param self: modules.Jarvis - объект основного модуля
|
|
||||||
:param phrase: str - фраза для запуска голосовой команды
|
|
||||||
:param wait_done: bool - нужно-ли ждать окончания фразы
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
play.play(self, phrase, wait_done)
|
play.play(self, phrase, wait_done)
|
||||||
|
|||||||
Generated
+16
-2721
File diff suppressed because it is too large
Load Diff
+1
-4
@@ -6,7 +6,7 @@ authors = ["dmitrium12 <belicdima8@gmail.com>"]
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.11,<3.12"
|
python = "^3.11"
|
||||||
vosk = "^0.3.45"
|
vosk = "^0.3.45"
|
||||||
pvporcupine = "^3.0.1"
|
pvporcupine = "^3.0.1"
|
||||||
pvrecorder = "^1.2.1"
|
pvrecorder = "^1.2.1"
|
||||||
@@ -27,9 +27,6 @@ torchaudio = "^2.1.1+cpu"
|
|||||||
ollama = "^0.1.6"
|
ollama = "^0.1.6"
|
||||||
ruff = "^0.4.2"
|
ruff = "^0.4.2"
|
||||||
noisereduce = "^3.0.2"
|
noisereduce = "^3.0.2"
|
||||||
environs = "^11.0.0"
|
|
||||||
webrtcvad = "^2.0.10"
|
|
||||||
tts = "^0.22.0"
|
|
||||||
|
|
||||||
|
|
||||||
[[tool.poetry.source]]
|
[[tool.poetry.source]]
|
||||||
|
|||||||
@@ -1,55 +0,0 @@
|
|||||||
import re
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
|
|
||||||
|
|
||||||
def filter_string(input_string: str) -> str:
|
|
||||||
allowed_chars = []
|
|
||||||
for j in "АБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдеёжзийклмнопрстуфхцчшщъыьэюя1234567890 !,.?-":
|
|
||||||
allowed_chars.append(j)
|
|
||||||
input_string = re.sub(r'^\d+.\s+', '', input_string)
|
|
||||||
return ''.join([char for char in input_string if char in allowed_chars])
|
|
||||||
|
|
||||||
|
|
||||||
repetition = 0
|
|
||||||
response = {}
|
|
||||||
soup = BeautifulSoup(
|
|
||||||
requests.get('https://theportalwiki.com/wiki/GLaDOS_voice_lines/ru').text,
|
|
||||||
features='html.parser'
|
|
||||||
)
|
|
||||||
for li in soup.find_all('li'):
|
|
||||||
try:
|
|
||||||
i = li.find('i').text
|
|
||||||
url = li.find('span', class_=['audio-player']).find('a')['href']
|
|
||||||
if i not in response.keys():
|
|
||||||
response[i] = url
|
|
||||||
else:
|
|
||||||
repetition += 1
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
i = li.find('a').text
|
|
||||||
url = li.find('a')['href']
|
|
||||||
if i not in response.keys():
|
|
||||||
response[i] = url
|
|
||||||
else:
|
|
||||||
repetition += 1
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
print(f'Количество найденный элементов: {len(response)}')
|
|
||||||
print(f'Количество повторении: {repetition}')
|
|
||||||
with open('MyTTSDataset/transcript.txt', 'w') as f:
|
|
||||||
for index, (key, value) in enumerate(response.items()):
|
|
||||||
try:
|
|
||||||
response = requests.get(value)
|
|
||||||
if response.status_code == 200:
|
|
||||||
key = filter_string(key)
|
|
||||||
if key and len(key.replace(" ", "")) > 3:
|
|
||||||
with open(f"MyTTSDataset/wavs/wav{index}.wav", 'wb') as file:
|
|
||||||
file.write(response.content)
|
|
||||||
f.write(f'wav{index}|{key}\n')
|
|
||||||
except requests.exceptions.MissingSchema:
|
|
||||||
pass
|
|
||||||
except requests.exceptions.InvalidSchema:
|
|
||||||
pass
|
|
||||||
|
|||||||
+11
-22
@@ -4,34 +4,23 @@ import sys
|
|||||||
from data import config
|
from data import config
|
||||||
|
|
||||||
|
|
||||||
def install_vosk_model() -> None:
|
def install_vosk_model():
|
||||||
"""
|
|
||||||
Функция устанавливает заданную в конфигурационном файле модели
|
|
||||||
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
open('data/model_small/README')
|
open('data/model_small/README')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
if sys.platform == "linux" or sys.platform == "linux2":
|
if sys.platform == "linux" or sys.platform == "linux2":
|
||||||
os.system(
|
os.system(f"wget https://alphacephei.com/vosk/models/{config.MODEL_NAME}.zip")
|
||||||
f"wget https://alphacephei.com/vosk/models/{config.VOSK_MODEL_NAME}.zip"
|
os.system(f"unzip {config.MODEL_NAME}.zip")
|
||||||
)
|
os.system(f"mv {config.MODEL_NAME} data/model_small")
|
||||||
os.system(f"unzip {config.VOSK_MODEL_NAME}.zip")
|
os.system(f"rm -rf {config.MODEL_NAME}.zip")
|
||||||
os.system(f"mv {config.VOSK_MODEL_NAME} data/model_small")
|
|
||||||
os.system(f"rm -rf {config.VOSK_MODEL_NAME}.zip")
|
|
||||||
elif sys.platform == "darwin":
|
elif sys.platform == "darwin":
|
||||||
os.system(
|
os.system(f"curl https://alphacephei.com/vosk/models/{config.MODEL_NAME}.zip")
|
||||||
f"curl https://alphacephei.com/vosk/models/{config.VOSK_MODEL_NAME}.zip"
|
os.system(f"unzip {config.MODEL_NAME}.zip")
|
||||||
)
|
os.system(f"mv {config.MODEL_NAME} data/model_small")
|
||||||
os.system(f"unzip {config.VOSK_MODEL_NAME}.zip")
|
os.system(f"rm -rf {config.MODEL_NAME}.zip")
|
||||||
os.system(f"mv {config.VOSK_MODEL_NAME} data/model_small")
|
|
||||||
os.system(f"rm -rf {config.VOSK_MODEL_NAME}.zip")
|
|
||||||
elif sys.platform == "win32":
|
elif sys.platform == "win32":
|
||||||
os.system(
|
os.system(f"curl https://alphacephei.com/vosk/models/{config.MODEL_NAME}.zip --output 1.zip")
|
||||||
f"curl https://alphacephei.com/vosk/models/{config.VOSK_MODEL_NAME}.zip --output 1.zip"
|
|
||||||
)
|
|
||||||
os.system('powershell -command "Expand-Archive 1.zip ./"')
|
os.system('powershell -command "Expand-Archive 1.zip ./"')
|
||||||
os.system(f"rename {config.VOSK_MODEL_NAME} data/model_small")
|
os.system(f"rename {config.MODEL_NAME} data/model_small")
|
||||||
os.system("del /s /q 1.zip")
|
os.system("del /s /q 1.zip")
|
||||||
|
|||||||
+1
-16
@@ -1,13 +1,4 @@
|
|||||||
def execute_cmd(self, cmd: str, recognized_phrase: str, voice: str) -> None:
|
def execute_cmd(self, cmd: str):
|
||||||
"""
|
|
||||||
Функция выполняет полученные команды
|
|
||||||
|
|
||||||
:param self: modules.Jarvis - объект основного модуля
|
|
||||||
:param cmd: str - команда которую функция должна выполнить
|
|
||||||
:param recognized_phrase: str - распознанная фраза из списка фраз
|
|
||||||
:param voice: str - распознанная фраза без проверки по списку
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if cmd == 'thanks':
|
if cmd == 'thanks':
|
||||||
self.play("thanks")
|
self.play("thanks")
|
||||||
elif cmd == 'stupid':
|
elif cmd == 'stupid':
|
||||||
@@ -16,9 +7,3 @@ def execute_cmd(self, cmd: str, recognized_phrase: str, voice: str) -> None:
|
|||||||
self.play("off", True)
|
self.play("off", True)
|
||||||
self.porcupine.delete()
|
self.porcupine.delete()
|
||||||
exit(0)
|
exit(0)
|
||||||
elif cmd == 'home_assistant_execute':
|
|
||||||
self.home_assistant.send_process(recognized_phrase)
|
|
||||||
elif cmd == 'home_assistant_get':
|
|
||||||
entity_name = self.home_assistant.voice_to_name(voice)
|
|
||||||
entity_info = self.home_assistant.validate_info(entity_name)
|
|
||||||
print(entity_info)
|
|
||||||
|
|||||||
+1
-9
@@ -3,15 +3,7 @@ import random
|
|||||||
import simpleaudio as sa
|
import simpleaudio as sa
|
||||||
|
|
||||||
|
|
||||||
def play(self, phrase: str, wait_done: bool = True) -> None:
|
def play(self, phrase, wait_done=True):
|
||||||
"""
|
|
||||||
Функция для запуска голосовой команды
|
|
||||||
|
|
||||||
:param self: modules.Jarvis - объект основного модуля
|
|
||||||
:param phrase: str - фраза для запуска голосовой команды
|
|
||||||
:param wait_done: bool - нужно-ли ждать окончания фразы
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
filename = None
|
filename = None
|
||||||
file_array = ["not_found", "thanks", "run", "stupid", "ready", "off"]
|
file_array = ["not_found", "thanks", "run", "stupid", "ready", "off"]
|
||||||
if phrase == "greet":
|
if phrase == "greet":
|
||||||
|
|||||||
Regular → Executable
+26
-52
@@ -1,57 +1,31 @@
|
|||||||
import os
|
import time
|
||||||
|
|
||||||
|
import sounddevice as sd
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
|
||||||
|
language = 'ru'
|
||||||
|
model_id = 'ru_v3'
|
||||||
|
sample_rate = 48000 # 48000
|
||||||
|
speaker = 'aidar' # aidar, baya, kseniya, xenia, random
|
||||||
|
put_accent = True
|
||||||
|
put_yo = True
|
||||||
|
device = torch.device('cpu') # cpu или gpu
|
||||||
|
text = "Хауди Хо, друзья!!!"
|
||||||
|
|
||||||
|
model, _ = torch.hub.load(repo_or_dir='snakers4/silero-models',
|
||||||
|
model='silero_tts',
|
||||||
|
language=language,
|
||||||
|
speaker=model_id)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
|
||||||
def load_data(audio_folder):
|
def va_speak(what: str):
|
||||||
audios = []
|
audio = model.apply_tts(text=what + "..",
|
||||||
texts = []
|
speaker=speaker,
|
||||||
for audio_file in os.listdir(audio_folder):
|
sample_rate=sample_rate,
|
||||||
if audio_file.endswith('.wav'):
|
put_accent=put_accent,
|
||||||
audio_path = os.path.join(audio_folder, audio_file)
|
put_yo=put_yo)
|
||||||
waveform, sample_rate = torchaudio.load(audio_path)
|
|
||||||
text_path = audio_path.replace('.wav', '.txt')
|
|
||||||
with open(text_path) as f:
|
|
||||||
text = f.read().strip()
|
|
||||||
audios.append((waveform, sample_rate))
|
|
||||||
texts.append(text)
|
|
||||||
return audios, texts
|
|
||||||
|
|
||||||
|
sd.play(audio, sample_rate * 1.05)
|
||||||
def train(model, audios, texts, epochs=3, learning_rate=1e-4):
|
time.sleep((len(audio) / sample_rate) + 0.5)
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
sd.stop()
|
||||||
criterion = torch.nn.MSELoss() # Вам нужно будет настроить эту функцию под вашу задачу
|
|
||||||
|
|
||||||
model.train()
|
|
||||||
for epoch in range(epochs):
|
|
||||||
total_loss = 0
|
|
||||||
for waveform, text in zip(audios, texts):
|
|
||||||
optimizer.zero_grad()
|
|
||||||
# Предполагается, что модель принимает текст и возвращает аудио
|
|
||||||
predicted_waveform = model(text)
|
|
||||||
loss = criterion(predicted_waveform, waveform)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
total_loss += loss.item()
|
|
||||||
average_loss = total_loss / len(audios)
|
|
||||||
print(f'Epoch {epoch + 1}: Average Loss = {average_loss}')
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
model_path = 'data/v4_ru.pt'
|
|
||||||
model = torch.load(model_path)
|
|
||||||
model.eval()
|
|
||||||
audio_folder = 'wav_files'
|
|
||||||
audios, texts = load_data(audio_folder)
|
|
||||||
train(model, audios, texts)
|
|
||||||
torch.save(model.state_dict(), 'fine_tuned_model.pth')
|
|
||||||
model.eval()
|
|
||||||
sample_text = "Пример текста для синтеза."
|
|
||||||
with torch.no_grad():
|
|
||||||
generated_waveform = model(sample_text)
|
|
||||||
torchaudio.save('output_audio.wav', generated_waveform, 16000)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user