style: 格式化代码
This commit is contained in:
@@ -1,9 +1,9 @@
|
|||||||
from textual.app import App
|
from textual.app import App
|
||||||
from textual.widgets import Button
|
from textual.widgets import Button
|
||||||
|
|
||||||
from heurams.services.logger import get_logger
|
|
||||||
from heurams.context import config_var
|
from heurams.context import config_var
|
||||||
from heurams.interface import HeurAMSApp
|
from heurams.interface import HeurAMSApp
|
||||||
|
from heurams.services.logger import get_logger
|
||||||
|
|
||||||
from .screens.about import AboutScreen
|
from .screens.about import AboutScreen
|
||||||
from .screens.dashboard import DashboardScreen
|
from .screens.dashboard import DashboardScreen
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ import pathlib
|
|||||||
from textual.app import ComposeResult
|
from textual.app import ComposeResult
|
||||||
from textual.containers import ScrollableContainer
|
from textual.containers import ScrollableContainer
|
||||||
from textual.screen import Screen
|
from textual.screen import Screen
|
||||||
from textual.widgets import (Button, Footer, Header, Label, ListItem, ListView,
|
from textual.widgets import Button, Footer, Header, Label, ListItem, ListView, Static
|
||||||
Static)
|
|
||||||
|
|
||||||
import heurams.services.timer as timer
|
import heurams.services.timer as timer
|
||||||
import heurams.services.version as version
|
import heurams.services.version as version
|
||||||
@@ -28,6 +27,7 @@ class DashboardScreen(Screen):
|
|||||||
Label(f'欢迎使用 "潜进" 启发式辅助记忆调度器', classes="title-label"),
|
Label(f'欢迎使用 "潜进" 启发式辅助记忆调度器', classes="title-label"),
|
||||||
Label(f"当前 UNIX 日时间戳: {timer.get_daystamp()}"),
|
Label(f"当前 UNIX 日时间戳: {timer.get_daystamp()}"),
|
||||||
Label(f'时区修正: UTC+{config_var.get()["timezone_offset"] / 3600}'),
|
Label(f'时区修正: UTC+{config_var.get()["timezone_offset"] / 3600}'),
|
||||||
|
Label(f"使用算法: {config_var.get()['algorithm']['default']}"),
|
||||||
Label("选择待学习或待修改的记忆单元集:", classes="title-label"),
|
Label("选择待学习或待修改的记忆单元集:", classes="title-label"),
|
||||||
ListView(id="union-list", classes="union-list-view"),
|
ListView(id="union-list", classes="union-list-view"),
|
||||||
Label(
|
Label(
|
||||||
|
|||||||
@@ -148,16 +148,24 @@ class MemScreen(Screen):
|
|||||||
|
|
||||||
def play_voice(self):
|
def play_voice(self):
|
||||||
"""朗读当前内容"""
|
"""朗读当前内容"""
|
||||||
from heurams.services.audio_service import play_by_path
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from heurams.services.audio_service import play_by_path
|
||||||
from heurams.services.hasher import get_md5
|
from heurams.services.hasher import get_md5
|
||||||
path = Path(config_var.get()['paths']["cache_dir"])
|
|
||||||
path = path / f"{get_md5(self.atom.registry['nucleon'].metadata["formation"]["tts_text"])}.wav"
|
path = Path(config_var.get()["paths"]["cache_dir"])
|
||||||
|
path = (
|
||||||
|
path
|
||||||
|
/ f"{get_md5(self.atom.registry['nucleon'].metadata["formation"]["tts_text"])}.wav"
|
||||||
|
)
|
||||||
if path.exists():
|
if path.exists():
|
||||||
play_by_path(path)
|
play_by_path(path)
|
||||||
else:
|
else:
|
||||||
from heurams.services.tts_service import convertor
|
from heurams.services.tts_service import convertor
|
||||||
convertor(self.atom.registry['nucleon'].metadata["formation"]["tts_text"], path)
|
|
||||||
|
convertor(
|
||||||
|
self.atom.registry["nucleon"].metadata["formation"]["tts_text"], path
|
||||||
|
)
|
||||||
play_by_path(path)
|
play_by_path(path)
|
||||||
|
|
||||||
def action_toggle_dark(self):
|
def action_toggle_dark(self):
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ import toml
|
|||||||
from textual.app import ComposeResult
|
from textual.app import ComposeResult
|
||||||
from textual.containers import ScrollableContainer
|
from textual.containers import ScrollableContainer
|
||||||
from textual.screen import Screen
|
from textual.screen import Screen
|
||||||
from textual.widgets import (Button, Footer, Header, Input, Label, Markdown,
|
from textual.widgets import Button, Footer, Header, Input, Label, Markdown, Select
|
||||||
Select)
|
|
||||||
|
|
||||||
from heurams.context import config_var
|
from heurams.context import config_var
|
||||||
from heurams.services.version import ver
|
from heurams.services.version import ver
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ class PrecachingScreen(Screen):
|
|||||||
if not cache_file.exists():
|
if not cache_file.exists():
|
||||||
try:
|
try:
|
||||||
from heurams.services.tts_service import convertor
|
from heurams.services.tts_service import convertor
|
||||||
|
|
||||||
convertor(text, cache_file)
|
convertor(text, cache_file)
|
||||||
return 1
|
return 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -74,11 +74,13 @@ class SyncScreen(Screen):
|
|||||||
"""从配置文件加载同步设置"""
|
"""从配置文件加载同步设置"""
|
||||||
try:
|
try:
|
||||||
from heurams.context import config_var
|
from heurams.context import config_var
|
||||||
|
|
||||||
config_data = config_var.get().data
|
config_data = config_var.get().data
|
||||||
self.sync_config = config_data.get('sync', {}).get('webdav', {})
|
self.sync_config = config_data.get("sync", {}).get("webdav", {})
|
||||||
|
|
||||||
# 创建同步服务实例
|
# 创建同步服务实例
|
||||||
from heurams.services.sync_service import create_sync_service_from_config
|
from heurams.services.sync_service import create_sync_service_from_config
|
||||||
|
|
||||||
self.sync_service = create_sync_service_from_config()
|
self.sync_service = create_sync_service_from_config()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -89,22 +91,22 @@ class SyncScreen(Screen):
|
|||||||
"""更新 UI 显示配置信息"""
|
"""更新 UI 显示配置信息"""
|
||||||
try:
|
try:
|
||||||
# 更新服务器 URL
|
# 更新服务器 URL
|
||||||
url = self.sync_config.get('url', '未配置')
|
url = self.sync_config.get("url", "未配置")
|
||||||
url_widget = self.query_one("#server_url")
|
url_widget = self.query_one("#server_url")
|
||||||
url_widget.update(url if url else '未配置') # type: ignore
|
url_widget.update(url if url else "未配置") # type: ignore
|
||||||
|
|
||||||
# 更新远程路径
|
# 更新远程路径
|
||||||
remote_path = self.sync_config.get('remote_path', '/heurams/')
|
remote_path = self.sync_config.get("remote_path", "/heurams/")
|
||||||
path_widget = self.query_one("#remote_path")
|
path_widget = self.query_one("#remote_path")
|
||||||
path_widget.update(remote_path) # type: ignore
|
path_widget.update(remote_path) # type: ignore
|
||||||
|
|
||||||
# 更新同步模式
|
# 更新同步模式
|
||||||
sync_mode = self.sync_config.get('sync_mode', 'bidirectional')
|
sync_mode = self.sync_config.get("sync_mode", "bidirectional")
|
||||||
mode_widget = self.query_one("#sync_mode")
|
mode_widget = self.query_one("#sync_mode")
|
||||||
mode_map = {
|
mode_map = {
|
||||||
'bidirectional': '双向同步',
|
"bidirectional": "双向同步",
|
||||||
'upload_only': '仅上传',
|
"upload_only": "仅上传",
|
||||||
'download_only': '仅下载',
|
"download_only": "仅下载",
|
||||||
}
|
}
|
||||||
mode_widget.update(mode_map.get(sync_mode, sync_mode)) # type: ignore
|
mode_widget.update(mode_map.get(sync_mode, sync_mode)) # type: ignore
|
||||||
|
|
||||||
@@ -218,44 +220,60 @@ class SyncScreen(Screen):
|
|||||||
try:
|
try:
|
||||||
# 获取需要同步的本地目录
|
# 获取需要同步的本地目录
|
||||||
from heurams.context import config_var
|
from heurams.context import config_var
|
||||||
|
|
||||||
config = config_var.get()
|
config = config_var.get()
|
||||||
paths = config.get('paths', {})
|
paths = config.get("paths", {})
|
||||||
|
|
||||||
# 同步 nucleon 目录
|
# 同步 nucleon 目录
|
||||||
nucleon_dir = pathlib.Path(paths.get('nucleon_dir', './data/nucleon'))
|
nucleon_dir = pathlib.Path(paths.get("nucleon_dir", "./data/nucleon"))
|
||||||
if nucleon_dir.exists():
|
if nucleon_dir.exists():
|
||||||
self.log_message(f"同步 nucleon 目录: {nucleon_dir}")
|
self.log_message(f"同步 nucleon 目录: {nucleon_dir}")
|
||||||
self.update_status(f"同步 nucleon 目录...", progress=10)
|
self.update_status(f"同步 nucleon 目录...", progress=10)
|
||||||
|
|
||||||
result = self.sync_service.sync_directory(nucleon_dir) # type: ignore
|
result = self.sync_service.sync_directory(nucleon_dir) # type: ignore
|
||||||
if result.get('success'):
|
if result.get("success"):
|
||||||
self.log_message(f"nucleon 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个")
|
self.log_message(
|
||||||
|
f"nucleon 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.log_message(f"nucleon 同步失败: {result.get('error', '未知错误')}", is_error=True)
|
self.log_message(
|
||||||
|
f"nucleon 同步失败: {result.get('error', '未知错误')}",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
# 同步 electron 目录
|
# 同步 electron 目录
|
||||||
electron_dir = pathlib.Path(paths.get('electron_dir', './data/electron'))
|
electron_dir = pathlib.Path(paths.get("electron_dir", "./data/electron"))
|
||||||
if electron_dir.exists():
|
if electron_dir.exists():
|
||||||
self.log_message(f"同步 electron 目录: {electron_dir}")
|
self.log_message(f"同步 electron 目录: {electron_dir}")
|
||||||
self.update_status(f"同步 electron 目录...", progress=60)
|
self.update_status(f"同步 electron 目录...", progress=60)
|
||||||
|
|
||||||
result = self.sync_service.sync_directory(electron_dir) # type: ignore
|
result = self.sync_service.sync_directory(electron_dir) # type: ignore
|
||||||
if result.get('success'):
|
if result.get("success"):
|
||||||
self.log_message(f"electron 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个")
|
self.log_message(
|
||||||
|
f"electron 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.log_message(f"electron 同步失败: {result.get('error', '未知错误')}", is_error=True)
|
self.log_message(
|
||||||
|
f"electron 同步失败: {result.get('error', '未知错误')}",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
# 同步 orbital 目录(如果存在)
|
# 同步 orbital 目录(如果存在)
|
||||||
orbital_dir = pathlib.Path(paths.get('orbital_dir', './data/orbital'))
|
orbital_dir = pathlib.Path(paths.get("orbital_dir", "./data/orbital"))
|
||||||
if orbital_dir.exists():
|
if orbital_dir.exists():
|
||||||
self.log_message(f"同步 orbital 目录: {orbital_dir}")
|
self.log_message(f"同步 orbital 目录: {orbital_dir}")
|
||||||
self.update_status(f"同步 orbital 目录...", progress=80)
|
self.update_status(f"同步 orbital 目录...", progress=80)
|
||||||
|
|
||||||
result = self.sync_service.sync_directory(orbital_dir) # type: ignore
|
result = self.sync_service.sync_directory(orbital_dir) # type: ignore
|
||||||
if result.get('success'):
|
if result.get("success"):
|
||||||
self.log_message(f"orbital 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个")
|
self.log_message(
|
||||||
|
f"orbital 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.log_message(f"orbital 同步失败: {result.get('error', '未知错误')}", is_error=True)
|
self.log_message(
|
||||||
|
f"orbital 同步失败: {result.get('error', '未知错误')}",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
# 同步完成
|
# 同步完成
|
||||||
self.update_status("同步完成", progress=100)
|
self.update_status("同步完成", progress=100)
|
||||||
|
|||||||
@@ -50,7 +50,8 @@ class Recognition(BasePuzzleWidget):
|
|||||||
|
|
||||||
def compose(self):
|
def compose(self):
|
||||||
from heurams.context import config_var
|
from heurams.context import config_var
|
||||||
autovoice = config_var.get()['interface']['memorizor']['autovoice']
|
|
||||||
|
autovoice = config_var.get()["interface"]["memorizor"]["autovoice"]
|
||||||
if autovoice:
|
if autovoice:
|
||||||
self.screen.action_play_voice() # type: ignore
|
self.screen.action_play_voice() # type: ignore
|
||||||
cfg: RecognitionConfig = self.atom.registry["orbital"]["puzzles"][self.alia]
|
cfg: RecognitionConfig = self.atom.registry["orbital"]["puzzles"][self.alia]
|
||||||
@@ -72,7 +73,7 @@ class Recognition(BasePuzzleWidget):
|
|||||||
primary = cfg["primary"]
|
primary = cfg["primary"]
|
||||||
|
|
||||||
with Center():
|
with Center():
|
||||||
for i in cfg['top_dim']:
|
for i in cfg["top_dim"]:
|
||||||
yield Static(f"[dim]{i}[/]")
|
yield Static(f"[dim]{i}[/]")
|
||||||
yield Label("")
|
yield Label("")
|
||||||
|
|
||||||
|
|||||||
@@ -10,15 +10,25 @@ MIT 许可证
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import TypedDict
|
|
||||||
import pathlib
|
import pathlib
|
||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
from heurams.context import config_var
|
from heurams.context import config_var
|
||||||
from heurams.kernel.algorithms.sm15m_calc import (MAX_AF, MIN_AF, NOTCH_AF,
|
from heurams.kernel.algorithms.sm15m_calc import (
|
||||||
RANGE_AF, RANGE_REPETITION,
|
MAX_AF,
|
||||||
SM, THRESHOLD_RECALL, Item)
|
MIN_AF,
|
||||||
|
NOTCH_AF,
|
||||||
|
RANGE_AF,
|
||||||
|
RANGE_REPETITION,
|
||||||
|
SM,
|
||||||
|
THRESHOLD_RECALL,
|
||||||
|
Item,
|
||||||
|
)
|
||||||
|
|
||||||
# 全局状态文件路径
|
# 全局状态文件路径
|
||||||
_GLOBAL_STATE_FILE = os.path.expanduser(pathlib.Path(config_var.get()['paths']['global_dir']) / 'sm15m_global_state.json')
|
_GLOBAL_STATE_FILE = os.path.expanduser(
|
||||||
|
pathlib.Path(config_var.get()["paths"]["global_dir"]) / "sm15m_global_state.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_global_sm():
|
def _get_global_sm():
|
||||||
|
|||||||
@@ -86,8 +86,8 @@ class Atom:
|
|||||||
# eval 环境设置
|
# eval 环境设置
|
||||||
def eval_with_env(s: str):
|
def eval_with_env(s: str):
|
||||||
default = config_var.get()["puzzles"]
|
default = config_var.get()["puzzles"]
|
||||||
payload = self.registry['nucleon'].payload
|
payload = self.registry["nucleon"].payload
|
||||||
metadata = self.registry['nucleon'].metadata
|
metadata = self.registry["nucleon"].metadata
|
||||||
eval_value = eval(s)
|
eval_value = eval(s)
|
||||||
if isinstance(eval_value, (int, float)):
|
if isinstance(eval_value, (int, float)):
|
||||||
ret = str(eval_value)
|
ret = str(eval_value)
|
||||||
@@ -117,10 +117,11 @@ class Atom:
|
|||||||
logger.debug("发现 eval 表达式: '%s'", data[5:])
|
logger.debug("发现 eval 表达式: '%s'", data[5:])
|
||||||
return modifier(data[5:])
|
return modifier(data[5:])
|
||||||
return data
|
return data
|
||||||
|
|
||||||
try:
|
try:
|
||||||
traverse(self.registry['nucleon'].payload, eval_with_env)
|
traverse(self.registry["nucleon"].payload, eval_with_env)
|
||||||
traverse(self.registry['nucleon'].metadata, eval_with_env)
|
traverse(self.registry["nucleon"].metadata, eval_with_env)
|
||||||
traverse(self.registry['orbital'], eval_with_env)
|
traverse(self.registry["orbital"], eval_with_env)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ret = f"此 eval 实例发生错误: {e}"
|
ret = f"此 eval 实例发生错误: {e}"
|
||||||
logger.warning(ret)
|
logger.warning(ret)
|
||||||
|
|||||||
@@ -18,9 +18,12 @@ class Electron:
|
|||||||
algo: 使用的算法模块标识
|
algo: 使用的算法模块标识
|
||||||
"""
|
"""
|
||||||
if algo_name == "":
|
if algo_name == "":
|
||||||
algo_name = config_var.get()['algorithm']['default']
|
algo_name = config_var.get()["algorithm"]["default"]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"创建 Electron 实例, ident: '%s', algo_name: '%s', algodata: %s", ident, algo_name, algodata
|
"创建 Electron 实例, ident: '%s', algo_name: '%s', algodata: %s",
|
||||||
|
ident,
|
||||||
|
algo_name,
|
||||||
|
algodata,
|
||||||
)
|
)
|
||||||
self.algodata = algodata
|
self.algodata = algodata
|
||||||
self.ident = ident
|
self.ident = ident
|
||||||
@@ -31,7 +34,9 @@ class Electron:
|
|||||||
self.algodata[self.algo.algo_name] = {}
|
self.algodata[self.algo.algo_name] = {}
|
||||||
logger.debug("算法键 '%s' 不存在, 已创建空字典", self.algo)
|
logger.debug("算法键 '%s' 不存在, 已创建空字典", self.algo)
|
||||||
if not self.algodata[self.algo.algo_name]:
|
if not self.algodata[self.algo.algo_name]:
|
||||||
logger.debug(f"算法数据为空, 使用默认值初始化{self.algodata[self.algo.algo_name]}")
|
logger.debug(
|
||||||
|
f"算法数据为空, 使用默认值初始化{self.algodata[self.algo.algo_name]}"
|
||||||
|
)
|
||||||
self._default_init(self.algo.defaults)
|
self._default_init(self.algo.defaults)
|
||||||
else:
|
else:
|
||||||
logger.debug("算法数据已存在, 跳过默认初始化")
|
logger.debug("算法数据已存在, 跳过默认初始化")
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ import pathlib
|
|||||||
|
|
||||||
import edge_tts
|
import edge_tts
|
||||||
|
|
||||||
from heurams.services.logger import get_logger
|
|
||||||
from heurams.context import config_var
|
from heurams.context import config_var
|
||||||
|
from heurams.services.logger import get_logger
|
||||||
|
|
||||||
from .base import BaseTTS
|
from .base import BaseTTS
|
||||||
|
|
||||||
@@ -19,7 +19,7 @@ class EdgeTTS(BaseTTS):
|
|||||||
try:
|
try:
|
||||||
communicate = edge_tts.Communicate(
|
communicate = edge_tts.Communicate(
|
||||||
text,
|
text,
|
||||||
config_var.get()['providers']['tts']['edgetts']["voice"],
|
config_var.get()["providers"]["tts"]["edgetts"]["voice"],
|
||||||
)
|
)
|
||||||
logger.debug("EdgeTTS 通信对象创建成功, 正在保存音频")
|
logger.debug("EdgeTTS 通信对象创建成功, 正在保存音频")
|
||||||
communicate.save_sync(str(path))
|
communicate.save_sync(str(path))
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
class SyncMode(Enum):
|
class SyncMode(Enum):
|
||||||
"""同步模式枚举"""
|
"""同步模式枚举"""
|
||||||
|
|
||||||
BIDIRECTIONAL = "bidirectional"
|
BIDIRECTIONAL = "bidirectional"
|
||||||
UPLOAD_ONLY = "upload_only"
|
UPLOAD_ONLY = "upload_only"
|
||||||
DOWNLOAD_ONLY = "download_only"
|
DOWNLOAD_ONLY = "download_only"
|
||||||
@@ -25,6 +26,7 @@ class SyncMode(Enum):
|
|||||||
|
|
||||||
class ConflictStrategy(Enum):
|
class ConflictStrategy(Enum):
|
||||||
"""冲突解决策略枚举"""
|
"""冲突解决策略枚举"""
|
||||||
|
|
||||||
NEWER = "newer" # 较新文件覆盖较旧文件
|
NEWER = "newer" # 较新文件覆盖较旧文件
|
||||||
ASK = "ask" # 用户手动选择
|
ASK = "ask" # 用户手动选择
|
||||||
KEEP_BOTH = "keep_both" # 保留双方(重命名)
|
KEEP_BOTH = "keep_both" # 保留双方(重命名)
|
||||||
@@ -33,6 +35,7 @@ class ConflictStrategy(Enum):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class SyncConfig:
|
class SyncConfig:
|
||||||
"""同步配置数据类"""
|
"""同步配置数据类"""
|
||||||
|
|
||||||
enabled: bool = False
|
enabled: bool = False
|
||||||
url: str = ""
|
url: str = ""
|
||||||
username: str = ""
|
username: str = ""
|
||||||
@@ -59,12 +62,12 @@ class SyncService:
|
|||||||
return
|
return
|
||||||
|
|
||||||
options = {
|
options = {
|
||||||
'webdav_hostname': self.config.url,
|
"webdav_hostname": self.config.url,
|
||||||
'webdav_login': self.config.username,
|
"webdav_login": self.config.username,
|
||||||
'webdav_password': self.config.password,
|
"webdav_password": self.config.password,
|
||||||
'webdav_root': self.config.remote_path,
|
"webdav_root": self.config.remote_path,
|
||||||
'verify_ssl': self.config.verify_ssl,
|
"verify_ssl": self.config.verify_ssl,
|
||||||
'disable_check': True, # 不检查服务器支持的功能
|
"disable_check": True, # 不检查服务器支持的功能
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -98,10 +101,10 @@ class SyncService:
|
|||||||
rel_path = file_path.relative_to(local_dir)
|
rel_path = file_path.relative_to(local_dir)
|
||||||
stat = file_path.stat()
|
stat = file_path.stat()
|
||||||
files[str(rel_path)] = {
|
files[str(rel_path)] = {
|
||||||
'path': file_path,
|
"path": file_path,
|
||||||
'size': stat.st_size,
|
"size": stat.st_size,
|
||||||
'mtime': stat.st_mtime,
|
"mtime": stat.st_mtime,
|
||||||
'hash': self._calculate_hash(file_path),
|
"hash": self._calculate_hash(file_path),
|
||||||
}
|
}
|
||||||
return files
|
return files
|
||||||
|
|
||||||
@@ -114,14 +117,14 @@ class SyncService:
|
|||||||
remote_list = self.client.list(recursive=True)
|
remote_list = self.client.list(recursive=True)
|
||||||
files = {}
|
files = {}
|
||||||
for item in remote_list:
|
for item in remote_list:
|
||||||
if not item.endswith('/'): # 忽略目录
|
if not item.endswith("/"): # 忽略目录
|
||||||
rel_path = item.lstrip('/')
|
rel_path = item.lstrip("/")
|
||||||
try:
|
try:
|
||||||
info = self.client.info(item)
|
info = self.client.info(item)
|
||||||
files[rel_path] = {
|
files[rel_path] = {
|
||||||
'path': item,
|
"path": item,
|
||||||
'size': info.get('size', 0),
|
"size": info.get("size", 0),
|
||||||
'mtime': self._parse_remote_mtime(info),
|
"mtime": self._parse_remote_mtime(info),
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("无法获取远程文件信息 %s: %s", item, e)
|
logger.warning("无法获取远程文件信息 %s: %s", item, e)
|
||||||
@@ -134,8 +137,8 @@ class SyncService:
|
|||||||
"""计算文件的 SHA-256 哈希值"""
|
"""计算文件的 SHA-256 哈希值"""
|
||||||
sha256 = hashlib.sha256()
|
sha256 = hashlib.sha256()
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'rb') as f:
|
with open(file_path, "rb") as f:
|
||||||
for block in iter(lambda: f.read(block_size), b''):
|
for block in iter(lambda: f.read(block_size), b""):
|
||||||
sha256.update(block)
|
sha256.update(block)
|
||||||
return sha256.hexdigest()
|
return sha256.hexdigest()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -160,14 +163,14 @@ class SyncService:
|
|||||||
"""
|
"""
|
||||||
if not self.client:
|
if not self.client:
|
||||||
logger.error("WebDAV 客户端未初始化")
|
logger.error("WebDAV 客户端未初始化")
|
||||||
return {'success': False, 'error': '客户端未初始化'}
|
return {"success": False, "error": "客户端未初始化"}
|
||||||
|
|
||||||
results = {
|
results = {
|
||||||
'uploaded': 0,
|
"uploaded": 0,
|
||||||
'downloaded': 0,
|
"downloaded": 0,
|
||||||
'conflicts': 0,
|
"conflicts": 0,
|
||||||
'errors': 0,
|
"errors": 0,
|
||||||
'success': True,
|
"success": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -180,29 +183,33 @@ class SyncService:
|
|||||||
# 根据同步模式处理文件
|
# 根据同步模式处理文件
|
||||||
if self.config.sync_mode in [SyncMode.BIDIRECTIONAL, SyncMode.UPLOAD_ONLY]:
|
if self.config.sync_mode in [SyncMode.BIDIRECTIONAL, SyncMode.UPLOAD_ONLY]:
|
||||||
stats = self._upload_files(local_dir, local_files, remote_files)
|
stats = self._upload_files(local_dir, local_files, remote_files)
|
||||||
results['uploaded'] += stats.get('uploaded', 0)
|
results["uploaded"] += stats.get("uploaded", 0)
|
||||||
results['conflicts'] += stats.get('conflicts', 0)
|
results["conflicts"] += stats.get("conflicts", 0)
|
||||||
results['errors'] += stats.get('errors', 0)
|
results["errors"] += stats.get("errors", 0)
|
||||||
|
|
||||||
if self.config.sync_mode in [SyncMode.BIDIRECTIONAL, SyncMode.DOWNLOAD_ONLY]:
|
if self.config.sync_mode in [
|
||||||
|
SyncMode.BIDIRECTIONAL,
|
||||||
|
SyncMode.DOWNLOAD_ONLY,
|
||||||
|
]:
|
||||||
stats = self._download_files(local_dir, local_files, remote_files)
|
stats = self._download_files(local_dir, local_files, remote_files)
|
||||||
results['downloaded'] += stats.get('downloaded', 0)
|
results["downloaded"] += stats.get("downloaded", 0)
|
||||||
results['conflicts'] += stats.get('conflicts', 0)
|
results["conflicts"] += stats.get("conflicts", 0)
|
||||||
results['errors'] += stats.get('errors', 0)
|
results["errors"] += stats.get("errors", 0)
|
||||||
|
|
||||||
logger.info("同步完成: %s", results)
|
logger.info("同步完成: %s", results)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("同步过程中发生错误: %s", e)
|
logger.error("同步过程中发生错误: %s", e)
|
||||||
results['success'] = False
|
results["success"] = False
|
||||||
results['error'] = str(e)
|
results["error"] = str(e)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _upload_files(self, local_dir: pathlib.Path,
|
def _upload_files(
|
||||||
local_files: dict, remote_files: dict) -> typing.Dict[str, int]:
|
self, local_dir: pathlib.Path, local_files: dict, remote_files: dict
|
||||||
|
) -> typing.Dict[str, int]:
|
||||||
"""上传文件到远程服务器"""
|
"""上传文件到远程服务器"""
|
||||||
stats = {'uploaded': 0, 'errors': 0, 'conflicts': 0}
|
stats = {"uploaded": 0, "errors": 0, "conflicts": 0}
|
||||||
|
|
||||||
for rel_path, local_info in local_files.items():
|
for rel_path, local_info in local_files.items():
|
||||||
remote_info = remote_files.get(rel_path)
|
remote_info = remote_files.get(rel_path)
|
||||||
@@ -216,20 +223,31 @@ class SyncService:
|
|||||||
should_upload = True # 远程不存在
|
should_upload = True # 远程不存在
|
||||||
else:
|
else:
|
||||||
# 检查冲突
|
# 检查冲突
|
||||||
local_mtime = local_info.get('mtime', 0)
|
local_mtime = local_info.get("mtime", 0)
|
||||||
remote_mtime = remote_info.get('mtime', 0)
|
remote_mtime = remote_info.get("mtime", 0)
|
||||||
|
|
||||||
if local_mtime != remote_mtime:
|
if local_mtime != remote_mtime:
|
||||||
# 存在冲突
|
# 存在冲突
|
||||||
stats['conflicts'] += 1
|
stats["conflicts"] += 1
|
||||||
should_upload, should_download = self._handle_conflict(local_info, remote_info)
|
should_upload, should_download = self._handle_conflict(
|
||||||
|
local_info, remote_info
|
||||||
|
)
|
||||||
|
|
||||||
if should_upload and self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH:
|
if (
|
||||||
|
should_upload
|
||||||
|
and self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH
|
||||||
|
):
|
||||||
# 重命名远程文件避免覆盖
|
# 重命名远程文件避免覆盖
|
||||||
conflict_suffix = f".conflict_{int(remote_mtime)}"
|
conflict_suffix = f".conflict_{int(remote_mtime)}"
|
||||||
name, ext = os.path.splitext(rel_path)
|
name, ext = os.path.splitext(rel_path)
|
||||||
new_rel_path = f"{name}{conflict_suffix}{ext}" if ext else f"{name}{conflict_suffix}"
|
new_rel_path = (
|
||||||
remote_path = os.path.join(self.config.remote_path, new_rel_path)
|
f"{name}{conflict_suffix}{ext}"
|
||||||
|
if ext
|
||||||
|
else f"{name}{conflict_suffix}"
|
||||||
|
)
|
||||||
|
remote_path = os.path.join(
|
||||||
|
self.config.remote_path, new_rel_path
|
||||||
|
)
|
||||||
conflict_resolved = True
|
conflict_resolved = True
|
||||||
logger.debug("冲突文件重命名: %s -> %s", rel_path, new_rel_path)
|
logger.debug("冲突文件重命名: %s -> %s", rel_path, new_rel_path)
|
||||||
else:
|
else:
|
||||||
@@ -238,19 +256,20 @@ class SyncService:
|
|||||||
|
|
||||||
if should_upload:
|
if should_upload:
|
||||||
try:
|
try:
|
||||||
self.client.upload_file(local_info['path'], remote_path)
|
self.client.upload_file(local_info["path"], remote_path)
|
||||||
stats['uploaded'] += 1
|
stats["uploaded"] += 1
|
||||||
logger.debug("上传文件: %s -> %s", rel_path, remote_path)
|
logger.debug("上传文件: %s -> %s", rel_path, remote_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("上传文件失败 %s: %s", rel_path, e)
|
logger.error("上传文件失败 %s: %s", rel_path, e)
|
||||||
stats['errors'] += 1
|
stats["errors"] += 1
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def _download_files(self, local_dir: pathlib.Path,
|
def _download_files(
|
||||||
local_files: dict, remote_files: dict) -> typing.Dict[str, int]:
|
self, local_dir: pathlib.Path, local_files: dict, remote_files: dict
|
||||||
|
) -> typing.Dict[str, int]:
|
||||||
"""从远程服务器下载文件"""
|
"""从远程服务器下载文件"""
|
||||||
stats = {'downloaded': 0, 'errors': 0, 'conflicts': 0}
|
stats = {"downloaded": 0, "errors": 0, "conflicts": 0}
|
||||||
|
|
||||||
for rel_path, remote_info in remote_files.items():
|
for rel_path, remote_info in remote_files.items():
|
||||||
local_info = local_files.get(rel_path)
|
local_info = local_files.get(rel_path)
|
||||||
@@ -261,13 +280,15 @@ class SyncService:
|
|||||||
should_download = True # 本地不存在
|
should_download = True # 本地不存在
|
||||||
else:
|
else:
|
||||||
# 检查冲突
|
# 检查冲突
|
||||||
local_mtime = local_info.get('mtime', 0)
|
local_mtime = local_info.get("mtime", 0)
|
||||||
remote_mtime = remote_info.get('mtime', 0)
|
remote_mtime = remote_info.get("mtime", 0)
|
||||||
|
|
||||||
if local_mtime != remote_mtime:
|
if local_mtime != remote_mtime:
|
||||||
# 存在冲突
|
# 存在冲突
|
||||||
stats['conflicts'] += 1
|
stats["conflicts"] += 1
|
||||||
should_upload, should_download = self._handle_conflict(local_info, remote_info)
|
should_upload, should_download = self._handle_conflict(
|
||||||
|
local_info, remote_info
|
||||||
|
)
|
||||||
# 如果应该上传,则不应该下载(冲突已在上传侧处理)
|
# 如果应该上传,则不应该下载(冲突已在上传侧处理)
|
||||||
if should_upload:
|
if should_upload:
|
||||||
should_download = False
|
should_download = False
|
||||||
@@ -279,24 +300,26 @@ class SyncService:
|
|||||||
try:
|
try:
|
||||||
local_path = local_dir / rel_path
|
local_path = local_dir / rel_path
|
||||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
self.client.download_file(remote_info['path'], str(local_path))
|
self.client.download_file(remote_info["path"], str(local_path))
|
||||||
stats['downloaded'] += 1
|
stats["downloaded"] += 1
|
||||||
logger.debug("下载文件: %s -> %s", rel_path, local_path)
|
logger.debug("下载文件: %s -> %s", rel_path, local_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("下载文件失败 %s: %s", rel_path, e)
|
logger.error("下载文件失败 %s: %s", rel_path, e)
|
||||||
stats['errors'] += 1
|
stats["errors"] += 1
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def _handle_conflict(self, local_info: dict, remote_info: dict) -> typing.Tuple[bool, bool]:
|
def _handle_conflict(
|
||||||
|
self, local_info: dict, remote_info: dict
|
||||||
|
) -> typing.Tuple[bool, bool]:
|
||||||
"""
|
"""
|
||||||
处理文件冲突
|
处理文件冲突
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(should_upload, should_download) - 是否应该上传和下载
|
(should_upload, should_download) - 是否应该上传和下载
|
||||||
"""
|
"""
|
||||||
local_mtime = local_info.get('mtime', 0)
|
local_mtime = local_info.get("mtime", 0)
|
||||||
remote_mtime = remote_info.get('mtime', 0)
|
remote_mtime = remote_info.get("mtime", 0)
|
||||||
|
|
||||||
if self.config.conflict_strategy == ConflictStrategy.NEWER:
|
if self.config.conflict_strategy == ConflictStrategy.NEWER:
|
||||||
# 较新文件覆盖较旧文件
|
# 较新文件覆盖较旧文件
|
||||||
@@ -318,8 +341,11 @@ class SyncService:
|
|||||||
elif self.config.conflict_strategy == ConflictStrategy.ASK:
|
elif self.config.conflict_strategy == ConflictStrategy.ASK:
|
||||||
# 用户手动选择 - 记录冲突,跳过
|
# 用户手动选择 - 记录冲突,跳过
|
||||||
# 返回 False, False 跳过,等待用户决定
|
# 返回 False, False 跳过,等待用户决定
|
||||||
logger.warning("文件冲突需要用户手动选择: local_mtime=%s, remote_mtime=%s",
|
logger.warning(
|
||||||
local_mtime, remote_mtime)
|
"文件冲突需要用户手动选择: local_mtime=%s, remote_mtime=%s",
|
||||||
|
local_mtime,
|
||||||
|
remote_mtime,
|
||||||
|
)
|
||||||
return False, False
|
return False, False
|
||||||
|
|
||||||
return False, False
|
return False, False
|
||||||
@@ -328,11 +354,11 @@ class SyncService:
|
|||||||
"""判断是否需要上传(本地较新或哈希不同)"""
|
"""判断是否需要上传(本地较新或哈希不同)"""
|
||||||
# 这里实现简单的基于时间的比较
|
# 这里实现简单的基于时间的比较
|
||||||
# 实际应该使用哈希比较更可靠
|
# 实际应该使用哈希比较更可靠
|
||||||
return local_info.get('mtime', 0) > remote_info.get('mtime', 0)
|
return local_info.get("mtime", 0) > remote_info.get("mtime", 0)
|
||||||
|
|
||||||
def _should_download(self, local_info: dict, remote_info: dict) -> bool:
|
def _should_download(self, local_info: dict, remote_info: dict) -> bool:
|
||||||
"""判断是否需要下载(远程较新)"""
|
"""判断是否需要下载(远程较新)"""
|
||||||
return remote_info.get('mtime', 0) > local_info.get('mtime', 0)
|
return remote_info.get("mtime", 0) > local_info.get("mtime", 0)
|
||||||
|
|
||||||
def upload_file(self, local_path: pathlib.Path, remote_path: str = "") -> bool:
|
def upload_file(self, local_path: pathlib.Path, remote_path: str = "") -> bool:
|
||||||
"""上传单个文件"""
|
"""上传单个文件"""
|
||||||
@@ -382,20 +408,22 @@ def create_sync_service_from_config() -> typing.Optional[SyncService]:
|
|||||||
try:
|
try:
|
||||||
from heurams.context import config_var
|
from heurams.context import config_var
|
||||||
|
|
||||||
sync_config = config_var.get()['providers']['sync']['webdav']
|
sync_config = config_var.get()["providers"]["sync"]["webdav"]
|
||||||
if not sync_config.get('enabled', False):
|
if not sync_config.get("enabled", False):
|
||||||
logger.debug("同步服务未启用")
|
logger.debug("同步服务未启用")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
config = SyncConfig(
|
config = SyncConfig(
|
||||||
enabled=sync_config.get('enabled', False),
|
enabled=sync_config.get("enabled", False),
|
||||||
url=sync_config.get('url', ''),
|
url=sync_config.get("url", ""),
|
||||||
username=sync_config.get('username', ''),
|
username=sync_config.get("username", ""),
|
||||||
password=sync_config.get('password', ''),
|
password=sync_config.get("password", ""),
|
||||||
remote_path=sync_config.get('remote_path', '/heurams/'),
|
remote_path=sync_config.get("remote_path", "/heurams/"),
|
||||||
sync_mode=SyncMode(sync_config.get('sync_mode', 'bidirectional')),
|
sync_mode=SyncMode(sync_config.get("sync_mode", "bidirectional")),
|
||||||
conflict_strategy=ConflictStrategy(sync_config.get('conflict_strategy', 'newer')),
|
conflict_strategy=ConflictStrategy(
|
||||||
verify_ssl=sync_config.get('verify_ssl', True),
|
sync_config.get("conflict_strategy", "newer")
|
||||||
|
),
|
||||||
|
verify_ssl=sync_config.get("verify_ssl", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
service = SyncService(config)
|
service = SyncService(config)
|
||||||
|
|||||||
@@ -6,11 +6,16 @@ import pathlib
|
|||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import MagicMock, patch, Mock
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
from heurams.context import ConfigContext
|
from heurams.context import ConfigContext
|
||||||
from heurams.services.config import ConfigFile
|
from heurams.services.config import ConfigFile
|
||||||
from heurams.services.sync_service import SyncService, SyncConfig, SyncMode, ConflictStrategy
|
from heurams.services.sync_service import (
|
||||||
|
ConflictStrategy,
|
||||||
|
SyncConfig,
|
||||||
|
SyncMode,
|
||||||
|
SyncService,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSyncServiceUnit(unittest.TestCase):
|
class TestSyncServiceUnit(unittest.TestCase):
|
||||||
@@ -44,7 +49,7 @@ class TestSyncServiceUnit(unittest.TestCase):
|
|||||||
"""在每个测试之后清理."""
|
"""在每个测试之后清理."""
|
||||||
self.temp_dir.cleanup()
|
self.temp_dir.cleanup()
|
||||||
|
|
||||||
@patch('heurams.services.sync_service.Client')
|
@patch("heurams.services.sync_service.Client")
|
||||||
def test_sync_service_initialization(self, mock_client_class):
|
def test_sync_service_initialization(self, mock_client_class):
|
||||||
"""测试同步服务初始化."""
|
"""测试同步服务初始化."""
|
||||||
mock_client_class.return_value = self.mock_client
|
mock_client_class.return_value = self.mock_client
|
||||||
@@ -56,7 +61,7 @@ class TestSyncServiceUnit(unittest.TestCase):
|
|||||||
self.assertIsNotNone(service.client)
|
self.assertIsNotNone(service.client)
|
||||||
self.assertEqual(service.config, self.config)
|
self.assertEqual(service.config, self.config)
|
||||||
|
|
||||||
@patch('heurams.services.sync_service.Client')
|
@patch("heurams.services.sync_service.Client")
|
||||||
def test_sync_service_disabled(self, mock_client_class):
|
def test_sync_service_disabled(self, mock_client_class):
|
||||||
"""测试同步服务未启用."""
|
"""测试同步服务未启用."""
|
||||||
config = SyncConfig(enabled=False)
|
config = SyncConfig(enabled=False)
|
||||||
@@ -66,7 +71,7 @@ class TestSyncServiceUnit(unittest.TestCase):
|
|||||||
mock_client_class.assert_not_called()
|
mock_client_class.assert_not_called()
|
||||||
self.assertIsNone(service.client)
|
self.assertIsNone(service.client)
|
||||||
|
|
||||||
@patch('heurams.services.sync_service.Client')
|
@patch("heurams.services.sync_service.Client")
|
||||||
def test_test_connection_success(self, mock_client_class):
|
def test_test_connection_success(self, mock_client_class):
|
||||||
"""测试连接测试成功."""
|
"""测试连接测试成功."""
|
||||||
mock_client_class.return_value = self.mock_client
|
mock_client_class.return_value = self.mock_client
|
||||||
@@ -78,7 +83,7 @@ class TestSyncServiceUnit(unittest.TestCase):
|
|||||||
self.assertTrue(result)
|
self.assertTrue(result)
|
||||||
self.mock_client.list.assert_called_once()
|
self.mock_client.list.assert_called_once()
|
||||||
|
|
||||||
@patch('heurams.services.sync_service.Client')
|
@patch("heurams.services.sync_service.Client")
|
||||||
def test_test_connection_failure(self, mock_client_class):
|
def test_test_connection_failure(self, mock_client_class):
|
||||||
"""测试连接测试失败."""
|
"""测试连接测试失败."""
|
||||||
mock_client_class.return_value = self.mock_client
|
mock_client_class.return_value = self.mock_client
|
||||||
@@ -90,7 +95,7 @@ class TestSyncServiceUnit(unittest.TestCase):
|
|||||||
self.assertFalse(result)
|
self.assertFalse(result)
|
||||||
self.mock_client.list.assert_called_once()
|
self.mock_client.list.assert_called_once()
|
||||||
|
|
||||||
@patch('heurams.services.sync_service.Client')
|
@patch("heurams.services.sync_service.Client")
|
||||||
def test_upload_file(self, mock_client_class):
|
def test_upload_file(self, mock_client_class):
|
||||||
"""测试上传单个文件."""
|
"""测试上传单个文件."""
|
||||||
mock_client_class.return_value = self.mock_client
|
mock_client_class.return_value = self.mock_client
|
||||||
@@ -101,7 +106,7 @@ class TestSyncServiceUnit(unittest.TestCase):
|
|||||||
self.assertTrue(result)
|
self.assertTrue(result)
|
||||||
self.mock_client.upload_file.assert_called_once()
|
self.mock_client.upload_file.assert_called_once()
|
||||||
|
|
||||||
@patch('heurams.services.sync_service.Client')
|
@patch("heurams.services.sync_service.Client")
|
||||||
def test_download_file(self, mock_client_class):
|
def test_download_file(self, mock_client_class):
|
||||||
"""测试下载单个文件."""
|
"""测试下载单个文件."""
|
||||||
mock_client_class.return_value = self.mock_client
|
mock_client_class.return_value = self.mock_client
|
||||||
@@ -116,7 +121,7 @@ class TestSyncServiceUnit(unittest.TestCase):
|
|||||||
self.mock_client.download_file.assert_called_once()
|
self.mock_client.download_file.assert_called_once()
|
||||||
self.assertTrue(local_path.parent.exists())
|
self.assertTrue(local_path.parent.exists())
|
||||||
|
|
||||||
@patch('heurams.services.sync_service.Client')
|
@patch("heurams.services.sync_service.Client")
|
||||||
def test_sync_directory_no_files(self, mock_client_class):
|
def test_sync_directory_no_files(self, mock_client_class):
|
||||||
"""测试同步空目录."""
|
"""测试同步空目录."""
|
||||||
mock_client_class.return_value = self.mock_client
|
mock_client_class.return_value = self.mock_client
|
||||||
@@ -126,12 +131,12 @@ class TestSyncServiceUnit(unittest.TestCase):
|
|||||||
service = SyncService(self.config)
|
service = SyncService(self.config)
|
||||||
result = service.sync_directory(self.temp_path)
|
result = service.sync_directory(self.temp_path)
|
||||||
|
|
||||||
self.assertTrue(result['success'])
|
self.assertTrue(result["success"])
|
||||||
self.assertEqual(result['uploaded'], 0)
|
self.assertEqual(result["uploaded"], 0)
|
||||||
self.assertEqual(result['downloaded'], 0)
|
self.assertEqual(result["downloaded"], 0)
|
||||||
self.mock_client.mkdir.assert_called_once()
|
self.mock_client.mkdir.assert_called_once()
|
||||||
|
|
||||||
@patch('heurams.services.sync_service.Client')
|
@patch("heurams.services.sync_service.Client")
|
||||||
def test_sync_directory_upload_only(self, mock_client_class):
|
def test_sync_directory_upload_only(self, mock_client_class):
|
||||||
"""测试仅上传模式."""
|
"""测试仅上传模式."""
|
||||||
mock_client_class.return_value = self.mock_client
|
mock_client_class.return_value = self.mock_client
|
||||||
@@ -151,54 +156,58 @@ class TestSyncServiceUnit(unittest.TestCase):
|
|||||||
service = SyncService(config)
|
service = SyncService(config)
|
||||||
result = service.sync_directory(self.temp_path)
|
result = service.sync_directory(self.temp_path)
|
||||||
|
|
||||||
self.assertTrue(result['success'])
|
self.assertTrue(result["success"])
|
||||||
self.mock_client.mkdir.assert_called_once()
|
self.mock_client.mkdir.assert_called_once()
|
||||||
|
|
||||||
@patch('heurams.services.sync_service.Client')
|
@patch("heurams.services.sync_service.Client")
|
||||||
def test_conflict_strategy_newer(self, mock_client_class):
|
def test_conflict_strategy_newer(self, mock_client_class):
|
||||||
"""测试 NEWER 冲突策略."""
|
"""测试 NEWER 冲突策略."""
|
||||||
mock_client_class.return_value = self.mock_client
|
mock_client_class.return_value = self.mock_client
|
||||||
|
|
||||||
# 模拟远程文件存在
|
# 模拟远程文件存在
|
||||||
self.mock_client.list.return_value = ["test.txt"]
|
self.mock_client.list.return_value = ["test.txt"]
|
||||||
self.mock_client.info.return_value = {'size': 100, 'modified': '2023-01-01T00:00:00Z'}
|
self.mock_client.info.return_value = {
|
||||||
|
"size": 100,
|
||||||
|
"modified": "2023-01-01T00:00:00Z",
|
||||||
|
}
|
||||||
self.mock_client.mkdir.return_value = None
|
self.mock_client.mkdir.return_value = None
|
||||||
|
|
||||||
service = SyncService(self.config)
|
service = SyncService(self.config)
|
||||||
result = service.sync_directory(self.temp_path)
|
result = service.sync_directory(self.temp_path)
|
||||||
|
|
||||||
self.assertTrue(result['success'])
|
self.assertTrue(result["success"])
|
||||||
# 应该有一个冲突
|
# 应该有一个冲突
|
||||||
self.assertGreaterEqual(result.get('conflicts', 0), 0)
|
self.assertGreaterEqual(result.get("conflicts", 0), 0)
|
||||||
|
|
||||||
@patch('heurams.services.sync_service.Client')
|
@patch("heurams.services.sync_service.Client")
|
||||||
def test_create_sync_service_from_config(self, mock_client_class):
|
def test_create_sync_service_from_config(self, mock_client_class):
|
||||||
"""测试从配置文件创建同步服务."""
|
"""测试从配置文件创建同步服务."""
|
||||||
mock_client_class.return_value = self.mock_client
|
mock_client_class.return_value = self.mock_client
|
||||||
|
|
||||||
# 创建临时配置文件
|
# 创建临时配置文件
|
||||||
config_data = {
|
config_data = {
|
||||||
'sync': {
|
"sync": {
|
||||||
'webdav': {
|
"webdav": {
|
||||||
'enabled': True,
|
"enabled": True,
|
||||||
'url': 'https://example.com/dav/',
|
"url": "https://example.com/dav/",
|
||||||
'username': 'test',
|
"username": "test",
|
||||||
'password': 'test',
|
"password": "test",
|
||||||
'remote_path': '/heurams/',
|
"remote_path": "/heurams/",
|
||||||
'sync_mode': 'bidirectional',
|
"sync_mode": "bidirectional",
|
||||||
'conflict_strategy': 'newer',
|
"conflict_strategy": "newer",
|
||||||
'verify_ssl': True,
|
"verify_ssl": True,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# 模拟 config_var
|
# 模拟 config_var
|
||||||
with patch('heurams.services.sync_service.config_var') as mock_config_var:
|
with patch("heurams.services.sync_service.config_var") as mock_config_var:
|
||||||
mock_config = MagicMock()
|
mock_config = MagicMock()
|
||||||
mock_config.data = config_data
|
mock_config.data = config_data
|
||||||
mock_config_var.get.return_value = mock_config
|
mock_config_var.get.return_value = mock_config
|
||||||
|
|
||||||
from heurams.services.sync_service import create_sync_service_from_config
|
from heurams.services.sync_service import create_sync_service_from_config
|
||||||
|
|
||||||
service = create_sync_service_from_config()
|
service = create_sync_service_from_config()
|
||||||
|
|
||||||
self.assertIsNotNone(service)
|
self.assertIsNotNone(service)
|
||||||
@@ -228,22 +237,24 @@ class TestSyncScreenUnit(unittest.TestCase):
|
|||||||
config_data["paths"]["cache_dir"] = str(self.temp_path / "cache")
|
config_data["paths"]["cache_dir"] = str(self.temp_path / "cache")
|
||||||
|
|
||||||
# 添加同步配置
|
# 添加同步配置
|
||||||
if 'sync' not in config_data:
|
if "sync" not in config_data:
|
||||||
config_data['sync'] = {}
|
config_data["sync"] = {}
|
||||||
config_data['sync']['webdav'] = {
|
config_data["sync"]["webdav"] = {
|
||||||
'enabled': False,
|
"enabled": False,
|
||||||
'url': '',
|
"url": "",
|
||||||
'username': '',
|
"username": "",
|
||||||
'password': '',
|
"password": "",
|
||||||
'remote_path': '/heurams/',
|
"remote_path": "/heurams/",
|
||||||
'sync_mode': 'bidirectional',
|
"sync_mode": "bidirectional",
|
||||||
'conflict_strategy': 'newer',
|
"conflict_strategy": "newer",
|
||||||
'verify_ssl': True,
|
"verify_ssl": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 创建目录
|
# 创建目录
|
||||||
for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]:
|
for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]:
|
||||||
pathlib.Path(config_data["paths"][dir_key]).mkdir(parents=True, exist_ok=True)
|
pathlib.Path(config_data["paths"][dir_key]).mkdir(
|
||||||
|
parents=True, exist_ok=True
|
||||||
|
)
|
||||||
|
|
||||||
# 使用 ConfigContext 设置配置
|
# 使用 ConfigContext 设置配置
|
||||||
self.config_ctx = ConfigContext(self.config)
|
self.config_ctx = ConfigContext(self.config)
|
||||||
@@ -254,7 +265,7 @@ class TestSyncScreenUnit(unittest.TestCase):
|
|||||||
self.config_ctx.__exit__(None, None, None)
|
self.config_ctx.__exit__(None, None, None)
|
||||||
self.temp_dir.cleanup()
|
self.temp_dir.cleanup()
|
||||||
|
|
||||||
@patch('heurams.interface.screens.synctool.create_sync_service_from_config')
|
@patch("heurams.interface.screens.synctool.create_sync_service_from_config")
|
||||||
def test_sync_screen_compose(self, mock_create_service):
|
def test_sync_screen_compose(self, mock_create_service):
|
||||||
"""测试 SyncScreen 的 compose 方法."""
|
"""测试 SyncScreen 的 compose 方法."""
|
||||||
from heurams.interface.screens.synctool import SyncScreen
|
from heurams.interface.screens.synctool import SyncScreen
|
||||||
@@ -268,12 +279,13 @@ class TestSyncScreenUnit(unittest.TestCase):
|
|||||||
|
|
||||||
# 测试 compose 方法
|
# 测试 compose 方法
|
||||||
from textual.app import ComposeResult
|
from textual.app import ComposeResult
|
||||||
|
|
||||||
result = screen.compose()
|
result = screen.compose()
|
||||||
widgets = list(result)
|
widgets = list(result)
|
||||||
|
|
||||||
# 检查基本部件
|
# 检查基本部件
|
||||||
from textual.widgets import Footer, Header, Button, Static, ProgressBar
|
|
||||||
from textual.containers import ScrollableContainer
|
from textual.containers import ScrollableContainer
|
||||||
|
from textual.widgets import Button, Footer, Header, ProgressBar, Static
|
||||||
|
|
||||||
header_present = any(isinstance(w, Header) for w in widgets)
|
header_present = any(isinstance(w, Header) for w in widgets)
|
||||||
footer_present = any(isinstance(w, Footer) for w in widgets)
|
footer_present = any(isinstance(w, Footer) for w in widgets)
|
||||||
@@ -284,7 +296,7 @@ class TestSyncScreenUnit(unittest.TestCase):
|
|||||||
container_present = any(isinstance(w, ScrollableContainer) for w in widgets)
|
container_present = any(isinstance(w, ScrollableContainer) for w in widgets)
|
||||||
self.assertTrue(container_present)
|
self.assertTrue(container_present)
|
||||||
|
|
||||||
@patch('heurams.interface.screens.synctool.create_sync_service_from_config')
|
@patch("heurams.interface.screens.synctool.create_sync_service_from_config")
|
||||||
def test_sync_screen_load_config(self, mock_create_service):
|
def test_sync_screen_load_config(self, mock_create_service):
|
||||||
"""测试 SyncScreen 加载配置."""
|
"""测试 SyncScreen 加载配置."""
|
||||||
from heurams.interface.screens.synctool import SyncScreen
|
from heurams.interface.screens.synctool import SyncScreen
|
||||||
|
|||||||
Reference in New Issue
Block a user