-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #324 from awwaawwa/better_cache
- Loading branch information
Showing
8 changed files
with
497 additions
and
236 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,8 +2,6 @@ name: Test and Build Python Package | |
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
|
||
jobs: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,91 +1,138 @@ | ||
import tempfile | ||
import os | ||
import time | ||
import hashlib | ||
import shutil | ||
|
||
cache_dir = os.path.join(tempfile.gettempdir(), "cache") | ||
os.makedirs(cache_dir, exist_ok=True) | ||
time_filename = "update_time" | ||
max_cache = 5 | ||
|
||
|
||
def deterministic_hash(obj): | ||
hash_object = hashlib.sha256() | ||
hash_object.update(str(obj).encode()) | ||
return hash_object.hexdigest()[0:20] | ||
|
||
|
||
def get_dirs(): | ||
dirs = [ | ||
os.path.join(cache_dir, dir) | ||
for dir in os.listdir(cache_dir) | ||
if os.path.isdir(os.path.join(cache_dir, dir)) | ||
] | ||
return dirs | ||
|
||
|
||
def get_time(dir): | ||
try: | ||
timefile = os.path.join(dir, time_filename) | ||
t = float(open(timefile, encoding="utf-8").read()) | ||
return t | ||
except FileNotFoundError: | ||
# handle the error as needed, for now we'll just return a default value | ||
return float( | ||
"inf" | ||
) # This ensures that this directory will be the first to be removed if required | ||
|
||
|
||
def write_time(dir): | ||
timefile = os.path.join(dir, time_filename) | ||
t = time.time() | ||
print(t, file=open(timefile, "w", encoding="utf-8"), end="") | ||
|
||
|
||
def argmin(iterable): | ||
return min(enumerate(iterable), key=lambda x: x[1])[0] | ||
|
||
|
||
def remove_extra(): | ||
dirs = get_dirs() | ||
for dir in dirs: | ||
if not os.path.isdir( | ||
dir | ||
): # This line might be redundant now, as get_dirs() ensures only directories are returned | ||
os.remove(dir) | ||
try: | ||
get_time(dir) | ||
except BaseException: | ||
shutil.rmtree(dir) | ||
while True: | ||
dirs = get_dirs() | ||
if len(dirs) <= max_cache: | ||
break | ||
times = [get_time(dir) for dir in dirs] | ||
arg = argmin(times) | ||
shutil.rmtree(dirs[arg]) | ||
|
||
|
||
def is_cached(hash_key): | ||
dir = os.path.join(cache_dir, hash_key) | ||
return os.path.exists(dir) | ||
|
||
|
||
def create_cache(hash_key): | ||
dir = os.path.join(cache_dir, hash_key) | ||
os.makedirs(dir, exist_ok=True) | ||
write_time(dir) | ||
|
||
|
||
def load_paragraph(hash_key, hash_key_paragraph): | ||
filename = os.path.join(cache_dir, hash_key, hash_key_paragraph) | ||
if os.path.exists(filename): | ||
return open(filename, encoding="utf-8").read() | ||
else: | ||
return None | ||
|
||
|
||
def write_paragraph(hash_key, hash_key_paragraph, paragraph): | ||
filename = os.path.join(cache_dir, hash_key, hash_key_paragraph) | ||
print(paragraph, file=open(filename, "w", encoding="utf-8"), end="") | ||
import json | ||
from peewee import Model, SqliteDatabase, AutoField, CharField, TextField, SQL | ||
from typing import Optional | ||
|
||
|
||
# we don't init the database here | ||
db = SqliteDatabase(None) | ||
|
||
|
||
class _TranslationCache(Model): | ||
id = AutoField() | ||
translate_engine = CharField(max_length=20) | ||
translate_engine_params = TextField() | ||
original_text = TextField() | ||
translation = TextField() | ||
|
||
class Meta: | ||
database = db | ||
constraints = [ | ||
SQL( | ||
""" | ||
UNIQUE ( | ||
translate_engine, | ||
translate_engine_params, | ||
original_text | ||
) | ||
ON CONFLICT REPLACE | ||
""" | ||
) | ||
] | ||
|
||
|
||
class TranslationCache: | ||
@staticmethod | ||
def _sort_dict_recursively(obj): | ||
if isinstance(obj, dict): | ||
return { | ||
k: TranslationCache._sort_dict_recursively(v) | ||
for k in sorted(obj.keys()) | ||
for v in [obj[k]] | ||
} | ||
elif isinstance(obj, list): | ||
return [TranslationCache._sort_dict_recursively(item) for item in obj] | ||
return obj | ||
|
||
def __init__(self, translate_engine: str, translate_engine_params: dict = None): | ||
self.translate_engine = translate_engine | ||
self.replace_params(translate_engine_params) | ||
|
||
# The program typically starts multi-threaded translation | ||
# only after cache parameters are fully configured, | ||
# so thread safety doesn't need to be considered here. | ||
def replace_params(self, params: dict = None): | ||
if params is None: | ||
params = {} | ||
self.params = params | ||
params = self._sort_dict_recursively(params) | ||
self.translate_engine_params = json.dumps(params) | ||
|
||
def update_params(self, params: dict = None): | ||
if params is None: | ||
params = {} | ||
self.params.update(params) | ||
self.replace_params(self.params) | ||
|
||
def add_params(self, k: str, v): | ||
self.params[k] = v | ||
self.replace_params(self.params) | ||
|
||
# Since peewee and the underlying sqlite are thread-safe, | ||
# get and set operations don't need locks. | ||
def get(self, original_text: str) -> Optional[str]: | ||
result = _TranslationCache.get_or_none( | ||
translate_engine=self.translate_engine, | ||
translate_engine_params=self.translate_engine_params, | ||
original_text=original_text, | ||
) | ||
return result.translation if result else None | ||
|
||
def set(self, original_text: str, translation: str): | ||
_TranslationCache.create( | ||
translate_engine=self.translate_engine, | ||
translate_engine_params=self.translate_engine_params, | ||
original_text=original_text, | ||
translation=translation, | ||
) | ||
|
||
|
||
def init_db(remove_exists=False): | ||
cache_folder = os.path.join(os.path.expanduser("~"), ".cache", "pdf2zh") | ||
os.makedirs(cache_folder, exist_ok=True) | ||
# The current version does not support database migration, so add the version number to the file name. | ||
cache_db_path = os.path.join(cache_folder, "cache.v1.db") | ||
if remove_exists and os.path.exists(cache_db_path): | ||
os.remove(cache_db_path) | ||
db.init( | ||
cache_db_path, | ||
pragmas={ | ||
"journal_mode": "wal", | ||
"busy_timeout": 1000, | ||
}, | ||
) | ||
db.create_tables([_TranslationCache], safe=True) | ||
|
||
|
||
def init_test_db(): | ||
import tempfile | ||
|
||
cache_db_path = tempfile.mktemp(suffix=".db") | ||
test_db = SqliteDatabase( | ||
cache_db_path, | ||
pragmas={ | ||
"journal_mode": "wal", | ||
"busy_timeout": 1000, | ||
}, | ||
) | ||
test_db.bind([_TranslationCache], bind_refs=False, bind_backrefs=False) | ||
test_db.connect() | ||
test_db.create_tables([_TranslationCache], safe=True) | ||
return test_db | ||
|
||
|
||
def clean_test_db(test_db): | ||
test_db.drop_tables([_TranslationCache]) | ||
test_db.close() | ||
db_path = test_db.database | ||
if os.path.exists(db_path): | ||
os.remove(test_db.database) | ||
wal_path = db_path + "-wal" | ||
if os.path.exists(wal_path): | ||
os.remove(wal_path) | ||
shm_path = db_path + "-shm" | ||
if os.path.exists(shm_path): | ||
os.remove(shm_path) | ||
|
||
|
||
init_db() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.