style: 格式化代码

This commit is contained in:
2025-12-21 07:56:10 +08:00
parent 1efe034a59
commit a2e12c7462
15 changed files with 373 additions and 290 deletions

View File

@@ -37,7 +37,7 @@ if pathlib.Path(workdir / "config" / "config_dev.toml").exists():
logger.debug("使用开发设置") logger.debug("使用开发设置")
config_var: ContextVar[ConfigFile] = ContextVar( config_var: ContextVar[ConfigFile] = ContextVar(
"config_var", default=ConfigFile(workdir / "config" / "config_dev.toml") "config_var", default=ConfigFile(workdir / "config" / "config_dev.toml")
) )
# runtime_var: ContextVar = ContextVar('runtime_var', default=dict()) # 运行时共享数据 # runtime_var: ContextVar = ContextVar('runtime_var', default=dict()) # 运行时共享数据

View File

@@ -1,9 +1,9 @@
from textual.app import App from textual.app import App
from textual.widgets import Button from textual.widgets import Button
from heurams.services.logger import get_logger
from heurams.context import config_var from heurams.context import config_var
from heurams.interface import HeurAMSApp from heurams.interface import HeurAMSApp
from heurams.services.logger import get_logger
from .screens.about import AboutScreen from .screens.about import AboutScreen
from .screens.dashboard import DashboardScreen from .screens.dashboard import DashboardScreen

View File

@@ -4,8 +4,7 @@ import pathlib
from textual.app import ComposeResult from textual.app import ComposeResult
from textual.containers import ScrollableContainer from textual.containers import ScrollableContainer
from textual.screen import Screen from textual.screen import Screen
from textual.widgets import (Button, Footer, Header, Label, ListItem, ListView, from textual.widgets import Button, Footer, Header, Label, ListItem, ListView, Static
Static)
import heurams.services.timer as timer import heurams.services.timer as timer
import heurams.services.version as version import heurams.services.version as version
@@ -28,6 +27,7 @@ class DashboardScreen(Screen):
Label(f'欢迎使用 "潜进" 启发式辅助记忆调度器', classes="title-label"), Label(f'欢迎使用 "潜进" 启发式辅助记忆调度器', classes="title-label"),
Label(f"当前 UNIX 日时间戳: {timer.get_daystamp()}"), Label(f"当前 UNIX 日时间戳: {timer.get_daystamp()}"),
Label(f'时区修正: UTC+{config_var.get()["timezone_offset"] / 3600}'), Label(f'时区修正: UTC+{config_var.get()["timezone_offset"] / 3600}'),
Label(f"使用算法: {config_var.get()['algorithm']['default']}"),
Label("选择待学习或待修改的记忆单元集:", classes="title-label"), Label("选择待学习或待修改的记忆单元集:", classes="title-label"),
ListView(id="union-list", classes="union-list-view"), ListView(id="union-list", classes="union-list-view"),
Label( Label(

View File

@@ -148,16 +148,24 @@ class MemScreen(Screen):
def play_voice(self): def play_voice(self):
"""朗读当前内容""" """朗读当前内容"""
from heurams.services.audio_service import play_by_path
from pathlib import Path from pathlib import Path
from heurams.services.audio_service import play_by_path
from heurams.services.hasher import get_md5 from heurams.services.hasher import get_md5
path = Path(config_var.get()['paths']["cache_dir"])
path = path / f"{get_md5(self.atom.registry['nucleon'].metadata["formation"]["tts_text"])}.wav" path = Path(config_var.get()["paths"]["cache_dir"])
path = (
path
/ f"{get_md5(self.atom.registry['nucleon'].metadata["formation"]["tts_text"])}.wav"
)
if path.exists(): if path.exists():
play_by_path(path) play_by_path(path)
else: else:
from heurams.services.tts_service import convertor from heurams.services.tts_service import convertor
convertor(self.atom.registry['nucleon'].metadata["formation"]["tts_text"], path)
convertor(
self.atom.registry["nucleon"].metadata["formation"]["tts_text"], path
)
play_by_path(path) play_by_path(path)
def action_toggle_dark(self): def action_toggle_dark(self):

View File

@@ -5,8 +5,7 @@ import toml
from textual.app import ComposeResult from textual.app import ComposeResult
from textual.containers import ScrollableContainer from textual.containers import ScrollableContainer
from textual.screen import Screen from textual.screen import Screen
from textual.widgets import (Button, Footer, Header, Input, Label, Markdown, from textual.widgets import Button, Footer, Header, Input, Label, Markdown, Select
Select)
from heurams.context import config_var from heurams.context import config_var
from heurams.services.version import ver from heurams.services.version import ver

View File

@@ -99,6 +99,7 @@ class PrecachingScreen(Screen):
if not cache_file.exists(): if not cache_file.exists():
try: try:
from heurams.services.tts_service import convertor from heurams.services.tts_service import convertor
convertor(text, cache_file) convertor(text, cache_file)
return 1 return 1
except Exception as e: except Exception as e:

View File

@@ -74,11 +74,13 @@ class SyncScreen(Screen):
"""从配置文件加载同步设置""" """从配置文件加载同步设置"""
try: try:
from heurams.context import config_var from heurams.context import config_var
config_data = config_var.get().data config_data = config_var.get().data
self.sync_config = config_data.get('sync', {}).get('webdav', {}) self.sync_config = config_data.get("sync", {}).get("webdav", {})
# 创建同步服务实例 # 创建同步服务实例
from heurams.services.sync_service import create_sync_service_from_config from heurams.services.sync_service import create_sync_service_from_config
self.sync_service = create_sync_service_from_config() self.sync_service = create_sync_service_from_config()
except Exception as e: except Exception as e:
@@ -89,22 +91,22 @@ class SyncScreen(Screen):
"""更新 UI 显示配置信息""" """更新 UI 显示配置信息"""
try: try:
# 更新服务器 URL # 更新服务器 URL
url = self.sync_config.get('url', '未配置') url = self.sync_config.get("url", "未配置")
url_widget = self.query_one("#server_url") url_widget = self.query_one("#server_url")
url_widget.update(url if url else '未配置') # type: ignore url_widget.update(url if url else "未配置") # type: ignore
# 更新远程路径 # 更新远程路径
remote_path = self.sync_config.get('remote_path', '/heurams/') remote_path = self.sync_config.get("remote_path", "/heurams/")
path_widget = self.query_one("#remote_path") path_widget = self.query_one("#remote_path")
path_widget.update(remote_path) # type: ignore path_widget.update(remote_path) # type: ignore
# 更新同步模式 # 更新同步模式
sync_mode = self.sync_config.get('sync_mode', 'bidirectional') sync_mode = self.sync_config.get("sync_mode", "bidirectional")
mode_widget = self.query_one("#sync_mode") mode_widget = self.query_one("#sync_mode")
mode_map = { mode_map = {
'bidirectional': '双向同步', "bidirectional": "双向同步",
'upload_only': '仅上传', "upload_only": "仅上传",
'download_only': '仅下载', "download_only": "仅下载",
} }
mode_widget.update(mode_map.get(sync_mode, sync_mode)) # type: ignore mode_widget.update(mode_map.get(sync_mode, sync_mode)) # type: ignore
@@ -145,7 +147,7 @@ class SyncScreen(Screen):
self.log_messages.append(log_line) self.log_messages.append(log_line)
# 保持日志行数不超过最大值 # 保持日志行数不超过最大值
if len(self.log_messages) > self.max_log_lines: if len(self.log_messages) > self.max_log_lines:
self.log_messages = self.log_messages[-self.max_log_lines:] self.log_messages = self.log_messages[-self.max_log_lines :]
# 更新日志显示 # 更新日志显示
try: try:
@@ -218,44 +220,60 @@ class SyncScreen(Screen):
try: try:
# 获取需要同步的本地目录 # 获取需要同步的本地目录
from heurams.context import config_var from heurams.context import config_var
config = config_var.get() config = config_var.get()
paths = config.get('paths', {}) paths = config.get("paths", {})
# 同步 nucleon 目录 # 同步 nucleon 目录
nucleon_dir = pathlib.Path(paths.get('nucleon_dir', './data/nucleon')) nucleon_dir = pathlib.Path(paths.get("nucleon_dir", "./data/nucleon"))
if nucleon_dir.exists(): if nucleon_dir.exists():
self.log_message(f"同步 nucleon 目录: {nucleon_dir}") self.log_message(f"同步 nucleon 目录: {nucleon_dir}")
self.update_status(f"同步 nucleon 目录...", progress=10) self.update_status(f"同步 nucleon 目录...", progress=10)
result = self.sync_service.sync_directory(nucleon_dir) # type: ignore result = self.sync_service.sync_directory(nucleon_dir) # type: ignore
if result.get('success'): if result.get("success"):
self.log_message(f"nucleon 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}") self.log_message(
f"nucleon 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}"
)
else: else:
self.log_message(f"nucleon 同步失败: {result.get('error', '未知错误')}", is_error=True) self.log_message(
f"nucleon 同步失败: {result.get('error', '未知错误')}",
is_error=True,
)
# 同步 electron 目录 # 同步 electron 目录
electron_dir = pathlib.Path(paths.get('electron_dir', './data/electron')) electron_dir = pathlib.Path(paths.get("electron_dir", "./data/electron"))
if electron_dir.exists(): if electron_dir.exists():
self.log_message(f"同步 electron 目录: {electron_dir}") self.log_message(f"同步 electron 目录: {electron_dir}")
self.update_status(f"同步 electron 目录...", progress=60) self.update_status(f"同步 electron 目录...", progress=60)
result = self.sync_service.sync_directory(electron_dir) # type: ignore result = self.sync_service.sync_directory(electron_dir) # type: ignore
if result.get('success'): if result.get("success"):
self.log_message(f"electron 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}") self.log_message(
f"electron 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}"
)
else: else:
self.log_message(f"electron 同步失败: {result.get('error', '未知错误')}", is_error=True) self.log_message(
f"electron 同步失败: {result.get('error', '未知错误')}",
is_error=True,
)
# 同步 orbital 目录(如果存在) # 同步 orbital 目录(如果存在)
orbital_dir = pathlib.Path(paths.get('orbital_dir', './data/orbital')) orbital_dir = pathlib.Path(paths.get("orbital_dir", "./data/orbital"))
if orbital_dir.exists(): if orbital_dir.exists():
self.log_message(f"同步 orbital 目录: {orbital_dir}") self.log_message(f"同步 orbital 目录: {orbital_dir}")
self.update_status(f"同步 orbital 目录...", progress=80) self.update_status(f"同步 orbital 目录...", progress=80)
result = self.sync_service.sync_directory(orbital_dir) # type: ignore result = self.sync_service.sync_directory(orbital_dir) # type: ignore
if result.get('success'): if result.get("success"):
self.log_message(f"orbital 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}") self.log_message(
f"orbital 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}"
)
else: else:
self.log_message(f"orbital 同步失败: {result.get('error', '未知错误')}", is_error=True) self.log_message(
f"orbital 同步失败: {result.get('error', '未知错误')}",
is_error=True,
)
# 同步完成 # 同步完成
self.update_status("同步完成", progress=100) self.update_status("同步完成", progress=100)

View File

@@ -50,7 +50,8 @@ class Recognition(BasePuzzleWidget):
def compose(self): def compose(self):
from heurams.context import config_var from heurams.context import config_var
autovoice = config_var.get()['interface']['memorizor']['autovoice']
autovoice = config_var.get()["interface"]["memorizor"]["autovoice"]
if autovoice: if autovoice:
self.screen.action_play_voice() # type: ignore self.screen.action_play_voice() # type: ignore
cfg: RecognitionConfig = self.atom.registry["orbital"]["puzzles"][self.alia] cfg: RecognitionConfig = self.atom.registry["orbital"]["puzzles"][self.alia]
@@ -72,7 +73,7 @@ class Recognition(BasePuzzleWidget):
primary = cfg["primary"] primary = cfg["primary"]
with Center(): with Center():
for i in cfg['top_dim']: for i in cfg["top_dim"]:
yield Static(f"[dim]{i}[/]") yield Static(f"[dim]{i}[/]")
yield Label("") yield Label("")

View File

@@ -10,15 +10,25 @@ MIT 许可证
import datetime import datetime
import json import json
import os import os
from typing import TypedDict
import pathlib import pathlib
from typing import TypedDict
from heurams.context import config_var from heurams.context import config_var
from heurams.kernel.algorithms.sm15m_calc import (MAX_AF, MIN_AF, NOTCH_AF, from heurams.kernel.algorithms.sm15m_calc import (
RANGE_AF, RANGE_REPETITION, MAX_AF,
SM, THRESHOLD_RECALL, Item) MIN_AF,
NOTCH_AF,
RANGE_AF,
RANGE_REPETITION,
SM,
THRESHOLD_RECALL,
Item,
)
# 全局状态文件路径 # 全局状态文件路径
_GLOBAL_STATE_FILE = os.path.expanduser(pathlib.Path(config_var.get()['paths']['global_dir']) / 'sm15m_global_state.json') _GLOBAL_STATE_FILE = os.path.expanduser(
pathlib.Path(config_var.get()["paths"]["global_dir"]) / "sm15m_global_state.json"
)
def _get_global_sm(): def _get_global_sm():

View File

@@ -86,8 +86,8 @@ class Atom:
# eval 环境设置 # eval 环境设置
def eval_with_env(s: str): def eval_with_env(s: str):
default = config_var.get()["puzzles"] default = config_var.get()["puzzles"]
payload = self.registry['nucleon'].payload payload = self.registry["nucleon"].payload
metadata = self.registry['nucleon'].metadata metadata = self.registry["nucleon"].metadata
eval_value = eval(s) eval_value = eval(s)
if isinstance(eval_value, (int, float)): if isinstance(eval_value, (int, float)):
ret = str(eval_value) ret = str(eval_value)
@@ -117,10 +117,11 @@ class Atom:
logger.debug("发现 eval 表达式: '%s'", data[5:]) logger.debug("发现 eval 表达式: '%s'", data[5:])
return modifier(data[5:]) return modifier(data[5:])
return data return data
try: try:
traverse(self.registry['nucleon'].payload, eval_with_env) traverse(self.registry["nucleon"].payload, eval_with_env)
traverse(self.registry['nucleon'].metadata, eval_with_env) traverse(self.registry["nucleon"].metadata, eval_with_env)
traverse(self.registry['orbital'], eval_with_env) traverse(self.registry["orbital"], eval_with_env)
except Exception as e: except Exception as e:
ret = f"此 eval 实例发生错误: {e}" ret = f"此 eval 实例发生错误: {e}"
logger.warning(ret) logger.warning(ret)

View File

@@ -18,9 +18,12 @@ class Electron:
algo: 使用的算法模块标识 algo: 使用的算法模块标识
""" """
if algo_name == "": if algo_name == "":
algo_name = config_var.get()['algorithm']['default'] algo_name = config_var.get()["algorithm"]["default"]
logger.debug( logger.debug(
"创建 Electron 实例, ident: '%s', algo_name: '%s', algodata: %s", ident, algo_name, algodata "创建 Electron 实例, ident: '%s', algo_name: '%s', algodata: %s",
ident,
algo_name,
algodata,
) )
self.algodata = algodata self.algodata = algodata
self.ident = ident self.ident = ident
@@ -31,7 +34,9 @@ class Electron:
self.algodata[self.algo.algo_name] = {} self.algodata[self.algo.algo_name] = {}
logger.debug("算法键 '%s' 不存在, 已创建空字典", self.algo) logger.debug("算法键 '%s' 不存在, 已创建空字典", self.algo)
if not self.algodata[self.algo.algo_name]: if not self.algodata[self.algo.algo_name]:
logger.debug(f"算法数据为空, 使用默认值初始化{self.algodata[self.algo.algo_name]}") logger.debug(
f"算法数据为空, 使用默认值初始化{self.algodata[self.algo.algo_name]}"
)
self._default_init(self.algo.defaults) self._default_init(self.algo.defaults)
else: else:
logger.debug("算法数据已存在, 跳过默认初始化") logger.debug("算法数据已存在, 跳过默认初始化")

View File

@@ -2,8 +2,8 @@ import pathlib
import edge_tts import edge_tts
from heurams.services.logger import get_logger
from heurams.context import config_var from heurams.context import config_var
from heurams.services.logger import get_logger
from .base import BaseTTS from .base import BaseTTS
@@ -19,7 +19,7 @@ class EdgeTTS(BaseTTS):
try: try:
communicate = edge_tts.Communicate( communicate = edge_tts.Communicate(
text, text,
config_var.get()['providers']['tts']['edgetts']["voice"], config_var.get()["providers"]["tts"]["edgetts"]["voice"],
) )
logger.debug("EdgeTTS 通信对象创建成功, 正在保存音频") logger.debug("EdgeTTS 通信对象创建成功, 正在保存音频")
communicate.save_sync(str(path)) communicate.save_sync(str(path))

View File

@@ -18,6 +18,7 @@ logger = get_logger(__name__)
class SyncMode(Enum): class SyncMode(Enum):
"""同步模式枚举""" """同步模式枚举"""
BIDIRECTIONAL = "bidirectional" BIDIRECTIONAL = "bidirectional"
UPLOAD_ONLY = "upload_only" UPLOAD_ONLY = "upload_only"
DOWNLOAD_ONLY = "download_only" DOWNLOAD_ONLY = "download_only"
@@ -25,6 +26,7 @@ class SyncMode(Enum):
class ConflictStrategy(Enum): class ConflictStrategy(Enum):
"""冲突解决策略枚举""" """冲突解决策略枚举"""
NEWER = "newer" # 较新文件覆盖较旧文件 NEWER = "newer" # 较新文件覆盖较旧文件
ASK = "ask" # 用户手动选择 ASK = "ask" # 用户手动选择
KEEP_BOTH = "keep_both" # 保留双方(重命名) KEEP_BOTH = "keep_both" # 保留双方(重命名)
@@ -33,6 +35,7 @@ class ConflictStrategy(Enum):
@dataclass @dataclass
class SyncConfig: class SyncConfig:
"""同步配置数据类""" """同步配置数据类"""
enabled: bool = False enabled: bool = False
url: str = "" url: str = ""
username: str = "" username: str = ""
@@ -59,12 +62,12 @@ class SyncService:
return return
options = { options = {
'webdav_hostname': self.config.url, "webdav_hostname": self.config.url,
'webdav_login': self.config.username, "webdav_login": self.config.username,
'webdav_password': self.config.password, "webdav_password": self.config.password,
'webdav_root': self.config.remote_path, "webdav_root": self.config.remote_path,
'verify_ssl': self.config.verify_ssl, "verify_ssl": self.config.verify_ssl,
'disable_check': True, # 不检查服务器支持的功能 "disable_check": True, # 不检查服务器支持的功能
} }
try: try:
@@ -98,10 +101,10 @@ class SyncService:
rel_path = file_path.relative_to(local_dir) rel_path = file_path.relative_to(local_dir)
stat = file_path.stat() stat = file_path.stat()
files[str(rel_path)] = { files[str(rel_path)] = {
'path': file_path, "path": file_path,
'size': stat.st_size, "size": stat.st_size,
'mtime': stat.st_mtime, "mtime": stat.st_mtime,
'hash': self._calculate_hash(file_path), "hash": self._calculate_hash(file_path),
} }
return files return files
@@ -114,14 +117,14 @@ class SyncService:
remote_list = self.client.list(recursive=True) remote_list = self.client.list(recursive=True)
files = {} files = {}
for item in remote_list: for item in remote_list:
if not item.endswith('/'): # 忽略目录 if not item.endswith("/"): # 忽略目录
rel_path = item.lstrip('/') rel_path = item.lstrip("/")
try: try:
info = self.client.info(item) info = self.client.info(item)
files[rel_path] = { files[rel_path] = {
'path': item, "path": item,
'size': info.get('size', 0), "size": info.get("size", 0),
'mtime': self._parse_remote_mtime(info), "mtime": self._parse_remote_mtime(info),
} }
except Exception as e: except Exception as e:
logger.warning("无法获取远程文件信息 %s: %s", item, e) logger.warning("无法获取远程文件信息 %s: %s", item, e)
@@ -134,8 +137,8 @@ class SyncService:
"""计算文件的 SHA-256 哈希值""" """计算文件的 SHA-256 哈希值"""
sha256 = hashlib.sha256() sha256 = hashlib.sha256()
try: try:
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
for block in iter(lambda: f.read(block_size), b''): for block in iter(lambda: f.read(block_size), b""):
sha256.update(block) sha256.update(block)
return sha256.hexdigest() return sha256.hexdigest()
except Exception as e: except Exception as e:
@@ -160,14 +163,14 @@ class SyncService:
""" """
if not self.client: if not self.client:
logger.error("WebDAV 客户端未初始化") logger.error("WebDAV 客户端未初始化")
return {'success': False, 'error': '客户端未初始化'} return {"success": False, "error": "客户端未初始化"}
results = { results = {
'uploaded': 0, "uploaded": 0,
'downloaded': 0, "downloaded": 0,
'conflicts': 0, "conflicts": 0,
'errors': 0, "errors": 0,
'success': True, "success": True,
} }
try: try:
@@ -180,29 +183,33 @@ class SyncService:
# 根据同步模式处理文件 # 根据同步模式处理文件
if self.config.sync_mode in [SyncMode.BIDIRECTIONAL, SyncMode.UPLOAD_ONLY]: if self.config.sync_mode in [SyncMode.BIDIRECTIONAL, SyncMode.UPLOAD_ONLY]:
stats = self._upload_files(local_dir, local_files, remote_files) stats = self._upload_files(local_dir, local_files, remote_files)
results['uploaded'] += stats.get('uploaded', 0) results["uploaded"] += stats.get("uploaded", 0)
results['conflicts'] += stats.get('conflicts', 0) results["conflicts"] += stats.get("conflicts", 0)
results['errors'] += stats.get('errors', 0) results["errors"] += stats.get("errors", 0)
if self.config.sync_mode in [SyncMode.BIDIRECTIONAL, SyncMode.DOWNLOAD_ONLY]: if self.config.sync_mode in [
SyncMode.BIDIRECTIONAL,
SyncMode.DOWNLOAD_ONLY,
]:
stats = self._download_files(local_dir, local_files, remote_files) stats = self._download_files(local_dir, local_files, remote_files)
results['downloaded'] += stats.get('downloaded', 0) results["downloaded"] += stats.get("downloaded", 0)
results['conflicts'] += stats.get('conflicts', 0) results["conflicts"] += stats.get("conflicts", 0)
results['errors'] += stats.get('errors', 0) results["errors"] += stats.get("errors", 0)
logger.info("同步完成: %s", results) logger.info("同步完成: %s", results)
return results return results
except Exception as e: except Exception as e:
logger.error("同步过程中发生错误: %s", e) logger.error("同步过程中发生错误: %s", e)
results['success'] = False results["success"] = False
results['error'] = str(e) results["error"] = str(e)
return results return results
def _upload_files(self, local_dir: pathlib.Path, def _upload_files(
local_files: dict, remote_files: dict) -> typing.Dict[str, int]: self, local_dir: pathlib.Path, local_files: dict, remote_files: dict
) -> typing.Dict[str, int]:
"""上传文件到远程服务器""" """上传文件到远程服务器"""
stats = {'uploaded': 0, 'errors': 0, 'conflicts': 0} stats = {"uploaded": 0, "errors": 0, "conflicts": 0}
for rel_path, local_info in local_files.items(): for rel_path, local_info in local_files.items():
remote_info = remote_files.get(rel_path) remote_info = remote_files.get(rel_path)
@@ -216,20 +223,31 @@ class SyncService:
should_upload = True # 远程不存在 should_upload = True # 远程不存在
else: else:
# 检查冲突 # 检查冲突
local_mtime = local_info.get('mtime', 0) local_mtime = local_info.get("mtime", 0)
remote_mtime = remote_info.get('mtime', 0) remote_mtime = remote_info.get("mtime", 0)
if local_mtime != remote_mtime: if local_mtime != remote_mtime:
# 存在冲突 # 存在冲突
stats['conflicts'] += 1 stats["conflicts"] += 1
should_upload, should_download = self._handle_conflict(local_info, remote_info) should_upload, should_download = self._handle_conflict(
local_info, remote_info
)
if should_upload and self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH: if (
should_upload
and self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH
):
# 重命名远程文件避免覆盖 # 重命名远程文件避免覆盖
conflict_suffix = f".conflict_{int(remote_mtime)}" conflict_suffix = f".conflict_{int(remote_mtime)}"
name, ext = os.path.splitext(rel_path) name, ext = os.path.splitext(rel_path)
new_rel_path = f"{name}{conflict_suffix}{ext}" if ext else f"{name}{conflict_suffix}" new_rel_path = (
remote_path = os.path.join(self.config.remote_path, new_rel_path) f"{name}{conflict_suffix}{ext}"
if ext
else f"{name}{conflict_suffix}"
)
remote_path = os.path.join(
self.config.remote_path, new_rel_path
)
conflict_resolved = True conflict_resolved = True
logger.debug("冲突文件重命名: %s -> %s", rel_path, new_rel_path) logger.debug("冲突文件重命名: %s -> %s", rel_path, new_rel_path)
else: else:
@@ -238,19 +256,20 @@ class SyncService:
if should_upload: if should_upload:
try: try:
self.client.upload_file(local_info['path'], remote_path) self.client.upload_file(local_info["path"], remote_path)
stats['uploaded'] += 1 stats["uploaded"] += 1
logger.debug("上传文件: %s -> %s", rel_path, remote_path) logger.debug("上传文件: %s -> %s", rel_path, remote_path)
except Exception as e: except Exception as e:
logger.error("上传文件失败 %s: %s", rel_path, e) logger.error("上传文件失败 %s: %s", rel_path, e)
stats['errors'] += 1 stats["errors"] += 1
return stats return stats
def _download_files(self, local_dir: pathlib.Path, def _download_files(
local_files: dict, remote_files: dict) -> typing.Dict[str, int]: self, local_dir: pathlib.Path, local_files: dict, remote_files: dict
) -> typing.Dict[str, int]:
"""从远程服务器下载文件""" """从远程服务器下载文件"""
stats = {'downloaded': 0, 'errors': 0, 'conflicts': 0} stats = {"downloaded": 0, "errors": 0, "conflicts": 0}
for rel_path, remote_info in remote_files.items(): for rel_path, remote_info in remote_files.items():
local_info = local_files.get(rel_path) local_info = local_files.get(rel_path)
@@ -261,13 +280,15 @@ class SyncService:
should_download = True # 本地不存在 should_download = True # 本地不存在
else: else:
# 检查冲突 # 检查冲突
local_mtime = local_info.get('mtime', 0) local_mtime = local_info.get("mtime", 0)
remote_mtime = remote_info.get('mtime', 0) remote_mtime = remote_info.get("mtime", 0)
if local_mtime != remote_mtime: if local_mtime != remote_mtime:
# 存在冲突 # 存在冲突
stats['conflicts'] += 1 stats["conflicts"] += 1
should_upload, should_download = self._handle_conflict(local_info, remote_info) should_upload, should_download = self._handle_conflict(
local_info, remote_info
)
# 如果应该上传,则不应该下载(冲突已在上传侧处理) # 如果应该上传,则不应该下载(冲突已在上传侧处理)
if should_upload: if should_upload:
should_download = False should_download = False
@@ -279,24 +300,26 @@ class SyncService:
try: try:
local_path = local_dir / rel_path local_path = local_dir / rel_path
local_path.parent.mkdir(parents=True, exist_ok=True) local_path.parent.mkdir(parents=True, exist_ok=True)
self.client.download_file(remote_info['path'], str(local_path)) self.client.download_file(remote_info["path"], str(local_path))
stats['downloaded'] += 1 stats["downloaded"] += 1
logger.debug("下载文件: %s -> %s", rel_path, local_path) logger.debug("下载文件: %s -> %s", rel_path, local_path)
except Exception as e: except Exception as e:
logger.error("下载文件失败 %s: %s", rel_path, e) logger.error("下载文件失败 %s: %s", rel_path, e)
stats['errors'] += 1 stats["errors"] += 1
return stats return stats
def _handle_conflict(self, local_info: dict, remote_info: dict) -> typing.Tuple[bool, bool]: def _handle_conflict(
self, local_info: dict, remote_info: dict
) -> typing.Tuple[bool, bool]:
""" """
处理文件冲突 处理文件冲突
Returns: Returns:
(should_upload, should_download) - 是否应该上传和下载 (should_upload, should_download) - 是否应该上传和下载
""" """
local_mtime = local_info.get('mtime', 0) local_mtime = local_info.get("mtime", 0)
remote_mtime = remote_info.get('mtime', 0) remote_mtime = remote_info.get("mtime", 0)
if self.config.conflict_strategy == ConflictStrategy.NEWER: if self.config.conflict_strategy == ConflictStrategy.NEWER:
# 较新文件覆盖较旧文件 # 较新文件覆盖较旧文件
@@ -318,8 +341,11 @@ class SyncService:
elif self.config.conflict_strategy == ConflictStrategy.ASK: elif self.config.conflict_strategy == ConflictStrategy.ASK:
# 用户手动选择 - 记录冲突,跳过 # 用户手动选择 - 记录冲突,跳过
# 返回 False, False 跳过,等待用户决定 # 返回 False, False 跳过,等待用户决定
logger.warning("文件冲突需要用户手动选择: local_mtime=%s, remote_mtime=%s", logger.warning(
local_mtime, remote_mtime) "文件冲突需要用户手动选择: local_mtime=%s, remote_mtime=%s",
local_mtime,
remote_mtime,
)
return False, False return False, False
return False, False return False, False
@@ -328,11 +354,11 @@ class SyncService:
"""判断是否需要上传(本地较新或哈希不同)""" """判断是否需要上传(本地较新或哈希不同)"""
# 这里实现简单的基于时间的比较 # 这里实现简单的基于时间的比较
# 实际应该使用哈希比较更可靠 # 实际应该使用哈希比较更可靠
return local_info.get('mtime', 0) > remote_info.get('mtime', 0) return local_info.get("mtime", 0) > remote_info.get("mtime", 0)
def _should_download(self, local_info: dict, remote_info: dict) -> bool: def _should_download(self, local_info: dict, remote_info: dict) -> bool:
"""判断是否需要下载(远程较新)""" """判断是否需要下载(远程较新)"""
return remote_info.get('mtime', 0) > local_info.get('mtime', 0) return remote_info.get("mtime", 0) > local_info.get("mtime", 0)
def upload_file(self, local_path: pathlib.Path, remote_path: str = "") -> bool: def upload_file(self, local_path: pathlib.Path, remote_path: str = "") -> bool:
"""上传单个文件""" """上传单个文件"""
@@ -382,20 +408,22 @@ def create_sync_service_from_config() -> typing.Optional[SyncService]:
try: try:
from heurams.context import config_var from heurams.context import config_var
sync_config = config_var.get()['providers']['sync']['webdav'] sync_config = config_var.get()["providers"]["sync"]["webdav"]
if not sync_config.get('enabled', False): if not sync_config.get("enabled", False):
logger.debug("同步服务未启用") logger.debug("同步服务未启用")
return None return None
config = SyncConfig( config = SyncConfig(
enabled=sync_config.get('enabled', False), enabled=sync_config.get("enabled", False),
url=sync_config.get('url', ''), url=sync_config.get("url", ""),
username=sync_config.get('username', ''), username=sync_config.get("username", ""),
password=sync_config.get('password', ''), password=sync_config.get("password", ""),
remote_path=sync_config.get('remote_path', '/heurams/'), remote_path=sync_config.get("remote_path", "/heurams/"),
sync_mode=SyncMode(sync_config.get('sync_mode', 'bidirectional')), sync_mode=SyncMode(sync_config.get("sync_mode", "bidirectional")),
conflict_strategy=ConflictStrategy(sync_config.get('conflict_strategy', 'newer')), conflict_strategy=ConflictStrategy(
verify_ssl=sync_config.get('verify_ssl', True), sync_config.get("conflict_strategy", "newer")
),
verify_ssl=sync_config.get("verify_ssl", True),
) )
service = SyncService(config) service = SyncService(config)

View File

@@ -6,11 +6,16 @@ import pathlib
import tempfile import tempfile
import time import time
import unittest import unittest
from unittest.mock import MagicMock, patch, Mock from unittest.mock import MagicMock, Mock, patch
from heurams.context import ConfigContext from heurams.context import ConfigContext
from heurams.services.config import ConfigFile from heurams.services.config import ConfigFile
from heurams.services.sync_service import SyncService, SyncConfig, SyncMode, ConflictStrategy from heurams.services.sync_service import (
ConflictStrategy,
SyncConfig,
SyncMode,
SyncService,
)
class TestSyncServiceUnit(unittest.TestCase): class TestSyncServiceUnit(unittest.TestCase):
@@ -44,7 +49,7 @@ class TestSyncServiceUnit(unittest.TestCase):
"""在每个测试之后清理.""" """在每个测试之后清理."""
self.temp_dir.cleanup() self.temp_dir.cleanup()
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_sync_service_initialization(self, mock_client_class): def test_sync_service_initialization(self, mock_client_class):
"""测试同步服务初始化.""" """测试同步服务初始化."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
@@ -56,7 +61,7 @@ class TestSyncServiceUnit(unittest.TestCase):
self.assertIsNotNone(service.client) self.assertIsNotNone(service.client)
self.assertEqual(service.config, self.config) self.assertEqual(service.config, self.config)
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_sync_service_disabled(self, mock_client_class): def test_sync_service_disabled(self, mock_client_class):
"""测试同步服务未启用.""" """测试同步服务未启用."""
config = SyncConfig(enabled=False) config = SyncConfig(enabled=False)
@@ -66,7 +71,7 @@ class TestSyncServiceUnit(unittest.TestCase):
mock_client_class.assert_not_called() mock_client_class.assert_not_called()
self.assertIsNone(service.client) self.assertIsNone(service.client)
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_test_connection_success(self, mock_client_class): def test_test_connection_success(self, mock_client_class):
"""测试连接测试成功.""" """测试连接测试成功."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
@@ -78,7 +83,7 @@ class TestSyncServiceUnit(unittest.TestCase):
self.assertTrue(result) self.assertTrue(result)
self.mock_client.list.assert_called_once() self.mock_client.list.assert_called_once()
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_test_connection_failure(self, mock_client_class): def test_test_connection_failure(self, mock_client_class):
"""测试连接测试失败.""" """测试连接测试失败."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
@@ -90,7 +95,7 @@ class TestSyncServiceUnit(unittest.TestCase):
self.assertFalse(result) self.assertFalse(result)
self.mock_client.list.assert_called_once() self.mock_client.list.assert_called_once()
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_upload_file(self, mock_client_class): def test_upload_file(self, mock_client_class):
"""测试上传单个文件.""" """测试上传单个文件."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
@@ -101,7 +106,7 @@ class TestSyncServiceUnit(unittest.TestCase):
self.assertTrue(result) self.assertTrue(result)
self.mock_client.upload_file.assert_called_once() self.mock_client.upload_file.assert_called_once()
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_download_file(self, mock_client_class): def test_download_file(self, mock_client_class):
"""测试下载单个文件.""" """测试下载单个文件."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
@@ -116,7 +121,7 @@ class TestSyncServiceUnit(unittest.TestCase):
self.mock_client.download_file.assert_called_once() self.mock_client.download_file.assert_called_once()
self.assertTrue(local_path.parent.exists()) self.assertTrue(local_path.parent.exists())
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_sync_directory_no_files(self, mock_client_class): def test_sync_directory_no_files(self, mock_client_class):
"""测试同步空目录.""" """测试同步空目录."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
@@ -126,12 +131,12 @@ class TestSyncServiceUnit(unittest.TestCase):
service = SyncService(self.config) service = SyncService(self.config)
result = service.sync_directory(self.temp_path) result = service.sync_directory(self.temp_path)
self.assertTrue(result['success']) self.assertTrue(result["success"])
self.assertEqual(result['uploaded'], 0) self.assertEqual(result["uploaded"], 0)
self.assertEqual(result['downloaded'], 0) self.assertEqual(result["downloaded"], 0)
self.mock_client.mkdir.assert_called_once() self.mock_client.mkdir.assert_called_once()
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_sync_directory_upload_only(self, mock_client_class): def test_sync_directory_upload_only(self, mock_client_class):
"""测试仅上传模式.""" """测试仅上传模式."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
@@ -151,54 +156,58 @@ class TestSyncServiceUnit(unittest.TestCase):
service = SyncService(config) service = SyncService(config)
result = service.sync_directory(self.temp_path) result = service.sync_directory(self.temp_path)
self.assertTrue(result['success']) self.assertTrue(result["success"])
self.mock_client.mkdir.assert_called_once() self.mock_client.mkdir.assert_called_once()
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_conflict_strategy_newer(self, mock_client_class): def test_conflict_strategy_newer(self, mock_client_class):
"""测试 NEWER 冲突策略.""" """测试 NEWER 冲突策略."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
# 模拟远程文件存在 # 模拟远程文件存在
self.mock_client.list.return_value = ["test.txt"] self.mock_client.list.return_value = ["test.txt"]
self.mock_client.info.return_value = {'size': 100, 'modified': '2023-01-01T00:00:00Z'} self.mock_client.info.return_value = {
"size": 100,
"modified": "2023-01-01T00:00:00Z",
}
self.mock_client.mkdir.return_value = None self.mock_client.mkdir.return_value = None
service = SyncService(self.config) service = SyncService(self.config)
result = service.sync_directory(self.temp_path) result = service.sync_directory(self.temp_path)
self.assertTrue(result['success']) self.assertTrue(result["success"])
# 应该有一个冲突 # 应该有一个冲突
self.assertGreaterEqual(result.get('conflicts', 0), 0) self.assertGreaterEqual(result.get("conflicts", 0), 0)
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_create_sync_service_from_config(self, mock_client_class): def test_create_sync_service_from_config(self, mock_client_class):
"""测试从配置文件创建同步服务.""" """测试从配置文件创建同步服务."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
# 创建临时配置文件 # 创建临时配置文件
config_data = { config_data = {
'sync': { "sync": {
'webdav': { "webdav": {
'enabled': True, "enabled": True,
'url': 'https://example.com/dav/', "url": "https://example.com/dav/",
'username': 'test', "username": "test",
'password': 'test', "password": "test",
'remote_path': '/heurams/', "remote_path": "/heurams/",
'sync_mode': 'bidirectional', "sync_mode": "bidirectional",
'conflict_strategy': 'newer', "conflict_strategy": "newer",
'verify_ssl': True, "verify_ssl": True,
} }
} }
} }
# 模拟 config_var # 模拟 config_var
with patch('heurams.services.sync_service.config_var') as mock_config_var: with patch("heurams.services.sync_service.config_var") as mock_config_var:
mock_config = MagicMock() mock_config = MagicMock()
mock_config.data = config_data mock_config.data = config_data
mock_config_var.get.return_value = mock_config mock_config_var.get.return_value = mock_config
from heurams.services.sync_service import create_sync_service_from_config from heurams.services.sync_service import create_sync_service_from_config
service = create_sync_service_from_config() service = create_sync_service_from_config()
self.assertIsNotNone(service) self.assertIsNotNone(service)
@@ -228,22 +237,24 @@ class TestSyncScreenUnit(unittest.TestCase):
config_data["paths"]["cache_dir"] = str(self.temp_path / "cache") config_data["paths"]["cache_dir"] = str(self.temp_path / "cache")
# 添加同步配置 # 添加同步配置
if 'sync' not in config_data: if "sync" not in config_data:
config_data['sync'] = {} config_data["sync"] = {}
config_data['sync']['webdav'] = { config_data["sync"]["webdav"] = {
'enabled': False, "enabled": False,
'url': '', "url": "",
'username': '', "username": "",
'password': '', "password": "",
'remote_path': '/heurams/', "remote_path": "/heurams/",
'sync_mode': 'bidirectional', "sync_mode": "bidirectional",
'conflict_strategy': 'newer', "conflict_strategy": "newer",
'verify_ssl': True, "verify_ssl": True,
} }
# 创建目录 # 创建目录
for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]: for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]:
pathlib.Path(config_data["paths"][dir_key]).mkdir(parents=True, exist_ok=True) pathlib.Path(config_data["paths"][dir_key]).mkdir(
parents=True, exist_ok=True
)
# 使用 ConfigContext 设置配置 # 使用 ConfigContext 设置配置
self.config_ctx = ConfigContext(self.config) self.config_ctx = ConfigContext(self.config)
@@ -254,7 +265,7 @@ class TestSyncScreenUnit(unittest.TestCase):
self.config_ctx.__exit__(None, None, None) self.config_ctx.__exit__(None, None, None)
self.temp_dir.cleanup() self.temp_dir.cleanup()
@patch('heurams.interface.screens.synctool.create_sync_service_from_config') @patch("heurams.interface.screens.synctool.create_sync_service_from_config")
def test_sync_screen_compose(self, mock_create_service): def test_sync_screen_compose(self, mock_create_service):
"""测试 SyncScreen 的 compose 方法.""" """测试 SyncScreen 的 compose 方法."""
from heurams.interface.screens.synctool import SyncScreen from heurams.interface.screens.synctool import SyncScreen
@@ -268,12 +279,13 @@ class TestSyncScreenUnit(unittest.TestCase):
# 测试 compose 方法 # 测试 compose 方法
from textual.app import ComposeResult from textual.app import ComposeResult
result = screen.compose() result = screen.compose()
widgets = list(result) widgets = list(result)
# 检查基本部件 # 检查基本部件
from textual.widgets import Footer, Header, Button, Static, ProgressBar
from textual.containers import ScrollableContainer from textual.containers import ScrollableContainer
from textual.widgets import Button, Footer, Header, ProgressBar, Static
header_present = any(isinstance(w, Header) for w in widgets) header_present = any(isinstance(w, Header) for w in widgets)
footer_present = any(isinstance(w, Footer) for w in widgets) footer_present = any(isinstance(w, Footer) for w in widgets)
@@ -284,7 +296,7 @@ class TestSyncScreenUnit(unittest.TestCase):
container_present = any(isinstance(w, ScrollableContainer) for w in widgets) container_present = any(isinstance(w, ScrollableContainer) for w in widgets)
self.assertTrue(container_present) self.assertTrue(container_present)
@patch('heurams.interface.screens.synctool.create_sync_service_from_config') @patch("heurams.interface.screens.synctool.create_sync_service_from_config")
def test_sync_screen_load_config(self, mock_create_service): def test_sync_screen_load_config(self, mock_create_service):
"""测试 SyncScreen 加载配置.""" """测试 SyncScreen 加载配置."""
from heurams.interface.screens.synctool import SyncScreen from heurams.interface.screens.synctool import SyncScreen