diff --git a/src/heurams/context.py b/src/heurams/context.py index 06b2e34..92eab06 100644 --- a/src/heurams/context.py +++ b/src/heurams/context.py @@ -36,8 +36,8 @@ if pathlib.Path(workdir / "config" / "config_dev.toml").exists(): print("使用开发设置") logger.debug("使用开发设置") config_var: ContextVar[ConfigFile] = ContextVar( - "config_var", default=ConfigFile(workdir / "config" / "config_dev.toml") -) + "config_var", default=ConfigFile(workdir / "config" / "config_dev.toml") + ) # runtime_var: ContextVar = ContextVar('runtime_var', default=dict()) # 运行时共享数据 diff --git a/src/heurams/interface/__main__.py b/src/heurams/interface/__main__.py index cf62f5a..a6dae83 100644 --- a/src/heurams/interface/__main__.py +++ b/src/heurams/interface/__main__.py @@ -1,9 +1,9 @@ from textual.app import App from textual.widgets import Button -from heurams.services.logger import get_logger from heurams.context import config_var from heurams.interface import HeurAMSApp +from heurams.services.logger import get_logger from .screens.about import AboutScreen from .screens.dashboard import DashboardScreen @@ -15,4 +15,4 @@ logger = get_logger(__name__) app = HeurAMSApp() if __name__ == "__main__": - app.run() \ No newline at end of file + app.run() diff --git a/src/heurams/interface/screens/dashboard.py b/src/heurams/interface/screens/dashboard.py index b7b8eb7..ed09e22 100644 --- a/src/heurams/interface/screens/dashboard.py +++ b/src/heurams/interface/screens/dashboard.py @@ -4,8 +4,7 @@ import pathlib from textual.app import ComposeResult from textual.containers import ScrollableContainer from textual.screen import Screen -from textual.widgets import (Button, Footer, Header, Label, ListItem, ListView, - Static) +from textual.widgets import Button, Footer, Header, Label, ListItem, ListView, Static import heurams.services.timer as timer import heurams.services.version as version @@ -28,6 +27,7 @@ class DashboardScreen(Screen): Label(f'欢迎使用 "潜进" 启发式辅助记忆调度器', classes="title-label"), Label(f"当前 UNIX 日时间戳: {timer.get_daystamp()}"), Label(f'时区修正: UTC+{config_var.get()["timezone_offset"] / 3600}'), + Label(f"使用算法: {config_var.get()['algorithm']['default']}"), Label("选择待学习或待修改的记忆单元集:", classes="title-label"), ListView(id="union-list", classes="union-list-view"), Label( diff --git a/src/heurams/interface/screens/memorizor.py b/src/heurams/interface/screens/memorizor.py index 9394a00..9bf03f7 100644 --- a/src/heurams/interface/screens/memorizor.py +++ b/src/heurams/interface/screens/memorizor.py @@ -148,16 +148,24 @@ class MemScreen(Screen): def play_voice(self): """朗读当前内容""" - from heurams.services.audio_service import play_by_path from pathlib import Path + + from heurams.services.audio_service import play_by_path from heurams.services.hasher import get_md5 - path = Path(config_var.get()['paths']["cache_dir"]) - path = path / f"{get_md5(self.atom.registry['nucleon'].metadata["formation"]["tts_text"])}.wav" + + path = Path(config_var.get()["paths"]["cache_dir"]) + path = ( + path + / f"{get_md5(self.atom.registry['nucleon'].metadata["formation"]["tts_text"])}.wav" + ) if path.exists(): play_by_path(path) else: from heurams.services.tts_service import convertor - convertor(self.atom.registry['nucleon'].metadata["formation"]["tts_text"], path) + + convertor( + self.atom.registry["nucleon"].metadata["formation"]["tts_text"], path + ) play_by_path(path) def action_toggle_dark(self): diff --git a/src/heurams/interface/screens/nucreator.py b/src/heurams/interface/screens/nucreator.py index 735dda6..b187528 100644 --- a/src/heurams/interface/screens/nucreator.py +++ b/src/heurams/interface/screens/nucreator.py @@ -5,8 +5,7 @@ import toml from textual.app import ComposeResult from textual.containers import ScrollableContainer from textual.screen import Screen -from textual.widgets import (Button, Footer, Header, Input, Label, Markdown, - Select) +from textual.widgets import Button, Footer, Header, Input, Label, Markdown, Select from heurams.context import config_var from heurams.services.version import ver diff --git a/src/heurams/interface/screens/precache.py b/src/heurams/interface/screens/precache.py index a8ab604..4c03bf2 100644 --- a/src/heurams/interface/screens/precache.py +++ b/src/heurams/interface/screens/precache.py @@ -99,6 +99,7 @@ class PrecachingScreen(Screen): if not cache_file.exists(): try: from heurams.services.tts_service import convertor + convertor(text, cache_file) return 1 except Exception as e: diff --git a/src/heurams/interface/screens/synctool.py b/src/heurams/interface/screens/synctool.py index b62f44b..e07a407 100644 --- a/src/heurams/interface/screens/synctool.py +++ b/src/heurams/interface/screens/synctool.py @@ -32,7 +32,7 @@ class SyncScreen(Screen): # 标题和连接状态 yield Static("WebDAV 同步工具", classes="title") yield Static("", id="status_label", classes="status") - + # 配置信息 yield Static("服务器配置", classes="section_title") with Horizontal(classes="config_info"): @@ -44,7 +44,7 @@ class SyncScreen(Screen): with Horizontal(classes="config_info"): yield Static("同步模式:", classes="config_label") yield Static("", id="sync_mode", classes="config_value") - + # 控制按钮 yield Static("控制面板", classes="section_title") with Horizontal(classes="control_buttons"): @@ -52,16 +52,16 @@ class SyncScreen(Screen): yield Button("开始同步", id="start_sync", variant="success") yield Button("暂停", id="pause_sync", variant="warning", disabled=True) yield Button("取消", id="cancel_sync", variant="error", disabled=True) - + # 进度显示 yield Static("同步进度", classes="section_title") yield ProgressBar(id="progress_bar", show_percentage=True, total=100) yield Static("", id="progress_label", classes="progress_text") - + # 日志输出 yield Static("同步日志", classes="section_title") yield Static("", id="log_output", classes="log_output") - + yield Footer() def on_mount(self): @@ -74,13 +74,15 @@ class SyncScreen(Screen): """从配置文件加载同步设置""" try: from heurams.context import config_var + config_data = config_var.get().data - self.sync_config = config_data.get('sync', {}).get('webdav', {}) - + self.sync_config = config_data.get("sync", {}).get("webdav", {}) + # 创建同步服务实例 from heurams.services.sync_service import create_sync_service_from_config + self.sync_service = create_sync_service_from_config() - + except Exception as e: self.log_message(f"加载配置失败: {e}", is_error=True) self.sync_config = {} @@ -89,34 +91,34 @@ class SyncScreen(Screen): """更新 UI 显示配置信息""" try: # 更新服务器 URL - url = self.sync_config.get('url', '未配置') + url = self.sync_config.get("url", "未配置") url_widget = self.query_one("#server_url") - url_widget.update(url if url else '未配置') # type: ignore - + url_widget.update(url if url else "未配置") # type: ignore + # 更新远程路径 - remote_path = self.sync_config.get('remote_path', '/heurams/') + remote_path = self.sync_config.get("remote_path", "/heurams/") path_widget = self.query_one("#remote_path") - path_widget.update(remote_path) # type: ignore - + path_widget.update(remote_path) # type: ignore + # 更新同步模式 - sync_mode = self.sync_config.get('sync_mode', 'bidirectional') + sync_mode = self.sync_config.get("sync_mode", "bidirectional") mode_widget = self.query_one("#sync_mode") mode_map = { - 'bidirectional': '双向同步', - 'upload_only': '仅上传', - 'download_only': '仅下载', + "bidirectional": "双向同步", + "upload_only": "仅上传", + "download_only": "仅下载", } - mode_widget.update(mode_map.get(sync_mode, sync_mode)) # type: ignore - + mode_widget.update(mode_map.get(sync_mode, sync_mode)) # type: ignore + # 更新状态标签 status_widget = self.query_one("#status_label") if self.sync_service and self.sync_service.client: - status_widget.update("✅ 同步服务已就绪") # type: ignore + status_widget.update("✅ 同步服务已就绪") # type: ignore status_widget.add_class("ready") else: - status_widget.update("❌ 同步服务未配置或未启用") # type: ignore + status_widget.update("❌ 同步服务未配置或未启用") # type: ignore status_widget.add_class("error") - + except Exception as e: self.log_message(f"更新 UI 失败: {e}", is_error=True) @@ -124,15 +126,15 @@ class SyncScreen(Screen): """更新状态显示""" try: status_widget = self.query_one("#status_label") - status_widget.update(status) # type: ignore - + status_widget.update(status) # type: ignore + if progress is not None: progress_bar = self.query_one("#progress_bar") - progress_bar.progress = progress # type: ignore - + progress_bar.progress = progress # type: ignore + progress_label = self.query_one("#progress_label") - progress_label.update(f"{progress}% - {current_item}" if current_item else f"{progress}%") # type: ignore - + progress_label.update(f"{progress}% - {current_item}" if current_item else f"{progress}%") # type: ignore + except Exception as e: self.log_message(f"更新状态失败: {e}", is_error=True) @@ -141,23 +143,23 @@ class SyncScreen(Screen): timestamp = time.strftime("%H:%M:%S") prefix = "[ERROR]" if is_error else "[INFO]" log_line = f"{timestamp} {prefix} {message}" - + self.log_messages.append(log_line) # 保持日志行数不超过最大值 if len(self.log_messages) > self.max_log_lines: - self.log_messages = self.log_messages[-self.max_log_lines:] - + self.log_messages = self.log_messages[-self.max_log_lines :] + # 更新日志显示 try: log_widget = self.query_one("#log_output") - log_widget.update("\n".join(self.log_messages)) # type: ignore + log_widget.update("\n".join(self.log_messages)) # type: ignore except Exception: pass # 如果组件未就绪,忽略错误 def on_button_pressed(self, event: Button.Pressed) -> None: """处理按钮点击事件""" button_id = event.button.id - + if button_id == "test_connection": self.test_connection() elif button_id == "start_sync": @@ -166,7 +168,7 @@ class SyncScreen(Screen): self.pause_sync() elif button_id == "cancel_sync": self.cancel_sync() - + event.stop() def test_connection(self): @@ -175,10 +177,10 @@ class SyncScreen(Screen): self.log_message("同步服务未初始化,请检查配置", is_error=True) self.update_status("❌ 同步服务未初始化") return - + self.log_message("正在测试 WebDAV 连接...") self.update_status("正在测试连接...") - + try: success = self.sync_service.test_connection() if success: @@ -196,71 +198,87 @@ class SyncScreen(Screen): if not self.sync_service: self.log_message("同步服务未初始化,无法开始同步", is_error=True) return - + if self.is_syncing: self.log_message("同步已在进行中", is_error=True) return - + self.is_syncing = True self.is_paused = False self.update_button_states() - + self.log_message("开始同步数据...") self.update_status("正在同步...", progress=0) - + # 启动后台同步任务 self.run_worker(self.perform_sync, thread=True) def perform_sync(self): """执行同步任务(在后台线程中运行)""" worker = get_current_worker() - + try: # 获取需要同步的本地目录 from heurams.context import config_var + config = config_var.get() - paths = config.get('paths', {}) - + paths = config.get("paths", {}) + # 同步 nucleon 目录 - nucleon_dir = pathlib.Path(paths.get('nucleon_dir', './data/nucleon')) + nucleon_dir = pathlib.Path(paths.get("nucleon_dir", "./data/nucleon")) if nucleon_dir.exists(): self.log_message(f"同步 nucleon 目录: {nucleon_dir}") self.update_status(f"同步 nucleon 目录...", progress=10) - - result = self.sync_service.sync_directory(nucleon_dir) # type: ignore - if result.get('success'): - self.log_message(f"nucleon 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个") + + result = self.sync_service.sync_directory(nucleon_dir) # type: ignore + if result.get("success"): + self.log_message( + f"nucleon 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个" + ) else: - self.log_message(f"nucleon 同步失败: {result.get('error', '未知错误')}", is_error=True) - + self.log_message( + f"nucleon 同步失败: {result.get('error', '未知错误')}", + is_error=True, + ) + # 同步 electron 目录 - electron_dir = pathlib.Path(paths.get('electron_dir', './data/electron')) + electron_dir = pathlib.Path(paths.get("electron_dir", "./data/electron")) if electron_dir.exists(): self.log_message(f"同步 electron 目录: {electron_dir}") self.update_status(f"同步 electron 目录...", progress=60) - - result = self.sync_service.sync_directory(electron_dir) # type: ignore - if result.get('success'): - self.log_message(f"electron 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个") + + result = self.sync_service.sync_directory(electron_dir) # type: ignore + if result.get("success"): + self.log_message( + f"electron 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个" + ) else: - self.log_message(f"electron 同步失败: {result.get('error', '未知错误')}", is_error=True) - + self.log_message( + f"electron 同步失败: {result.get('error', '未知错误')}", + is_error=True, + ) + # 同步 orbital 目录(如果存在) - orbital_dir = pathlib.Path(paths.get('orbital_dir', './data/orbital')) + orbital_dir = pathlib.Path(paths.get("orbital_dir", "./data/orbital")) if orbital_dir.exists(): self.log_message(f"同步 orbital 目录: {orbital_dir}") self.update_status(f"同步 orbital 目录...", progress=80) - - result = self.sync_service.sync_directory(orbital_dir) # type: ignore - if result.get('success'): - self.log_message(f"orbital 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个") + + result = self.sync_service.sync_directory(orbital_dir) # type: ignore + if result.get("success"): + self.log_message( + f"orbital 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个" + ) else: - self.log_message(f"orbital 同步失败: {result.get('error', '未知错误')}", is_error=True) - + self.log_message( + f"orbital 同步失败: {result.get('error', '未知错误')}", + is_error=True, + ) + # 同步完成 self.update_status("同步完成", progress=100) self.log_message("所有目录同步完成") - + except Exception as e: self.log_message(f"同步过程中发生错误: {e}", is_error=True) self.update_status("同步失败") @@ -268,16 +286,16 @@ class SyncScreen(Screen): # 重置同步状态 self.is_syncing = False self.is_paused = False - self.update_button_states() # type: ignore + self.update_button_states() # type: ignore def pause_sync(self): """暂停同步""" if not self.is_syncing: return - + self.is_paused = not self.is_paused self.update_button_states() - + if self.is_paused: self.log_message("同步已暂停") self.update_status("同步已暂停") @@ -289,11 +307,11 @@ class SyncScreen(Screen): """取消同步""" if not self.is_syncing: return - + self.is_syncing = False self.is_paused = False self.update_button_states() - + self.log_message("同步已取消") self.update_status("同步已取消") @@ -303,17 +321,17 @@ class SyncScreen(Screen): start_button = self.query_one("#start_sync") pause_button = self.query_one("#pause_sync") cancel_button = self.query_one("#cancel_sync") - + if self.is_syncing: start_button.disabled = True pause_button.disabled = False cancel_button.disabled = False - pause_button.label = "继续" if self.is_paused else "暂停" # type: ignore + pause_button.label = "继续" if self.is_paused else "暂停" # type: ignore else: start_button.disabled = False pause_button.disabled = True cancel_button.disabled = True - + except Exception as e: self.log_message(f"更新按钮状态失败: {e}", is_error=True) diff --git a/src/heurams/interface/widgets/recognition.py b/src/heurams/interface/widgets/recognition.py index 842082a..7b4364c 100644 --- a/src/heurams/interface/widgets/recognition.py +++ b/src/heurams/interface/widgets/recognition.py @@ -50,9 +50,10 @@ class Recognition(BasePuzzleWidget): def compose(self): from heurams.context import config_var - autovoice = config_var.get()['interface']['memorizor']['autovoice'] + + autovoice = config_var.get()["interface"]["memorizor"]["autovoice"] if autovoice: - self.screen.action_play_voice() # type: ignore + self.screen.action_play_voice() # type: ignore cfg: RecognitionConfig = self.atom.registry["orbital"]["puzzles"][self.alia] delim = self.atom.registry["nucleon"].metadata["formation"]["delimiter"] replace_dict = { @@ -72,7 +73,7 @@ class Recognition(BasePuzzleWidget): primary = cfg["primary"] with Center(): - for i in cfg['top_dim']: + for i in cfg["top_dim"]: yield Static(f"[dim]{i}[/]") yield Label("") diff --git a/src/heurams/kernel/algorithms/sm15m.py b/src/heurams/kernel/algorithms/sm15m.py index 95e210e..1decbb5 100644 --- a/src/heurams/kernel/algorithms/sm15m.py +++ b/src/heurams/kernel/algorithms/sm15m.py @@ -10,15 +10,25 @@ MIT 许可证 import datetime import json import os -from typing import TypedDict import pathlib +from typing import TypedDict + from heurams.context import config_var -from heurams.kernel.algorithms.sm15m_calc import (MAX_AF, MIN_AF, NOTCH_AF, - RANGE_AF, RANGE_REPETITION, - SM, THRESHOLD_RECALL, Item) +from heurams.kernel.algorithms.sm15m_calc import ( + MAX_AF, + MIN_AF, + NOTCH_AF, + RANGE_AF, + RANGE_REPETITION, + SM, + THRESHOLD_RECALL, + Item, +) # 全局状态文件路径 -_GLOBAL_STATE_FILE = os.path.expanduser(pathlib.Path(config_var.get()['paths']['global_dir']) / 'sm15m_global_state.json') +_GLOBAL_STATE_FILE = os.path.expanduser( + pathlib.Path(config_var.get()["paths"]["global_dir"]) / "sm15m_global_state.json" +) def _get_global_sm(): diff --git a/src/heurams/kernel/particles/atom.py b/src/heurams/kernel/particles/atom.py index 4d90ff2..dfb172c 100644 --- a/src/heurams/kernel/particles/atom.py +++ b/src/heurams/kernel/particles/atom.py @@ -86,8 +86,8 @@ class Atom: # eval 环境设置 def eval_with_env(s: str): default = config_var.get()["puzzles"] - payload = self.registry['nucleon'].payload - metadata = self.registry['nucleon'].metadata + payload = self.registry["nucleon"].payload + metadata = self.registry["nucleon"].metadata eval_value = eval(s) if isinstance(eval_value, (int, float)): ret = str(eval_value) @@ -117,10 +117,11 @@ class Atom: logger.debug("发现 eval 表达式: '%s'", data[5:]) return modifier(data[5:]) return data + try: - traverse(self.registry['nucleon'].payload, eval_with_env) - traverse(self.registry['nucleon'].metadata, eval_with_env) - traverse(self.registry['orbital'], eval_with_env) + traverse(self.registry["nucleon"].payload, eval_with_env) + traverse(self.registry["nucleon"].metadata, eval_with_env) + traverse(self.registry["orbital"], eval_with_env) except Exception as e: ret = f"此 eval 实例发生错误: {e}" logger.warning(ret) diff --git a/src/heurams/kernel/particles/electron.py b/src/heurams/kernel/particles/electron.py index cb7cf54..cc5cdd6 100644 --- a/src/heurams/kernel/particles/electron.py +++ b/src/heurams/kernel/particles/electron.py @@ -18,9 +18,12 @@ class Electron: algo: 使用的算法模块标识 """ if algo_name == "": - algo_name = config_var.get()['algorithm']['default'] + algo_name = config_var.get()["algorithm"]["default"] logger.debug( - "创建 Electron 实例, ident: '%s', algo_name: '%s', algodata: %s", ident, algo_name, algodata + "创建 Electron 实例, ident: '%s', algo_name: '%s', algodata: %s", + ident, + algo_name, + algodata, ) self.algodata = algodata self.ident = ident @@ -31,7 +34,9 @@ class Electron: self.algodata[self.algo.algo_name] = {} logger.debug("算法键 '%s' 不存在, 已创建空字典", self.algo) if not self.algodata[self.algo.algo_name]: - logger.debug(f"算法数据为空, 使用默认值初始化{self.algodata[self.algo.algo_name]}") + logger.debug( + f"算法数据为空, 使用默认值初始化{self.algodata[self.algo.algo_name]}" + ) self._default_init(self.algo.defaults) else: logger.debug("算法数据已存在, 跳过默认初始化") diff --git a/src/heurams/kernel/particles/nucleon.py b/src/heurams/kernel/particles/nucleon.py index 098e840..175661b 100644 --- a/src/heurams/kernel/particles/nucleon.py +++ b/src/heurams/kernel/particles/nucleon.py @@ -53,4 +53,4 @@ class Nucleon: def placeholder(): """生成一个占位原子核""" logger.debug("创建 Nucleon 占位符") - return Nucleon("核子对象样例内容", {}) \ No newline at end of file + return Nucleon("核子对象样例内容", {}) diff --git a/src/heurams/providers/tts/edge_tts.py b/src/heurams/providers/tts/edge_tts.py index ee74cf0..9b8a33c 100644 --- a/src/heurams/providers/tts/edge_tts.py +++ b/src/heurams/providers/tts/edge_tts.py @@ -2,8 +2,8 @@ import pathlib import edge_tts -from heurams.services.logger import get_logger from heurams.context import config_var +from heurams.services.logger import get_logger from .base import BaseTTS @@ -19,7 +19,7 @@ class EdgeTTS(BaseTTS): try: communicate = edge_tts.Communicate( text, - config_var.get()['providers']['tts']['edgetts']["voice"], + config_var.get()["providers"]["tts"]["edgetts"]["voice"], ) logger.debug("EdgeTTS 通信对象创建成功, 正在保存音频") communicate.save_sync(str(path)) diff --git a/src/heurams/services/sync_service.py b/src/heurams/services/sync_service.py index c8792c0..a82d5da 100644 --- a/src/heurams/services/sync_service.py +++ b/src/heurams/services/sync_service.py @@ -18,6 +18,7 @@ logger = get_logger(__name__) class SyncMode(Enum): """同步模式枚举""" + BIDIRECTIONAL = "bidirectional" UPLOAD_ONLY = "upload_only" DOWNLOAD_ONLY = "download_only" @@ -25,6 +26,7 @@ class SyncMode(Enum): class ConflictStrategy(Enum): """冲突解决策略枚举""" + NEWER = "newer" # 较新文件覆盖较旧文件 ASK = "ask" # 用户手动选择 KEEP_BOTH = "keep_both" # 保留双方(重命名) @@ -33,6 +35,7 @@ class ConflictStrategy(Enum): @dataclass class SyncConfig: """同步配置数据类""" + enabled: bool = False url: str = "" username: str = "" @@ -59,12 +62,12 @@ class SyncService: return options = { - 'webdav_hostname': self.config.url, - 'webdav_login': self.config.username, - 'webdav_password': self.config.password, - 'webdav_root': self.config.remote_path, - 'verify_ssl': self.config.verify_ssl, - 'disable_check': True, # 不检查服务器支持的功能 + "webdav_hostname": self.config.url, + "webdav_login": self.config.username, + "webdav_password": self.config.password, + "webdav_root": self.config.remote_path, + "verify_ssl": self.config.verify_ssl, + "disable_check": True, # 不检查服务器支持的功能 } try: @@ -98,10 +101,10 @@ class SyncService: rel_path = file_path.relative_to(local_dir) stat = file_path.stat() files[str(rel_path)] = { - 'path': file_path, - 'size': stat.st_size, - 'mtime': stat.st_mtime, - 'hash': self._calculate_hash(file_path), + "path": file_path, + "size": stat.st_size, + "mtime": stat.st_mtime, + "hash": self._calculate_hash(file_path), } return files @@ -114,14 +117,14 @@ class SyncService: remote_list = self.client.list(recursive=True) files = {} for item in remote_list: - if not item.endswith('/'): # 忽略目录 - rel_path = item.lstrip('/') + if not item.endswith("/"): # 忽略目录 + rel_path = item.lstrip("/") try: info = self.client.info(item) files[rel_path] = { - 'path': item, - 'size': info.get('size', 0), - 'mtime': self._parse_remote_mtime(info), + "path": item, + "size": info.get("size", 0), + "mtime": self._parse_remote_mtime(info), } except Exception as e: logger.warning("无法获取远程文件信息 %s: %s", item, e) @@ -134,8 +137,8 @@ class SyncService: """计算文件的 SHA-256 哈希值""" sha256 = hashlib.sha256() try: - with open(file_path, 'rb') as f: - for block in iter(lambda: f.read(block_size), b''): + with open(file_path, "rb") as f: + for block in iter(lambda: f.read(block_size), b""): sha256.update(block) return sha256.hexdigest() except Exception as e: @@ -151,23 +154,23 @@ class SyncService: def sync_directory(self, local_dir: pathlib.Path) -> typing.Dict[str, typing.Any]: """ 同步目录 - + Args: local_dir: 本地目录路径 - + Returns: 同步结果统计 """ if not self.client: logger.error("WebDAV 客户端未初始化") - return {'success': False, 'error': '客户端未初始化'} + return {"success": False, "error": "客户端未初始化"} results = { - 'uploaded': 0, - 'downloaded': 0, - 'conflicts': 0, - 'errors': 0, - 'success': True, + "uploaded": 0, + "downloaded": 0, + "conflicts": 0, + "errors": 0, + "success": True, } try: @@ -180,124 +183,144 @@ class SyncService: # 根据同步模式处理文件 if self.config.sync_mode in [SyncMode.BIDIRECTIONAL, SyncMode.UPLOAD_ONLY]: stats = self._upload_files(local_dir, local_files, remote_files) - results['uploaded'] += stats.get('uploaded', 0) - results['conflicts'] += stats.get('conflicts', 0) - results['errors'] += stats.get('errors', 0) + results["uploaded"] += stats.get("uploaded", 0) + results["conflicts"] += stats.get("conflicts", 0) + results["errors"] += stats.get("errors", 0) - if self.config.sync_mode in [SyncMode.BIDIRECTIONAL, SyncMode.DOWNLOAD_ONLY]: + if self.config.sync_mode in [ + SyncMode.BIDIRECTIONAL, + SyncMode.DOWNLOAD_ONLY, + ]: stats = self._download_files(local_dir, local_files, remote_files) - results['downloaded'] += stats.get('downloaded', 0) - results['conflicts'] += stats.get('conflicts', 0) - results['errors'] += stats.get('errors', 0) + results["downloaded"] += stats.get("downloaded", 0) + results["conflicts"] += stats.get("conflicts", 0) + results["errors"] += stats.get("errors", 0) logger.info("同步完成: %s", results) return results except Exception as e: logger.error("同步过程中发生错误: %s", e) - results['success'] = False - results['error'] = str(e) + results["success"] = False + results["error"] = str(e) return results - def _upload_files(self, local_dir: pathlib.Path, - local_files: dict, remote_files: dict) -> typing.Dict[str, int]: + def _upload_files( + self, local_dir: pathlib.Path, local_files: dict, remote_files: dict + ) -> typing.Dict[str, int]: """上传文件到远程服务器""" - stats = {'uploaded': 0, 'errors': 0, 'conflicts': 0} - + stats = {"uploaded": 0, "errors": 0, "conflicts": 0} + for rel_path, local_info in local_files.items(): remote_info = remote_files.get(rel_path) - + # 判断是否需要上传 should_upload = False conflict_resolved = False remote_path = os.path.join(self.config.remote_path, rel_path) - + if not remote_info: should_upload = True # 远程不存在 else: # 检查冲突 - local_mtime = local_info.get('mtime', 0) - remote_mtime = remote_info.get('mtime', 0) - + local_mtime = local_info.get("mtime", 0) + remote_mtime = remote_info.get("mtime", 0) + if local_mtime != remote_mtime: # 存在冲突 - stats['conflicts'] += 1 - should_upload, should_download = self._handle_conflict(local_info, remote_info) - - if should_upload and self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH: + stats["conflicts"] += 1 + should_upload, should_download = self._handle_conflict( + local_info, remote_info + ) + + if ( + should_upload + and self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH + ): # 重命名远程文件避免覆盖 conflict_suffix = f".conflict_{int(remote_mtime)}" name, ext = os.path.splitext(rel_path) - new_rel_path = f"{name}{conflict_suffix}{ext}" if ext else f"{name}{conflict_suffix}" - remote_path = os.path.join(self.config.remote_path, new_rel_path) + new_rel_path = ( + f"{name}{conflict_suffix}{ext}" + if ext + else f"{name}{conflict_suffix}" + ) + remote_path = os.path.join( + self.config.remote_path, new_rel_path + ) conflict_resolved = True logger.debug("冲突文件重命名: %s -> %s", rel_path, new_rel_path) else: # 时间相同,无需上传 should_upload = False - + if should_upload: try: - self.client.upload_file(local_info['path'], remote_path) - stats['uploaded'] += 1 + self.client.upload_file(local_info["path"], remote_path) + stats["uploaded"] += 1 logger.debug("上传文件: %s -> %s", rel_path, remote_path) except Exception as e: logger.error("上传文件失败 %s: %s", rel_path, e) - stats['errors'] += 1 - + stats["errors"] += 1 + return stats - def _download_files(self, local_dir: pathlib.Path, - local_files: dict, remote_files: dict) -> typing.Dict[str, int]: + def _download_files( + self, local_dir: pathlib.Path, local_files: dict, remote_files: dict + ) -> typing.Dict[str, int]: """从远程服务器下载文件""" - stats = {'downloaded': 0, 'errors': 0, 'conflicts': 0} - + stats = {"downloaded": 0, "errors": 0, "conflicts": 0} + for rel_path, remote_info in remote_files.items(): local_info = local_files.get(rel_path) - + # 判断是否需要下载 should_download = False if not local_info: should_download = True # 本地不存在 else: # 检查冲突 - local_mtime = local_info.get('mtime', 0) - remote_mtime = remote_info.get('mtime', 0) - + local_mtime = local_info.get("mtime", 0) + remote_mtime = remote_info.get("mtime", 0) + if local_mtime != remote_mtime: # 存在冲突 - stats['conflicts'] += 1 - should_upload, should_download = self._handle_conflict(local_info, remote_info) + stats["conflicts"] += 1 + should_upload, should_download = self._handle_conflict( + local_info, remote_info + ) # 如果应该上传,则不应该下载(冲突已在上传侧处理) if should_upload: should_download = False else: # 时间相同,无需下载 should_download = False - + if should_download: try: local_path = local_dir / rel_path local_path.parent.mkdir(parents=True, exist_ok=True) - self.client.download_file(remote_info['path'], str(local_path)) - stats['downloaded'] += 1 + self.client.download_file(remote_info["path"], str(local_path)) + stats["downloaded"] += 1 logger.debug("下载文件: %s -> %s", rel_path, local_path) except Exception as e: logger.error("下载文件失败 %s: %s", rel_path, e) - stats['errors'] += 1 - + stats["errors"] += 1 + return stats - def _handle_conflict(self, local_info: dict, remote_info: dict) -> typing.Tuple[bool, bool]: + def _handle_conflict( + self, local_info: dict, remote_info: dict + ) -> typing.Tuple[bool, bool]: """ 处理文件冲突 - + Returns: (should_upload, should_download) - 是否应该上传和下载 """ - local_mtime = local_info.get('mtime', 0) - remote_mtime = remote_info.get('mtime', 0) - + local_mtime = local_info.get("mtime", 0) + remote_mtime = remote_info.get("mtime", 0) + if self.config.conflict_strategy == ConflictStrategy.NEWER: # 较新文件覆盖较旧文件 if local_mtime > remote_mtime: @@ -306,7 +329,7 @@ class SyncService: return False, True # 下载远程较新版本 else: return False, False # 时间相同,无需操作 - + elif self.config.conflict_strategy == ConflictStrategy.KEEP_BOTH: # 保留双方 - 重命名远程文件 # 这里实现简单的重命名策略:添加冲突后缀 @@ -314,25 +337,28 @@ class SyncService: # 返回 True, False 表示上传重命名后的文件 # 重命名逻辑在调用处处理 return True, False - + elif self.config.conflict_strategy == ConflictStrategy.ASK: # 用户手动选择 - 记录冲突,跳过 # 返回 False, False 跳过,等待用户决定 - logger.warning("文件冲突需要用户手动选择: local_mtime=%s, remote_mtime=%s", - local_mtime, remote_mtime) + logger.warning( + "文件冲突需要用户手动选择: local_mtime=%s, remote_mtime=%s", + local_mtime, + remote_mtime, + ) return False, False - + return False, False def _should_upload(self, local_info: dict, remote_info: dict) -> bool: """判断是否需要上传(本地较新或哈希不同)""" # 这里实现简单的基于时间的比较 # 实际应该使用哈希比较更可靠 - return local_info.get('mtime', 0) > remote_info.get('mtime', 0) + return local_info.get("mtime", 0) > remote_info.get("mtime", 0) def _should_download(self, local_info: dict, remote_info: dict) -> bool: """判断是否需要下载(远程较新)""" - return remote_info.get('mtime', 0) > local_info.get('mtime', 0) + return remote_info.get("mtime", 0) > local_info.get("mtime", 0) def upload_file(self, local_path: pathlib.Path, remote_path: str = "") -> bool: """上传单个文件""" @@ -381,30 +407,32 @@ def create_sync_service_from_config() -> typing.Optional[SyncService]: """从配置文件创建同步服务实例""" try: from heurams.context import config_var - - sync_config = config_var.get()['providers']['sync']['webdav'] - if not sync_config.get('enabled', False): + + sync_config = config_var.get()["providers"]["sync"]["webdav"] + if not sync_config.get("enabled", False): logger.debug("同步服务未启用") return None - + config = SyncConfig( - enabled=sync_config.get('enabled', False), - url=sync_config.get('url', ''), - username=sync_config.get('username', ''), - password=sync_config.get('password', ''), - remote_path=sync_config.get('remote_path', '/heurams/'), - sync_mode=SyncMode(sync_config.get('sync_mode', 'bidirectional')), - conflict_strategy=ConflictStrategy(sync_config.get('conflict_strategy', 'newer')), - verify_ssl=sync_config.get('verify_ssl', True), + enabled=sync_config.get("enabled", False), + url=sync_config.get("url", ""), + username=sync_config.get("username", ""), + password=sync_config.get("password", ""), + remote_path=sync_config.get("remote_path", "/heurams/"), + sync_mode=SyncMode(sync_config.get("sync_mode", "bidirectional")), + conflict_strategy=ConflictStrategy( + sync_config.get("conflict_strategy", "newer") + ), + verify_ssl=sync_config.get("verify_ssl", True), ) - + service = SyncService(config) if service.client is None: logger.warning("同步服务客户端创建失败") return None - + return service - + except Exception as e: logger.error("创建同步服务失败: %s", e) - return None \ No newline at end of file + return None diff --git a/tests/interface/test_synctool.py b/tests/interface/test_synctool.py index 4d2c217..096ac8a 100644 --- a/tests/interface/test_synctool.py +++ b/tests/interface/test_synctool.py @@ -6,11 +6,16 @@ import pathlib import tempfile import time import unittest -from unittest.mock import MagicMock, patch, Mock +from unittest.mock import MagicMock, Mock, patch from heurams.context import ConfigContext from heurams.services.config import ConfigFile -from heurams.services.sync_service import SyncService, SyncConfig, SyncMode, ConflictStrategy +from heurams.services.sync_service import ( + ConflictStrategy, + SyncConfig, + SyncMode, + SyncService, +) class TestSyncServiceUnit(unittest.TestCase): @@ -20,14 +25,14 @@ class TestSyncServiceUnit(unittest.TestCase): """在每个测试之前运行, 设置临时目录和模拟客户端.""" self.temp_dir = tempfile.TemporaryDirectory() self.temp_path = pathlib.Path(self.temp_dir.name) - + # 创建测试文件 self.test_file = self.temp_path / "test.txt" self.test_file.write_text("测试内容") - + # 模拟 WebDAV 客户端 self.mock_client = MagicMock() - + # 创建同步配置 self.config = SyncConfig( enabled=True, @@ -44,100 +49,100 @@ class TestSyncServiceUnit(unittest.TestCase): """在每个测试之后清理.""" self.temp_dir.cleanup() - @patch('heurams.services.sync_service.Client') + @patch("heurams.services.sync_service.Client") def test_sync_service_initialization(self, mock_client_class): """测试同步服务初始化.""" mock_client_class.return_value = self.mock_client - + service = SyncService(self.config) - + # 验证客户端已创建 mock_client_class.assert_called_once() self.assertIsNotNone(service.client) self.assertEqual(service.config, self.config) - @patch('heurams.services.sync_service.Client') + @patch("heurams.services.sync_service.Client") def test_sync_service_disabled(self, mock_client_class): """测试同步服务未启用.""" config = SyncConfig(enabled=False) service = SyncService(config) - + # 客户端不应初始化 mock_client_class.assert_not_called() self.assertIsNone(service.client) - @patch('heurams.services.sync_service.Client') + @patch("heurams.services.sync_service.Client") def test_test_connection_success(self, mock_client_class): """测试连接测试成功.""" mock_client_class.return_value = self.mock_client self.mock_client.list.return_value = [] - + service = SyncService(self.config) result = service.test_connection() - + self.assertTrue(result) self.mock_client.list.assert_called_once() - @patch('heurams.services.sync_service.Client') + @patch("heurams.services.sync_service.Client") def test_test_connection_failure(self, mock_client_class): """测试连接测试失败.""" mock_client_class.return_value = self.mock_client self.mock_client.list.side_effect = Exception("连接失败") - + service = SyncService(self.config) result = service.test_connection() - + self.assertFalse(result) self.mock_client.list.assert_called_once() - @patch('heurams.services.sync_service.Client') + @patch("heurams.services.sync_service.Client") def test_upload_file(self, mock_client_class): """测试上传单个文件.""" mock_client_class.return_value = self.mock_client - + service = SyncService(self.config) result = service.upload_file(self.test_file) - + self.assertTrue(result) self.mock_client.upload_file.assert_called_once() - @patch('heurams.services.sync_service.Client') + @patch("heurams.services.sync_service.Client") def test_download_file(self, mock_client_class): """测试下载单个文件.""" mock_client_class.return_value = self.mock_client - + service = SyncService(self.config) remote_path = "/heurams/test.txt" local_path = self.temp_path / "downloaded.txt" - + result = service.download_file(remote_path, local_path) - + self.assertTrue(result) self.mock_client.download_file.assert_called_once() self.assertTrue(local_path.parent.exists()) - @patch('heurams.services.sync_service.Client') + @patch("heurams.services.sync_service.Client") def test_sync_directory_no_files(self, mock_client_class): """测试同步空目录.""" mock_client_class.return_value = self.mock_client self.mock_client.list.return_value = [] self.mock_client.mkdir.return_value = None - + service = SyncService(self.config) result = service.sync_directory(self.temp_path) - - self.assertTrue(result['success']) - self.assertEqual(result['uploaded'], 0) - self.assertEqual(result['downloaded'], 0) + + self.assertTrue(result["success"]) + self.assertEqual(result["uploaded"], 0) + self.assertEqual(result["downloaded"], 0) self.mock_client.mkdir.assert_called_once() - @patch('heurams.services.sync_service.Client') + @patch("heurams.services.sync_service.Client") def test_sync_directory_upload_only(self, mock_client_class): """测试仅上传模式.""" mock_client_class.return_value = self.mock_client self.mock_client.list.return_value = [] self.mock_client.mkdir.return_value = None - + config = SyncConfig( enabled=True, url="https://example.com/dav/", @@ -147,60 +152,64 @@ class TestSyncServiceUnit(unittest.TestCase): sync_mode=SyncMode.UPLOAD_ONLY, conflict_strategy=ConflictStrategy.NEWER, ) - + service = SyncService(config) result = service.sync_directory(self.temp_path) - - self.assertTrue(result['success']) + + self.assertTrue(result["success"]) self.mock_client.mkdir.assert_called_once() - @patch('heurams.services.sync_service.Client') + @patch("heurams.services.sync_service.Client") def test_conflict_strategy_newer(self, mock_client_class): """测试 NEWER 冲突策略.""" mock_client_class.return_value = self.mock_client - + # 模拟远程文件存在 self.mock_client.list.return_value = ["test.txt"] - self.mock_client.info.return_value = {'size': 100, 'modified': '2023-01-01T00:00:00Z'} + self.mock_client.info.return_value = { + "size": 100, + "modified": "2023-01-01T00:00:00Z", + } self.mock_client.mkdir.return_value = None - + service = SyncService(self.config) result = service.sync_directory(self.temp_path) - - self.assertTrue(result['success']) - # 应该有一个冲突 - self.assertGreaterEqual(result.get('conflicts', 0), 0) - @patch('heurams.services.sync_service.Client') + self.assertTrue(result["success"]) + # 应该有一个冲突 + self.assertGreaterEqual(result.get("conflicts", 0), 0) + + @patch("heurams.services.sync_service.Client") def test_create_sync_service_from_config(self, mock_client_class): """测试从配置文件创建同步服务.""" mock_client_class.return_value = self.mock_client - + # 创建临时配置文件 config_data = { - 'sync': { - 'webdav': { - 'enabled': True, - 'url': 'https://example.com/dav/', - 'username': 'test', - 'password': 'test', - 'remote_path': '/heurams/', - 'sync_mode': 'bidirectional', - 'conflict_strategy': 'newer', - 'verify_ssl': True, + "sync": { + "webdav": { + "enabled": True, + "url": "https://example.com/dav/", + "username": "test", + "password": "test", + "remote_path": "/heurams/", + "sync_mode": "bidirectional", + "conflict_strategy": "newer", + "verify_ssl": True, } } } - + # 模拟 config_var - with patch('heurams.services.sync_service.config_var') as mock_config_var: + with patch("heurams.services.sync_service.config_var") as mock_config_var: mock_config = MagicMock() mock_config.data = config_data mock_config_var.get.return_value = mock_config - + from heurams.services.sync_service import create_sync_service_from_config + service = create_sync_service_from_config() - + self.assertIsNotNone(service) self.assertIsNotNone(service.client) @@ -212,39 +221,41 @@ class TestSyncScreenUnit(unittest.TestCase): """在每个测试之前运行.""" self.temp_dir = tempfile.TemporaryDirectory() self.temp_path = pathlib.Path(self.temp_dir.name) - + # 创建默认配置 default_config_path = ( pathlib.Path(__file__).parent.parent.parent / "src/heurams/default/config/config.toml" ) self.config = ConfigFile(default_config_path) - + # 更新配置中的路径 config_data = self.config.data config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon") config_data["paths"]["electron_dir"] = str(self.temp_path / "electron") config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital") config_data["paths"]["cache_dir"] = str(self.temp_path / "cache") - + # 添加同步配置 - if 'sync' not in config_data: - config_data['sync'] = {} - config_data['sync']['webdav'] = { - 'enabled': False, - 'url': '', - 'username': '', - 'password': '', - 'remote_path': '/heurams/', - 'sync_mode': 'bidirectional', - 'conflict_strategy': 'newer', - 'verify_ssl': True, + if "sync" not in config_data: + config_data["sync"] = {} + config_data["sync"]["webdav"] = { + "enabled": False, + "url": "", + "username": "", + "password": "", + "remote_path": "/heurams/", + "sync_mode": "bidirectional", + "conflict_strategy": "newer", + "verify_ssl": True, } - + # 创建目录 for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]: - pathlib.Path(config_data["paths"][dir_key]).mkdir(parents=True, exist_ok=True) - + pathlib.Path(config_data["paths"][dir_key]).mkdir( + parents=True, exist_ok=True + ) + # 使用 ConfigContext 设置配置 self.config_ctx = ConfigContext(self.config) self.config_ctx.__enter__() @@ -254,52 +265,53 @@ class TestSyncScreenUnit(unittest.TestCase): self.config_ctx.__exit__(None, None, None) self.temp_dir.cleanup() - @patch('heurams.interface.screens.synctool.create_sync_service_from_config') + @patch("heurams.interface.screens.synctool.create_sync_service_from_config") def test_sync_screen_compose(self, mock_create_service): """测试 SyncScreen 的 compose 方法.""" from heurams.interface.screens.synctool import SyncScreen - + # 模拟同步服务 mock_service = MagicMock() mock_service.client = MagicMock() mock_create_service.return_value = mock_service - + screen = SyncScreen() - + # 测试 compose 方法 from textual.app import ComposeResult + result = screen.compose() widgets = list(result) - + # 检查基本部件 - from textual.widgets import Footer, Header, Button, Static, ProgressBar from textual.containers import ScrollableContainer - + from textual.widgets import Button, Footer, Header, ProgressBar, Static + header_present = any(isinstance(w, Header) for w in widgets) footer_present = any(isinstance(w, Footer) for w in widgets) self.assertTrue(header_present) self.assertTrue(footer_present) - + # 检查容器 container_present = any(isinstance(w, ScrollableContainer) for w in widgets) self.assertTrue(container_present) - @patch('heurams.interface.screens.synctool.create_sync_service_from_config') + @patch("heurams.interface.screens.synctool.create_sync_service_from_config") def test_sync_screen_load_config(self, mock_create_service): """测试 SyncScreen 加载配置.""" from heurams.interface.screens.synctool import SyncScreen - + mock_service = MagicMock() mock_service.client = MagicMock() mock_create_service.return_value = mock_service - + screen = SyncScreen() screen.load_config() - + # 验证配置已加载 self.assertIsNotNone(screen.sync_config) mock_create_service.assert_called_once() if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()