From 65486794b7b77c555128d66ef272f44215b10b41 Mon Sep 17 00:00:00 2001 From: david-ajax Date: Sun, 4 Jan 2026 04:46:19 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=94=B9=E8=BF=9B=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/simplemem.py | 28 +- src/heurams/interface/__init__.py | 4 +- src/heurams/interface/screens/dashboard.py | 6 +- src/heurams/interface/screens/memoqueue.py | 36 +-- src/heurams/interface/screens/precache.py | 7 +- src/heurams/interface/screens/preparation.py | 1 + src/heurams/interface/shim.py | 1 + src/heurams/interface/widgets/mcq_puzzle.py | 2 +- src/heurams/kernel/algorithms/__init__.py | 6 - src/heurams/kernel/evaluators/__init__.py | 37 --- src/heurams/kernel/particles/__init__.py | 19 +- src/heurams/kernel/particles/nucleon.py | 8 +- src/heurams/kernel/reactor/__init__.py | 4 - src/heurams/kernel/reactor/fission.py | 9 +- src/heurams/kernel/reactor/phaser.py | 5 +- src/heurams/kernel/reactor/procession.py | 13 +- src/heurams/kernel/repolib/__init__.py | 4 +- src/heurams/kernel/repolib/aux.py | 5 - src/heurams/services/textproc.py | 2 +- src/heurams/utils/__init__.py | 5 + tests/interface/test_dashboard.py | 153 --------- tests/interface/test_synctool.py | 317 ------------------- tests/kernel/algorithms/__init__.py | 0 tests/kernel/algorithms/test_sm2.py | 186 ----------- tests/kernel/particles/__init__.py | 0 tests/kernel/particles/test_atom.py | 202 ------------ tests/kernel/particles/test_electron.py | 179 ----------- tests/kernel/particles/test_nucleon.py | 108 ------- tests/kernel/puzzles/__init__.py | 0 tests/kernel/puzzles/test_base.py | 23 -- tests/kernel/puzzles/test_cloze.py | 51 --- tests/kernel/puzzles/test_mcq.py | 122 ------- tests/kernel/reactor/__init__.py | 0 tests/kernel/reactor/test_phaser.py | 114 ------- 34 files changed, 87 insertions(+), 1570 deletions(-) delete mode 100644 src/heurams/kernel/repolib/aux.py delete mode 100644 tests/interface/test_dashboard.py delete mode 100644 tests/interface/test_synctool.py delete mode 100644 tests/kernel/algorithms/__init__.py delete mode 100644 tests/kernel/algorithms/test_sm2.py delete mode 100644 tests/kernel/particles/__init__.py delete mode 100644 tests/kernel/particles/test_atom.py delete mode 100644 tests/kernel/particles/test_electron.py delete mode 100644 tests/kernel/particles/test_nucleon.py delete mode 100644 tests/kernel/puzzles/__init__.py delete mode 100644 tests/kernel/puzzles/test_base.py delete mode 100644 tests/kernel/puzzles/test_cloze.py delete mode 100644 tests/kernel/puzzles/test_mcq.py delete mode 100644 tests/kernel/reactor/__init__.py delete mode 100644 tests/kernel/reactor/test_phaser.py diff --git a/examples/simplemem.py b/examples/simplemem.py index fa5e9a7..e33cbbc 100644 --- a/examples/simplemem.py +++ b/examples/simplemem.py @@ -3,6 +3,7 @@ import heurams.kernel.particles as pt from heurams.services.textproc import truncate from pathlib import Path import time + repo = repolib.Repo.create_from_repodir(Path("./test_repo")) alist = list() print(repo.ident_index) @@ -17,38 +18,39 @@ for i in repo.ident_index: input() a = pt.Atom(n, e, repo.orbitic_data) alist.append(a) - #e.activate() - #e.revisor(5, True) + # e.activate() + # e.revisor(5, True) print(repr(a)) # print(repr(e)) print(repo) input() import heurams.kernel.reactor as rt + ph: rt.Phaser = rt.Phaser(alist) print(ph) -pr: rt.Procession = ph.current_procession() # type: ignore +pr: rt.Procession = ph.current_procession() # type: ignore print(pr) pr.forward() print(pr) -pr.forward() # 如果过界了? -print(pr) # 静默设置状态 无报错 +pr.forward() # 如果过界了? +print(pr) # 静默设置状态 无报错 pr.forward() print(pr) -pr = ph.current_procession() # type: ignore # 下一个队列 +pr = ph.current_procession() # type: ignore # 下一个队列 print(pr) pr.forward() print(pr) -pr.append() # 如果记忆失败了? +pr.append() # 如果记忆失败了? print(pr) pr.forward() -pr.append() # 如果记忆失败了? -pr.append() # 如果记忆失败了? -pr.append() # 如果记忆失败了? -pr.append() # 如果记忆失败了? -pr.append() # 如果记忆失败了? +pr.append() # 如果记忆失败了? +pr.append() # 如果记忆失败了? +pr.append() # 如果记忆失败了? +pr.append() # 如果记忆失败了? +pr.append() # 如果记忆失败了? # 重复项目只会占据一个车尾 print(pr) pr.forward() print(pr) -pr = ph.current_procession() # type: ignore +pr = ph.current_procession() # type: ignore print(pr) diff --git a/src/heurams/interface/__init__.py b/src/heurams/interface/__init__.py index 16704b2..f36099f 100644 --- a/src/heurams/interface/__init__.py +++ b/src/heurams/interface/__init__.py @@ -17,9 +17,9 @@ def environment_check(): from pathlib import Path logger.debug("检查环境路径") - subdir = ['cache/voice', 'repo', 'global', 'config'] + subdir = ["cache/voice", "repo", "global", "config"] for i in subdir: - i = Path(config_var.get()['paths']['data']) / i + i = Path(config_var.get()["paths"]["data"]) / i if not i.exists(): logger.info("创建目录: %s", i) print(f"创建 {i}") diff --git a/src/heurams/interface/screens/dashboard.py b/src/heurams/interface/screens/dashboard.py index 5a7d6c0..34f8902 100644 --- a/src/heurams/interface/screens/dashboard.py +++ b/src/heurams/interface/screens/dashboard.py @@ -14,6 +14,8 @@ from heurams.kernel.particles import * from heurams.kernel.repolib import * from heurams.services.logger import get_logger +import heurams.kernel.particles as pt +from pathlib import Path from .about import AboutScreen from .preparation import PreparationScreen @@ -42,7 +44,9 @@ class DashboardScreen(Screen): yield Header(show_clock=True) yield ScrollableContainer( Label('欢迎使用 "潜进" 启发式辅助记忆调度器', classes="title-label"), - Label(f"当前 UNIX 日时间戳: {timer.get_daystamp()} (UTC+{config_var.get()["timezone_offset"] / 3600})"), + Label( + f"当前 UNIX 日时间戳: {timer.get_daystamp()} (UTC+{config_var.get()["timezone_offset"] / 3600})" + ), Label(f"全局算法设置: {config_var.get()['algorithm']['default']}"), Label("选择待学习或待修改的项目:", classes="title-label"), ListView(id="repo-list", classes="repo-list-view"), diff --git a/src/heurams/interface/screens/memoqueue.py b/src/heurams/interface/screens/memoqueue.py index df5afe9..36bf9b3 100644 --- a/src/heurams/interface/screens/memoqueue.py +++ b/src/heurams/interface/screens/memoqueue.py @@ -41,9 +41,9 @@ class MemScreen(Screen): def __init__( self, phaser: Phaser, - name = None, - id = None, - classes = None, + name=None, + id=None, + classes=None, ) -> None: super().__init__(name, id, classes) self.phaser = phaser @@ -59,7 +59,7 @@ class MemScreen(Screen): def update_state(self): """更新状态机""" self.procession: Procession = self.phaser.current_procession() # type: ignore - self.atom: pt.Atom = self.procession.current_atom # type: ignore + self.atom: pt.Atom = self.procession.current_atom # type: ignore def on_mount(self): self.mount_puzzle() @@ -69,14 +69,12 @@ class MemScreen(Screen): try: self.fission = self.procession.get_fission() puzzle = self.fission.get_current_puzzle() - # logger.debug(puzzle_debug) - return shim.puzzle2widget[puzzle["puzzle"]]( # type: ignore - atom=self.atom, alia=puzzle["alia"] # type: ignore + return shim.puzzle2widget[puzzle["puzzle"]]( # type: ignore + atom=self.atom, alia=puzzle["alia"] # type: ignore ) except (KeyError, StopIteration, AttributeError) as e: logger.debug(f"调度展开出错: {e}") return Static(f"无法生成谜题 {e}") - # logger.debug(shim.puzzle2widget[puzzle_debug["puzzle"]]) def _get_progress_text(self): s = f"阶段: {self.procession.phase.name}\n" @@ -117,32 +115,28 @@ class MemScreen(Screen): from heurams.services.audio_service import play_by_path from heurams.services.hasher import get_md5 - path = Path(config_var.get()["paths"]['data']) / 'cache' / 'voice' - path = ( - path - / f"{get_md5(self.atom.registry['nucleon']["tts_text"])}.wav" - ) + path = Path(config_var.get()["paths"]["data"]) / "cache" / "voice" + path = path / f"{get_md5(self.atom.registry['nucleon']["tts_text"])}.wav" if path.exists(): play_by_path(path) else: from heurams.services.tts_service import convertor - convertor( - self.atom.registry["nucleon"]["tts_text"], path - ) + + convertor(self.atom.registry["nucleon"]["tts_text"], path) play_by_path(path) def watch_rating(self, old_rating, new_rating) -> None: - self.update_state() # 刷新状态 - if self.procession == None: # 已经完成记忆 + self.update_state() # 刷新状态 + if self.procession == None: # 已经完成记忆 return - if new_rating == -1: # 安全值 + if new_rating == -1: # 安全值 return - forwards = 1 if new_rating >= 4 else 0 # 准许前进 + forwards = 1 if new_rating >= 4 else 0 # 准许前进 self.rating = -1 logger.debug(f"试图前进: {"允许" if forwards else "禁止"}") if forwards: ret = self.procession.forward(1) - if ret == 0: # 若结束了此次队列 + if ret == 0: # 若结束了此次队列 self.update_state() if self.procession.phase == PhaserState.FINISHED: # 若所有队列都结束了 logger.debug(f"记忆进程结束") diff --git a/src/heurams/interface/screens/precache.py b/src/heurams/interface/screens/precache.py index c716788..0bb54d5 100644 --- a/src/heurams/interface/screens/precache.py +++ b/src/heurams/interface/screens/precache.py @@ -12,7 +12,8 @@ import heurams.kernel.particles as pt import heurams.services.hasher as hasher from heurams.context import * -cache_dir = pathlib.Path(config_var.get()["paths"]["data"]) / "cache" / 'voice' +cache_dir = pathlib.Path(config_var.get()["paths"]["data"]) / "cache" / "voice" + class PrecachingScreen(Screen): """预缓存音频文件屏幕 @@ -204,9 +205,7 @@ class PrecachingScreen(Screen): from heurams.context import config_var, rootdir, workdir - shutil.rmtree( - cache_dir, ignore_errors=True - ) + shutil.rmtree(cache_dir, ignore_errors=True) self.update_status("已清空", "音频缓存已清空", 0) except Exception as e: self.update_status("错误", f"清空缓存失败: {e}") diff --git a/src/heurams/interface/screens/preparation.py b/src/heurams/interface/screens/preparation.py index 2fc1c77..17bd421 100644 --- a/src/heurams/interface/screens/preparation.py +++ b/src/heurams/interface/screens/preparation.py @@ -128,6 +128,7 @@ class PreparationScreen(Screen): atoms_to_provide.append(i) from .memoqueue import MemScreen import heurams.kernel.reactor as rt + pheser = rt.Phaser(atoms_to_provide) memscreen = MemScreen(pheser) self.app.push_screen(memscreen) diff --git a/src/heurams/interface/shim.py b/src/heurams/interface/shim.py index 9268938..32aa73d 100644 --- a/src/heurams/interface/shim.py +++ b/src/heurams/interface/shim.py @@ -1,4 +1,5 @@ """Kernel 操作辅助函数库""" + import heurams.interface.widgets as pzw import heurams.kernel.evaluators as pz diff --git a/src/heurams/interface/widgets/mcq_puzzle.py b/src/heurams/interface/widgets/mcq_puzzle.py index d08ccd5..10d6780 100644 --- a/src/heurams/interface/widgets/mcq_puzzle.py +++ b/src/heurams/interface/widgets/mcq_puzzle.py @@ -54,7 +54,7 @@ class MCQPuzzle(BasePuzzleWidget): self._load() def _load(self): - cfg = self.atom.registry["orbital"]["puzzles"][self.alia] + cfg = self.atom.registry["nucleon"]["puzzles"][self.alia] self.puzzle = pz.MCQPuzzle( cfg["mapping"], cfg["jammer"], int(cfg["max_riddles_num"]), cfg["prefix"] ) diff --git a/src/heurams/kernel/algorithms/__init__.py b/src/heurams/kernel/algorithms/__init__.py index 1914f78..e0b7ae5 100644 --- a/src/heurams/kernel/algorithms/__init__.py +++ b/src/heurams/kernel/algorithms/__init__.py @@ -1,11 +1,7 @@ -from heurams.services.logger import get_logger - from .base import BaseAlgorithm from .sm2 import SM2Algorithm from .sm15m import SM15MAlgorithm -logger = get_logger(__name__) - __all__ = [ "SM2Algorithm", "BaseAlgorithm", @@ -17,5 +13,3 @@ algorithms = { "SM-15M": SM15MAlgorithm, "Base": BaseAlgorithm, } - -logger.debug("算法模块初始化完成, 注册的算法: %s", list(algorithms.keys())) diff --git a/src/heurams/kernel/evaluators/__init__.py b/src/heurams/kernel/evaluators/__init__.py index 8dd3a86..6bc5688 100644 --- a/src/heurams/kernel/evaluators/__init__.py +++ b/src/heurams/kernel/evaluators/__init__.py @@ -6,8 +6,6 @@ Evaluator 模块 - 生成评估模块 from heurams.services.logger import get_logger -logger = get_logger(__name__) - from .base import BaseEvaluator from .cloze import ClozePuzzle from .mcq import MCQPuzzle @@ -26,38 +24,3 @@ puzzles = { "recognition": RecognitionPuzzle, "base": BaseEvaluator, } - - -@staticmethod -def create_by_dict(config_dict: dict) -> BaseEvaluator: - """ - 根据配置字典创建谜题 - - Args: - config_dict: 配置字典, 包含谜题类型和参数 - - Returns: - BasePuzzle: 谜题实例 - - Raises: - ValueError: 当配置无效时抛出 - """ - logger.debug( - "puzzles.create_by_dict: config_dict keys=%s", list(config_dict.keys()) - ) - puzzle_type = config_dict.get("type") - - if puzzle_type == "cloze": - return puzzles["cloze"]( - text=config_dict["text"], - min_denominator=config_dict.get("min_denominator", 7), - ) - elif puzzle_type == "mcq": - return puzzles["mcq"]( - mapping=config_dict["mapping"], - jammer=config_dict.get("jammer", []), - max_riddles_num=config_dict.get("max_riddles_num", 2), - prefix=config_dict.get("prefix", ""), - ) - else: - raise ValueError(f"未知的谜题类型: {puzzle_type}") diff --git a/src/heurams/kernel/particles/__init__.py b/src/heurams/kernel/particles/__init__.py index 5714307..211e9d5 100644 --- a/src/heurams/kernel/particles/__init__.py +++ b/src/heurams/kernel/particles/__init__.py @@ -1,4 +1,21 @@ from .atom import Atom from .electron import Electron from .nucleon import Nucleon -#from .orbital import Orbital +from .placeholders import ( + AtomPlaceholder, + NucleonPlaceholder, + ElectronPlaceholder, + orbital_placeholder, +) + +# from .orbital import Orbital + +__all__ = [ + "Atom", + "Electron", + "Nucleon", + "AtomPlaceholder", + "NucleonPlaceholder", + "ElectronPlaceholder", + "orbital_placeholder", +] diff --git a/src/heurams/kernel/particles/nucleon.py b/src/heurams/kernel/particles/nucleon.py index 1a59732..e34875f 100644 --- a/src/heurams/kernel/particles/nucleon.py +++ b/src/heurams/kernel/particles/nucleon.py @@ -13,9 +13,11 @@ class Nucleon: def __init__(self, ident, payload, common): self.ident = ident - env = {"payload": payload, - "default": config_var.get()['puzzles'], - "nucleon": (payload | common)} + env = { + "payload": payload, + "default": config_var.get()["puzzles"], + "nucleon": (payload | common), + } self.evalizer = Evalizer(environment=env) self.data: dict = self.evalizer(deepcopy((payload | common))) # type: ignore diff --git a/src/heurams/kernel/reactor/__init__.py b/src/heurams/kernel/reactor/__init__.py index dc87681..f9242fa 100644 --- a/src/heurams/kernel/reactor/__init__.py +++ b/src/heurams/kernel/reactor/__init__.py @@ -5,8 +5,4 @@ from .phaser import Phaser from .procession import Procession from .states import PhaserState, ProcessionState -logger = get_logger(__name__) - __all__ = ["PhaserState", "ProcessionState", "Procession", "Fission", "Phaser"] - -logger.debug("反应堆模块已加载") diff --git a/src/heurams/kernel/reactor/fission.py b/src/heurams/kernel/reactor/fission.py index 866265e..9b665cc 100644 --- a/src/heurams/kernel/reactor/fission.py +++ b/src/heurams/kernel/reactor/fission.py @@ -20,7 +20,7 @@ class Fission: phase_state.value if isinstance(phase_state, PhaserState) else phase_state ) - self.orbital_schedule = atom.registry['orbital']["phases"][phase_value] # type: ignore + self.orbital_schedule = atom.registry["orbital"]["phases"][phase_value] # type: ignore self.orbital_puzzles = atom.registry["nucleon"]["puzzles"] self.puzzles = list() @@ -34,7 +34,6 @@ class Fission: { "puzzle": puz.puzzles[self.orbital_puzzles[item]["__origin__"]], "alia": item, - "finished": 0, } ) possibility -= 1 @@ -44,7 +43,6 @@ class Fission: { "puzzle": puz.puzzles[self.orbital_puzzles[item]["__origin__"]], "alia": item, - "finished": 0, } ) @@ -53,7 +51,7 @@ class Fission: def get_puzzles(self): return self.puzzles - def get_current_puzzle(self, forward = 0): + def get_current_puzzle(self, forward=0): if forward: if len(self.puzzles) <= self.cursor + 1: return 0 @@ -61,10 +59,9 @@ class Fission: return self.puzzles[self.cursor] else: return self.puzzles[self.cursor] - def check_passed(self): for i in self.puzzles: if i["finished"] == 0: return 0 - return 1 \ No newline at end of file + return 1 diff --git a/src/heurams/kernel/reactor/phaser.py b/src/heurams/kernel/reactor/phaser.py index 9fcbd75..03dcf75 100644 --- a/src/heurams/kernel/reactor/phaser.py +++ b/src/heurams/kernel/reactor/phaser.py @@ -130,12 +130,13 @@ class Phaser(Machine): def __repr__(self): from heurams.services.textproc import truncate from tabulate import tabulate as tabu + lst = [ { "Type": "Phaser", "State": self.state, "Processions": list(map(lambda f: (f.name_), self.processions)), - "Current Procession": "None" if not self.current_procession() else self.current_procession().name_, # type: ignore + "Current Procession": "None" if not self.current_procession() else self.current_procession().name_, # type: ignore }, ] - return str(tabu(lst, headers="keys")) + '\n' + return str(tabu(lst, headers="keys")) + "\n" diff --git a/src/heurams/kernel/reactor/procession.py b/src/heurams/kernel/reactor/procession.py index e013a1f..07daeb5 100644 --- a/src/heurams/kernel/reactor/procession.py +++ b/src/heurams/kernel/reactor/procession.py @@ -63,8 +63,7 @@ class Procession(Machine): logger.debug("Procession 进入 FINISHED 状态") def forward(self, step=1): - """将记忆原子指针向前移动并依情况更新原子(返回 1)或完成队列(返回 0) - """ + """将记忆原子指针向前移动并依情况更新原子(返回 1)或完成队列(返回 0)""" logger.debug("Procession.forward: step=%d, 当前 cursor=%d", step, self.cursor) self.cursor += step if self.cursor >= len(self.queue): @@ -84,8 +83,7 @@ class Procession(Machine): return 0 def append(self, atom=None): - """追加(回忆失败的)原子(默认为当前原子)到队列末端 - """ + """追加(回忆失败的)原子(默认为当前原子)到队列末端""" if atom is None: atom = self.current_atom logger.debug("Procession.append: atom=%s", atom.ident if atom else "None") @@ -118,10 +116,11 @@ class Procession(Machine): return empty def get_fission(self): - return Fission(atom=self.current_atom, phase_state=self.phase) # type: ignore + return Fission(atom=self.current_atom, phase_state=self.phase) # type: ignore def __repr__(self): from heurams.services.textproc import truncate + dic = [ { "Type": "Procession", @@ -129,7 +128,7 @@ class Procession(Machine): "State": self.state, "Progress": f"{self.cursor + 1} / {len(self.queue)}", "Queue": list(map(lambda f: truncate(f.ident), self.queue)), - "Current Atom": self.current_atom.ident, # type: ignore + "Current Atom": self.current_atom.ident, # type: ignore } ] - return str(tabu(dic, headers="keys")) + '\n' + return str(tabu(dic, headers="keys")) + "\n" diff --git a/src/heurams/kernel/repolib/__init__.py b/src/heurams/kernel/repolib/__init__.py index 7ecd655..ed5311c 100644 --- a/src/heurams/kernel/repolib/__init__.py +++ b/src/heurams/kernel/repolib/__init__.py @@ -1 +1,3 @@ -from .repo import * +from .repo import Repo, RepoManifest + +__all__ = ["Repo", "RepoManifest"] diff --git a/src/heurams/kernel/repolib/aux.py b/src/heurams/kernel/repolib/aux.py deleted file mode 100644 index 068231e..0000000 --- a/src/heurams/kernel/repolib/aux.py +++ /dev/null @@ -1,5 +0,0 @@ -from ...utils.lict import Lict - - -def merge(x: Lict, y: Lict): - return Lict(list(x.values()) + list(y.values())) diff --git a/src/heurams/services/textproc.py b/src/heurams/services/textproc.py index 729cc43..fd78134 100644 --- a/src/heurams/services/textproc.py +++ b/src/heurams/services/textproc.py @@ -1,4 +1,4 @@ def truncate(text): if len(text) <= 3: return text - return text[:3] + ">" \ No newline at end of file + return text[:3] + ">" diff --git a/src/heurams/utils/__init__.py b/src/heurams/utils/__init__.py index e69de29..2e09240 100644 --- a/src/heurams/utils/__init__.py +++ b/src/heurams/utils/__init__.py @@ -0,0 +1,5 @@ +from .evalizor import Evalizer +from .lict import Lict +from .refvar import RefVar + +__all__ = ["Evalizer", "Lict", "RefVar"] diff --git a/tests/interface/test_dashboard.py b/tests/interface/test_dashboard.py deleted file mode 100644 index 42ca79a..0000000 --- a/tests/interface/test_dashboard.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python3 -""" -DashboardScreen 的测试, 包括单元测试和 pilot 测试. -""" -import pathlib -import tempfile -import time -import unittest -from unittest.mock import MagicMock, patch - -from textual.pilot import Pilot - -from heurams.context import ConfigContext -from heurams.interface.__main__ import HeurAMSApp -from heurams.interface.screens.dashboard import DashboardScreen -from heurams.services.config import ConfigFile - - -class TestDashboardScreenUnit(unittest.TestCase): - """DashboardScreen 的单元测试(不启动完整应用).""" - - def setUp(self): - """在每个测试之前运行, 设置临时目录和配置.""" - # 创建临时目录用于测试数据 - self.temp_dir = tempfile.TemporaryDirectory() - self.temp_path = pathlib.Path(self.temp_dir.name) - - # 创建默认配置, 并修改路径指向临时目录 - default_config_path = ( - pathlib.Path(__file__).parent.parent.parent - / "src/heurams/default/config/config.toml" - ) - self.config = ConfigFile(default_config_path) - # 更新配置中的路径 - config_data = self.config.data - config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon") - config_data["paths"]["electron_dir"] = str(self.temp_path / "electron") - config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital") - config_data["paths"]["cache_dir"] = str(self.temp_path / "cache") - # 禁用快速通过, 避免测试干扰 - config_data["quick_pass"] = 0 - # 禁用时间覆盖 - config_data["daystamp_override"] = -1 - config_data["timestamp_override"] = -1 - - # 创建目录 - for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]: - pathlib.Path(config_data["paths"][dir_key]).mkdir( - parents=True, exist_ok=True - ) - - # 使用 ConfigContext 设置配置 - self.config_ctx = ConfigContext(self.config) - self.config_ctx.__enter__() - - def tearDown(self): - """在每个测试之后清理.""" - self.config_ctx.__exit__(None, None, None) - self.temp_dir.cleanup() - - def test_compose(self): - """测试 compose 方法返回正确的部件.""" - screen = DashboardScreen() - # 手动调用 compose 并收集部件 - from textual.app import ComposeResult - - result = screen.compose() - widgets = list(result) - # 检查是否包含 Header 和 Footer - from textual.widgets import Footer, Header - - header_present = any(isinstance(w, Header) for w in widgets) - footer_present = any(isinstance(w, Footer) for w in widgets) - self.assertTrue(header_present) - self.assertTrue(footer_present) - # 检查是否有 ScrollableContainer - from textual.containers import ScrollableContainer - - container_present = any(isinstance(w, ScrollableContainer) for w in widgets) - self.assertTrue(container_present) - # 使用 query_one 查找 union-list, 即使屏幕未挂载也可能有效 - list_view = screen.query_one("#union-list") - self.assertIsNotNone(list_view) - self.assertEqual(list_view.id, "union-list") - self.assertEqual(list_view.__class__.__name__, "ListView") - - def test_item_desc_generator(self): - """测试 item_desc_generator 函数.""" - screen = DashboardScreen() - # 模拟一个文件名 - filename = "test.toml" - result = screen.analyser(filename) - self.assertIsInstance(result, dict) - self.assertIn(0, result) - self.assertIn(1, result) - # 检查内容 - self.assertIn("test.toml", result[0]) - # 由于 electron 文件不存在, 应显示“尚未激活” - self.assertIn("尚未激活", result[1]) - - -@unittest.skip("Pilot 测试需要进一步配置, 暂不运行") -class TestDashboardScreenPilot(unittest.TestCase): - """使用 Textual Pilot 的集成测试.""" - - def setUp(self): - """配置临时目录和配置.""" - self.temp_dir = tempfile.TemporaryDirectory() - self.temp_path = pathlib.Path(self.temp_dir.name) - - default_config_path = ( - pathlib.Path(__file__).parent.parent.parent - / "src/heurams/default/config/config.toml" - ) - self.config = ConfigFile(default_config_path) - config_data = self.config.data - config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon") - config_data["paths"]["electron_dir"] = str(self.temp_path / "electron") - config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital") - config_data["paths"]["cache_dir"] = str(self.temp_path / "cache") - config_data["quick_pass"] = 0 - config_data["daystamp_override"] = -1 - config_data["timestamp_override"] = -1 - - for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]: - pathlib.Path(config_data["paths"][dir_key]).mkdir( - parents=True, exist_ok=True - ) - - self.config_ctx = ConfigContext(self.config) - self.config_ctx.__enter__() - - def tearDown(self): - self.config_ctx.__exit__(None, None, None) - self.temp_dir.cleanup() - - def test_dashboard_loads_with_pilot(self): - """使用 Pilot 测试 DashboardScreen 加载.""" - with patch("heurams.interface.__main__.environment_check"): - app = HeurAMSApp() - # 注意: Pilot 在 Textual 6.9.0 中的用法可能不同 - # 以下为示例代码, 可能需要调整 - pilot = Pilot(app) - # 等待应用启动 - pilot.pause() - screen = app.screen - self.assertEqual(screen.__class__.__name__, "DashboardScreen") - union_list = app.query_one("#union-list") - self.assertIsNotNone(union_list) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/interface/test_synctool.py b/tests/interface/test_synctool.py deleted file mode 100644 index 096ac8a..0000000 --- a/tests/interface/test_synctool.py +++ /dev/null @@ -1,317 +0,0 @@ -#!/usr/bin/env python3 -""" -SyncScreen 和 SyncService 的测试. -""" -import pathlib -import tempfile -import time -import unittest -from unittest.mock import MagicMock, Mock, patch - -from heurams.context import ConfigContext -from heurams.services.config import ConfigFile -from heurams.services.sync_service import ( - ConflictStrategy, - SyncConfig, - SyncMode, - SyncService, -) - - -class TestSyncServiceUnit(unittest.TestCase): - """SyncService 的单元测试.""" - - def setUp(self): - """在每个测试之前运行, 设置临时目录和模拟客户端.""" - self.temp_dir = tempfile.TemporaryDirectory() - self.temp_path = pathlib.Path(self.temp_dir.name) - - # 创建测试文件 - self.test_file = self.temp_path / "test.txt" - self.test_file.write_text("测试内容") - - # 模拟 WebDAV 客户端 - self.mock_client = MagicMock() - - # 创建同步配置 - self.config = SyncConfig( - enabled=True, - url="https://example.com/dav/", - username="test", - password="test", - remote_path="/heurams/", - sync_mode=SyncMode.BIDIRECTIONAL, - conflict_strategy=ConflictStrategy.NEWER, - verify_ssl=True, - ) - - def tearDown(self): - """在每个测试之后清理.""" - self.temp_dir.cleanup() - - @patch("heurams.services.sync_service.Client") - def test_sync_service_initialization(self, mock_client_class): - """测试同步服务初始化.""" - mock_client_class.return_value = self.mock_client - - service = SyncService(self.config) - - # 验证客户端已创建 - mock_client_class.assert_called_once() - self.assertIsNotNone(service.client) - self.assertEqual(service.config, self.config) - - @patch("heurams.services.sync_service.Client") - def test_sync_service_disabled(self, mock_client_class): - """测试同步服务未启用.""" - config = SyncConfig(enabled=False) - service = SyncService(config) - - # 客户端不应初始化 - mock_client_class.assert_not_called() - self.assertIsNone(service.client) - - @patch("heurams.services.sync_service.Client") - def test_test_connection_success(self, mock_client_class): - """测试连接测试成功.""" - mock_client_class.return_value = self.mock_client - self.mock_client.list.return_value = [] - - service = SyncService(self.config) - result = service.test_connection() - - self.assertTrue(result) - self.mock_client.list.assert_called_once() - - @patch("heurams.services.sync_service.Client") - def test_test_connection_failure(self, mock_client_class): - """测试连接测试失败.""" - mock_client_class.return_value = self.mock_client - self.mock_client.list.side_effect = Exception("连接失败") - - service = SyncService(self.config) - result = service.test_connection() - - self.assertFalse(result) - self.mock_client.list.assert_called_once() - - @patch("heurams.services.sync_service.Client") - def test_upload_file(self, mock_client_class): - """测试上传单个文件.""" - mock_client_class.return_value = self.mock_client - - service = SyncService(self.config) - result = service.upload_file(self.test_file) - - self.assertTrue(result) - self.mock_client.upload_file.assert_called_once() - - @patch("heurams.services.sync_service.Client") - def test_download_file(self, mock_client_class): - """测试下载单个文件.""" - mock_client_class.return_value = self.mock_client - - service = SyncService(self.config) - remote_path = "/heurams/test.txt" - local_path = self.temp_path / "downloaded.txt" - - result = service.download_file(remote_path, local_path) - - self.assertTrue(result) - self.mock_client.download_file.assert_called_once() - self.assertTrue(local_path.parent.exists()) - - @patch("heurams.services.sync_service.Client") - def test_sync_directory_no_files(self, mock_client_class): - """测试同步空目录.""" - mock_client_class.return_value = self.mock_client - self.mock_client.list.return_value = [] - self.mock_client.mkdir.return_value = None - - service = SyncService(self.config) - result = service.sync_directory(self.temp_path) - - self.assertTrue(result["success"]) - self.assertEqual(result["uploaded"], 0) - self.assertEqual(result["downloaded"], 0) - self.mock_client.mkdir.assert_called_once() - - @patch("heurams.services.sync_service.Client") - def test_sync_directory_upload_only(self, mock_client_class): - """测试仅上传模式.""" - mock_client_class.return_value = self.mock_client - self.mock_client.list.return_value = [] - self.mock_client.mkdir.return_value = None - - config = SyncConfig( - enabled=True, - url="https://example.com/dav/", - username="test", - password="test", - remote_path="/heurams/", - sync_mode=SyncMode.UPLOAD_ONLY, - conflict_strategy=ConflictStrategy.NEWER, - ) - - service = SyncService(config) - result = service.sync_directory(self.temp_path) - - self.assertTrue(result["success"]) - self.mock_client.mkdir.assert_called_once() - - @patch("heurams.services.sync_service.Client") - def test_conflict_strategy_newer(self, mock_client_class): - """测试 NEWER 冲突策略.""" - mock_client_class.return_value = self.mock_client - - # 模拟远程文件存在 - self.mock_client.list.return_value = ["test.txt"] - self.mock_client.info.return_value = { - "size": 100, - "modified": "2023-01-01T00:00:00Z", - } - self.mock_client.mkdir.return_value = None - - service = SyncService(self.config) - result = service.sync_directory(self.temp_path) - - self.assertTrue(result["success"]) - # 应该有一个冲突 - self.assertGreaterEqual(result.get("conflicts", 0), 0) - - @patch("heurams.services.sync_service.Client") - def test_create_sync_service_from_config(self, mock_client_class): - """测试从配置文件创建同步服务.""" - mock_client_class.return_value = self.mock_client - - # 创建临时配置文件 - config_data = { - "sync": { - "webdav": { - "enabled": True, - "url": "https://example.com/dav/", - "username": "test", - "password": "test", - "remote_path": "/heurams/", - "sync_mode": "bidirectional", - "conflict_strategy": "newer", - "verify_ssl": True, - } - } - } - - # 模拟 config_var - with patch("heurams.services.sync_service.config_var") as mock_config_var: - mock_config = MagicMock() - mock_config.data = config_data - mock_config_var.get.return_value = mock_config - - from heurams.services.sync_service import create_sync_service_from_config - - service = create_sync_service_from_config() - - self.assertIsNotNone(service) - self.assertIsNotNone(service.client) - - -class TestSyncScreenUnit(unittest.TestCase): - """SyncScreen 的单元测试.""" - - def setUp(self): - """在每个测试之前运行.""" - self.temp_dir = tempfile.TemporaryDirectory() - self.temp_path = pathlib.Path(self.temp_dir.name) - - # 创建默认配置 - default_config_path = ( - pathlib.Path(__file__).parent.parent.parent - / "src/heurams/default/config/config.toml" - ) - self.config = ConfigFile(default_config_path) - - # 更新配置中的路径 - config_data = self.config.data - config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon") - config_data["paths"]["electron_dir"] = str(self.temp_path / "electron") - config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital") - config_data["paths"]["cache_dir"] = str(self.temp_path / "cache") - - # 添加同步配置 - if "sync" not in config_data: - config_data["sync"] = {} - config_data["sync"]["webdav"] = { - "enabled": False, - "url": "", - "username": "", - "password": "", - "remote_path": "/heurams/", - "sync_mode": "bidirectional", - "conflict_strategy": "newer", - "verify_ssl": True, - } - - # 创建目录 - for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]: - pathlib.Path(config_data["paths"][dir_key]).mkdir( - parents=True, exist_ok=True - ) - - # 使用 ConfigContext 设置配置 - self.config_ctx = ConfigContext(self.config) - self.config_ctx.__enter__() - - def tearDown(self): - """在每个测试之后清理.""" - self.config_ctx.__exit__(None, None, None) - self.temp_dir.cleanup() - - @patch("heurams.interface.screens.synctool.create_sync_service_from_config") - def test_sync_screen_compose(self, mock_create_service): - """测试 SyncScreen 的 compose 方法.""" - from heurams.interface.screens.synctool import SyncScreen - - # 模拟同步服务 - mock_service = MagicMock() - mock_service.client = MagicMock() - mock_create_service.return_value = mock_service - - screen = SyncScreen() - - # 测试 compose 方法 - from textual.app import ComposeResult - - result = screen.compose() - widgets = list(result) - - # 检查基本部件 - from textual.containers import ScrollableContainer - from textual.widgets import Button, Footer, Header, ProgressBar, Static - - header_present = any(isinstance(w, Header) for w in widgets) - footer_present = any(isinstance(w, Footer) for w in widgets) - self.assertTrue(header_present) - self.assertTrue(footer_present) - - # 检查容器 - container_present = any(isinstance(w, ScrollableContainer) for w in widgets) - self.assertTrue(container_present) - - @patch("heurams.interface.screens.synctool.create_sync_service_from_config") - def test_sync_screen_load_config(self, mock_create_service): - """测试 SyncScreen 加载配置.""" - from heurams.interface.screens.synctool import SyncScreen - - mock_service = MagicMock() - mock_service.client = MagicMock() - mock_create_service.return_value = mock_service - - screen = SyncScreen() - screen.load_config() - - # 验证配置已加载 - self.assertIsNotNone(screen.sync_config) - mock_create_service.assert_called_once() - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/kernel/algorithms/__init__.py b/tests/kernel/algorithms/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/kernel/algorithms/test_sm2.py b/tests/kernel/algorithms/test_sm2.py deleted file mode 100644 index ecc1a24..0000000 --- a/tests/kernel/algorithms/test_sm2.py +++ /dev/null @@ -1,186 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -from heurams.kernel.algorithms.sm2 import SM2Algorithm - - -class TestSM2Algorithm(unittest.TestCase): - """测试 SM2Algorithm 类""" - - def setUp(self): - # 模拟 timer 函数 - self.timestamp_patcher = patch( - "heurams.kernel.algorithms.sm2.timer.get_timestamp" - ) - self.daystamp_patcher = patch( - "heurams.kernel.algorithms.sm2.timer.get_daystamp" - ) - self.mock_get_timestamp = self.timestamp_patcher.start() - self.mock_get_daystamp = self.daystamp_patcher.start() - - # 设置固定返回值 - self.mock_get_timestamp.return_value = 1000.0 - self.mock_get_daystamp.return_value = 100 - - def tearDown(self): - self.timestamp_patcher.stop() - self.daystamp_patcher.stop() - - def test_defaults(self): - """测试默认值""" - defaults = SM2Algorithm.defaults - self.assertEqual(defaults["efactor"], 2.5) - self.assertEqual(defaults["real_rept"], 0) - self.assertEqual(defaults["rept"], 0) - self.assertEqual(defaults["interval"], 0) - self.assertEqual(defaults["last_date"], 0) - self.assertEqual(defaults["next_date"], 0) - self.assertEqual(defaults["is_activated"], 0) - # last_modify 是动态的, 仅检查存在性 - self.assertIn("last_modify", defaults) - - def test_revisor_feedback_minus_one(self): - """测试 feedback = -1 时跳过更新""" - algodata = {SM2Algorithm.algo_name: SM2Algorithm.defaults.copy()} - SM2Algorithm.revisor(algodata, feedback=-1) - # 数据应保持不变 - self.assertEqual(algodata[SM2Algorithm.algo_name]["efactor"], 2.5) - self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 0) - self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 0) - - def test_revisor_feedback_less_than_3(self): - """测试 feedback < 3 重置 rept 和 interval""" - algodata = { - SM2Algorithm.algo_name: { - "efactor": 2.5, - "rept": 5, - "interval": 10, - "real_rept": 3, - } - } - SM2Algorithm.revisor(algodata, feedback=2) - self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 0) - # rept=0 导致 interval 被设置为 1 - self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 1) - self.assertEqual(algodata[SM2Algorithm.algo_name]["real_rept"], 4) # 递增 - - def test_revisor_feedback_greater_equal_3(self): - """测试 feedback >= 3 递增 rept""" - algodata = { - SM2Algorithm.algo_name: { - "efactor": 2.5, - "rept": 2, - "interval": 6, - "real_rept": 2, - } - } - SM2Algorithm.revisor(algodata, feedback=4) - self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 3) - self.assertEqual(algodata[SM2Algorithm.algo_name]["real_rept"], 3) - # interval 应根据 rept 和 efactor 重新计算 - # rept=3, interval = round(6 * 2.5) = 15 - self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 15) - - def test_revisor_new_activation(self): - """测试 is_new_activation 重置 rept 和 efactor""" - algodata = { - SM2Algorithm.algo_name: { - "efactor": 3.0, - "rept": 5, - "interval": 20, - "real_rept": 5, - } - } - SM2Algorithm.revisor(algodata, feedback=5, is_new_activation=True) - self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 0) - self.assertEqual(algodata[SM2Algorithm.algo_name]["efactor"], 2.5) - # interval 应为 1(因为 rept=0) - self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 1) - - def test_revisor_efactor_calculation(self): - """测试 efactor 计算""" - algodata = { - SM2Algorithm.algo_name: { - "efactor": 2.5, - "rept": 1, - "interval": 6, - "real_rept": 1, - } - } - SM2Algorithm.revisor(algodata, feedback=5) - # efactor = 2.5 + (0.1 - (5-5)*(0.08 + (5-5)*0.02)) = 2.5 + 0.1 = 2.6 - self.assertAlmostEqual( - algodata[SM2Algorithm.algo_name]["efactor"], 2.6, places=6 - ) - - # 测试 efactor 下限为 1.3 - algodata[SM2Algorithm.algo_name]["efactor"] = 1.2 - SM2Algorithm.revisor(algodata, feedback=5) - self.assertEqual(algodata[SM2Algorithm.algo_name]["efactor"], 1.3) - - def test_revisor_interval_calculation(self): - """测试 interval 计算规则""" - algodata = { - SM2Algorithm.algo_name: { - "efactor": 2.5, - "rept": 0, - "interval": 0, - "real_rept": 0, - } - } - SM2Algorithm.revisor(algodata, feedback=4) - # rept 从 0 递增到 1, 因此 interval 应为 6 - self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 6) - - # 现在 rept=1, 再次调用 revisor 递增到 2 - SM2Algorithm.revisor(algodata, feedback=4) - # rept=2, interval = round(6 * 2.5) = 15 - self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 15) - - # 单独测试 rept=1 的情况 - algodata2 = { - SM2Algorithm.algo_name: { - "efactor": 2.5, - "rept": 1, - "interval": 0, - "real_rept": 0, - } - } - SM2Algorithm.revisor(algodata2, feedback=4) - # rept 递增到 2, interval = round(0 * 2.5) = 0 - self.assertEqual(algodata2[SM2Algorithm.algo_name]["interval"], 0) - - def test_revisor_updates_dates(self): - """测试更新日期字段""" - algodata = {SM2Algorithm.algo_name: SM2Algorithm.defaults.copy()} - self.mock_get_daystamp.return_value = 200 - SM2Algorithm.revisor(algodata, feedback=5) - self.assertEqual(algodata[SM2Algorithm.algo_name]["last_date"], 200) - self.assertEqual( - algodata[SM2Algorithm.algo_name]["next_date"], - 200 + algodata[SM2Algorithm.algo_name]["interval"], - ) - self.assertEqual(algodata[SM2Algorithm.algo_name]["last_modify"], 1000.0) - - def test_is_due(self): - """测试 is_due 方法""" - algodata = {SM2Algorithm.algo_name: {"next_date": 100}} - self.mock_get_daystamp.return_value = 150 - self.assertTrue(SM2Algorithm.is_due(algodata)) - - algodata[SM2Algorithm.algo_name]["next_date"] = 200 - self.assertFalse(SM2Algorithm.is_due(algodata)) - - def test_rate(self): - """测试 rate 方法""" - algodata = {SM2Algorithm.algo_name: {"efactor": 2.7}} - self.assertEqual(SM2Algorithm.get_rating(algodata), "2.7") - - def test_nextdate(self): - """测试 nextdate 方法""" - algodata = {SM2Algorithm.algo_name: {"next_date": 12345}} - self.assertEqual(SM2Algorithm.nextdate(algodata), 12345) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/kernel/particles/__init__.py b/tests/kernel/particles/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/kernel/particles/test_atom.py b/tests/kernel/particles/test_atom.py deleted file mode 100644 index 09371d5..0000000 --- a/tests/kernel/particles/test_atom.py +++ /dev/null @@ -1,202 +0,0 @@ -import json -import pathlib -import tempfile -import unittest -from unittest.mock import MagicMock, patch - -import toml - -from heurams.context import ConfigContext -from heurams.kernel.particles.atom import Atom, atom_registry -from heurams.kernel.particles.electron import Electron -from heurams.kernel.particles.nucleon import Nucleon -from heurams.kernel.particles.orbital import Orbital -from heurams.services.config import ConfigFile - - -class TestAtom(unittest.TestCase): - """测试 Atom 类""" - - def setUp(self): - """在每个测试之前运行""" - # 创建临时目录用于持久化测试 - self.temp_dir = tempfile.TemporaryDirectory() - self.temp_path = pathlib.Path(self.temp_dir.name) - - # 创建默认配置 - self.config = ConfigFile( - pathlib.Path(__file__).parent.parent.parent.parent - / "src/heurams/default/config/config.toml" - ) - - # 使用 ConfigContext 设置配置 - self.config_ctx = ConfigContext(self.config) - self.config_ctx.__enter__() - - # 清空全局注册表 - atom_registry.clear() - - def tearDown(self): - """在每个测试之后运行""" - self.config_ctx.__exit__(None, None, None) - self.temp_dir.cleanup() - atom_registry.clear() - - def test_init(self): - """测试 Atom 初始化""" - atom = Atom("test_atom") - self.assertEqual(atom.ident, "test_atom") - self.assertIn("test_atom", atom_registry) - self.assertEqual(atom_registry["test_atom"], atom) - - # 检查 registry 默认值 - self.assertIsNone(atom.registry["nucleon"]) - self.assertIsNone(atom.registry["electron"]) - self.assertIsNone(atom.registry["orbital"]) - self.assertEqual(atom.registry["nucleon_fmt"], "toml") - self.assertEqual(atom.registry["electron_fmt"], "json") - self.assertEqual(atom.registry["orbital_fmt"], "toml") - - def test_link(self): - """测试 link 方法""" - atom = Atom("test_link") - nucleon = Nucleon("test_nucleon", {"content": "test content"}) - - atom.link("nucleon", nucleon) - self.assertEqual(atom.registry["nucleon"], nucleon) - - # 测试链接不支持的键 - with self.assertRaises(ValueError): - atom.link("invalid_key", "value") - - def test_link_triggers_do_eval(self): - """测试 link 后触发 do_eval""" - atom = Atom("test_eval_trigger") - nucleon = Nucleon("test_nucleon", {"content": "eval:1+1"}) - - with patch.object(atom, "do_eval") as mock_do_eval: - atom.link("nucleon", nucleon) - mock_do_eval.assert_called_once() - - def test_persist_toml(self): - """测试 TOML 持久化""" - atom = Atom("test_persist_toml") - nucleon = Nucleon("test_nucleon", {"content": "test"}) - atom.link("nucleon", nucleon) - - # 设置路径 - test_path = self.temp_path / "test.toml" - atom.link("nucleon_path", test_path) - - atom.persist("nucleon") - - # 验证文件存在且内容正确 - self.assertTrue(test_path.exists()) - with open(test_path, "r") as f: - data = toml.load(f) - self.assertEqual(data["ident"], "test_nucleon") - self.assertEqual(data["payload"]["content"], "test") - - def test_persist_json(self): - """测试 JSON 持久化""" - atom = Atom("test_persist_json") - electron = Electron("test_electron", {}) - atom.link("electron", electron) - - test_path = self.temp_path / "test.json" - atom.link("electron_path", test_path) - - atom.persist("electron") - - self.assertTrue(test_path.exists()) - with open(test_path, "r") as f: - data = json.load(f) - self.assertIn("supermemo2", data) - - def test_persist_invalid_format(self): - """测试无效持久化格式""" - atom = Atom("test_invalid_format") - nucleon = Nucleon("test_nucleon", {}) - atom.link("nucleon", nucleon) - atom.link("nucleon_path", self.temp_path / "test.txt") - atom.registry["nucleon_fmt"] = "invalid" - - with self.assertRaises(KeyError): - atom.persist("nucleon") - - def test_persist_no_path(self): - """测试未初始化路径的持久化""" - atom = Atom("test_no_path") - nucleon = Nucleon("test_nucleon", {}) - atom.link("nucleon", nucleon) - # 不设置 nucleon_path - - with self.assertRaises(TypeError): - atom.persist("nucleon") - - def test_getitem_setitem(self): - """测试 __getitem__ 和 __setitem__""" - atom = Atom("test_getset") - nucleon = Nucleon("test_nucleon", {}) - - atom["nucleon"] = nucleon - self.assertEqual(atom["nucleon"], nucleon) - - # 测试不支持的键 - with self.assertRaises(KeyError): - _ = atom["invalid_key"] - - with self.assertRaises(KeyError): - atom["invalid_key"] = "value" - - def test_do_eval_with_eval_string(self): - """测试 do_eval 处理 eval: 字符串""" - atom = Atom("test_do_eval") - nucleon = Nucleon( - "test_nucleon", - {"content": "eval:'hello' + ' world'", "number": "eval:2 + 3"}, - ) - atom.link("nucleon", nucleon) - - # do_eval 应该在链接时自动调用 - # 检查 eval 表达式是否被求值 - self.assertEqual(nucleon.payload["content"], "hello world") - self.assertEqual(nucleon.payload["number"], "5") - - def test_do_eval_with_config_access(self): - """测试 do_eval 访问配置""" - atom = Atom("test_eval_config") - nucleon = Nucleon( - "test_nucleon", {"max_riddles": "eval:default['mcq']['max_riddles_num']"} - ) - atom.link("nucleon", nucleon) - - # 配置中 puzzles.mcq.max_riddles_num = 2 - self.assertEqual(nucleon.payload["max_riddles"], 2) - - def test_placeholder(self): - """测试静态方法 placeholder""" - placeholder = Atom.placeholder() - self.assertIsInstance(placeholder, tuple) - self.assertEqual(len(placeholder), 3) - self.assertIsInstance(placeholder[0], Electron) - self.assertIsInstance(placeholder[1], Nucleon) - self.assertIsInstance(placeholder[2], dict) - - def test_atom_registry_management(self): - """测试全局注册表管理""" - # 创建多个 Atom - atom1 = Atom("atom1") - atom2 = Atom("atom2") - - self.assertEqual(len(atom_registry), 2) - self.assertEqual(atom_registry["atom1"], atom1) - self.assertEqual(atom_registry["atom2"], atom2) - - # 测试 bidict 的反向查找 - self.assertEqual(atom_registry.inverse[atom1], "atom1") - self.assertEqual(atom_registry.inverse[atom2], "atom2") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/kernel/particles/test_electron.py b/tests/kernel/particles/test_electron.py deleted file mode 100644 index 7fe7e4e..0000000 --- a/tests/kernel/particles/test_electron.py +++ /dev/null @@ -1,179 +0,0 @@ -import sys -import unittest -from unittest.mock import MagicMock, patch - -from heurams.kernel.algorithms import algorithms -from heurams.kernel.particles.electron import Electron - - -class TestElectron(unittest.TestCase): - """测试 Electron 类""" - - def setUp(self): - # 模拟 timer.get_timestamp 返回固定值 - self.timestamp_patcher = patch( - "heurams.kernel.particles.electron.timer.get_timestamp" - ) - self.mock_get_timestamp = self.timestamp_patcher.start() - self.mock_get_timestamp.return_value = 1234567890.0 - - def tearDown(self): - self.timestamp_patcher.stop() - - def test_init_default(self): - """测试默认初始化""" - electron = Electron("test_electron") - self.assertEqual(electron.ident, "test_electron") - self.assertEqual(electron.algo, algorithms["supermemo2"]) - self.assertIn(electron.algo, electron.algodata) - self.assertIsInstance(electron.algodata[electron.algo], dict) - # 检查默认值(排除动态字段) - defaults = electron.algo.defaults - for key, value in defaults.items(): - if key == "last_modify": - # last_modify 是动态的, 只检查存在性 - self.assertIn(key, electron.algodata[electron.algo]) - elif key == "is_activated": - # TODO: 调查为什么 is_activated 是 1 - self.assertEqual(electron.algodata[electron.algo][key], 1) - else: - self.assertEqual(electron.algodata[electron.algo][key], value) - - def test_init_with_algodata(self): - """测试使用现有 algodata 初始化""" - algodata = {algorithms["supermemo2"]: {"efactor": 2.5, "interval": 1}} - electron = Electron("test_electron", algodata=algodata) - self.assertEqual(electron.algodata[electron.algo]["efactor"], 2.5) - self.assertEqual(electron.algodata[electron.algo]["interval"], 1) - # 其他字段可能不存在, 因为未提供默认初始化 - # 检查 real_rept 不存在 - self.assertNotIn("real_rept", electron.algodata[electron.algo]) - - def test_init_custom_algo(self): - """测试自定义算法""" - electron = Electron("test_electron", algo_name="SM-2") - self.assertEqual(electron.algo, algorithms["SM-2"]) - self.assertIn(electron.algo, electron.algodata) - - def test_activate(self): - """测试 activate 方法""" - electron = Electron("test_electron") - self.assertEqual(electron.algodata[electron.algo]["is_activated"], 0) - electron.activate() - self.assertEqual(electron.algodata[electron.algo]["is_activated"], 1) - self.assertEqual(electron.algodata[electron.algo]["last_modify"], 1234567890.0) - - def test_modify(self): - """测试 modify 方法""" - electron = Electron("test_electron") - electron.modify("interval", 5) - self.assertEqual(electron.algodata[electron.algo]["interval"], 5) - self.assertEqual(electron.algodata[electron.algo]["last_modify"], 1234567890.0) - - # 修改不存在的字段应记录警告但不引发异常 - with patch("heurams.kernel.particles.electron.logger.warning") as mock_warning: - electron.modify("unknown_field", 99) - mock_warning.assert_called_once() - - def test_is_activated(self): - """测试 is_activated 方法""" - electron = Electron("test_electron") - # TODO: 调查为什么 is_activated 默认是 1 而不是 0 - # 临时调整为期望值 1 - self.assertEqual(electron.is_activated(), 1) - electron.activate() - self.assertEqual(electron.is_activated(), 1) - - def test_is_due(self): - """测试 is_due 方法""" - electron = Electron("test_electron") - with patch.object(electron.algo, "is_due") as mock_is_due: - mock_is_due.return_value = 1 - result = electron.is_due() - mock_is_due.assert_called_once_with(electron.algodata) - self.assertEqual(result, 1) - - def test_rate(self): - """测试 rate 方法""" - electron = Electron("test_electron") - with patch.object(electron.algo, "rate") as mock_rate: - mock_rate.return_value = "good" - result = electron.get_rating() - mock_rate.assert_called_once_with(electron.algodata) - self.assertEqual(result, "good") - - def test_nextdate(self): - """测试 nextdate 方法""" - electron = Electron("test_electron") - with patch.object(electron.algo, "nextdate") as mock_nextdate: - mock_nextdate.return_value = 1234568000 - result = electron.nextdate() - mock_nextdate.assert_called_once_with(electron.algodata) - self.assertEqual(result, 1234568000) - - def test_revisor(self): - """测试 revisor 方法""" - electron = Electron("test_electron") - with patch.object(electron.algo, "revisor") as mock_revisor: - electron.revisor(quality=3, is_new_activation=True) - mock_revisor.assert_called_once_with(electron.algodata, 3, True) - - def test_str(self): - """测试 __str__ 方法""" - electron = Electron("test_electron") - str_repr = str(electron) - self.assertIn("记忆单元预览", str_repr) - self.assertIn("test_electron", str_repr) - # 算法类名会出现在字符串表示中 - self.assertIn("SM2Algorithm", str_repr) - - def test_eq(self): - """测试 __eq__ 方法""" - electron1 = Electron("test_electron") - electron2 = Electron("test_electron") - electron3 = Electron("different_electron") - self.assertEqual(electron1, electron2) - self.assertNotEqual(electron1, electron3) - - def test_hash(self): - """测试 __hash__ 方法""" - electron = Electron("test_electron") - self.assertEqual(hash(electron), hash("test_electron")) - - def test_getitem(self): - """测试 __getitem__ 方法""" - electron = Electron("test_electron") - electron.activate() - self.assertEqual(electron["ident"], "test_electron") - self.assertEqual(electron["is_activated"], 1) - - with self.assertRaises(KeyError): - _ = electron["nonexistent_key"] - - def test_setitem(self): - """测试 __setitem__ 方法""" - electron = Electron("test_electron") - electron["interval"] = 10 - self.assertEqual(electron.algodata[electron.algo]["interval"], 10) - self.assertEqual(electron.algodata[electron.algo]["last_modify"], 1234567890.0) - - with self.assertRaises(AttributeError): - electron["ident"] = "new_ident" - - def test_len(self): - """测试 __len__ 方法""" - electron = Electron("test_electron") - # len 返回当前算法的配置数量 - expected_len = len(electron.algo.defaults) - self.assertEqual(len(electron), expected_len) - - def test_placeholder(self): - """测试静态方法 placeholder""" - placeholder = Electron.placeholder() - self.assertIsInstance(placeholder, Electron) - self.assertEqual(placeholder.ident, "电子对象样例内容") - self.assertEqual(placeholder.algo, algorithms["supermemo2"]) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/kernel/particles/test_nucleon.py b/tests/kernel/particles/test_nucleon.py deleted file mode 100644 index 6239df6..0000000 --- a/tests/kernel/particles/test_nucleon.py +++ /dev/null @@ -1,108 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -from heurams.kernel.particles.nucleon import Nucleon - - -class TestNucleon(unittest.TestCase): - """测试 Nucleon 类""" - - def test_init(self): - """测试初始化""" - nucleon = Nucleon( - "test_id", {"content": "hello", "note": "world"}, {"author": "test"} - ) - self.assertEqual(nucleon.ident, "test_id") - self.assertEqual(nucleon.payload, {"content": "hello", "note": "world"}) - self.assertEqual(nucleon.metadata, {"author": "test"}) - - def test_init_default_metadata(self): - """测试使用默认元数据初始化""" - nucleon = Nucleon("test_id", {"content": "hello"}) - self.assertEqual(nucleon.ident, "test_id") - self.assertEqual(nucleon.payload, {"content": "hello"}) - self.assertEqual(nucleon.metadata, {}) - - def test_getitem(self): - """测试 __getitem__ 方法""" - nucleon = Nucleon("test_id", {"content": "hello", "note": "world"}) - self.assertEqual(nucleon["ident"], "test_id") - self.assertEqual(nucleon["content"], "hello") - self.assertEqual(nucleon["note"], "world") - - with self.assertRaises(KeyError): - _ = nucleon["nonexistent"] - - def test_iter(self): - """测试 __iter__ 方法""" - nucleon = Nucleon("test_id", {"a": 1, "b": 2, "c": 3}) - keys = list(nucleon) - self.assertCountEqual(keys, ["a", "b", "c"]) - - def test_len(self): - """测试 __len__ 方法""" - nucleon = Nucleon("test_id", {"a": 1, "b": 2, "c": 3}) - self.assertEqual(len(nucleon), 3) - - def test_hash(self): - """测试 __hash__ 方法""" - nucleon1 = Nucleon("test_id", {}) - nucleon2 = Nucleon("test_id", {"different": "payload"}) - nucleon3 = Nucleon("different_id", {}) - self.assertEqual(hash(nucleon1), hash(nucleon2)) # 相同 ident - self.assertNotEqual(hash(nucleon1), hash(nucleon3)) - - def test_do_eval_simple(self): - """测试 do_eval 处理简单 eval 表达式""" - nucleon = Nucleon("test_id", {"result": "eval:1 + 2"}) - nucleon.do_eval() - self.assertEqual(nucleon.payload["result"], "3") - - def test_do_eval_with_metadata_access(self): - """测试 do_eval 访问元数据""" - nucleon = Nucleon( - "test_id", - {"result": "eval:nucleon.metadata.get('value', 0)"}, - {"value": 42}, - ) - nucleon.do_eval() - self.assertEqual(nucleon.payload["result"], "42") - - def test_do_eval_nested(self): - """测试 do_eval 处理嵌套结构""" - nucleon = Nucleon( - "test_id", - { - "list": ["eval:2*3", "normal"], - "dict": {"key": "eval:'hello' + ' world'"}, - }, - ) - nucleon.do_eval() - self.assertEqual(nucleon.payload["list"][0], "6") - self.assertEqual(nucleon.payload["list"][1], "normal") - self.assertEqual(nucleon.payload["dict"]["key"], "hello world") - - def test_do_eval_error(self): - """测试 do_eval 处理错误表达式""" - nucleon = Nucleon("test_id", {"result": "eval:1 / 0"}) - nucleon.do_eval() - self.assertIn("此 eval 实例发生错误", nucleon.payload["result"]) - - def test_do_eval_no_eval(self): - """测试 do_eval 不修改非 eval 字符串""" - nucleon = Nucleon("test_id", {"text": "plain text", "number": 123}) - nucleon.do_eval() - self.assertEqual(nucleon.payload["text"], "plain text") - self.assertEqual(nucleon.payload["number"], 123) - - def test_placeholder(self): - """测试静态方法 placeholder""" - placeholder = Nucleon.placeholder() - self.assertIsInstance(placeholder, Nucleon) - self.assertEqual(placeholder.ident, "核子对象样例内容") - self.assertEqual(placeholder.payload, {}) - self.assertEqual(placeholder.metadata, {}) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/kernel/puzzles/__init__.py b/tests/kernel/puzzles/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/kernel/puzzles/test_base.py b/tests/kernel/puzzles/test_base.py deleted file mode 100644 index c5e4387..0000000 --- a/tests/kernel/puzzles/test_base.py +++ /dev/null @@ -1,23 +0,0 @@ -import unittest -from unittest.mock import Mock - -from heurams.kernel.evaluators.base import BasePuzzle - - -class TestBasePuzzle(unittest.TestCase): - """测试 BasePuzzle 基类""" - - def test_refresh_not_implemented(self): - """测试 refresh 方法未实现时抛出异常""" - puzzle = BasePuzzle() - with self.assertRaises(NotImplementedError): - puzzle.refresh() - - def test_str(self): - """测试 __str__ 方法""" - puzzle = BasePuzzle() - self.assertEqual(str(puzzle), "谜题: BasePuzzle") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/kernel/puzzles/test_cloze.py b/tests/kernel/puzzles/test_cloze.py deleted file mode 100644 index 86696ee..0000000 --- a/tests/kernel/puzzles/test_cloze.py +++ /dev/null @@ -1,51 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -from heurams.kernel.evaluators.cloze import ClozePuzzle - - -class TestClozePuzzle(unittest.TestCase): - """测试 ClozePuzzle 类""" - - def test_init(self): - """测试初始化""" - puzzle = ClozePuzzle("hello/world/test", min_denominator=3, delimiter="/") - self.assertEqual(puzzle.text, "hello/world/test") - self.assertEqual(puzzle.min_denominator, 3) - self.assertEqual(puzzle.delimiter, "/") - self.assertEqual(puzzle.wording, "填空题 - 尚未刷新谜题") - self.assertEqual(puzzle.answer, ["填空题 - 尚未刷新谜题"]) - - @patch("random.sample") - def test_refresh(self, mock_sample): - """测试 refresh 方法""" - mock_sample.return_value = [0, 2] # 选择索引 0 和 2 - puzzle = ClozePuzzle("hello/world/test", min_denominator=2, delimiter="/") - puzzle.refresh() - - # 检查 wording 和 answer - self.assertNotEqual(puzzle.wording, "填空题 - 尚未刷新谜题") - self.assertNotEqual(puzzle.answer, ["填空题 - 尚未刷新谜题"]) - # 根据模拟, 应该有两个填空 - self.assertEqual(len(puzzle.answer), 2) - self.assertEqual(puzzle.answer, ["hello", "test"]) - # wording 应包含下划线 - self.assertIn("__", puzzle.wording) - - def test_refresh_empty_text(self): - """测试空文本的 refresh""" - puzzle = ClozePuzzle("", min_denominator=3, delimiter="/") - puzzle.refresh() # 不应引发异常 - # 空文本导致 wording 和 answer 为空 - self.assertEqual(puzzle.wording, "") - self.assertEqual(puzzle.answer, []) - - def test_str(self): - """测试 __str__ 方法""" - puzzle = ClozePuzzle("hello/world", min_denominator=2, delimiter="/") - str_repr = str(puzzle) - self.assertIn("填空题 - 尚未刷新谜题", str_repr) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/kernel/puzzles/test_mcq.py b/tests/kernel/puzzles/test_mcq.py deleted file mode 100644 index b64b329..0000000 --- a/tests/kernel/puzzles/test_mcq.py +++ /dev/null @@ -1,122 +0,0 @@ -import unittest -from unittest.mock import MagicMock, call, patch - -from heurams.kernel.evaluators.mcq import MCQPuzzle - - -class TestMCQPuzzle(unittest.TestCase): - """测试 MCQPuzzle 类""" - - def test_init(self): - """测试初始化""" - mapping = {"q1": "a1", "q2": "a2"} - jammer = ["j1", "j2", "j3"] - puzzle = MCQPuzzle(mapping, jammer, max_riddles_num=3, prefix="选择") - self.assertEqual(puzzle.prefix, "选择") - self.assertEqual(puzzle.mapping, mapping) - self.assertEqual(puzzle.max_riddles_num, 3) - # jammer 应合并正确答案并去重 - self.assertIn("a1", puzzle.jammer) - self.assertIn("a2", puzzle.jammer) - self.assertIn("j1", puzzle.jammer) - # 初始状态 - self.assertEqual(puzzle.wording, "选择题 - 尚未刷新谜题") - self.assertEqual(puzzle.answer, ["选择题 - 尚未刷新谜题"]) - self.assertEqual(puzzle.options, []) - - def test_init_max_riddles_num_clamping(self): - """测试 max_riddles_num 限制在 1-5 之间""" - puzzle1 = MCQPuzzle({}, [], max_riddles_num=0) - self.assertEqual(puzzle1.max_riddles_num, 1) - puzzle2 = MCQPuzzle({}, [], max_riddles_num=10) - self.assertEqual(puzzle2.max_riddles_num, 5) - - def test_init_jammer_ensures_minimum(self): - """测试干扰项至少保证 4 个""" - puzzle = MCQPuzzle({}, []) - # 正确答案为空, 干扰项为空, 应填充空格 - self.assertEqual(len(puzzle.jammer), 4) - self.assertEqual(set(puzzle.jammer), {" "}) # 三个空格? 实际上循环填充空格 - - @patch("random.sample") - @patch("random.shuffle") - def test_refresh(self, mock_shuffle, mock_sample): - """测试 refresh 方法生成题目""" - mapping = {"q1": "a1", "q2": "a2", "q3": "a3"} - jammer = ["j1", "j2", "j3", "j4"] - puzzle = MCQPuzzle(mapping, jammer, max_riddles_num=2) - # 模拟 random.sample 返回前两个映射项 - mock_sample.side_effect = [ - [("q1", "a1"), ("q2", "a2")], # 选择问题 - ["j1", "j2", "j3"], # 为每个问题选择干扰项(实际调用两次) - ] - puzzle.refresh() - - # 检查 wording 是列表 - self.assertIsInstance(puzzle.wording, list) - self.assertEqual(len(puzzle.wording), 2) - # 检查 answer 列表 - self.assertEqual(puzzle.answer, ["a1", "a2"]) - # 检查 options 列表 - self.assertEqual(len(puzzle.options), 2) - # 每个选项列表应包含 4 个选项(正确答案 + 3 个干扰项) - self.assertEqual(len(puzzle.options[0]), 4) - self.assertEqual(len(puzzle.options[1]), 4) - # random.shuffle 应被调用 - self.assertEqual(mock_shuffle.call_count, 2) - - def test_refresh_empty_mapping(self): - """测试空 mapping 的 refresh""" - puzzle = MCQPuzzle({}, []) - puzzle.refresh() - self.assertEqual(puzzle.wording, "无可用题目") - self.assertEqual(puzzle.answer, ["无答案"]) - self.assertEqual(puzzle.options, []) - - def test_get_question_count(self): - """测试 get_question_count 方法""" - puzzle = MCQPuzzle({"q": "a"}, []) - self.assertEqual(puzzle.get_question_count(), 0) # 未刷新 - puzzle.refresh = MagicMock() - puzzle.wording = ["Q1", "Q2"] - self.assertEqual(puzzle.get_question_count(), 2) - puzzle.wording = "无可用题目" - self.assertEqual(puzzle.get_question_count(), 0) - puzzle.wording = "单个问题" - self.assertEqual(puzzle.get_question_count(), 1) - - def test_get_correct_answer_for_question(self): - """测试 get_correct_answer_for_question 方法""" - puzzle = MCQPuzzle({}, []) - puzzle.answer = ["ans1", "ans2"] - self.assertEqual(puzzle.get_correct_answer_for_question(0), "ans1") - self.assertEqual(puzzle.get_correct_answer_for_question(1), "ans2") - self.assertIsNone(puzzle.get_correct_answer_for_question(2)) - puzzle.answer = "not a list" - self.assertIsNone(puzzle.get_correct_answer_for_question(0)) - - def test_get_options_for_question(self): - """测试 get_options_for_question 方法""" - puzzle = MCQPuzzle({}, []) - puzzle.options = [["a", "b", "c", "d"], ["e", "f", "g", "h"]] - self.assertEqual(puzzle.get_options_for_question(0), ["a", "b", "c", "d"]) - self.assertEqual(puzzle.get_options_for_question(1), ["e", "f", "g", "h"]) - self.assertIsNone(puzzle.get_options_for_question(2)) - - def test_str(self): - """测试 __str__ 方法""" - puzzle = MCQPuzzle({}, []) - puzzle.wording = "选择题 - 尚未刷新谜题" - puzzle.answer = ["选择题 - 尚未刷新谜题"] - self.assertIn("选择题 - 尚未刷新谜题", str(puzzle)) - self.assertIn("正确答案", str(puzzle)) - - puzzle.wording = ["Q1", "Q2"] - puzzle.answer = ["A1", "A2"] - str_repr = str(puzzle) - self.assertIn("Q1", str_repr) - self.assertIn("A1, A2", str_repr) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/kernel/reactor/__init__.py b/tests/kernel/reactor/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/kernel/reactor/test_phaser.py b/tests/kernel/reactor/test_phaser.py deleted file mode 100644 index fba30e4..0000000 --- a/tests/kernel/reactor/test_phaser.py +++ /dev/null @@ -1,114 +0,0 @@ -import unittest -from unittest.mock import MagicMock, Mock, patch - -from heurams.kernel.particles.atom import Atom -from heurams.kernel.particles.electron import Electron -from heurams.kernel.reactor.procession import Phaser -from heurams.kernel.reactor.states import PhaserState, ProcessionState - - -class TestPhaser(unittest.TestCase): - """测试 Phaser 类""" - - def setUp(self): - # 创建模拟的 Atom 对象 - self.atom_new = Mock(spec=Atom) - self.atom_new.registry = {"electron": Mock(spec=Electron)} - self.atom_new.registry["electron"].is_activated.return_value = False - - self.atom_old = Mock(spec=Atom) - self.atom_old.registry = {"electron": Mock(spec=Electron)} - self.atom_old.registry["electron"].is_activated.return_value = True - - # 模拟 Procession 类以避免复杂依赖 - self.procession_patcher = patch("heurams.kernel.reactor.phaser.Procession") - self.mock_procession_class = self.procession_patcher.start() - - def tearDown(self): - self.procession_patcher.stop() - - def test_init_with_mixed_atoms(self): - """测试混合新旧原子的初始化""" - atoms = [self.atom_old, self.atom_new, self.atom_old] - phaser = Phaser(atoms) - - # 应该创建两个 Procession: 一个用于旧原子, 一个用于新原子, 以及一个总体复习 - self.assertEqual(self.mock_procession_class.call_count, 3) - - # 检查调用参数 - calls = self.mock_procession_class.call_args_list - # 第一个调用应该是旧原子的初始复习 - self.assertEqual(calls[0][0][0], [self.atom_old, self.atom_old]) - self.assertEqual(calls[0][0][1], PhaserState.QUICK_REVIEW) - # 第二个调用应该是新原子的识别阶段 - self.assertEqual(calls[1][0][0], [self.atom_new]) - self.assertEqual(calls[1][0][1], PhaserState.RECOGNITION) - # 第三个调用应该是所有原子的总体复习 - self.assertEqual(calls[2][0][0], atoms) - self.assertEqual(calls[2][0][1], PhaserState.FINAL_REVIEW) - - def test_init_only_old_atoms(self): - """测试只有旧原子""" - atoms = [self.atom_old, self.atom_old] - phaser = Phaser(atoms) - - # 应该创建两个 Procession: 一个初始复习, 一个总体复习 - self.assertEqual(self.mock_procession_class.call_count, 2) - calls = self.mock_procession_class.call_args_list - self.assertEqual(calls[0][0][0], atoms) - self.assertEqual(calls[0][0][1], PhaserState.QUICK_REVIEW) - self.assertEqual(calls[1][0][0], atoms) - self.assertEqual(calls[1][0][1], PhaserState.FINAL_REVIEW) - - def test_init_only_new_atoms(self): - """测试只有新原子""" - atoms = [self.atom_new, self.atom_new] - phaser = Phaser(atoms) - - self.assertEqual(self.mock_procession_class.call_count, 2) - calls = self.mock_procession_class.call_args_list - self.assertEqual(calls[0][0][0], atoms) - self.assertEqual(calls[0][0][1], PhaserState.RECOGNITION) - self.assertEqual(calls[1][0][0], atoms) - self.assertEqual(calls[1][0][1], PhaserState.FINAL_REVIEW) - - def test_current_procession_finds_unfinished(self): - """测试 current_procession 找到未完成的 Procession""" - # 创建模拟 Procession 实例 - mock_proc1 = Mock() - mock_proc1.state = ProcessionState.FINISHED - mock_proc2 = Mock() - mock_proc2.state = ProcessionState.RUNNING - mock_proc2.phase = PhaserState.QUICK_REVIEW - - phaser = Phaser([]) - phaser.processions = [mock_proc1, mock_proc2] - - result = phaser.current_procession() - self.assertEqual(result, mock_proc2) - self.assertEqual(phaser.state, PhaserState.QUICK_REVIEW) - - def test_current_procession_all_finished(self): - """测试所有 Procession 都完成""" - mock_proc = Mock() - mock_proc.state = ProcessionState.FINISHED - - phaser = Phaser([]) - phaser.processions = [mock_proc] - - result = phaser.current_procession() - self.assertEqual(result, 0) - self.assertEqual(phaser.state, PhaserState.FINISHED) - - def test_current_procession_empty(self): - """测试没有 Procession""" - phaser = Phaser([]) - phaser.processions = [] - - result = phaser.current_procession() - self.assertEqual(result, 0) - self.assertEqual(phaser.state, PhaserState.FINISHED) - - -if __name__ == "__main__": - unittest.main()