style: 格式化代码

This commit is contained in:
2025-12-21 07:56:10 +08:00
parent 1efe034a59
commit a2e12c7462
15 changed files with 373 additions and 290 deletions

View File

@@ -36,8 +36,8 @@ if pathlib.Path(workdir / "config" / "config_dev.toml").exists():
print("使用开发设置") print("使用开发设置")
logger.debug("使用开发设置") logger.debug("使用开发设置")
config_var: ContextVar[ConfigFile] = ContextVar( config_var: ContextVar[ConfigFile] = ContextVar(
"config_var", default=ConfigFile(workdir / "config" / "config_dev.toml") "config_var", default=ConfigFile(workdir / "config" / "config_dev.toml")
) )
# runtime_var: ContextVar = ContextVar('runtime_var', default=dict()) # 运行时共享数据 # runtime_var: ContextVar = ContextVar('runtime_var', default=dict()) # 运行时共享数据

View File

@@ -1,9 +1,9 @@
from textual.app import App from textual.app import App
from textual.widgets import Button from textual.widgets import Button
from heurams.services.logger import get_logger
from heurams.context import config_var from heurams.context import config_var
from heurams.interface import HeurAMSApp from heurams.interface import HeurAMSApp
from heurams.services.logger import get_logger
from .screens.about import AboutScreen from .screens.about import AboutScreen
from .screens.dashboard import DashboardScreen from .screens.dashboard import DashboardScreen
@@ -15,4 +15,4 @@ logger = get_logger(__name__)
app = HeurAMSApp() app = HeurAMSApp()
if __name__ == "__main__": if __name__ == "__main__":
app.run() app.run()

View File

@@ -4,8 +4,7 @@ import pathlib
from textual.app import ComposeResult from textual.app import ComposeResult
from textual.containers import ScrollableContainer from textual.containers import ScrollableContainer
from textual.screen import Screen from textual.screen import Screen
from textual.widgets import (Button, Footer, Header, Label, ListItem, ListView, from textual.widgets import Button, Footer, Header, Label, ListItem, ListView, Static
Static)
import heurams.services.timer as timer import heurams.services.timer as timer
import heurams.services.version as version import heurams.services.version as version
@@ -28,6 +27,7 @@ class DashboardScreen(Screen):
Label(f'欢迎使用 "潜进" 启发式辅助记忆调度器', classes="title-label"), Label(f'欢迎使用 "潜进" 启发式辅助记忆调度器', classes="title-label"),
Label(f"当前 UNIX 日时间戳: {timer.get_daystamp()}"), Label(f"当前 UNIX 日时间戳: {timer.get_daystamp()}"),
Label(f'时区修正: UTC+{config_var.get()["timezone_offset"] / 3600}'), Label(f'时区修正: UTC+{config_var.get()["timezone_offset"] / 3600}'),
Label(f"使用算法: {config_var.get()['algorithm']['default']}"),
Label("选择待学习或待修改的记忆单元集:", classes="title-label"), Label("选择待学习或待修改的记忆单元集:", classes="title-label"),
ListView(id="union-list", classes="union-list-view"), ListView(id="union-list", classes="union-list-view"),
Label( Label(

View File

@@ -148,16 +148,24 @@ class MemScreen(Screen):
def play_voice(self): def play_voice(self):
"""朗读当前内容""" """朗读当前内容"""
from heurams.services.audio_service import play_by_path
from pathlib import Path from pathlib import Path
from heurams.services.audio_service import play_by_path
from heurams.services.hasher import get_md5 from heurams.services.hasher import get_md5
path = Path(config_var.get()['paths']["cache_dir"])
path = path / f"{get_md5(self.atom.registry['nucleon'].metadata["formation"]["tts_text"])}.wav" path = Path(config_var.get()["paths"]["cache_dir"])
path = (
path
/ f"{get_md5(self.atom.registry['nucleon'].metadata["formation"]["tts_text"])}.wav"
)
if path.exists(): if path.exists():
play_by_path(path) play_by_path(path)
else: else:
from heurams.services.tts_service import convertor from heurams.services.tts_service import convertor
convertor(self.atom.registry['nucleon'].metadata["formation"]["tts_text"], path)
convertor(
self.atom.registry["nucleon"].metadata["formation"]["tts_text"], path
)
play_by_path(path) play_by_path(path)
def action_toggle_dark(self): def action_toggle_dark(self):

View File

@@ -5,8 +5,7 @@ import toml
from textual.app import ComposeResult from textual.app import ComposeResult
from textual.containers import ScrollableContainer from textual.containers import ScrollableContainer
from textual.screen import Screen from textual.screen import Screen
from textual.widgets import (Button, Footer, Header, Input, Label, Markdown, from textual.widgets import Button, Footer, Header, Input, Label, Markdown, Select
Select)
from heurams.context import config_var from heurams.context import config_var
from heurams.services.version import ver from heurams.services.version import ver

View File

@@ -99,6 +99,7 @@ class PrecachingScreen(Screen):
if not cache_file.exists(): if not cache_file.exists():
try: try:
from heurams.services.tts_service import convertor from heurams.services.tts_service import convertor
convertor(text, cache_file) convertor(text, cache_file)
return 1 return 1
except Exception as e: except Exception as e:

View File

@@ -32,7 +32,7 @@ class SyncScreen(Screen):
# 标题和连接状态 # 标题和连接状态
yield Static("WebDAV 同步工具", classes="title") yield Static("WebDAV 同步工具", classes="title")
yield Static("", id="status_label", classes="status") yield Static("", id="status_label", classes="status")
# 配置信息 # 配置信息
yield Static("服务器配置", classes="section_title") yield Static("服务器配置", classes="section_title")
with Horizontal(classes="config_info"): with Horizontal(classes="config_info"):
@@ -44,7 +44,7 @@ class SyncScreen(Screen):
with Horizontal(classes="config_info"): with Horizontal(classes="config_info"):
yield Static("同步模式:", classes="config_label") yield Static("同步模式:", classes="config_label")
yield Static("", id="sync_mode", classes="config_value") yield Static("", id="sync_mode", classes="config_value")
# 控制按钮 # 控制按钮
yield Static("控制面板", classes="section_title") yield Static("控制面板", classes="section_title")
with Horizontal(classes="control_buttons"): with Horizontal(classes="control_buttons"):
@@ -52,16 +52,16 @@ class SyncScreen(Screen):
yield Button("开始同步", id="start_sync", variant="success") yield Button("开始同步", id="start_sync", variant="success")
yield Button("暂停", id="pause_sync", variant="warning", disabled=True) yield Button("暂停", id="pause_sync", variant="warning", disabled=True)
yield Button("取消", id="cancel_sync", variant="error", disabled=True) yield Button("取消", id="cancel_sync", variant="error", disabled=True)
# 进度显示 # 进度显示
yield Static("同步进度", classes="section_title") yield Static("同步进度", classes="section_title")
yield ProgressBar(id="progress_bar", show_percentage=True, total=100) yield ProgressBar(id="progress_bar", show_percentage=True, total=100)
yield Static("", id="progress_label", classes="progress_text") yield Static("", id="progress_label", classes="progress_text")
# 日志输出 # 日志输出
yield Static("同步日志", classes="section_title") yield Static("同步日志", classes="section_title")
yield Static("", id="log_output", classes="log_output") yield Static("", id="log_output", classes="log_output")
yield Footer() yield Footer()
def on_mount(self): def on_mount(self):
@@ -74,13 +74,15 @@ class SyncScreen(Screen):
"""从配置文件加载同步设置""" """从配置文件加载同步设置"""
try: try:
from heurams.context import config_var from heurams.context import config_var
config_data = config_var.get().data config_data = config_var.get().data
self.sync_config = config_data.get('sync', {}).get('webdav', {}) self.sync_config = config_data.get("sync", {}).get("webdav", {})
# 创建同步服务实例 # 创建同步服务实例
from heurams.services.sync_service import create_sync_service_from_config from heurams.services.sync_service import create_sync_service_from_config
self.sync_service = create_sync_service_from_config() self.sync_service = create_sync_service_from_config()
except Exception as e: except Exception as e:
self.log_message(f"加载配置失败: {e}", is_error=True) self.log_message(f"加载配置失败: {e}", is_error=True)
self.sync_config = {} self.sync_config = {}
@@ -89,34 +91,34 @@ class SyncScreen(Screen):
"""更新 UI 显示配置信息""" """更新 UI 显示配置信息"""
try: try:
# 更新服务器 URL # 更新服务器 URL
url = self.sync_config.get('url', '未配置') url = self.sync_config.get("url", "未配置")
url_widget = self.query_one("#server_url") url_widget = self.query_one("#server_url")
url_widget.update(url if url else '未配置') # type: ignore url_widget.update(url if url else "未配置") # type: ignore
# 更新远程路径 # 更新远程路径
remote_path = self.sync_config.get('remote_path', '/heurams/') remote_path = self.sync_config.get("remote_path", "/heurams/")
path_widget = self.query_one("#remote_path") path_widget = self.query_one("#remote_path")
path_widget.update(remote_path) # type: ignore path_widget.update(remote_path) # type: ignore
# 更新同步模式 # 更新同步模式
sync_mode = self.sync_config.get('sync_mode', 'bidirectional') sync_mode = self.sync_config.get("sync_mode", "bidirectional")
mode_widget = self.query_one("#sync_mode") mode_widget = self.query_one("#sync_mode")
mode_map = { mode_map = {
'bidirectional': '双向同步', "bidirectional": "双向同步",
'upload_only': '仅上传', "upload_only": "仅上传",
'download_only': '仅下载', "download_only": "仅下载",
} }
mode_widget.update(mode_map.get(sync_mode, sync_mode)) # type: ignore mode_widget.update(mode_map.get(sync_mode, sync_mode)) # type: ignore
# 更新状态标签 # 更新状态标签
status_widget = self.query_one("#status_label") status_widget = self.query_one("#status_label")
if self.sync_service and self.sync_service.client: if self.sync_service and self.sync_service.client:
status_widget.update("✅ 同步服务已就绪") # type: ignore status_widget.update("✅ 同步服务已就绪") # type: ignore
status_widget.add_class("ready") status_widget.add_class("ready")
else: else:
status_widget.update("❌ 同步服务未配置或未启用") # type: ignore status_widget.update("❌ 同步服务未配置或未启用") # type: ignore
status_widget.add_class("error") status_widget.add_class("error")
except Exception as e: except Exception as e:
self.log_message(f"更新 UI 失败: {e}", is_error=True) self.log_message(f"更新 UI 失败: {e}", is_error=True)
@@ -124,15 +126,15 @@ class SyncScreen(Screen):
"""更新状态显示""" """更新状态显示"""
try: try:
status_widget = self.query_one("#status_label") status_widget = self.query_one("#status_label")
status_widget.update(status) # type: ignore status_widget.update(status) # type: ignore
if progress is not None: if progress is not None:
progress_bar = self.query_one("#progress_bar") progress_bar = self.query_one("#progress_bar")
progress_bar.progress = progress # type: ignore progress_bar.progress = progress # type: ignore
progress_label = self.query_one("#progress_label") progress_label = self.query_one("#progress_label")
progress_label.update(f"{progress}% - {current_item}" if current_item else f"{progress}%") # type: ignore progress_label.update(f"{progress}% - {current_item}" if current_item else f"{progress}%") # type: ignore
except Exception as e: except Exception as e:
self.log_message(f"更新状态失败: {e}", is_error=True) self.log_message(f"更新状态失败: {e}", is_error=True)
@@ -141,23 +143,23 @@ class SyncScreen(Screen):
timestamp = time.strftime("%H:%M:%S") timestamp = time.strftime("%H:%M:%S")
prefix = "[ERROR]" if is_error else "[INFO]" prefix = "[ERROR]" if is_error else "[INFO]"
log_line = f"{timestamp} {prefix} {message}" log_line = f"{timestamp} {prefix} {message}"
self.log_messages.append(log_line) self.log_messages.append(log_line)
# 保持日志行数不超过最大值 # 保持日志行数不超过最大值
if len(self.log_messages) > self.max_log_lines: if len(self.log_messages) > self.max_log_lines:
self.log_messages = self.log_messages[-self.max_log_lines:] self.log_messages = self.log_messages[-self.max_log_lines :]
# 更新日志显示 # 更新日志显示
try: try:
log_widget = self.query_one("#log_output") log_widget = self.query_one("#log_output")
log_widget.update("\n".join(self.log_messages)) # type: ignore log_widget.update("\n".join(self.log_messages)) # type: ignore
except Exception: except Exception:
pass # 如果组件未就绪,忽略错误 pass # 如果组件未就绪,忽略错误
def on_button_pressed(self, event: Button.Pressed) -> None: def on_button_pressed(self, event: Button.Pressed) -> None:
"""处理按钮点击事件""" """处理按钮点击事件"""
button_id = event.button.id button_id = event.button.id
if button_id == "test_connection": if button_id == "test_connection":
self.test_connection() self.test_connection()
elif button_id == "start_sync": elif button_id == "start_sync":
@@ -166,7 +168,7 @@ class SyncScreen(Screen):
self.pause_sync() self.pause_sync()
elif button_id == "cancel_sync": elif button_id == "cancel_sync":
self.cancel_sync() self.cancel_sync()
event.stop() event.stop()
def test_connection(self): def test_connection(self):
@@ -175,10 +177,10 @@ class SyncScreen(Screen):
self.log_message("同步服务未初始化,请检查配置", is_error=True) self.log_message("同步服务未初始化,请检查配置", is_error=True)
self.update_status("❌ 同步服务未初始化") self.update_status("❌ 同步服务未初始化")
return return
self.log_message("正在测试 WebDAV 连接...") self.log_message("正在测试 WebDAV 连接...")
self.update_status("正在测试连接...") self.update_status("正在测试连接...")
try: try:
success = self.sync_service.test_connection() success = self.sync_service.test_connection()
if success: if success:
@@ -196,71 +198,87 @@ class SyncScreen(Screen):
if not self.sync_service: if not self.sync_service:
self.log_message("同步服务未初始化,无法开始同步", is_error=True) self.log_message("同步服务未初始化,无法开始同步", is_error=True)
return return
if self.is_syncing: if self.is_syncing:
self.log_message("同步已在进行中", is_error=True) self.log_message("同步已在进行中", is_error=True)
return return
self.is_syncing = True self.is_syncing = True
self.is_paused = False self.is_paused = False
self.update_button_states() self.update_button_states()
self.log_message("开始同步数据...") self.log_message("开始同步数据...")
self.update_status("正在同步...", progress=0) self.update_status("正在同步...", progress=0)
# 启动后台同步任务 # 启动后台同步任务
self.run_worker(self.perform_sync, thread=True) self.run_worker(self.perform_sync, thread=True)
def perform_sync(self): def perform_sync(self):
"""执行同步任务(在后台线程中运行)""" """执行同步任务(在后台线程中运行)"""
worker = get_current_worker() worker = get_current_worker()
try: try:
# 获取需要同步的本地目录 # 获取需要同步的本地目录
from heurams.context import config_var from heurams.context import config_var
config = config_var.get() config = config_var.get()
paths = config.get('paths', {}) paths = config.get("paths", {})
# 同步 nucleon 目录 # 同步 nucleon 目录
nucleon_dir = pathlib.Path(paths.get('nucleon_dir', './data/nucleon')) nucleon_dir = pathlib.Path(paths.get("nucleon_dir", "./data/nucleon"))
if nucleon_dir.exists(): if nucleon_dir.exists():
self.log_message(f"同步 nucleon 目录: {nucleon_dir}") self.log_message(f"同步 nucleon 目录: {nucleon_dir}")
self.update_status(f"同步 nucleon 目录...", progress=10) self.update_status(f"同步 nucleon 目录...", progress=10)
result = self.sync_service.sync_directory(nucleon_dir) # type: ignore result = self.sync_service.sync_directory(nucleon_dir) # type: ignore
if result.get('success'): if result.get("success"):
self.log_message(f"nucleon 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}") self.log_message(
f"nucleon 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}"
)
else: else:
self.log_message(f"nucleon 同步失败: {result.get('error', '未知错误')}", is_error=True) self.log_message(
f"nucleon 同步失败: {result.get('error', '未知错误')}",
is_error=True,
)
# 同步 electron 目录 # 同步 electron 目录
electron_dir = pathlib.Path(paths.get('electron_dir', './data/electron')) electron_dir = pathlib.Path(paths.get("electron_dir", "./data/electron"))
if electron_dir.exists(): if electron_dir.exists():
self.log_message(f"同步 electron 目录: {electron_dir}") self.log_message(f"同步 electron 目录: {electron_dir}")
self.update_status(f"同步 electron 目录...", progress=60) self.update_status(f"同步 electron 目录...", progress=60)
result = self.sync_service.sync_directory(electron_dir) # type: ignore result = self.sync_service.sync_directory(electron_dir) # type: ignore
if result.get('success'): if result.get("success"):
self.log_message(f"electron 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}") self.log_message(
f"electron 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}"
)
else: else:
self.log_message(f"electron 同步失败: {result.get('error', '未知错误')}", is_error=True) self.log_message(
f"electron 同步失败: {result.get('error', '未知错误')}",
is_error=True,
)
# 同步 orbital 目录(如果存在) # 同步 orbital 目录(如果存在)
orbital_dir = pathlib.Path(paths.get('orbital_dir', './data/orbital')) orbital_dir = pathlib.Path(paths.get("orbital_dir", "./data/orbital"))
if orbital_dir.exists(): if orbital_dir.exists():
self.log_message(f"同步 orbital 目录: {orbital_dir}") self.log_message(f"同步 orbital 目录: {orbital_dir}")
self.update_status(f"同步 orbital 目录...", progress=80) self.update_status(f"同步 orbital 目录...", progress=80)
result = self.sync_service.sync_directory(orbital_dir) # type: ignore result = self.sync_service.sync_directory(orbital_dir) # type: ignore
if result.get('success'): if result.get("success"):
self.log_message(f"orbital 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}") self.log_message(
f"orbital 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}"
)
else: else:
self.log_message(f"orbital 同步失败: {result.get('error', '未知错误')}", is_error=True) self.log_message(
f"orbital 同步失败: {result.get('error', '未知错误')}",
is_error=True,
)
# 同步完成 # 同步完成
self.update_status("同步完成", progress=100) self.update_status("同步完成", progress=100)
self.log_message("所有目录同步完成") self.log_message("所有目录同步完成")
except Exception as e: except Exception as e:
self.log_message(f"同步过程中发生错误: {e}", is_error=True) self.log_message(f"同步过程中发生错误: {e}", is_error=True)
self.update_status("同步失败") self.update_status("同步失败")
@@ -268,16 +286,16 @@ class SyncScreen(Screen):
# 重置同步状态 # 重置同步状态
self.is_syncing = False self.is_syncing = False
self.is_paused = False self.is_paused = False
self.update_button_states() # type: ignore self.update_button_states() # type: ignore
def pause_sync(self): def pause_sync(self):
"""暂停同步""" """暂停同步"""
if not self.is_syncing: if not self.is_syncing:
return return
self.is_paused = not self.is_paused self.is_paused = not self.is_paused
self.update_button_states() self.update_button_states()
if self.is_paused: if self.is_paused:
self.log_message("同步已暂停") self.log_message("同步已暂停")
self.update_status("同步已暂停") self.update_status("同步已暂停")
@@ -289,11 +307,11 @@ class SyncScreen(Screen):
"""取消同步""" """取消同步"""
if not self.is_syncing: if not self.is_syncing:
return return
self.is_syncing = False self.is_syncing = False
self.is_paused = False self.is_paused = False
self.update_button_states() self.update_button_states()
self.log_message("同步已取消") self.log_message("同步已取消")
self.update_status("同步已取消") self.update_status("同步已取消")
@@ -303,17 +321,17 @@ class SyncScreen(Screen):
start_button = self.query_one("#start_sync") start_button = self.query_one("#start_sync")
pause_button = self.query_one("#pause_sync") pause_button = self.query_one("#pause_sync")
cancel_button = self.query_one("#cancel_sync") cancel_button = self.query_one("#cancel_sync")
if self.is_syncing: if self.is_syncing:
start_button.disabled = True start_button.disabled = True
pause_button.disabled = False pause_button.disabled = False
cancel_button.disabled = False cancel_button.disabled = False
pause_button.label = "继续" if self.is_paused else "暂停" # type: ignore pause_button.label = "继续" if self.is_paused else "暂停" # type: ignore
else: else:
start_button.disabled = False start_button.disabled = False
pause_button.disabled = True pause_button.disabled = True
cancel_button.disabled = True cancel_button.disabled = True
except Exception as e: except Exception as e:
self.log_message(f"更新按钮状态失败: {e}", is_error=True) self.log_message(f"更新按钮状态失败: {e}", is_error=True)

View File

@@ -50,9 +50,10 @@ class Recognition(BasePuzzleWidget):
def compose(self): def compose(self):
from heurams.context import config_var from heurams.context import config_var
autovoice = config_var.get()['interface']['memorizor']['autovoice']
autovoice = config_var.get()["interface"]["memorizor"]["autovoice"]
if autovoice: if autovoice:
self.screen.action_play_voice() # type: ignore self.screen.action_play_voice() # type: ignore
cfg: RecognitionConfig = self.atom.registry["orbital"]["puzzles"][self.alia] cfg: RecognitionConfig = self.atom.registry["orbital"]["puzzles"][self.alia]
delim = self.atom.registry["nucleon"].metadata["formation"]["delimiter"] delim = self.atom.registry["nucleon"].metadata["formation"]["delimiter"]
replace_dict = { replace_dict = {
@@ -72,7 +73,7 @@ class Recognition(BasePuzzleWidget):
primary = cfg["primary"] primary = cfg["primary"]
with Center(): with Center():
for i in cfg['top_dim']: for i in cfg["top_dim"]:
yield Static(f"[dim]{i}[/]") yield Static(f"[dim]{i}[/]")
yield Label("") yield Label("")

View File

@@ -10,15 +10,25 @@ MIT 许可证
import datetime import datetime
import json import json
import os import os
from typing import TypedDict
import pathlib import pathlib
from typing import TypedDict
from heurams.context import config_var from heurams.context import config_var
from heurams.kernel.algorithms.sm15m_calc import (MAX_AF, MIN_AF, NOTCH_AF, from heurams.kernel.algorithms.sm15m_calc import (
RANGE_AF, RANGE_REPETITION, MAX_AF,
SM, THRESHOLD_RECALL, Item) MIN_AF,
NOTCH_AF,
RANGE_AF,
RANGE_REPETITION,
SM,
THRESHOLD_RECALL,
Item,
)
# 全局状态文件路径 # 全局状态文件路径
_GLOBAL_STATE_FILE = os.path.expanduser(pathlib.Path(config_var.get()['paths']['global_dir']) / 'sm15m_global_state.json') _GLOBAL_STATE_FILE = os.path.expanduser(
pathlib.Path(config_var.get()["paths"]["global_dir"]) / "sm15m_global_state.json"
)
def _get_global_sm(): def _get_global_sm():

View File

@@ -86,8 +86,8 @@ class Atom:
# eval 环境设置 # eval 环境设置
def eval_with_env(s: str): def eval_with_env(s: str):
default = config_var.get()["puzzles"] default = config_var.get()["puzzles"]
payload = self.registry['nucleon'].payload payload = self.registry["nucleon"].payload
metadata = self.registry['nucleon'].metadata metadata = self.registry["nucleon"].metadata
eval_value = eval(s) eval_value = eval(s)
if isinstance(eval_value, (int, float)): if isinstance(eval_value, (int, float)):
ret = str(eval_value) ret = str(eval_value)
@@ -117,10 +117,11 @@ class Atom:
logger.debug("发现 eval 表达式: '%s'", data[5:]) logger.debug("发现 eval 表达式: '%s'", data[5:])
return modifier(data[5:]) return modifier(data[5:])
return data return data
try: try:
traverse(self.registry['nucleon'].payload, eval_with_env) traverse(self.registry["nucleon"].payload, eval_with_env)
traverse(self.registry['nucleon'].metadata, eval_with_env) traverse(self.registry["nucleon"].metadata, eval_with_env)
traverse(self.registry['orbital'], eval_with_env) traverse(self.registry["orbital"], eval_with_env)
except Exception as e: except Exception as e:
ret = f"此 eval 实例发生错误: {e}" ret = f"此 eval 实例发生错误: {e}"
logger.warning(ret) logger.warning(ret)

View File

@@ -18,9 +18,12 @@ class Electron:
algo: 使用的算法模块标识 algo: 使用的算法模块标识
""" """
if algo_name == "": if algo_name == "":
algo_name = config_var.get()['algorithm']['default'] algo_name = config_var.get()["algorithm"]["default"]
logger.debug( logger.debug(
"创建 Electron 实例, ident: '%s', algo_name: '%s', algodata: %s", ident, algo_name, algodata "创建 Electron 实例, ident: '%s', algo_name: '%s', algodata: %s",
ident,
algo_name,
algodata,
) )
self.algodata = algodata self.algodata = algodata
self.ident = ident self.ident = ident
@@ -31,7 +34,9 @@ class Electron:
self.algodata[self.algo.algo_name] = {} self.algodata[self.algo.algo_name] = {}
logger.debug("算法键 '%s' 不存在, 已创建空字典", self.algo) logger.debug("算法键 '%s' 不存在, 已创建空字典", self.algo)
if not self.algodata[self.algo.algo_name]: if not self.algodata[self.algo.algo_name]:
logger.debug(f"算法数据为空, 使用默认值初始化{self.algodata[self.algo.algo_name]}") logger.debug(
f"算法数据为空, 使用默认值初始化{self.algodata[self.algo.algo_name]}"
)
self._default_init(self.algo.defaults) self._default_init(self.algo.defaults)
else: else:
logger.debug("算法数据已存在, 跳过默认初始化") logger.debug("算法数据已存在, 跳过默认初始化")

View File

@@ -53,4 +53,4 @@ class Nucleon:
def placeholder(): def placeholder():
"""生成一个占位原子核""" """生成一个占位原子核"""
logger.debug("创建 Nucleon 占位符") logger.debug("创建 Nucleon 占位符")
return Nucleon("核子对象样例内容", {}) return Nucleon("核子对象样例内容", {})

View File

@@ -2,8 +2,8 @@ import pathlib
import edge_tts import edge_tts
from heurams.services.logger import get_logger
from heurams.context import config_var from heurams.context import config_var
from heurams.services.logger import get_logger
from .base import BaseTTS from .base import BaseTTS
@@ -19,7 +19,7 @@ class EdgeTTS(BaseTTS):
try: try:
communicate = edge_tts.Communicate( communicate = edge_tts.Communicate(
text, text,
config_var.get()['providers']['tts']['edgetts']["voice"], config_var.get()["providers"]["tts"]["edgetts"]["voice"],
) )
logger.debug("EdgeTTS 通信对象创建成功, 正在保存音频") logger.debug("EdgeTTS 通信对象创建成功, 正在保存音频")
communicate.save_sync(str(path)) communicate.save_sync(str(path))

View File

@@ -18,6 +18,7 @@ logger = get_logger(__name__)
class SyncMode(Enum): class SyncMode(Enum):
"""同步模式枚举""" """同步模式枚举"""
BIDIRECTIONAL = "bidirectional" BIDIRECTIONAL = "bidirectional"
UPLOAD_ONLY = "upload_only" UPLOAD_ONLY = "upload_only"
DOWNLOAD_ONLY = "download_only" DOWNLOAD_ONLY = "download_only"
@@ -25,6 +26,7 @@ class SyncMode(Enum):
class ConflictStrategy(Enum): class ConflictStrategy(Enum):
"""冲突解决策略枚举""" """冲突解决策略枚举"""
NEWER = "newer" # 较新文件覆盖较旧文件 NEWER = "newer" # 较新文件覆盖较旧文件
ASK = "ask" # 用户手动选择 ASK = "ask" # 用户手动选择
KEEP_BOTH = "keep_both" # 保留双方(重命名) KEEP_BOTH = "keep_both" # 保留双方(重命名)
@@ -33,6 +35,7 @@ class ConflictStrategy(Enum):
@dataclass @dataclass
class SyncConfig: class SyncConfig:
"""同步配置数据类""" """同步配置数据类"""
enabled: bool = False enabled: bool = False
url: str = "" url: str = ""
username: str = "" username: str = ""
@@ -59,12 +62,12 @@ class SyncService:
return return
options = { options = {
'webdav_hostname': self.config.url, "webdav_hostname": self.config.url,
'webdav_login': self.config.username, "webdav_login": self.config.username,
'webdav_password': self.config.password, "webdav_password": self.config.password,
'webdav_root': self.config.remote_path, "webdav_root": self.config.remote_path,
'verify_ssl': self.config.verify_ssl, "verify_ssl": self.config.verify_ssl,
'disable_check': True, # 不检查服务器支持的功能 "disable_check": True, # 不检查服务器支持的功能
} }
try: try:
@@ -98,10 +101,10 @@ class SyncService:
rel_path = file_path.relative_to(local_dir) rel_path = file_path.relative_to(local_dir)
stat = file_path.stat() stat = file_path.stat()
files[str(rel_path)] = { files[str(rel_path)] = {
'path': file_path, "path": file_path,
'size': stat.st_size, "size": stat.st_size,
'mtime': stat.st_mtime, "mtime": stat.st_mtime,
'hash': self._calculate_hash(file_path), "hash": self._calculate_hash(file_path),
} }
return files return files
@@ -114,14 +117,14 @@ class SyncService:
remote_list = self.client.list(recursive=True) remote_list = self.client.list(recursive=True)
files = {} files = {}
for item in remote_list: for item in remote_list:
if not item.endswith('/'): # 忽略目录 if not item.endswith("/"): # 忽略目录
rel_path = item.lstrip('/') rel_path = item.lstrip("/")
try: try:
info = self.client.info(item) info = self.client.info(item)
files[rel_path] = { files[rel_path] = {
'path': item, "path": item,
'size': info.get('size', 0), "size": info.get("size", 0),
'mtime': self._parse_remote_mtime(info), "mtime": self._parse_remote_mtime(info),
} }
except Exception as e: except Exception as e:
logger.warning("无法获取远程文件信息 %s: %s", item, e) logger.warning("无法获取远程文件信息 %s: %s", item, e)
@@ -134,8 +137,8 @@ class SyncService:
"""计算文件的 SHA-256 哈希值""" """计算文件的 SHA-256 哈希值"""
sha256 = hashlib.sha256() sha256 = hashlib.sha256()
try: try:
with open(file_path, 'rb') as f: with open(file_path, "rb") as f:
for block in iter(lambda: f.read(block_size), b''): for block in iter(lambda: f.read(block_size), b""):
sha256.update(block) sha256.update(block)
return sha256.hexdigest() return sha256.hexdigest()
except Exception as e: except Exception as e:
@@ -151,23 +154,23 @@ class SyncService:
def sync_directory(self, local_dir: pathlib.Path) -> typing.Dict[str, typing.Any]: def sync_directory(self, local_dir: pathlib.Path) -> typing.Dict[str, typing.Any]:
""" """
同步目录 同步目录
Args: Args:
local_dir: 本地目录路径 local_dir: 本地目录路径
Returns: Returns:
同步结果统计 同步结果统计
""" """
if not self.client: if not self.client:
logger.error("WebDAV 客户端未初始化") logger.error("WebDAV 客户端未初始化")
return {'success': False, 'error': '客户端未初始化'} return {"success": False, "error": "客户端未初始化"}
results = { results = {
'uploaded': 0, "uploaded": 0,
'downloaded': 0, "downloaded": 0,
'conflicts': 0, "conflicts": 0,
'errors': 0, "errors": 0,
'success': True, "success": True,
} }
try: try:
@@ -180,124 +183,144 @@ class SyncService:
# 根据同步模式处理文件 # 根据同步模式处理文件
if self.config.sync_mode in [SyncMode.BIDIRECTIONAL, SyncMode.UPLOAD_ONLY]: if self.config.sync_mode in [SyncMode.BIDIRECTIONAL, SyncMode.UPLOAD_ONLY]:
stats = self._upload_files(local_dir, local_files, remote_files) stats = self._upload_files(local_dir, local_files, remote_files)
results['uploaded'] += stats.get('uploaded', 0) results["uploaded"] += stats.get("uploaded", 0)
results['conflicts'] += stats.get('conflicts', 0) results["conflicts"] += stats.get("conflicts", 0)
results['errors'] += stats.get('errors', 0) results["errors"] += stats.get("errors", 0)
if self.config.sync_mode in [SyncMode.BIDIRECTIONAL, SyncMode.DOWNLOAD_ONLY]: if self.config.sync_mode in [
SyncMode.BIDIRECTIONAL,
SyncMode.DOWNLOAD_ONLY,
]:
stats = self._download_files(local_dir, local_files, remote_files) stats = self._download_files(local_dir, local_files, remote_files)
results['downloaded'] += stats.get('downloaded', 0) results["downloaded"] += stats.get("downloaded", 0)
results['conflicts'] += stats.get('conflicts', 0) results["conflicts"] += stats.get("conflicts", 0)
results['errors'] += stats.get('errors', 0) results["errors"] += stats.get("errors", 0)
logger.info("同步完成: %s", results) logger.info("同步完成: %s", results)
return results return results
except Exception as e: except Exception as e:
logger.error("同步过程中发生错误: %s", e) logger.error("同步过程中发生错误: %s", e)
results['success'] = False results["success"] = False
results['error'] = str(e) results["error"] = str(e)
return results return results
def _upload_files(self, local_dir: pathlib.Path, def _upload_files(
local_files: dict, remote_files: dict) -> typing.Dict[str, int]: self, local_dir: pathlib.Path, local_files: dict, remote_files: dict
) -> typing.Dict[str, int]:
"""上传文件到远程服务器""" """上传文件到远程服务器"""
stats = {'uploaded': 0, 'errors': 0, 'conflicts': 0} stats = {"uploaded": 0, "errors": 0, "conflicts": 0}
for rel_path, local_info in local_files.items(): for rel_path, local_info in local_files.items():
remote_info = remote_files.get(rel_path) remote_info = remote_files.get(rel_path)
# 判断是否需要上传 # 判断是否需要上传
should_upload = False should_upload = False
conflict_resolved = False conflict_resolved = False
remote_path = os.path.join(self.config.remote_path, rel_path) remote_path = os.path.join(self.config.remote_path, rel_path)
if not remote_info: if not remote_info:
should_upload = True # 远程不存在 should_upload = True # 远程不存在
else: else:
# 检查冲突 # 检查冲突
local_mtime = local_info.get('mtime', 0) local_mtime = local_info.get("mtime", 0)
remote_mtime = remote_info.get('mtime', 0) remote_mtime = remote_info.get("mtime", 0)
if local_mtime != remote_mtime: if local_mtime != remote_mtime:
# 存在冲突 # 存在冲突
stats['conflicts'] += 1 stats["conflicts"] += 1
should_upload, should_download = self._handle_conflict(local_info, remote_info) should_upload, should_download = self._handle_conflict(
local_info, remote_info
if should_upload and self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH: )
if (
should_upload
and self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH
):
# 重命名远程文件避免覆盖 # 重命名远程文件避免覆盖
conflict_suffix = f".conflict_{int(remote_mtime)}" conflict_suffix = f".conflict_{int(remote_mtime)}"
name, ext = os.path.splitext(rel_path) name, ext = os.path.splitext(rel_path)
new_rel_path = f"{name}{conflict_suffix}{ext}" if ext else f"{name}{conflict_suffix}" new_rel_path = (
remote_path = os.path.join(self.config.remote_path, new_rel_path) f"{name}{conflict_suffix}{ext}"
if ext
else f"{name}{conflict_suffix}"
)
remote_path = os.path.join(
self.config.remote_path, new_rel_path
)
conflict_resolved = True conflict_resolved = True
logger.debug("冲突文件重命名: %s -> %s", rel_path, new_rel_path) logger.debug("冲突文件重命名: %s -> %s", rel_path, new_rel_path)
else: else:
# 时间相同,无需上传 # 时间相同,无需上传
should_upload = False should_upload = False
if should_upload: if should_upload:
try: try:
self.client.upload_file(local_info['path'], remote_path) self.client.upload_file(local_info["path"], remote_path)
stats['uploaded'] += 1 stats["uploaded"] += 1
logger.debug("上传文件: %s -> %s", rel_path, remote_path) logger.debug("上传文件: %s -> %s", rel_path, remote_path)
except Exception as e: except Exception as e:
logger.error("上传文件失败 %s: %s", rel_path, e) logger.error("上传文件失败 %s: %s", rel_path, e)
stats['errors'] += 1 stats["errors"] += 1
return stats return stats
def _download_files(self, local_dir: pathlib.Path, def _download_files(
local_files: dict, remote_files: dict) -> typing.Dict[str, int]: self, local_dir: pathlib.Path, local_files: dict, remote_files: dict
) -> typing.Dict[str, int]:
"""从远程服务器下载文件""" """从远程服务器下载文件"""
stats = {'downloaded': 0, 'errors': 0, 'conflicts': 0} stats = {"downloaded": 0, "errors": 0, "conflicts": 0}
for rel_path, remote_info in remote_files.items(): for rel_path, remote_info in remote_files.items():
local_info = local_files.get(rel_path) local_info = local_files.get(rel_path)
# 判断是否需要下载 # 判断是否需要下载
should_download = False should_download = False
if not local_info: if not local_info:
should_download = True # 本地不存在 should_download = True # 本地不存在
else: else:
# 检查冲突 # 检查冲突
local_mtime = local_info.get('mtime', 0) local_mtime = local_info.get("mtime", 0)
remote_mtime = remote_info.get('mtime', 0) remote_mtime = remote_info.get("mtime", 0)
if local_mtime != remote_mtime: if local_mtime != remote_mtime:
# 存在冲突 # 存在冲突
stats['conflicts'] += 1 stats["conflicts"] += 1
should_upload, should_download = self._handle_conflict(local_info, remote_info) should_upload, should_download = self._handle_conflict(
local_info, remote_info
)
# 如果应该上传,则不应该下载(冲突已在上传侧处理) # 如果应该上传,则不应该下载(冲突已在上传侧处理)
if should_upload: if should_upload:
should_download = False should_download = False
else: else:
# 时间相同,无需下载 # 时间相同,无需下载
should_download = False should_download = False
if should_download: if should_download:
try: try:
local_path = local_dir / rel_path local_path = local_dir / rel_path
local_path.parent.mkdir(parents=True, exist_ok=True) local_path.parent.mkdir(parents=True, exist_ok=True)
self.client.download_file(remote_info['path'], str(local_path)) self.client.download_file(remote_info["path"], str(local_path))
stats['downloaded'] += 1 stats["downloaded"] += 1
logger.debug("下载文件: %s -> %s", rel_path, local_path) logger.debug("下载文件: %s -> %s", rel_path, local_path)
except Exception as e: except Exception as e:
logger.error("下载文件失败 %s: %s", rel_path, e) logger.error("下载文件失败 %s: %s", rel_path, e)
stats['errors'] += 1 stats["errors"] += 1
return stats return stats
def _handle_conflict(self, local_info: dict, remote_info: dict) -> typing.Tuple[bool, bool]: def _handle_conflict(
self, local_info: dict, remote_info: dict
) -> typing.Tuple[bool, bool]:
""" """
处理文件冲突 处理文件冲突
Returns: Returns:
(should_upload, should_download) - 是否应该上传和下载 (should_upload, should_download) - 是否应该上传和下载
""" """
local_mtime = local_info.get('mtime', 0) local_mtime = local_info.get("mtime", 0)
remote_mtime = remote_info.get('mtime', 0) remote_mtime = remote_info.get("mtime", 0)
if self.config.conflict_strategy == ConflictStrategy.NEWER: if self.config.conflict_strategy == ConflictStrategy.NEWER:
# 较新文件覆盖较旧文件 # 较新文件覆盖较旧文件
if local_mtime > remote_mtime: if local_mtime > remote_mtime:
@@ -306,7 +329,7 @@ class SyncService:
return False, True # 下载远程较新版本 return False, True # 下载远程较新版本
else: else:
return False, False # 时间相同,无需操作 return False, False # 时间相同,无需操作
elif self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH: elif self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH:
# 保留双方 - 重命名远程文件 # 保留双方 - 重命名远程文件
# 这里实现简单的重命名策略:添加冲突后缀 # 这里实现简单的重命名策略:添加冲突后缀
@@ -314,25 +337,28 @@ class SyncService:
# 返回 True, False 表示上传重命名后的文件 # 返回 True, False 表示上传重命名后的文件
# 重命名逻辑在调用处处理 # 重命名逻辑在调用处处理
return True, False return True, False
elif self.config.conflict_strategy == ConflictStrategy.ASK: elif self.config.conflict_strategy == ConflictStrategy.ASK:
# 用户手动选择 - 记录冲突,跳过 # 用户手动选择 - 记录冲突,跳过
# 返回 False, False 跳过,等待用户决定 # 返回 False, False 跳过,等待用户决定
logger.warning("文件冲突需要用户手动选择: local_mtime=%s, remote_mtime=%s", logger.warning(
local_mtime, remote_mtime) "文件冲突需要用户手动选择: local_mtime=%s, remote_mtime=%s",
local_mtime,
remote_mtime,
)
return False, False return False, False
return False, False return False, False
def _should_upload(self, local_info: dict, remote_info: dict) -> bool: def _should_upload(self, local_info: dict, remote_info: dict) -> bool:
"""判断是否需要上传(本地较新或哈希不同)""" """判断是否需要上传(本地较新或哈希不同)"""
# 这里实现简单的基于时间的比较 # 这里实现简单的基于时间的比较
# 实际应该使用哈希比较更可靠 # 实际应该使用哈希比较更可靠
return local_info.get('mtime', 0) > remote_info.get('mtime', 0) return local_info.get("mtime", 0) > remote_info.get("mtime", 0)
def _should_download(self, local_info: dict, remote_info: dict) -> bool: def _should_download(self, local_info: dict, remote_info: dict) -> bool:
"""判断是否需要下载(远程较新)""" """判断是否需要下载(远程较新)"""
return remote_info.get('mtime', 0) > local_info.get('mtime', 0) return remote_info.get("mtime", 0) > local_info.get("mtime", 0)
def upload_file(self, local_path: pathlib.Path, remote_path: str = "") -> bool: def upload_file(self, local_path: pathlib.Path, remote_path: str = "") -> bool:
"""上传单个文件""" """上传单个文件"""
@@ -381,30 +407,32 @@ def create_sync_service_from_config() -> typing.Optional[SyncService]:
"""从配置文件创建同步服务实例""" """从配置文件创建同步服务实例"""
try: try:
from heurams.context import config_var from heurams.context import config_var
sync_config = config_var.get()['providers']['sync']['webdav'] sync_config = config_var.get()["providers"]["sync"]["webdav"]
if not sync_config.get('enabled', False): if not sync_config.get("enabled", False):
logger.debug("同步服务未启用") logger.debug("同步服务未启用")
return None return None
config = SyncConfig( config = SyncConfig(
enabled=sync_config.get('enabled', False), enabled=sync_config.get("enabled", False),
url=sync_config.get('url', ''), url=sync_config.get("url", ""),
username=sync_config.get('username', ''), username=sync_config.get("username", ""),
password=sync_config.get('password', ''), password=sync_config.get("password", ""),
remote_path=sync_config.get('remote_path', '/heurams/'), remote_path=sync_config.get("remote_path", "/heurams/"),
sync_mode=SyncMode(sync_config.get('sync_mode', 'bidirectional')), sync_mode=SyncMode(sync_config.get("sync_mode", "bidirectional")),
conflict_strategy=ConflictStrategy(sync_config.get('conflict_strategy', 'newer')), conflict_strategy=ConflictStrategy(
verify_ssl=sync_config.get('verify_ssl', True), sync_config.get("conflict_strategy", "newer")
),
verify_ssl=sync_config.get("verify_ssl", True),
) )
service = SyncService(config) service = SyncService(config)
if service.client is None: if service.client is None:
logger.warning("同步服务客户端创建失败") logger.warning("同步服务客户端创建失败")
return None return None
return service return service
except Exception as e: except Exception as e:
logger.error("创建同步服务失败: %s", e) logger.error("创建同步服务失败: %s", e)
return None return None

View File

@@ -6,11 +6,16 @@ import pathlib
import tempfile import tempfile
import time import time
import unittest import unittest
from unittest.mock import MagicMock, patch, Mock from unittest.mock import MagicMock, Mock, patch
from heurams.context import ConfigContext from heurams.context import ConfigContext
from heurams.services.config import ConfigFile from heurams.services.config import ConfigFile
from heurams.services.sync_service import SyncService, SyncConfig, SyncMode, ConflictStrategy from heurams.services.sync_service import (
ConflictStrategy,
SyncConfig,
SyncMode,
SyncService,
)
class TestSyncServiceUnit(unittest.TestCase): class TestSyncServiceUnit(unittest.TestCase):
@@ -20,14 +25,14 @@ class TestSyncServiceUnit(unittest.TestCase):
"""在每个测试之前运行, 设置临时目录和模拟客户端.""" """在每个测试之前运行, 设置临时目录和模拟客户端."""
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
self.temp_path = pathlib.Path(self.temp_dir.name) self.temp_path = pathlib.Path(self.temp_dir.name)
# 创建测试文件 # 创建测试文件
self.test_file = self.temp_path / "test.txt" self.test_file = self.temp_path / "test.txt"
self.test_file.write_text("测试内容") self.test_file.write_text("测试内容")
# 模拟 WebDAV 客户端 # 模拟 WebDAV 客户端
self.mock_client = MagicMock() self.mock_client = MagicMock()
# 创建同步配置 # 创建同步配置
self.config = SyncConfig( self.config = SyncConfig(
enabled=True, enabled=True,
@@ -44,100 +49,100 @@ class TestSyncServiceUnit(unittest.TestCase):
"""在每个测试之后清理.""" """在每个测试之后清理."""
self.temp_dir.cleanup() self.temp_dir.cleanup()
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_sync_service_initialization(self, mock_client_class): def test_sync_service_initialization(self, mock_client_class):
"""测试同步服务初始化.""" """测试同步服务初始化."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
service = SyncService(self.config) service = SyncService(self.config)
# 验证客户端已创建 # 验证客户端已创建
mock_client_class.assert_called_once() mock_client_class.assert_called_once()
self.assertIsNotNone(service.client) self.assertIsNotNone(service.client)
self.assertEqual(service.config, self.config) self.assertEqual(service.config, self.config)
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_sync_service_disabled(self, mock_client_class): def test_sync_service_disabled(self, mock_client_class):
"""测试同步服务未启用.""" """测试同步服务未启用."""
config = SyncConfig(enabled=False) config = SyncConfig(enabled=False)
service = SyncService(config) service = SyncService(config)
# 客户端不应初始化 # 客户端不应初始化
mock_client_class.assert_not_called() mock_client_class.assert_not_called()
self.assertIsNone(service.client) self.assertIsNone(service.client)
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_test_connection_success(self, mock_client_class): def test_test_connection_success(self, mock_client_class):
"""测试连接测试成功.""" """测试连接测试成功."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
self.mock_client.list.return_value = [] self.mock_client.list.return_value = []
service = SyncService(self.config) service = SyncService(self.config)
result = service.test_connection() result = service.test_connection()
self.assertTrue(result) self.assertTrue(result)
self.mock_client.list.assert_called_once() self.mock_client.list.assert_called_once()
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_test_connection_failure(self, mock_client_class): def test_test_connection_failure(self, mock_client_class):
"""测试连接测试失败.""" """测试连接测试失败."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
self.mock_client.list.side_effect = Exception("连接失败") self.mock_client.list.side_effect = Exception("连接失败")
service = SyncService(self.config) service = SyncService(self.config)
result = service.test_connection() result = service.test_connection()
self.assertFalse(result) self.assertFalse(result)
self.mock_client.list.assert_called_once() self.mock_client.list.assert_called_once()
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_upload_file(self, mock_client_class): def test_upload_file(self, mock_client_class):
"""测试上传单个文件.""" """测试上传单个文件."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
service = SyncService(self.config) service = SyncService(self.config)
result = service.upload_file(self.test_file) result = service.upload_file(self.test_file)
self.assertTrue(result) self.assertTrue(result)
self.mock_client.upload_file.assert_called_once() self.mock_client.upload_file.assert_called_once()
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_download_file(self, mock_client_class): def test_download_file(self, mock_client_class):
"""测试下载单个文件.""" """测试下载单个文件."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
service = SyncService(self.config) service = SyncService(self.config)
remote_path = "/heurams/test.txt" remote_path = "/heurams/test.txt"
local_path = self.temp_path / "downloaded.txt" local_path = self.temp_path / "downloaded.txt"
result = service.download_file(remote_path, local_path) result = service.download_file(remote_path, local_path)
self.assertTrue(result) self.assertTrue(result)
self.mock_client.download_file.assert_called_once() self.mock_client.download_file.assert_called_once()
self.assertTrue(local_path.parent.exists()) self.assertTrue(local_path.parent.exists())
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_sync_directory_no_files(self, mock_client_class): def test_sync_directory_no_files(self, mock_client_class):
"""测试同步空目录.""" """测试同步空目录."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
self.mock_client.list.return_value = [] self.mock_client.list.return_value = []
self.mock_client.mkdir.return_value = None self.mock_client.mkdir.return_value = None
service = SyncService(self.config) service = SyncService(self.config)
result = service.sync_directory(self.temp_path) result = service.sync_directory(self.temp_path)
self.assertTrue(result['success']) self.assertTrue(result["success"])
self.assertEqual(result['uploaded'], 0) self.assertEqual(result["uploaded"], 0)
self.assertEqual(result['downloaded'], 0) self.assertEqual(result["downloaded"], 0)
self.mock_client.mkdir.assert_called_once() self.mock_client.mkdir.assert_called_once()
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_sync_directory_upload_only(self, mock_client_class): def test_sync_directory_upload_only(self, mock_client_class):
"""测试仅上传模式.""" """测试仅上传模式."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
self.mock_client.list.return_value = [] self.mock_client.list.return_value = []
self.mock_client.mkdir.return_value = None self.mock_client.mkdir.return_value = None
config = SyncConfig( config = SyncConfig(
enabled=True, enabled=True,
url="https://example.com/dav/", url="https://example.com/dav/",
@@ -147,60 +152,64 @@ class TestSyncServiceUnit(unittest.TestCase):
sync_mode=SyncMode.UPLOAD_ONLY, sync_mode=SyncMode.UPLOAD_ONLY,
conflict_strategy=ConflictStrategy.NEWER, conflict_strategy=ConflictStrategy.NEWER,
) )
service = SyncService(config) service = SyncService(config)
result = service.sync_directory(self.temp_path) result = service.sync_directory(self.temp_path)
self.assertTrue(result['success']) self.assertTrue(result["success"])
self.mock_client.mkdir.assert_called_once() self.mock_client.mkdir.assert_called_once()
@patch('heurams.services.sync_service.Client') @patch("heurams.services.sync_service.Client")
def test_conflict_strategy_newer(self, mock_client_class): def test_conflict_strategy_newer(self, mock_client_class):
"""测试 NEWER 冲突策略.""" """测试 NEWER 冲突策略."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
# 模拟远程文件存在 # 模拟远程文件存在
self.mock_client.list.return_value = ["test.txt"] self.mock_client.list.return_value = ["test.txt"]
self.mock_client.info.return_value = {'size': 100, 'modified': '2023-01-01T00:00:00Z'} self.mock_client.info.return_value = {
"size": 100,
"modified": "2023-01-01T00:00:00Z",
}
self.mock_client.mkdir.return_value = None self.mock_client.mkdir.return_value = None
service = SyncService(self.config) service = SyncService(self.config)
result = service.sync_directory(self.temp_path) result = service.sync_directory(self.temp_path)
self.assertTrue(result['success'])
# 应该有一个冲突
self.assertGreaterEqual(result.get('conflicts', 0), 0)
@patch('heurams.services.sync_service.Client') self.assertTrue(result["success"])
# 应该有一个冲突
self.assertGreaterEqual(result.get("conflicts", 0), 0)
@patch("heurams.services.sync_service.Client")
def test_create_sync_service_from_config(self, mock_client_class): def test_create_sync_service_from_config(self, mock_client_class):
"""测试从配置文件创建同步服务.""" """测试从配置文件创建同步服务."""
mock_client_class.return_value = self.mock_client mock_client_class.return_value = self.mock_client
# 创建临时配置文件 # 创建临时配置文件
config_data = { config_data = {
'sync': { "sync": {
'webdav': { "webdav": {
'enabled': True, "enabled": True,
'url': 'https://example.com/dav/', "url": "https://example.com/dav/",
'username': 'test', "username": "test",
'password': 'test', "password": "test",
'remote_path': '/heurams/', "remote_path": "/heurams/",
'sync_mode': 'bidirectional', "sync_mode": "bidirectional",
'conflict_strategy': 'newer', "conflict_strategy": "newer",
'verify_ssl': True, "verify_ssl": True,
} }
} }
} }
# 模拟 config_var # 模拟 config_var
with patch('heurams.services.sync_service.config_var') as mock_config_var: with patch("heurams.services.sync_service.config_var") as mock_config_var:
mock_config = MagicMock() mock_config = MagicMock()
mock_config.data = config_data mock_config.data = config_data
mock_config_var.get.return_value = mock_config mock_config_var.get.return_value = mock_config
from heurams.services.sync_service import create_sync_service_from_config from heurams.services.sync_service import create_sync_service_from_config
service = create_sync_service_from_config() service = create_sync_service_from_config()
self.assertIsNotNone(service) self.assertIsNotNone(service)
self.assertIsNotNone(service.client) self.assertIsNotNone(service.client)
@@ -212,39 +221,41 @@ class TestSyncScreenUnit(unittest.TestCase):
"""在每个测试之前运行.""" """在每个测试之前运行."""
self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir = tempfile.TemporaryDirectory()
self.temp_path = pathlib.Path(self.temp_dir.name) self.temp_path = pathlib.Path(self.temp_dir.name)
# 创建默认配置 # 创建默认配置
default_config_path = ( default_config_path = (
pathlib.Path(__file__).parent.parent.parent pathlib.Path(__file__).parent.parent.parent
/ "src/heurams/default/config/config.toml" / "src/heurams/default/config/config.toml"
) )
self.config = ConfigFile(default_config_path) self.config = ConfigFile(default_config_path)
# 更新配置中的路径 # 更新配置中的路径
config_data = self.config.data config_data = self.config.data
config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon") config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon")
config_data["paths"]["electron_dir"] = str(self.temp_path / "electron") config_data["paths"]["electron_dir"] = str(self.temp_path / "electron")
config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital") config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital")
config_data["paths"]["cache_dir"] = str(self.temp_path / "cache") config_data["paths"]["cache_dir"] = str(self.temp_path / "cache")
# 添加同步配置 # 添加同步配置
if 'sync' not in config_data: if "sync" not in config_data:
config_data['sync'] = {} config_data["sync"] = {}
config_data['sync']['webdav'] = { config_data["sync"]["webdav"] = {
'enabled': False, "enabled": False,
'url': '', "url": "",
'username': '', "username": "",
'password': '', "password": "",
'remote_path': '/heurams/', "remote_path": "/heurams/",
'sync_mode': 'bidirectional', "sync_mode": "bidirectional",
'conflict_strategy': 'newer', "conflict_strategy": "newer",
'verify_ssl': True, "verify_ssl": True,
} }
# 创建目录 # 创建目录
for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]: for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]:
pathlib.Path(config_data["paths"][dir_key]).mkdir(parents=True, exist_ok=True) pathlib.Path(config_data["paths"][dir_key]).mkdir(
parents=True, exist_ok=True
)
# 使用 ConfigContext 设置配置 # 使用 ConfigContext 设置配置
self.config_ctx = ConfigContext(self.config) self.config_ctx = ConfigContext(self.config)
self.config_ctx.__enter__() self.config_ctx.__enter__()
@@ -254,52 +265,53 @@ class TestSyncScreenUnit(unittest.TestCase):
self.config_ctx.__exit__(None, None, None) self.config_ctx.__exit__(None, None, None)
self.temp_dir.cleanup() self.temp_dir.cleanup()
@patch('heurams.interface.screens.synctool.create_sync_service_from_config') @patch("heurams.interface.screens.synctool.create_sync_service_from_config")
def test_sync_screen_compose(self, mock_create_service): def test_sync_screen_compose(self, mock_create_service):
"""测试 SyncScreen 的 compose 方法.""" """测试 SyncScreen 的 compose 方法."""
from heurams.interface.screens.synctool import SyncScreen from heurams.interface.screens.synctool import SyncScreen
# 模拟同步服务 # 模拟同步服务
mock_service = MagicMock() mock_service = MagicMock()
mock_service.client = MagicMock() mock_service.client = MagicMock()
mock_create_service.return_value = mock_service mock_create_service.return_value = mock_service
screen = SyncScreen() screen = SyncScreen()
# 测试 compose 方法 # 测试 compose 方法
from textual.app import ComposeResult from textual.app import ComposeResult
result = screen.compose() result = screen.compose()
widgets = list(result) widgets = list(result)
# 检查基本部件 # 检查基本部件
from textual.widgets import Footer, Header, Button, Static, ProgressBar
from textual.containers import ScrollableContainer from textual.containers import ScrollableContainer
from textual.widgets import Button, Footer, Header, ProgressBar, Static
header_present = any(isinstance(w, Header) for w in widgets) header_present = any(isinstance(w, Header) for w in widgets)
footer_present = any(isinstance(w, Footer) for w in widgets) footer_present = any(isinstance(w, Footer) for w in widgets)
self.assertTrue(header_present) self.assertTrue(header_present)
self.assertTrue(footer_present) self.assertTrue(footer_present)
# 检查容器 # 检查容器
container_present = any(isinstance(w, ScrollableContainer) for w in widgets) container_present = any(isinstance(w, ScrollableContainer) for w in widgets)
self.assertTrue(container_present) self.assertTrue(container_present)
@patch('heurams.interface.screens.synctool.create_sync_service_from_config') @patch("heurams.interface.screens.synctool.create_sync_service_from_config")
def test_sync_screen_load_config(self, mock_create_service): def test_sync_screen_load_config(self, mock_create_service):
"""测试 SyncScreen 加载配置.""" """测试 SyncScreen 加载配置."""
from heurams.interface.screens.synctool import SyncScreen from heurams.interface.screens.synctool import SyncScreen
mock_service = MagicMock() mock_service = MagicMock()
mock_service.client = MagicMock() mock_service.client = MagicMock()
mock_create_service.return_value = mock_service mock_create_service.return_value = mock_service
screen = SyncScreen() screen = SyncScreen()
screen.load_config() screen.load_config()
# 验证配置已加载 # 验证配置已加载
self.assertIsNotNone(screen.sync_config) self.assertIsNotNone(screen.sync_config)
mock_create_service.assert_called_once() mock_create_service.assert_called_once()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()