49 lines
1.5 KiB
Python
49 lines
1.5 KiB
Python
import time
|
|
import torch
|
|
from transformers import MarianMTModel, MarianTokenizer
|
|
|
|
MODEL_NAME = "Helsinki-NLP/opus-mt-en-fr"
|
|
tokenizer = MarianTokenizer.from_pretrained(MODEL_NAME)
|
|
model = MarianMTModel.from_pretrained(MODEL_NAME)
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = model.to(DEVICE)
|
|
model.eval() # disable dropout, slightly faster inference
|
|
|
|
FILE_PATH = "input.txt"
|
|
OUTPUT_PATH = "output_fr.txt"
|
|
|
|
def translate(text: str) -> str:
|
|
inputs = tokenizer(
|
|
text,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=512
|
|
).to(DEVICE)
|
|
|
|
with torch.no_grad(): # saves memory, faster on GPU
|
|
translated = model.generate(**inputs)
|
|
|
|
return tokenizer.decode(translated[0], skip_special_tokens=True)
|
|
|
|
def tail_and_translate(filepath: str):
|
|
with open(filepath, "r", encoding="utf-8") as f:
|
|
f.seek(0, 2) # jump to end of file
|
|
print(f"Watching {filepath}...")
|
|
|
|
with open(OUTPUT_PATH, "a", encoding="utf-8") as out:
|
|
while True:
|
|
line = f.readline()
|
|
if line:
|
|
line = line.strip()
|
|
if line:
|
|
translated = translate(line)
|
|
print(f"EN: {line}")
|
|
print(f"FR: {translated}")
|
|
out.write(translated + "\n")
|
|
out.flush()
|
|
else:
|
|
time.sleep(0.2)
|
|
|
|
tail_and_translate(FILE_PATH)
|