From 98ec6504a469e4a65247c15d12dde32bef683365 Mon Sep 17 00:00:00 2001 From: david-ajax Date: Sun, 21 Dec 2025 02:15:23 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E9=AA=8C=E6=80=A7=20SM-15M=20?= =?UTF-8?q?=E7=AE=97=E6=B3=95=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 实验性 SM-15M 逆向工程算法实现 --- README.md | 2 +- src/heurams/interface/widgets/basic_puzzle.py | 2 +- src/heurams/interface/widgets/mcq_puzzle.py | 4 +- src/heurams/kernel/algorithms/__init__.py | 4 +- src/heurams/kernel/algorithms/sm15m.py | 281 +++ src/heurams/kernel/algorithms/sm15m_calc.py | 1573 +++++++++++++++++ src/heurams/kernel/particles/electron.py | 2 +- src/heurams/kernel/puzzles/base.py | 2 +- src/heurams/kernel/puzzles/recognition.py | 2 +- src/heurams/kernel/reactor/procession.py | 2 +- src/heurams/providers/llm/openai.py | 2 +- src/heurams/services/logger.py | 8 +- tests/interface/test_dashboard.py | 2 +- tests/kernel/algorithms/test_sm2.py | 2 +- tests/kernel/particles/test_electron.py | 2 +- tests/kernel/puzzles/test_mcq.py | 4 +- 16 files changed, 1875 insertions(+), 19 deletions(-) create mode 100644 src/heurams/kernel/algorithms/sm15m.py create mode 100644 src/heurams/kernel/algorithms/sm15m_calc.py diff --git a/README.md b/README.md index 1a2583b..b4f9148 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ python -m heurams.interface ## 配置 -配置文件位于 `config/config.toml`(相对于工作目录). 如果不存在, 会使用内置的默认配置. +配置文件位于 `config/config.toml`(相对于工作目录). 如果不存在, 会使用内置的默认配置. ## 项目结构 diff --git a/src/heurams/interface/widgets/basic_puzzle.py b/src/heurams/interface/widgets/basic_puzzle.py index 05d35ed..9908e60 100644 --- a/src/heurams/interface/widgets/basic_puzzle.py +++ b/src/heurams/interface/widgets/basic_puzzle.py @@ -49,7 +49,7 @@ class BasicEvaluation(BasePuzzleWidget): # 显示主要内容 yield Label(self.atom.registry["nucleon"]["content"], id="main") - # 显示评估说明(可选) + # 显示评估说明(可选) yield Static("请评估你对这个内容的记忆程度: ", classes="instruction") # 按钮容器 diff --git a/src/heurams/interface/widgets/mcq_puzzle.py b/src/heurams/interface/widgets/mcq_puzzle.py index 1589847..9ba9e08 100644 --- a/src/heurams/interface/widgets/mcq_puzzle.py +++ b/src/heurams/interface/widgets/mcq_puzzle.py @@ -116,7 +116,7 @@ class MCQPuzzle(BasePuzzleWidget): self.screen.rating = rating # type: ignore self.handler(rating) - # 重置输入(如果回答错误) + # 重置输入(如果回答错误) if not is_correct: self.inputlist = [] self.refresh_buttons() @@ -127,7 +127,7 @@ class MCQPuzzle(BasePuzzleWidget): self.update_display() def refresh_buttons(self): - """刷新按钮显示(用于题目切换)""" + """刷新按钮显示(用于题目切换)""" # 移除所有选项按钮 logger.debug("刷新按钮") self.cursor += 1 diff --git a/src/heurams/kernel/algorithms/__init__.py b/src/heurams/kernel/algorithms/__init__.py index b19ae41..a6ebbeb 100644 --- a/src/heurams/kernel/algorithms/__init__.py +++ b/src/heurams/kernel/algorithms/__init__.py @@ -1,6 +1,7 @@ from heurams.services.logger import get_logger from .sm2 import SM2Algorithm +from .sm15m import SM15MAlgorithm logger = get_logger(__name__) @@ -10,7 +11,8 @@ __all__ = [ algorithms = { "SM-2": SM2Algorithm, - "supermemo2": SM2Algorithm, + "SM-15M": SM15MAlgorithm, + # "SM-15M": SM15MAlgorithm, } logger.debug("算法模块初始化完成, 注册的算法: %s", list(algorithms.keys())) diff --git a/src/heurams/kernel/algorithms/sm15m.py b/src/heurams/kernel/algorithms/sm15m.py new file mode 100644 index 0000000..3e30d29 --- /dev/null +++ b/src/heurams/kernel/algorithms/sm15m.py @@ -0,0 +1,281 @@ +""" +SM-15 接口兼容实现, 基于 SM-15 算法的逆向工程 +全局状态保存在文件中, 项目状态通过 algodata 字典传递 + +基于: https://github.com/kazuaki/sm.js +原始 CoffeeScript 代码: (c) 2014 Kazuaki Tanida (MIT 许可证) +""" + +import datetime +import json +import os +from typing import TypedDict + +from heurams.kernel.algorithms.sm15m_calc import (MAX_AF, MIN_AF, NOTCH_AF, + RANGE_AF, RANGE_REPETITION, + SM, THRESHOLD_RECALL, Item) + +# 全局状态文件路径 +_GLOBAL_STATE_FILE = os.path.expanduser("~/.sm15_global_state.json") + + +def _get_global_sm(): + """获取全局 SM 实例, 从文件加载或创建新的""" + if os.path.exists(_GLOBAL_STATE_FILE): + try: + with open(_GLOBAL_STATE_FILE, "r", encoding="utf-8") as f: + data = json.load(f) + sm_instance = SM.load(data) + return sm_instance + except Exception: + # 如果加载失败, 创建新的实例 + pass + + # 创建新的 SM 实例 + sm_instance = SM() + # 保存初始状态 + _save_global_sm(sm_instance) + return sm_instance + + +def _save_global_sm(sm_instance): + """保存全局 SM 实例到文件""" + try: + data = sm_instance.data() + with open(_GLOBAL_STATE_FILE, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + except Exception: + # 忽略保存错误 + pass + + +class SM15MAlgorithm: + algo_name = "SM-15M" + + class AlgodataDict(TypedDict): + efactor: float + real_rept: int + rept: int + interval: int + last_date: int + next_date: int + is_activated: int + last_modify: float + + defaults = { + "efactor": 2.5, + "real_rept": 0, + "rept": 0, + "interval": 0, + "last_date": 0, + "next_date": 0, + "is_activated": 0, + "last_modify": 0.0, + } + + @classmethod + def _get_timestamp(cls): + """获取当前时间戳(秒)""" + return datetime.datetime.now().timestamp() + + @classmethod + def _get_daystamp(cls): + """获取当前天数戳(从某个纪元开始的天数)""" + # 使用与原始 SM-2 相同的纪元:1970-01-01 + now = datetime.datetime.now() + epoch = datetime.datetime(1970, 1, 1) + delta = now - epoch + return delta.days + + @classmethod + def _algodata_to_item(cls, algodata, sm_instance): + """将 algodata 转换为 Item 实例""" + # 从 algodata 获取 SM-2 数据 + sm15_data = algodata.get(cls.algo_name, cls.defaults.copy()) + + # 创建 Item 实例 + item = Item(sm_instance) + + # 映射字段 + # efactor -> A-Factor (需要转换) + efactor = sm15_data.get("efactor", 2.5) + # SM-2 的 efactor 范围 [1.3, 2.5+], SM-15 的 A-Factor 范围 [1.2, 6.9] + # 简单线性映射:af = (efactor - 1.3) * (MAX_AF - MIN_AF) / (2.5 - 1.3) + MIN_AF + # 但 efactor 可能大于 2.5, 所以需要限制 + af = max(MIN_AF, min(MAX_AF, efactor * 2.0)) # 粗略映射 + # 调试 + # print(f"DEBUG: efactor={efactor}, af before set={af}") + item.af(af) + # print(f"DEBUG: item.af() after set={item.af()}") + + # rept -> repetition (成功回忆次数) + rept = sm15_data.get("rept", 0) + item.repetition = ( + rept - 1 if rept > 0 else -1 + ) # SM-15 中 repetition=-1 表示新项目 + + # real_rept -> lapse? 或者忽略 + real_rept = sm15_data.get("real_rept", 0) + # 可以存储在 value 中或忽略 + + # interval -> optimum_interval (需要从天数转换为毫秒) + interval_days = sm15_data.get("interval", 0) + if interval_days == 0: + item.optimum_interval = sm_instance.interval_base + else: + item.optimum_interval = interval_days * 24 * 60 * 60 * 1000 # 天转毫秒 + + # last_date -> previous_date + last_date_days = sm15_data.get("last_date", 0) + if last_date_days > 0: + epoch = datetime.datetime(1970, 1, 1) + item.previous_date = epoch + datetime.timedelta(days=last_date_days) + + # next_date -> due_date + next_date_days = sm15_data.get("next_date", 0) + if next_date_days > 0: + epoch = datetime.datetime(1970, 1, 1) + item.due_date = epoch + datetime.timedelta(days=next_date_days) + + # is_activated 和 last_modify 忽略 + + # 将原始 algodata 保存在 value 中以便恢复 + item.value = { + "front": "SM-15 item", + "back": "SM-15 item", + "_sm15_data": sm15_data, + } + + return item + + @classmethod + def _item_to_algodata(cls, item, algodata): + """将 Item 实例状态写回 algodata""" + if cls.algo_name not in algodata: + algodata[cls.algo_name] = cls.defaults.copy() + + sm15_data = algodata[cls.algo_name] + + # A-Factor -> efactor (反向映射) + af = item.af() + if af is None: + af = MIN_AF + # 反向粗略映射 + efactor = max(1.3, min(af / 2.0, 10.0)) # 限制范围 + # 调试 + # print(f"DEBUG: item.af()={af}, computed efactor={efactor}") + sm15_data["efactor"] = efactor + + # repetition -> rept + rept = item.repetition + 1 if item.repetition >= 0 else 0 + sm15_data["rept"] = rept + + # real_rept: 递增在 revisor 中处理, 这里保持不变 + # 但如果没有 real_rept 字段, 则初始化为0 + if "real_rept" not in sm15_data: + sm15_data["real_rept"] = 0 + + # optimum_interval -> interval (毫秒转天) + interval_ms = item.optimum_interval + if interval_ms == item.sm.interval_base: + sm15_data["interval"] = 0 + else: + interval_days = max(0, round(interval_ms / (24 * 60 * 60 * 1000))) + sm15_data["interval"] = interval_days + + # previous_date -> last_date + if item.previous_date: + epoch = datetime.datetime(1970, 1, 1) + last_date_days = (item.previous_date - epoch).days + sm15_data["last_date"] = last_date_days + else: + sm15_data["last_date"] = 0 + + # due_date -> next_date + if item.due_date: + epoch = datetime.datetime(1970, 1, 1) + next_date_days = (item.due_date - epoch).days + sm15_data["next_date"] = next_date_days + else: + sm15_data["next_date"] = 0 + + # is_activated: 保持不变或设为1 + if "is_activated" not in sm15_data: + sm15_data["is_activated"] = 1 + + # last_modify: 更新时间戳 + sm15_data["last_modify"] = cls._get_timestamp() + + return algodata + + @classmethod + def revisor( + cls, algodata: dict, feedback: int = 5, is_new_activation: bool = False + ): + """SM-15 算法迭代决策机制实现""" + # 获取全局 SM 实例 + sm_instance = _get_global_sm() + + # 将 algodata 转换为 Item + item = cls._algodata_to_item(algodata, sm_instance) + + # 处理 is_new_activation + if is_new_activation: + # 重置为初始状态 + item.repetition = -1 + item.lapse = 0 + item.optimum_interval = sm_instance.interval_base + item.previous_date = None + item.due_date = datetime.datetime.fromtimestamp(0) + item.af(2.5) # 重置 efactor + + # 将项目临时添加到 SM 实例(以便 answer 更新共享状态) + sm_instance.q.append(item) + + # 处理反馈(评分) + # SM-2 的 feedback 是 0-5, SM-15 的 grade 也是 0-5 + grade = feedback + now = datetime.datetime.now() + + # 调用 answer 方法 + item.answer(grade, now) + + # 更新共享状态(FI-Graph, ForgettingCurves, OFM) + if item.repetition >= 0: + sm_instance.forgetting_curves.register_point(grade, item, now) + sm_instance.ofm.update() + sm_instance.fi_g.update(grade, item, now) + + # 从队列中移除项目 + sm_instance.q.remove(item) + + # 保存全局状态 + _save_global_sm(sm_instance) + + # 将更新后的 Item 状态写回 algodata + cls._item_to_algodata(item, algodata) + + # 更新 real_rept(总复习次数) + algodata[cls.algo_name]["real_rept"] += 1 + + @classmethod + def is_due(cls, algodata): + """检查项目是否到期""" + sm15_data = algodata.get(cls.algo_name, cls.defaults.copy()) + next_date_days = sm15_data.get("next_date", 0) + current_daystamp = cls._get_daystamp() + return next_date_days <= current_daystamp + + @classmethod + def rate(cls, algodata): + """获取项目的评分(返回 efactor 字符串)""" + sm15_data = algodata.get(cls.algo_name, cls.defaults.copy()) + efactor = sm15_data.get("efactor", 2.5) + return str(efactor) + + @classmethod + def nextdate(cls, algodata) -> int: + """获取下次复习日期(天数戳)""" + sm15_data = algodata.get(cls.algo_name, cls.defaults.copy()) + next_date_days = sm15_data.get("next_date", 0) + return next_date_days diff --git a/src/heurams/kernel/algorithms/sm15m_calc.py b/src/heurams/kernel/algorithms/sm15m_calc.py new file mode 100644 index 0000000..bcd2e5b --- /dev/null +++ b/src/heurams/kernel/algorithms/sm15m_calc.py @@ -0,0 +1,1573 @@ +""" +基于: https://github.com/kazuaki/sm.js +原始 CoffeeScript 代码: (c) 2014 Kazuaki Tanida +MIT 许可证 + +================================================================================ + +主要算法概念: + +1. 间隔重复 (Spaced Repetition) + - 根据记忆强度动态调整复习间隔 + - 使用遗忘曲线预测记忆保留率 + +2. A-Factor (难度因子) + - 表示项目的记忆难度 + - 范围: MIN_AF (1.2) 到 MAX_AF (≈7.5) + - 值越大表示项目越容易记忆 + +3. O-Factor (最优因子) + - 基于重复次数和难度因子的最优间隔乘数 + - 存储在 O-Factor 矩阵中 + +4. R-Factor (回忆因子) + - 基于重复次数和实际遗忘指数的实际间隔乘数 + - 存储在 R-Factor 矩阵中 + +5. 遗忘曲线 (Forgetting Curve) + - 描述记忆保留率随时间衰减的曲线 + - 用于计算遗忘指数 (Forgetting Index) + +6. 遗忘指数-评分图 (FI-Grade Graph) + - 建立遗忘指数与用户评分之间的关系 + - 用于校正回忆因子 + +================================================================================ +""" + +import datetime +import json +import math +import sys +from typing import Any, Callable, Dict, List, Optional, Tuple + +# ============================================================================ +# Global Constants +# ============================================================================ + +# A-Factor 的取值范围大小(矩阵维度) +RANGE_AF = 20 + +# 重复次数的取值范围大小(矩阵维度) +RANGE_REPETITION = 20 + +# 最小 A-Factor 值(最简单的项目) +MIN_AF = 1.2 + +# A-Factor 的步长(每个等级的增量) +NOTCH_AF = 0.3 + +# 最大 A-Factor 值(最难的项目) +# 计算公式: MIN_AF + NOTCH_AF * (RANGE_AF - 1) = 1.2 + 0.3 * 19 = 6.9 +MAX_AF = MIN_AF + NOTCH_AF * (RANGE_AF - 1) + +# 最大评分值(用户评分的上限) +MAX_GRADE = 5 + +# 记忆阈值:评分 >= 此值表示成功回忆 +THRESHOLD_RECALL = 3 + +# ============================================================================ +# Helper Functions +# ============================================================================ + + +def sum_values(values): + """ + 计算列表中所有数值的和 + + 参数: + values: 数值列表 + + 返回: + 所有数值的总和 + """ + return sum(values) + + +def mse(y_func, points): + """ + 计算函数 y 与数据点之间的均方误差 (Mean Squared Error) + + 参数: + y_func: 函数 y = f(x) + points: 数据点列表, 每个点为 (x, y) + + 返回: + 均方误差值, 衡量函数拟合程度 + """ + errors = [(y_func(p[0]) - p[1]) ** 2 for p in points] + return sum(errors) / len(points) if errors else 0 + + +def exponential_regression(points): + """ + 指数回归: y = a * exp(b * x) + + 使用最小二乘法拟合指数函数 y = a * e^(b*x)。 + 算法参考: http://mathworld.wolfram.com/LeastSquaresFittingExponential.html + + 参数: + points: 数据点列表, 每个点为 (x, y) + + 返回: + 包含以下键的字典: + - 'y': 函数 y(x) = a * exp(b * x) + - 'x': 反函数 x(y) = (ln(y) - ln(a)) / b + - 'a': 系数 a + - 'b': 指数系数 b + - 'mse': 计算均方误差的函数 + + 数学推导: + 对 y = a * e^(b*x) 两边取对数: ln(y) = ln(a) + b*x + 令 Y' = ln(y), a' = ln(a), 转换为线性回归: Y' = a' + b*x + 使用最小二乘法求解 a' 和 b, 然后 a = exp(a') + """ + n = len(points) + X = [p[0] for p in points] + Y = [p[1] for p in points] + logY = [math.log(y) for y in Y] + sqX = [x * x for x in X] + + sum_logY = sum(logY) + sum_sqX = sum(sqX) + sumX = sum(X) + sumX_logY = sum(X[i] * logY[i] for i in range(n)) + sq_sumX = sumX * sumX + + a_coeff = (sum_logY * sum_sqX - sumX * sumX_logY) / (n * sum_sqX - sq_sumX) + b_coeff = (n * sumX_logY - sumX * sum_logY) / (n * sum_sqX - sq_sumX) + + a = math.exp(a_coeff) + b = b_coeff + + def y_func(x): + return a * math.exp(b * x) + + def x_func(y): + return (-a_coeff + math.log(y)) / b if b != 0 else 0 + + result = { + "y": y_func, + "x": x_func, + "a": a, + "b": b, + "mse": lambda: mse(y_func, points), + } + return result + + +def linear_regression(points): + """ + 线性回归: y = a + b * x + + 使用最小二乘法拟合线性函数。 + + 参数: + points: 数据点列表, 每个点为 (x, y) + + 返回: + 包含以下键的字典: + - 'y': 函数 y(x) = a + b * x + - 'x': 反函数 x(y) = (y - a) / b + - 'a': 截距 + - 'b': 斜率 + + 计算公式: + b = (n*Σxy - ΣxΣy) / (n*Σx² - (Σx)²) + a = (Σy - b*Σx) / n + """ + n = len(points) + X = [p[0] for p in points] + Y = [p[1] for p in points] + sqX = [x * x for x in X] + + sumY = sum(Y) + sum_sqX = sum(sqX) + sumX = sum(X) + sumXY = sum(X[i] * Y[i] for i in range(n)) + sq_sumX = sumX * sumX + + a = (sumY * sum_sqX - sumX * sumXY) / (n * sum_sqX - sq_sumX) + b = (n * sumXY - sumX * sumY) / (n * sum_sqX - sq_sumX) + + def y_func(x): + return a + b * x + + def x_func(y): + return (y - a) / b if b != 0 else 0 + + return {"y": y_func, "x": x_func, "a": a, "b": b} + + +def power_law_model(a, b): + """ + 幂律模型: y = a * x^b + + 创建幂律函数模型对象。 + + 参数: + a: 系数 + b: 指数 + + 返回: + 包含以下键的字典: + - 'y': 函数 y(x) = a * x^b + - 'x': 反函数 x(y) = (y / a)^(1/b) + - 'a': 系数 a + - 'b': 指数 b + """ + + def y_func(x): + return a * (x**b) + + def x_func(y): + return (y / a) ** (1 / b) if a != 0 and b != 0 else 0 + + return {"y": y_func, "x": x_func, "a": a, "b": b} + + +def power_law_regression(points): + """ + 幂律回归: y = a * x^b + + 使用最小二乘法拟合幂律函数。 + 算法参考: http://mathworld.wolfram.com/LeastSquaresFittingPowerLaw.html + + 参数: + points: 数据点列表, 每个点为 (x, y) + + 返回: + 幂律模型字典(包含 'y', 'x', 'a', 'b', 'mse' 键) + + 数学推导: + 对 y = a * x^b 两边取对数: ln(y) = ln(a) + b * ln(x) + 令 Y' = ln(y), X' = ln(x), a' = ln(a) + 转换为线性回归: Y' = a' + b * X' + 使用最小二乘法求解 a' 和 b, 然后 a = exp(a') + """ + n = len(points) + X = [p[0] for p in points] + Y = [p[1] for p in points] + logX = [math.log(x) for x in X] + logY = [math.log(y) for y in Y] + + sum_logX_logY = sum(logX[i] * logY[i] for i in range(n)) + sum_logX = sum(logX) + sum_logY = sum(logY) + sum_sq_logX = sum(lx * lx for lx in logX) + sq_sum_logX = sum_logX * sum_logX + + b = (n * sum_logX_logY - sum_logX * sum_logY) / (n * sum_sq_logX - sq_sum_logX) + a_coeff = (sum_logY - b * sum_logX) / n + a = math.exp(a_coeff) + + model = power_law_model(a, b) + model["mse"] = lambda: mse(model["y"], points) + return model + + +def fixed_point_power_law_regression(points, fixed_point): + """ + 定点幂律回归: y = q * (x/p)^b + + 拟合经过固定点 (p, q) 的幂律函数。 + 在 SM-15 算法中用于拟合 O-Factor 矩阵。 + + 参数: + points: 数据点列表, 每个点为 (x, y) + fixed_point: 固定点 (p, q), 函数必须经过此点 + + 返回: + 幂律模型字典(包含 'y', 'x', 'a', 'b' 键) + + 数学推导: + 给定固定点 (p, q), 模型为: y = q * (x/p)^b + 对两边取对数: ln(y) = b * ln(x/p) + ln(q) + 令 Y' = ln(y) - ln(q), X' = ln(x/p) + 转换为通过原点的线性回归: Y' = b * X' + 使用最小二乘法求解 b + """ + n = len(points) + p, q = fixed_point + logQ = math.log(q) + + X = [math.log(point[0] / p) for point in points] + Y = [math.log(point[1]) - logQ for point in points] + + # Linear regression through origin on transformed points + sumXY = sum(X[i] * Y[i] for i in range(n)) + sum_sqX = sum(x * x for x in X) + b = sumXY / sum_sqX if sum_sqX != 0 else 0 + + model = power_law_model(q / (p**b), b) + return model + + +def linear_regression_through_origin(points): + """ + 通过原点的线性回归: y = b * x + + 拟合通过原点的线性函数, 即截距为 0。 + + 参数: + points: 数据点列表, 每个点为 (x, y) + + 返回: + 包含以下键的字典: + - 'y': 函数 y(x) = b * x + - 'x': 反函数 x(y) = y / b + - 'b': 斜率 + + 计算公式: + b = Σ(x_i * y_i) / Σ(x_i²) + """ + n = len(points) + X = [p[0] for p in points] + Y = [p[1] for p in points] + + sumXY = sum(X[i] * Y[i] for i in range(n)) + sum_sqX = sum(x * x for x in X) + b = sumXY / sum_sqX if sum_sqX != 0 else 0 + + def y_func(x): + return b * x + + def x_func(y): + return y / b if b != 0 else 0 + + return {"y": y_func, "x": x_func, "b": b} + + +# ============================================================================ +# Core Classes +# ============================================================================ + + +class Item: + """ + 表示单个闪卡项目(记忆项目)。 + + 在 SM-15 算法中, 每个项目代表一个需要记忆的单元(如单词、概念等)。 + 项目包含记忆状态、复习历史和算法参数。 + + 属性: + sm: 所属的 SM 实例 + value: 项目内容(通常是字典, 包含 front/back) + lapse: 遗忘次数 + repetition: 成功回忆次数(-1 表示新项目) + of: O-Factor 值(最优因子) + optimum_interval: 最优复习间隔(毫秒) + due_date: 下次复习到期时间 + previous_date: 上次复习时间 + _afs: 估计的 A-Factor 历史记录 + _af: 当前 A-Factor 值 + + 主要功能: + 1. 计算实际间隔和 UF(使用因子) + 2. 管理 A-Factor(难度因子) + 3. 处理用户评分并更新记忆状态 + 4. 计算下一次复习间隔 + 5. 序列化和反序列化 + + 算法原理: + - 间隔重复基于最优间隔和 O-Factor 调整 + - A-Factor 反映项目难度, 通过历史估计值加权平均计算 + - UF(使用因子)是实际间隔与调整后最优间隔的比率 + - 当评分低于阈值时, 项目被标记为遗忘(lapse增加) + """ + + MAX_AFS_COUNT = 30 + + def __init__(self, sm, value=None): + """ + 初始化新的闪卡项目。 + + 参数: + sm: 所属的 SM 实例 + value: 项目内容(通常是包含 front/back 的字典) + + 初始状态: + - lapse(遗忘次数): 0 + - repetition(重复次数): -1(表示新项目) + - of(O-Factor): 1.0(默认值) + - optimum_interval(最优间隔): 等于 SM 的基础间隔 + - due_date(到期时间): 1970-01-01(立即到期) + - previous_date(上次复习): None(尚未复习) + - _afs(A-Factor 历史): 空列表 + - _af(当前 A-Factor): None(尚未计算) + """ + self.sm = sm + self.value = value + self.lapse = 0 + self.repetition = -1 + self.of = 1.0 + self.optimum_interval = sm.interval_base + self.due_date = datetime.datetime.fromtimestamp(0) # epoch start + self.previous_date = None + self._afs = [] # estimated A-Factor history + self._af = None # current A-Factor + + def interval(self, now=None): + """ + 计算自上次复习以来的实际间隔。 + + 参数: + now: 当前时间(默认为当前时间) + + 返回: + 实际间隔(毫秒) + + 注意: + - 如果项目尚未复习过(previous_date为None), 返回基础间隔 + - 间隔计算使用实际经过的时间, 而非计划的间隔 + - 返回值为毫秒, 与SM-15算法内部表示一致 + """ + if now is None: + now = datetime.datetime.now() + + if self.previous_date is None: + return self.sm.interval_base + return ( + now - self.previous_date + ).total_seconds() * 1000 # convert to milliseconds + + def uf(self, now=None): + """ + 计算 UF(使用因子, Utilization Factor)。 + + UF 是实际间隔与调整后最优间隔的比率: + UF = 实际间隔 / (最优间隔 / O-Factor) + + 参数: + now: 当前时间(默认为当前时间) + + 返回: + UF 值 + + 算法意义: + - UF = 1: 实际间隔等于调整后最优间隔 + - UF > 1: 实际间隔长于最优间隔(可能更难回忆) + - UF < 1: 实际间隔短于最优间隔(可能更容易回忆) + - UF 用于估计 A-Factor 和校正记忆模型 + """ + if now is None: + now = datetime.datetime.now() + + interval = self.interval(now) + adjusted_optimum = self.optimum_interval / self.of + return interval / adjusted_optimum if adjusted_optimum != 0 else 0 + + def af(self, value=None): + """ + 获取或设置 A-Factor(难度因子)。 + + A-Factor 表示项目的记忆难度, 值越大表示项目越容易记忆。 + 取值范围: MIN_AF (1.2) 到 MAX_AF (≈7.5), 步长为 NOTCH_AF (0.3)。 + + 参数: + value: 要设置的 A-Factor 值(如果为None则返回当前值) + + 返回: + 当前或设置后的 A-Factor 值 + + 处理逻辑: + - 如果 value 为 None: 返回当前 _af 值 + - 如果提供 value: 将其舍入到最近的 notch 值, 确保在有效范围内 + - 舍入公式: a = round((value - MIN_AF) / NOTCH_AF) + - 最终值: MIN_AF + a * NOTCH_AF, 限制在 [MIN_AF, MAX_AF] 范围内 + """ + if value is None: + return self._af + + # Round to nearest notch + a = round((value - MIN_AF) / NOTCH_AF) + self._af = max(MIN_AF, min(MAX_AF, MIN_AF + a * NOTCH_AF)) + return self._af + + def af_index(self): + """ + 获取最接近的 A-Factor 在矩阵中的索引。 + + 由于 A-Factor 矩阵使用离散值(20个等级), 需要将连续的实际 A-Factor + 映射到最接近的离散值索引。 + + 返回: + A-Factor 矩阵中的索引(0 到 RANGE_AF-1) + + 算法: + 1. 生成所有可能的 A-Factor 值: MIN_AF + i * NOTCH_AF + 2. 计算当前 A-Factor 与每个可能值的绝对差 + 3. 返回差值最小的索引 + + 用途: + 用于从 O-Factor 矩阵和 R-Factor 矩阵中查找对应的值。 + """ + afs = [MIN_AF + i * NOTCH_AF for i in range(RANGE_AF)] + + # Find index with minimum difference + min_diff = float("inf") + min_index = 0 + + for i, af_val in enumerate(afs): + diff = abs(self.af() - af_val) # type: ignore + if diff < min_diff: + min_diff = diff + min_index = i + + return min_index + + def _I(self, now=None): + """ + 计算新的最优间隔(SM-15 算法的第1步)。 + + 注意:此实现与原始 SM-15 的不同之处在于使用实际间隔而非先前计算的间隔。 + + 参数: + now: 当前时间(默认为当前时间) + + 算法步骤: + 1. 根据重复次数和 A-Factor 索引从 O-Factor 矩阵获取 O-Factor 值 + 2. 计算新的 O-Factor: of = max(1, (of_val - 1) * (实际间隔/最优间隔) + 1) + 3. 更新最优间隔: 最优间隔 = round(最优间隔 * of) + 4. 更新时间: previous_date = now, due_date = now + 最优间隔 + + 特殊处理: + - 对于第一次重复(repetition == 0), 使用 lapse 作为 A-Factor 索引 + - 对于后续重复, 使用 af_index() 计算的索引 + + 数学意义: + - O-Factor 根据实际表现动态调整 + - 如果实际间隔长于最优间隔, O-Factor 增加(下次间隔更长) + - 如果实际间隔短于最优间隔, O-Factor 减小(下次间隔更短) + """ + if now is None: + now = datetime.datetime.now() + + # Get O-Factor from matrix + if self.repetition == 0: + af_index = self.lapse + else: + af_index = self.af_index() + + of_val = self.sm.ofm.of(self.repetition, af_index) + + # Calculate new O-Factor + actual_interval = self.interval(now) + self.of = max(1.0, (of_val - 1) * (actual_interval / self.optimum_interval) + 1) + + # Update optimum interval + self.optimum_interval = round(self.optimum_interval * self.of) + + # Update dates + self.previous_date = now + self.due_date = now + datetime.timedelta(milliseconds=self.optimum_interval) + + def _update_af(self, grade, now=None): + """ + 基于评分更新 A-Factor(SM-15 算法的第9、11步)。 + + 参数: + grade: 用户评分(0-5) + now: 当前时间(默认为当前时间) + + 算法步骤: + 1. 从 FI-Grade 图估计遗忘指数 (FI) + 2. 校正 UF: corrected_uf = UF * (requested_FI / estimated_FI) + 3. 估计 A-Factor: + - 如果 repetition > 0: 从 O-Factor 矩阵反推 A-Factor + - 否则: 直接使用 corrected_uf, 限制在有效范围内 + 4. 将估计值加入历史记录(保留最近的 MAX_AFS_COUNT 个) + 5. 计算加权平均值(最近的值权重更高) + 6. 更新当前 A-Factor + + 算法意义: + - 使用遗忘指数校正 UF, 考虑实际记忆表现 + - 通过 O-Factor 矩阵反推 A-Factor, 建立 UF 与 A-Factor 的关系 + - 使用加权平均平滑估计值, 避免单次表现的过度影响 + """ + if now is None: + now = datetime.datetime.now() + + estimated_fi = max(1.0, self.sm.fi_g.fi(grade)) + corrected_uf = self.uf(now) * (self.sm.requested_fi / estimated_fi) + + # Estimate A-Factor + if self.repetition > 0: + estimated_af = self.sm.ofm.af(self.repetition, corrected_uf) + else: + estimated_af = max(MIN_AF, min(MAX_AF, corrected_uf)) + + # Add to history (keep only recent values) + self._afs.append(estimated_af) + if len(self._afs) > self.MAX_AFS_COUNT: + self._afs = self._afs[-self.MAX_AFS_COUNT :] + + # Calculate weighted average + weights = list(range(1, len(self._afs) + 1)) + weighted_sum = sum(af * weight for af, weight in zip(self._afs, weights)) + total_weight = sum(weights) + + self.af(weighted_sum / total_weight if total_weight != 0 else estimated_af) + + def answer(self, grade, now=None): + """ + 处理用户评分, 更新项目状态。 + + 这是 SM-15 算法的核心方法, 根据用户评分决定项目下一步的状态。 + + 参数: + grade: 用户评分(0-5, 0表示完全遗忘, 5表示完美回忆) + now: 当前时间(默认为当前时间) + + 处理逻辑: + 1. 如果不是新项目(repetition >= 0), 更新 A-Factor + 2. 如果评分 >= THRESHOLD_RECALL (3): + - 增加重复次数(如果未达到上限) + - 调用 _I() 计算新的最优间隔和到期时间 + 3. 如果评分 < THRESHOLD_RECALL: + - 增加遗忘次数(如果未达到上限) + - 重置最优间隔为基础间隔 + - 重置 previous_date 为 None(下次 interval() 返回基础间隔) + - 设置 due_date 为当前时间(立即重新复习) + - 重置重复次数为 -1(重新开始学习) + + 算法意义: + - 成功回忆时, 项目进入下一轮间隔重复周期 + - 遗忘时, 项目重置为初始状态, 需要重新学习 + - 阈值 THRESHOLD_RECALL 区分成功与失败回忆 + """ + if now is None: + now = datetime.datetime.now() + + # Update A-Factor if not a new item + if self.repetition >= 0: + self._update_af(grade, now) + + if grade >= THRESHOLD_RECALL: + # Remembered successfully + if self.repetition < RANGE_REPETITION - 1: + self.repetition += 1 + self._I(now) + else: + # Forgotten + if self.lapse < RANGE_AF - 1: + self.lapse += 1 + self.optimum_interval = self.sm.interval_base + self.previous_date = None # reset interval calculation + self.due_date = now + self.repetition = -1 + + def data(self): + """ + 序列化项目数据, 用于保存和加载。 + + 返回: + 包含项目所有状态的字典, 可转换为 JSON 格式保存。 + + 数据结构: + - value: 项目内容 + - repetition: 重复次数 + - lapse: 遗忘次数 + - of: O-Factor 值 + - optimumInterval: 最优间隔(毫秒) + - dueDate: 到期时间(ISO 格式字符串) + - previousDate: 上次复习时间(ISO 格式字符串或 null) + - _afs: A-Factor 历史记录列表 + + 注意: + - 日期对象转换为 ISO 格式字符串以便序列化 + - 反序列化时需要在 load() 方法中转换回 datetime 对象 + - 保持与原始 JavaScript 版本的数据格式兼容 + """ + return { + "value": self.value, + "repetition": self.repetition, + "lapse": self.lapse, + "of": self.of, + "optimumInterval": self.optimum_interval, + "dueDate": ( + self.due_date.isoformat() + if isinstance(self.due_date, datetime.datetime) + else self.due_date + ), + "previousDate": ( + self.previous_date.isoformat() + if isinstance(self.previous_date, datetime.datetime) + else self.previous_date + ), + "_afs": self._afs, + } + + @classmethod + def load(cls, sm, data): + """ + 从序列化数据加载项目。 + + 参数: + sm: 所属的 SM 实例 + data: 序列化的项目数据字典 + + 返回: + 恢复状态的 Item 实例 + + 处理逻辑: + 1. 创建新的 Item 实例 + 2. 复制基本属性(value, repetition, lapse, of, optimumInterval, _afs) + 3. 转换日期字符串为 datetime 对象 + 4. 如果 previousDate 存在则转换, 否则设为 None + 5. 如果 _af 历史记录不为空, 设置当前 A-Factor 为最后一个值 + + 注意: + - 日期字符串应为 ISO 格式(如 data() 方法生成的格式) + - 保持与原始 JavaScript 版本的数据兼容性 + - 加载后项目状态完全恢复, 包括历史记录 + """ + item = cls(sm) + + # Copy basic properties + item.value = data.get("value") + item.repetition = data.get("repetition", -1) + item.lapse = data.get("lapse", 0) + item.of = data.get("of", 1.0) + item.optimum_interval = data.get("optimumInterval", sm.interval_base) + item._afs = data.get("_afs", []) + + # Parse dates + due_date_str = data.get("dueDate") + if due_date_str: + if isinstance(due_date_str, str): + item.due_date = datetime.datetime.fromisoformat( + due_date_str.replace("Z", "+00:00") + ) + else: + # Handle numeric timestamp + item.due_date = datetime.datetime.fromtimestamp(due_date_str / 1000) + + previous_date_str = data.get("previousDate") + if previous_date_str: + if isinstance(previous_date_str, str): + item.previous_date = datetime.datetime.fromisoformat( + previous_date_str.replace("Z", "+00:00") + ) + else: + item.previous_date = datetime.datetime.fromtimestamp( + previous_date_str / 1000 + ) + + # Initialize A-Factor if we have history + if item._afs: + item.af(sum(item._afs) / len(item._afs)) + + return item + + +class FI_G: + """ + 遗忘指数-评分图(FI-Grade Graph)。 + + 建立遗忘指数(Forgetting Index)与用户评分(Grade)之间的关系。 + 用于根据用户评分估计实际遗忘指数, 从而校正记忆模型。 + + 属性: + sm: 所属的 SM 实例 + points: 数据点列表, 每个点为 [fi, grade] + _graph: 缓存的回归模型 + MAX_POINTS_COUNT: 最大数据点数(5000) + GRADE_OFFSET: 评分偏移量(1), 避免评分为0时的数学问题 + + 算法原理: + 1. 收集 (遗忘指数, 评分) 数据点 + 2. 使用指数回归拟合 FI-Grade 关系 + 3. 根据评分估计遗忘指数, 用于校正 UF 和 A-Factor + + 默认初始化: + - 点1: (0, MAX_GRADE) - 遗忘指数为0时, 评分应为最高 + - 点2: (100, 0) - 遗忘指数为100时, 评分应为最低 + + 主要功能: + 1. 记录新的数据点 + 2. 根据评分估计遗忘指数 + 3. 更新图形(SM-15 算法的第10步) + """ + + MAX_POINTS_COUNT = 5000 + GRADE_OFFSET = 1 + + def __init__(self, sm, points=None): + self.sm = sm + self._graph = None + + if points is not None: + self.points = points + else: + # Initialize with default points + self.points = [] + self._register_point(0, MAX_GRADE) + self._register_point(100, 0) + + def _register_point(self, fi, g): + """Add a point to the graph.""" + self.points.append([fi, g + self.GRADE_OFFSET]) + + # Keep only recent points + if len(self.points) > self.MAX_POINTS_COUNT: + self.points = self.points[-self.MAX_POINTS_COUNT :] + + self._graph = None # Invalidate cached regression + + def update(self, grade, item, now=None): + """Update FI-G graph with new data (Step 10 in SM-15).""" + if now is None: + now = datetime.datetime.now() + + # Expected forgetting index + def expected_fi(): + # Simple linear forgetting curve assumption + return (item.uf(now) / item.of) * self.sm.requested_fi + + # Alternative method using forgetting curves (commented out) + # curve = self.sm.forgetting_curves.curves[item.repetition][item.af_index()] + # uf_val = curve.uf(100 - self.sm.requested_fi) + # return 100 - curve.retention(item.uf() / uf_val) + + self._register_point(expected_fi(), grade) + + def fi(self, grade): + """Estimate forgetting index for given grade.""" + if not self.points: + return 50.0 # Default value + + if self._graph is None: + self._graph = exponential_regression(self.points) + + estimated = self._graph["x"](grade + self.GRADE_OFFSET) + return max(0.0, min(100.0, estimated)) + + def grade(self, fi): + """Estimate grade for given forgetting index.""" + if not self.points: + return 2.5 # Default value + + if self._graph is None: + self._graph = exponential_regression(self.points) + + estimated = self._graph["y"](fi) + return estimated - self.GRADE_OFFSET + + def data(self): + """Serialize FI-G data.""" + return {"points": self.points} + + @classmethod + def load(cls, sm, data): + """Deserialize FI-G from data.""" + return cls(sm, data.get("points")) + + +class ForgettingCurve: + """ + 单个遗忘曲线, 针对特定的重复次数和 A-Factor。 + + 描述记忆保留率(Retention)随时间(通过 UF 表示)衰减的曲线。 + 每个曲线对应一个特定的(重复次数, A-Factor)组合。 + + 属性: + points: 数据点列表, 每个点为 [uf, retention] + _curve: 缓存的指数回归模型 + MAX_POINTS_COUNT: 最大数据点数(500) + FORGOTTEN: 遗忘状态的保留率值(1) + REMEMBERED: 成功回忆状态的保留率值(101) + + 数据表示: + - UF(使用因子): x轴, 表示时间(实际间隔/调整后最优间隔) + - 保留率: y轴, 1表示完全遗忘, 101表示完全回忆(实际为0-100%加偏移) + - 偏移量 FORGOTTEN=1 避免取对数时的数学问题 + + 算法原理: + 1. 收集 (UF, 回忆结果) 数据点 + 2. 使用指数回归拟合遗忘曲线 + 3. 根据 UF 预测保留率, 或根据保留率反推 UF + + 主要功能: + 1. 注册新的数据点(回忆结果) + 2. 计算给定 UF 的保留率 + 3. 计算给定保留率的 UF + 4. 序列化和反序列化 + """ + + MAX_POINTS_COUNT = 500 + FORGOTTEN = 1 + REMEMBERED = 100 + FORGOTTEN + + def __init__(self, points): + self.points = points + self._curve = None + + def register_point(self, grade, uf): + """Add a data point to the curve.""" + is_remembered = grade >= THRESHOLD_RECALL + self.points.append([uf, self.REMEMBERED if is_remembered else self.FORGOTTEN]) + + # Keep only recent points + if len(self.points) > self.MAX_POINTS_COUNT: + self.points = self.points[-self.MAX_POINTS_COUNT :] + + self._curve = None # Invalidate cached regression + + def retention(self, uf): + """Calculate retention probability for given UF.""" + if not self.points: + return 50.0 # Default retention + + if self._curve is None: + self._curve = exponential_regression(self.points) + + estimated = self._curve["y"](uf) + clamped = max(self.FORGOTTEN, min(estimated, self.REMEMBERED)) + return clamped - self.FORGOTTEN + + def uf(self, retention): + """Calculate UF for given retention probability.""" + if not self.points: + return 1.0 # Default UF + + if self._curve is None: + self._curve = exponential_regression(self.points) + + target = retention + self.FORGOTTEN + return max(0.0, self._curve["x"](target)) + + def data(self): + """Serialize curve data.""" + return self.points + + +class ForgettingCurves: + """ + 遗忘曲线矩阵(重复次数 × A-Factor)。 + + 包含 RANGE_REPETITION × RANGE_AF 个遗忘曲线, 每个曲线对应一个 + (重复次数, A-Factor)组合。这是 SM-15 算法的核心数据结构之一。 + + 属性: + sm: 所属的 SM 实例 + curves: 二维列表的遗忘曲线矩阵 [重复次数][A-Factor索引] + FORGOTTEN: 遗忘状态的保留率值(1) + REMEMBERED: 成功回忆状态的保留率值(101) + + 矩阵结构: + - 行: 重复次数(0 到 RANGE_REPETITION-1) + - 列: A-Factor 索引(0 到 RANGE_AF-1) + - 每个单元格: 一个 ForgettingCurve 实例 + + 初始化: + - 如果提供 points 参数: 从现有数据加载曲线 + - 否则: 生成初始曲线, 基于数学公式创建初始数据点 + + 主要功能: + 1. 为特定项目和评分注册数据点 + 2. 获取特定重复次数和 A-Factor 的曲线 + 3. 序列化和反序列化整个矩阵 + 4. 管理遗忘曲线数据的收集和更新 + + 算法作用: + - 建立 UF 与保留率之间的定量关系 + - 为 R-Factor 矩阵提供数据基础 + - 帮助估计项目的记忆强度随时间的变化 + """ + + FORGOTTEN = 1 + REMEMBERED = 100 + FORGOTTEN + + def __init__(self, sm, points=None): + self.sm = sm + self.curves = [] + + # Initialize curves matrix + for r in range(RANGE_REPETITION): + row = [] + for a in range(RANGE_AF): + if points is not None: + partial_points = points[r][a] + else: + # Generate initial points + if r > 0: + partial_points = [[0, self.REMEMBERED]] + [ + [ + MIN_AF + NOTCH_AF * i, + min( + self.REMEMBERED, + math.exp( + -(r + 1) + / 200 + * (i - a * math.sqrt(2 / (r + 1))) + ) + * (self.REMEMBERED - self.sm.requested_fi), + ), + ] + for i in range(21) + ] + else: + partial_points = [[0, self.REMEMBERED]] + [ + [ + MIN_AF + NOTCH_AF * i, + min( + self.REMEMBERED, + math.exp(-1 / (10 + 1 * (a + 1)) * (i - (a**0.6))) + * (self.REMEMBERED - self.sm.requested_fi), + ), + ] + for i in range(21) + ] + + row.append(ForgettingCurve(partial_points)) + self.curves.append(row) + + def register_point(self, grade, item, now=None): + """Register a data point in the appropriate curve.""" + if item.repetition > 0: + af_index = item.af_index() + else: + af_index = item.lapse + + self.curves[item.repetition][af_index].register_point(grade, item.uf(now)) + + def data(self): + """Serialize forgetting curves data.""" + return { + "points": [ + [self.curves[r][a].data() for a in range(RANGE_AF)] + for r in range(RANGE_REPETITION) + ] + } + + @classmethod + def load(cls, sm, data): + """Deserialize forgetting curves from data.""" + return cls(sm, data.get("points")) + + +class RFM: + """ + R-Factor 矩阵(回忆因子矩阵)。 + + R-Factor 表示在给定重复次数和 A-Factor 下, 达到目标遗忘指数所需的 UF 值。 + 实际上是遗忘曲线的包装器, 提供便捷的接口访问。 + + 属性: + sm: 所属的 SM 实例 + + 计算公式: + R-Factor = curve.uf(100 - requested_fi) + 其中 curve 是对应 (repetition, af_index) 的遗忘曲线 + uf() 方法返回达到指定保留率所需的 UF 值 + + 算法意义: + - R-Factor 是实际观察到的间隔乘数 + - 表示在特定记忆强度下, 达到目标遗忘水平所需的时间倍数 + - 用于与 O-Factor(最优因子)比较, 校正记忆模型 + - 是 O-Factor 矩阵计算的基础 + + 主要功能: + 获取特定重复次数和 A-Factor 索引的 R-Factor 值 + """ + + def __init__(self, sm): + self.sm = sm + + def rf(self, repetition, af_index): + """Get R-Factor for given repetition and A-Factor index.""" + return self.sm.forgetting_curves.curves[repetition][af_index].uf( + 100 - self.sm.requested_fi + ) + + +class OFM: + """ + O-Factor 矩阵(最优因子矩阵)。 + + O-Factor 表示在给定重复次数和 A-Factor 下的最优间隔乘数。 + 基于 R-Factor 矩阵通过幂律回归计算得出。 + + 属性: + sm: 所属的 SM 实例 + _ofm: 缓存的 O-Factor 矩阵 + _ofm0: 缓存的重复次数为0时的 O-Factor 数组 + INITIAL_REP_VALUE: 初始重复值(1) + + 矩阵结构: + - 行: 重复次数(0 到 RANGE_REPETITION-1) + - 列: A-Factor 索引(0 到 RANGE_AF-1) + - 每个单元格: O-Factor 值 + + 算法原理(update() 方法, SM-15 第8步): + 1. 对于每个 A-Factor 索引: + a. 收集 (重复次数, R-Factor) 数据点 + b. 使用定点幂律回归拟合, 固定点 (1, 1) + c. 生成该 A-Factor 对应的 O-Factor 数组 + 2. 对于重复次数0: + a. 收集 (A-Factor, R-Factor) 数据点 + b. 使用幂律回归拟合 + c. 生成重复次数0时的 O-Factor 数组 + + 主要功能: + 1. 更新 O-Factor 矩阵 + 2. 获取特定重复次数和 A-Factor 索引的 O-Factor + 3. 从 O-Factor 和 UF 反推 A-Factor + """ + + INITIAL_REP_VALUE = 1 + + def __init__(self, sm): + self.sm = sm + self._ofm = None + self._ofm0 = None + self.update() + + def update(self): + """Update O-Factor matrix (Step 8 in SM-15).""" + + # Helper functions + def af_from_index(a): + return a * NOTCH_AF + MIN_AF + + def rep_from_index(r): + return r + self.INITIAL_REP_VALUE + + # Calculate D-factors + dfs = [] + for a in range(RANGE_AF): + points = [ + [rep_from_index(r), self.sm.rfm.rf(r, a)] + for r in range(1, RANGE_REPETITION) + ] + fixed_point = [rep_from_index(1), af_from_index(a)] + model = fixed_point_power_law_regression(points, fixed_point) + dfs.append(model["b"]) + + # Transform D-factors + dfs_transformed = [af_from_index(a) / (2 ** dfs[a]) for a in range(RANGE_AF)] + + # Linear regression on D-factors + decay_points = [[a, dfs_transformed[a]] for a in range(RANGE_AF)] + decay = linear_regression(decay_points) + + # Create O-Factor model for each A-Factor + def create_ofm(a): + af = af_from_index(a) + b = ( + math.log(af / decay["y"](a)) / math.log(rep_from_index(1)) + if decay["y"](a) != 0 + else 0 + ) + model = power_law_model(af / (rep_from_index(1) ** b), b) + + return { + "y": lambda r: model["y"](rep_from_index(r)), + "x": lambda y: model["x"](y) - self.INITIAL_REP_VALUE, + } + + self._ofm = [create_ofm(a) for a in range(RANGE_AF)] + + # Create O-Factor model for repetition 0 + ofm0_points = [[a, self.sm.rfm.rf(0, a)] for a in range(RANGE_AF)] + ofm0 = exponential_regression(ofm0_points) + self._ofm0 = lambda a: ofm0["y"](a) + + def of(self, repetition, af_index): + """Get O-Factor for given repetition and A-Factor index.""" + if repetition == 0: + return self._ofm0(af_index) # type: ignore + else: + return self._ofm[af_index]["y"](repetition) # type: ignore + + def af(self, repetition, of_val): + """Get A-Factor index for given repetition and O-Factor.""" + af_from_idx = lambda a: a * NOTCH_AF + MIN_AF + + # Find closest A-Factor index + min_diff = float("inf") + min_index = 0 + + for a in range(RANGE_AF): + diff = abs(self.of(repetition, a) - of_val) + if diff < min_diff: + min_diff = diff + min_index = a + + return af_from_idx(min_index) + + +class SM: + """ + SM-15 算法主调度器。 + + 这是 SM-15 间隔重复算法的核心类, 负责协调所有组件和算法流程。 + 管理项目队列、处理用户交互、执行算法更新步骤。 + + 属性: + requested_fi: 目标遗忘指数(默认10%, 表示希望10%的项目被遗忘) + interval_base: 基础间隔(3小时, 毫秒单位) + q: 项目队列, 按 due_date 排序 + fi_g: FI-Grade 图实例 + forgetting_curves: 遗忘曲线矩阵实例 + rfm: R-Factor 矩阵实例 + ofm: O-Factor 矩阵实例 + + 主要功能: + 1. 项目管理: 添加、删除、查询项目 + 2. 复习调度: 获取到期项目, 处理用户评分 + 3. 算法协调: 调用各组件更新算法参数 + 4. 数据持久化: 保存和加载学习状态 + 5. 队列管理: 维护按到期时间排序的项目队列 + + 算法流程概览: + 1. 添加项目时创建 Item 实例, 插入排序队列 + 2. 复习时获取到期项目, 接收用户评分 + 3. 调用 answer() 处理评分, 更新项目状态 + 4. 更新 FI-Grade 图、遗忘曲线、O-Factor 矩阵 + 5. 重新计算项目的最优间隔和下次到期时间 + 6. 将项目重新插入队列的适当位置 + + 使用方式: + 1. 创建 SM 实例 + 2. 使用 add_item() 添加学习项目 + 3. 使用 next_item() 获取需要复习的项目 + 4. 使用 answer() 处理用户评分 + 5. 使用 data() 和 load() 保存/加载学习进度 + """ + + def __init__(self): + """ + 初始化 SM-15 调度器。 + + 设置默认参数并初始化所有算法组件。 + + 默认参数: + - requested_fi: 10.0(目标遗忘指数10%) + - interval_base: 3 * 60 * 60 * 1000(3小时, 毫秒单位) + - q: 空项目队列(按到期时间排序) + + 初始化的组件: + - fi_g: FI-Grade 图, 管理遗忘指数与评分的关系 + - forgetting_curves: 遗忘曲线矩阵, 存储记忆保留率数据 + - rfm: R-Factor 矩阵, 包装遗忘曲线提供 R-Factor 查询 + - ofm: O-Factor 矩阵, 计算和管理最优因子 + + 注意: + - interval_base 是算法的基础时间单位, 所有间隔计算基于此值 + - requested_fi 是算法的核心目标, 控制复习间隔的激进程度 + - 组件间存在依赖关系, 初始化顺序重要 + """ + self.requested_fi = 10.0 # target forgetting index (10%) + self.interval_base = 3 * 60 * 60 * 1000 # 3 hours in milliseconds + self.q = [] # items sorted by due_date + + # Initialize components + self.fi_g = FI_G(self) + self.forgetting_curves = ForgettingCurves(self) + self.rfm = RFM(self) + self.ofm = OFM(self) + + def _find_index_to_insert(self, item, r=None): + """Binary search to find insertion index for sorted queue.""" + if r is None: + r = list(range(len(self.q))) + + if not r: + return 0 + + v = item.due_date + i = len(r) // 2 + + if len(r) == 1: + return r[i] if v < self.q[r[i]].due_date else r[i] + 1 + + if v < self.q[r[i]].due_date: + return self._find_index_to_insert(item, r[:i]) + else: + return self._find_index_to_insert(item, r[i:]) + + def add_item(self, value): + """Add a new item to the queue.""" + item = Item(self, value) + index = self._find_index_to_insert(item) + self.q.insert(index, item) + + def next_item(self, is_advanceable=False): + """Get next item due for review.""" + if not self.q: + return None + + now = datetime.datetime.now() + if is_advanceable or self.q[0].due_date < now: + return self.q[0] + + return None + + def answer(self, grade, item, now=None): + """Process answer for given item.""" + if now is None: + now = datetime.datetime.now() + + self._update(grade, item, now) + self.discard(item) + + index = self._find_index_to_insert(item) + self.q.insert(index, item) + + def _update(self, grade, item, now=None): + """Internal update method.""" + if now is None: + now = datetime.datetime.now() + + if item.repetition >= 0: + self.forgetting_curves.register_point(grade, item, now) + self.ofm.update() + self.fi_g.update(grade, item, now) + + item.answer(grade, now) + + def discard(self, item): + """Remove item from queue.""" + if item in self.q: + self.q.remove(item) + + def data(self): + """Serialize SM state.""" + return { + "requestedFI": self.requested_fi, + "intervalBase": self.interval_base, + "q": [item.data() for item in self.q], + "fi_g": self.fi_g.data(), + "forgettingCurves": self.forgetting_curves.data(), + "version": 1, + } + + @classmethod + def load(cls, data): + """Deserialize SM from data.""" + sm = cls() + sm.requested_fi = data.get("requestedFI", 10.0) + sm.interval_base = data.get("intervalBase", 3 * 60 * 60 * 1000) + + # Load items + items_data = data.get("q", []) + sm.q = [Item.load(sm, item_data) for item_data in items_data] + + # Load components + sm.fi_g = FI_G.load(sm, data.get("fi_g", {})) + sm.forgetting_curves = ForgettingCurves.load( + sm, data.get("forgettingCurves", {}) + ) + + # Reinitialize RFM and update OFM + sm.rfm = RFM(sm) + sm.ofm = OFM(sm) + sm.ofm.update() + + return sm + + +# ============================================================================ +# Test Functions (for internal testing) +# ============================================================================ + +_test = { + "exponentialRegression": exponential_regression, + "linearRegression": linear_regression, + "powerLawRegression": power_law_regression, + "fixedPointPowerLawRegression": fixed_point_power_law_regression, + "linearRegressionThroughOrigin": linear_regression_through_origin, +} + +# ============================================================================ +# CLI Interface +# ============================================================================ + + +def main(): + """ + 简单的闪卡命令行应用程序。 + + 提供交互式命令行界面, 使用户能够使用 SM-15 算法进行闪卡学习。 + + 可用命令: + - a/add: 添加新卡片 + - n/next: 复习下一个到期的卡片 + - N/Next: 复习下一个卡片(即使未到期) + - s/save: 保存学习进度到文件 + - l/load: 从文件加载学习进度 + - e/exit: 退出程序 + - eval: 执行 Python 表达式(调试用) + - list: 列出所有卡片 + + 使用流程: + 1. 启动程序显示命令提示 + 2. 输入 'a' 添加新卡片, 依次输入正面和背面内容 + 3. 输入 'n' 复习到期的卡片 + 4. 对显示的卡片输入评分 (0-5) 或 'D' 丢弃卡片 + 5. 重复步骤3-4进行复习 + 6. 使用 's' 保存进度, 'l' 加载进度 + + 数据文件: + - 默认保存文件: data.json + - 格式: JSON, 包含所有卡片状态和算法数据 + - 兼容性: 与原始 JavaScript 版本的数据格式兼容 + + 注意事项: + - 评分范围: 0 (完全遗忘) 到 5 (完美回忆) + - 阈值: 评分 >= 3 表示成功回忆 + - 时间单位: 内部使用毫秒, 但用户界面使用自然时间表示 + """ + import sys + + print("(a)add, (n)next, (N)next advanceably, (s)save, (l)load, (e)exit") + + mode = ["entrance"] + data = None + sm = SM() + + def goto_entrance(): + nonlocal mode, data + mode = ["entrance"] + data = None + sys.stdout.write("sm> ") + sys.stdout.flush() + + goto_entrance() + + while True: + try: + user_input = input().strip() + except EOFError: + break + + if mode[0] == "entrance": + if user_input in ["a", "add"]: + mode = ["add"] + elif user_input in ["n", "next"]: + mode = ["next"] + elif user_input in ["N", "Next"]: + mode = ["next", "_adv"] + elif user_input in ["s", "save"]: + mode = ["save"] + elif user_input in ["l", "load"]: + mode = ["load"] + elif user_input in ["e", "exit"]: + mode = ["exit"] + elif user_input == "eval": + mode = ["eval"] + elif user_input == "list": + mode = ["list"] + else: + goto_entrance() + continue + + if mode[0] == "add": + if len(mode) == 1: + data = {"front": None, "back": None} + print("Enter the front of the new card:") + mode.append("front") + elif mode[1] == "front": + data["front"] = user_input # type: ignore + print("Enter the back of the new card:") + mode[1] = "back" + elif mode[1] == "back": + data["back"] = user_input # type: ignore + sm.add_item(data) + goto_entrance() + + elif mode[0] == "next": + if mode[1] in ["_adv", None]: + is_advanceable = mode[1] == "_adv" + data = sm.next_item(is_advanceable) + + if data is None: + if sm.q: + next_due = sm.q[0].due_date + print( + f'There is no card that can be shown now. The next card is due at "{next_due}".' + ) + else: + print("There is no card.") + goto_entrance() + else: + print( + f"How much do you remember [{data.value.get('front', 'No front')}]:" + ) + mode[1] = "review" + + elif mode[1] == "review": + try: + g = int(user_input) + if 0 <= g <= 5: + sm.answer(g, data) + print(f"The answer was [{data.value.get('back', 'No back')}].") # type: ignore + goto_entrance() + elif user_input == "D": + sm.discard(data) + goto_entrance() + else: + print( + "The value should be from '0' (bad) to '5' (good). Otherwise 'D' to discard:" + ) + except ValueError: + print("Please enter a number from 0 to 5, or 'D' to discard:") + + elif mode[0] == "save": + if len(mode) == 1: + print( + "Enter file name to save configuration. (default name is [data.json]):" + ) + mode.append(True) # type: ignore + else: + filename = user_input if user_input else "data.json" + with open(filename, "w") as f: + json.dump(sm.data(), f, indent=2) + print(f"Saved to {filename}") + goto_entrance() + + elif mode[0] == "load": + if len(mode) == 1: + print( + "Enter file name to load configuration. (default name is [data.json]):" + ) + mode.append(True) # type: ignore + else: + filename = user_input if user_input else "data.json" + with open(filename, "r") as f: + data = json.load(f) + sm = SM.load(data) + print(f"Loaded from {filename}") + goto_entrance() + + elif mode[0] == "exit": + if len(mode) == 1: + print("Exiting...") + break + + elif mode[0] == "eval": + if len(mode) == 1: + mode.append(True) # type: ignore + else: + try: + result = eval(user_input) + print(result) + except Exception as e: + print(f"Error: {e}") + goto_entrance() + + elif mode[0] == "list": + for item in sm.q: + print(json.dumps(item.data())) + goto_entrance() + + +if __name__ == "__main__": + try: + main() + except Exception as error: + print(f"An error occurred: {error}") + sys.exit(1) diff --git a/src/heurams/kernel/particles/electron.py b/src/heurams/kernel/particles/electron.py index 5bfc4a9..65f7ffe 100644 --- a/src/heurams/kernel/particles/electron.py +++ b/src/heurams/kernel/particles/electron.py @@ -9,7 +9,7 @@ logger = get_logger(__name__) class Electron: """电子: 记忆分析元数据及算法""" - def __init__(self, ident: str, algodata: dict = {}, algo_name: str = "supermemo2"): + def __init__(self, ident: str, algodata: dict = {}, algo_name: str = "SM-2"): """初始化电子对象 (记忆数据) Args: diff --git a/src/heurams/kernel/puzzles/base.py b/src/heurams/kernel/puzzles/base.py index 864c108..688d7c8 100644 --- a/src/heurams/kernel/puzzles/base.py +++ b/src/heurams/kernel/puzzles/base.py @@ -8,7 +8,7 @@ class BasePuzzle: """谜题基类""" def refresh(self): - logger.debug("BasePuzzle.refresh 被调用(未实现)") + logger.debug("BasePuzzle.refresh 被调用(未实现)") raise NotImplementedError("谜题对象未实现 refresh 方法") def __str__(self): diff --git a/src/heurams/kernel/puzzles/recognition.py b/src/heurams/kernel/puzzles/recognition.py index f54e6d3..964db3f 100644 --- a/src/heurams/kernel/puzzles/recognition.py +++ b/src/heurams/kernel/puzzles/recognition.py @@ -16,5 +16,5 @@ class RecognitionPuzzle(BasePuzzle): super().__init__() def refresh(self): - logger.debug("RecognitionPuzzle.refresh(空实现)") + logger.debug("RecognitionPuzzle.refresh(空实现)") pass diff --git a/src/heurams/kernel/reactor/procession.py b/src/heurams/kernel/reactor/procession.py index 4ca253e..39811b6 100644 --- a/src/heurams/kernel/reactor/procession.py +++ b/src/heurams/kernel/reactor/procession.py @@ -52,7 +52,7 @@ class Procession: self.queue.append(atom) logger.debug("原子已追加到队列, 新队列长度=%d", len(self.queue)) else: - logger.debug("原子未追加(重复或队列长度<=1)") + logger.debug("原子未追加(重复或队列长度<=1)") def __len__(self): length = len(self.queue) - self.cursor diff --git a/src/heurams/providers/llm/openai.py b/src/heurams/providers/llm/openai.py index 910ef0b..43a74f7 100644 --- a/src/heurams/providers/llm/openai.py +++ b/src/heurams/providers/llm/openai.py @@ -2,4 +2,4 @@ from heurams.services.logger import get_logger logger = get_logger(__name__) -logger.debug("OpenAI provider 模块已加载(未实现)") +logger.debug("OpenAI provider 模块已加载(未实现)") diff --git a/src/heurams/services/logger.py b/src/heurams/services/logger.py index e5a3147..d1662d5 100644 --- a/src/heurams/services/logger.py +++ b/src/heurams/services/logger.py @@ -45,7 +45,7 @@ def setup_logging( # 创建formatter formatter = logging.Formatter(log_format, date_format) - # 创建文件handler(使用RotatingFileHandler防止日志过大) + # 创建文件handler(使用RotatingFileHandler防止日志过大) file_handler = logging.handlers.RotatingFileHandler( filename=log_path, maxBytes=max_bytes, @@ -55,7 +55,7 @@ def setup_logging( file_handler.setFormatter(formatter) file_handler.setLevel(log_level) - # 配置root logger - 设置为 WARNING 级别(只记录重要信息) + # 配置root logger - 设置为 WARNING 级别(只记录重要信息) root_logger = logging.getLogger() root_logger.setLevel(logging.WARNING) # 这里改为 WARNING @@ -63,7 +63,7 @@ def setup_logging( for handler in root_logger.handlers[:]: root_logger.removeHandler(handler) - # 创建自己的应用logger(单独设置DEBUG级别) + # 创建自己的应用logger(单独设置DEBUG级别) app_logger = logging.getLogger("heurams") app_logger.setLevel(log_level) # 保持DEBUG级别 app_logger.addHandler(file_handler) @@ -146,7 +146,7 @@ def exception(msg: str, *args, **kwargs) -> None: get_logger().exception(msg, *args, **kwargs) -# 初始化日志系统(硬编码配置) +# 初始化日志系统(硬编码配置) setup_logging() diff --git a/tests/interface/test_dashboard.py b/tests/interface/test_dashboard.py index dd0ea91..be519ac 100644 --- a/tests/interface/test_dashboard.py +++ b/tests/interface/test_dashboard.py @@ -17,7 +17,7 @@ from heurams.services.config import ConfigFile class TestDashboardScreenUnit(unittest.TestCase): - """DashboardScreen 的单元测试(不启动完整应用).""" + """DashboardScreen 的单元测试(不启动完整应用).""" def setUp(self): """在每个测试之前运行, 设置临时目录和配置.""" diff --git a/tests/kernel/algorithms/test_sm2.py b/tests/kernel/algorithms/test_sm2.py index 03713fc..9075715 100644 --- a/tests/kernel/algorithms/test_sm2.py +++ b/tests/kernel/algorithms/test_sm2.py @@ -94,7 +94,7 @@ class TestSM2Algorithm(unittest.TestCase): SM2Algorithm.revisor(algodata, feedback=5, is_new_activation=True) self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 0) self.assertEqual(algodata[SM2Algorithm.algo_name]["efactor"], 2.5) - # interval 应为 1(因为 rept=0) + # interval 应为 1(因为 rept=0) self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 1) def test_revisor_efactor_calculation(self): diff --git a/tests/kernel/particles/test_electron.py b/tests/kernel/particles/test_electron.py index 2f1ea47..75c4552 100644 --- a/tests/kernel/particles/test_electron.py +++ b/tests/kernel/particles/test_electron.py @@ -27,7 +27,7 @@ class TestElectron(unittest.TestCase): self.assertEqual(electron.algo, algorithms["supermemo2"]) self.assertIn(electron.algo, electron.algodata) self.assertIsInstance(electron.algodata[electron.algo], dict) - # 检查默认值(排除动态字段) + # 检查默认值(排除动态字段) defaults = electron.algo.defaults for key, value in defaults.items(): if key == "last_modify": diff --git a/tests/kernel/puzzles/test_mcq.py b/tests/kernel/puzzles/test_mcq.py index 066dcee..61dc1a9 100644 --- a/tests/kernel/puzzles/test_mcq.py +++ b/tests/kernel/puzzles/test_mcq.py @@ -48,7 +48,7 @@ class TestMCQPuzzle(unittest.TestCase): # 模拟 random.sample 返回前两个映射项 mock_sample.side_effect = [ [("q1", "a1"), ("q2", "a2")], # 选择问题 - ["j1", "j2", "j3"], # 为每个问题选择干扰项(实际调用两次) + ["j1", "j2", "j3"], # 为每个问题选择干扰项(实际调用两次) ] puzzle.refresh() @@ -59,7 +59,7 @@ class TestMCQPuzzle(unittest.TestCase): self.assertEqual(puzzle.answer, ["a1", "a2"]) # 检查 options 列表 self.assertEqual(len(puzzle.options), 2) - # 每个选项列表应包含 4 个选项(正确答案 + 3 个干扰项) + # 每个选项列表应包含 4 个选项(正确答案 + 3 个干扰项) self.assertEqual(len(puzzle.options[0]), 4) self.assertEqual(len(puzzle.options[1]), 4) # random.shuffle 应被调用