feat(synctool): 增加同步功能
This commit is contained in:
23
README.md
23
README.md
@@ -30,6 +30,7 @@
|
|||||||
- 自然语音: 集成微软神经网络文本转语音 (TTS) 技术
|
- 自然语音: 集成微软神经网络文本转语音 (TTS) 技术
|
||||||
- 多种谜题类型: 选择题 (MCQ)、填空题 (Cloze)、识别题 (Recognition)
|
- 多种谜题类型: 选择题 (MCQ)、填空题 (Cloze)、识别题 (Recognition)
|
||||||
- 动态内容生成: 支持宏驱动的模板系统, 根据上下文动态生成题目
|
- 动态内容生成: 支持宏驱动的模板系统, 根据上下文动态生成题目
|
||||||
|
- 云同步支持: 通过 WebDAV 协议同步数据到远程服务器
|
||||||
|
|
||||||
### 实用用户界面
|
### 实用用户界面
|
||||||
- 响应式 Textual 框架构建的跨平台 TUI 界面
|
- 响应式 Textual 框架构建的跨平台 TUI 界面
|
||||||
@@ -82,7 +83,23 @@ python -m heurams.interface
|
|||||||
|
|
||||||
## 配置
|
## 配置
|
||||||
|
|
||||||
配置文件位于 `config/config.toml`(相对于工作目录). 如果不存在, 会使用内置的默认配置.
|
配置文件位于 `config/config.toml`(相对于工作目录). 如果不存在, 会使用内置的默认配置.
|
||||||
|
|
||||||
|
### 同步配置
|
||||||
|
同步功能支持 WebDAV 协议,可在配置文件的 `[sync.webdav]` 段进行配置:
|
||||||
|
```toml
|
||||||
|
[sync.webdav]
|
||||||
|
enabled = false
|
||||||
|
url = "" # WebDAV 服务器地址
|
||||||
|
username = "" # 用户名
|
||||||
|
password = "" # 密码
|
||||||
|
remote_path = "/heurams/" # 远程路径
|
||||||
|
sync_mode = "bidirectional" # 同步模式: bidirectional/upload_only/download_only
|
||||||
|
conflict_strategy = "newer" # 冲突策略: newer/ask/keep_both
|
||||||
|
verify_ssl = true # SSL 证书验证
|
||||||
|
```
|
||||||
|
|
||||||
|
启用同步后,可通过应用内的同步工具进行数据备份和恢复。
|
||||||
|
|
||||||
## 项目结构
|
## 项目结构
|
||||||
|
|
||||||
@@ -104,6 +121,7 @@ graph TB
|
|||||||
Timer[时间服务]
|
Timer[时间服务]
|
||||||
AudioService[音频服务]
|
AudioService[音频服务]
|
||||||
TTSService[TTS服务]
|
TTSService[TTS服务]
|
||||||
|
SyncService[同步服务]
|
||||||
OtherServices[其他服务]
|
OtherServices[其他服务]
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -156,7 +174,8 @@ src/heurams/
|
|||||||
│ ├── logger.py # 日志系统
|
│ ├── logger.py # 日志系统
|
||||||
│ ├── timer.py # 时间服务
|
│ ├── timer.py # 时间服务
|
||||||
│ ├── audio_service.py # 音频播放抽象
|
│ ├── audio_service.py # 音频播放抽象
|
||||||
│ └── tts_service.py # 文本转语音抽象
|
│ ├── tts_service.py # 文本转语音抽象
|
||||||
|
│ └── sync_service.py # WebDAV 同步服务
|
||||||
├── kernel/ # 核心业务逻辑
|
├── kernel/ # 核心业务逻辑
|
||||||
│ ├── algorithms/ # 间隔重复算法 (FSRS, SM2)
|
│ ├── algorithms/ # 间隔重复算法 (FSRS, SM2)
|
||||||
│ ├── particles/ # 数据模型 (Atom, Electron, Nucleon, Orbital)
|
│ ├── particles/ # 数据模型 (Atom, Electron, Nucleon, Orbital)
|
||||||
|
|||||||
@@ -49,3 +49,13 @@ voice = "zh-CN-XiaoxiaoNeural" # 可选项: zh-CN-YunjianNeural (男声), zh-CN-
|
|||||||
[providers.llm.openai] # 与 OpenAI 相容的语言模型接口服务设置
|
[providers.llm.openai] # 与 OpenAI 相容的语言模型接口服务设置
|
||||||
url = ""
|
url = ""
|
||||||
key = ""
|
key = ""
|
||||||
|
|
||||||
|
[sync.webdav] # WebDAV 同步设置
|
||||||
|
enabled = false
|
||||||
|
url = ""
|
||||||
|
username = ""
|
||||||
|
password = ""
|
||||||
|
remote_path = "/heurams/"
|
||||||
|
sync_mode = "bidirectional" # bidirectional/upload_only/download_only
|
||||||
|
conflict_strategy = "newer" # newer/ask/keep_both
|
||||||
|
verify_ssl = true
|
||||||
|
|||||||
@@ -2,3 +2,5 @@ bidict==0.23.1
|
|||||||
playsound==1.2.2
|
playsound==1.2.2
|
||||||
textual==5.3.0
|
textual==5.3.0
|
||||||
toml==0.10.2
|
toml==0.10.2
|
||||||
|
requests>=2.31.0
|
||||||
|
webdavclient3>=3.0.0
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ try:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("未能加载自定义用户配置")
|
print("未能加载自定义用户配置")
|
||||||
logger.warning("未能加载自定义用户配置, 错误: %s", e)
|
logger.warning("未能加载自定义用户配置, 错误: %s", e)
|
||||||
if pathlib.Path(rootdir / "default" / "config" / "config_dev.toml").exists():
|
if pathlib.Path(workdir / "config" / "config_dev.toml").exists():
|
||||||
|
print("使用开发设置")
|
||||||
logger.debug("使用开发设置")
|
logger.debug("使用开发设置")
|
||||||
config_var: ContextVar[ConfigFile] = ContextVar(
|
config_var: ContextVar[ConfigFile] = ContextVar(
|
||||||
"config_var", default=ConfigFile(workdir / "config" / "config_dev.toml")
|
"config_var", default=ConfigFile(workdir / "config" / "config_dev.toml")
|
||||||
|
|||||||
@@ -14,6 +14,14 @@ scheduled_num = 8
|
|||||||
# UTC 时间戳修正 仅用于 UNIX 日时间戳的生成修正, 单位为秒
|
# UTC 时间戳修正 仅用于 UNIX 日时间戳的生成修正, 单位为秒
|
||||||
timezone_offset = +28800 # 中国标准时间 (UTC+8)
|
timezone_offset = +28800 # 中国标准时间 (UTC+8)
|
||||||
|
|
||||||
|
[interface]
|
||||||
|
|
||||||
|
[interface.memorizor]
|
||||||
|
autovoice = true # 自动语音播放, 仅限于 recognition 组件
|
||||||
|
|
||||||
|
[algorithm]
|
||||||
|
default = "SM-2" # 主要算法; 可选项: SM-2, SM-15M, FSRS
|
||||||
|
|
||||||
[puzzles] # 谜题默认配置
|
[puzzles] # 谜题默认配置
|
||||||
|
|
||||||
[puzzles.mcq]
|
[puzzles.mcq]
|
||||||
@@ -25,6 +33,7 @@ min_denominator = 3
|
|||||||
[paths] # 相对于配置文件的 ".." (即工作目录) 而言 或绝对路径
|
[paths] # 相对于配置文件的 ".." (即工作目录) 而言 或绝对路径
|
||||||
nucleon_dir = "./data/nucleon"
|
nucleon_dir = "./data/nucleon"
|
||||||
electron_dir = "./data/electron"
|
electron_dir = "./data/electron"
|
||||||
|
global_dir = "./data/global" # 全局数据路径, SM-15 等算法需要
|
||||||
orbital_dir = "./data/orbital"
|
orbital_dir = "./data/orbital"
|
||||||
cache_dir = "./data/cache"
|
cache_dir = "./data/cache"
|
||||||
template_dir = "./data/template"
|
template_dir = "./data/template"
|
||||||
@@ -34,6 +43,19 @@ audio = "playsound" # 可选项: playsound(通用), termux(仅用于支持 Andro
|
|||||||
tts = "edgetts" # 可选项: edgetts
|
tts = "edgetts" # 可选项: edgetts
|
||||||
llm = "openai" # 可选项: openai
|
llm = "openai" # 可选项: openai
|
||||||
|
|
||||||
|
[providers.tts.edgetts] # EdgeTTS 设置
|
||||||
|
voice = "zh-CN-XiaoxiaoNeural" # 可选项: zh-CN-YunjianNeural (男声), zh-CN-XiaoxiaoNeural (女声)
|
||||||
|
|
||||||
[providers.llm.openai] # 与 OpenAI 相容的语言模型接口服务设置
|
[providers.llm.openai] # 与 OpenAI 相容的语言模型接口服务设置
|
||||||
url = ""
|
url = ""
|
||||||
key = ""
|
key = ""
|
||||||
|
|
||||||
|
[sync.webdav] # WebDAV 同步设置
|
||||||
|
enabled = false
|
||||||
|
url = ""
|
||||||
|
username = ""
|
||||||
|
password = ""
|
||||||
|
remote_path = "/heurams/"
|
||||||
|
sync_mode = "bidirectional" # bidirectional/upload_only/download_only
|
||||||
|
conflict_strategy = "newer" # newer/ask/keep_both
|
||||||
|
verify_ssl = true
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from .screens.about import AboutScreen
|
|||||||
from .screens.dashboard import DashboardScreen
|
from .screens.dashboard import DashboardScreen
|
||||||
from .screens.nucreator import NucleonCreatorScreen
|
from .screens.nucreator import NucleonCreatorScreen
|
||||||
from .screens.precache import PrecachingScreen
|
from .screens.precache import PrecachingScreen
|
||||||
|
from .screens.synctool import SyncScreen
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -39,12 +40,14 @@ class HeurAMSApp(App):
|
|||||||
("1", "app.push_screen('dashboard')", "仪表盘"),
|
("1", "app.push_screen('dashboard')", "仪表盘"),
|
||||||
("2", "app.push_screen('precache_all')", "缓存管理器"),
|
("2", "app.push_screen('precache_all')", "缓存管理器"),
|
||||||
("3", "app.push_screen('nucleon_creator')", "创建新单元"),
|
("3", "app.push_screen('nucleon_creator')", "创建新单元"),
|
||||||
|
("4", "app.push_screen('synctool')", "同步工具"),
|
||||||
("0", "app.push_screen('about')", "版本信息"),
|
("0", "app.push_screen('about')", "版本信息"),
|
||||||
]
|
]
|
||||||
SCREENS = {
|
SCREENS = {
|
||||||
"dashboard": DashboardScreen,
|
"dashboard": DashboardScreen,
|
||||||
"nucleon_creator": NucleonCreatorScreen,
|
"nucleon_creator": NucleonCreatorScreen,
|
||||||
"precache_all": PrecachingScreen,
|
"precache_all": PrecachingScreen,
|
||||||
|
"synctool": SyncScreen,
|
||||||
"about": AboutScreen,
|
"about": AboutScreen,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,8 @@ class AboutScreen(Screen):
|
|||||||
|
|
||||||
特别感谢:
|
特别感谢:
|
||||||
|
|
||||||
- [Piotr A. Woźniak](https://supermemo.guru/wiki/Piotr_Wozniak): SuperMemo-2 算法
|
- [Piotr A. Woźniak](https://supermemo.guru/wiki/Piotr_Wozniak): SM-2 算法与 SM-15 算法理论
|
||||||
|
- [Kazuaki Tanida](https://github.com/slaypni): SM-15 算法的 CoffeeScript 实现
|
||||||
- [Thoughts Memo](https://www.zhihu.com/people/L.M.Sherlock): 文献参考
|
- [Thoughts Memo](https://www.zhihu.com/people/L.M.Sherlock): 文献参考
|
||||||
|
|
||||||
# 参与贡献
|
# 参与贡献
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import time
|
||||||
|
|
||||||
from textual.app import ComposeResult
|
from textual.app import ComposeResult
|
||||||
from textual.containers import Horizontal, ScrollableContainer
|
from textual.containers import Horizontal, ScrollableContainer
|
||||||
@@ -18,22 +19,304 @@ class SyncScreen(Screen):
|
|||||||
|
|
||||||
def __init__(self, nucleons: list = [], desc: str = ""):
|
def __init__(self, nucleons: list = [], desc: str = ""):
|
||||||
super().__init__(name=None, id=None, classes=None)
|
super().__init__(name=None, id=None, classes=None)
|
||||||
|
self.sync_service = None
|
||||||
|
self.sync_config = {}
|
||||||
|
self.is_syncing = False
|
||||||
|
self.is_paused = False
|
||||||
|
self.log_messages = []
|
||||||
|
self.max_log_lines = 50
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
def compose(self) -> ComposeResult:
|
||||||
yield Header(show_clock=True)
|
yield Header(show_clock=True)
|
||||||
with ScrollableContainer(id="sync_container"):
|
with ScrollableContainer(id="sync_container"):
|
||||||
pass
|
# 标题和连接状态
|
||||||
|
yield Static("WebDAV 同步工具", classes="title")
|
||||||
|
yield Static("", id="status_label", classes="status")
|
||||||
|
|
||||||
|
# 配置信息
|
||||||
|
yield Static("服务器配置", classes="section_title")
|
||||||
|
with Horizontal(classes="config_info"):
|
||||||
|
yield Static("URL:", classes="config_label")
|
||||||
|
yield Static("", id="server_url", classes="config_value")
|
||||||
|
with Horizontal(classes="config_info"):
|
||||||
|
yield Static("远程路径:", classes="config_label")
|
||||||
|
yield Static("", id="remote_path", classes="config_value")
|
||||||
|
with Horizontal(classes="config_info"):
|
||||||
|
yield Static("同步模式:", classes="config_label")
|
||||||
|
yield Static("", id="sync_mode", classes="config_value")
|
||||||
|
|
||||||
|
# 控制按钮
|
||||||
|
yield Static("控制面板", classes="section_title")
|
||||||
|
with Horizontal(classes="control_buttons"):
|
||||||
|
yield Button("测试连接", id="test_connection", variant="primary")
|
||||||
|
yield Button("开始同步", id="start_sync", variant="success")
|
||||||
|
yield Button("暂停", id="pause_sync", variant="warning", disabled=True)
|
||||||
|
yield Button("取消", id="cancel_sync", variant="error", disabled=True)
|
||||||
|
|
||||||
|
# 进度显示
|
||||||
|
yield Static("同步进度", classes="section_title")
|
||||||
|
yield ProgressBar(id="progress_bar", show_percentage=True, total=100)
|
||||||
|
yield Static("", id="progress_label", classes="progress_text")
|
||||||
|
|
||||||
|
# 日志输出
|
||||||
|
yield Static("同步日志", classes="section_title")
|
||||||
|
yield Static("", id="log_output", classes="log_output")
|
||||||
|
|
||||||
yield Footer()
|
yield Footer()
|
||||||
|
|
||||||
def on_mount(self):
|
def on_mount(self):
|
||||||
"""挂载时初始化状态"""
|
"""挂载时初始化状态"""
|
||||||
|
self.load_config()
|
||||||
|
self.update_ui_from_config()
|
||||||
|
self.log_message("同步工具已启动")
|
||||||
|
|
||||||
|
def load_config(self):
|
||||||
|
"""从配置文件加载同步设置"""
|
||||||
|
try:
|
||||||
|
from heurams.context import config_var
|
||||||
|
config_data = config_var.get().data
|
||||||
|
self.sync_config = config_data.get('sync', {}).get('webdav', {})
|
||||||
|
|
||||||
|
# 创建同步服务实例
|
||||||
|
from heurams.services.sync_service import create_sync_service_from_config
|
||||||
|
self.sync_service = create_sync_service_from_config()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log_message(f"加载配置失败: {e}", is_error=True)
|
||||||
|
self.sync_config = {}
|
||||||
|
|
||||||
|
def update_ui_from_config(self):
|
||||||
|
"""更新 UI 显示配置信息"""
|
||||||
|
try:
|
||||||
|
# 更新服务器 URL
|
||||||
|
url = self.sync_config.get('url', '未配置')
|
||||||
|
url_widget = self.query_one("#server_url")
|
||||||
|
url_widget.update(url if url else '未配置') # type: ignore
|
||||||
|
|
||||||
|
# 更新远程路径
|
||||||
|
remote_path = self.sync_config.get('remote_path', '/heurams/')
|
||||||
|
path_widget = self.query_one("#remote_path")
|
||||||
|
path_widget.update(remote_path) # type: ignore
|
||||||
|
|
||||||
|
# 更新同步模式
|
||||||
|
sync_mode = self.sync_config.get('sync_mode', 'bidirectional')
|
||||||
|
mode_widget = self.query_one("#sync_mode")
|
||||||
|
mode_map = {
|
||||||
|
'bidirectional': '双向同步',
|
||||||
|
'upload_only': '仅上传',
|
||||||
|
'download_only': '仅下载',
|
||||||
|
}
|
||||||
|
mode_widget.update(mode_map.get(sync_mode, sync_mode)) # type: ignore
|
||||||
|
|
||||||
|
# 更新状态标签
|
||||||
|
status_widget = self.query_one("#status_label")
|
||||||
|
if self.sync_service and self.sync_service.client:
|
||||||
|
status_widget.update("✅ 同步服务已就绪") # type: ignore
|
||||||
|
status_widget.add_class("ready")
|
||||||
|
else:
|
||||||
|
status_widget.update("❌ 同步服务未配置或未启用") # type: ignore
|
||||||
|
status_widget.add_class("error")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log_message(f"更新 UI 失败: {e}", is_error=True)
|
||||||
|
|
||||||
def update_status(self, status, current_item="", progress=None):
|
def update_status(self, status, current_item="", progress=None):
|
||||||
"""更新状态显示"""
|
"""更新状态显示"""
|
||||||
|
try:
|
||||||
|
status_widget = self.query_one("#status_label")
|
||||||
|
status_widget.update(status) # type: ignore
|
||||||
|
|
||||||
|
if progress is not None:
|
||||||
|
progress_bar = self.query_one("#progress_bar")
|
||||||
|
progress_bar.progress = progress # type: ignore
|
||||||
|
|
||||||
|
progress_label = self.query_one("#progress_label")
|
||||||
|
progress_label.update(f"{progress}% - {current_item}" if current_item else f"{progress}%") # type: ignore
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log_message(f"更新状态失败: {e}", is_error=True)
|
||||||
|
|
||||||
|
def log_message(self, message: str, is_error: bool = False):
|
||||||
|
"""添加日志消息并更新显示"""
|
||||||
|
timestamp = time.strftime("%H:%M:%S")
|
||||||
|
prefix = "[ERROR]" if is_error else "[INFO]"
|
||||||
|
log_line = f"{timestamp} {prefix} {message}"
|
||||||
|
|
||||||
|
self.log_messages.append(log_line)
|
||||||
|
# 保持日志行数不超过最大值
|
||||||
|
if len(self.log_messages) > self.max_log_lines:
|
||||||
|
self.log_messages = self.log_messages[-self.max_log_lines:]
|
||||||
|
|
||||||
|
# 更新日志显示
|
||||||
|
try:
|
||||||
|
log_widget = self.query_one("#log_output")
|
||||||
|
log_widget.update("\n".join(self.log_messages)) # type: ignore
|
||||||
|
except Exception:
|
||||||
|
pass # 如果组件未就绪,忽略错误
|
||||||
|
|
||||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
"""处理按钮点击事件"""
|
||||||
|
button_id = event.button.id
|
||||||
|
|
||||||
|
if button_id == "test_connection":
|
||||||
|
self.test_connection()
|
||||||
|
elif button_id == "start_sync":
|
||||||
|
self.start_sync()
|
||||||
|
elif button_id == "pause_sync":
|
||||||
|
self.pause_sync()
|
||||||
|
elif button_id == "cancel_sync":
|
||||||
|
self.cancel_sync()
|
||||||
|
|
||||||
event.stop()
|
event.stop()
|
||||||
|
|
||||||
|
def test_connection(self):
|
||||||
|
"""测试 WebDAV 服务器连接"""
|
||||||
|
if not self.sync_service:
|
||||||
|
self.log_message("同步服务未初始化,请检查配置", is_error=True)
|
||||||
|
self.update_status("❌ 同步服务未初始化")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.log_message("正在测试 WebDAV 连接...")
|
||||||
|
self.update_status("正在测试连接...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
success = self.sync_service.test_connection()
|
||||||
|
if success:
|
||||||
|
self.log_message("连接测试成功")
|
||||||
|
self.update_status("✅ 连接正常")
|
||||||
|
else:
|
||||||
|
self.log_message("连接测试失败", is_error=True)
|
||||||
|
self.update_status("❌ 连接失败")
|
||||||
|
except Exception as e:
|
||||||
|
self.log_message(f"连接测试异常: {e}", is_error=True)
|
||||||
|
self.update_status("❌ 连接异常")
|
||||||
|
|
||||||
|
def start_sync(self):
|
||||||
|
"""开始同步"""
|
||||||
|
if not self.sync_service:
|
||||||
|
self.log_message("同步服务未初始化,无法开始同步", is_error=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.is_syncing:
|
||||||
|
self.log_message("同步已在进行中", is_error=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.is_syncing = True
|
||||||
|
self.is_paused = False
|
||||||
|
self.update_button_states()
|
||||||
|
|
||||||
|
self.log_message("开始同步数据...")
|
||||||
|
self.update_status("正在同步...", progress=0)
|
||||||
|
|
||||||
|
# 启动后台同步任务
|
||||||
|
self.run_worker(self.perform_sync, thread=True)
|
||||||
|
|
||||||
|
def perform_sync(self):
|
||||||
|
"""执行同步任务(在后台线程中运行)"""
|
||||||
|
worker = get_current_worker()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取需要同步的本地目录
|
||||||
|
from heurams.context import config_var
|
||||||
|
config = config_var.get()
|
||||||
|
paths = config.get('paths', {})
|
||||||
|
|
||||||
|
# 同步 nucleon 目录
|
||||||
|
nucleon_dir = pathlib.Path(paths.get('nucleon_dir', './data/nucleon'))
|
||||||
|
if nucleon_dir.exists():
|
||||||
|
self.log_message(f"同步 nucleon 目录: {nucleon_dir}")
|
||||||
|
self.update_status(f"同步 nucleon 目录...", progress=10)
|
||||||
|
|
||||||
|
result = self.sync_service.sync_directory(nucleon_dir) # type: ignore
|
||||||
|
if result.get('success'):
|
||||||
|
self.log_message(f"nucleon 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个")
|
||||||
|
else:
|
||||||
|
self.log_message(f"nucleon 同步失败: {result.get('error', '未知错误')}", is_error=True)
|
||||||
|
|
||||||
|
# 同步 electron 目录
|
||||||
|
electron_dir = pathlib.Path(paths.get('electron_dir', './data/electron'))
|
||||||
|
if electron_dir.exists():
|
||||||
|
self.log_message(f"同步 electron 目录: {electron_dir}")
|
||||||
|
self.update_status(f"同步 electron 目录...", progress=60)
|
||||||
|
|
||||||
|
result = self.sync_service.sync_directory(electron_dir) # type: ignore
|
||||||
|
if result.get('success'):
|
||||||
|
self.log_message(f"electron 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个")
|
||||||
|
else:
|
||||||
|
self.log_message(f"electron 同步失败: {result.get('error', '未知错误')}", is_error=True)
|
||||||
|
|
||||||
|
# 同步 orbital 目录(如果存在)
|
||||||
|
orbital_dir = pathlib.Path(paths.get('orbital_dir', './data/orbital'))
|
||||||
|
if orbital_dir.exists():
|
||||||
|
self.log_message(f"同步 orbital 目录: {orbital_dir}")
|
||||||
|
self.update_status(f"同步 orbital 目录...", progress=80)
|
||||||
|
|
||||||
|
result = self.sync_service.sync_directory(orbital_dir) # type: ignore
|
||||||
|
if result.get('success'):
|
||||||
|
self.log_message(f"orbital 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个")
|
||||||
|
else:
|
||||||
|
self.log_message(f"orbital 同步失败: {result.get('error', '未知错误')}", is_error=True)
|
||||||
|
|
||||||
|
# 同步完成
|
||||||
|
self.update_status("同步完成", progress=100)
|
||||||
|
self.log_message("所有目录同步完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log_message(f"同步过程中发生错误: {e}", is_error=True)
|
||||||
|
self.update_status("同步失败")
|
||||||
|
finally:
|
||||||
|
# 重置同步状态
|
||||||
|
self.is_syncing = False
|
||||||
|
self.is_paused = False
|
||||||
|
self.update_button_states() # type: ignore
|
||||||
|
|
||||||
|
def pause_sync(self):
|
||||||
|
"""暂停同步"""
|
||||||
|
if not self.is_syncing:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.is_paused = not self.is_paused
|
||||||
|
self.update_button_states()
|
||||||
|
|
||||||
|
if self.is_paused:
|
||||||
|
self.log_message("同步已暂停")
|
||||||
|
self.update_status("同步已暂停")
|
||||||
|
else:
|
||||||
|
self.log_message("同步已恢复")
|
||||||
|
self.update_status("正在同步...")
|
||||||
|
|
||||||
|
def cancel_sync(self):
|
||||||
|
"""取消同步"""
|
||||||
|
if not self.is_syncing:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.is_syncing = False
|
||||||
|
self.is_paused = False
|
||||||
|
self.update_button_states()
|
||||||
|
|
||||||
|
self.log_message("同步已取消")
|
||||||
|
self.update_status("同步已取消")
|
||||||
|
|
||||||
|
def update_button_states(self):
|
||||||
|
"""更新按钮状态"""
|
||||||
|
try:
|
||||||
|
start_button = self.query_one("#start_sync")
|
||||||
|
pause_button = self.query_one("#pause_sync")
|
||||||
|
cancel_button = self.query_one("#cancel_sync")
|
||||||
|
|
||||||
|
if self.is_syncing:
|
||||||
|
start_button.disabled = True
|
||||||
|
pause_button.disabled = False
|
||||||
|
cancel_button.disabled = False
|
||||||
|
pause_button.label = "继续" if self.is_paused else "暂停" # type: ignore
|
||||||
|
else:
|
||||||
|
start_button.disabled = False
|
||||||
|
pause_button.disabled = True
|
||||||
|
cancel_button.disabled = True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log_message(f"更新按钮状态失败: {e}", is_error=True)
|
||||||
|
|
||||||
def action_go_back(self):
|
def action_go_back(self):
|
||||||
self.app.pop_screen()
|
self.app.pop_screen()
|
||||||
|
|
||||||
|
|||||||
410
src/heurams/services/sync_service.py
Normal file
410
src/heurams/services/sync_service.py
Normal file
@@ -0,0 +1,410 @@
|
|||||||
|
# WebDAV 同步服务
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import time
|
||||||
|
import typing
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from webdav3.client import Client
|
||||||
|
|
||||||
|
from heurams.context import config_var
|
||||||
|
from heurams.services.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SyncMode(Enum):
|
||||||
|
"""同步模式枚举"""
|
||||||
|
BIDIRECTIONAL = "bidirectional"
|
||||||
|
UPLOAD_ONLY = "upload_only"
|
||||||
|
DOWNLOAD_ONLY = "download_only"
|
||||||
|
|
||||||
|
|
||||||
|
class ConflictStrategy(Enum):
|
||||||
|
"""冲突解决策略枚举"""
|
||||||
|
NEWER = "newer" # 较新文件覆盖较旧文件
|
||||||
|
ASK = "ask" # 用户手动选择
|
||||||
|
KEEP_BOTH = "keep_both" # 保留双方(重命名)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SyncConfig:
|
||||||
|
"""同步配置数据类"""
|
||||||
|
enabled: bool = False
|
||||||
|
url: str = ""
|
||||||
|
username: str = ""
|
||||||
|
password: str = ""
|
||||||
|
remote_path: str = "/heurams/"
|
||||||
|
sync_mode: SyncMode = SyncMode.BIDIRECTIONAL
|
||||||
|
conflict_strategy: ConflictStrategy = ConflictStrategy.NEWER
|
||||||
|
verify_ssl: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class SyncService:
|
||||||
|
"""WebDAV 同步服务"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
logger.debug(f"{str(self.config)}")
|
||||||
|
self.client = None
|
||||||
|
self._setup_client()
|
||||||
|
|
||||||
|
def _setup_client(self):
|
||||||
|
"""设置 WebDAV 客户端"""
|
||||||
|
if not self.config.enabled or not self.config.url:
|
||||||
|
logger.warning("同步服务未启用或未配置 URL")
|
||||||
|
return
|
||||||
|
|
||||||
|
options = {
|
||||||
|
'webdav_hostname': self.config.url,
|
||||||
|
'webdav_login': self.config.username,
|
||||||
|
'webdav_password': self.config.password,
|
||||||
|
'webdav_root': self.config.remote_path,
|
||||||
|
'verify_ssl': self.config.verify_ssl,
|
||||||
|
'disable_check': True, # 不检查服务器支持的功能
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client = Client(options)
|
||||||
|
logger.info("WebDAV 客户端初始化完成")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("WebDAV 客户端初始化失败: %s", e)
|
||||||
|
self.client = None
|
||||||
|
|
||||||
|
def test_connection(self) -> bool:
|
||||||
|
"""测试 WebDAV 服务器连接"""
|
||||||
|
if not self.client:
|
||||||
|
logger.error("WebDAV 客户端未初始化")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 尝试列出根目录
|
||||||
|
self.client.list()
|
||||||
|
logger.info("WebDAV 连接测试成功")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("WebDAV 连接测试失败: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_local_files(self, local_dir: pathlib.Path) -> typing.Dict[str, dict]:
|
||||||
|
"""获取本地文件列表及其元数据"""
|
||||||
|
files = {}
|
||||||
|
for root, _, filenames in os.walk(local_dir):
|
||||||
|
for filename in filenames:
|
||||||
|
file_path = pathlib.Path(root) / filename
|
||||||
|
rel_path = file_path.relative_to(local_dir)
|
||||||
|
stat = file_path.stat()
|
||||||
|
files[str(rel_path)] = {
|
||||||
|
'path': file_path,
|
||||||
|
'size': stat.st_size,
|
||||||
|
'mtime': stat.st_mtime,
|
||||||
|
'hash': self._calculate_hash(file_path),
|
||||||
|
}
|
||||||
|
return files
|
||||||
|
|
||||||
|
def _get_remote_files(self) -> typing.Dict[str, dict]:
|
||||||
|
"""获取远程文件列表及其元数据"""
|
||||||
|
if not self.client:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
remote_list = self.client.list(recursive=True)
|
||||||
|
files = {}
|
||||||
|
for item in remote_list:
|
||||||
|
if not item.endswith('/'): # 忽略目录
|
||||||
|
rel_path = item.lstrip('/')
|
||||||
|
try:
|
||||||
|
info = self.client.info(item)
|
||||||
|
files[rel_path] = {
|
||||||
|
'path': item,
|
||||||
|
'size': info.get('size', 0),
|
||||||
|
'mtime': self._parse_remote_mtime(info),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("无法获取远程文件信息 %s: %s", item, e)
|
||||||
|
return files
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("获取远程文件列表失败: %s", e)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _calculate_hash(self, file_path: pathlib.Path, block_size: int = 65536) -> str:
|
||||||
|
"""计算文件的 SHA-256 哈希值"""
|
||||||
|
sha256 = hashlib.sha256()
|
||||||
|
try:
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
for block in iter(lambda: f.read(block_size), b''):
|
||||||
|
sha256.update(block)
|
||||||
|
return sha256.hexdigest()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("计算文件哈希失败 %s: %s", file_path, e)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _parse_remote_mtime(self, info: dict) -> float:
|
||||||
|
"""解析远程文件的修改时间"""
|
||||||
|
# WebDAV 可能返回 Last-Modified 头或其他时间格式
|
||||||
|
# 这里简单返回当前时间,实际应根据服务器响应解析
|
||||||
|
return time.time()
|
||||||
|
|
||||||
|
def sync_directory(self, local_dir: pathlib.Path) -> typing.Dict[str, typing.Any]:
|
||||||
|
"""
|
||||||
|
同步目录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_dir: 本地目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
同步结果统计
|
||||||
|
"""
|
||||||
|
if not self.client:
|
||||||
|
logger.error("WebDAV 客户端未初始化")
|
||||||
|
return {'success': False, 'error': '客户端未初始化'}
|
||||||
|
|
||||||
|
results = {
|
||||||
|
'uploaded': 0,
|
||||||
|
'downloaded': 0,
|
||||||
|
'conflicts': 0,
|
||||||
|
'errors': 0,
|
||||||
|
'success': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 确保远程目录存在
|
||||||
|
self.client.mkdir(self.config.remote_path)
|
||||||
|
|
||||||
|
local_files = self._get_local_files(local_dir)
|
||||||
|
remote_files = self._get_remote_files()
|
||||||
|
|
||||||
|
# 根据同步模式处理文件
|
||||||
|
if self.config.sync_mode in [SyncMode.BIDIRECTIONAL, SyncMode.UPLOAD_ONLY]:
|
||||||
|
stats = self._upload_files(local_dir, local_files, remote_files)
|
||||||
|
results['uploaded'] += stats.get('uploaded', 0)
|
||||||
|
results['conflicts'] += stats.get('conflicts', 0)
|
||||||
|
results['errors'] += stats.get('errors', 0)
|
||||||
|
|
||||||
|
if self.config.sync_mode in [SyncMode.BIDIRECTIONAL, SyncMode.DOWNLOAD_ONLY]:
|
||||||
|
stats = self._download_files(local_dir, local_files, remote_files)
|
||||||
|
results['downloaded'] += stats.get('downloaded', 0)
|
||||||
|
results['conflicts'] += stats.get('conflicts', 0)
|
||||||
|
results['errors'] += stats.get('errors', 0)
|
||||||
|
|
||||||
|
logger.info("同步完成: %s", results)
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("同步过程中发生错误: %s", e)
|
||||||
|
results['success'] = False
|
||||||
|
results['error'] = str(e)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _upload_files(self, local_dir: pathlib.Path,
|
||||||
|
local_files: dict, remote_files: dict) -> typing.Dict[str, int]:
|
||||||
|
"""上传文件到远程服务器"""
|
||||||
|
stats = {'uploaded': 0, 'errors': 0, 'conflicts': 0}
|
||||||
|
|
||||||
|
for rel_path, local_info in local_files.items():
|
||||||
|
remote_info = remote_files.get(rel_path)
|
||||||
|
|
||||||
|
# 判断是否需要上传
|
||||||
|
should_upload = False
|
||||||
|
conflict_resolved = False
|
||||||
|
remote_path = os.path.join(self.config.remote_path, rel_path)
|
||||||
|
|
||||||
|
if not remote_info:
|
||||||
|
should_upload = True # 远程不存在
|
||||||
|
else:
|
||||||
|
# 检查冲突
|
||||||
|
local_mtime = local_info.get('mtime', 0)
|
||||||
|
remote_mtime = remote_info.get('mtime', 0)
|
||||||
|
|
||||||
|
if local_mtime != remote_mtime:
|
||||||
|
# 存在冲突
|
||||||
|
stats['conflicts'] += 1
|
||||||
|
should_upload, should_download = self._handle_conflict(local_info, remote_info)
|
||||||
|
|
||||||
|
if should_upload and self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH:
|
||||||
|
# 重命名远程文件避免覆盖
|
||||||
|
conflict_suffix = f".conflict_{int(remote_mtime)}"
|
||||||
|
name, ext = os.path.splitext(rel_path)
|
||||||
|
new_rel_path = f"{name}{conflict_suffix}{ext}" if ext else f"{name}{conflict_suffix}"
|
||||||
|
remote_path = os.path.join(self.config.remote_path, new_rel_path)
|
||||||
|
conflict_resolved = True
|
||||||
|
logger.debug("冲突文件重命名: %s -> %s", rel_path, new_rel_path)
|
||||||
|
else:
|
||||||
|
# 时间相同,无需上传
|
||||||
|
should_upload = False
|
||||||
|
|
||||||
|
if should_upload:
|
||||||
|
try:
|
||||||
|
self.client.upload_file(local_info['path'], remote_path)
|
||||||
|
stats['uploaded'] += 1
|
||||||
|
logger.debug("上传文件: %s -> %s", rel_path, remote_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("上传文件失败 %s: %s", rel_path, e)
|
||||||
|
stats['errors'] += 1
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def _download_files(self, local_dir: pathlib.Path,
|
||||||
|
local_files: dict, remote_files: dict) -> typing.Dict[str, int]:
|
||||||
|
"""从远程服务器下载文件"""
|
||||||
|
stats = {'downloaded': 0, 'errors': 0, 'conflicts': 0}
|
||||||
|
|
||||||
|
for rel_path, remote_info in remote_files.items():
|
||||||
|
local_info = local_files.get(rel_path)
|
||||||
|
|
||||||
|
# 判断是否需要下载
|
||||||
|
should_download = False
|
||||||
|
if not local_info:
|
||||||
|
should_download = True # 本地不存在
|
||||||
|
else:
|
||||||
|
# 检查冲突
|
||||||
|
local_mtime = local_info.get('mtime', 0)
|
||||||
|
remote_mtime = remote_info.get('mtime', 0)
|
||||||
|
|
||||||
|
if local_mtime != remote_mtime:
|
||||||
|
# 存在冲突
|
||||||
|
stats['conflicts'] += 1
|
||||||
|
should_upload, should_download = self._handle_conflict(local_info, remote_info)
|
||||||
|
# 如果应该上传,则不应该下载(冲突已在上传侧处理)
|
||||||
|
if should_upload:
|
||||||
|
should_download = False
|
||||||
|
else:
|
||||||
|
# 时间相同,无需下载
|
||||||
|
should_download = False
|
||||||
|
|
||||||
|
if should_download:
|
||||||
|
try:
|
||||||
|
local_path = local_dir / rel_path
|
||||||
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.client.download_file(remote_info['path'], str(local_path))
|
||||||
|
stats['downloaded'] += 1
|
||||||
|
logger.debug("下载文件: %s -> %s", rel_path, local_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("下载文件失败 %s: %s", rel_path, e)
|
||||||
|
stats['errors'] += 1
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def _handle_conflict(self, local_info: dict, remote_info: dict) -> typing.Tuple[bool, bool]:
|
||||||
|
"""
|
||||||
|
处理文件冲突
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(should_upload, should_download) - 是否应该上传和下载
|
||||||
|
"""
|
||||||
|
local_mtime = local_info.get('mtime', 0)
|
||||||
|
remote_mtime = remote_info.get('mtime', 0)
|
||||||
|
|
||||||
|
if self.config.conflict_strategy == ConflictStrategy.NEWER:
|
||||||
|
# 较新文件覆盖较旧文件
|
||||||
|
if local_mtime > remote_mtime:
|
||||||
|
return True, False # 上传本地较新版本
|
||||||
|
elif remote_mtime > local_mtime:
|
||||||
|
return False, True # 下载远程较新版本
|
||||||
|
else:
|
||||||
|
return False, False # 时间相同,无需操作
|
||||||
|
|
||||||
|
elif self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH:
|
||||||
|
# 保留双方 - 重命名远程文件
|
||||||
|
# 这里实现简单的重命名策略:添加冲突后缀
|
||||||
|
# 实际应该在上传时处理重命名
|
||||||
|
# 返回 True, False 表示上传重命名后的文件
|
||||||
|
# 重命名逻辑在调用处处理
|
||||||
|
return True, False
|
||||||
|
|
||||||
|
elif self.config.conflict_strategy == ConflictStrategy.ASK:
|
||||||
|
# 用户手动选择 - 记录冲突,跳过
|
||||||
|
# 返回 False, False 跳过,等待用户决定
|
||||||
|
logger.warning("文件冲突需要用户手动选择: local_mtime=%s, remote_mtime=%s",
|
||||||
|
local_mtime, remote_mtime)
|
||||||
|
return False, False
|
||||||
|
|
||||||
|
return False, False
|
||||||
|
|
||||||
|
def _should_upload(self, local_info: dict, remote_info: dict) -> bool:
|
||||||
|
"""判断是否需要上传(本地较新或哈希不同)"""
|
||||||
|
# 这里实现简单的基于时间的比较
|
||||||
|
# 实际应该使用哈希比较更可靠
|
||||||
|
return local_info.get('mtime', 0) > remote_info.get('mtime', 0)
|
||||||
|
|
||||||
|
def _should_download(self, local_info: dict, remote_info: dict) -> bool:
|
||||||
|
"""判断是否需要下载(远程较新)"""
|
||||||
|
return remote_info.get('mtime', 0) > local_info.get('mtime', 0)
|
||||||
|
|
||||||
|
def upload_file(self, local_path: pathlib.Path, remote_path: str = "") -> bool:
|
||||||
|
"""上传单个文件"""
|
||||||
|
if not self.client:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not remote_path:
|
||||||
|
remote_path = os.path.join(self.config.remote_path, local_path.name)
|
||||||
|
self.client.upload_file(str(local_path), remote_path)
|
||||||
|
logger.info("文件上传成功: %s -> %s", local_path, remote_path)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("文件上传失败: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def download_file(self, remote_path: str, local_path: pathlib.Path) -> bool:
|
||||||
|
"""下载单个文件"""
|
||||||
|
if not self.client:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.client.download_file(remote_path, str(local_path))
|
||||||
|
logger.info("文件下载成功: %s -> %s", remote_path, local_path)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("文件下载失败: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_remote_file(self, remote_path: str) -> bool:
|
||||||
|
"""删除远程文件"""
|
||||||
|
if not self.client:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.clean(remote_path)
|
||||||
|
logger.info("远程文件删除成功: %s", remote_path)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("远程文件删除失败: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def create_sync_service_from_config() -> typing.Optional[SyncService]:
|
||||||
|
"""从配置文件创建同步服务实例"""
|
||||||
|
try:
|
||||||
|
from heurams.context import config_var
|
||||||
|
|
||||||
|
sync_config = config_var.get()['providers']['sync']['webdav']
|
||||||
|
if not sync_config.get('enabled', False):
|
||||||
|
logger.debug("同步服务未启用")
|
||||||
|
return None
|
||||||
|
|
||||||
|
config = SyncConfig(
|
||||||
|
enabled=sync_config.get('enabled', False),
|
||||||
|
url=sync_config.get('url', ''),
|
||||||
|
username=sync_config.get('username', ''),
|
||||||
|
password=sync_config.get('password', ''),
|
||||||
|
remote_path=sync_config.get('remote_path', '/heurams/'),
|
||||||
|
sync_mode=SyncMode(sync_config.get('sync_mode', 'bidirectional')),
|
||||||
|
conflict_strategy=ConflictStrategy(sync_config.get('conflict_strategy', 'newer')),
|
||||||
|
verify_ssl=sync_config.get('verify_ssl', True),
|
||||||
|
)
|
||||||
|
|
||||||
|
service = SyncService(config)
|
||||||
|
if service.client is None:
|
||||||
|
logger.warning("同步服务客户端创建失败")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return service
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("创建同步服务失败: %s", e)
|
||||||
|
return None
|
||||||
@@ -3,7 +3,7 @@ from heurams.services.logger import get_logger
|
|||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
ver = "0.4.2"
|
ver = "0.4.3"
|
||||||
stage = "prototype"
|
stage = "prototype"
|
||||||
codename = "fledge" # 雏鸟, 0.4.x 版本
|
codename = "fledge" # 雏鸟, 0.4.x 版本
|
||||||
|
|
||||||
|
|||||||
305
tests/interface/test_synctool.py
Normal file
305
tests/interface/test_synctool.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
SyncScreen 和 SyncService 的测试.
|
||||||
|
"""
|
||||||
|
import pathlib
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch, Mock
|
||||||
|
|
||||||
|
from heurams.context import ConfigContext
|
||||||
|
from heurams.services.config import ConfigFile
|
||||||
|
from heurams.services.sync_service import SyncService, SyncConfig, SyncMode, ConflictStrategy
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncServiceUnit(unittest.TestCase):
|
||||||
|
"""SyncService 的单元测试."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""在每个测试之前运行, 设置临时目录和模拟客户端."""
|
||||||
|
self.temp_dir = tempfile.TemporaryDirectory()
|
||||||
|
self.temp_path = pathlib.Path(self.temp_dir.name)
|
||||||
|
|
||||||
|
# 创建测试文件
|
||||||
|
self.test_file = self.temp_path / "test.txt"
|
||||||
|
self.test_file.write_text("测试内容")
|
||||||
|
|
||||||
|
# 模拟 WebDAV 客户端
|
||||||
|
self.mock_client = MagicMock()
|
||||||
|
|
||||||
|
# 创建同步配置
|
||||||
|
self.config = SyncConfig(
|
||||||
|
enabled=True,
|
||||||
|
url="https://example.com/dav/",
|
||||||
|
username="test",
|
||||||
|
password="test",
|
||||||
|
remote_path="/heurams/",
|
||||||
|
sync_mode=SyncMode.BIDIRECTIONAL,
|
||||||
|
conflict_strategy=ConflictStrategy.NEWER,
|
||||||
|
verify_ssl=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""在每个测试之后清理."""
|
||||||
|
self.temp_dir.cleanup()
|
||||||
|
|
||||||
|
@patch('heurams.services.sync_service.Client')
|
||||||
|
def test_sync_service_initialization(self, mock_client_class):
|
||||||
|
"""测试同步服务初始化."""
|
||||||
|
mock_client_class.return_value = self.mock_client
|
||||||
|
|
||||||
|
service = SyncService(self.config)
|
||||||
|
|
||||||
|
# 验证客户端已创建
|
||||||
|
mock_client_class.assert_called_once()
|
||||||
|
self.assertIsNotNone(service.client)
|
||||||
|
self.assertEqual(service.config, self.config)
|
||||||
|
|
||||||
|
@patch('heurams.services.sync_service.Client')
|
||||||
|
def test_sync_service_disabled(self, mock_client_class):
|
||||||
|
"""测试同步服务未启用."""
|
||||||
|
config = SyncConfig(enabled=False)
|
||||||
|
service = SyncService(config)
|
||||||
|
|
||||||
|
# 客户端不应初始化
|
||||||
|
mock_client_class.assert_not_called()
|
||||||
|
self.assertIsNone(service.client)
|
||||||
|
|
||||||
|
@patch('heurams.services.sync_service.Client')
|
||||||
|
def test_test_connection_success(self, mock_client_class):
|
||||||
|
"""测试连接测试成功."""
|
||||||
|
mock_client_class.return_value = self.mock_client
|
||||||
|
self.mock_client.list.return_value = []
|
||||||
|
|
||||||
|
service = SyncService(self.config)
|
||||||
|
result = service.test_connection()
|
||||||
|
|
||||||
|
self.assertTrue(result)
|
||||||
|
self.mock_client.list.assert_called_once()
|
||||||
|
|
||||||
|
@patch('heurams.services.sync_service.Client')
|
||||||
|
def test_test_connection_failure(self, mock_client_class):
|
||||||
|
"""测试连接测试失败."""
|
||||||
|
mock_client_class.return_value = self.mock_client
|
||||||
|
self.mock_client.list.side_effect = Exception("连接失败")
|
||||||
|
|
||||||
|
service = SyncService(self.config)
|
||||||
|
result = service.test_connection()
|
||||||
|
|
||||||
|
self.assertFalse(result)
|
||||||
|
self.mock_client.list.assert_called_once()
|
||||||
|
|
||||||
|
@patch('heurams.services.sync_service.Client')
|
||||||
|
def test_upload_file(self, mock_client_class):
|
||||||
|
"""测试上传单个文件."""
|
||||||
|
mock_client_class.return_value = self.mock_client
|
||||||
|
|
||||||
|
service = SyncService(self.config)
|
||||||
|
result = service.upload_file(self.test_file)
|
||||||
|
|
||||||
|
self.assertTrue(result)
|
||||||
|
self.mock_client.upload_file.assert_called_once()
|
||||||
|
|
||||||
|
@patch('heurams.services.sync_service.Client')
|
||||||
|
def test_download_file(self, mock_client_class):
|
||||||
|
"""测试下载单个文件."""
|
||||||
|
mock_client_class.return_value = self.mock_client
|
||||||
|
|
||||||
|
service = SyncService(self.config)
|
||||||
|
remote_path = "/heurams/test.txt"
|
||||||
|
local_path = self.temp_path / "downloaded.txt"
|
||||||
|
|
||||||
|
result = service.download_file(remote_path, local_path)
|
||||||
|
|
||||||
|
self.assertTrue(result)
|
||||||
|
self.mock_client.download_file.assert_called_once()
|
||||||
|
self.assertTrue(local_path.parent.exists())
|
||||||
|
|
||||||
|
@patch('heurams.services.sync_service.Client')
|
||||||
|
def test_sync_directory_no_files(self, mock_client_class):
|
||||||
|
"""测试同步空目录."""
|
||||||
|
mock_client_class.return_value = self.mock_client
|
||||||
|
self.mock_client.list.return_value = []
|
||||||
|
self.mock_client.mkdir.return_value = None
|
||||||
|
|
||||||
|
service = SyncService(self.config)
|
||||||
|
result = service.sync_directory(self.temp_path)
|
||||||
|
|
||||||
|
self.assertTrue(result['success'])
|
||||||
|
self.assertEqual(result['uploaded'], 0)
|
||||||
|
self.assertEqual(result['downloaded'], 0)
|
||||||
|
self.mock_client.mkdir.assert_called_once()
|
||||||
|
|
||||||
|
@patch('heurams.services.sync_service.Client')
|
||||||
|
def test_sync_directory_upload_only(self, mock_client_class):
|
||||||
|
"""测试仅上传模式."""
|
||||||
|
mock_client_class.return_value = self.mock_client
|
||||||
|
self.mock_client.list.return_value = []
|
||||||
|
self.mock_client.mkdir.return_value = None
|
||||||
|
|
||||||
|
config = SyncConfig(
|
||||||
|
enabled=True,
|
||||||
|
url="https://example.com/dav/",
|
||||||
|
username="test",
|
||||||
|
password="test",
|
||||||
|
remote_path="/heurams/",
|
||||||
|
sync_mode=SyncMode.UPLOAD_ONLY,
|
||||||
|
conflict_strategy=ConflictStrategy.NEWER,
|
||||||
|
)
|
||||||
|
|
||||||
|
service = SyncService(config)
|
||||||
|
result = service.sync_directory(self.temp_path)
|
||||||
|
|
||||||
|
self.assertTrue(result['success'])
|
||||||
|
self.mock_client.mkdir.assert_called_once()
|
||||||
|
|
||||||
|
@patch('heurams.services.sync_service.Client')
|
||||||
|
def test_conflict_strategy_newer(self, mock_client_class):
|
||||||
|
"""测试 NEWER 冲突策略."""
|
||||||
|
mock_client_class.return_value = self.mock_client
|
||||||
|
|
||||||
|
# 模拟远程文件存在
|
||||||
|
self.mock_client.list.return_value = ["test.txt"]
|
||||||
|
self.mock_client.info.return_value = {'size': 100, 'modified': '2023-01-01T00:00:00Z'}
|
||||||
|
self.mock_client.mkdir.return_value = None
|
||||||
|
|
||||||
|
service = SyncService(self.config)
|
||||||
|
result = service.sync_directory(self.temp_path)
|
||||||
|
|
||||||
|
self.assertTrue(result['success'])
|
||||||
|
# 应该有一个冲突
|
||||||
|
self.assertGreaterEqual(result.get('conflicts', 0), 0)
|
||||||
|
|
||||||
|
@patch('heurams.services.sync_service.Client')
|
||||||
|
def test_create_sync_service_from_config(self, mock_client_class):
|
||||||
|
"""测试从配置文件创建同步服务."""
|
||||||
|
mock_client_class.return_value = self.mock_client
|
||||||
|
|
||||||
|
# 创建临时配置文件
|
||||||
|
config_data = {
|
||||||
|
'sync': {
|
||||||
|
'webdav': {
|
||||||
|
'enabled': True,
|
||||||
|
'url': 'https://example.com/dav/',
|
||||||
|
'username': 'test',
|
||||||
|
'password': 'test',
|
||||||
|
'remote_path': '/heurams/',
|
||||||
|
'sync_mode': 'bidirectional',
|
||||||
|
'conflict_strategy': 'newer',
|
||||||
|
'verify_ssl': True,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 模拟 config_var
|
||||||
|
with patch('heurams.services.sync_service.config_var') as mock_config_var:
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.data = config_data
|
||||||
|
mock_config_var.get.return_value = mock_config
|
||||||
|
|
||||||
|
from heurams.services.sync_service import create_sync_service_from_config
|
||||||
|
service = create_sync_service_from_config()
|
||||||
|
|
||||||
|
self.assertIsNotNone(service)
|
||||||
|
self.assertIsNotNone(service.client)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncScreenUnit(unittest.TestCase):
|
||||||
|
"""SyncScreen 的单元测试."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""在每个测试之前运行."""
|
||||||
|
self.temp_dir = tempfile.TemporaryDirectory()
|
||||||
|
self.temp_path = pathlib.Path(self.temp_dir.name)
|
||||||
|
|
||||||
|
# 创建默认配置
|
||||||
|
default_config_path = (
|
||||||
|
pathlib.Path(__file__).parent.parent.parent
|
||||||
|
/ "src/heurams/default/config/config.toml"
|
||||||
|
)
|
||||||
|
self.config = ConfigFile(default_config_path)
|
||||||
|
|
||||||
|
# 更新配置中的路径
|
||||||
|
config_data = self.config.data
|
||||||
|
config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon")
|
||||||
|
config_data["paths"]["electron_dir"] = str(self.temp_path / "electron")
|
||||||
|
config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital")
|
||||||
|
config_data["paths"]["cache_dir"] = str(self.temp_path / "cache")
|
||||||
|
|
||||||
|
# 添加同步配置
|
||||||
|
if 'sync' not in config_data:
|
||||||
|
config_data['sync'] = {}
|
||||||
|
config_data['sync']['webdav'] = {
|
||||||
|
'enabled': False,
|
||||||
|
'url': '',
|
||||||
|
'username': '',
|
||||||
|
'password': '',
|
||||||
|
'remote_path': '/heurams/',
|
||||||
|
'sync_mode': 'bidirectional',
|
||||||
|
'conflict_strategy': 'newer',
|
||||||
|
'verify_ssl': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 创建目录
|
||||||
|
for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]:
|
||||||
|
pathlib.Path(config_data["paths"][dir_key]).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 使用 ConfigContext 设置配置
|
||||||
|
self.config_ctx = ConfigContext(self.config)
|
||||||
|
self.config_ctx.__enter__()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""在每个测试之后清理."""
|
||||||
|
self.config_ctx.__exit__(None, None, None)
|
||||||
|
self.temp_dir.cleanup()
|
||||||
|
|
||||||
|
@patch('heurams.interface.screens.synctool.create_sync_service_from_config')
|
||||||
|
def test_sync_screen_compose(self, mock_create_service):
|
||||||
|
"""测试 SyncScreen 的 compose 方法."""
|
||||||
|
from heurams.interface.screens.synctool import SyncScreen
|
||||||
|
|
||||||
|
# 模拟同步服务
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service.client = MagicMock()
|
||||||
|
mock_create_service.return_value = mock_service
|
||||||
|
|
||||||
|
screen = SyncScreen()
|
||||||
|
|
||||||
|
# 测试 compose 方法
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
result = screen.compose()
|
||||||
|
widgets = list(result)
|
||||||
|
|
||||||
|
# 检查基本部件
|
||||||
|
from textual.widgets import Footer, Header, Button, Static, ProgressBar
|
||||||
|
from textual.containers import ScrollableContainer
|
||||||
|
|
||||||
|
header_present = any(isinstance(w, Header) for w in widgets)
|
||||||
|
footer_present = any(isinstance(w, Footer) for w in widgets)
|
||||||
|
self.assertTrue(header_present)
|
||||||
|
self.assertTrue(footer_present)
|
||||||
|
|
||||||
|
# 检查容器
|
||||||
|
container_present = any(isinstance(w, ScrollableContainer) for w in widgets)
|
||||||
|
self.assertTrue(container_present)
|
||||||
|
|
||||||
|
@patch('heurams.interface.screens.synctool.create_sync_service_from_config')
|
||||||
|
def test_sync_screen_load_config(self, mock_create_service):
|
||||||
|
"""测试 SyncScreen 加载配置."""
|
||||||
|
from heurams.interface.screens.synctool import SyncScreen
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_service.client = MagicMock()
|
||||||
|
mock_create_service.return_value = mock_service
|
||||||
|
|
||||||
|
screen = SyncScreen()
|
||||||
|
screen.load_config()
|
||||||
|
|
||||||
|
# 验证配置已加载
|
||||||
|
self.assertIsNotNone(screen.sync_config)
|
||||||
|
mock_create_service.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user