fix: 改进代码
This commit is contained in:
@@ -3,6 +3,7 @@ import heurams.kernel.particles as pt
|
|||||||
from heurams.services.textproc import truncate
|
from heurams.services.textproc import truncate
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
|
|
||||||
repo = repolib.Repo.create_from_repodir(Path("./test_repo"))
|
repo = repolib.Repo.create_from_repodir(Path("./test_repo"))
|
||||||
alist = list()
|
alist = list()
|
||||||
print(repo.ident_index)
|
print(repo.ident_index)
|
||||||
@@ -17,13 +18,14 @@ for i in repo.ident_index:
|
|||||||
input()
|
input()
|
||||||
a = pt.Atom(n, e, repo.orbitic_data)
|
a = pt.Atom(n, e, repo.orbitic_data)
|
||||||
alist.append(a)
|
alist.append(a)
|
||||||
#e.activate()
|
# e.activate()
|
||||||
#e.revisor(5, True)
|
# e.revisor(5, True)
|
||||||
print(repr(a))
|
print(repr(a))
|
||||||
# print(repr(e))
|
# print(repr(e))
|
||||||
print(repo)
|
print(repo)
|
||||||
input()
|
input()
|
||||||
import heurams.kernel.reactor as rt
|
import heurams.kernel.reactor as rt
|
||||||
|
|
||||||
ph: rt.Phaser = rt.Phaser(alist)
|
ph: rt.Phaser = rt.Phaser(alist)
|
||||||
print(ph)
|
print(ph)
|
||||||
pr: rt.Procession = ph.current_procession() # type: ignore
|
pr: rt.Procession = ph.current_procession() # type: ignore
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ def environment_check():
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
logger.debug("检查环境路径")
|
logger.debug("检查环境路径")
|
||||||
subdir = ['cache/voice', 'repo', 'global', 'config']
|
subdir = ["cache/voice", "repo", "global", "config"]
|
||||||
for i in subdir:
|
for i in subdir:
|
||||||
i = Path(config_var.get()['paths']['data']) / i
|
i = Path(config_var.get()["paths"]["data"]) / i
|
||||||
if not i.exists():
|
if not i.exists():
|
||||||
logger.info("创建目录: %s", i)
|
logger.info("创建目录: %s", i)
|
||||||
print(f"创建 {i}")
|
print(f"创建 {i}")
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ from heurams.kernel.particles import *
|
|||||||
from heurams.kernel.repolib import *
|
from heurams.kernel.repolib import *
|
||||||
from heurams.services.logger import get_logger
|
from heurams.services.logger import get_logger
|
||||||
|
|
||||||
|
import heurams.kernel.particles as pt
|
||||||
|
from pathlib import Path
|
||||||
from .about import AboutScreen
|
from .about import AboutScreen
|
||||||
from .preparation import PreparationScreen
|
from .preparation import PreparationScreen
|
||||||
|
|
||||||
@@ -42,7 +44,9 @@ class DashboardScreen(Screen):
|
|||||||
yield Header(show_clock=True)
|
yield Header(show_clock=True)
|
||||||
yield ScrollableContainer(
|
yield ScrollableContainer(
|
||||||
Label('欢迎使用 "潜进" 启发式辅助记忆调度器', classes="title-label"),
|
Label('欢迎使用 "潜进" 启发式辅助记忆调度器', classes="title-label"),
|
||||||
Label(f"当前 UNIX 日时间戳: {timer.get_daystamp()} (UTC+{config_var.get()["timezone_offset"] / 3600})"),
|
Label(
|
||||||
|
f"当前 UNIX 日时间戳: {timer.get_daystamp()} (UTC+{config_var.get()["timezone_offset"] / 3600})"
|
||||||
|
),
|
||||||
Label(f"全局算法设置: {config_var.get()['algorithm']['default']}"),
|
Label(f"全局算法设置: {config_var.get()['algorithm']['default']}"),
|
||||||
Label("选择待学习或待修改的项目:", classes="title-label"),
|
Label("选择待学习或待修改的项目:", classes="title-label"),
|
||||||
ListView(id="repo-list", classes="repo-list-view"),
|
ListView(id="repo-list", classes="repo-list-view"),
|
||||||
|
|||||||
@@ -41,9 +41,9 @@ class MemScreen(Screen):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
phaser: Phaser,
|
phaser: Phaser,
|
||||||
name = None,
|
name=None,
|
||||||
id = None,
|
id=None,
|
||||||
classes = None,
|
classes=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(name, id, classes)
|
super().__init__(name, id, classes)
|
||||||
self.phaser = phaser
|
self.phaser = phaser
|
||||||
@@ -69,14 +69,12 @@ class MemScreen(Screen):
|
|||||||
try:
|
try:
|
||||||
self.fission = self.procession.get_fission()
|
self.fission = self.procession.get_fission()
|
||||||
puzzle = self.fission.get_current_puzzle()
|
puzzle = self.fission.get_current_puzzle()
|
||||||
# logger.debug(puzzle_debug)
|
|
||||||
return shim.puzzle2widget[puzzle["puzzle"]]( # type: ignore
|
return shim.puzzle2widget[puzzle["puzzle"]]( # type: ignore
|
||||||
atom=self.atom, alia=puzzle["alia"] # type: ignore
|
atom=self.atom, alia=puzzle["alia"] # type: ignore
|
||||||
)
|
)
|
||||||
except (KeyError, StopIteration, AttributeError) as e:
|
except (KeyError, StopIteration, AttributeError) as e:
|
||||||
logger.debug(f"调度展开出错: {e}")
|
logger.debug(f"调度展开出错: {e}")
|
||||||
return Static(f"无法生成谜题 {e}")
|
return Static(f"无法生成谜题 {e}")
|
||||||
# logger.debug(shim.puzzle2widget[puzzle_debug["puzzle"]])
|
|
||||||
|
|
||||||
def _get_progress_text(self):
|
def _get_progress_text(self):
|
||||||
s = f"阶段: {self.procession.phase.name}\n"
|
s = f"阶段: {self.procession.phase.name}\n"
|
||||||
@@ -117,18 +115,14 @@ class MemScreen(Screen):
|
|||||||
from heurams.services.audio_service import play_by_path
|
from heurams.services.audio_service import play_by_path
|
||||||
from heurams.services.hasher import get_md5
|
from heurams.services.hasher import get_md5
|
||||||
|
|
||||||
path = Path(config_var.get()["paths"]['data']) / 'cache' / 'voice'
|
path = Path(config_var.get()["paths"]["data"]) / "cache" / "voice"
|
||||||
path = (
|
path = path / f"{get_md5(self.atom.registry['nucleon']["tts_text"])}.wav"
|
||||||
path
|
|
||||||
/ f"{get_md5(self.atom.registry['nucleon']["tts_text"])}.wav"
|
|
||||||
)
|
|
||||||
if path.exists():
|
if path.exists():
|
||||||
play_by_path(path)
|
play_by_path(path)
|
||||||
else:
|
else:
|
||||||
from heurams.services.tts_service import convertor
|
from heurams.services.tts_service import convertor
|
||||||
convertor(
|
|
||||||
self.atom.registry["nucleon"]["tts_text"], path
|
convertor(self.atom.registry["nucleon"]["tts_text"], path)
|
||||||
)
|
|
||||||
play_by_path(path)
|
play_by_path(path)
|
||||||
|
|
||||||
def watch_rating(self, old_rating, new_rating) -> None:
|
def watch_rating(self, old_rating, new_rating) -> None:
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ import heurams.kernel.particles as pt
|
|||||||
import heurams.services.hasher as hasher
|
import heurams.services.hasher as hasher
|
||||||
from heurams.context import *
|
from heurams.context import *
|
||||||
|
|
||||||
cache_dir = pathlib.Path(config_var.get()["paths"]["data"]) / "cache" / 'voice'
|
cache_dir = pathlib.Path(config_var.get()["paths"]["data"]) / "cache" / "voice"
|
||||||
|
|
||||||
|
|
||||||
class PrecachingScreen(Screen):
|
class PrecachingScreen(Screen):
|
||||||
"""预缓存音频文件屏幕
|
"""预缓存音频文件屏幕
|
||||||
@@ -204,9 +205,7 @@ class PrecachingScreen(Screen):
|
|||||||
|
|
||||||
from heurams.context import config_var, rootdir, workdir
|
from heurams.context import config_var, rootdir, workdir
|
||||||
|
|
||||||
shutil.rmtree(
|
shutil.rmtree(cache_dir, ignore_errors=True)
|
||||||
cache_dir, ignore_errors=True
|
|
||||||
)
|
|
||||||
self.update_status("已清空", "音频缓存已清空", 0)
|
self.update_status("已清空", "音频缓存已清空", 0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.update_status("错误", f"清空缓存失败: {e}")
|
self.update_status("错误", f"清空缓存失败: {e}")
|
||||||
|
|||||||
@@ -128,6 +128,7 @@ class PreparationScreen(Screen):
|
|||||||
atoms_to_provide.append(i)
|
atoms_to_provide.append(i)
|
||||||
from .memoqueue import MemScreen
|
from .memoqueue import MemScreen
|
||||||
import heurams.kernel.reactor as rt
|
import heurams.kernel.reactor as rt
|
||||||
|
|
||||||
pheser = rt.Phaser(atoms_to_provide)
|
pheser = rt.Phaser(atoms_to_provide)
|
||||||
memscreen = MemScreen(pheser)
|
memscreen = MemScreen(pheser)
|
||||||
self.app.push_screen(memscreen)
|
self.app.push_screen(memscreen)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Kernel 操作辅助函数库"""
|
"""Kernel 操作辅助函数库"""
|
||||||
|
|
||||||
import heurams.interface.widgets as pzw
|
import heurams.interface.widgets as pzw
|
||||||
import heurams.kernel.evaluators as pz
|
import heurams.kernel.evaluators as pz
|
||||||
|
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class MCQPuzzle(BasePuzzleWidget):
|
|||||||
self._load()
|
self._load()
|
||||||
|
|
||||||
def _load(self):
|
def _load(self):
|
||||||
cfg = self.atom.registry["orbital"]["puzzles"][self.alia]
|
cfg = self.atom.registry["nucleon"]["puzzles"][self.alia]
|
||||||
self.puzzle = pz.MCQPuzzle(
|
self.puzzle = pz.MCQPuzzle(
|
||||||
cfg["mapping"], cfg["jammer"], int(cfg["max_riddles_num"]), cfg["prefix"]
|
cfg["mapping"], cfg["jammer"], int(cfg["max_riddles_num"]), cfg["prefix"]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,11 +1,7 @@
|
|||||||
from heurams.services.logger import get_logger
|
|
||||||
|
|
||||||
from .base import BaseAlgorithm
|
from .base import BaseAlgorithm
|
||||||
from .sm2 import SM2Algorithm
|
from .sm2 import SM2Algorithm
|
||||||
from .sm15m import SM15MAlgorithm
|
from .sm15m import SM15MAlgorithm
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SM2Algorithm",
|
"SM2Algorithm",
|
||||||
"BaseAlgorithm",
|
"BaseAlgorithm",
|
||||||
@@ -17,5 +13,3 @@ algorithms = {
|
|||||||
"SM-15M": SM15MAlgorithm,
|
"SM-15M": SM15MAlgorithm,
|
||||||
"Base": BaseAlgorithm,
|
"Base": BaseAlgorithm,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug("算法模块初始化完成, 注册的算法: %s", list(algorithms.keys()))
|
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ Evaluator 模块 - 生成评估模块
|
|||||||
|
|
||||||
from heurams.services.logger import get_logger
|
from heurams.services.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
from .base import BaseEvaluator
|
from .base import BaseEvaluator
|
||||||
from .cloze import ClozePuzzle
|
from .cloze import ClozePuzzle
|
||||||
from .mcq import MCQPuzzle
|
from .mcq import MCQPuzzle
|
||||||
@@ -26,38 +24,3 @@ puzzles = {
|
|||||||
"recognition": RecognitionPuzzle,
|
"recognition": RecognitionPuzzle,
|
||||||
"base": BaseEvaluator,
|
"base": BaseEvaluator,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_by_dict(config_dict: dict) -> BaseEvaluator:
|
|
||||||
"""
|
|
||||||
根据配置字典创建谜题
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dict: 配置字典, 包含谜题类型和参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BasePuzzle: 谜题实例
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: 当配置无效时抛出
|
|
||||||
"""
|
|
||||||
logger.debug(
|
|
||||||
"puzzles.create_by_dict: config_dict keys=%s", list(config_dict.keys())
|
|
||||||
)
|
|
||||||
puzzle_type = config_dict.get("type")
|
|
||||||
|
|
||||||
if puzzle_type == "cloze":
|
|
||||||
return puzzles["cloze"](
|
|
||||||
text=config_dict["text"],
|
|
||||||
min_denominator=config_dict.get("min_denominator", 7),
|
|
||||||
)
|
|
||||||
elif puzzle_type == "mcq":
|
|
||||||
return puzzles["mcq"](
|
|
||||||
mapping=config_dict["mapping"],
|
|
||||||
jammer=config_dict.get("jammer", []),
|
|
||||||
max_riddles_num=config_dict.get("max_riddles_num", 2),
|
|
||||||
prefix=config_dict.get("prefix", ""),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"未知的谜题类型: {puzzle_type}")
|
|
||||||
|
|||||||
@@ -1,4 +1,21 @@
|
|||||||
from .atom import Atom
|
from .atom import Atom
|
||||||
from .electron import Electron
|
from .electron import Electron
|
||||||
from .nucleon import Nucleon
|
from .nucleon import Nucleon
|
||||||
#from .orbital import Orbital
|
from .placeholders import (
|
||||||
|
AtomPlaceholder,
|
||||||
|
NucleonPlaceholder,
|
||||||
|
ElectronPlaceholder,
|
||||||
|
orbital_placeholder,
|
||||||
|
)
|
||||||
|
|
||||||
|
# from .orbital import Orbital
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Atom",
|
||||||
|
"Electron",
|
||||||
|
"Nucleon",
|
||||||
|
"AtomPlaceholder",
|
||||||
|
"NucleonPlaceholder",
|
||||||
|
"ElectronPlaceholder",
|
||||||
|
"orbital_placeholder",
|
||||||
|
]
|
||||||
|
|||||||
@@ -13,9 +13,11 @@ class Nucleon:
|
|||||||
|
|
||||||
def __init__(self, ident, payload, common):
|
def __init__(self, ident, payload, common):
|
||||||
self.ident = ident
|
self.ident = ident
|
||||||
env = {"payload": payload,
|
env = {
|
||||||
"default": config_var.get()['puzzles'],
|
"payload": payload,
|
||||||
"nucleon": (payload | common)}
|
"default": config_var.get()["puzzles"],
|
||||||
|
"nucleon": (payload | common),
|
||||||
|
}
|
||||||
self.evalizer = Evalizer(environment=env)
|
self.evalizer = Evalizer(environment=env)
|
||||||
self.data: dict = self.evalizer(deepcopy((payload | common))) # type: ignore
|
self.data: dict = self.evalizer(deepcopy((payload | common))) # type: ignore
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,4 @@ from .phaser import Phaser
|
|||||||
from .procession import Procession
|
from .procession import Procession
|
||||||
from .states import PhaserState, ProcessionState
|
from .states import PhaserState, ProcessionState
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
__all__ = ["PhaserState", "ProcessionState", "Procession", "Fission", "Phaser"]
|
__all__ = ["PhaserState", "ProcessionState", "Procession", "Fission", "Phaser"]
|
||||||
|
|
||||||
logger.debug("反应堆模块已加载")
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class Fission:
|
|||||||
phase_state.value if isinstance(phase_state, PhaserState) else phase_state
|
phase_state.value if isinstance(phase_state, PhaserState) else phase_state
|
||||||
)
|
)
|
||||||
|
|
||||||
self.orbital_schedule = atom.registry['orbital']["phases"][phase_value] # type: ignore
|
self.orbital_schedule = atom.registry["orbital"]["phases"][phase_value] # type: ignore
|
||||||
self.orbital_puzzles = atom.registry["nucleon"]["puzzles"]
|
self.orbital_puzzles = atom.registry["nucleon"]["puzzles"]
|
||||||
|
|
||||||
self.puzzles = list()
|
self.puzzles = list()
|
||||||
@@ -34,7 +34,6 @@ class Fission:
|
|||||||
{
|
{
|
||||||
"puzzle": puz.puzzles[self.orbital_puzzles[item]["__origin__"]],
|
"puzzle": puz.puzzles[self.orbital_puzzles[item]["__origin__"]],
|
||||||
"alia": item,
|
"alia": item,
|
||||||
"finished": 0,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
possibility -= 1
|
possibility -= 1
|
||||||
@@ -44,7 +43,6 @@ class Fission:
|
|||||||
{
|
{
|
||||||
"puzzle": puz.puzzles[self.orbital_puzzles[item]["__origin__"]],
|
"puzzle": puz.puzzles[self.orbital_puzzles[item]["__origin__"]],
|
||||||
"alia": item,
|
"alia": item,
|
||||||
"finished": 0,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,7 +51,7 @@ class Fission:
|
|||||||
def get_puzzles(self):
|
def get_puzzles(self):
|
||||||
return self.puzzles
|
return self.puzzles
|
||||||
|
|
||||||
def get_current_puzzle(self, forward = 0):
|
def get_current_puzzle(self, forward=0):
|
||||||
if forward:
|
if forward:
|
||||||
if len(self.puzzles) <= self.cursor + 1:
|
if len(self.puzzles) <= self.cursor + 1:
|
||||||
return 0
|
return 0
|
||||||
@@ -62,7 +60,6 @@ class Fission:
|
|||||||
else:
|
else:
|
||||||
return self.puzzles[self.cursor]
|
return self.puzzles[self.cursor]
|
||||||
|
|
||||||
|
|
||||||
def check_passed(self):
|
def check_passed(self):
|
||||||
for i in self.puzzles:
|
for i in self.puzzles:
|
||||||
if i["finished"] == 0:
|
if i["finished"] == 0:
|
||||||
|
|||||||
@@ -130,6 +130,7 @@ class Phaser(Machine):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
from heurams.services.textproc import truncate
|
from heurams.services.textproc import truncate
|
||||||
from tabulate import tabulate as tabu
|
from tabulate import tabulate as tabu
|
||||||
|
|
||||||
lst = [
|
lst = [
|
||||||
{
|
{
|
||||||
"Type": "Phaser",
|
"Type": "Phaser",
|
||||||
@@ -138,4 +139,4 @@ class Phaser(Machine):
|
|||||||
"Current Procession": "None" if not self.current_procession() else self.current_procession().name_, # type: ignore
|
"Current Procession": "None" if not self.current_procession() else self.current_procession().name_, # type: ignore
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
return str(tabu(lst, headers="keys")) + '\n'
|
return str(tabu(lst, headers="keys")) + "\n"
|
||||||
|
|||||||
@@ -63,8 +63,7 @@ class Procession(Machine):
|
|||||||
logger.debug("Procession 进入 FINISHED 状态")
|
logger.debug("Procession 进入 FINISHED 状态")
|
||||||
|
|
||||||
def forward(self, step=1):
|
def forward(self, step=1):
|
||||||
"""将记忆原子指针向前移动并依情况更新原子(返回 1)或完成队列(返回 0)
|
"""将记忆原子指针向前移动并依情况更新原子(返回 1)或完成队列(返回 0)"""
|
||||||
"""
|
|
||||||
logger.debug("Procession.forward: step=%d, 当前 cursor=%d", step, self.cursor)
|
logger.debug("Procession.forward: step=%d, 当前 cursor=%d", step, self.cursor)
|
||||||
self.cursor += step
|
self.cursor += step
|
||||||
if self.cursor >= len(self.queue):
|
if self.cursor >= len(self.queue):
|
||||||
@@ -84,8 +83,7 @@ class Procession(Machine):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
def append(self, atom=None):
|
def append(self, atom=None):
|
||||||
"""追加(回忆失败的)原子(默认为当前原子)到队列末端
|
"""追加(回忆失败的)原子(默认为当前原子)到队列末端"""
|
||||||
"""
|
|
||||||
if atom is None:
|
if atom is None:
|
||||||
atom = self.current_atom
|
atom = self.current_atom
|
||||||
logger.debug("Procession.append: atom=%s", atom.ident if atom else "None")
|
logger.debug("Procession.append: atom=%s", atom.ident if atom else "None")
|
||||||
@@ -122,6 +120,7 @@ class Procession(Machine):
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
from heurams.services.textproc import truncate
|
from heurams.services.textproc import truncate
|
||||||
|
|
||||||
dic = [
|
dic = [
|
||||||
{
|
{
|
||||||
"Type": "Procession",
|
"Type": "Procession",
|
||||||
@@ -132,4 +131,4 @@ class Procession(Machine):
|
|||||||
"Current Atom": self.current_atom.ident, # type: ignore
|
"Current Atom": self.current_atom.ident, # type: ignore
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
return str(tabu(dic, headers="keys")) + '\n'
|
return str(tabu(dic, headers="keys")) + "\n"
|
||||||
|
|||||||
@@ -1 +1,3 @@
|
|||||||
from .repo import *
|
from .repo import Repo, RepoManifest
|
||||||
|
|
||||||
|
__all__ = ["Repo", "RepoManifest"]
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
from ...utils.lict import Lict
|
|
||||||
|
|
||||||
|
|
||||||
def merge(x: Lict, y: Lict):
|
|
||||||
return Lict(list(x.values()) + list(y.values()))
|
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
from .evalizor import Evalizer
|
||||||
|
from .lict import Lict
|
||||||
|
from .refvar import RefVar
|
||||||
|
|
||||||
|
__all__ = ["Evalizer", "Lict", "RefVar"]
|
||||||
|
|||||||
@@ -1,153 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
DashboardScreen 的测试, 包括单元测试和 pilot 测试.
|
|
||||||
"""
|
|
||||||
import pathlib
|
|
||||||
import tempfile
|
|
||||||
import time
|
|
||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
from textual.pilot import Pilot
|
|
||||||
|
|
||||||
from heurams.context import ConfigContext
|
|
||||||
from heurams.interface.__main__ import HeurAMSApp
|
|
||||||
from heurams.interface.screens.dashboard import DashboardScreen
|
|
||||||
from heurams.services.config import ConfigFile
|
|
||||||
|
|
||||||
|
|
||||||
class TestDashboardScreenUnit(unittest.TestCase):
|
|
||||||
"""DashboardScreen 的单元测试(不启动完整应用)."""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
"""在每个测试之前运行, 设置临时目录和配置."""
|
|
||||||
# 创建临时目录用于测试数据
|
|
||||||
self.temp_dir = tempfile.TemporaryDirectory()
|
|
||||||
self.temp_path = pathlib.Path(self.temp_dir.name)
|
|
||||||
|
|
||||||
# 创建默认配置, 并修改路径指向临时目录
|
|
||||||
default_config_path = (
|
|
||||||
pathlib.Path(__file__).parent.parent.parent
|
|
||||||
/ "src/heurams/default/config/config.toml"
|
|
||||||
)
|
|
||||||
self.config = ConfigFile(default_config_path)
|
|
||||||
# 更新配置中的路径
|
|
||||||
config_data = self.config.data
|
|
||||||
config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon")
|
|
||||||
config_data["paths"]["electron_dir"] = str(self.temp_path / "electron")
|
|
||||||
config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital")
|
|
||||||
config_data["paths"]["cache_dir"] = str(self.temp_path / "cache")
|
|
||||||
# 禁用快速通过, 避免测试干扰
|
|
||||||
config_data["quick_pass"] = 0
|
|
||||||
# 禁用时间覆盖
|
|
||||||
config_data["daystamp_override"] = -1
|
|
||||||
config_data["timestamp_override"] = -1
|
|
||||||
|
|
||||||
# 创建目录
|
|
||||||
for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]:
|
|
||||||
pathlib.Path(config_data["paths"][dir_key]).mkdir(
|
|
||||||
parents=True, exist_ok=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用 ConfigContext 设置配置
|
|
||||||
self.config_ctx = ConfigContext(self.config)
|
|
||||||
self.config_ctx.__enter__()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
"""在每个测试之后清理."""
|
|
||||||
self.config_ctx.__exit__(None, None, None)
|
|
||||||
self.temp_dir.cleanup()
|
|
||||||
|
|
||||||
def test_compose(self):
|
|
||||||
"""测试 compose 方法返回正确的部件."""
|
|
||||||
screen = DashboardScreen()
|
|
||||||
# 手动调用 compose 并收集部件
|
|
||||||
from textual.app import ComposeResult
|
|
||||||
|
|
||||||
result = screen.compose()
|
|
||||||
widgets = list(result)
|
|
||||||
# 检查是否包含 Header 和 Footer
|
|
||||||
from textual.widgets import Footer, Header
|
|
||||||
|
|
||||||
header_present = any(isinstance(w, Header) for w in widgets)
|
|
||||||
footer_present = any(isinstance(w, Footer) for w in widgets)
|
|
||||||
self.assertTrue(header_present)
|
|
||||||
self.assertTrue(footer_present)
|
|
||||||
# 检查是否有 ScrollableContainer
|
|
||||||
from textual.containers import ScrollableContainer
|
|
||||||
|
|
||||||
container_present = any(isinstance(w, ScrollableContainer) for w in widgets)
|
|
||||||
self.assertTrue(container_present)
|
|
||||||
# 使用 query_one 查找 union-list, 即使屏幕未挂载也可能有效
|
|
||||||
list_view = screen.query_one("#union-list")
|
|
||||||
self.assertIsNotNone(list_view)
|
|
||||||
self.assertEqual(list_view.id, "union-list")
|
|
||||||
self.assertEqual(list_view.__class__.__name__, "ListView")
|
|
||||||
|
|
||||||
def test_item_desc_generator(self):
|
|
||||||
"""测试 item_desc_generator 函数."""
|
|
||||||
screen = DashboardScreen()
|
|
||||||
# 模拟一个文件名
|
|
||||||
filename = "test.toml"
|
|
||||||
result = screen.analyser(filename)
|
|
||||||
self.assertIsInstance(result, dict)
|
|
||||||
self.assertIn(0, result)
|
|
||||||
self.assertIn(1, result)
|
|
||||||
# 检查内容
|
|
||||||
self.assertIn("test.toml", result[0])
|
|
||||||
# 由于 electron 文件不存在, 应显示“尚未激活”
|
|
||||||
self.assertIn("尚未激活", result[1])
|
|
||||||
|
|
||||||
|
|
||||||
@unittest.skip("Pilot 测试需要进一步配置, 暂不运行")
|
|
||||||
class TestDashboardScreenPilot(unittest.TestCase):
|
|
||||||
"""使用 Textual Pilot 的集成测试."""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
"""配置临时目录和配置."""
|
|
||||||
self.temp_dir = tempfile.TemporaryDirectory()
|
|
||||||
self.temp_path = pathlib.Path(self.temp_dir.name)
|
|
||||||
|
|
||||||
default_config_path = (
|
|
||||||
pathlib.Path(__file__).parent.parent.parent
|
|
||||||
/ "src/heurams/default/config/config.toml"
|
|
||||||
)
|
|
||||||
self.config = ConfigFile(default_config_path)
|
|
||||||
config_data = self.config.data
|
|
||||||
config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon")
|
|
||||||
config_data["paths"]["electron_dir"] = str(self.temp_path / "electron")
|
|
||||||
config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital")
|
|
||||||
config_data["paths"]["cache_dir"] = str(self.temp_path / "cache")
|
|
||||||
config_data["quick_pass"] = 0
|
|
||||||
config_data["daystamp_override"] = -1
|
|
||||||
config_data["timestamp_override"] = -1
|
|
||||||
|
|
||||||
for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]:
|
|
||||||
pathlib.Path(config_data["paths"][dir_key]).mkdir(
|
|
||||||
parents=True, exist_ok=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.config_ctx = ConfigContext(self.config)
|
|
||||||
self.config_ctx.__enter__()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.config_ctx.__exit__(None, None, None)
|
|
||||||
self.temp_dir.cleanup()
|
|
||||||
|
|
||||||
def test_dashboard_loads_with_pilot(self):
|
|
||||||
"""使用 Pilot 测试 DashboardScreen 加载."""
|
|
||||||
with patch("heurams.interface.__main__.environment_check"):
|
|
||||||
app = HeurAMSApp()
|
|
||||||
# 注意: Pilot 在 Textual 6.9.0 中的用法可能不同
|
|
||||||
# 以下为示例代码, 可能需要调整
|
|
||||||
pilot = Pilot(app)
|
|
||||||
# 等待应用启动
|
|
||||||
pilot.pause()
|
|
||||||
screen = app.screen
|
|
||||||
self.assertEqual(screen.__class__.__name__, "DashboardScreen")
|
|
||||||
union_list = app.query_one("#union-list")
|
|
||||||
self.assertIsNotNone(union_list)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,317 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
SyncScreen 和 SyncService 的测试.
|
|
||||||
"""
|
|
||||||
import pathlib
|
|
||||||
import tempfile
|
|
||||||
import time
|
|
||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
|
||||||
|
|
||||||
from heurams.context import ConfigContext
|
|
||||||
from heurams.services.config import ConfigFile
|
|
||||||
from heurams.services.sync_service import (
|
|
||||||
ConflictStrategy,
|
|
||||||
SyncConfig,
|
|
||||||
SyncMode,
|
|
||||||
SyncService,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestSyncServiceUnit(unittest.TestCase):
|
|
||||||
"""SyncService 的单元测试."""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
"""在每个测试之前运行, 设置临时目录和模拟客户端."""
|
|
||||||
self.temp_dir = tempfile.TemporaryDirectory()
|
|
||||||
self.temp_path = pathlib.Path(self.temp_dir.name)
|
|
||||||
|
|
||||||
# 创建测试文件
|
|
||||||
self.test_file = self.temp_path / "test.txt"
|
|
||||||
self.test_file.write_text("测试内容")
|
|
||||||
|
|
||||||
# 模拟 WebDAV 客户端
|
|
||||||
self.mock_client = MagicMock()
|
|
||||||
|
|
||||||
# 创建同步配置
|
|
||||||
self.config = SyncConfig(
|
|
||||||
enabled=True,
|
|
||||||
url="https://example.com/dav/",
|
|
||||||
username="test",
|
|
||||||
password="test",
|
|
||||||
remote_path="/heurams/",
|
|
||||||
sync_mode=SyncMode.BIDIRECTIONAL,
|
|
||||||
conflict_strategy=ConflictStrategy.NEWER,
|
|
||||||
verify_ssl=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
"""在每个测试之后清理."""
|
|
||||||
self.temp_dir.cleanup()
|
|
||||||
|
|
||||||
@patch("heurams.services.sync_service.Client")
|
|
||||||
def test_sync_service_initialization(self, mock_client_class):
|
|
||||||
"""测试同步服务初始化."""
|
|
||||||
mock_client_class.return_value = self.mock_client
|
|
||||||
|
|
||||||
service = SyncService(self.config)
|
|
||||||
|
|
||||||
# 验证客户端已创建
|
|
||||||
mock_client_class.assert_called_once()
|
|
||||||
self.assertIsNotNone(service.client)
|
|
||||||
self.assertEqual(service.config, self.config)
|
|
||||||
|
|
||||||
@patch("heurams.services.sync_service.Client")
|
|
||||||
def test_sync_service_disabled(self, mock_client_class):
|
|
||||||
"""测试同步服务未启用."""
|
|
||||||
config = SyncConfig(enabled=False)
|
|
||||||
service = SyncService(config)
|
|
||||||
|
|
||||||
# 客户端不应初始化
|
|
||||||
mock_client_class.assert_not_called()
|
|
||||||
self.assertIsNone(service.client)
|
|
||||||
|
|
||||||
@patch("heurams.services.sync_service.Client")
|
|
||||||
def test_test_connection_success(self, mock_client_class):
|
|
||||||
"""测试连接测试成功."""
|
|
||||||
mock_client_class.return_value = self.mock_client
|
|
||||||
self.mock_client.list.return_value = []
|
|
||||||
|
|
||||||
service = SyncService(self.config)
|
|
||||||
result = service.test_connection()
|
|
||||||
|
|
||||||
self.assertTrue(result)
|
|
||||||
self.mock_client.list.assert_called_once()
|
|
||||||
|
|
||||||
@patch("heurams.services.sync_service.Client")
|
|
||||||
def test_test_connection_failure(self, mock_client_class):
|
|
||||||
"""测试连接测试失败."""
|
|
||||||
mock_client_class.return_value = self.mock_client
|
|
||||||
self.mock_client.list.side_effect = Exception("连接失败")
|
|
||||||
|
|
||||||
service = SyncService(self.config)
|
|
||||||
result = service.test_connection()
|
|
||||||
|
|
||||||
self.assertFalse(result)
|
|
||||||
self.mock_client.list.assert_called_once()
|
|
||||||
|
|
||||||
@patch("heurams.services.sync_service.Client")
|
|
||||||
def test_upload_file(self, mock_client_class):
|
|
||||||
"""测试上传单个文件."""
|
|
||||||
mock_client_class.return_value = self.mock_client
|
|
||||||
|
|
||||||
service = SyncService(self.config)
|
|
||||||
result = service.upload_file(self.test_file)
|
|
||||||
|
|
||||||
self.assertTrue(result)
|
|
||||||
self.mock_client.upload_file.assert_called_once()
|
|
||||||
|
|
||||||
@patch("heurams.services.sync_service.Client")
|
|
||||||
def test_download_file(self, mock_client_class):
|
|
||||||
"""测试下载单个文件."""
|
|
||||||
mock_client_class.return_value = self.mock_client
|
|
||||||
|
|
||||||
service = SyncService(self.config)
|
|
||||||
remote_path = "/heurams/test.txt"
|
|
||||||
local_path = self.temp_path / "downloaded.txt"
|
|
||||||
|
|
||||||
result = service.download_file(remote_path, local_path)
|
|
||||||
|
|
||||||
self.assertTrue(result)
|
|
||||||
self.mock_client.download_file.assert_called_once()
|
|
||||||
self.assertTrue(local_path.parent.exists())
|
|
||||||
|
|
||||||
@patch("heurams.services.sync_service.Client")
|
|
||||||
def test_sync_directory_no_files(self, mock_client_class):
|
|
||||||
"""测试同步空目录."""
|
|
||||||
mock_client_class.return_value = self.mock_client
|
|
||||||
self.mock_client.list.return_value = []
|
|
||||||
self.mock_client.mkdir.return_value = None
|
|
||||||
|
|
||||||
service = SyncService(self.config)
|
|
||||||
result = service.sync_directory(self.temp_path)
|
|
||||||
|
|
||||||
self.assertTrue(result["success"])
|
|
||||||
self.assertEqual(result["uploaded"], 0)
|
|
||||||
self.assertEqual(result["downloaded"], 0)
|
|
||||||
self.mock_client.mkdir.assert_called_once()
|
|
||||||
|
|
||||||
@patch("heurams.services.sync_service.Client")
|
|
||||||
def test_sync_directory_upload_only(self, mock_client_class):
|
|
||||||
"""测试仅上传模式."""
|
|
||||||
mock_client_class.return_value = self.mock_client
|
|
||||||
self.mock_client.list.return_value = []
|
|
||||||
self.mock_client.mkdir.return_value = None
|
|
||||||
|
|
||||||
config = SyncConfig(
|
|
||||||
enabled=True,
|
|
||||||
url="https://example.com/dav/",
|
|
||||||
username="test",
|
|
||||||
password="test",
|
|
||||||
remote_path="/heurams/",
|
|
||||||
sync_mode=SyncMode.UPLOAD_ONLY,
|
|
||||||
conflict_strategy=ConflictStrategy.NEWER,
|
|
||||||
)
|
|
||||||
|
|
||||||
service = SyncService(config)
|
|
||||||
result = service.sync_directory(self.temp_path)
|
|
||||||
|
|
||||||
self.assertTrue(result["success"])
|
|
||||||
self.mock_client.mkdir.assert_called_once()
|
|
||||||
|
|
||||||
@patch("heurams.services.sync_service.Client")
|
|
||||||
def test_conflict_strategy_newer(self, mock_client_class):
|
|
||||||
"""测试 NEWER 冲突策略."""
|
|
||||||
mock_client_class.return_value = self.mock_client
|
|
||||||
|
|
||||||
# 模拟远程文件存在
|
|
||||||
self.mock_client.list.return_value = ["test.txt"]
|
|
||||||
self.mock_client.info.return_value = {
|
|
||||||
"size": 100,
|
|
||||||
"modified": "2023-01-01T00:00:00Z",
|
|
||||||
}
|
|
||||||
self.mock_client.mkdir.return_value = None
|
|
||||||
|
|
||||||
service = SyncService(self.config)
|
|
||||||
result = service.sync_directory(self.temp_path)
|
|
||||||
|
|
||||||
self.assertTrue(result["success"])
|
|
||||||
# 应该有一个冲突
|
|
||||||
self.assertGreaterEqual(result.get("conflicts", 0), 0)
|
|
||||||
|
|
||||||
@patch("heurams.services.sync_service.Client")
|
|
||||||
def test_create_sync_service_from_config(self, mock_client_class):
|
|
||||||
"""测试从配置文件创建同步服务."""
|
|
||||||
mock_client_class.return_value = self.mock_client
|
|
||||||
|
|
||||||
# 创建临时配置文件
|
|
||||||
config_data = {
|
|
||||||
"sync": {
|
|
||||||
"webdav": {
|
|
||||||
"enabled": True,
|
|
||||||
"url": "https://example.com/dav/",
|
|
||||||
"username": "test",
|
|
||||||
"password": "test",
|
|
||||||
"remote_path": "/heurams/",
|
|
||||||
"sync_mode": "bidirectional",
|
|
||||||
"conflict_strategy": "newer",
|
|
||||||
"verify_ssl": True,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# 模拟 config_var
|
|
||||||
with patch("heurams.services.sync_service.config_var") as mock_config_var:
|
|
||||||
mock_config = MagicMock()
|
|
||||||
mock_config.data = config_data
|
|
||||||
mock_config_var.get.return_value = mock_config
|
|
||||||
|
|
||||||
from heurams.services.sync_service import create_sync_service_from_config
|
|
||||||
|
|
||||||
service = create_sync_service_from_config()
|
|
||||||
|
|
||||||
self.assertIsNotNone(service)
|
|
||||||
self.assertIsNotNone(service.client)
|
|
||||||
|
|
||||||
|
|
||||||
class TestSyncScreenUnit(unittest.TestCase):
|
|
||||||
"""SyncScreen 的单元测试."""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
"""在每个测试之前运行."""
|
|
||||||
self.temp_dir = tempfile.TemporaryDirectory()
|
|
||||||
self.temp_path = pathlib.Path(self.temp_dir.name)
|
|
||||||
|
|
||||||
# 创建默认配置
|
|
||||||
default_config_path = (
|
|
||||||
pathlib.Path(__file__).parent.parent.parent
|
|
||||||
/ "src/heurams/default/config/config.toml"
|
|
||||||
)
|
|
||||||
self.config = ConfigFile(default_config_path)
|
|
||||||
|
|
||||||
# 更新配置中的路径
|
|
||||||
config_data = self.config.data
|
|
||||||
config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon")
|
|
||||||
config_data["paths"]["electron_dir"] = str(self.temp_path / "electron")
|
|
||||||
config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital")
|
|
||||||
config_data["paths"]["cache_dir"] = str(self.temp_path / "cache")
|
|
||||||
|
|
||||||
# 添加同步配置
|
|
||||||
if "sync" not in config_data:
|
|
||||||
config_data["sync"] = {}
|
|
||||||
config_data["sync"]["webdav"] = {
|
|
||||||
"enabled": False,
|
|
||||||
"url": "",
|
|
||||||
"username": "",
|
|
||||||
"password": "",
|
|
||||||
"remote_path": "/heurams/",
|
|
||||||
"sync_mode": "bidirectional",
|
|
||||||
"conflict_strategy": "newer",
|
|
||||||
"verify_ssl": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 创建目录
|
|
||||||
for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]:
|
|
||||||
pathlib.Path(config_data["paths"][dir_key]).mkdir(
|
|
||||||
parents=True, exist_ok=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用 ConfigContext 设置配置
|
|
||||||
self.config_ctx = ConfigContext(self.config)
|
|
||||||
self.config_ctx.__enter__()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
"""在每个测试之后清理."""
|
|
||||||
self.config_ctx.__exit__(None, None, None)
|
|
||||||
self.temp_dir.cleanup()
|
|
||||||
|
|
||||||
@patch("heurams.interface.screens.synctool.create_sync_service_from_config")
|
|
||||||
def test_sync_screen_compose(self, mock_create_service):
|
|
||||||
"""测试 SyncScreen 的 compose 方法."""
|
|
||||||
from heurams.interface.screens.synctool import SyncScreen
|
|
||||||
|
|
||||||
# 模拟同步服务
|
|
||||||
mock_service = MagicMock()
|
|
||||||
mock_service.client = MagicMock()
|
|
||||||
mock_create_service.return_value = mock_service
|
|
||||||
|
|
||||||
screen = SyncScreen()
|
|
||||||
|
|
||||||
# 测试 compose 方法
|
|
||||||
from textual.app import ComposeResult
|
|
||||||
|
|
||||||
result = screen.compose()
|
|
||||||
widgets = list(result)
|
|
||||||
|
|
||||||
# 检查基本部件
|
|
||||||
from textual.containers import ScrollableContainer
|
|
||||||
from textual.widgets import Button, Footer, Header, ProgressBar, Static
|
|
||||||
|
|
||||||
header_present = any(isinstance(w, Header) for w in widgets)
|
|
||||||
footer_present = any(isinstance(w, Footer) for w in widgets)
|
|
||||||
self.assertTrue(header_present)
|
|
||||||
self.assertTrue(footer_present)
|
|
||||||
|
|
||||||
# 检查容器
|
|
||||||
container_present = any(isinstance(w, ScrollableContainer) for w in widgets)
|
|
||||||
self.assertTrue(container_present)
|
|
||||||
|
|
||||||
@patch("heurams.interface.screens.synctool.create_sync_service_from_config")
|
|
||||||
def test_sync_screen_load_config(self, mock_create_service):
|
|
||||||
"""测试 SyncScreen 加载配置."""
|
|
||||||
from heurams.interface.screens.synctool import SyncScreen
|
|
||||||
|
|
||||||
mock_service = MagicMock()
|
|
||||||
mock_service.client = MagicMock()
|
|
||||||
mock_create_service.return_value = mock_service
|
|
||||||
|
|
||||||
screen = SyncScreen()
|
|
||||||
screen.load_config()
|
|
||||||
|
|
||||||
# 验证配置已加载
|
|
||||||
self.assertIsNotNone(screen.sync_config)
|
|
||||||
mock_create_service.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,186 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
from heurams.kernel.algorithms.sm2 import SM2Algorithm
|
|
||||||
|
|
||||||
|
|
||||||
class TestSM2Algorithm(unittest.TestCase):
|
|
||||||
"""测试 SM2Algorithm 类"""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
# 模拟 timer 函数
|
|
||||||
self.timestamp_patcher = patch(
|
|
||||||
"heurams.kernel.algorithms.sm2.timer.get_timestamp"
|
|
||||||
)
|
|
||||||
self.daystamp_patcher = patch(
|
|
||||||
"heurams.kernel.algorithms.sm2.timer.get_daystamp"
|
|
||||||
)
|
|
||||||
self.mock_get_timestamp = self.timestamp_patcher.start()
|
|
||||||
self.mock_get_daystamp = self.daystamp_patcher.start()
|
|
||||||
|
|
||||||
# 设置固定返回值
|
|
||||||
self.mock_get_timestamp.return_value = 1000.0
|
|
||||||
self.mock_get_daystamp.return_value = 100
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.timestamp_patcher.stop()
|
|
||||||
self.daystamp_patcher.stop()
|
|
||||||
|
|
||||||
def test_defaults(self):
|
|
||||||
"""测试默认值"""
|
|
||||||
defaults = SM2Algorithm.defaults
|
|
||||||
self.assertEqual(defaults["efactor"], 2.5)
|
|
||||||
self.assertEqual(defaults["real_rept"], 0)
|
|
||||||
self.assertEqual(defaults["rept"], 0)
|
|
||||||
self.assertEqual(defaults["interval"], 0)
|
|
||||||
self.assertEqual(defaults["last_date"], 0)
|
|
||||||
self.assertEqual(defaults["next_date"], 0)
|
|
||||||
self.assertEqual(defaults["is_activated"], 0)
|
|
||||||
# last_modify 是动态的, 仅检查存在性
|
|
||||||
self.assertIn("last_modify", defaults)
|
|
||||||
|
|
||||||
def test_revisor_feedback_minus_one(self):
|
|
||||||
"""测试 feedback = -1 时跳过更新"""
|
|
||||||
algodata = {SM2Algorithm.algo_name: SM2Algorithm.defaults.copy()}
|
|
||||||
SM2Algorithm.revisor(algodata, feedback=-1)
|
|
||||||
# 数据应保持不变
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["efactor"], 2.5)
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 0)
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 0)
|
|
||||||
|
|
||||||
def test_revisor_feedback_less_than_3(self):
|
|
||||||
"""测试 feedback < 3 重置 rept 和 interval"""
|
|
||||||
algodata = {
|
|
||||||
SM2Algorithm.algo_name: {
|
|
||||||
"efactor": 2.5,
|
|
||||||
"rept": 5,
|
|
||||||
"interval": 10,
|
|
||||||
"real_rept": 3,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
SM2Algorithm.revisor(algodata, feedback=2)
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 0)
|
|
||||||
# rept=0 导致 interval 被设置为 1
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 1)
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["real_rept"], 4) # 递增
|
|
||||||
|
|
||||||
def test_revisor_feedback_greater_equal_3(self):
|
|
||||||
"""测试 feedback >= 3 递增 rept"""
|
|
||||||
algodata = {
|
|
||||||
SM2Algorithm.algo_name: {
|
|
||||||
"efactor": 2.5,
|
|
||||||
"rept": 2,
|
|
||||||
"interval": 6,
|
|
||||||
"real_rept": 2,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
SM2Algorithm.revisor(algodata, feedback=4)
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 3)
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["real_rept"], 3)
|
|
||||||
# interval 应根据 rept 和 efactor 重新计算
|
|
||||||
# rept=3, interval = round(6 * 2.5) = 15
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 15)
|
|
||||||
|
|
||||||
def test_revisor_new_activation(self):
|
|
||||||
"""测试 is_new_activation 重置 rept 和 efactor"""
|
|
||||||
algodata = {
|
|
||||||
SM2Algorithm.algo_name: {
|
|
||||||
"efactor": 3.0,
|
|
||||||
"rept": 5,
|
|
||||||
"interval": 20,
|
|
||||||
"real_rept": 5,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
SM2Algorithm.revisor(algodata, feedback=5, is_new_activation=True)
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 0)
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["efactor"], 2.5)
|
|
||||||
# interval 应为 1(因为 rept=0)
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 1)
|
|
||||||
|
|
||||||
def test_revisor_efactor_calculation(self):
|
|
||||||
"""测试 efactor 计算"""
|
|
||||||
algodata = {
|
|
||||||
SM2Algorithm.algo_name: {
|
|
||||||
"efactor": 2.5,
|
|
||||||
"rept": 1,
|
|
||||||
"interval": 6,
|
|
||||||
"real_rept": 1,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
SM2Algorithm.revisor(algodata, feedback=5)
|
|
||||||
# efactor = 2.5 + (0.1 - (5-5)*(0.08 + (5-5)*0.02)) = 2.5 + 0.1 = 2.6
|
|
||||||
self.assertAlmostEqual(
|
|
||||||
algodata[SM2Algorithm.algo_name]["efactor"], 2.6, places=6
|
|
||||||
)
|
|
||||||
|
|
||||||
# 测试 efactor 下限为 1.3
|
|
||||||
algodata[SM2Algorithm.algo_name]["efactor"] = 1.2
|
|
||||||
SM2Algorithm.revisor(algodata, feedback=5)
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["efactor"], 1.3)
|
|
||||||
|
|
||||||
def test_revisor_interval_calculation(self):
|
|
||||||
"""测试 interval 计算规则"""
|
|
||||||
algodata = {
|
|
||||||
SM2Algorithm.algo_name: {
|
|
||||||
"efactor": 2.5,
|
|
||||||
"rept": 0,
|
|
||||||
"interval": 0,
|
|
||||||
"real_rept": 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
SM2Algorithm.revisor(algodata, feedback=4)
|
|
||||||
# rept 从 0 递增到 1, 因此 interval 应为 6
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 6)
|
|
||||||
|
|
||||||
# 现在 rept=1, 再次调用 revisor 递增到 2
|
|
||||||
SM2Algorithm.revisor(algodata, feedback=4)
|
|
||||||
# rept=2, interval = round(6 * 2.5) = 15
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 15)
|
|
||||||
|
|
||||||
# 单独测试 rept=1 的情况
|
|
||||||
algodata2 = {
|
|
||||||
SM2Algorithm.algo_name: {
|
|
||||||
"efactor": 2.5,
|
|
||||||
"rept": 1,
|
|
||||||
"interval": 0,
|
|
||||||
"real_rept": 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
SM2Algorithm.revisor(algodata2, feedback=4)
|
|
||||||
# rept 递增到 2, interval = round(0 * 2.5) = 0
|
|
||||||
self.assertEqual(algodata2[SM2Algorithm.algo_name]["interval"], 0)
|
|
||||||
|
|
||||||
def test_revisor_updates_dates(self):
|
|
||||||
"""测试更新日期字段"""
|
|
||||||
algodata = {SM2Algorithm.algo_name: SM2Algorithm.defaults.copy()}
|
|
||||||
self.mock_get_daystamp.return_value = 200
|
|
||||||
SM2Algorithm.revisor(algodata, feedback=5)
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["last_date"], 200)
|
|
||||||
self.assertEqual(
|
|
||||||
algodata[SM2Algorithm.algo_name]["next_date"],
|
|
||||||
200 + algodata[SM2Algorithm.algo_name]["interval"],
|
|
||||||
)
|
|
||||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["last_modify"], 1000.0)
|
|
||||||
|
|
||||||
def test_is_due(self):
|
|
||||||
"""测试 is_due 方法"""
|
|
||||||
algodata = {SM2Algorithm.algo_name: {"next_date": 100}}
|
|
||||||
self.mock_get_daystamp.return_value = 150
|
|
||||||
self.assertTrue(SM2Algorithm.is_due(algodata))
|
|
||||||
|
|
||||||
algodata[SM2Algorithm.algo_name]["next_date"] = 200
|
|
||||||
self.assertFalse(SM2Algorithm.is_due(algodata))
|
|
||||||
|
|
||||||
def test_rate(self):
|
|
||||||
"""测试 rate 方法"""
|
|
||||||
algodata = {SM2Algorithm.algo_name: {"efactor": 2.7}}
|
|
||||||
self.assertEqual(SM2Algorithm.get_rating(algodata), "2.7")
|
|
||||||
|
|
||||||
def test_nextdate(self):
|
|
||||||
"""测试 nextdate 方法"""
|
|
||||||
algodata = {SM2Algorithm.algo_name: {"next_date": 12345}}
|
|
||||||
self.assertEqual(SM2Algorithm.nextdate(algodata), 12345)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,202 +0,0 @@
|
|||||||
import json
|
|
||||||
import pathlib
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import toml
|
|
||||||
|
|
||||||
from heurams.context import ConfigContext
|
|
||||||
from heurams.kernel.particles.atom import Atom, atom_registry
|
|
||||||
from heurams.kernel.particles.electron import Electron
|
|
||||||
from heurams.kernel.particles.nucleon import Nucleon
|
|
||||||
from heurams.kernel.particles.orbital import Orbital
|
|
||||||
from heurams.services.config import ConfigFile
|
|
||||||
|
|
||||||
|
|
||||||
class TestAtom(unittest.TestCase):
|
|
||||||
"""测试 Atom 类"""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
"""在每个测试之前运行"""
|
|
||||||
# 创建临时目录用于持久化测试
|
|
||||||
self.temp_dir = tempfile.TemporaryDirectory()
|
|
||||||
self.temp_path = pathlib.Path(self.temp_dir.name)
|
|
||||||
|
|
||||||
# 创建默认配置
|
|
||||||
self.config = ConfigFile(
|
|
||||||
pathlib.Path(__file__).parent.parent.parent.parent
|
|
||||||
/ "src/heurams/default/config/config.toml"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用 ConfigContext 设置配置
|
|
||||||
self.config_ctx = ConfigContext(self.config)
|
|
||||||
self.config_ctx.__enter__()
|
|
||||||
|
|
||||||
# 清空全局注册表
|
|
||||||
atom_registry.clear()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
"""在每个测试之后运行"""
|
|
||||||
self.config_ctx.__exit__(None, None, None)
|
|
||||||
self.temp_dir.cleanup()
|
|
||||||
atom_registry.clear()
|
|
||||||
|
|
||||||
def test_init(self):
|
|
||||||
"""测试 Atom 初始化"""
|
|
||||||
atom = Atom("test_atom")
|
|
||||||
self.assertEqual(atom.ident, "test_atom")
|
|
||||||
self.assertIn("test_atom", atom_registry)
|
|
||||||
self.assertEqual(atom_registry["test_atom"], atom)
|
|
||||||
|
|
||||||
# 检查 registry 默认值
|
|
||||||
self.assertIsNone(atom.registry["nucleon"])
|
|
||||||
self.assertIsNone(atom.registry["electron"])
|
|
||||||
self.assertIsNone(atom.registry["orbital"])
|
|
||||||
self.assertEqual(atom.registry["nucleon_fmt"], "toml")
|
|
||||||
self.assertEqual(atom.registry["electron_fmt"], "json")
|
|
||||||
self.assertEqual(atom.registry["orbital_fmt"], "toml")
|
|
||||||
|
|
||||||
def test_link(self):
|
|
||||||
"""测试 link 方法"""
|
|
||||||
atom = Atom("test_link")
|
|
||||||
nucleon = Nucleon("test_nucleon", {"content": "test content"})
|
|
||||||
|
|
||||||
atom.link("nucleon", nucleon)
|
|
||||||
self.assertEqual(atom.registry["nucleon"], nucleon)
|
|
||||||
|
|
||||||
# 测试链接不支持的键
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
atom.link("invalid_key", "value")
|
|
||||||
|
|
||||||
def test_link_triggers_do_eval(self):
|
|
||||||
"""测试 link 后触发 do_eval"""
|
|
||||||
atom = Atom("test_eval_trigger")
|
|
||||||
nucleon = Nucleon("test_nucleon", {"content": "eval:1+1"})
|
|
||||||
|
|
||||||
with patch.object(atom, "do_eval") as mock_do_eval:
|
|
||||||
atom.link("nucleon", nucleon)
|
|
||||||
mock_do_eval.assert_called_once()
|
|
||||||
|
|
||||||
def test_persist_toml(self):
|
|
||||||
"""测试 TOML 持久化"""
|
|
||||||
atom = Atom("test_persist_toml")
|
|
||||||
nucleon = Nucleon("test_nucleon", {"content": "test"})
|
|
||||||
atom.link("nucleon", nucleon)
|
|
||||||
|
|
||||||
# 设置路径
|
|
||||||
test_path = self.temp_path / "test.toml"
|
|
||||||
atom.link("nucleon_path", test_path)
|
|
||||||
|
|
||||||
atom.persist("nucleon")
|
|
||||||
|
|
||||||
# 验证文件存在且内容正确
|
|
||||||
self.assertTrue(test_path.exists())
|
|
||||||
with open(test_path, "r") as f:
|
|
||||||
data = toml.load(f)
|
|
||||||
self.assertEqual(data["ident"], "test_nucleon")
|
|
||||||
self.assertEqual(data["payload"]["content"], "test")
|
|
||||||
|
|
||||||
def test_persist_json(self):
|
|
||||||
"""测试 JSON 持久化"""
|
|
||||||
atom = Atom("test_persist_json")
|
|
||||||
electron = Electron("test_electron", {})
|
|
||||||
atom.link("electron", electron)
|
|
||||||
|
|
||||||
test_path = self.temp_path / "test.json"
|
|
||||||
atom.link("electron_path", test_path)
|
|
||||||
|
|
||||||
atom.persist("electron")
|
|
||||||
|
|
||||||
self.assertTrue(test_path.exists())
|
|
||||||
with open(test_path, "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
self.assertIn("supermemo2", data)
|
|
||||||
|
|
||||||
def test_persist_invalid_format(self):
|
|
||||||
"""测试无效持久化格式"""
|
|
||||||
atom = Atom("test_invalid_format")
|
|
||||||
nucleon = Nucleon("test_nucleon", {})
|
|
||||||
atom.link("nucleon", nucleon)
|
|
||||||
atom.link("nucleon_path", self.temp_path / "test.txt")
|
|
||||||
atom.registry["nucleon_fmt"] = "invalid"
|
|
||||||
|
|
||||||
with self.assertRaises(KeyError):
|
|
||||||
atom.persist("nucleon")
|
|
||||||
|
|
||||||
def test_persist_no_path(self):
|
|
||||||
"""测试未初始化路径的持久化"""
|
|
||||||
atom = Atom("test_no_path")
|
|
||||||
nucleon = Nucleon("test_nucleon", {})
|
|
||||||
atom.link("nucleon", nucleon)
|
|
||||||
# 不设置 nucleon_path
|
|
||||||
|
|
||||||
with self.assertRaises(TypeError):
|
|
||||||
atom.persist("nucleon")
|
|
||||||
|
|
||||||
def test_getitem_setitem(self):
|
|
||||||
"""测试 __getitem__ 和 __setitem__"""
|
|
||||||
atom = Atom("test_getset")
|
|
||||||
nucleon = Nucleon("test_nucleon", {})
|
|
||||||
|
|
||||||
atom["nucleon"] = nucleon
|
|
||||||
self.assertEqual(atom["nucleon"], nucleon)
|
|
||||||
|
|
||||||
# 测试不支持的键
|
|
||||||
with self.assertRaises(KeyError):
|
|
||||||
_ = atom["invalid_key"]
|
|
||||||
|
|
||||||
with self.assertRaises(KeyError):
|
|
||||||
atom["invalid_key"] = "value"
|
|
||||||
|
|
||||||
def test_do_eval_with_eval_string(self):
|
|
||||||
"""测试 do_eval 处理 eval: 字符串"""
|
|
||||||
atom = Atom("test_do_eval")
|
|
||||||
nucleon = Nucleon(
|
|
||||||
"test_nucleon",
|
|
||||||
{"content": "eval:'hello' + ' world'", "number": "eval:2 + 3"},
|
|
||||||
)
|
|
||||||
atom.link("nucleon", nucleon)
|
|
||||||
|
|
||||||
# do_eval 应该在链接时自动调用
|
|
||||||
# 检查 eval 表达式是否被求值
|
|
||||||
self.assertEqual(nucleon.payload["content"], "hello world")
|
|
||||||
self.assertEqual(nucleon.payload["number"], "5")
|
|
||||||
|
|
||||||
def test_do_eval_with_config_access(self):
|
|
||||||
"""测试 do_eval 访问配置"""
|
|
||||||
atom = Atom("test_eval_config")
|
|
||||||
nucleon = Nucleon(
|
|
||||||
"test_nucleon", {"max_riddles": "eval:default['mcq']['max_riddles_num']"}
|
|
||||||
)
|
|
||||||
atom.link("nucleon", nucleon)
|
|
||||||
|
|
||||||
# 配置中 puzzles.mcq.max_riddles_num = 2
|
|
||||||
self.assertEqual(nucleon.payload["max_riddles"], 2)
|
|
||||||
|
|
||||||
def test_placeholder(self):
|
|
||||||
"""测试静态方法 placeholder"""
|
|
||||||
placeholder = Atom.placeholder()
|
|
||||||
self.assertIsInstance(placeholder, tuple)
|
|
||||||
self.assertEqual(len(placeholder), 3)
|
|
||||||
self.assertIsInstance(placeholder[0], Electron)
|
|
||||||
self.assertIsInstance(placeholder[1], Nucleon)
|
|
||||||
self.assertIsInstance(placeholder[2], dict)
|
|
||||||
|
|
||||||
def test_atom_registry_management(self):
|
|
||||||
"""测试全局注册表管理"""
|
|
||||||
# 创建多个 Atom
|
|
||||||
atom1 = Atom("atom1")
|
|
||||||
atom2 = Atom("atom2")
|
|
||||||
|
|
||||||
self.assertEqual(len(atom_registry), 2)
|
|
||||||
self.assertEqual(atom_registry["atom1"], atom1)
|
|
||||||
self.assertEqual(atom_registry["atom2"], atom2)
|
|
||||||
|
|
||||||
# 测试 bidict 的反向查找
|
|
||||||
self.assertEqual(atom_registry.inverse[atom1], "atom1")
|
|
||||||
self.assertEqual(atom_registry.inverse[atom2], "atom2")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,179 +0,0 @@
|
|||||||
import sys
|
|
||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
from heurams.kernel.algorithms import algorithms
|
|
||||||
from heurams.kernel.particles.electron import Electron
|
|
||||||
|
|
||||||
|
|
||||||
class TestElectron(unittest.TestCase):
|
|
||||||
"""测试 Electron 类"""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
# 模拟 timer.get_timestamp 返回固定值
|
|
||||||
self.timestamp_patcher = patch(
|
|
||||||
"heurams.kernel.particles.electron.timer.get_timestamp"
|
|
||||||
)
|
|
||||||
self.mock_get_timestamp = self.timestamp_patcher.start()
|
|
||||||
self.mock_get_timestamp.return_value = 1234567890.0
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.timestamp_patcher.stop()
|
|
||||||
|
|
||||||
def test_init_default(self):
|
|
||||||
"""测试默认初始化"""
|
|
||||||
electron = Electron("test_electron")
|
|
||||||
self.assertEqual(electron.ident, "test_electron")
|
|
||||||
self.assertEqual(electron.algo, algorithms["supermemo2"])
|
|
||||||
self.assertIn(electron.algo, electron.algodata)
|
|
||||||
self.assertIsInstance(electron.algodata[electron.algo], dict)
|
|
||||||
# 检查默认值(排除动态字段)
|
|
||||||
defaults = electron.algo.defaults
|
|
||||||
for key, value in defaults.items():
|
|
||||||
if key == "last_modify":
|
|
||||||
# last_modify 是动态的, 只检查存在性
|
|
||||||
self.assertIn(key, electron.algodata[electron.algo])
|
|
||||||
elif key == "is_activated":
|
|
||||||
# TODO: 调查为什么 is_activated 是 1
|
|
||||||
self.assertEqual(electron.algodata[electron.algo][key], 1)
|
|
||||||
else:
|
|
||||||
self.assertEqual(electron.algodata[electron.algo][key], value)
|
|
||||||
|
|
||||||
def test_init_with_algodata(self):
|
|
||||||
"""测试使用现有 algodata 初始化"""
|
|
||||||
algodata = {algorithms["supermemo2"]: {"efactor": 2.5, "interval": 1}}
|
|
||||||
electron = Electron("test_electron", algodata=algodata)
|
|
||||||
self.assertEqual(electron.algodata[electron.algo]["efactor"], 2.5)
|
|
||||||
self.assertEqual(electron.algodata[electron.algo]["interval"], 1)
|
|
||||||
# 其他字段可能不存在, 因为未提供默认初始化
|
|
||||||
# 检查 real_rept 不存在
|
|
||||||
self.assertNotIn("real_rept", electron.algodata[electron.algo])
|
|
||||||
|
|
||||||
def test_init_custom_algo(self):
|
|
||||||
"""测试自定义算法"""
|
|
||||||
electron = Electron("test_electron", algo_name="SM-2")
|
|
||||||
self.assertEqual(electron.algo, algorithms["SM-2"])
|
|
||||||
self.assertIn(electron.algo, electron.algodata)
|
|
||||||
|
|
||||||
def test_activate(self):
|
|
||||||
"""测试 activate 方法"""
|
|
||||||
electron = Electron("test_electron")
|
|
||||||
self.assertEqual(electron.algodata[electron.algo]["is_activated"], 0)
|
|
||||||
electron.activate()
|
|
||||||
self.assertEqual(electron.algodata[electron.algo]["is_activated"], 1)
|
|
||||||
self.assertEqual(electron.algodata[electron.algo]["last_modify"], 1234567890.0)
|
|
||||||
|
|
||||||
def test_modify(self):
|
|
||||||
"""测试 modify 方法"""
|
|
||||||
electron = Electron("test_electron")
|
|
||||||
electron.modify("interval", 5)
|
|
||||||
self.assertEqual(electron.algodata[electron.algo]["interval"], 5)
|
|
||||||
self.assertEqual(electron.algodata[electron.algo]["last_modify"], 1234567890.0)
|
|
||||||
|
|
||||||
# 修改不存在的字段应记录警告但不引发异常
|
|
||||||
with patch("heurams.kernel.particles.electron.logger.warning") as mock_warning:
|
|
||||||
electron.modify("unknown_field", 99)
|
|
||||||
mock_warning.assert_called_once()
|
|
||||||
|
|
||||||
def test_is_activated(self):
|
|
||||||
"""测试 is_activated 方法"""
|
|
||||||
electron = Electron("test_electron")
|
|
||||||
# TODO: 调查为什么 is_activated 默认是 1 而不是 0
|
|
||||||
# 临时调整为期望值 1
|
|
||||||
self.assertEqual(electron.is_activated(), 1)
|
|
||||||
electron.activate()
|
|
||||||
self.assertEqual(electron.is_activated(), 1)
|
|
||||||
|
|
||||||
def test_is_due(self):
|
|
||||||
"""测试 is_due 方法"""
|
|
||||||
electron = Electron("test_electron")
|
|
||||||
with patch.object(electron.algo, "is_due") as mock_is_due:
|
|
||||||
mock_is_due.return_value = 1
|
|
||||||
result = electron.is_due()
|
|
||||||
mock_is_due.assert_called_once_with(electron.algodata)
|
|
||||||
self.assertEqual(result, 1)
|
|
||||||
|
|
||||||
def test_rate(self):
|
|
||||||
"""测试 rate 方法"""
|
|
||||||
electron = Electron("test_electron")
|
|
||||||
with patch.object(electron.algo, "rate") as mock_rate:
|
|
||||||
mock_rate.return_value = "good"
|
|
||||||
result = electron.get_rating()
|
|
||||||
mock_rate.assert_called_once_with(electron.algodata)
|
|
||||||
self.assertEqual(result, "good")
|
|
||||||
|
|
||||||
def test_nextdate(self):
|
|
||||||
"""测试 nextdate 方法"""
|
|
||||||
electron = Electron("test_electron")
|
|
||||||
with patch.object(electron.algo, "nextdate") as mock_nextdate:
|
|
||||||
mock_nextdate.return_value = 1234568000
|
|
||||||
result = electron.nextdate()
|
|
||||||
mock_nextdate.assert_called_once_with(electron.algodata)
|
|
||||||
self.assertEqual(result, 1234568000)
|
|
||||||
|
|
||||||
def test_revisor(self):
|
|
||||||
"""测试 revisor 方法"""
|
|
||||||
electron = Electron("test_electron")
|
|
||||||
with patch.object(electron.algo, "revisor") as mock_revisor:
|
|
||||||
electron.revisor(quality=3, is_new_activation=True)
|
|
||||||
mock_revisor.assert_called_once_with(electron.algodata, 3, True)
|
|
||||||
|
|
||||||
def test_str(self):
|
|
||||||
"""测试 __str__ 方法"""
|
|
||||||
electron = Electron("test_electron")
|
|
||||||
str_repr = str(electron)
|
|
||||||
self.assertIn("记忆单元预览", str_repr)
|
|
||||||
self.assertIn("test_electron", str_repr)
|
|
||||||
# 算法类名会出现在字符串表示中
|
|
||||||
self.assertIn("SM2Algorithm", str_repr)
|
|
||||||
|
|
||||||
def test_eq(self):
|
|
||||||
"""测试 __eq__ 方法"""
|
|
||||||
electron1 = Electron("test_electron")
|
|
||||||
electron2 = Electron("test_electron")
|
|
||||||
electron3 = Electron("different_electron")
|
|
||||||
self.assertEqual(electron1, electron2)
|
|
||||||
self.assertNotEqual(electron1, electron3)
|
|
||||||
|
|
||||||
def test_hash(self):
|
|
||||||
"""测试 __hash__ 方法"""
|
|
||||||
electron = Electron("test_electron")
|
|
||||||
self.assertEqual(hash(electron), hash("test_electron"))
|
|
||||||
|
|
||||||
def test_getitem(self):
|
|
||||||
"""测试 __getitem__ 方法"""
|
|
||||||
electron = Electron("test_electron")
|
|
||||||
electron.activate()
|
|
||||||
self.assertEqual(electron["ident"], "test_electron")
|
|
||||||
self.assertEqual(electron["is_activated"], 1)
|
|
||||||
|
|
||||||
with self.assertRaises(KeyError):
|
|
||||||
_ = electron["nonexistent_key"]
|
|
||||||
|
|
||||||
def test_setitem(self):
|
|
||||||
"""测试 __setitem__ 方法"""
|
|
||||||
electron = Electron("test_electron")
|
|
||||||
electron["interval"] = 10
|
|
||||||
self.assertEqual(electron.algodata[electron.algo]["interval"], 10)
|
|
||||||
self.assertEqual(electron.algodata[electron.algo]["last_modify"], 1234567890.0)
|
|
||||||
|
|
||||||
with self.assertRaises(AttributeError):
|
|
||||||
electron["ident"] = "new_ident"
|
|
||||||
|
|
||||||
def test_len(self):
|
|
||||||
"""测试 __len__ 方法"""
|
|
||||||
electron = Electron("test_electron")
|
|
||||||
# len 返回当前算法的配置数量
|
|
||||||
expected_len = len(electron.algo.defaults)
|
|
||||||
self.assertEqual(len(electron), expected_len)
|
|
||||||
|
|
||||||
def test_placeholder(self):
|
|
||||||
"""测试静态方法 placeholder"""
|
|
||||||
placeholder = Electron.placeholder()
|
|
||||||
self.assertIsInstance(placeholder, Electron)
|
|
||||||
self.assertEqual(placeholder.ident, "电子对象样例内容")
|
|
||||||
self.assertEqual(placeholder.algo, algorithms["supermemo2"])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
from heurams.kernel.particles.nucleon import Nucleon
|
|
||||||
|
|
||||||
|
|
||||||
class TestNucleon(unittest.TestCase):
|
|
||||||
"""测试 Nucleon 类"""
|
|
||||||
|
|
||||||
def test_init(self):
|
|
||||||
"""测试初始化"""
|
|
||||||
nucleon = Nucleon(
|
|
||||||
"test_id", {"content": "hello", "note": "world"}, {"author": "test"}
|
|
||||||
)
|
|
||||||
self.assertEqual(nucleon.ident, "test_id")
|
|
||||||
self.assertEqual(nucleon.payload, {"content": "hello", "note": "world"})
|
|
||||||
self.assertEqual(nucleon.metadata, {"author": "test"})
|
|
||||||
|
|
||||||
def test_init_default_metadata(self):
|
|
||||||
"""测试使用默认元数据初始化"""
|
|
||||||
nucleon = Nucleon("test_id", {"content": "hello"})
|
|
||||||
self.assertEqual(nucleon.ident, "test_id")
|
|
||||||
self.assertEqual(nucleon.payload, {"content": "hello"})
|
|
||||||
self.assertEqual(nucleon.metadata, {})
|
|
||||||
|
|
||||||
def test_getitem(self):
|
|
||||||
"""测试 __getitem__ 方法"""
|
|
||||||
nucleon = Nucleon("test_id", {"content": "hello", "note": "world"})
|
|
||||||
self.assertEqual(nucleon["ident"], "test_id")
|
|
||||||
self.assertEqual(nucleon["content"], "hello")
|
|
||||||
self.assertEqual(nucleon["note"], "world")
|
|
||||||
|
|
||||||
with self.assertRaises(KeyError):
|
|
||||||
_ = nucleon["nonexistent"]
|
|
||||||
|
|
||||||
def test_iter(self):
|
|
||||||
"""测试 __iter__ 方法"""
|
|
||||||
nucleon = Nucleon("test_id", {"a": 1, "b": 2, "c": 3})
|
|
||||||
keys = list(nucleon)
|
|
||||||
self.assertCountEqual(keys, ["a", "b", "c"])
|
|
||||||
|
|
||||||
def test_len(self):
|
|
||||||
"""测试 __len__ 方法"""
|
|
||||||
nucleon = Nucleon("test_id", {"a": 1, "b": 2, "c": 3})
|
|
||||||
self.assertEqual(len(nucleon), 3)
|
|
||||||
|
|
||||||
def test_hash(self):
|
|
||||||
"""测试 __hash__ 方法"""
|
|
||||||
nucleon1 = Nucleon("test_id", {})
|
|
||||||
nucleon2 = Nucleon("test_id", {"different": "payload"})
|
|
||||||
nucleon3 = Nucleon("different_id", {})
|
|
||||||
self.assertEqual(hash(nucleon1), hash(nucleon2)) # 相同 ident
|
|
||||||
self.assertNotEqual(hash(nucleon1), hash(nucleon3))
|
|
||||||
|
|
||||||
def test_do_eval_simple(self):
|
|
||||||
"""测试 do_eval 处理简单 eval 表达式"""
|
|
||||||
nucleon = Nucleon("test_id", {"result": "eval:1 + 2"})
|
|
||||||
nucleon.do_eval()
|
|
||||||
self.assertEqual(nucleon.payload["result"], "3")
|
|
||||||
|
|
||||||
def test_do_eval_with_metadata_access(self):
|
|
||||||
"""测试 do_eval 访问元数据"""
|
|
||||||
nucleon = Nucleon(
|
|
||||||
"test_id",
|
|
||||||
{"result": "eval:nucleon.metadata.get('value', 0)"},
|
|
||||||
{"value": 42},
|
|
||||||
)
|
|
||||||
nucleon.do_eval()
|
|
||||||
self.assertEqual(nucleon.payload["result"], "42")
|
|
||||||
|
|
||||||
def test_do_eval_nested(self):
|
|
||||||
"""测试 do_eval 处理嵌套结构"""
|
|
||||||
nucleon = Nucleon(
|
|
||||||
"test_id",
|
|
||||||
{
|
|
||||||
"list": ["eval:2*3", "normal"],
|
|
||||||
"dict": {"key": "eval:'hello' + ' world'"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
nucleon.do_eval()
|
|
||||||
self.assertEqual(nucleon.payload["list"][0], "6")
|
|
||||||
self.assertEqual(nucleon.payload["list"][1], "normal")
|
|
||||||
self.assertEqual(nucleon.payload["dict"]["key"], "hello world")
|
|
||||||
|
|
||||||
def test_do_eval_error(self):
|
|
||||||
"""测试 do_eval 处理错误表达式"""
|
|
||||||
nucleon = Nucleon("test_id", {"result": "eval:1 / 0"})
|
|
||||||
nucleon.do_eval()
|
|
||||||
self.assertIn("此 eval 实例发生错误", nucleon.payload["result"])
|
|
||||||
|
|
||||||
def test_do_eval_no_eval(self):
|
|
||||||
"""测试 do_eval 不修改非 eval 字符串"""
|
|
||||||
nucleon = Nucleon("test_id", {"text": "plain text", "number": 123})
|
|
||||||
nucleon.do_eval()
|
|
||||||
self.assertEqual(nucleon.payload["text"], "plain text")
|
|
||||||
self.assertEqual(nucleon.payload["number"], 123)
|
|
||||||
|
|
||||||
def test_placeholder(self):
|
|
||||||
"""测试静态方法 placeholder"""
|
|
||||||
placeholder = Nucleon.placeholder()
|
|
||||||
self.assertIsInstance(placeholder, Nucleon)
|
|
||||||
self.assertEqual(placeholder.ident, "核子对象样例内容")
|
|
||||||
self.assertEqual(placeholder.payload, {})
|
|
||||||
self.assertEqual(placeholder.metadata, {})
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from unittest.mock import Mock
|
|
||||||
|
|
||||||
from heurams.kernel.evaluators.base import BasePuzzle
|
|
||||||
|
|
||||||
|
|
||||||
class TestBasePuzzle(unittest.TestCase):
|
|
||||||
"""测试 BasePuzzle 基类"""
|
|
||||||
|
|
||||||
def test_refresh_not_implemented(self):
|
|
||||||
"""测试 refresh 方法未实现时抛出异常"""
|
|
||||||
puzzle = BasePuzzle()
|
|
||||||
with self.assertRaises(NotImplementedError):
|
|
||||||
puzzle.refresh()
|
|
||||||
|
|
||||||
def test_str(self):
|
|
||||||
"""测试 __str__ 方法"""
|
|
||||||
puzzle = BasePuzzle()
|
|
||||||
self.assertEqual(str(puzzle), "谜题: BasePuzzle")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
from heurams.kernel.evaluators.cloze import ClozePuzzle
|
|
||||||
|
|
||||||
|
|
||||||
class TestClozePuzzle(unittest.TestCase):
|
|
||||||
"""测试 ClozePuzzle 类"""
|
|
||||||
|
|
||||||
def test_init(self):
|
|
||||||
"""测试初始化"""
|
|
||||||
puzzle = ClozePuzzle("hello/world/test", min_denominator=3, delimiter="/")
|
|
||||||
self.assertEqual(puzzle.text, "hello/world/test")
|
|
||||||
self.assertEqual(puzzle.min_denominator, 3)
|
|
||||||
self.assertEqual(puzzle.delimiter, "/")
|
|
||||||
self.assertEqual(puzzle.wording, "填空题 - 尚未刷新谜题")
|
|
||||||
self.assertEqual(puzzle.answer, ["填空题 - 尚未刷新谜题"])
|
|
||||||
|
|
||||||
@patch("random.sample")
|
|
||||||
def test_refresh(self, mock_sample):
|
|
||||||
"""测试 refresh 方法"""
|
|
||||||
mock_sample.return_value = [0, 2] # 选择索引 0 和 2
|
|
||||||
puzzle = ClozePuzzle("hello/world/test", min_denominator=2, delimiter="/")
|
|
||||||
puzzle.refresh()
|
|
||||||
|
|
||||||
# 检查 wording 和 answer
|
|
||||||
self.assertNotEqual(puzzle.wording, "填空题 - 尚未刷新谜题")
|
|
||||||
self.assertNotEqual(puzzle.answer, ["填空题 - 尚未刷新谜题"])
|
|
||||||
# 根据模拟, 应该有两个填空
|
|
||||||
self.assertEqual(len(puzzle.answer), 2)
|
|
||||||
self.assertEqual(puzzle.answer, ["hello", "test"])
|
|
||||||
# wording 应包含下划线
|
|
||||||
self.assertIn("__", puzzle.wording)
|
|
||||||
|
|
||||||
def test_refresh_empty_text(self):
|
|
||||||
"""测试空文本的 refresh"""
|
|
||||||
puzzle = ClozePuzzle("", min_denominator=3, delimiter="/")
|
|
||||||
puzzle.refresh() # 不应引发异常
|
|
||||||
# 空文本导致 wording 和 answer 为空
|
|
||||||
self.assertEqual(puzzle.wording, "")
|
|
||||||
self.assertEqual(puzzle.answer, [])
|
|
||||||
|
|
||||||
def test_str(self):
|
|
||||||
"""测试 __str__ 方法"""
|
|
||||||
puzzle = ClozePuzzle("hello/world", min_denominator=2, delimiter="/")
|
|
||||||
str_repr = str(puzzle)
|
|
||||||
self.assertIn("填空题 - 尚未刷新谜题", str_repr)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock, call, patch
|
|
||||||
|
|
||||||
from heurams.kernel.evaluators.mcq import MCQPuzzle
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCQPuzzle(unittest.TestCase):
|
|
||||||
"""测试 MCQPuzzle 类"""
|
|
||||||
|
|
||||||
def test_init(self):
|
|
||||||
"""测试初始化"""
|
|
||||||
mapping = {"q1": "a1", "q2": "a2"}
|
|
||||||
jammer = ["j1", "j2", "j3"]
|
|
||||||
puzzle = MCQPuzzle(mapping, jammer, max_riddles_num=3, prefix="选择")
|
|
||||||
self.assertEqual(puzzle.prefix, "选择")
|
|
||||||
self.assertEqual(puzzle.mapping, mapping)
|
|
||||||
self.assertEqual(puzzle.max_riddles_num, 3)
|
|
||||||
# jammer 应合并正确答案并去重
|
|
||||||
self.assertIn("a1", puzzle.jammer)
|
|
||||||
self.assertIn("a2", puzzle.jammer)
|
|
||||||
self.assertIn("j1", puzzle.jammer)
|
|
||||||
# 初始状态
|
|
||||||
self.assertEqual(puzzle.wording, "选择题 - 尚未刷新谜题")
|
|
||||||
self.assertEqual(puzzle.answer, ["选择题 - 尚未刷新谜题"])
|
|
||||||
self.assertEqual(puzzle.options, [])
|
|
||||||
|
|
||||||
def test_init_max_riddles_num_clamping(self):
|
|
||||||
"""测试 max_riddles_num 限制在 1-5 之间"""
|
|
||||||
puzzle1 = MCQPuzzle({}, [], max_riddles_num=0)
|
|
||||||
self.assertEqual(puzzle1.max_riddles_num, 1)
|
|
||||||
puzzle2 = MCQPuzzle({}, [], max_riddles_num=10)
|
|
||||||
self.assertEqual(puzzle2.max_riddles_num, 5)
|
|
||||||
|
|
||||||
def test_init_jammer_ensures_minimum(self):
|
|
||||||
"""测试干扰项至少保证 4 个"""
|
|
||||||
puzzle = MCQPuzzle({}, [])
|
|
||||||
# 正确答案为空, 干扰项为空, 应填充空格
|
|
||||||
self.assertEqual(len(puzzle.jammer), 4)
|
|
||||||
self.assertEqual(set(puzzle.jammer), {" "}) # 三个空格? 实际上循环填充空格
|
|
||||||
|
|
||||||
@patch("random.sample")
|
|
||||||
@patch("random.shuffle")
|
|
||||||
def test_refresh(self, mock_shuffle, mock_sample):
|
|
||||||
"""测试 refresh 方法生成题目"""
|
|
||||||
mapping = {"q1": "a1", "q2": "a2", "q3": "a3"}
|
|
||||||
jammer = ["j1", "j2", "j3", "j4"]
|
|
||||||
puzzle = MCQPuzzle(mapping, jammer, max_riddles_num=2)
|
|
||||||
# 模拟 random.sample 返回前两个映射项
|
|
||||||
mock_sample.side_effect = [
|
|
||||||
[("q1", "a1"), ("q2", "a2")], # 选择问题
|
|
||||||
["j1", "j2", "j3"], # 为每个问题选择干扰项(实际调用两次)
|
|
||||||
]
|
|
||||||
puzzle.refresh()
|
|
||||||
|
|
||||||
# 检查 wording 是列表
|
|
||||||
self.assertIsInstance(puzzle.wording, list)
|
|
||||||
self.assertEqual(len(puzzle.wording), 2)
|
|
||||||
# 检查 answer 列表
|
|
||||||
self.assertEqual(puzzle.answer, ["a1", "a2"])
|
|
||||||
# 检查 options 列表
|
|
||||||
self.assertEqual(len(puzzle.options), 2)
|
|
||||||
# 每个选项列表应包含 4 个选项(正确答案 + 3 个干扰项)
|
|
||||||
self.assertEqual(len(puzzle.options[0]), 4)
|
|
||||||
self.assertEqual(len(puzzle.options[1]), 4)
|
|
||||||
# random.shuffle 应被调用
|
|
||||||
self.assertEqual(mock_shuffle.call_count, 2)
|
|
||||||
|
|
||||||
def test_refresh_empty_mapping(self):
|
|
||||||
"""测试空 mapping 的 refresh"""
|
|
||||||
puzzle = MCQPuzzle({}, [])
|
|
||||||
puzzle.refresh()
|
|
||||||
self.assertEqual(puzzle.wording, "无可用题目")
|
|
||||||
self.assertEqual(puzzle.answer, ["无答案"])
|
|
||||||
self.assertEqual(puzzle.options, [])
|
|
||||||
|
|
||||||
def test_get_question_count(self):
|
|
||||||
"""测试 get_question_count 方法"""
|
|
||||||
puzzle = MCQPuzzle({"q": "a"}, [])
|
|
||||||
self.assertEqual(puzzle.get_question_count(), 0) # 未刷新
|
|
||||||
puzzle.refresh = MagicMock()
|
|
||||||
puzzle.wording = ["Q1", "Q2"]
|
|
||||||
self.assertEqual(puzzle.get_question_count(), 2)
|
|
||||||
puzzle.wording = "无可用题目"
|
|
||||||
self.assertEqual(puzzle.get_question_count(), 0)
|
|
||||||
puzzle.wording = "单个问题"
|
|
||||||
self.assertEqual(puzzle.get_question_count(), 1)
|
|
||||||
|
|
||||||
def test_get_correct_answer_for_question(self):
|
|
||||||
"""测试 get_correct_answer_for_question 方法"""
|
|
||||||
puzzle = MCQPuzzle({}, [])
|
|
||||||
puzzle.answer = ["ans1", "ans2"]
|
|
||||||
self.assertEqual(puzzle.get_correct_answer_for_question(0), "ans1")
|
|
||||||
self.assertEqual(puzzle.get_correct_answer_for_question(1), "ans2")
|
|
||||||
self.assertIsNone(puzzle.get_correct_answer_for_question(2))
|
|
||||||
puzzle.answer = "not a list"
|
|
||||||
self.assertIsNone(puzzle.get_correct_answer_for_question(0))
|
|
||||||
|
|
||||||
def test_get_options_for_question(self):
|
|
||||||
"""测试 get_options_for_question 方法"""
|
|
||||||
puzzle = MCQPuzzle({}, [])
|
|
||||||
puzzle.options = [["a", "b", "c", "d"], ["e", "f", "g", "h"]]
|
|
||||||
self.assertEqual(puzzle.get_options_for_question(0), ["a", "b", "c", "d"])
|
|
||||||
self.assertEqual(puzzle.get_options_for_question(1), ["e", "f", "g", "h"])
|
|
||||||
self.assertIsNone(puzzle.get_options_for_question(2))
|
|
||||||
|
|
||||||
def test_str(self):
|
|
||||||
"""测试 __str__ 方法"""
|
|
||||||
puzzle = MCQPuzzle({}, [])
|
|
||||||
puzzle.wording = "选择题 - 尚未刷新谜题"
|
|
||||||
puzzle.answer = ["选择题 - 尚未刷新谜题"]
|
|
||||||
self.assertIn("选择题 - 尚未刷新谜题", str(puzzle))
|
|
||||||
self.assertIn("正确答案", str(puzzle))
|
|
||||||
|
|
||||||
puzzle.wording = ["Q1", "Q2"]
|
|
||||||
puzzle.answer = ["A1", "A2"]
|
|
||||||
str_repr = str(puzzle)
|
|
||||||
self.assertIn("Q1", str_repr)
|
|
||||||
self.assertIn("A1, A2", str_repr)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
|
||||||
|
|
||||||
from heurams.kernel.particles.atom import Atom
|
|
||||||
from heurams.kernel.particles.electron import Electron
|
|
||||||
from heurams.kernel.reactor.procession import Phaser
|
|
||||||
from heurams.kernel.reactor.states import PhaserState, ProcessionState
|
|
||||||
|
|
||||||
|
|
||||||
class TestPhaser(unittest.TestCase):
|
|
||||||
"""测试 Phaser 类"""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
# 创建模拟的 Atom 对象
|
|
||||||
self.atom_new = Mock(spec=Atom)
|
|
||||||
self.atom_new.registry = {"electron": Mock(spec=Electron)}
|
|
||||||
self.atom_new.registry["electron"].is_activated.return_value = False
|
|
||||||
|
|
||||||
self.atom_old = Mock(spec=Atom)
|
|
||||||
self.atom_old.registry = {"electron": Mock(spec=Electron)}
|
|
||||||
self.atom_old.registry["electron"].is_activated.return_value = True
|
|
||||||
|
|
||||||
# 模拟 Procession 类以避免复杂依赖
|
|
||||||
self.procession_patcher = patch("heurams.kernel.reactor.phaser.Procession")
|
|
||||||
self.mock_procession_class = self.procession_patcher.start()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.procession_patcher.stop()
|
|
||||||
|
|
||||||
def test_init_with_mixed_atoms(self):
|
|
||||||
"""测试混合新旧原子的初始化"""
|
|
||||||
atoms = [self.atom_old, self.atom_new, self.atom_old]
|
|
||||||
phaser = Phaser(atoms)
|
|
||||||
|
|
||||||
# 应该创建两个 Procession: 一个用于旧原子, 一个用于新原子, 以及一个总体复习
|
|
||||||
self.assertEqual(self.mock_procession_class.call_count, 3)
|
|
||||||
|
|
||||||
# 检查调用参数
|
|
||||||
calls = self.mock_procession_class.call_args_list
|
|
||||||
# 第一个调用应该是旧原子的初始复习
|
|
||||||
self.assertEqual(calls[0][0][0], [self.atom_old, self.atom_old])
|
|
||||||
self.assertEqual(calls[0][0][1], PhaserState.QUICK_REVIEW)
|
|
||||||
# 第二个调用应该是新原子的识别阶段
|
|
||||||
self.assertEqual(calls[1][0][0], [self.atom_new])
|
|
||||||
self.assertEqual(calls[1][0][1], PhaserState.RECOGNITION)
|
|
||||||
# 第三个调用应该是所有原子的总体复习
|
|
||||||
self.assertEqual(calls[2][0][0], atoms)
|
|
||||||
self.assertEqual(calls[2][0][1], PhaserState.FINAL_REVIEW)
|
|
||||||
|
|
||||||
def test_init_only_old_atoms(self):
|
|
||||||
"""测试只有旧原子"""
|
|
||||||
atoms = [self.atom_old, self.atom_old]
|
|
||||||
phaser = Phaser(atoms)
|
|
||||||
|
|
||||||
# 应该创建两个 Procession: 一个初始复习, 一个总体复习
|
|
||||||
self.assertEqual(self.mock_procession_class.call_count, 2)
|
|
||||||
calls = self.mock_procession_class.call_args_list
|
|
||||||
self.assertEqual(calls[0][0][0], atoms)
|
|
||||||
self.assertEqual(calls[0][0][1], PhaserState.QUICK_REVIEW)
|
|
||||||
self.assertEqual(calls[1][0][0], atoms)
|
|
||||||
self.assertEqual(calls[1][0][1], PhaserState.FINAL_REVIEW)
|
|
||||||
|
|
||||||
def test_init_only_new_atoms(self):
|
|
||||||
"""测试只有新原子"""
|
|
||||||
atoms = [self.atom_new, self.atom_new]
|
|
||||||
phaser = Phaser(atoms)
|
|
||||||
|
|
||||||
self.assertEqual(self.mock_procession_class.call_count, 2)
|
|
||||||
calls = self.mock_procession_class.call_args_list
|
|
||||||
self.assertEqual(calls[0][0][0], atoms)
|
|
||||||
self.assertEqual(calls[0][0][1], PhaserState.RECOGNITION)
|
|
||||||
self.assertEqual(calls[1][0][0], atoms)
|
|
||||||
self.assertEqual(calls[1][0][1], PhaserState.FINAL_REVIEW)
|
|
||||||
|
|
||||||
def test_current_procession_finds_unfinished(self):
|
|
||||||
"""测试 current_procession 找到未完成的 Procession"""
|
|
||||||
# 创建模拟 Procession 实例
|
|
||||||
mock_proc1 = Mock()
|
|
||||||
mock_proc1.state = ProcessionState.FINISHED
|
|
||||||
mock_proc2 = Mock()
|
|
||||||
mock_proc2.state = ProcessionState.RUNNING
|
|
||||||
mock_proc2.phase = PhaserState.QUICK_REVIEW
|
|
||||||
|
|
||||||
phaser = Phaser([])
|
|
||||||
phaser.processions = [mock_proc1, mock_proc2]
|
|
||||||
|
|
||||||
result = phaser.current_procession()
|
|
||||||
self.assertEqual(result, mock_proc2)
|
|
||||||
self.assertEqual(phaser.state, PhaserState.QUICK_REVIEW)
|
|
||||||
|
|
||||||
def test_current_procession_all_finished(self):
|
|
||||||
"""测试所有 Procession 都完成"""
|
|
||||||
mock_proc = Mock()
|
|
||||||
mock_proc.state = ProcessionState.FINISHED
|
|
||||||
|
|
||||||
phaser = Phaser([])
|
|
||||||
phaser.processions = [mock_proc]
|
|
||||||
|
|
||||||
result = phaser.current_procession()
|
|
||||||
self.assertEqual(result, 0)
|
|
||||||
self.assertEqual(phaser.state, PhaserState.FINISHED)
|
|
||||||
|
|
||||||
def test_current_procession_empty(self):
|
|
||||||
"""测试没有 Procession"""
|
|
||||||
phaser = Phaser([])
|
|
||||||
phaser.processions = []
|
|
||||||
|
|
||||||
result = phaser.current_procession()
|
|
||||||
self.assertEqual(result, 0)
|
|
||||||
self.assertEqual(phaser.state, PhaserState.FINISHED)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
Reference in New Issue
Block a user