fix: 改进代码

This commit is contained in:
2026-01-04 04:46:19 +08:00
parent c585c79e73
commit 65486794b7
34 changed files with 87 additions and 1570 deletions

View File

@@ -3,6 +3,7 @@ import heurams.kernel.particles as pt
from heurams.services.textproc import truncate
from pathlib import Path
import time
repo = repolib.Repo.create_from_repodir(Path("./test_repo"))
alist = list()
print(repo.ident_index)
@@ -24,6 +25,7 @@ for i in repo.ident_index:
print(repo)
input()
import heurams.kernel.reactor as rt
ph: rt.Phaser = rt.Phaser(alist)
print(ph)
pr: rt.Procession = ph.current_procession() # type: ignore

View File

@@ -17,9 +17,9 @@ def environment_check():
from pathlib import Path
logger.debug("检查环境路径")
subdir = ['cache/voice', 'repo', 'global', 'config']
subdir = ["cache/voice", "repo", "global", "config"]
for i in subdir:
i = Path(config_var.get()['paths']['data']) / i
i = Path(config_var.get()["paths"]["data"]) / i
if not i.exists():
logger.info("创建目录: %s", i)
print(f"创建 {i}")

View File

@@ -14,6 +14,8 @@ from heurams.kernel.particles import *
from heurams.kernel.repolib import *
from heurams.services.logger import get_logger
import heurams.kernel.particles as pt
from pathlib import Path
from .about import AboutScreen
from .preparation import PreparationScreen
@@ -42,7 +44,9 @@ class DashboardScreen(Screen):
yield Header(show_clock=True)
yield ScrollableContainer(
Label('欢迎使用 "潜进" 启发式辅助记忆调度器', classes="title-label"),
Label(f"当前 UNIX 日时间戳: {timer.get_daystamp()} (UTC+{config_var.get()["timezone_offset"] / 3600})"),
Label(
f"当前 UNIX 日时间戳: {timer.get_daystamp()} (UTC+{config_var.get()["timezone_offset"] / 3600})"
),
Label(f"全局算法设置: {config_var.get()['algorithm']['default']}"),
Label("选择待学习或待修改的项目:", classes="title-label"),
ListView(id="repo-list", classes="repo-list-view"),

View File

@@ -69,14 +69,12 @@ class MemScreen(Screen):
try:
self.fission = self.procession.get_fission()
puzzle = self.fission.get_current_puzzle()
# logger.debug(puzzle_debug)
return shim.puzzle2widget[puzzle["puzzle"]]( # type: ignore
atom=self.atom, alia=puzzle["alia"] # type: ignore
)
except (KeyError, StopIteration, AttributeError) as e:
logger.debug(f"调度展开出错: {e}")
return Static(f"无法生成谜题 {e}")
# logger.debug(shim.puzzle2widget[puzzle_debug["puzzle"]])
def _get_progress_text(self):
s = f"阶段: {self.procession.phase.name}\n"
@@ -117,18 +115,14 @@ class MemScreen(Screen):
from heurams.services.audio_service import play_by_path
from heurams.services.hasher import get_md5
path = Path(config_var.get()["paths"]['data']) / 'cache' / 'voice'
path = (
path
/ f"{get_md5(self.atom.registry['nucleon']["tts_text"])}.wav"
)
path = Path(config_var.get()["paths"]["data"]) / "cache" / "voice"
path = path / f"{get_md5(self.atom.registry['nucleon']["tts_text"])}.wav"
if path.exists():
play_by_path(path)
else:
from heurams.services.tts_service import convertor
convertor(
self.atom.registry["nucleon"]["tts_text"], path
)
convertor(self.atom.registry["nucleon"]["tts_text"], path)
play_by_path(path)
def watch_rating(self, old_rating, new_rating) -> None:

View File

@@ -12,7 +12,8 @@ import heurams.kernel.particles as pt
import heurams.services.hasher as hasher
from heurams.context import *
cache_dir = pathlib.Path(config_var.get()["paths"]["data"]) / "cache" / 'voice'
cache_dir = pathlib.Path(config_var.get()["paths"]["data"]) / "cache" / "voice"
class PrecachingScreen(Screen):
"""预缓存音频文件屏幕
@@ -204,9 +205,7 @@ class PrecachingScreen(Screen):
from heurams.context import config_var, rootdir, workdir
shutil.rmtree(
cache_dir, ignore_errors=True
)
shutil.rmtree(cache_dir, ignore_errors=True)
self.update_status("已清空", "音频缓存已清空", 0)
except Exception as e:
self.update_status("错误", f"清空缓存失败: {e}")

View File

@@ -128,6 +128,7 @@ class PreparationScreen(Screen):
atoms_to_provide.append(i)
from .memoqueue import MemScreen
import heurams.kernel.reactor as rt
pheser = rt.Phaser(atoms_to_provide)
memscreen = MemScreen(pheser)
self.app.push_screen(memscreen)

View File

@@ -1,4 +1,5 @@
"""Kernel 操作辅助函数库"""
import heurams.interface.widgets as pzw
import heurams.kernel.evaluators as pz

View File

@@ -54,7 +54,7 @@ class MCQPuzzle(BasePuzzleWidget):
self._load()
def _load(self):
cfg = self.atom.registry["orbital"]["puzzles"][self.alia]
cfg = self.atom.registry["nucleon"]["puzzles"][self.alia]
self.puzzle = pz.MCQPuzzle(
cfg["mapping"], cfg["jammer"], int(cfg["max_riddles_num"]), cfg["prefix"]
)

View File

@@ -1,11 +1,7 @@
from heurams.services.logger import get_logger
from .base import BaseAlgorithm
from .sm2 import SM2Algorithm
from .sm15m import SM15MAlgorithm
logger = get_logger(__name__)
__all__ = [
"SM2Algorithm",
"BaseAlgorithm",
@@ -17,5 +13,3 @@ algorithms = {
"SM-15M": SM15MAlgorithm,
"Base": BaseAlgorithm,
}
logger.debug("算法模块初始化完成, 注册的算法: %s", list(algorithms.keys()))

View File

@@ -6,8 +6,6 @@ Evaluator 模块 - 生成评估模块
from heurams.services.logger import get_logger
logger = get_logger(__name__)
from .base import BaseEvaluator
from .cloze import ClozePuzzle
from .mcq import MCQPuzzle
@@ -26,38 +24,3 @@ puzzles = {
"recognition": RecognitionPuzzle,
"base": BaseEvaluator,
}
@staticmethod
def create_by_dict(config_dict: dict) -> BaseEvaluator:
"""
根据配置字典创建谜题
Args:
config_dict: 配置字典, 包含谜题类型和参数
Returns:
BasePuzzle: 谜题实例
Raises:
ValueError: 当配置无效时抛出
"""
logger.debug(
"puzzles.create_by_dict: config_dict keys=%s", list(config_dict.keys())
)
puzzle_type = config_dict.get("type")
if puzzle_type == "cloze":
return puzzles["cloze"](
text=config_dict["text"],
min_denominator=config_dict.get("min_denominator", 7),
)
elif puzzle_type == "mcq":
return puzzles["mcq"](
mapping=config_dict["mapping"],
jammer=config_dict.get("jammer", []),
max_riddles_num=config_dict.get("max_riddles_num", 2),
prefix=config_dict.get("prefix", ""),
)
else:
raise ValueError(f"未知的谜题类型: {puzzle_type}")

View File

@@ -1,4 +1,21 @@
from .atom import Atom
from .electron import Electron
from .nucleon import Nucleon
from .placeholders import (
AtomPlaceholder,
NucleonPlaceholder,
ElectronPlaceholder,
orbital_placeholder,
)
# from .orbital import Orbital
__all__ = [
"Atom",
"Electron",
"Nucleon",
"AtomPlaceholder",
"NucleonPlaceholder",
"ElectronPlaceholder",
"orbital_placeholder",
]

View File

@@ -13,9 +13,11 @@ class Nucleon:
def __init__(self, ident, payload, common):
self.ident = ident
env = {"payload": payload,
"default": config_var.get()['puzzles'],
"nucleon": (payload | common)}
env = {
"payload": payload,
"default": config_var.get()["puzzles"],
"nucleon": (payload | common),
}
self.evalizer = Evalizer(environment=env)
self.data: dict = self.evalizer(deepcopy((payload | common))) # type: ignore

View File

@@ -5,8 +5,4 @@ from .phaser import Phaser
from .procession import Procession
from .states import PhaserState, ProcessionState
logger = get_logger(__name__)
__all__ = ["PhaserState", "ProcessionState", "Procession", "Fission", "Phaser"]
logger.debug("反应堆模块已加载")

View File

@@ -20,7 +20,7 @@ class Fission:
phase_state.value if isinstance(phase_state, PhaserState) else phase_state
)
self.orbital_schedule = atom.registry['orbital']["phases"][phase_value] # type: ignore
self.orbital_schedule = atom.registry["orbital"]["phases"][phase_value] # type: ignore
self.orbital_puzzles = atom.registry["nucleon"]["puzzles"]
self.puzzles = list()
@@ -34,7 +34,6 @@ class Fission:
{
"puzzle": puz.puzzles[self.orbital_puzzles[item]["__origin__"]],
"alia": item,
"finished": 0,
}
)
possibility -= 1
@@ -44,7 +43,6 @@ class Fission:
{
"puzzle": puz.puzzles[self.orbital_puzzles[item]["__origin__"]],
"alia": item,
"finished": 0,
}
)
@@ -62,7 +60,6 @@ class Fission:
else:
return self.puzzles[self.cursor]
def check_passed(self):
for i in self.puzzles:
if i["finished"] == 0:

View File

@@ -130,6 +130,7 @@ class Phaser(Machine):
def __repr__(self):
from heurams.services.textproc import truncate
from tabulate import tabulate as tabu
lst = [
{
"Type": "Phaser",
@@ -138,4 +139,4 @@ class Phaser(Machine):
"Current Procession": "None" if not self.current_procession() else self.current_procession().name_, # type: ignore
},
]
return str(tabu(lst, headers="keys")) + '\n'
return str(tabu(lst, headers="keys")) + "\n"

View File

@@ -63,8 +63,7 @@ class Procession(Machine):
logger.debug("Procession 进入 FINISHED 状态")
def forward(self, step=1):
"""将记忆原子指针向前移动并依情况更新原子(返回 1)或完成队列(返回 0)
"""
"""将记忆原子指针向前移动并依情况更新原子(返回 1)或完成队列(返回 0)"""
logger.debug("Procession.forward: step=%d, 当前 cursor=%d", step, self.cursor)
self.cursor += step
if self.cursor >= len(self.queue):
@@ -84,8 +83,7 @@ class Procession(Machine):
return 0
def append(self, atom=None):
"""追加(回忆失败的)原子(默认为当前原子)到队列末端
"""
"""追加(回忆失败的)原子(默认为当前原子)到队列末端"""
if atom is None:
atom = self.current_atom
logger.debug("Procession.append: atom=%s", atom.ident if atom else "None")
@@ -122,6 +120,7 @@ class Procession(Machine):
def __repr__(self):
from heurams.services.textproc import truncate
dic = [
{
"Type": "Procession",
@@ -132,4 +131,4 @@ class Procession(Machine):
"Current Atom": self.current_atom.ident, # type: ignore
}
]
return str(tabu(dic, headers="keys")) + '\n'
return str(tabu(dic, headers="keys")) + "\n"

View File

@@ -1 +1,3 @@
from .repo import *
from .repo import Repo, RepoManifest
__all__ = ["Repo", "RepoManifest"]

View File

@@ -1,5 +0,0 @@
from ...utils.lict import Lict
def merge(x: Lict, y: Lict):
return Lict(list(x.values()) + list(y.values()))

View File

@@ -0,0 +1,5 @@
from .evalizor import Evalizer
from .lict import Lict
from .refvar import RefVar
__all__ = ["Evalizer", "Lict", "RefVar"]

View File

@@ -1,153 +0,0 @@
#!/usr/bin/env python3
"""
DashboardScreen 的测试, 包括单元测试和 pilot 测试.
"""
import pathlib
import tempfile
import time
import unittest
from unittest.mock import MagicMock, patch
from textual.pilot import Pilot
from heurams.context import ConfigContext
from heurams.interface.__main__ import HeurAMSApp
from heurams.interface.screens.dashboard import DashboardScreen
from heurams.services.config import ConfigFile
class TestDashboardScreenUnit(unittest.TestCase):
"""DashboardScreen 的单元测试(不启动完整应用)."""
def setUp(self):
"""在每个测试之前运行, 设置临时目录和配置."""
# 创建临时目录用于测试数据
self.temp_dir = tempfile.TemporaryDirectory()
self.temp_path = pathlib.Path(self.temp_dir.name)
# 创建默认配置, 并修改路径指向临时目录
default_config_path = (
pathlib.Path(__file__).parent.parent.parent
/ "src/heurams/default/config/config.toml"
)
self.config = ConfigFile(default_config_path)
# 更新配置中的路径
config_data = self.config.data
config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon")
config_data["paths"]["electron_dir"] = str(self.temp_path / "electron")
config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital")
config_data["paths"]["cache_dir"] = str(self.temp_path / "cache")
# 禁用快速通过, 避免测试干扰
config_data["quick_pass"] = 0
# 禁用时间覆盖
config_data["daystamp_override"] = -1
config_data["timestamp_override"] = -1
# 创建目录
for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]:
pathlib.Path(config_data["paths"][dir_key]).mkdir(
parents=True, exist_ok=True
)
# 使用 ConfigContext 设置配置
self.config_ctx = ConfigContext(self.config)
self.config_ctx.__enter__()
def tearDown(self):
"""在每个测试之后清理."""
self.config_ctx.__exit__(None, None, None)
self.temp_dir.cleanup()
def test_compose(self):
"""测试 compose 方法返回正确的部件."""
screen = DashboardScreen()
# 手动调用 compose 并收集部件
from textual.app import ComposeResult
result = screen.compose()
widgets = list(result)
# 检查是否包含 Header 和 Footer
from textual.widgets import Footer, Header
header_present = any(isinstance(w, Header) for w in widgets)
footer_present = any(isinstance(w, Footer) for w in widgets)
self.assertTrue(header_present)
self.assertTrue(footer_present)
# 检查是否有 ScrollableContainer
from textual.containers import ScrollableContainer
container_present = any(isinstance(w, ScrollableContainer) for w in widgets)
self.assertTrue(container_present)
# 使用 query_one 查找 union-list, 即使屏幕未挂载也可能有效
list_view = screen.query_one("#union-list")
self.assertIsNotNone(list_view)
self.assertEqual(list_view.id, "union-list")
self.assertEqual(list_view.__class__.__name__, "ListView")
def test_item_desc_generator(self):
"""测试 item_desc_generator 函数."""
screen = DashboardScreen()
# 模拟一个文件名
filename = "test.toml"
result = screen.analyser(filename)
self.assertIsInstance(result, dict)
self.assertIn(0, result)
self.assertIn(1, result)
# 检查内容
self.assertIn("test.toml", result[0])
# 由于 electron 文件不存在, 应显示“尚未激活”
self.assertIn("尚未激活", result[1])
@unittest.skip("Pilot 测试需要进一步配置, 暂不运行")
class TestDashboardScreenPilot(unittest.TestCase):
"""使用 Textual Pilot 的集成测试."""
def setUp(self):
"""配置临时目录和配置."""
self.temp_dir = tempfile.TemporaryDirectory()
self.temp_path = pathlib.Path(self.temp_dir.name)
default_config_path = (
pathlib.Path(__file__).parent.parent.parent
/ "src/heurams/default/config/config.toml"
)
self.config = ConfigFile(default_config_path)
config_data = self.config.data
config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon")
config_data["paths"]["electron_dir"] = str(self.temp_path / "electron")
config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital")
config_data["paths"]["cache_dir"] = str(self.temp_path / "cache")
config_data["quick_pass"] = 0
config_data["daystamp_override"] = -1
config_data["timestamp_override"] = -1
for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]:
pathlib.Path(config_data["paths"][dir_key]).mkdir(
parents=True, exist_ok=True
)
self.config_ctx = ConfigContext(self.config)
self.config_ctx.__enter__()
def tearDown(self):
self.config_ctx.__exit__(None, None, None)
self.temp_dir.cleanup()
def test_dashboard_loads_with_pilot(self):
"""使用 Pilot 测试 DashboardScreen 加载."""
with patch("heurams.interface.__main__.environment_check"):
app = HeurAMSApp()
# 注意: Pilot 在 Textual 6.9.0 中的用法可能不同
# 以下为示例代码, 可能需要调整
pilot = Pilot(app)
# 等待应用启动
pilot.pause()
screen = app.screen
self.assertEqual(screen.__class__.__name__, "DashboardScreen")
union_list = app.query_one("#union-list")
self.assertIsNotNone(union_list)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,317 +0,0 @@
#!/usr/bin/env python3
"""
SyncScreen 和 SyncService 的测试.
"""
import pathlib
import tempfile
import time
import unittest
from unittest.mock import MagicMock, Mock, patch
from heurams.context import ConfigContext
from heurams.services.config import ConfigFile
from heurams.services.sync_service import (
ConflictStrategy,
SyncConfig,
SyncMode,
SyncService,
)
class TestSyncServiceUnit(unittest.TestCase):
"""SyncService 的单元测试."""
def setUp(self):
"""在每个测试之前运行, 设置临时目录和模拟客户端."""
self.temp_dir = tempfile.TemporaryDirectory()
self.temp_path = pathlib.Path(self.temp_dir.name)
# 创建测试文件
self.test_file = self.temp_path / "test.txt"
self.test_file.write_text("测试内容")
# 模拟 WebDAV 客户端
self.mock_client = MagicMock()
# 创建同步配置
self.config = SyncConfig(
enabled=True,
url="https://example.com/dav/",
username="test",
password="test",
remote_path="/heurams/",
sync_mode=SyncMode.BIDIRECTIONAL,
conflict_strategy=ConflictStrategy.NEWER,
verify_ssl=True,
)
def tearDown(self):
"""在每个测试之后清理."""
self.temp_dir.cleanup()
@patch("heurams.services.sync_service.Client")
def test_sync_service_initialization(self, mock_client_class):
"""测试同步服务初始化."""
mock_client_class.return_value = self.mock_client
service = SyncService(self.config)
# 验证客户端已创建
mock_client_class.assert_called_once()
self.assertIsNotNone(service.client)
self.assertEqual(service.config, self.config)
@patch("heurams.services.sync_service.Client")
def test_sync_service_disabled(self, mock_client_class):
"""测试同步服务未启用."""
config = SyncConfig(enabled=False)
service = SyncService(config)
# 客户端不应初始化
mock_client_class.assert_not_called()
self.assertIsNone(service.client)
@patch("heurams.services.sync_service.Client")
def test_test_connection_success(self, mock_client_class):
"""测试连接测试成功."""
mock_client_class.return_value = self.mock_client
self.mock_client.list.return_value = []
service = SyncService(self.config)
result = service.test_connection()
self.assertTrue(result)
self.mock_client.list.assert_called_once()
@patch("heurams.services.sync_service.Client")
def test_test_connection_failure(self, mock_client_class):
"""测试连接测试失败."""
mock_client_class.return_value = self.mock_client
self.mock_client.list.side_effect = Exception("连接失败")
service = SyncService(self.config)
result = service.test_connection()
self.assertFalse(result)
self.mock_client.list.assert_called_once()
@patch("heurams.services.sync_service.Client")
def test_upload_file(self, mock_client_class):
"""测试上传单个文件."""
mock_client_class.return_value = self.mock_client
service = SyncService(self.config)
result = service.upload_file(self.test_file)
self.assertTrue(result)
self.mock_client.upload_file.assert_called_once()
@patch("heurams.services.sync_service.Client")
def test_download_file(self, mock_client_class):
"""测试下载单个文件."""
mock_client_class.return_value = self.mock_client
service = SyncService(self.config)
remote_path = "/heurams/test.txt"
local_path = self.temp_path / "downloaded.txt"
result = service.download_file(remote_path, local_path)
self.assertTrue(result)
self.mock_client.download_file.assert_called_once()
self.assertTrue(local_path.parent.exists())
@patch("heurams.services.sync_service.Client")
def test_sync_directory_no_files(self, mock_client_class):
"""测试同步空目录."""
mock_client_class.return_value = self.mock_client
self.mock_client.list.return_value = []
self.mock_client.mkdir.return_value = None
service = SyncService(self.config)
result = service.sync_directory(self.temp_path)
self.assertTrue(result["success"])
self.assertEqual(result["uploaded"], 0)
self.assertEqual(result["downloaded"], 0)
self.mock_client.mkdir.assert_called_once()
@patch("heurams.services.sync_service.Client")
def test_sync_directory_upload_only(self, mock_client_class):
"""测试仅上传模式."""
mock_client_class.return_value = self.mock_client
self.mock_client.list.return_value = []
self.mock_client.mkdir.return_value = None
config = SyncConfig(
enabled=True,
url="https://example.com/dav/",
username="test",
password="test",
remote_path="/heurams/",
sync_mode=SyncMode.UPLOAD_ONLY,
conflict_strategy=ConflictStrategy.NEWER,
)
service = SyncService(config)
result = service.sync_directory(self.temp_path)
self.assertTrue(result["success"])
self.mock_client.mkdir.assert_called_once()
@patch("heurams.services.sync_service.Client")
def test_conflict_strategy_newer(self, mock_client_class):
"""测试 NEWER 冲突策略."""
mock_client_class.return_value = self.mock_client
# 模拟远程文件存在
self.mock_client.list.return_value = ["test.txt"]
self.mock_client.info.return_value = {
"size": 100,
"modified": "2023-01-01T00:00:00Z",
}
self.mock_client.mkdir.return_value = None
service = SyncService(self.config)
result = service.sync_directory(self.temp_path)
self.assertTrue(result["success"])
# 应该有一个冲突
self.assertGreaterEqual(result.get("conflicts", 0), 0)
@patch("heurams.services.sync_service.Client")
def test_create_sync_service_from_config(self, mock_client_class):
"""测试从配置文件创建同步服务."""
mock_client_class.return_value = self.mock_client
# 创建临时配置文件
config_data = {
"sync": {
"webdav": {
"enabled": True,
"url": "https://example.com/dav/",
"username": "test",
"password": "test",
"remote_path": "/heurams/",
"sync_mode": "bidirectional",
"conflict_strategy": "newer",
"verify_ssl": True,
}
}
}
# 模拟 config_var
with patch("heurams.services.sync_service.config_var") as mock_config_var:
mock_config = MagicMock()
mock_config.data = config_data
mock_config_var.get.return_value = mock_config
from heurams.services.sync_service import create_sync_service_from_config
service = create_sync_service_from_config()
self.assertIsNotNone(service)
self.assertIsNotNone(service.client)
class TestSyncScreenUnit(unittest.TestCase):
"""SyncScreen 的单元测试."""
def setUp(self):
"""在每个测试之前运行."""
self.temp_dir = tempfile.TemporaryDirectory()
self.temp_path = pathlib.Path(self.temp_dir.name)
# 创建默认配置
default_config_path = (
pathlib.Path(__file__).parent.parent.parent
/ "src/heurams/default/config/config.toml"
)
self.config = ConfigFile(default_config_path)
# 更新配置中的路径
config_data = self.config.data
config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon")
config_data["paths"]["electron_dir"] = str(self.temp_path / "electron")
config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital")
config_data["paths"]["cache_dir"] = str(self.temp_path / "cache")
# 添加同步配置
if "sync" not in config_data:
config_data["sync"] = {}
config_data["sync"]["webdav"] = {
"enabled": False,
"url": "",
"username": "",
"password": "",
"remote_path": "/heurams/",
"sync_mode": "bidirectional",
"conflict_strategy": "newer",
"verify_ssl": True,
}
# 创建目录
for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]:
pathlib.Path(config_data["paths"][dir_key]).mkdir(
parents=True, exist_ok=True
)
# 使用 ConfigContext 设置配置
self.config_ctx = ConfigContext(self.config)
self.config_ctx.__enter__()
def tearDown(self):
"""在每个测试之后清理."""
self.config_ctx.__exit__(None, None, None)
self.temp_dir.cleanup()
@patch("heurams.interface.screens.synctool.create_sync_service_from_config")
def test_sync_screen_compose(self, mock_create_service):
"""测试 SyncScreen 的 compose 方法."""
from heurams.interface.screens.synctool import SyncScreen
# 模拟同步服务
mock_service = MagicMock()
mock_service.client = MagicMock()
mock_create_service.return_value = mock_service
screen = SyncScreen()
# 测试 compose 方法
from textual.app import ComposeResult
result = screen.compose()
widgets = list(result)
# 检查基本部件
from textual.containers import ScrollableContainer
from textual.widgets import Button, Footer, Header, ProgressBar, Static
header_present = any(isinstance(w, Header) for w in widgets)
footer_present = any(isinstance(w, Footer) for w in widgets)
self.assertTrue(header_present)
self.assertTrue(footer_present)
# 检查容器
container_present = any(isinstance(w, ScrollableContainer) for w in widgets)
self.assertTrue(container_present)
@patch("heurams.interface.screens.synctool.create_sync_service_from_config")
def test_sync_screen_load_config(self, mock_create_service):
"""测试 SyncScreen 加载配置."""
from heurams.interface.screens.synctool import SyncScreen
mock_service = MagicMock()
mock_service.client = MagicMock()
mock_create_service.return_value = mock_service
screen = SyncScreen()
screen.load_config()
# 验证配置已加载
self.assertIsNotNone(screen.sync_config)
mock_create_service.assert_called_once()
if __name__ == "__main__":
unittest.main()

View File

@@ -1,186 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
from heurams.kernel.algorithms.sm2 import SM2Algorithm
class TestSM2Algorithm(unittest.TestCase):
"""测试 SM2Algorithm 类"""
def setUp(self):
# 模拟 timer 函数
self.timestamp_patcher = patch(
"heurams.kernel.algorithms.sm2.timer.get_timestamp"
)
self.daystamp_patcher = patch(
"heurams.kernel.algorithms.sm2.timer.get_daystamp"
)
self.mock_get_timestamp = self.timestamp_patcher.start()
self.mock_get_daystamp = self.daystamp_patcher.start()
# 设置固定返回值
self.mock_get_timestamp.return_value = 1000.0
self.mock_get_daystamp.return_value = 100
def tearDown(self):
self.timestamp_patcher.stop()
self.daystamp_patcher.stop()
def test_defaults(self):
"""测试默认值"""
defaults = SM2Algorithm.defaults
self.assertEqual(defaults["efactor"], 2.5)
self.assertEqual(defaults["real_rept"], 0)
self.assertEqual(defaults["rept"], 0)
self.assertEqual(defaults["interval"], 0)
self.assertEqual(defaults["last_date"], 0)
self.assertEqual(defaults["next_date"], 0)
self.assertEqual(defaults["is_activated"], 0)
# last_modify 是动态的, 仅检查存在性
self.assertIn("last_modify", defaults)
def test_revisor_feedback_minus_one(self):
"""测试 feedback = -1 时跳过更新"""
algodata = {SM2Algorithm.algo_name: SM2Algorithm.defaults.copy()}
SM2Algorithm.revisor(algodata, feedback=-1)
# 数据应保持不变
self.assertEqual(algodata[SM2Algorithm.algo_name]["efactor"], 2.5)
self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 0)
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 0)
def test_revisor_feedback_less_than_3(self):
"""测试 feedback < 3 重置 rept 和 interval"""
algodata = {
SM2Algorithm.algo_name: {
"efactor": 2.5,
"rept": 5,
"interval": 10,
"real_rept": 3,
}
}
SM2Algorithm.revisor(algodata, feedback=2)
self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 0)
# rept=0 导致 interval 被设置为 1
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 1)
self.assertEqual(algodata[SM2Algorithm.algo_name]["real_rept"], 4) # 递增
def test_revisor_feedback_greater_equal_3(self):
"""测试 feedback >= 3 递增 rept"""
algodata = {
SM2Algorithm.algo_name: {
"efactor": 2.5,
"rept": 2,
"interval": 6,
"real_rept": 2,
}
}
SM2Algorithm.revisor(algodata, feedback=4)
self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 3)
self.assertEqual(algodata[SM2Algorithm.algo_name]["real_rept"], 3)
# interval 应根据 rept 和 efactor 重新计算
# rept=3, interval = round(6 * 2.5) = 15
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 15)
def test_revisor_new_activation(self):
"""测试 is_new_activation 重置 rept 和 efactor"""
algodata = {
SM2Algorithm.algo_name: {
"efactor": 3.0,
"rept": 5,
"interval": 20,
"real_rept": 5,
}
}
SM2Algorithm.revisor(algodata, feedback=5, is_new_activation=True)
self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 0)
self.assertEqual(algodata[SM2Algorithm.algo_name]["efactor"], 2.5)
# interval 应为 1(因为 rept=0)
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 1)
def test_revisor_efactor_calculation(self):
"""测试 efactor 计算"""
algodata = {
SM2Algorithm.algo_name: {
"efactor": 2.5,
"rept": 1,
"interval": 6,
"real_rept": 1,
}
}
SM2Algorithm.revisor(algodata, feedback=5)
# efactor = 2.5 + (0.1 - (5-5)*(0.08 + (5-5)*0.02)) = 2.5 + 0.1 = 2.6
self.assertAlmostEqual(
algodata[SM2Algorithm.algo_name]["efactor"], 2.6, places=6
)
# 测试 efactor 下限为 1.3
algodata[SM2Algorithm.algo_name]["efactor"] = 1.2
SM2Algorithm.revisor(algodata, feedback=5)
self.assertEqual(algodata[SM2Algorithm.algo_name]["efactor"], 1.3)
def test_revisor_interval_calculation(self):
"""测试 interval 计算规则"""
algodata = {
SM2Algorithm.algo_name: {
"efactor": 2.5,
"rept": 0,
"interval": 0,
"real_rept": 0,
}
}
SM2Algorithm.revisor(algodata, feedback=4)
# rept 从 0 递增到 1, 因此 interval 应为 6
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 6)
# 现在 rept=1, 再次调用 revisor 递增到 2
SM2Algorithm.revisor(algodata, feedback=4)
# rept=2, interval = round(6 * 2.5) = 15
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 15)
# 单独测试 rept=1 的情况
algodata2 = {
SM2Algorithm.algo_name: {
"efactor": 2.5,
"rept": 1,
"interval": 0,
"real_rept": 0,
}
}
SM2Algorithm.revisor(algodata2, feedback=4)
# rept 递增到 2, interval = round(0 * 2.5) = 0
self.assertEqual(algodata2[SM2Algorithm.algo_name]["interval"], 0)
def test_revisor_updates_dates(self):
"""测试更新日期字段"""
algodata = {SM2Algorithm.algo_name: SM2Algorithm.defaults.copy()}
self.mock_get_daystamp.return_value = 200
SM2Algorithm.revisor(algodata, feedback=5)
self.assertEqual(algodata[SM2Algorithm.algo_name]["last_date"], 200)
self.assertEqual(
algodata[SM2Algorithm.algo_name]["next_date"],
200 + algodata[SM2Algorithm.algo_name]["interval"],
)
self.assertEqual(algodata[SM2Algorithm.algo_name]["last_modify"], 1000.0)
def test_is_due(self):
"""测试 is_due 方法"""
algodata = {SM2Algorithm.algo_name: {"next_date": 100}}
self.mock_get_daystamp.return_value = 150
self.assertTrue(SM2Algorithm.is_due(algodata))
algodata[SM2Algorithm.algo_name]["next_date"] = 200
self.assertFalse(SM2Algorithm.is_due(algodata))
def test_rate(self):
"""测试 rate 方法"""
algodata = {SM2Algorithm.algo_name: {"efactor": 2.7}}
self.assertEqual(SM2Algorithm.get_rating(algodata), "2.7")
def test_nextdate(self):
"""测试 nextdate 方法"""
algodata = {SM2Algorithm.algo_name: {"next_date": 12345}}
self.assertEqual(SM2Algorithm.nextdate(algodata), 12345)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,202 +0,0 @@
import json
import pathlib
import tempfile
import unittest
from unittest.mock import MagicMock, patch
import toml
from heurams.context import ConfigContext
from heurams.kernel.particles.atom import Atom, atom_registry
from heurams.kernel.particles.electron import Electron
from heurams.kernel.particles.nucleon import Nucleon
from heurams.kernel.particles.orbital import Orbital
from heurams.services.config import ConfigFile
class TestAtom(unittest.TestCase):
"""测试 Atom 类"""
def setUp(self):
"""在每个测试之前运行"""
# 创建临时目录用于持久化测试
self.temp_dir = tempfile.TemporaryDirectory()
self.temp_path = pathlib.Path(self.temp_dir.name)
# 创建默认配置
self.config = ConfigFile(
pathlib.Path(__file__).parent.parent.parent.parent
/ "src/heurams/default/config/config.toml"
)
# 使用 ConfigContext 设置配置
self.config_ctx = ConfigContext(self.config)
self.config_ctx.__enter__()
# 清空全局注册表
atom_registry.clear()
def tearDown(self):
"""在每个测试之后运行"""
self.config_ctx.__exit__(None, None, None)
self.temp_dir.cleanup()
atom_registry.clear()
def test_init(self):
"""测试 Atom 初始化"""
atom = Atom("test_atom")
self.assertEqual(atom.ident, "test_atom")
self.assertIn("test_atom", atom_registry)
self.assertEqual(atom_registry["test_atom"], atom)
# 检查 registry 默认值
self.assertIsNone(atom.registry["nucleon"])
self.assertIsNone(atom.registry["electron"])
self.assertIsNone(atom.registry["orbital"])
self.assertEqual(atom.registry["nucleon_fmt"], "toml")
self.assertEqual(atom.registry["electron_fmt"], "json")
self.assertEqual(atom.registry["orbital_fmt"], "toml")
def test_link(self):
"""测试 link 方法"""
atom = Atom("test_link")
nucleon = Nucleon("test_nucleon", {"content": "test content"})
atom.link("nucleon", nucleon)
self.assertEqual(atom.registry["nucleon"], nucleon)
# 测试链接不支持的键
with self.assertRaises(ValueError):
atom.link("invalid_key", "value")
def test_link_triggers_do_eval(self):
"""测试 link 后触发 do_eval"""
atom = Atom("test_eval_trigger")
nucleon = Nucleon("test_nucleon", {"content": "eval:1+1"})
with patch.object(atom, "do_eval") as mock_do_eval:
atom.link("nucleon", nucleon)
mock_do_eval.assert_called_once()
def test_persist_toml(self):
"""测试 TOML 持久化"""
atom = Atom("test_persist_toml")
nucleon = Nucleon("test_nucleon", {"content": "test"})
atom.link("nucleon", nucleon)
# 设置路径
test_path = self.temp_path / "test.toml"
atom.link("nucleon_path", test_path)
atom.persist("nucleon")
# 验证文件存在且内容正确
self.assertTrue(test_path.exists())
with open(test_path, "r") as f:
data = toml.load(f)
self.assertEqual(data["ident"], "test_nucleon")
self.assertEqual(data["payload"]["content"], "test")
def test_persist_json(self):
"""测试 JSON 持久化"""
atom = Atom("test_persist_json")
electron = Electron("test_electron", {})
atom.link("electron", electron)
test_path = self.temp_path / "test.json"
atom.link("electron_path", test_path)
atom.persist("electron")
self.assertTrue(test_path.exists())
with open(test_path, "r") as f:
data = json.load(f)
self.assertIn("supermemo2", data)
def test_persist_invalid_format(self):
"""测试无效持久化格式"""
atom = Atom("test_invalid_format")
nucleon = Nucleon("test_nucleon", {})
atom.link("nucleon", nucleon)
atom.link("nucleon_path", self.temp_path / "test.txt")
atom.registry["nucleon_fmt"] = "invalid"
with self.assertRaises(KeyError):
atom.persist("nucleon")
def test_persist_no_path(self):
"""测试未初始化路径的持久化"""
atom = Atom("test_no_path")
nucleon = Nucleon("test_nucleon", {})
atom.link("nucleon", nucleon)
# 不设置 nucleon_path
with self.assertRaises(TypeError):
atom.persist("nucleon")
def test_getitem_setitem(self):
"""测试 __getitem__ 和 __setitem__"""
atom = Atom("test_getset")
nucleon = Nucleon("test_nucleon", {})
atom["nucleon"] = nucleon
self.assertEqual(atom["nucleon"], nucleon)
# 测试不支持的键
with self.assertRaises(KeyError):
_ = atom["invalid_key"]
with self.assertRaises(KeyError):
atom["invalid_key"] = "value"
def test_do_eval_with_eval_string(self):
"""测试 do_eval 处理 eval: 字符串"""
atom = Atom("test_do_eval")
nucleon = Nucleon(
"test_nucleon",
{"content": "eval:'hello' + ' world'", "number": "eval:2 + 3"},
)
atom.link("nucleon", nucleon)
# do_eval 应该在链接时自动调用
# 检查 eval 表达式是否被求值
self.assertEqual(nucleon.payload["content"], "hello world")
self.assertEqual(nucleon.payload["number"], "5")
def test_do_eval_with_config_access(self):
"""测试 do_eval 访问配置"""
atom = Atom("test_eval_config")
nucleon = Nucleon(
"test_nucleon", {"max_riddles": "eval:default['mcq']['max_riddles_num']"}
)
atom.link("nucleon", nucleon)
# 配置中 puzzles.mcq.max_riddles_num = 2
self.assertEqual(nucleon.payload["max_riddles"], 2)
def test_placeholder(self):
"""测试静态方法 placeholder"""
placeholder = Atom.placeholder()
self.assertIsInstance(placeholder, tuple)
self.assertEqual(len(placeholder), 3)
self.assertIsInstance(placeholder[0], Electron)
self.assertIsInstance(placeholder[1], Nucleon)
self.assertIsInstance(placeholder[2], dict)
def test_atom_registry_management(self):
"""测试全局注册表管理"""
# 创建多个 Atom
atom1 = Atom("atom1")
atom2 = Atom("atom2")
self.assertEqual(len(atom_registry), 2)
self.assertEqual(atom_registry["atom1"], atom1)
self.assertEqual(atom_registry["atom2"], atom2)
# 测试 bidict 的反向查找
self.assertEqual(atom_registry.inverse[atom1], "atom1")
self.assertEqual(atom_registry.inverse[atom2], "atom2")
if __name__ == "__main__":
unittest.main()

View File

@@ -1,179 +0,0 @@
import sys
import unittest
from unittest.mock import MagicMock, patch
from heurams.kernel.algorithms import algorithms
from heurams.kernel.particles.electron import Electron
class TestElectron(unittest.TestCase):
"""测试 Electron 类"""
def setUp(self):
# 模拟 timer.get_timestamp 返回固定值
self.timestamp_patcher = patch(
"heurams.kernel.particles.electron.timer.get_timestamp"
)
self.mock_get_timestamp = self.timestamp_patcher.start()
self.mock_get_timestamp.return_value = 1234567890.0
def tearDown(self):
self.timestamp_patcher.stop()
def test_init_default(self):
"""测试默认初始化"""
electron = Electron("test_electron")
self.assertEqual(electron.ident, "test_electron")
self.assertEqual(electron.algo, algorithms["supermemo2"])
self.assertIn(electron.algo, electron.algodata)
self.assertIsInstance(electron.algodata[electron.algo], dict)
# 检查默认值(排除动态字段)
defaults = electron.algo.defaults
for key, value in defaults.items():
if key == "last_modify":
# last_modify 是动态的, 只检查存在性
self.assertIn(key, electron.algodata[electron.algo])
elif key == "is_activated":
# TODO: 调查为什么 is_activated 是 1
self.assertEqual(electron.algodata[electron.algo][key], 1)
else:
self.assertEqual(electron.algodata[electron.algo][key], value)
def test_init_with_algodata(self):
"""测试使用现有 algodata 初始化"""
algodata = {algorithms["supermemo2"]: {"efactor": 2.5, "interval": 1}}
electron = Electron("test_electron", algodata=algodata)
self.assertEqual(electron.algodata[electron.algo]["efactor"], 2.5)
self.assertEqual(electron.algodata[electron.algo]["interval"], 1)
# 其他字段可能不存在, 因为未提供默认初始化
# 检查 real_rept 不存在
self.assertNotIn("real_rept", electron.algodata[electron.algo])
def test_init_custom_algo(self):
"""测试自定义算法"""
electron = Electron("test_electron", algo_name="SM-2")
self.assertEqual(electron.algo, algorithms["SM-2"])
self.assertIn(electron.algo, electron.algodata)
def test_activate(self):
"""测试 activate 方法"""
electron = Electron("test_electron")
self.assertEqual(electron.algodata[electron.algo]["is_activated"], 0)
electron.activate()
self.assertEqual(electron.algodata[electron.algo]["is_activated"], 1)
self.assertEqual(electron.algodata[electron.algo]["last_modify"], 1234567890.0)
def test_modify(self):
"""测试 modify 方法"""
electron = Electron("test_electron")
electron.modify("interval", 5)
self.assertEqual(electron.algodata[electron.algo]["interval"], 5)
self.assertEqual(electron.algodata[electron.algo]["last_modify"], 1234567890.0)
# 修改不存在的字段应记录警告但不引发异常
with patch("heurams.kernel.particles.electron.logger.warning") as mock_warning:
electron.modify("unknown_field", 99)
mock_warning.assert_called_once()
def test_is_activated(self):
"""测试 is_activated 方法"""
electron = Electron("test_electron")
# TODO: 调查为什么 is_activated 默认是 1 而不是 0
# 临时调整为期望值 1
self.assertEqual(electron.is_activated(), 1)
electron.activate()
self.assertEqual(electron.is_activated(), 1)
def test_is_due(self):
"""测试 is_due 方法"""
electron = Electron("test_electron")
with patch.object(electron.algo, "is_due") as mock_is_due:
mock_is_due.return_value = 1
result = electron.is_due()
mock_is_due.assert_called_once_with(electron.algodata)
self.assertEqual(result, 1)
def test_rate(self):
"""测试 rate 方法"""
electron = Electron("test_electron")
with patch.object(electron.algo, "rate") as mock_rate:
mock_rate.return_value = "good"
result = electron.get_rating()
mock_rate.assert_called_once_with(electron.algodata)
self.assertEqual(result, "good")
def test_nextdate(self):
"""测试 nextdate 方法"""
electron = Electron("test_electron")
with patch.object(electron.algo, "nextdate") as mock_nextdate:
mock_nextdate.return_value = 1234568000
result = electron.nextdate()
mock_nextdate.assert_called_once_with(electron.algodata)
self.assertEqual(result, 1234568000)
def test_revisor(self):
"""测试 revisor 方法"""
electron = Electron("test_electron")
with patch.object(electron.algo, "revisor") as mock_revisor:
electron.revisor(quality=3, is_new_activation=True)
mock_revisor.assert_called_once_with(electron.algodata, 3, True)
def test_str(self):
"""测试 __str__ 方法"""
electron = Electron("test_electron")
str_repr = str(electron)
self.assertIn("记忆单元预览", str_repr)
self.assertIn("test_electron", str_repr)
# 算法类名会出现在字符串表示中
self.assertIn("SM2Algorithm", str_repr)
def test_eq(self):
"""测试 __eq__ 方法"""
electron1 = Electron("test_electron")
electron2 = Electron("test_electron")
electron3 = Electron("different_electron")
self.assertEqual(electron1, electron2)
self.assertNotEqual(electron1, electron3)
def test_hash(self):
"""测试 __hash__ 方法"""
electron = Electron("test_electron")
self.assertEqual(hash(electron), hash("test_electron"))
def test_getitem(self):
"""测试 __getitem__ 方法"""
electron = Electron("test_electron")
electron.activate()
self.assertEqual(electron["ident"], "test_electron")
self.assertEqual(electron["is_activated"], 1)
with self.assertRaises(KeyError):
_ = electron["nonexistent_key"]
def test_setitem(self):
"""测试 __setitem__ 方法"""
electron = Electron("test_electron")
electron["interval"] = 10
self.assertEqual(electron.algodata[electron.algo]["interval"], 10)
self.assertEqual(electron.algodata[electron.algo]["last_modify"], 1234567890.0)
with self.assertRaises(AttributeError):
electron["ident"] = "new_ident"
def test_len(self):
"""测试 __len__ 方法"""
electron = Electron("test_electron")
# len 返回当前算法的配置数量
expected_len = len(electron.algo.defaults)
self.assertEqual(len(electron), expected_len)
def test_placeholder(self):
"""测试静态方法 placeholder"""
placeholder = Electron.placeholder()
self.assertIsInstance(placeholder, Electron)
self.assertEqual(placeholder.ident, "电子对象样例内容")
self.assertEqual(placeholder.algo, algorithms["supermemo2"])
if __name__ == "__main__":
unittest.main()

View File

@@ -1,108 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
from heurams.kernel.particles.nucleon import Nucleon
class TestNucleon(unittest.TestCase):
"""测试 Nucleon 类"""
def test_init(self):
"""测试初始化"""
nucleon = Nucleon(
"test_id", {"content": "hello", "note": "world"}, {"author": "test"}
)
self.assertEqual(nucleon.ident, "test_id")
self.assertEqual(nucleon.payload, {"content": "hello", "note": "world"})
self.assertEqual(nucleon.metadata, {"author": "test"})
def test_init_default_metadata(self):
"""测试使用默认元数据初始化"""
nucleon = Nucleon("test_id", {"content": "hello"})
self.assertEqual(nucleon.ident, "test_id")
self.assertEqual(nucleon.payload, {"content": "hello"})
self.assertEqual(nucleon.metadata, {})
def test_getitem(self):
"""测试 __getitem__ 方法"""
nucleon = Nucleon("test_id", {"content": "hello", "note": "world"})
self.assertEqual(nucleon["ident"], "test_id")
self.assertEqual(nucleon["content"], "hello")
self.assertEqual(nucleon["note"], "world")
with self.assertRaises(KeyError):
_ = nucleon["nonexistent"]
def test_iter(self):
"""测试 __iter__ 方法"""
nucleon = Nucleon("test_id", {"a": 1, "b": 2, "c": 3})
keys = list(nucleon)
self.assertCountEqual(keys, ["a", "b", "c"])
def test_len(self):
"""测试 __len__ 方法"""
nucleon = Nucleon("test_id", {"a": 1, "b": 2, "c": 3})
self.assertEqual(len(nucleon), 3)
def test_hash(self):
"""测试 __hash__ 方法"""
nucleon1 = Nucleon("test_id", {})
nucleon2 = Nucleon("test_id", {"different": "payload"})
nucleon3 = Nucleon("different_id", {})
self.assertEqual(hash(nucleon1), hash(nucleon2)) # 相同 ident
self.assertNotEqual(hash(nucleon1), hash(nucleon3))
def test_do_eval_simple(self):
"""测试 do_eval 处理简单 eval 表达式"""
nucleon = Nucleon("test_id", {"result": "eval:1 + 2"})
nucleon.do_eval()
self.assertEqual(nucleon.payload["result"], "3")
def test_do_eval_with_metadata_access(self):
"""测试 do_eval 访问元数据"""
nucleon = Nucleon(
"test_id",
{"result": "eval:nucleon.metadata.get('value', 0)"},
{"value": 42},
)
nucleon.do_eval()
self.assertEqual(nucleon.payload["result"], "42")
def test_do_eval_nested(self):
"""测试 do_eval 处理嵌套结构"""
nucleon = Nucleon(
"test_id",
{
"list": ["eval:2*3", "normal"],
"dict": {"key": "eval:'hello' + ' world'"},
},
)
nucleon.do_eval()
self.assertEqual(nucleon.payload["list"][0], "6")
self.assertEqual(nucleon.payload["list"][1], "normal")
self.assertEqual(nucleon.payload["dict"]["key"], "hello world")
def test_do_eval_error(self):
"""测试 do_eval 处理错误表达式"""
nucleon = Nucleon("test_id", {"result": "eval:1 / 0"})
nucleon.do_eval()
self.assertIn("此 eval 实例发生错误", nucleon.payload["result"])
def test_do_eval_no_eval(self):
"""测试 do_eval 不修改非 eval 字符串"""
nucleon = Nucleon("test_id", {"text": "plain text", "number": 123})
nucleon.do_eval()
self.assertEqual(nucleon.payload["text"], "plain text")
self.assertEqual(nucleon.payload["number"], 123)
def test_placeholder(self):
"""测试静态方法 placeholder"""
placeholder = Nucleon.placeholder()
self.assertIsInstance(placeholder, Nucleon)
self.assertEqual(placeholder.ident, "核子对象样例内容")
self.assertEqual(placeholder.payload, {})
self.assertEqual(placeholder.metadata, {})
if __name__ == "__main__":
unittest.main()

View File

@@ -1,23 +0,0 @@
import unittest
from unittest.mock import Mock
from heurams.kernel.evaluators.base import BasePuzzle
class TestBasePuzzle(unittest.TestCase):
"""测试 BasePuzzle 基类"""
def test_refresh_not_implemented(self):
"""测试 refresh 方法未实现时抛出异常"""
puzzle = BasePuzzle()
with self.assertRaises(NotImplementedError):
puzzle.refresh()
def test_str(self):
"""测试 __str__ 方法"""
puzzle = BasePuzzle()
self.assertEqual(str(puzzle), "谜题: BasePuzzle")
if __name__ == "__main__":
unittest.main()

View File

@@ -1,51 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
from heurams.kernel.evaluators.cloze import ClozePuzzle
class TestClozePuzzle(unittest.TestCase):
"""测试 ClozePuzzle 类"""
def test_init(self):
"""测试初始化"""
puzzle = ClozePuzzle("hello/world/test", min_denominator=3, delimiter="/")
self.assertEqual(puzzle.text, "hello/world/test")
self.assertEqual(puzzle.min_denominator, 3)
self.assertEqual(puzzle.delimiter, "/")
self.assertEqual(puzzle.wording, "填空题 - 尚未刷新谜题")
self.assertEqual(puzzle.answer, ["填空题 - 尚未刷新谜题"])
@patch("random.sample")
def test_refresh(self, mock_sample):
"""测试 refresh 方法"""
mock_sample.return_value = [0, 2] # 选择索引 0 和 2
puzzle = ClozePuzzle("hello/world/test", min_denominator=2, delimiter="/")
puzzle.refresh()
# 检查 wording 和 answer
self.assertNotEqual(puzzle.wording, "填空题 - 尚未刷新谜题")
self.assertNotEqual(puzzle.answer, ["填空题 - 尚未刷新谜题"])
# 根据模拟, 应该有两个填空
self.assertEqual(len(puzzle.answer), 2)
self.assertEqual(puzzle.answer, ["hello", "test"])
# wording 应包含下划线
self.assertIn("__", puzzle.wording)
def test_refresh_empty_text(self):
"""测试空文本的 refresh"""
puzzle = ClozePuzzle("", min_denominator=3, delimiter="/")
puzzle.refresh() # 不应引发异常
# 空文本导致 wording 和 answer 为空
self.assertEqual(puzzle.wording, "")
self.assertEqual(puzzle.answer, [])
def test_str(self):
"""测试 __str__ 方法"""
puzzle = ClozePuzzle("hello/world", min_denominator=2, delimiter="/")
str_repr = str(puzzle)
self.assertIn("填空题 - 尚未刷新谜题", str_repr)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,122 +0,0 @@
import unittest
from unittest.mock import MagicMock, call, patch
from heurams.kernel.evaluators.mcq import MCQPuzzle
class TestMCQPuzzle(unittest.TestCase):
"""测试 MCQPuzzle 类"""
def test_init(self):
"""测试初始化"""
mapping = {"q1": "a1", "q2": "a2"}
jammer = ["j1", "j2", "j3"]
puzzle = MCQPuzzle(mapping, jammer, max_riddles_num=3, prefix="选择")
self.assertEqual(puzzle.prefix, "选择")
self.assertEqual(puzzle.mapping, mapping)
self.assertEqual(puzzle.max_riddles_num, 3)
# jammer 应合并正确答案并去重
self.assertIn("a1", puzzle.jammer)
self.assertIn("a2", puzzle.jammer)
self.assertIn("j1", puzzle.jammer)
# 初始状态
self.assertEqual(puzzle.wording, "选择题 - 尚未刷新谜题")
self.assertEqual(puzzle.answer, ["选择题 - 尚未刷新谜题"])
self.assertEqual(puzzle.options, [])
def test_init_max_riddles_num_clamping(self):
"""测试 max_riddles_num 限制在 1-5 之间"""
puzzle1 = MCQPuzzle({}, [], max_riddles_num=0)
self.assertEqual(puzzle1.max_riddles_num, 1)
puzzle2 = MCQPuzzle({}, [], max_riddles_num=10)
self.assertEqual(puzzle2.max_riddles_num, 5)
def test_init_jammer_ensures_minimum(self):
"""测试干扰项至少保证 4 个"""
puzzle = MCQPuzzle({}, [])
# 正确答案为空, 干扰项为空, 应填充空格
self.assertEqual(len(puzzle.jammer), 4)
self.assertEqual(set(puzzle.jammer), {" "}) # 三个空格? 实际上循环填充空格
@patch("random.sample")
@patch("random.shuffle")
def test_refresh(self, mock_shuffle, mock_sample):
"""测试 refresh 方法生成题目"""
mapping = {"q1": "a1", "q2": "a2", "q3": "a3"}
jammer = ["j1", "j2", "j3", "j4"]
puzzle = MCQPuzzle(mapping, jammer, max_riddles_num=2)
# 模拟 random.sample 返回前两个映射项
mock_sample.side_effect = [
[("q1", "a1"), ("q2", "a2")], # 选择问题
["j1", "j2", "j3"], # 为每个问题选择干扰项(实际调用两次)
]
puzzle.refresh()
# 检查 wording 是列表
self.assertIsInstance(puzzle.wording, list)
self.assertEqual(len(puzzle.wording), 2)
# 检查 answer 列表
self.assertEqual(puzzle.answer, ["a1", "a2"])
# 检查 options 列表
self.assertEqual(len(puzzle.options), 2)
# 每个选项列表应包含 4 个选项(正确答案 + 3 个干扰项)
self.assertEqual(len(puzzle.options[0]), 4)
self.assertEqual(len(puzzle.options[1]), 4)
# random.shuffle 应被调用
self.assertEqual(mock_shuffle.call_count, 2)
def test_refresh_empty_mapping(self):
"""测试空 mapping 的 refresh"""
puzzle = MCQPuzzle({}, [])
puzzle.refresh()
self.assertEqual(puzzle.wording, "无可用题目")
self.assertEqual(puzzle.answer, ["无答案"])
self.assertEqual(puzzle.options, [])
def test_get_question_count(self):
"""测试 get_question_count 方法"""
puzzle = MCQPuzzle({"q": "a"}, [])
self.assertEqual(puzzle.get_question_count(), 0) # 未刷新
puzzle.refresh = MagicMock()
puzzle.wording = ["Q1", "Q2"]
self.assertEqual(puzzle.get_question_count(), 2)
puzzle.wording = "无可用题目"
self.assertEqual(puzzle.get_question_count(), 0)
puzzle.wording = "单个问题"
self.assertEqual(puzzle.get_question_count(), 1)
def test_get_correct_answer_for_question(self):
"""测试 get_correct_answer_for_question 方法"""
puzzle = MCQPuzzle({}, [])
puzzle.answer = ["ans1", "ans2"]
self.assertEqual(puzzle.get_correct_answer_for_question(0), "ans1")
self.assertEqual(puzzle.get_correct_answer_for_question(1), "ans2")
self.assertIsNone(puzzle.get_correct_answer_for_question(2))
puzzle.answer = "not a list"
self.assertIsNone(puzzle.get_correct_answer_for_question(0))
def test_get_options_for_question(self):
"""测试 get_options_for_question 方法"""
puzzle = MCQPuzzle({}, [])
puzzle.options = [["a", "b", "c", "d"], ["e", "f", "g", "h"]]
self.assertEqual(puzzle.get_options_for_question(0), ["a", "b", "c", "d"])
self.assertEqual(puzzle.get_options_for_question(1), ["e", "f", "g", "h"])
self.assertIsNone(puzzle.get_options_for_question(2))
def test_str(self):
"""测试 __str__ 方法"""
puzzle = MCQPuzzle({}, [])
puzzle.wording = "选择题 - 尚未刷新谜题"
puzzle.answer = ["选择题 - 尚未刷新谜题"]
self.assertIn("选择题 - 尚未刷新谜题", str(puzzle))
self.assertIn("正确答案", str(puzzle))
puzzle.wording = ["Q1", "Q2"]
puzzle.answer = ["A1", "A2"]
str_repr = str(puzzle)
self.assertIn("Q1", str_repr)
self.assertIn("A1, A2", str_repr)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,114 +0,0 @@
import unittest
from unittest.mock import MagicMock, Mock, patch
from heurams.kernel.particles.atom import Atom
from heurams.kernel.particles.electron import Electron
from heurams.kernel.reactor.procession import Phaser
from heurams.kernel.reactor.states import PhaserState, ProcessionState
class TestPhaser(unittest.TestCase):
"""测试 Phaser 类"""
def setUp(self):
# 创建模拟的 Atom 对象
self.atom_new = Mock(spec=Atom)
self.atom_new.registry = {"electron": Mock(spec=Electron)}
self.atom_new.registry["electron"].is_activated.return_value = False
self.atom_old = Mock(spec=Atom)
self.atom_old.registry = {"electron": Mock(spec=Electron)}
self.atom_old.registry["electron"].is_activated.return_value = True
# 模拟 Procession 类以避免复杂依赖
self.procession_patcher = patch("heurams.kernel.reactor.phaser.Procession")
self.mock_procession_class = self.procession_patcher.start()
def tearDown(self):
self.procession_patcher.stop()
def test_init_with_mixed_atoms(self):
"""测试混合新旧原子的初始化"""
atoms = [self.atom_old, self.atom_new, self.atom_old]
phaser = Phaser(atoms)
# 应该创建两个 Procession: 一个用于旧原子, 一个用于新原子, 以及一个总体复习
self.assertEqual(self.mock_procession_class.call_count, 3)
# 检查调用参数
calls = self.mock_procession_class.call_args_list
# 第一个调用应该是旧原子的初始复习
self.assertEqual(calls[0][0][0], [self.atom_old, self.atom_old])
self.assertEqual(calls[0][0][1], PhaserState.QUICK_REVIEW)
# 第二个调用应该是新原子的识别阶段
self.assertEqual(calls[1][0][0], [self.atom_new])
self.assertEqual(calls[1][0][1], PhaserState.RECOGNITION)
# 第三个调用应该是所有原子的总体复习
self.assertEqual(calls[2][0][0], atoms)
self.assertEqual(calls[2][0][1], PhaserState.FINAL_REVIEW)
def test_init_only_old_atoms(self):
"""测试只有旧原子"""
atoms = [self.atom_old, self.atom_old]
phaser = Phaser(atoms)
# 应该创建两个 Procession: 一个初始复习, 一个总体复习
self.assertEqual(self.mock_procession_class.call_count, 2)
calls = self.mock_procession_class.call_args_list
self.assertEqual(calls[0][0][0], atoms)
self.assertEqual(calls[0][0][1], PhaserState.QUICK_REVIEW)
self.assertEqual(calls[1][0][0], atoms)
self.assertEqual(calls[1][0][1], PhaserState.FINAL_REVIEW)
def test_init_only_new_atoms(self):
"""测试只有新原子"""
atoms = [self.atom_new, self.atom_new]
phaser = Phaser(atoms)
self.assertEqual(self.mock_procession_class.call_count, 2)
calls = self.mock_procession_class.call_args_list
self.assertEqual(calls[0][0][0], atoms)
self.assertEqual(calls[0][0][1], PhaserState.RECOGNITION)
self.assertEqual(calls[1][0][0], atoms)
self.assertEqual(calls[1][0][1], PhaserState.FINAL_REVIEW)
def test_current_procession_finds_unfinished(self):
"""测试 current_procession 找到未完成的 Procession"""
# 创建模拟 Procession 实例
mock_proc1 = Mock()
mock_proc1.state = ProcessionState.FINISHED
mock_proc2 = Mock()
mock_proc2.state = ProcessionState.RUNNING
mock_proc2.phase = PhaserState.QUICK_REVIEW
phaser = Phaser([])
phaser.processions = [mock_proc1, mock_proc2]
result = phaser.current_procession()
self.assertEqual(result, mock_proc2)
self.assertEqual(phaser.state, PhaserState.QUICK_REVIEW)
def test_current_procession_all_finished(self):
"""测试所有 Procession 都完成"""
mock_proc = Mock()
mock_proc.state = ProcessionState.FINISHED
phaser = Phaser([])
phaser.processions = [mock_proc]
result = phaser.current_procession()
self.assertEqual(result, 0)
self.assertEqual(phaser.state, PhaserState.FINISHED)
def test_current_procession_empty(self):
"""测试没有 Procession"""
phaser = Phaser([])
phaser.processions = []
result = phaser.current_procession()
self.assertEqual(result, 0)
self.assertEqual(phaser.state, PhaserState.FINISHED)
if __name__ == "__main__":
unittest.main()