feat: 实验性 SM-15M 算法实现

实验性 SM-15M 逆向工程算法实现
This commit is contained in:
2025-12-21 02:15:23 +08:00
parent 243eea864b
commit 98ec6504a4
16 changed files with 1875 additions and 19 deletions

View File

@@ -82,7 +82,7 @@ python -m heurams.interface
## 配置
配置文件位于 `config/config.toml`相对于工作目录. 如果不存在, 会使用内置的默认配置.
配置文件位于 `config/config.toml`(相对于工作目录). 如果不存在, 会使用内置的默认配置.
## 项目结构

View File

@@ -49,7 +49,7 @@ class BasicEvaluation(BasePuzzleWidget):
# 显示主要内容
yield Label(self.atom.registry["nucleon"]["content"], id="main")
# 显示评估说明可选
# 显示评估说明(可选)
yield Static("请评估你对这个内容的记忆程度: ", classes="instruction")
# 按钮容器

View File

@@ -116,7 +116,7 @@ class MCQPuzzle(BasePuzzleWidget):
self.screen.rating = rating # type: ignore
self.handler(rating)
# 重置输入如果回答错误
# 重置输入(如果回答错误)
if not is_correct:
self.inputlist = []
self.refresh_buttons()
@@ -127,7 +127,7 @@ class MCQPuzzle(BasePuzzleWidget):
self.update_display()
def refresh_buttons(self):
"""刷新按钮显示用于题目切换"""
"""刷新按钮显示(用于题目切换)"""
# 移除所有选项按钮
logger.debug("刷新按钮")
self.cursor += 1

View File

@@ -1,6 +1,7 @@
from heurams.services.logger import get_logger
from .sm2 import SM2Algorithm
from .sm15m import SM15MAlgorithm
logger = get_logger(__name__)
@@ -10,7 +11,8 @@ __all__ = [
algorithms = {
"SM-2": SM2Algorithm,
"supermemo2": SM2Algorithm,
"SM-15M": SM15MAlgorithm,
# "SM-15M": SM15MAlgorithm,
}
logger.debug("算法模块初始化完成, 注册的算法: %s", list(algorithms.keys()))

View File

@@ -0,0 +1,281 @@
"""
SM-15 接口兼容实现, 基于 SM-15 算法的逆向工程
全局状态保存在文件中, 项目状态通过 algodata 字典传递
基于: https://github.com/kazuaki/sm.js
原始 CoffeeScript 代码: (c) 2014 Kazuaki Tanida (MIT 许可证)
"""
import datetime
import json
import os
from typing import TypedDict
from heurams.kernel.algorithms.sm15m_calc import (MAX_AF, MIN_AF, NOTCH_AF,
RANGE_AF, RANGE_REPETITION,
SM, THRESHOLD_RECALL, Item)
# 全局状态文件路径
_GLOBAL_STATE_FILE = os.path.expanduser("~/.sm15_global_state.json")
def _get_global_sm():
"""获取全局 SM 实例, 从文件加载或创建新的"""
if os.path.exists(_GLOBAL_STATE_FILE):
try:
with open(_GLOBAL_STATE_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
sm_instance = SM.load(data)
return sm_instance
except Exception:
# 如果加载失败, 创建新的实例
pass
# 创建新的 SM 实例
sm_instance = SM()
# 保存初始状态
_save_global_sm(sm_instance)
return sm_instance
def _save_global_sm(sm_instance):
"""保存全局 SM 实例到文件"""
try:
data = sm_instance.data()
with open(_GLOBAL_STATE_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
except Exception:
# 忽略保存错误
pass
class SM15MAlgorithm:
algo_name = "SM-15M"
class AlgodataDict(TypedDict):
efactor: float
real_rept: int
rept: int
interval: int
last_date: int
next_date: int
is_activated: int
last_modify: float
defaults = {
"efactor": 2.5,
"real_rept": 0,
"rept": 0,
"interval": 0,
"last_date": 0,
"next_date": 0,
"is_activated": 0,
"last_modify": 0.0,
}
@classmethod
def _get_timestamp(cls):
"""获取当前时间戳(秒)"""
return datetime.datetime.now().timestamp()
@classmethod
def _get_daystamp(cls):
"""获取当前天数戳(从某个纪元开始的天数)"""
# 使用与原始 SM-2 相同的纪元1970-01-01
now = datetime.datetime.now()
epoch = datetime.datetime(1970, 1, 1)
delta = now - epoch
return delta.days
@classmethod
def _algodata_to_item(cls, algodata, sm_instance):
"""将 algodata 转换为 Item 实例"""
# 从 algodata 获取 SM-2 数据
sm15_data = algodata.get(cls.algo_name, cls.defaults.copy())
# 创建 Item 实例
item = Item(sm_instance)
# 映射字段
# efactor -> A-Factor (需要转换)
efactor = sm15_data.get("efactor", 2.5)
# SM-2 的 efactor 范围 [1.3, 2.5+], SM-15 的 A-Factor 范围 [1.2, 6.9]
# 简单线性映射af = (efactor - 1.3) * (MAX_AF - MIN_AF) / (2.5 - 1.3) + MIN_AF
# 但 efactor 可能大于 2.5, 所以需要限制
af = max(MIN_AF, min(MAX_AF, efactor * 2.0)) # 粗略映射
# 调试
# print(f"DEBUG: efactor={efactor}, af before set={af}")
item.af(af)
# print(f"DEBUG: item.af() after set={item.af()}")
# rept -> repetition (成功回忆次数)
rept = sm15_data.get("rept", 0)
item.repetition = (
rept - 1 if rept > 0 else -1
) # SM-15 中 repetition=-1 表示新项目
# real_rept -> lapse? 或者忽略
real_rept = sm15_data.get("real_rept", 0)
# 可以存储在 value 中或忽略
# interval -> optimum_interval (需要从天数转换为毫秒)
interval_days = sm15_data.get("interval", 0)
if interval_days == 0:
item.optimum_interval = sm_instance.interval_base
else:
item.optimum_interval = interval_days * 24 * 60 * 60 * 1000 # 天转毫秒
# last_date -> previous_date
last_date_days = sm15_data.get("last_date", 0)
if last_date_days > 0:
epoch = datetime.datetime(1970, 1, 1)
item.previous_date = epoch + datetime.timedelta(days=last_date_days)
# next_date -> due_date
next_date_days = sm15_data.get("next_date", 0)
if next_date_days > 0:
epoch = datetime.datetime(1970, 1, 1)
item.due_date = epoch + datetime.timedelta(days=next_date_days)
# is_activated 和 last_modify 忽略
# 将原始 algodata 保存在 value 中以便恢复
item.value = {
"front": "SM-15 item",
"back": "SM-15 item",
"_sm15_data": sm15_data,
}
return item
@classmethod
def _item_to_algodata(cls, item, algodata):
"""将 Item 实例状态写回 algodata"""
if cls.algo_name not in algodata:
algodata[cls.algo_name] = cls.defaults.copy()
sm15_data = algodata[cls.algo_name]
# A-Factor -> efactor (反向映射)
af = item.af()
if af is None:
af = MIN_AF
# 反向粗略映射
efactor = max(1.3, min(af / 2.0, 10.0)) # 限制范围
# 调试
# print(f"DEBUG: item.af()={af}, computed efactor={efactor}")
sm15_data["efactor"] = efactor
# repetition -> rept
rept = item.repetition + 1 if item.repetition >= 0 else 0
sm15_data["rept"] = rept
# real_rept: 递增在 revisor 中处理, 这里保持不变
# 但如果没有 real_rept 字段, 则初始化为0
if "real_rept" not in sm15_data:
sm15_data["real_rept"] = 0
# optimum_interval -> interval (毫秒转天)
interval_ms = item.optimum_interval
if interval_ms == item.sm.interval_base:
sm15_data["interval"] = 0
else:
interval_days = max(0, round(interval_ms / (24 * 60 * 60 * 1000)))
sm15_data["interval"] = interval_days
# previous_date -> last_date
if item.previous_date:
epoch = datetime.datetime(1970, 1, 1)
last_date_days = (item.previous_date - epoch).days
sm15_data["last_date"] = last_date_days
else:
sm15_data["last_date"] = 0
# due_date -> next_date
if item.due_date:
epoch = datetime.datetime(1970, 1, 1)
next_date_days = (item.due_date - epoch).days
sm15_data["next_date"] = next_date_days
else:
sm15_data["next_date"] = 0
# is_activated: 保持不变或设为1
if "is_activated" not in sm15_data:
sm15_data["is_activated"] = 1
# last_modify: 更新时间戳
sm15_data["last_modify"] = cls._get_timestamp()
return algodata
@classmethod
def revisor(
cls, algodata: dict, feedback: int = 5, is_new_activation: bool = False
):
"""SM-15 算法迭代决策机制实现"""
# 获取全局 SM 实例
sm_instance = _get_global_sm()
# 将 algodata 转换为 Item
item = cls._algodata_to_item(algodata, sm_instance)
# 处理 is_new_activation
if is_new_activation:
# 重置为初始状态
item.repetition = -1
item.lapse = 0
item.optimum_interval = sm_instance.interval_base
item.previous_date = None
item.due_date = datetime.datetime.fromtimestamp(0)
item.af(2.5) # 重置 efactor
# 将项目临时添加到 SM 实例(以便 answer 更新共享状态)
sm_instance.q.append(item)
# 处理反馈(评分)
# SM-2 的 feedback 是 0-5, SM-15 的 grade 也是 0-5
grade = feedback
now = datetime.datetime.now()
# 调用 answer 方法
item.answer(grade, now)
# 更新共享状态(FI-Graph, ForgettingCurves, OFM)
if item.repetition >= 0:
sm_instance.forgetting_curves.register_point(grade, item, now)
sm_instance.ofm.update()
sm_instance.fi_g.update(grade, item, now)
# 从队列中移除项目
sm_instance.q.remove(item)
# 保存全局状态
_save_global_sm(sm_instance)
# 将更新后的 Item 状态写回 algodata
cls._item_to_algodata(item, algodata)
# 更新 real_rept(总复习次数)
algodata[cls.algo_name]["real_rept"] += 1
@classmethod
def is_due(cls, algodata):
"""检查项目是否到期"""
sm15_data = algodata.get(cls.algo_name, cls.defaults.copy())
next_date_days = sm15_data.get("next_date", 0)
current_daystamp = cls._get_daystamp()
return next_date_days <= current_daystamp
@classmethod
def rate(cls, algodata):
"""获取项目的评分(返回 efactor 字符串)"""
sm15_data = algodata.get(cls.algo_name, cls.defaults.copy())
efactor = sm15_data.get("efactor", 2.5)
return str(efactor)
@classmethod
def nextdate(cls, algodata) -> int:
"""获取下次复习日期(天数戳)"""
sm15_data = algodata.get(cls.algo_name, cls.defaults.copy())
next_date_days = sm15_data.get("next_date", 0)
return next_date_days

File diff suppressed because it is too large Load Diff

View File

@@ -9,7 +9,7 @@ logger = get_logger(__name__)
class Electron:
"""电子: 记忆分析元数据及算法"""
def __init__(self, ident: str, algodata: dict = {}, algo_name: str = "supermemo2"):
def __init__(self, ident: str, algodata: dict = {}, algo_name: str = "SM-2"):
"""初始化电子对象 (记忆数据)
Args:

View File

@@ -8,7 +8,7 @@ class BasePuzzle:
"""谜题基类"""
def refresh(self):
logger.debug("BasePuzzle.refresh 被调用未实现")
logger.debug("BasePuzzle.refresh 被调用(未实现)")
raise NotImplementedError("谜题对象未实现 refresh 方法")
def __str__(self):

View File

@@ -16,5 +16,5 @@ class RecognitionPuzzle(BasePuzzle):
super().__init__()
def refresh(self):
logger.debug("RecognitionPuzzle.refresh空实现")
logger.debug("RecognitionPuzzle.refresh(空实现)")
pass

View File

@@ -52,7 +52,7 @@ class Procession:
self.queue.append(atom)
logger.debug("原子已追加到队列, 新队列长度=%d", len(self.queue))
else:
logger.debug("原子未追加重复或队列长度<=1")
logger.debug("原子未追加(重复或队列长度<=1)")
def __len__(self):
length = len(self.queue) - self.cursor

View File

@@ -2,4 +2,4 @@ from heurams.services.logger import get_logger
logger = get_logger(__name__)
logger.debug("OpenAI provider 模块已加载未实现")
logger.debug("OpenAI provider 模块已加载(未实现)")

View File

@@ -45,7 +45,7 @@ def setup_logging(
# 创建formatter
formatter = logging.Formatter(log_format, date_format)
# 创建文件handler使用RotatingFileHandler防止日志过大
# 创建文件handler(使用RotatingFileHandler防止日志过大)
file_handler = logging.handlers.RotatingFileHandler(
filename=log_path,
maxBytes=max_bytes,
@@ -55,7 +55,7 @@ def setup_logging(
file_handler.setFormatter(formatter)
file_handler.setLevel(log_level)
# 配置root logger - 设置为 WARNING 级别只记录重要信息
# 配置root logger - 设置为 WARNING 级别(只记录重要信息)
root_logger = logging.getLogger()
root_logger.setLevel(logging.WARNING) # 这里改为 WARNING
@@ -63,7 +63,7 @@ def setup_logging(
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
# 创建自己的应用logger单独设置DEBUG级别
# 创建自己的应用logger(单独设置DEBUG级别)
app_logger = logging.getLogger("heurams")
app_logger.setLevel(log_level) # 保持DEBUG级别
app_logger.addHandler(file_handler)
@@ -146,7 +146,7 @@ def exception(msg: str, *args, **kwargs) -> None:
get_logger().exception(msg, *args, **kwargs)
# 初始化日志系统硬编码配置
# 初始化日志系统(硬编码配置)
setup_logging()

View File

@@ -17,7 +17,7 @@ from heurams.services.config import ConfigFile
class TestDashboardScreenUnit(unittest.TestCase):
"""DashboardScreen 的单元测试不启动完整应用."""
"""DashboardScreen 的单元测试(不启动完整应用)."""
def setUp(self):
"""在每个测试之前运行, 设置临时目录和配置."""

View File

@@ -94,7 +94,7 @@ class TestSM2Algorithm(unittest.TestCase):
SM2Algorithm.revisor(algodata, feedback=5, is_new_activation=True)
self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 0)
self.assertEqual(algodata[SM2Algorithm.algo_name]["efactor"], 2.5)
# interval 应为 1因为 rept=0
# interval 应为 1(因为 rept=0)
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 1)
def test_revisor_efactor_calculation(self):

View File

@@ -27,7 +27,7 @@ class TestElectron(unittest.TestCase):
self.assertEqual(electron.algo, algorithms["supermemo2"])
self.assertIn(electron.algo, electron.algodata)
self.assertIsInstance(electron.algodata[electron.algo], dict)
# 检查默认值排除动态字段
# 检查默认值(排除动态字段)
defaults = electron.algo.defaults
for key, value in defaults.items():
if key == "last_modify":

View File

@@ -48,7 +48,7 @@ class TestMCQPuzzle(unittest.TestCase):
# 模拟 random.sample 返回前两个映射项
mock_sample.side_effect = [
[("q1", "a1"), ("q2", "a2")], # 选择问题
["j1", "j2", "j3"], # 为每个问题选择干扰项实际调用两次
["j1", "j2", "j3"], # 为每个问题选择干扰项(实际调用两次)
]
puzzle.refresh()
@@ -59,7 +59,7 @@ class TestMCQPuzzle(unittest.TestCase):
self.assertEqual(puzzle.answer, ["a1", "a2"])
# 检查 options 列表
self.assertEqual(len(puzzle.options), 2)
# 每个选项列表应包含 4 个选项正确答案 + 3 个干扰项
# 每个选项列表应包含 4 个选项(正确答案 + 3 个干扰项)
self.assertEqual(len(puzzle.options[0]), 4)
self.assertEqual(len(puzzle.options[1]), 4)
# random.shuffle 应被调用