style: 格式化代码

This commit is contained in:
2025-12-21 07:56:10 +08:00
parent 1efe034a59
commit a2e12c7462
15 changed files with 373 additions and 290 deletions

View File

@@ -36,8 +36,8 @@ if pathlib.Path(workdir / "config" / "config_dev.toml").exists():
print("使用开发设置")
logger.debug("使用开发设置")
config_var: ContextVar[ConfigFile] = ContextVar(
"config_var", default=ConfigFile(workdir / "config" / "config_dev.toml")
)
"config_var", default=ConfigFile(workdir / "config" / "config_dev.toml")
)
# runtime_var: ContextVar = ContextVar('runtime_var', default=dict()) # 运行时共享数据

View File

@@ -1,9 +1,9 @@
from textual.app import App
from textual.widgets import Button
from heurams.services.logger import get_logger
from heurams.context import config_var
from heurams.interface import HeurAMSApp
from heurams.services.logger import get_logger
from .screens.about import AboutScreen
from .screens.dashboard import DashboardScreen
@@ -15,4 +15,4 @@ logger = get_logger(__name__)
app = HeurAMSApp()
if __name__ == "__main__":
app.run()
app.run()

View File

@@ -4,8 +4,7 @@ import pathlib
from textual.app import ComposeResult
from textual.containers import ScrollableContainer
from textual.screen import Screen
from textual.widgets import (Button, Footer, Header, Label, ListItem, ListView,
Static)
from textual.widgets import Button, Footer, Header, Label, ListItem, ListView, Static
import heurams.services.timer as timer
import heurams.services.version as version
@@ -28,6 +27,7 @@ class DashboardScreen(Screen):
Label(f'欢迎使用 "潜进" 启发式辅助记忆调度器', classes="title-label"),
Label(f"当前 UNIX 日时间戳: {timer.get_daystamp()}"),
Label(f'时区修正: UTC+{config_var.get()["timezone_offset"] / 3600}'),
Label(f"使用算法: {config_var.get()['algorithm']['default']}"),
Label("选择待学习或待修改的记忆单元集:", classes="title-label"),
ListView(id="union-list", classes="union-list-view"),
Label(

View File

@@ -148,16 +148,24 @@ class MemScreen(Screen):
def play_voice(self):
"""朗读当前内容"""
from heurams.services.audio_service import play_by_path
from pathlib import Path
from heurams.services.audio_service import play_by_path
from heurams.services.hasher import get_md5
path = Path(config_var.get()['paths']["cache_dir"])
path = path / f"{get_md5(self.atom.registry['nucleon'].metadata["formation"]["tts_text"])}.wav"
path = Path(config_var.get()["paths"]["cache_dir"])
path = (
path
/ f"{get_md5(self.atom.registry['nucleon'].metadata["formation"]["tts_text"])}.wav"
)
if path.exists():
play_by_path(path)
else:
from heurams.services.tts_service import convertor
convertor(self.atom.registry['nucleon'].metadata["formation"]["tts_text"], path)
convertor(
self.atom.registry["nucleon"].metadata["formation"]["tts_text"], path
)
play_by_path(path)
def action_toggle_dark(self):

View File

@@ -5,8 +5,7 @@ import toml
from textual.app import ComposeResult
from textual.containers import ScrollableContainer
from textual.screen import Screen
from textual.widgets import (Button, Footer, Header, Input, Label, Markdown,
Select)
from textual.widgets import Button, Footer, Header, Input, Label, Markdown, Select
from heurams.context import config_var
from heurams.services.version import ver

View File

@@ -99,6 +99,7 @@ class PrecachingScreen(Screen):
if not cache_file.exists():
try:
from heurams.services.tts_service import convertor
convertor(text, cache_file)
return 1
except Exception as e:

View File

@@ -32,7 +32,7 @@ class SyncScreen(Screen):
# 标题和连接状态
yield Static("WebDAV 同步工具", classes="title")
yield Static("", id="status_label", classes="status")
# 配置信息
yield Static("服务器配置", classes="section_title")
with Horizontal(classes="config_info"):
@@ -44,7 +44,7 @@ class SyncScreen(Screen):
with Horizontal(classes="config_info"):
yield Static("同步模式:", classes="config_label")
yield Static("", id="sync_mode", classes="config_value")
# 控制按钮
yield Static("控制面板", classes="section_title")
with Horizontal(classes="control_buttons"):
@@ -52,16 +52,16 @@ class SyncScreen(Screen):
yield Button("开始同步", id="start_sync", variant="success")
yield Button("暂停", id="pause_sync", variant="warning", disabled=True)
yield Button("取消", id="cancel_sync", variant="error", disabled=True)
# 进度显示
yield Static("同步进度", classes="section_title")
yield ProgressBar(id="progress_bar", show_percentage=True, total=100)
yield Static("", id="progress_label", classes="progress_text")
# 日志输出
yield Static("同步日志", classes="section_title")
yield Static("", id="log_output", classes="log_output")
yield Footer()
def on_mount(self):
@@ -74,13 +74,15 @@ class SyncScreen(Screen):
"""从配置文件加载同步设置"""
try:
from heurams.context import config_var
config_data = config_var.get().data
self.sync_config = config_data.get('sync', {}).get('webdav', {})
self.sync_config = config_data.get("sync", {}).get("webdav", {})
# 创建同步服务实例
from heurams.services.sync_service import create_sync_service_from_config
self.sync_service = create_sync_service_from_config()
except Exception as e:
self.log_message(f"加载配置失败: {e}", is_error=True)
self.sync_config = {}
@@ -89,34 +91,34 @@ class SyncScreen(Screen):
"""更新 UI 显示配置信息"""
try:
# 更新服务器 URL
url = self.sync_config.get('url', '未配置')
url = self.sync_config.get("url", "未配置")
url_widget = self.query_one("#server_url")
url_widget.update(url if url else '未配置') # type: ignore
url_widget.update(url if url else "未配置") # type: ignore
# 更新远程路径
remote_path = self.sync_config.get('remote_path', '/heurams/')
remote_path = self.sync_config.get("remote_path", "/heurams/")
path_widget = self.query_one("#remote_path")
path_widget.update(remote_path) # type: ignore
path_widget.update(remote_path) # type: ignore
# 更新同步模式
sync_mode = self.sync_config.get('sync_mode', 'bidirectional')
sync_mode = self.sync_config.get("sync_mode", "bidirectional")
mode_widget = self.query_one("#sync_mode")
mode_map = {
'bidirectional': '双向同步',
'upload_only': '仅上传',
'download_only': '仅下载',
"bidirectional": "双向同步",
"upload_only": "仅上传",
"download_only": "仅下载",
}
mode_widget.update(mode_map.get(sync_mode, sync_mode)) # type: ignore
mode_widget.update(mode_map.get(sync_mode, sync_mode)) # type: ignore
# 更新状态标签
status_widget = self.query_one("#status_label")
if self.sync_service and self.sync_service.client:
status_widget.update("✅ 同步服务已就绪") # type: ignore
status_widget.update("✅ 同步服务已就绪") # type: ignore
status_widget.add_class("ready")
else:
status_widget.update("❌ 同步服务未配置或未启用") # type: ignore
status_widget.update("❌ 同步服务未配置或未启用") # type: ignore
status_widget.add_class("error")
except Exception as e:
self.log_message(f"更新 UI 失败: {e}", is_error=True)
@@ -124,15 +126,15 @@ class SyncScreen(Screen):
"""更新状态显示"""
try:
status_widget = self.query_one("#status_label")
status_widget.update(status) # type: ignore
status_widget.update(status) # type: ignore
if progress is not None:
progress_bar = self.query_one("#progress_bar")
progress_bar.progress = progress # type: ignore
progress_bar.progress = progress # type: ignore
progress_label = self.query_one("#progress_label")
progress_label.update(f"{progress}% - {current_item}" if current_item else f"{progress}%") # type: ignore
progress_label.update(f"{progress}% - {current_item}" if current_item else f"{progress}%") # type: ignore
except Exception as e:
self.log_message(f"更新状态失败: {e}", is_error=True)
@@ -141,23 +143,23 @@ class SyncScreen(Screen):
timestamp = time.strftime("%H:%M:%S")
prefix = "[ERROR]" if is_error else "[INFO]"
log_line = f"{timestamp} {prefix} {message}"
self.log_messages.append(log_line)
# 保持日志行数不超过最大值
if len(self.log_messages) > self.max_log_lines:
self.log_messages = self.log_messages[-self.max_log_lines:]
self.log_messages = self.log_messages[-self.max_log_lines :]
# 更新日志显示
try:
log_widget = self.query_one("#log_output")
log_widget.update("\n".join(self.log_messages)) # type: ignore
log_widget.update("\n".join(self.log_messages)) # type: ignore
except Exception:
pass # 如果组件未就绪,忽略错误
def on_button_pressed(self, event: Button.Pressed) -> None:
"""处理按钮点击事件"""
button_id = event.button.id
if button_id == "test_connection":
self.test_connection()
elif button_id == "start_sync":
@@ -166,7 +168,7 @@ class SyncScreen(Screen):
self.pause_sync()
elif button_id == "cancel_sync":
self.cancel_sync()
event.stop()
def test_connection(self):
@@ -175,10 +177,10 @@ class SyncScreen(Screen):
self.log_message("同步服务未初始化,请检查配置", is_error=True)
self.update_status("❌ 同步服务未初始化")
return
self.log_message("正在测试 WebDAV 连接...")
self.update_status("正在测试连接...")
try:
success = self.sync_service.test_connection()
if success:
@@ -196,71 +198,87 @@ class SyncScreen(Screen):
if not self.sync_service:
self.log_message("同步服务未初始化,无法开始同步", is_error=True)
return
if self.is_syncing:
self.log_message("同步已在进行中", is_error=True)
return
self.is_syncing = True
self.is_paused = False
self.update_button_states()
self.log_message("开始同步数据...")
self.update_status("正在同步...", progress=0)
# 启动后台同步任务
self.run_worker(self.perform_sync, thread=True)
def perform_sync(self):
"""执行同步任务(在后台线程中运行)"""
worker = get_current_worker()
try:
# 获取需要同步的本地目录
from heurams.context import config_var
config = config_var.get()
paths = config.get('paths', {})
paths = config.get("paths", {})
# 同步 nucleon 目录
nucleon_dir = pathlib.Path(paths.get('nucleon_dir', './data/nucleon'))
nucleon_dir = pathlib.Path(paths.get("nucleon_dir", "./data/nucleon"))
if nucleon_dir.exists():
self.log_message(f"同步 nucleon 目录: {nucleon_dir}")
self.update_status(f"同步 nucleon 目录...", progress=10)
result = self.sync_service.sync_directory(nucleon_dir) # type: ignore
if result.get('success'):
self.log_message(f"nucleon 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}")
result = self.sync_service.sync_directory(nucleon_dir) # type: ignore
if result.get("success"):
self.log_message(
f"nucleon 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}"
)
else:
self.log_message(f"nucleon 同步失败: {result.get('error', '未知错误')}", is_error=True)
self.log_message(
f"nucleon 同步失败: {result.get('error', '未知错误')}",
is_error=True,
)
# 同步 electron 目录
electron_dir = pathlib.Path(paths.get('electron_dir', './data/electron'))
electron_dir = pathlib.Path(paths.get("electron_dir", "./data/electron"))
if electron_dir.exists():
self.log_message(f"同步 electron 目录: {electron_dir}")
self.update_status(f"同步 electron 目录...", progress=60)
result = self.sync_service.sync_directory(electron_dir) # type: ignore
if result.get('success'):
self.log_message(f"electron 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}")
result = self.sync_service.sync_directory(electron_dir) # type: ignore
if result.get("success"):
self.log_message(
f"electron 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}"
)
else:
self.log_message(f"electron 同步失败: {result.get('error', '未知错误')}", is_error=True)
self.log_message(
f"electron 同步失败: {result.get('error', '未知错误')}",
is_error=True,
)
# 同步 orbital 目录(如果存在)
orbital_dir = pathlib.Path(paths.get('orbital_dir', './data/orbital'))
orbital_dir = pathlib.Path(paths.get("orbital_dir", "./data/orbital"))
if orbital_dir.exists():
self.log_message(f"同步 orbital 目录: {orbital_dir}")
self.update_status(f"同步 orbital 目录...", progress=80)
result = self.sync_service.sync_directory(orbital_dir) # type: ignore
if result.get('success'):
self.log_message(f"orbital 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}")
result = self.sync_service.sync_directory(orbital_dir) # type: ignore
if result.get("success"):
self.log_message(
f"orbital 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)}"
)
else:
self.log_message(f"orbital 同步失败: {result.get('error', '未知错误')}", is_error=True)
self.log_message(
f"orbital 同步失败: {result.get('error', '未知错误')}",
is_error=True,
)
# 同步完成
self.update_status("同步完成", progress=100)
self.log_message("所有目录同步完成")
except Exception as e:
self.log_message(f"同步过程中发生错误: {e}", is_error=True)
self.update_status("同步失败")
@@ -268,16 +286,16 @@ class SyncScreen(Screen):
# 重置同步状态
self.is_syncing = False
self.is_paused = False
self.update_button_states() # type: ignore
self.update_button_states() # type: ignore
def pause_sync(self):
"""暂停同步"""
if not self.is_syncing:
return
self.is_paused = not self.is_paused
self.update_button_states()
if self.is_paused:
self.log_message("同步已暂停")
self.update_status("同步已暂停")
@@ -289,11 +307,11 @@ class SyncScreen(Screen):
"""取消同步"""
if not self.is_syncing:
return
self.is_syncing = False
self.is_paused = False
self.update_button_states()
self.log_message("同步已取消")
self.update_status("同步已取消")
@@ -303,17 +321,17 @@ class SyncScreen(Screen):
start_button = self.query_one("#start_sync")
pause_button = self.query_one("#pause_sync")
cancel_button = self.query_one("#cancel_sync")
if self.is_syncing:
start_button.disabled = True
pause_button.disabled = False
cancel_button.disabled = False
pause_button.label = "继续" if self.is_paused else "暂停" # type: ignore
pause_button.label = "继续" if self.is_paused else "暂停" # type: ignore
else:
start_button.disabled = False
pause_button.disabled = True
cancel_button.disabled = True
except Exception as e:
self.log_message(f"更新按钮状态失败: {e}", is_error=True)

View File

@@ -50,9 +50,10 @@ class Recognition(BasePuzzleWidget):
def compose(self):
from heurams.context import config_var
autovoice = config_var.get()['interface']['memorizor']['autovoice']
autovoice = config_var.get()["interface"]["memorizor"]["autovoice"]
if autovoice:
self.screen.action_play_voice() # type: ignore
self.screen.action_play_voice() # type: ignore
cfg: RecognitionConfig = self.atom.registry["orbital"]["puzzles"][self.alia]
delim = self.atom.registry["nucleon"].metadata["formation"]["delimiter"]
replace_dict = {
@@ -72,7 +73,7 @@ class Recognition(BasePuzzleWidget):
primary = cfg["primary"]
with Center():
for i in cfg['top_dim']:
for i in cfg["top_dim"]:
yield Static(f"[dim]{i}[/]")
yield Label("")

View File

@@ -10,15 +10,25 @@ MIT 许可证
import datetime
import json
import os
from typing import TypedDict
import pathlib
from typing import TypedDict
from heurams.context import config_var
from heurams.kernel.algorithms.sm15m_calc import (MAX_AF, MIN_AF, NOTCH_AF,
RANGE_AF, RANGE_REPETITION,
SM, THRESHOLD_RECALL, Item)
from heurams.kernel.algorithms.sm15m_calc import (
MAX_AF,
MIN_AF,
NOTCH_AF,
RANGE_AF,
RANGE_REPETITION,
SM,
THRESHOLD_RECALL,
Item,
)
# 全局状态文件路径
_GLOBAL_STATE_FILE = os.path.expanduser(pathlib.Path(config_var.get()['paths']['global_dir']) / 'sm15m_global_state.json')
_GLOBAL_STATE_FILE = os.path.expanduser(
pathlib.Path(config_var.get()["paths"]["global_dir"]) / "sm15m_global_state.json"
)
def _get_global_sm():

View File

@@ -86,8 +86,8 @@ class Atom:
# eval 环境设置
def eval_with_env(s: str):
default = config_var.get()["puzzles"]
payload = self.registry['nucleon'].payload
metadata = self.registry['nucleon'].metadata
payload = self.registry["nucleon"].payload
metadata = self.registry["nucleon"].metadata
eval_value = eval(s)
if isinstance(eval_value, (int, float)):
ret = str(eval_value)
@@ -117,10 +117,11 @@ class Atom:
logger.debug("发现 eval 表达式: '%s'", data[5:])
return modifier(data[5:])
return data
try:
traverse(self.registry['nucleon'].payload, eval_with_env)
traverse(self.registry['nucleon'].metadata, eval_with_env)
traverse(self.registry['orbital'], eval_with_env)
traverse(self.registry["nucleon"].payload, eval_with_env)
traverse(self.registry["nucleon"].metadata, eval_with_env)
traverse(self.registry["orbital"], eval_with_env)
except Exception as e:
ret = f"此 eval 实例发生错误: {e}"
logger.warning(ret)

View File

@@ -18,9 +18,12 @@ class Electron:
algo: 使用的算法模块标识
"""
if algo_name == "":
algo_name = config_var.get()['algorithm']['default']
algo_name = config_var.get()["algorithm"]["default"]
logger.debug(
"创建 Electron 实例, ident: '%s', algo_name: '%s', algodata: %s", ident, algo_name, algodata
"创建 Electron 实例, ident: '%s', algo_name: '%s', algodata: %s",
ident,
algo_name,
algodata,
)
self.algodata = algodata
self.ident = ident
@@ -31,7 +34,9 @@ class Electron:
self.algodata[self.algo.algo_name] = {}
logger.debug("算法键 '%s' 不存在, 已创建空字典", self.algo)
if not self.algodata[self.algo.algo_name]:
logger.debug(f"算法数据为空, 使用默认值初始化{self.algodata[self.algo.algo_name]}")
logger.debug(
f"算法数据为空, 使用默认值初始化{self.algodata[self.algo.algo_name]}"
)
self._default_init(self.algo.defaults)
else:
logger.debug("算法数据已存在, 跳过默认初始化")

View File

@@ -53,4 +53,4 @@ class Nucleon:
def placeholder():
"""生成一个占位原子核"""
logger.debug("创建 Nucleon 占位符")
return Nucleon("核子对象样例内容", {})
return Nucleon("核子对象样例内容", {})

View File

@@ -2,8 +2,8 @@ import pathlib
import edge_tts
from heurams.services.logger import get_logger
from heurams.context import config_var
from heurams.services.logger import get_logger
from .base import BaseTTS
@@ -19,7 +19,7 @@ class EdgeTTS(BaseTTS):
try:
communicate = edge_tts.Communicate(
text,
config_var.get()['providers']['tts']['edgetts']["voice"],
config_var.get()["providers"]["tts"]["edgetts"]["voice"],
)
logger.debug("EdgeTTS 通信对象创建成功, 正在保存音频")
communicate.save_sync(str(path))

View File

@@ -18,6 +18,7 @@ logger = get_logger(__name__)
class SyncMode(Enum):
"""同步模式枚举"""
BIDIRECTIONAL = "bidirectional"
UPLOAD_ONLY = "upload_only"
DOWNLOAD_ONLY = "download_only"
@@ -25,6 +26,7 @@ class SyncMode(Enum):
class ConflictStrategy(Enum):
"""冲突解决策略枚举"""
NEWER = "newer" # 较新文件覆盖较旧文件
ASK = "ask" # 用户手动选择
KEEP_BOTH = "keep_both" # 保留双方(重命名)
@@ -33,6 +35,7 @@ class ConflictStrategy(Enum):
@dataclass
class SyncConfig:
"""同步配置数据类"""
enabled: bool = False
url: str = ""
username: str = ""
@@ -59,12 +62,12 @@ class SyncService:
return
options = {
'webdav_hostname': self.config.url,
'webdav_login': self.config.username,
'webdav_password': self.config.password,
'webdav_root': self.config.remote_path,
'verify_ssl': self.config.verify_ssl,
'disable_check': True, # 不检查服务器支持的功能
"webdav_hostname": self.config.url,
"webdav_login": self.config.username,
"webdav_password": self.config.password,
"webdav_root": self.config.remote_path,
"verify_ssl": self.config.verify_ssl,
"disable_check": True, # 不检查服务器支持的功能
}
try:
@@ -98,10 +101,10 @@ class SyncService:
rel_path = file_path.relative_to(local_dir)
stat = file_path.stat()
files[str(rel_path)] = {
'path': file_path,
'size': stat.st_size,
'mtime': stat.st_mtime,
'hash': self._calculate_hash(file_path),
"path": file_path,
"size": stat.st_size,
"mtime": stat.st_mtime,
"hash": self._calculate_hash(file_path),
}
return files
@@ -114,14 +117,14 @@ class SyncService:
remote_list = self.client.list(recursive=True)
files = {}
for item in remote_list:
if not item.endswith('/'): # 忽略目录
rel_path = item.lstrip('/')
if not item.endswith("/"): # 忽略目录
rel_path = item.lstrip("/")
try:
info = self.client.info(item)
files[rel_path] = {
'path': item,
'size': info.get('size', 0),
'mtime': self._parse_remote_mtime(info),
"path": item,
"size": info.get("size", 0),
"mtime": self._parse_remote_mtime(info),
}
except Exception as e:
logger.warning("无法获取远程文件信息 %s: %s", item, e)
@@ -134,8 +137,8 @@ class SyncService:
"""计算文件的 SHA-256 哈希值"""
sha256 = hashlib.sha256()
try:
with open(file_path, 'rb') as f:
for block in iter(lambda: f.read(block_size), b''):
with open(file_path, "rb") as f:
for block in iter(lambda: f.read(block_size), b""):
sha256.update(block)
return sha256.hexdigest()
except Exception as e:
@@ -151,23 +154,23 @@ class SyncService:
def sync_directory(self, local_dir: pathlib.Path) -> typing.Dict[str, typing.Any]:
"""
同步目录
Args:
local_dir: 本地目录路径
Returns:
同步结果统计
"""
if not self.client:
logger.error("WebDAV 客户端未初始化")
return {'success': False, 'error': '客户端未初始化'}
return {"success": False, "error": "客户端未初始化"}
results = {
'uploaded': 0,
'downloaded': 0,
'conflicts': 0,
'errors': 0,
'success': True,
"uploaded": 0,
"downloaded": 0,
"conflicts": 0,
"errors": 0,
"success": True,
}
try:
@@ -180,124 +183,144 @@ class SyncService:
# 根据同步模式处理文件
if self.config.sync_mode in [SyncMode.BIDIRECTIONAL, SyncMode.UPLOAD_ONLY]:
stats = self._upload_files(local_dir, local_files, remote_files)
results['uploaded'] += stats.get('uploaded', 0)
results['conflicts'] += stats.get('conflicts', 0)
results['errors'] += stats.get('errors', 0)
results["uploaded"] += stats.get("uploaded", 0)
results["conflicts"] += stats.get("conflicts", 0)
results["errors"] += stats.get("errors", 0)
if self.config.sync_mode in [SyncMode.BIDIRECTIONAL, SyncMode.DOWNLOAD_ONLY]:
if self.config.sync_mode in [
SyncMode.BIDIRECTIONAL,
SyncMode.DOWNLOAD_ONLY,
]:
stats = self._download_files(local_dir, local_files, remote_files)
results['downloaded'] += stats.get('downloaded', 0)
results['conflicts'] += stats.get('conflicts', 0)
results['errors'] += stats.get('errors', 0)
results["downloaded"] += stats.get("downloaded", 0)
results["conflicts"] += stats.get("conflicts", 0)
results["errors"] += stats.get("errors", 0)
logger.info("同步完成: %s", results)
return results
except Exception as e:
logger.error("同步过程中发生错误: %s", e)
results['success'] = False
results['error'] = str(e)
results["success"] = False
results["error"] = str(e)
return results
def _upload_files(self, local_dir: pathlib.Path,
local_files: dict, remote_files: dict) -> typing.Dict[str, int]:
def _upload_files(
self, local_dir: pathlib.Path, local_files: dict, remote_files: dict
) -> typing.Dict[str, int]:
"""上传文件到远程服务器"""
stats = {'uploaded': 0, 'errors': 0, 'conflicts': 0}
stats = {"uploaded": 0, "errors": 0, "conflicts": 0}
for rel_path, local_info in local_files.items():
remote_info = remote_files.get(rel_path)
# 判断是否需要上传
should_upload = False
conflict_resolved = False
remote_path = os.path.join(self.config.remote_path, rel_path)
if not remote_info:
should_upload = True # 远程不存在
else:
# 检查冲突
local_mtime = local_info.get('mtime', 0)
remote_mtime = remote_info.get('mtime', 0)
local_mtime = local_info.get("mtime", 0)
remote_mtime = remote_info.get("mtime", 0)
if local_mtime != remote_mtime:
# 存在冲突
stats['conflicts'] += 1
should_upload, should_download = self._handle_conflict(local_info, remote_info)
if should_upload and self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH:
stats["conflicts"] += 1
should_upload, should_download = self._handle_conflict(
local_info, remote_info
)
if (
should_upload
and self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH
):
# 重命名远程文件避免覆盖
conflict_suffix = f".conflict_{int(remote_mtime)}"
name, ext = os.path.splitext(rel_path)
new_rel_path = f"{name}{conflict_suffix}{ext}" if ext else f"{name}{conflict_suffix}"
remote_path = os.path.join(self.config.remote_path, new_rel_path)
new_rel_path = (
f"{name}{conflict_suffix}{ext}"
if ext
else f"{name}{conflict_suffix}"
)
remote_path = os.path.join(
self.config.remote_path, new_rel_path
)
conflict_resolved = True
logger.debug("冲突文件重命名: %s -> %s", rel_path, new_rel_path)
else:
# 时间相同,无需上传
should_upload = False
if should_upload:
try:
self.client.upload_file(local_info['path'], remote_path)
stats['uploaded'] += 1
self.client.upload_file(local_info["path"], remote_path)
stats["uploaded"] += 1
logger.debug("上传文件: %s -> %s", rel_path, remote_path)
except Exception as e:
logger.error("上传文件失败 %s: %s", rel_path, e)
stats['errors'] += 1
stats["errors"] += 1
return stats
def _download_files(self, local_dir: pathlib.Path,
local_files: dict, remote_files: dict) -> typing.Dict[str, int]:
def _download_files(
self, local_dir: pathlib.Path, local_files: dict, remote_files: dict
) -> typing.Dict[str, int]:
"""从远程服务器下载文件"""
stats = {'downloaded': 0, 'errors': 0, 'conflicts': 0}
stats = {"downloaded": 0, "errors": 0, "conflicts": 0}
for rel_path, remote_info in remote_files.items():
local_info = local_files.get(rel_path)
# 判断是否需要下载
should_download = False
if not local_info:
should_download = True # 本地不存在
else:
# 检查冲突
local_mtime = local_info.get('mtime', 0)
remote_mtime = remote_info.get('mtime', 0)
local_mtime = local_info.get("mtime", 0)
remote_mtime = remote_info.get("mtime", 0)
if local_mtime != remote_mtime:
# 存在冲突
stats['conflicts'] += 1
should_upload, should_download = self._handle_conflict(local_info, remote_info)
stats["conflicts"] += 1
should_upload, should_download = self._handle_conflict(
local_info, remote_info
)
# 如果应该上传,则不应该下载(冲突已在上传侧处理)
if should_upload:
should_download = False
else:
# 时间相同,无需下载
should_download = False
if should_download:
try:
local_path = local_dir / rel_path
local_path.parent.mkdir(parents=True, exist_ok=True)
self.client.download_file(remote_info['path'], str(local_path))
stats['downloaded'] += 1
self.client.download_file(remote_info["path"], str(local_path))
stats["downloaded"] += 1
logger.debug("下载文件: %s -> %s", rel_path, local_path)
except Exception as e:
logger.error("下载文件失败 %s: %s", rel_path, e)
stats['errors'] += 1
stats["errors"] += 1
return stats
def _handle_conflict(self, local_info: dict, remote_info: dict) -> typing.Tuple[bool, bool]:
def _handle_conflict(
self, local_info: dict, remote_info: dict
) -> typing.Tuple[bool, bool]:
"""
处理文件冲突
Returns:
(should_upload, should_download) - 是否应该上传和下载
"""
local_mtime = local_info.get('mtime', 0)
remote_mtime = remote_info.get('mtime', 0)
local_mtime = local_info.get("mtime", 0)
remote_mtime = remote_info.get("mtime", 0)
if self.config.conflict_strategy == ConflictStrategy.NEWER:
# 较新文件覆盖较旧文件
if local_mtime > remote_mtime:
@@ -306,7 +329,7 @@ class SyncService:
return False, True # 下载远程较新版本
else:
return False, False # 时间相同,无需操作
elif self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH:
# 保留双方 - 重命名远程文件
# 这里实现简单的重命名策略:添加冲突后缀
@@ -314,25 +337,28 @@ class SyncService:
# 返回 True, False 表示上传重命名后的文件
# 重命名逻辑在调用处处理
return True, False
elif self.config.conflict_strategy == ConflictStrategy.ASK:
# 用户手动选择 - 记录冲突,跳过
# 返回 False, False 跳过,等待用户决定
logger.warning("文件冲突需要用户手动选择: local_mtime=%s, remote_mtime=%s",
local_mtime, remote_mtime)
logger.warning(
"文件冲突需要用户手动选择: local_mtime=%s, remote_mtime=%s",
local_mtime,
remote_mtime,
)
return False, False
return False, False
def _should_upload(self, local_info: dict, remote_info: dict) -> bool:
"""判断是否需要上传(本地较新或哈希不同)"""
# 这里实现简单的基于时间的比较
# 实际应该使用哈希比较更可靠
return local_info.get('mtime', 0) > remote_info.get('mtime', 0)
return local_info.get("mtime", 0) > remote_info.get("mtime", 0)
def _should_download(self, local_info: dict, remote_info: dict) -> bool:
"""判断是否需要下载(远程较新)"""
return remote_info.get('mtime', 0) > local_info.get('mtime', 0)
return remote_info.get("mtime", 0) > local_info.get("mtime", 0)
def upload_file(self, local_path: pathlib.Path, remote_path: str = "") -> bool:
"""上传单个文件"""
@@ -381,30 +407,32 @@ def create_sync_service_from_config() -> typing.Optional[SyncService]:
"""从配置文件创建同步服务实例"""
try:
from heurams.context import config_var
sync_config = config_var.get()['providers']['sync']['webdav']
if not sync_config.get('enabled', False):
sync_config = config_var.get()["providers"]["sync"]["webdav"]
if not sync_config.get("enabled", False):
logger.debug("同步服务未启用")
return None
config = SyncConfig(
enabled=sync_config.get('enabled', False),
url=sync_config.get('url', ''),
username=sync_config.get('username', ''),
password=sync_config.get('password', ''),
remote_path=sync_config.get('remote_path', '/heurams/'),
sync_mode=SyncMode(sync_config.get('sync_mode', 'bidirectional')),
conflict_strategy=ConflictStrategy(sync_config.get('conflict_strategy', 'newer')),
verify_ssl=sync_config.get('verify_ssl', True),
enabled=sync_config.get("enabled", False),
url=sync_config.get("url", ""),
username=sync_config.get("username", ""),
password=sync_config.get("password", ""),
remote_path=sync_config.get("remote_path", "/heurams/"),
sync_mode=SyncMode(sync_config.get("sync_mode", "bidirectional")),
conflict_strategy=ConflictStrategy(
sync_config.get("conflict_strategy", "newer")
),
verify_ssl=sync_config.get("verify_ssl", True),
)
service = SyncService(config)
if service.client is None:
logger.warning("同步服务客户端创建失败")
return None
return service
except Exception as e:
logger.error("创建同步服务失败: %s", e)
return None
return None

View File

@@ -6,11 +6,16 @@ import pathlib
import tempfile
import time
import unittest
from unittest.mock import MagicMock, patch, Mock
from unittest.mock import MagicMock, Mock, patch
from heurams.context import ConfigContext
from heurams.services.config import ConfigFile
from heurams.services.sync_service import SyncService, SyncConfig, SyncMode, ConflictStrategy
from heurams.services.sync_service import (
ConflictStrategy,
SyncConfig,
SyncMode,
SyncService,
)
class TestSyncServiceUnit(unittest.TestCase):
@@ -20,14 +25,14 @@ class TestSyncServiceUnit(unittest.TestCase):
"""在每个测试之前运行, 设置临时目录和模拟客户端."""
self.temp_dir = tempfile.TemporaryDirectory()
self.temp_path = pathlib.Path(self.temp_dir.name)
# 创建测试文件
self.test_file = self.temp_path / "test.txt"
self.test_file.write_text("测试内容")
# 模拟 WebDAV 客户端
self.mock_client = MagicMock()
# 创建同步配置
self.config = SyncConfig(
enabled=True,
@@ -44,100 +49,100 @@ class TestSyncServiceUnit(unittest.TestCase):
"""在每个测试之后清理."""
self.temp_dir.cleanup()
@patch('heurams.services.sync_service.Client')
@patch("heurams.services.sync_service.Client")
def test_sync_service_initialization(self, mock_client_class):
"""测试同步服务初始化."""
mock_client_class.return_value = self.mock_client
service = SyncService(self.config)
# 验证客户端已创建
mock_client_class.assert_called_once()
self.assertIsNotNone(service.client)
self.assertEqual(service.config, self.config)
@patch('heurams.services.sync_service.Client')
@patch("heurams.services.sync_service.Client")
def test_sync_service_disabled(self, mock_client_class):
"""测试同步服务未启用."""
config = SyncConfig(enabled=False)
service = SyncService(config)
# 客户端不应初始化
mock_client_class.assert_not_called()
self.assertIsNone(service.client)
@patch('heurams.services.sync_service.Client')
@patch("heurams.services.sync_service.Client")
def test_test_connection_success(self, mock_client_class):
"""测试连接测试成功."""
mock_client_class.return_value = self.mock_client
self.mock_client.list.return_value = []
service = SyncService(self.config)
result = service.test_connection()
self.assertTrue(result)
self.mock_client.list.assert_called_once()
@patch('heurams.services.sync_service.Client')
@patch("heurams.services.sync_service.Client")
def test_test_connection_failure(self, mock_client_class):
"""测试连接测试失败."""
mock_client_class.return_value = self.mock_client
self.mock_client.list.side_effect = Exception("连接失败")
service = SyncService(self.config)
result = service.test_connection()
self.assertFalse(result)
self.mock_client.list.assert_called_once()
@patch('heurams.services.sync_service.Client')
@patch("heurams.services.sync_service.Client")
def test_upload_file(self, mock_client_class):
"""测试上传单个文件."""
mock_client_class.return_value = self.mock_client
service = SyncService(self.config)
result = service.upload_file(self.test_file)
self.assertTrue(result)
self.mock_client.upload_file.assert_called_once()
@patch('heurams.services.sync_service.Client')
@patch("heurams.services.sync_service.Client")
def test_download_file(self, mock_client_class):
"""测试下载单个文件."""
mock_client_class.return_value = self.mock_client
service = SyncService(self.config)
remote_path = "/heurams/test.txt"
local_path = self.temp_path / "downloaded.txt"
result = service.download_file(remote_path, local_path)
self.assertTrue(result)
self.mock_client.download_file.assert_called_once()
self.assertTrue(local_path.parent.exists())
@patch('heurams.services.sync_service.Client')
@patch("heurams.services.sync_service.Client")
def test_sync_directory_no_files(self, mock_client_class):
"""测试同步空目录."""
mock_client_class.return_value = self.mock_client
self.mock_client.list.return_value = []
self.mock_client.mkdir.return_value = None
service = SyncService(self.config)
result = service.sync_directory(self.temp_path)
self.assertTrue(result['success'])
self.assertEqual(result['uploaded'], 0)
self.assertEqual(result['downloaded'], 0)
self.assertTrue(result["success"])
self.assertEqual(result["uploaded"], 0)
self.assertEqual(result["downloaded"], 0)
self.mock_client.mkdir.assert_called_once()
@patch('heurams.services.sync_service.Client')
@patch("heurams.services.sync_service.Client")
def test_sync_directory_upload_only(self, mock_client_class):
"""测试仅上传模式."""
mock_client_class.return_value = self.mock_client
self.mock_client.list.return_value = []
self.mock_client.mkdir.return_value = None
config = SyncConfig(
enabled=True,
url="https://example.com/dav/",
@@ -147,60 +152,64 @@ class TestSyncServiceUnit(unittest.TestCase):
sync_mode=SyncMode.UPLOAD_ONLY,
conflict_strategy=ConflictStrategy.NEWER,
)
service = SyncService(config)
result = service.sync_directory(self.temp_path)
self.assertTrue(result['success'])
self.assertTrue(result["success"])
self.mock_client.mkdir.assert_called_once()
@patch('heurams.services.sync_service.Client')
@patch("heurams.services.sync_service.Client")
def test_conflict_strategy_newer(self, mock_client_class):
"""测试 NEWER 冲突策略."""
mock_client_class.return_value = self.mock_client
# 模拟远程文件存在
self.mock_client.list.return_value = ["test.txt"]
self.mock_client.info.return_value = {'size': 100, 'modified': '2023-01-01T00:00:00Z'}
self.mock_client.info.return_value = {
"size": 100,
"modified": "2023-01-01T00:00:00Z",
}
self.mock_client.mkdir.return_value = None
service = SyncService(self.config)
result = service.sync_directory(self.temp_path)
self.assertTrue(result['success'])
# 应该有一个冲突
self.assertGreaterEqual(result.get('conflicts', 0), 0)
@patch('heurams.services.sync_service.Client')
self.assertTrue(result["success"])
# 应该有一个冲突
self.assertGreaterEqual(result.get("conflicts", 0), 0)
@patch("heurams.services.sync_service.Client")
def test_create_sync_service_from_config(self, mock_client_class):
"""测试从配置文件创建同步服务."""
mock_client_class.return_value = self.mock_client
# 创建临时配置文件
config_data = {
'sync': {
'webdav': {
'enabled': True,
'url': 'https://example.com/dav/',
'username': 'test',
'password': 'test',
'remote_path': '/heurams/',
'sync_mode': 'bidirectional',
'conflict_strategy': 'newer',
'verify_ssl': True,
"sync": {
"webdav": {
"enabled": True,
"url": "https://example.com/dav/",
"username": "test",
"password": "test",
"remote_path": "/heurams/",
"sync_mode": "bidirectional",
"conflict_strategy": "newer",
"verify_ssl": True,
}
}
}
# 模拟 config_var
with patch('heurams.services.sync_service.config_var') as mock_config_var:
with patch("heurams.services.sync_service.config_var") as mock_config_var:
mock_config = MagicMock()
mock_config.data = config_data
mock_config_var.get.return_value = mock_config
from heurams.services.sync_service import create_sync_service_from_config
service = create_sync_service_from_config()
self.assertIsNotNone(service)
self.assertIsNotNone(service.client)
@@ -212,39 +221,41 @@ class TestSyncScreenUnit(unittest.TestCase):
"""在每个测试之前运行."""
self.temp_dir = tempfile.TemporaryDirectory()
self.temp_path = pathlib.Path(self.temp_dir.name)
# 创建默认配置
default_config_path = (
pathlib.Path(__file__).parent.parent.parent
/ "src/heurams/default/config/config.toml"
)
self.config = ConfigFile(default_config_path)
# 更新配置中的路径
config_data = self.config.data
config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon")
config_data["paths"]["electron_dir"] = str(self.temp_path / "electron")
config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital")
config_data["paths"]["cache_dir"] = str(self.temp_path / "cache")
# 添加同步配置
if 'sync' not in config_data:
config_data['sync'] = {}
config_data['sync']['webdav'] = {
'enabled': False,
'url': '',
'username': '',
'password': '',
'remote_path': '/heurams/',
'sync_mode': 'bidirectional',
'conflict_strategy': 'newer',
'verify_ssl': True,
if "sync" not in config_data:
config_data["sync"] = {}
config_data["sync"]["webdav"] = {
"enabled": False,
"url": "",
"username": "",
"password": "",
"remote_path": "/heurams/",
"sync_mode": "bidirectional",
"conflict_strategy": "newer",
"verify_ssl": True,
}
# 创建目录
for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]:
pathlib.Path(config_data["paths"][dir_key]).mkdir(parents=True, exist_ok=True)
pathlib.Path(config_data["paths"][dir_key]).mkdir(
parents=True, exist_ok=True
)
# 使用 ConfigContext 设置配置
self.config_ctx = ConfigContext(self.config)
self.config_ctx.__enter__()
@@ -254,52 +265,53 @@ class TestSyncScreenUnit(unittest.TestCase):
self.config_ctx.__exit__(None, None, None)
self.temp_dir.cleanup()
@patch('heurams.interface.screens.synctool.create_sync_service_from_config')
@patch("heurams.interface.screens.synctool.create_sync_service_from_config")
def test_sync_screen_compose(self, mock_create_service):
"""测试 SyncScreen 的 compose 方法."""
from heurams.interface.screens.synctool import SyncScreen
# 模拟同步服务
mock_service = MagicMock()
mock_service.client = MagicMock()
mock_create_service.return_value = mock_service
screen = SyncScreen()
# 测试 compose 方法
from textual.app import ComposeResult
result = screen.compose()
widgets = list(result)
# 检查基本部件
from textual.widgets import Footer, Header, Button, Static, ProgressBar
from textual.containers import ScrollableContainer
from textual.widgets import Button, Footer, Header, ProgressBar, Static
header_present = any(isinstance(w, Header) for w in widgets)
footer_present = any(isinstance(w, Footer) for w in widgets)
self.assertTrue(header_present)
self.assertTrue(footer_present)
# 检查容器
container_present = any(isinstance(w, ScrollableContainer) for w in widgets)
self.assertTrue(container_present)
@patch('heurams.interface.screens.synctool.create_sync_service_from_config')
@patch("heurams.interface.screens.synctool.create_sync_service_from_config")
def test_sync_screen_load_config(self, mock_create_service):
"""测试 SyncScreen 加载配置."""
from heurams.interface.screens.synctool import SyncScreen
mock_service = MagicMock()
mock_service.client = MagicMock()
mock_create_service.return_value = mock_service
screen = SyncScreen()
screen.load_config()
# 验证配置已加载
self.assertIsNotNone(screen.sync_config)
mock_create_service.assert_called_once()
if __name__ == "__main__":
unittest.main()
unittest.main()