Compare commits
35 Commits
0.4.1
...
refactor/v
| Author | SHA1 | Date | |
|---|---|---|---|
| 016ae16100 | |||
| 5b1a627ddb | |||
| 2e4a663adf | |||
| ed85655e8b | |||
| 47c14e520b | |||
| ca7ef92b05 | |||
| 22b41789eb | |||
| e1c935f348 | |||
| 65486794b7 | |||
| c585c79e73 | |||
| 5d883b015e | |||
| a689604021 | |||
| 55c656e8f9 | |||
| 94aaef386b | |||
| aacf4fdbdf | |||
| eced6130f1 | |||
| 9b32a01a10 | |||
| 94839c6369 | |||
| 573bf22b2b | |||
| eaa38fb880 | |||
| b65dad6a1f | |||
| c13b8bed98 | |||
| b5f30ec4ee | |||
| 87cefedb61 | |||
| 0fb421412e | |||
| ee0646ac79 | |||
| d8fc18166d | |||
| a2e12c7462 | |||
| 1efe034a59 | |||
| 0a365b568a | |||
| e303d4dc1e | |||
| cb78290f05 | |||
| e0417981b1 | |||
| a0660d3348 | |||
| f5e0417292 |
11
.gitignore
vendored
11
.gitignore
vendored
@@ -5,19 +5,16 @@
|
||||
__pycache__/
|
||||
.idea/
|
||||
cache/
|
||||
#nucleon/test.toml
|
||||
electron/test.toml
|
||||
data/repo/cngk
|
||||
*.egg-info/
|
||||
build/
|
||||
dist/
|
||||
old/
|
||||
# config/
|
||||
data/cache/
|
||||
data/electron/
|
||||
data/nucleon/
|
||||
!data/nucleon/test*
|
||||
data/orbital/
|
||||
data/global/
|
||||
config/config_dev.toml
|
||||
AGENTS.md
|
||||
*.log.*
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.13
|
||||
@@ -3,19 +3,22 @@
|
||||
欢迎为此项目做出贡献!
|
||||
本项目是一个开源项目, 我们鼓励社区成员参与改进.
|
||||
|
||||
## 开发流程
|
||||
## 开发规范
|
||||
|
||||
1. **分支划分**:
|
||||
1. 分支划分:
|
||||
- `main` 分支: 稳定版本
|
||||
- `dev` 分支: 开发版本
|
||||
- 功能分支: 从 `dev` 分支创建, 命名格式为 `feature/描述` 或 `fix/描述` 或 `refactor/描述`
|
||||
2. **代码风格**:
|
||||
2. 代码风格:
|
||||
- 请使用 Black 格式化代码
|
||||
- 遵循 PEP 8 规范
|
||||
- 添加适当的文档字符串
|
||||
3. **提交消息**:
|
||||
3. 提交消息:
|
||||
- 使用简体中文或英文撰写清晰的提交消息
|
||||
- 格式: 遵循 Conventional Commits 规范
|
||||
4. 合并方式:
|
||||
- 不使用 Fast-forward 合并
|
||||
- 可以设置 `git config merge.ff false`
|
||||
|
||||
## 设置开发环境
|
||||
|
||||
|
||||
109
README.md
109
README.md
@@ -1,35 +1,80 @@
|
||||
# 潜进 (HeurAMS) - 启发式辅助记忆程序
|
||||
|
||||
## 概述
|
||||
"潜进" (HeurAMS: Heuristic Auxiliary Memorizing Scheduler, 启发式记忆辅助调度器) 是为习题册, 古诗词, 及其他问答/记忆/理解型知识设计的多用途辅助记忆软件, 提供动态规划的优化记忆方案
|
||||
"潜进" (HeurAMS: Heuristic Auxiliary Memorizing Scheduler, 启发式记忆辅助调度器) 是为习题册, 古诗词, 及其他问答/记忆/理解型知识设计的开放源代码多用途辅助记忆软件, 提供动态规划的优化记忆方案
|
||||
|
||||
## 关于此仓库
|
||||
"潜进" 软件组项目包含多个子项目
|
||||
此仓库包含了 "潜进" 项目的核心和基于 Textual 的基本用户界面的实现
|
||||
本仓库为 "潜进" 软件组项目的核心部分, 包含核心功能模块以及基于 Textual 框架的基础用户界面(heurams.interface)实现
|
||||
除了通过用户界面进行学习外, 你也可以在 Python 中导入 `heurams` 库, 使用其中实现的状态机, 算法迭代器和数据模型构建辅助记忆功能
|
||||
本仓库在 AGPLv3 下开放源代码(详见 LICENSE 文件)
|
||||
|
||||
## 开发进程
|
||||
- 0.0.x: 简易调度器实现与最小原型.
|
||||
- 0.1.x: 命令行操作的调度器.
|
||||
- 0.2.x: 使用 Textual 构建富文本终端用户界面, 项目可行性验证, 采用 SM-2 原始算法, 评估方式为用户自评估的原型.
|
||||
- 0.3.x: 简单的多文件项目, 创建了记忆内容/算法数据结构, 基于 SM-2 改进算法的自动复习测评评估. 重点设计古诗文记忆理解功能, 以及 TUI 界面实现, 简单的 TTS 集成.
|
||||
- 0.4.x: 使用模块管理解耦设计, 增加文档与类型标注, 采用上下文设计模式的隐式依赖注入与遵从 IoC, 注册器设计的算法与功能实现, 支持其他调度算法模块 (SM-2, FSRS) 与谜题模块, 采用日志调试, 更新文件格式, 引入动态数据模式(宏驱动的动态内容生成), 与基于文件的策略调控, 更佳的用户数据处理, 加入模块化扩展集成, 将算法数据格式换为 json 提高性能, 采用 provider-service 抽象架构, 支持切换服务提供者, 整体兼容性改进.
|
||||
> 下一步?
|
||||
> 使用 Flutter 构建酷酷的现代化前端, 增加云同步/文档源服务...
|
||||
## 版本日志
|
||||
|
||||
### 0.0.x
|
||||
- 简易调度器实现与最小原型
|
||||
|
||||
### 0.1.x
|
||||
- 命令行操作的调度器
|
||||
|
||||
### 0.2.x
|
||||
- 使用 Textual 构建富文本终端用户界面
|
||||
- 项目可行性验证
|
||||
- 采用 SM-2 原始算法, 评估方式为用户自评估原型
|
||||
|
||||
### 0.3.x Frontal 前端
|
||||
- 简单的多文件项目
|
||||
- 创建了记忆内容/算法数据结构
|
||||
- 基于 SM-2 改进算法的自动复习测评评估
|
||||
- 重点设计古诗文记忆理解功能
|
||||
- TUI 界面改进
|
||||
- 简单的 TTS 集成
|
||||
|
||||
### 0.4.x Fledge 雏鸟
|
||||
- 开发目标转为多用途
|
||||
- 使用模块管理解耦设计
|
||||
- 增加文档与类型标注
|
||||
- 采用上下文设计模式的隐式依赖注入与遵从 IoC, 注册器设计的算法与功能实现
|
||||
- 支持其他调度算法模块 (SM-2, SM-18M 参考理论变体, FSRS) 与谜题模块
|
||||
- 采用规范的日志调试取代 Textual Devtools 调试
|
||||
- 更新数据持久化协议规范
|
||||
- 引入动态数据模式 (宏驱动的动态内容生成) , 与基于文件的策略调控
|
||||
- 更佳的用户数据处理
|
||||
- 加入模块化扩展集成
|
||||
- 更换算法数据格式, 提高性能
|
||||
- 采用 provider-service 抽象架构, 支持切换服务提供者
|
||||
- 整体兼容性改进
|
||||
|
||||
### 0.5.x Fulcrum 支点
|
||||
- 以仓库 (repository) 对象作为文件系统与运行时对象间的桥梁, 提高解耦性与性能
|
||||
- 使用具有列表-字典 API 同步特性的 "Lict" 对象作为 Repo 数据的内部存储
|
||||
- 将粒子对象作为纯运行时对象, 数据通过引用自动同步至 Repo, 减少负担
|
||||
- 实现声音形式回顾 "电台" 功能
|
||||
- 改进数据存储结构, 实现选择性持久化
|
||||
- 增强可配置性
|
||||
- 使用 Transitions 状态机库重新实现 Reactor 模块系列状态机, 增强可维护性
|
||||
- 实现整体回顾记忆功能, 与队列式记忆功能并列
|
||||
- 加入状态机快照功能 (基于 pickle) , 使中断的记忆流程得以恢复
|
||||
- 增加 "整体文章引用" 功能, 实现从一篇长文本中摘取内容片段记忆并在原文中高亮查看的组织操作
|
||||
|
||||
### 下一步?
|
||||
- 增加云同步 / 文档源服务
|
||||
- 使用 Flutter 构建酷酷的现代化前端
|
||||
- ...
|
||||
|
||||
## 特性
|
||||
|
||||
### 间隔迭代算法
|
||||
> 许多出版物都广泛讨论了不同重复间隔对学习效果的影响. 特别是, 间隔效应被认为是一种普遍现象. 间隔效应是指, 如果重复的间隔是分散/稀疏的, 而不是集中重复, 那么学习任务的表现会更好. 因此, 有观点提出, 学习中使用的最佳重复间隔是**最长的、但不会导致遗忘的间隔**.
|
||||
> 许多出版物都广泛讨论了不同重复间隔对学习效果的影响. 特别是, 间隔效应被认为是一种普遍现象. 间隔效应是指, 如果重复的间隔是分散/稀疏的, 而不是集中重复, 那么学习任务的表现会更好. 因此, 有观点提出, 学习中使用的最佳重复间隔是**最长的, 但不会导致遗忘的间隔**.
|
||||
- 采用经实证的 SM-2 间隔迭代算法, 此算法亦用作 Anki 闪卡记忆软件的默认闪卡调度器
|
||||
- 动态规划每个记忆单元的记忆间隔时间表
|
||||
- 动态跟踪记忆反馈数据, 优化长期记忆保留率与稳定性
|
||||
|
||||
### 学习进程优化
|
||||
- 逐字解析: 支持逐字详细释义解析
|
||||
- 语法分析: 接入生成式人工智能, 支持古文结构交互式解析
|
||||
- 自然语音: 集成微软神经网络文本转语音 (TTS) 技术
|
||||
- 多种谜题类型: 选择题 (MCQ)、填空题 (Cloze)、识别题 (Recognition)
|
||||
- 元数据配置: 支持配置详细附加数据
|
||||
- 自然语音: 集成文本转语音 (TTS) 功能, 支持"电台"回顾功能
|
||||
- 多种谜题类型: 选择题 (MCQ), 填空题 (Cloze), 识别题 (Recognition)
|
||||
- 动态内容生成: 支持宏驱动的模板系统, 根据上下文动态生成题目
|
||||
- 云同步支持: 通过多种协议同步数据到远程服务器
|
||||
|
||||
### 实用用户界面
|
||||
- 响应式 Textual 框架构建的跨平台 TUI 界面
|
||||
@@ -37,10 +82,10 @@
|
||||
- 简洁直观的复习流程设计
|
||||
|
||||
### 架构特性
|
||||
- 模块化设计: 算法、谜题、服务提供者可插拔替换
|
||||
- 模块化设计: 算法, 谜题, 服务提供者可插拔替换
|
||||
- 上下文管理: 使用 ContextVar 实现隐式依赖注入
|
||||
- 数据持久化: TOML 配置与内容, JSON 算法状态
|
||||
- 服务抽象: 音频播放、TTS、LLM 通过 provider 架构支持多种后端
|
||||
- 服务抽象: 音频播放, TTS, LLM 通过 provider 架构支持多种后端
|
||||
- 完整日志系统: 带轮转的日志记录, 便于调试
|
||||
|
||||
## 安装
|
||||
@@ -62,31 +107,17 @@
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## 使用
|
||||
|
||||
### 启动应用
|
||||
## 启动应用
|
||||
```bash
|
||||
# 在任一目录(建议是空目录或者包根目录, 将被用作存放数据)下运行
|
||||
python -m heurams.interface
|
||||
```
|
||||
|
||||
### 数据目录结构
|
||||
应用会在工作目录下创建以下数据目录:
|
||||
- `data/nucleon/`: 记忆内容 (TOML 格式)
|
||||
- `data/electron/`: 算法状态 (JSON 格式)
|
||||
- `data/orbital/`: 策略配置 (TOML 格式)
|
||||
- `data/cache/`: 音频缓存文件
|
||||
- `data/template/`: 内容模板
|
||||
|
||||
首次运行时会自动创建这些目录.
|
||||
|
||||
## 配置
|
||||
|
||||
配置文件位于 `config/config.toml`(相对于工作目录). 如果不存在, 会使用内置的默认配置.
|
||||
配置文件位于 `./data/config/config.toml`(相对于工作目录). 如果不存在, 会使用内置的默认配置.
|
||||
|
||||
## 项目结构
|
||||
|
||||
### 架构图
|
||||
### 架构图(待更新 0.5.0)
|
||||
|
||||
以下 Mermaid 图展示了 HeurAMS 的主要组件及其关系:
|
||||
|
||||
@@ -104,6 +135,7 @@ graph TB
|
||||
Timer[时间服务]
|
||||
AudioService[音频服务]
|
||||
TTSService[TTS服务]
|
||||
SyncService[同步服务]
|
||||
OtherServices[其他服务]
|
||||
end
|
||||
|
||||
@@ -146,17 +178,18 @@ graph TB
|
||||
Algorithms --> Files
|
||||
```
|
||||
|
||||
### 目录结构
|
||||
### 目录结构(待更新 0.5.0)
|
||||
```
|
||||
src/heurams/
|
||||
├── __init__.py # 包入口点
|
||||
├── context.py # 全局上下文、路径、配置上下文管理器
|
||||
├── context.py # 全局上下文, 路径, 配置上下文管理器
|
||||
├── services/ # 核心服务
|
||||
│ ├── config.py # 配置管理
|
||||
│ ├── logger.py # 日志系统
|
||||
│ ├── timer.py # 时间服务
|
||||
│ ├── audio_service.py # 音频播放抽象
|
||||
│ └── tts_service.py # 文本转语音抽象
|
||||
│ ├── tts_service.py # 文本转语音抽象
|
||||
│ └── sync_service.py # WebDAV 同步服务
|
||||
├── kernel/ # 核心业务逻辑
|
||||
│ ├── algorithms/ # 间隔重复算法 (FSRS, SM2)
|
||||
│ ├── particles/ # 数据模型 (Atom, Electron, Nucleon, Orbital)
|
||||
|
||||
59
data/config/config.toml
Normal file
59
data/config/config.toml
Normal file
@@ -0,0 +1,59 @@
|
||||
# [调试] 将更改保存到文件
|
||||
persist_to_file = 1
|
||||
|
||||
# [调试] 覆写时间, 设为 -1 以禁用
|
||||
daystamp_override = -1
|
||||
timestamp_override = -1
|
||||
|
||||
# [调试] 一键通过
|
||||
quick_pass = true
|
||||
|
||||
# 对于每个项目的默认新记忆原子数量
|
||||
scheduled_num = 8
|
||||
|
||||
# UTC 时间戳修正 仅用于 UNIX 日时间戳的生成修正, 单位为秒
|
||||
timezone_offset = +28800 # 中国标准时间 (UTC+8)
|
||||
|
||||
[interface]
|
||||
|
||||
[interface.memorizor]
|
||||
autovoice = false # 自动语音播放, 仅限于 recognition 组件
|
||||
|
||||
[algorithm]
|
||||
default = "SM-2" # 主要算法; 可选项: SM-2, SM-15M, FSRS
|
||||
|
||||
[puzzles] # 谜题默认配置
|
||||
|
||||
[puzzles.mcq]
|
||||
max_riddles_num = 2
|
||||
|
||||
[puzzles.cloze]
|
||||
min_denominator = 3
|
||||
|
||||
[paths] # 相对于配置文件的 ".." (即工作目录) 而言 或绝对路径
|
||||
data = "./data"
|
||||
cache = "./data/cache"
|
||||
config = "./data/config"
|
||||
global = "./data/global"
|
||||
repo = "./data/repo"
|
||||
[services] # 定义服务到提供者的映射
|
||||
audio = "playsound" # 可选项: playsound(通用), termux(仅用于支持 Android Termux), mpg123(TODO)
|
||||
tts = "edgetts" # 可选项: edgetts
|
||||
llm = "openai" # 可选项: openai
|
||||
sync = "webdav" # 可选项: 留空, webdav
|
||||
|
||||
[providers.tts.edgetts] # EdgeTTS 设置
|
||||
voice = "zh-CN-XiaoxiaoNeural" # 可选项: zh-CN-YunjianNeural (男声), zh-CN-XiaoxiaoNeural (女声)
|
||||
|
||||
[providers.llm.openai] # 与 OpenAI 相容的语言模型接口服务设置
|
||||
url = ""
|
||||
key = ""
|
||||
|
||||
[providers.sync.webdav] # WebDAV 同步设置
|
||||
url = ""
|
||||
username = ""
|
||||
password = ""
|
||||
remote_path = "/heurams/"
|
||||
verify_ssl = true
|
||||
|
||||
[sync]
|
||||
1
data/repo/test/algodata.json
Normal file
1
data/repo/test/algodata.json
Normal file
@@ -0,0 +1 @@
|
||||
{}
|
||||
3
data/repo/test/manifest.toml
Normal file
3
data/repo/test/manifest.toml
Normal file
@@ -0,0 +1,3 @@
|
||||
title = "测试单元: 过秦论"
|
||||
author = "__heurams__"
|
||||
desc = "高考古诗文: 过秦论"
|
||||
11
data/repo/test/payload.toml
Normal file
11
data/repo/test/payload.toml
Normal file
@@ -0,0 +1,11 @@
|
||||
["秦孝公据崤函之固, 拥雍州之地,"]
|
||||
note = []
|
||||
content = "秦孝公/据/崤函/之固/, 拥/雍州/之地,/"
|
||||
translation = "秦孝公占据着崤山和函谷关的险固地势,拥有雍州的土地,"
|
||||
keyword_note = {"据"="占据", "崤函"="崤山和函谷关", "雍州"="古代九州之一"}
|
||||
|
||||
["君臣固守以窥周室,"]
|
||||
note = []
|
||||
content = "君臣/固守/以窥/周室,/"
|
||||
translation = "君臣牢固地守卫着,借以窥视周王室的权力,"
|
||||
keyword_note = {"窥"="窥视"}
|
||||
11
data/repo/test/schedule.toml
Normal file
11
data/repo/test/schedule.toml
Normal file
@@ -0,0 +1,11 @@
|
||||
schedule = ["quick_review", "recognition", "final_review"]
|
||||
|
||||
[phases]
|
||||
quick_review = [["FillBlank", "1.0"], ["SelectMeaning", "0.5"], ["Recognition", "1.0"]]
|
||||
recognition = [["Recognition", "1.0"]]
|
||||
final_review = [["FillBlank", "1.0"], ["SelectMeaning", "1.0"], ["Recognition", "1.0"]]
|
||||
|
||||
[annotation]
|
||||
"quick_review" = "复习旧知"
|
||||
"recognition" = "新知识"
|
||||
"final_review" = "总复习"
|
||||
17
data/repo/test/typedef.toml
Normal file
17
data/repo/test/typedef.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[annotation]
|
||||
note = "笔记"
|
||||
keyword_note = "关键词翻译"
|
||||
translation = "语句翻译"
|
||||
delimiter = "分隔符"
|
||||
content = "内容"
|
||||
tts_text = "文本转语音文本"
|
||||
|
||||
[common]
|
||||
delimiter = "/"
|
||||
tts_text = "eval:payload['content'].replace('/', '')"
|
||||
|
||||
[common.puzzles] # 谜题定义
|
||||
# 我们称 "Recognition" 为 recognition 谜题的 alia
|
||||
"Recognition" = { __origin__ = "recognition", __hint__ = "", primary = "eval:payload['content']", secondary = ["eval:payload['keyword_note']", "eval:payload['note']"], top_dim = ["eval:payload['translation']"] }
|
||||
"SelectMeaning" = { __origin__ = "mcq", __hint__ = "eval:payload['content']", primary = "eval:payload['content']", mapping = "eval:payload['keyword_note']", jammer = "eval:list(payload['keyword_note'].values())", max_riddles_num = "eval:default['mcq']['max_riddles_num']", prefix = "选择正确项: " }
|
||||
"FillBlank" = { __origin__ = "cloze", __hint__ = "", text = "eval:payload['content']", delimiter = "eval:nucleon['delimiter']", min_denominator = "eval:default['cloze']['min_denominator']"}
|
||||
@@ -1,23 +0,0 @@
|
||||
# Nucleon 是 HeurAMS 软件项目使用的基于 TOML 的专有源文件格式, 版本 5
|
||||
# 建议使用的 MIME 类型: application/vnd.xyz.imwangzhiyu.heurams-nucleon.v5+toml
|
||||
|
||||
[__metadata__]
|
||||
[__metadata__.attribution] # 元信息
|
||||
desc = "带有宏支持的空白模板"
|
||||
|
||||
[__metadata__.annotation] # 键批注
|
||||
|
||||
[__metadata__.formation] # 文件配置
|
||||
#delimiter = "/"
|
||||
#tts_text = "eval:nucleon['content'].replace('/', '')"
|
||||
|
||||
[__metadata__.orbital.puzzles] # 谜题定义
|
||||
# 我们称 "Recognition" 为 recognition 谜题的 alia
|
||||
#"Recognition" = { __origin__ = "recognition", __hint__ = "", primary = "eval:nucleon['content']", secondary = ["eval:nucleon['keyword_note']", "eval:nucleon['note']"], top_dim = ["eval:nucleon['translation']"] }
|
||||
#"SelectMeaning" = { __origin__ = "mcq", __hint__ = "eval:nucleon['content']", mapping = "eval:nucleon['keyword_note']", jammer = "eval:nucleon['keyword_note']", max_riddles_num = "eval:default['mcq']['max_riddles_num']", prefix = "选择正确项: " }
|
||||
#"FillBlank" = { __origin__ = "cloze", __hint__ = "", text = "eval:nucleon['content']", delimiter = "eval:metadata['formation']['delimiter']", min_denominator = "eval:default['cloze']['min_denominator']"}
|
||||
|
||||
[__metadata__.orbital.schedule] # 内置的推荐学习方案
|
||||
#quick_review = [["FillBlank", "1.0"], ["SelectMeaning", "0.5"], ["recognition", "1.0"]]
|
||||
#recognition = [["Recognition", "1.0"]]
|
||||
#final_review = [["FillBlank", "0.7"], ["SelectMeaning", "0.7"], ["recognition", "1.0"]]
|
||||
@@ -14,6 +14,14 @@ scheduled_num = 8
|
||||
# UTC 时间戳修正 仅用于 UNIX 日时间戳的生成修正, 单位为秒
|
||||
timezone_offset = +28800 # 中国标准时间 (UTC+8)
|
||||
|
||||
[interface]
|
||||
|
||||
[interface.memorizor]
|
||||
autovoice = true # 自动语音播放, 仅限于 recognition 组件
|
||||
|
||||
[algorithm]
|
||||
default = "SM-2" # 主要算法; 可选项: SM-2, SM-15M, FSRS
|
||||
|
||||
[puzzles] # 谜题默认配置
|
||||
|
||||
[puzzles.mcq]
|
||||
@@ -23,17 +31,26 @@ max_riddles_num = 2
|
||||
min_denominator = 3
|
||||
|
||||
[paths] # 相对于配置文件的 ".." (即工作目录) 而言 或绝对路径
|
||||
nucleon_dir = "./data/nucleon"
|
||||
electron_dir = "./data/electron"
|
||||
orbital_dir = "./data/orbital"
|
||||
cache_dir = "./data/cache"
|
||||
template_dir = "./data/template"
|
||||
data = "./data"
|
||||
|
||||
[services] # 定义服务到提供者的映射
|
||||
audio = "playsound" # 可选项: playsound(通用), termux(仅用于支持 Android Termux), mpg123(TODO)
|
||||
tts = "edgetts" # 可选项: edgetts
|
||||
llm = "openai" # 可选项: openai
|
||||
sync = "webdav" # 可选项: 留空, webdav
|
||||
|
||||
[providers.tts.edgetts] # EdgeTTS 设置
|
||||
voice = "zh-CN-XiaoxiaoNeural" # 可选项: zh-CN-YunjianNeural (男声), zh-CN-XiaoxiaoNeural (女声)
|
||||
|
||||
[providers.llm.openai] # 与 OpenAI 相容的语言模型接口服务设置
|
||||
url = ""
|
||||
key = ""
|
||||
|
||||
[providers.sync.webdav] # WebDAV 同步设置
|
||||
url = ""
|
||||
username = ""
|
||||
password = ""
|
||||
remote_path = "/heurams/"
|
||||
verify_ssl = true
|
||||
|
||||
[sync]
|
||||
14
examples/jiebatest.py
Normal file
14
examples/jiebatest.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# encoding=utf-8
|
||||
import jieba
|
||||
|
||||
# jieba.enable_paddle()# 启动paddle模式。 0.40版之后开始支持,早期版本不支持
|
||||
strs = ["我来到北京清华大学", "乒乓球拍卖完了", "中国科学技术大学"]
|
||||
# for str in strs:
|
||||
# seg_list = jieba.cut(str,use_paddle=True) # 使用paddle模式
|
||||
# print("Paddle Mode: " + '/'.join(list(seg_list)))
|
||||
|
||||
seg_list = jieba.cut("秦孝公据崤函之固, 拥雍州之地", cut_all=False)
|
||||
print("Default Mode: " + "/ ".join(seg_list)) # 精确模式
|
||||
|
||||
seg_list = jieba.cut("他来到了网易杭研大厦") # 默认是精确模式
|
||||
print(", ".join(seg_list))
|
||||
764
examples/repo.ipynb
Normal file
764
examples/repo.ipynb
Normal file
@@ -0,0 +1,764 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "51b89355",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 演练场\n",
|
||||
"此笔记本将带你了解 repomgr 与 particles 对象相关操作 \n",
|
||||
"此笔记本内含的系统命令默认仅存在于 Linux 操作系统, 如果你使用 Windows, 请在安装 busybox 或 cygwin 或 WSL 的环境下执行此笔记本"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f5c49014",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 从一个例子开始\n",
|
||||
"## 了解文件结构\n",
|
||||
"了解一下文件结构"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"id": "a5ed9864",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[01;34m.\u001b[0m\n",
|
||||
"├── \u001b[01;34mdata\u001b[0m\n",
|
||||
"│ └── \u001b[01;34mconfig\u001b[0m\n",
|
||||
"│ └── \u001b[00mconfig.toml\u001b[0m\n",
|
||||
"├── \u001b[00mjiebatest.py\u001b[0m\n",
|
||||
"├── \u001b[00mrepo.ipynb\u001b[0m\n",
|
||||
"├── \u001b[00msimplemem.py\u001b[0m\n",
|
||||
"└── \u001b[01;34mtest_repo\u001b[0m\n",
|
||||
" ├── \u001b[00malgodata.json\u001b[0m\n",
|
||||
" ├── \u001b[00mmanifest.toml\u001b[0m\n",
|
||||
" ├── \u001b[00mpayload.toml\u001b[0m\n",
|
||||
" ├── \u001b[00mschedule.toml\u001b[0m\n",
|
||||
" └── \u001b[00mtypedef.toml\u001b[0m\n",
|
||||
"\n",
|
||||
"4 directories, 9 files\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!tree # 了解文件结构"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4e10922b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"如果你先前运行了单元格, 请运行下面一格清理."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"id": "9777730e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"zsh:1: no matches found: heurams.log*\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!rm -rf test_new_repo\n",
|
||||
"!rm -rf heurams.log*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "058c098f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 导入模块\n",
|
||||
"导入所需模块, 你会看到欢迎信息, 标示了库所使用的配置. \n",
|
||||
"HeurAMS 在基础设施也使用配置文件实现隐式的依赖注入. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"id": "bf1b00c8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import heurams.kernel.repolib as repolib # 这是 RepoLib 子模块, 用于管理和结构化 repo(中文含义: 仓库) 数据结构与本地文件间的联系\n",
|
||||
"import heurams.kernel.particles as pt # 这是 Particles(中文含义: 粒子) 子模块, 用于运行时的记忆管理操作\n",
|
||||
"from pathlib import (\n",
|
||||
" Path,\n",
|
||||
") # 这是 Python 的 Pathlib 模块, 用于表示文件路径, 在整个项目中, 都使用此模块表示路径"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ea1f68bb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 运行时检查\n",
|
||||
"如你所见, repo 在文件系统内存储为一个文件夹. \n",
|
||||
"因此在载入之前, 首先要检查这是否是一个合乎标准的 repo 文件夹. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"id": "897b62d7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"这是一个 合规 的 repo!\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"is_vaild = repolib.Repo.check_repodir(Path(\"./test_repo\"))\n",
|
||||
"print(f\"这是一个 {'合规' if is_vaild else '不合规'} 的 repo!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "24a19991",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 加载仓库\n",
|
||||
"接下来, 正式加载 repo."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"id": "708ae7e4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_repo = repolib.Repo.create_from_repodir(Path(\"./test_repo\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "474f8eb7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 导出为字典\n",
|
||||
"作为一个数据容器, repo 相应地建立了导入和导出的功能. \n",
|
||||
"我们刚刚从本地文件夹导入了一个 repo. \n",
|
||||
"现在试试导出为一个字典."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"id": "a11115fb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'algodata': [('秦孝公据崤函之固, 拥雍州之地,', {}), ('君臣固守以窥周室,', {})],\n",
|
||||
" 'manifest': {'author': '__heurams__',\n",
|
||||
" 'desc': '高考古诗文: 过秦论',\n",
|
||||
" 'title': '测试单元: 过秦论'},\n",
|
||||
" 'payload': [('秦孝公据崤函之固, 拥雍州之地,',\n",
|
||||
" {'content': '秦孝公/据/崤函/之固/, 拥/雍州/之地,/',\n",
|
||||
" 'keyword_note': {'崤函': '崤山和函谷关', '据': '占据', '雍州': '古代九州之一'},\n",
|
||||
" 'note': [],\n",
|
||||
" 'translation': '秦孝公占据着崤山和函谷关的险固地势,拥有雍州的土地,'}),\n",
|
||||
" ('君臣固守以窥周室,',\n",
|
||||
" {'content': '君臣/固守/以窥/周室,/',\n",
|
||||
" 'keyword_note': {'窥': '窥视'},\n",
|
||||
" 'note': [],\n",
|
||||
" 'translation': '君臣牢固地守卫着,借以窥视周王室的权力,'})],\n",
|
||||
" 'schedule': {'phases': {'final_review': [['FillBlank', '0.7'],\n",
|
||||
" ['SelectMeaning', '0.7'],\n",
|
||||
" ['Recognition', '1.0']],\n",
|
||||
" 'quick_review': [['FillBlank', '1.0'],\n",
|
||||
" ['SelectMeaning', '0.5'],\n",
|
||||
" ['Recognition', '1.0']],\n",
|
||||
" 'recognition': [['Recognition', '1.0']]},\n",
|
||||
" 'schedule': ['quick_review', 'recognition', 'final_review']},\n",
|
||||
" 'source': PosixPath('test_repo'),\n",
|
||||
" 'typedef': {'annotation': {'content': '内容',\n",
|
||||
" 'delimiter': '分隔符',\n",
|
||||
" 'keyword_note': '关键词翻译',\n",
|
||||
" 'note': '笔记',\n",
|
||||
" 'translation': '语句翻译',\n",
|
||||
" 'tts_text': '文本转语音文本'},\n",
|
||||
" 'common': {'delimiter': '/',\n",
|
||||
" 'puzzles': {'FillBlank': {'__hint__': '',\n",
|
||||
" '__origin__': 'cloze',\n",
|
||||
" 'delimiter': \"eval:nucleon['delimiter']\",\n",
|
||||
" 'min_denominator': \"eval:default['cloze']['min_denominator']\",\n",
|
||||
" 'text': \"eval:payload['content']\"},\n",
|
||||
" 'Recognition': {'__hint__': '',\n",
|
||||
" '__origin__': 'recognition',\n",
|
||||
" 'primary': \"eval:payload['content']\",\n",
|
||||
" 'secondary': [\"eval:payload['keyword_note']\",\n",
|
||||
" \"eval:payload['note']\"],\n",
|
||||
" 'top_dim': [\"eval:payload['translation']\"]},\n",
|
||||
" 'SelectMeaning': {'__hint__': \"eval:payload['content']\",\n",
|
||||
" '__origin__': 'mcq',\n",
|
||||
" 'jammer': \"eval:list(payload['keyword_note'].values())\",\n",
|
||||
" 'mapping': \"eval:payload['keyword_note']\",\n",
|
||||
" 'max_riddles_num': \"eval:default['mcq']['max_riddles_num']\",\n",
|
||||
" 'prefix': '选择正确项: ',\n",
|
||||
" 'primary': \"eval:payload['content']\"}},\n",
|
||||
" 'tts_text': \"eval:payload['content'].replace('/', \"\n",
|
||||
" \"'')\"}}}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"test_repo_dic = test_repo.export_to_single_dict()\n",
|
||||
"from pprint import pprint\n",
|
||||
"\n",
|
||||
"pprint(test_repo_dic)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "35a2e06f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 持久化与部分保存\n",
|
||||
"如你所见, 所有内容被结构化地输出了! \n",
|
||||
"\n",
|
||||
"现在写回到文件夹! \n",
|
||||
"\n",
|
||||
"我们注意到, 并非所有的内容都要被修改. \n",
|
||||
"我们可以只保存接受修改的一部分, 默认情况下, 是迭代的记忆数据(algodata). \n",
|
||||
"这就是为什么我们一般不使用单个 json 或 toml 来存储 repo.\n",
|
||||
"\n",
|
||||
"persist_to_repodir 接受两个可选参数: \n",
|
||||
"- save_list: 默认为 [\"algodata\"], 是要持久化的数据.\n",
|
||||
"- source: 默认为原目录, 你也可以手动指定为其他文件夹(通过 Path)\n",
|
||||
"\n",
|
||||
"现在做一些演练, 我们将创建一个位于 test_new_repo 的\"克隆\". \n",
|
||||
"除非文件夹已经存在, Repo 对象将会为你自动创建新文件夹."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"id": "05eeaacc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001b[01;34m.\u001b[0m\n",
|
||||
"├── \u001b[01;34mdata\u001b[0m\n",
|
||||
"│ └── \u001b[01;34mconfig\u001b[0m\n",
|
||||
"│ └── \u001b[00mconfig.toml\u001b[0m\n",
|
||||
"├── \u001b[00mjiebatest.py\u001b[0m\n",
|
||||
"├── \u001b[00mrepo.ipynb\u001b[0m\n",
|
||||
"├── \u001b[00msimplemem.py\u001b[0m\n",
|
||||
"├── \u001b[01;34mtest_new_repo\u001b[0m\n",
|
||||
"│ ├── \u001b[00malgodata.json\u001b[0m\n",
|
||||
"│ ├── \u001b[00mmanifest.toml\u001b[0m\n",
|
||||
"│ ├── \u001b[00mpayload.toml\u001b[0m\n",
|
||||
"│ ├── \u001b[00mschedule.toml\u001b[0m\n",
|
||||
"│ └── \u001b[00mtypedef.toml\u001b[0m\n",
|
||||
"└── \u001b[01;34mtest_repo\u001b[0m\n",
|
||||
" ├── \u001b[00malgodata.json\u001b[0m\n",
|
||||
" ├── \u001b[00mmanifest.toml\u001b[0m\n",
|
||||
" ├── \u001b[00mpayload.toml\u001b[0m\n",
|
||||
" ├── \u001b[00mschedule.toml\u001b[0m\n",
|
||||
" └── \u001b[00mtypedef.toml\u001b[0m\n",
|
||||
"\n",
|
||||
"5 directories, 14 files\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"test_repo.persist_to_repodir(\n",
|
||||
" save_list=[\"schedule\", \"payload\", \"manifest\", \"typedef\", \"algodata\"],\n",
|
||||
" source=Path(\"test_new_repo\"),\n",
|
||||
")\n",
|
||||
"!tree"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "059d7bdf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"如你所见, test_new_repo 已被生成!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4ef8925c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 数据结构\n",
|
||||
"现在讲解 repo 的数据结构"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c19fed95",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Lict 对象\n",
|
||||
"Lict 对象集成了部分列表和字典的功能, 数据在这两种风格的 API 间都可用, 且修改是同步的. \n",
|
||||
"Lict 默认情况下不会保存序列顺序, 而是在列表形式下, 自动按索引字符序排布, 详情请参阅源代码. \n",
|
||||
"现在导入并初始化一个 Lict 对象:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"id": "7e88bd7c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[('name', 'tom'), ('age', 12), ('enemy', 'jerry')]\n",
|
||||
"[('name', 'tom'), ('age', 12), ('enemy', 'jerry')]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from heurams.kernel.auxiliary.lict import Lict\n",
|
||||
"\n",
|
||||
"lct = Lict() # 空的\n",
|
||||
"lct = Lict(initlist=[(\"name\", \"tom\"), (\"age\", 12), (\"enemy\", \"jerry\")]) # 基于列表\n",
|
||||
"print(lct)\n",
|
||||
"lct = Lict(initdict={\"name\": \"tom\", \"age\": 12, \"enemy\": \"jerry\"}) # 基于字典\n",
|
||||
"print(lct)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4d760bf9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 输出形式\n",
|
||||
"Lict 的\"官方\"输出形式是列表形式\n",
|
||||
"你也可以选择输出字典形式"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"id": "248f6cba",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'name': 'tom', 'age': 12, 'enemy': 'jerry'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(lct.dicted_data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "29dce184",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### dicted_data 属性与修改方式\n",
|
||||
"dicted_data 属性是一个字典, 它自动同步来自 Lict 对象操作的修改.\n",
|
||||
"一个注意事项: 不要直接修改 dicted_data, 这将不会触发同步 hook.\n",
|
||||
"如果你一定要这样做, 请在完事后手动运行同步 hook.\n",
|
||||
"推荐的修改方式是直接把 lct 当作一个字典"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"id": "a0eb07a7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[('name', 'tom'), ('age', 12), ('enemy', 'jerry')]\n",
|
||||
"[('name', 'tom'), ('age', 12), ('enemy', 'jerry'), ('type', 'cat')]\n",
|
||||
"[('name', 'tom'), ('age', 12), ('enemy', 'jerry'), ('type', 'cat'), ('is_human', False)]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 由于 jupyter 的环境处理, 请不要重复运行此单元格, 如果想再看一遍, 请重启 jupyter 后再全部运行\n",
|
||||
"\n",
|
||||
"# 错误的方式\n",
|
||||
"lct.dicted_data[\"type\"] = \"cat\"\n",
|
||||
"print(lct) # 将不会同步修改\n",
|
||||
"\n",
|
||||
"# 不推荐, 但可用的方式\n",
|
||||
"lct.dicted_data[\"type\"] = \"cat\"\n",
|
||||
"lct._sync_based_on_dict()\n",
|
||||
"print(lct)\n",
|
||||
"\n",
|
||||
"# 推荐方式\n",
|
||||
"lct[\"is_human\"] = False\n",
|
||||
"print(lct)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2337d113",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### data 属性与修改方式\n",
|
||||
"data 属性是一个列表, 它自动同步来自 Lict 对象操作的修改.\n",
|
||||
"一个注意事项: 不要直接修改 data, 这将不会触发同步 hook, 并且可能破坏排序.\n",
|
||||
"如果你一定要这样做, 请在完事后手动运行同步 hook 和 sort, 此处不演示.\n",
|
||||
"推荐的修改方式是直接把 lct 当作一个列表, 且避免使用索引修改"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0ab442d4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'name': 'tom', 'age': 12, 'enemy': 'jerry', 'type': 'cat', 'is_human': False, 'enemy_2': 'spike'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 由于 Jupyter 的环境处理(环境状态会累积), 请不要重复运行此单元格, 如果想再看一遍, 请重启 jupyter 后再全部运行\n",
|
||||
"\n",
|
||||
"# 唯一推荐方式\n",
|
||||
"lct.append((\"enemy_2\", \"spike\"))\n",
|
||||
"print(lct.dicted_data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a3383f59",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 多面手\n",
|
||||
"Lict 有一些很酷的功能\n",
|
||||
"详情请看源文件\n",
|
||||
"此处是一些例子"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"id": "f3ca752f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[('age', 12), ('enemy', 'jerry'), ('is_human', False), ('name', 'tom'), ('type', 'cat'), ('enemy_2', 'spike')]\n",
|
||||
"{'age': 12, 'enemy': 'jerry', 'is_human': False, 'name': 'tom', 'type': 'cat', 'enemy_2': 'spike'}\n",
|
||||
"------\n",
|
||||
"('age', 12)\n",
|
||||
"('enemy', 'jerry')\n",
|
||||
"('is_human', False)\n",
|
||||
"('name', 'tom')\n",
|
||||
"('type', 'cat')\n",
|
||||
"('enemy_2', 'spike')\n",
|
||||
"6\n",
|
||||
"('enemy_2', 'spike')\n",
|
||||
"[('age', 12), ('enemy', 'jerry'), ('is_human', False), ('name', 'tom'), ('type', 'cat')]\n",
|
||||
"('type', 'cat')\n",
|
||||
"[('age', 12), ('enemy', 'jerry'), ('is_human', False), ('name', 'tom')]\n",
|
||||
"('name', 'tom')\n",
|
||||
"[('age', 12), ('enemy', 'jerry'), ('is_human', False)]\n",
|
||||
"('is_human', False)\n",
|
||||
"[('age', 12), ('enemy', 'jerry')]\n",
|
||||
"('enemy', 'jerry')\n",
|
||||
"[('age', 12)]\n",
|
||||
"('age', 12)\n",
|
||||
"[]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Ellipsis"
|
||||
]
|
||||
},
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"lct = Lict(\n",
|
||||
" initdict={\n",
|
||||
" \"age\": 12,\n",
|
||||
" \"enemy\": \"jerry\",\n",
|
||||
" \"is_human\": False,\n",
|
||||
" \"name\": \"tom\",\n",
|
||||
" \"type\": \"cat\",\n",
|
||||
" \"enemy_2\": \"spike\",\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"print(lct)\n",
|
||||
"print(lct.dicted_data)\n",
|
||||
"print(\"------\")\n",
|
||||
"for i in lct:\n",
|
||||
" print(i)\n",
|
||||
"print(len(lct))\n",
|
||||
"while len(lct) > 0:\n",
|
||||
" print(lct.pop())\n",
|
||||
" print(lct)\n",
|
||||
"lct = Lict(\n",
|
||||
" initdict={\n",
|
||||
" \"age\": 12,\n",
|
||||
" \"enemy\": \"jerry\",\n",
|
||||
" \"is_human\": False,\n",
|
||||
" \"name\": \"tom\",\n",
|
||||
" \"type\": \"cat\",\n",
|
||||
" \"enemy_2\": \"spike\",\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"..."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2d6d3483",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"关爱环境 从你我做起"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"id": "773bf99c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"zsh:1: no matches found: heurams.log*\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!rm -rf test_new_repo\n",
|
||||
"!rm -rf heurams.log*"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"id": "8645c5a2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{ 'content': '秦孝公/据/崤函/之固/, 拥/雍州/之地,/',\n",
|
||||
" 'delimiter': '/',\n",
|
||||
" 'keyword_note': {'崤函': '崤山和函谷关', '据': '占据', '雍州': '古代九州之一'},\n",
|
||||
" 'note': [],\n",
|
||||
" 'puzzles': { 'FillBlank': { '__hint__': '',\n",
|
||||
" '__origin__': 'cloze',\n",
|
||||
" 'delimiter': '/',\n",
|
||||
" 'min_denominator': 3,\n",
|
||||
" 'text': '秦孝公/据/崤函/之固/, 拥/雍州/之地,/'},\n",
|
||||
" 'Recognition': { '__hint__': '',\n",
|
||||
" '__origin__': 'recognition',\n",
|
||||
" 'primary': '秦孝公/据/崤函/之固/, 拥/雍州/之地,/',\n",
|
||||
" 'secondary': [ { '崤函': '崤山和函谷关',\n",
|
||||
" '据': '占据',\n",
|
||||
" '雍州': '古代九州之一'},\n",
|
||||
" []],\n",
|
||||
" 'top_dim': [ '秦孝公占据着崤山和函谷关的险固地势,拥有雍州的土地,']},\n",
|
||||
" 'SelectMeaning': { '__hint__': '秦孝公/据/崤函/之固/, 拥/雍州/之地,/',\n",
|
||||
" '__origin__': 'mcq',\n",
|
||||
" 'jammer': ['占据', '崤山和函谷关', '古代九州之一'],\n",
|
||||
" 'mapping': { '崤函': '崤山和函谷关',\n",
|
||||
" '据': '占据',\n",
|
||||
" '雍州': '古代九州之一'},\n",
|
||||
" 'max_riddles_num': 2,\n",
|
||||
" 'prefix': '选择正确项: ',\n",
|
||||
" 'primary': '秦孝公/据/崤函/之固/, 拥/雍州/之地,/'}},\n",
|
||||
" 'translation': '秦孝公占据着崤山和函谷关的险固地势,拥有雍州的土地,',\n",
|
||||
" 'tts_text': '秦孝公据崤函之固, 拥雍州之地,'}\n",
|
||||
"{ 'SM-2': { 'efactor': 2.5,\n",
|
||||
" 'interval': 1,\n",
|
||||
" 'is_activated': 1,\n",
|
||||
" 'last_date': 20459,\n",
|
||||
" 'last_modify': 1767700296.4950516,\n",
|
||||
" 'next_date': 20460,\n",
|
||||
" 'real_rept': 1,\n",
|
||||
" 'rept': 0}}\n",
|
||||
"{ 'content': '君臣/固守/以窥/周室,/',\n",
|
||||
" 'delimiter': '/',\n",
|
||||
" 'keyword_note': {'窥': '窥视'},\n",
|
||||
" 'note': [],\n",
|
||||
" 'puzzles': { 'FillBlank': { '__hint__': '',\n",
|
||||
" '__origin__': 'cloze',\n",
|
||||
" 'delimiter': '/',\n",
|
||||
" 'min_denominator': 3,\n",
|
||||
" 'text': '君臣/固守/以窥/周室,/'},\n",
|
||||
" 'Recognition': { '__hint__': '',\n",
|
||||
" '__origin__': 'recognition',\n",
|
||||
" 'primary': '君臣/固守/以窥/周室,/',\n",
|
||||
" 'secondary': [{'窥': '窥视'}, []],\n",
|
||||
" 'top_dim': ['君臣牢固地守卫着,借以窥视周王室的权力,']},\n",
|
||||
" 'SelectMeaning': { '__hint__': '君臣/固守/以窥/周室,/',\n",
|
||||
" '__origin__': 'mcq',\n",
|
||||
" 'jammer': ['窥视'],\n",
|
||||
" 'mapping': {'窥': '窥视'},\n",
|
||||
" 'max_riddles_num': 2,\n",
|
||||
" 'prefix': '选择正确项: ',\n",
|
||||
" 'primary': '君臣/固守/以窥/周室,/'}},\n",
|
||||
" 'translation': '君臣牢固地守卫着,借以窥视周王室的权力,',\n",
|
||||
" 'tts_text': '君臣固守以窥周室,'}\n",
|
||||
"{ 'SM-2': { 'efactor': 2.5,\n",
|
||||
" 'interval': 1,\n",
|
||||
" 'is_activated': 1,\n",
|
||||
" 'last_date': 20459,\n",
|
||||
" 'last_modify': 1767700296.4968777,\n",
|
||||
" 'next_date': 20460,\n",
|
||||
" 'real_rept': 1,\n",
|
||||
" 'rept': 0}}\n",
|
||||
"{ 'algodata': [ ( '秦孝公据崤函之固, 拥雍州之地,',\n",
|
||||
" { 'SM-2': { 'efactor': 2.5,\n",
|
||||
" 'interval': 1,\n",
|
||||
" 'is_activated': 1,\n",
|
||||
" 'last_date': 20459,\n",
|
||||
" 'last_modify': 1767700296.4950516,\n",
|
||||
" 'next_date': 20460,\n",
|
||||
" 'real_rept': 1,\n",
|
||||
" 'rept': 0}}),\n",
|
||||
" ( '君臣固守以窥周室,',\n",
|
||||
" { 'SM-2': { 'efactor': 2.5,\n",
|
||||
" 'interval': 1,\n",
|
||||
" 'is_activated': 1,\n",
|
||||
" 'last_date': 20459,\n",
|
||||
" 'last_modify': 1767700296.4968777,\n",
|
||||
" 'next_date': 20460,\n",
|
||||
" 'real_rept': 1,\n",
|
||||
" 'rept': 0}})],\n",
|
||||
" 'manifest': { 'author': '__heurams__',\n",
|
||||
" 'desc': '高考古诗文: 过秦论',\n",
|
||||
" 'title': '测试单元: 过秦论'},\n",
|
||||
" 'payload': [ ( '秦孝公据崤函之固, 拥雍州之地,',\n",
|
||||
" { 'content': '秦孝公/据/崤函/之固/, 拥/雍州/之地,/',\n",
|
||||
" 'keyword_note': { '崤函': '崤山和函谷关',\n",
|
||||
" '据': '占据',\n",
|
||||
" '雍州': '古代九州之一'},\n",
|
||||
" 'note': [],\n",
|
||||
" 'translation': '秦孝公占据着崤山和函谷关的险固地势,拥有雍州的土地,'}),\n",
|
||||
" ( '君臣固守以窥周室,',\n",
|
||||
" { 'content': '君臣/固守/以窥/周室,/',\n",
|
||||
" 'keyword_note': {'窥': '窥视'},\n",
|
||||
" 'note': [],\n",
|
||||
" 'translation': '君臣牢固地守卫着,借以窥视周王室的权力,'})],\n",
|
||||
" 'schedule': { 'phases': { 'final_review': [ ['FillBlank', '0.7'],\n",
|
||||
" ['SelectMeaning', '0.7'],\n",
|
||||
" ['Recognition', '1.0']],\n",
|
||||
" 'quick_review': [ ['FillBlank', '1.0'],\n",
|
||||
" ['SelectMeaning', '0.5'],\n",
|
||||
" ['Recognition', '1.0']],\n",
|
||||
" 'recognition': [['Recognition', '1.0']]},\n",
|
||||
" 'schedule': [ 'quick_review',\n",
|
||||
" 'recognition',\n",
|
||||
" 'final_review']},\n",
|
||||
" 'source': PosixPath('test_repo'),\n",
|
||||
" 'typedef': { 'annotation': { 'content': '内容',\n",
|
||||
" 'delimiter': '分隔符',\n",
|
||||
" 'keyword_note': '关键词翻译',\n",
|
||||
" 'note': '笔记',\n",
|
||||
" 'translation': '语句翻译',\n",
|
||||
" 'tts_text': '文本转语音文本'},\n",
|
||||
" 'common': { 'delimiter': '/',\n",
|
||||
" 'puzzles': { 'FillBlank': { '__hint__': '',\n",
|
||||
" '__origin__': 'cloze',\n",
|
||||
" 'delimiter': \"eval:nucleon['delimiter']\",\n",
|
||||
" 'min_denominator': \"eval:default['cloze']['min_denominator']\",\n",
|
||||
" 'text': \"eval:payload['content']\"},\n",
|
||||
" 'Recognition': { '__hint__': '',\n",
|
||||
" '__origin__': 'recognition',\n",
|
||||
" 'primary': \"eval:payload['content']\",\n",
|
||||
" 'secondary': [ \"eval:payload['keyword_note']\",\n",
|
||||
" \"eval:payload['note']\"],\n",
|
||||
" 'top_dim': [ \"eval:payload['translation']\"]},\n",
|
||||
" 'SelectMeaning': { '__hint__': \"eval:payload['content']\",\n",
|
||||
" '__origin__': 'mcq',\n",
|
||||
" 'jammer': \"eval:list(payload['keyword_note'].values())\",\n",
|
||||
" 'mapping': \"eval:payload['keyword_note']\",\n",
|
||||
" 'max_riddles_num': \"eval:default['mcq']['max_riddles_num']\",\n",
|
||||
" 'prefix': '选择正确项: ',\n",
|
||||
" 'primary': \"eval:payload['content']\"}},\n",
|
||||
" 'tts_text': \"eval:payload['content'].replace('/', \"\n",
|
||||
" \"'')\"}}}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"repo = repolib.Repo.create_from_repodir(Path(\"./test_repo\"))\n",
|
||||
"for i in repo.ident_index:\n",
|
||||
" n = pt.Nucleon.create_on_nucleonic_data(\n",
|
||||
" nucleonic_data=repo.nucleonic_data_lict.get_itemic_unit(i)\n",
|
||||
" )\n",
|
||||
" e = pt.Electron.create_on_electonic_data(\n",
|
||||
" electronic_data=repo.electronic_data_lict.get_itemic_unit(i)\n",
|
||||
" )\n",
|
||||
" e.activate()\n",
|
||||
" e.revisor(5, True)\n",
|
||||
" print(repr(n))\n",
|
||||
" print(repr(e))\n",
|
||||
"print(repo)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
57
examples/simplemem.py
Normal file
57
examples/simplemem.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
import heurams.kernel.repolib as repolib
|
||||
from heurams.services.textproc import truncate
|
||||
|
||||
repo = repolib.Repo.create_from_repodir(Path("./test_repo"))
|
||||
alist = list()
|
||||
print(repo.ident_index)
|
||||
for i in repo.ident_index:
|
||||
n = pt.Nucleon.create_on_nucleonic_data(
|
||||
nucleonic_data=repo.nucleonic_data_lict.get_itemic_unit(i)
|
||||
)
|
||||
e = pt.Electron.create_on_electonic_data(
|
||||
electronic_data=repo.electronic_data_lict.get_itemic_unit(i)
|
||||
)
|
||||
print(n)
|
||||
input()
|
||||
a = pt.Atom(n, e, repo.orbitic_data)
|
||||
alist.append(a)
|
||||
# e.activate()
|
||||
# e.revisor(5, True)
|
||||
print(repr(a))
|
||||
# print(repr(e))
|
||||
print(repo)
|
||||
input()
|
||||
import heurams.kernel.reactor as rt
|
||||
|
||||
ph: rt.Phaser = rt.Phaser(alist)
|
||||
print(ph)
|
||||
pr: rt.Procession = ph.current_procession() # type: ignore
|
||||
print(pr)
|
||||
pr.forward()
|
||||
print(pr)
|
||||
pr.forward() # 如果过界了?
|
||||
print(pr) # 静默设置状态 无报错
|
||||
pr.forward()
|
||||
print(pr)
|
||||
pr = ph.current_procession() # type: ignore # 下一个队列
|
||||
print(pr)
|
||||
pr.forward()
|
||||
print(pr)
|
||||
pr.append() # 如果记忆失败了?
|
||||
print(pr)
|
||||
pr.forward()
|
||||
pr.append() # 如果记忆失败了?
|
||||
pr.append() # 如果记忆失败了?
|
||||
pr.append() # 如果记忆失败了?
|
||||
pr.append() # 如果记忆失败了?
|
||||
pr.append() # 如果记忆失败了?
|
||||
# 重复项目只会占据一个车尾
|
||||
print(pr)
|
||||
pr.forward()
|
||||
print(pr)
|
||||
pr = ph.current_procession() # type: ignore
|
||||
print(pr)
|
||||
1
examples/test_repo/algodata.json
Normal file
1
examples/test_repo/algodata.json
Normal file
@@ -0,0 +1 @@
|
||||
{}
|
||||
3
examples/test_repo/manifest.toml
Normal file
3
examples/test_repo/manifest.toml
Normal file
@@ -0,0 +1,3 @@
|
||||
title = "测试单元: 过秦论"
|
||||
author = "__heurams__"
|
||||
desc = "高考古诗文: 过秦论"
|
||||
11
examples/test_repo/payload.toml
Normal file
11
examples/test_repo/payload.toml
Normal file
@@ -0,0 +1,11 @@
|
||||
["秦孝公据崤函之固, 拥雍州之地,"]
|
||||
note = []
|
||||
content = "秦孝公/据/崤函/之固/, 拥/雍州/之地,/"
|
||||
translation = "秦孝公占据着崤山和函谷关的险固地势,拥有雍州的土地,"
|
||||
keyword_note = {"据"="占据", "崤函"="崤山和函谷关", "雍州"="古代九州之一"}
|
||||
|
||||
["君臣固守以窥周室,"]
|
||||
note = []
|
||||
content = "君臣/固守/以窥/周室,/"
|
||||
translation = "君臣牢固地守卫着,借以窥视周王室的权力,"
|
||||
keyword_note = {"窥"="窥视"}
|
||||
5
examples/test_repo/schedule.toml
Normal file
5
examples/test_repo/schedule.toml
Normal file
@@ -0,0 +1,5 @@
|
||||
schedule = ["quick_review", "recognition", "final_review"]
|
||||
[phases]
|
||||
quick_review = [["FillBlank", "1.0"], ["SelectMeaning", "0.5"], ["Recognition", "1.0"]]
|
||||
recognition = [["Recognition", "1.0"]]
|
||||
final_review = [["FillBlank", "0.7"], ["SelectMeaning", "0.7"], ["Recognition", "1.0"]]
|
||||
17
examples/test_repo/typedef.toml
Normal file
17
examples/test_repo/typedef.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[annotation]
|
||||
note = "笔记"
|
||||
keyword_note = "关键词翻译"
|
||||
translation = "语句翻译"
|
||||
delimiter = "分隔符"
|
||||
content = "内容"
|
||||
tts_text = "文本转语音文本"
|
||||
|
||||
[common]
|
||||
delimiter = "/"
|
||||
tts_text = "eval:payload['content'].replace('/', '')"
|
||||
|
||||
[common.puzzles] # 谜题定义
|
||||
# 我们称 "Recognition" 为 recognition 谜题的 alia
|
||||
"Recognition" = { __origin__ = "recognition", __hint__ = "", primary = "eval:payload['content']", secondary = ["eval:payload['keyword_note']", "eval:payload['note']"], top_dim = ["eval:payload['translation']"] }
|
||||
"SelectMeaning" = { __origin__ = "mcq", __hint__ = "eval:payload['content']", primary = "eval:payload['content']", mapping = "eval:payload['keyword_note']", jammer = "eval:list(payload['keyword_note'].values())", max_riddles_num = "eval:default['mcq']['max_riddles_num']", prefix = "选择正确项: " }
|
||||
"FillBlank" = { __origin__ = "cloze", __hint__ = "", text = "eval:payload['content']", delimiter = "eval:nucleon['delimiter']", min_denominator = "eval:default['cloze']['min_denominator']"}
|
||||
12
glossary.md
Normal file
12
glossary.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# 运行时对象
|
||||
Atom: 原子, 由核子, 电子, 轨道对象一并构成, 用于处理记忆所需一系列对象
|
||||
Nucleon: 核子, 负责解析文件动态内容, 并储存记忆材料内容与谜题定义, 是静态只读但可临时覆盖内容的
|
||||
Electron: 电子, 负责处理记忆算法数据
|
||||
Orbital: 轨道, 储存记忆阶段信息与谜题阶段内出现配置
|
||||
# 状态机对象
|
||||
Transitions: 一种状态机框架库
|
||||
Reactor: 状态机库
|
||||
Phaser...
|
||||
|
||||
rating: 用户评估生成的值
|
||||
quality: 用于单元反馈的值
|
||||
@@ -1,27 +1,27 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=45", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "heurams"
|
||||
version = "0.4.0"
|
||||
description = "Heuristic Assisted Memory Scheduler"
|
||||
license = {file = "LICENSE"}
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: GNU Affero General Public License v3",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Operating System :: OS Independent",
|
||||
"Topic :: Education",
|
||||
"Intended Audience :: Education",
|
||||
]
|
||||
keywords = ["spaced-repetition", "memory", "learning", "tui", "textual", "flashcards", "education"]
|
||||
dependencies = [
|
||||
"bidict==0.23.1",
|
||||
"playsound==1.2.2",
|
||||
"textual==5.3.0",
|
||||
"toml==0.10.2",
|
||||
]
|
||||
version = "0.5.0"
|
||||
description = "Heuristic Auxiliary Memory Scheduler"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "pluvium27", email = "pluvium27@outlook.com" }
|
||||
]
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"edge-tts==7.0.2",
|
||||
"jieba==0.42.1",
|
||||
"openai==1.0.0",
|
||||
"playsound==1.2.2",
|
||||
"tabulate>=0.9.0",
|
||||
"textual==7.0.0",
|
||||
"toml==0.10.2",
|
||||
"transitions==0.9.3",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
[project.scripts]
|
||||
heurams = "heurams.__main__:main"
|
||||
tui = "heurams.interface.__main__:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.9.22,<0.10.0"]
|
||||
build-backend = "uv_build"
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
bidict==0.23.1
|
||||
edge-tts==7.0.2
|
||||
jieba==0.42.1
|
||||
openai==1.0.0
|
||||
playsound==1.2.2
|
||||
textual==5.3.0
|
||||
tabulate>=0.9.0
|
||||
textual==7.0.0
|
||||
toml==0.10.2
|
||||
transitions==0.9.3
|
||||
|
||||
@@ -1,7 +1,19 @@
|
||||
prompt = """HeurAMS 已经被成功地安装在系统中.
|
||||
# __main__.py
|
||||
def main():
|
||||
prompt = """HeurAMS 已经被成功地安装在系统中.
|
||||
但 HeurAMS 被设计为一个带有辅助记忆调度器功能的软件包, 无法直接被执行, 但可被其他 Python 程序调用.
|
||||
若您想启动内置的基本用户界面,
|
||||
若您想启动内置的基本用户界面:
|
||||
请运行 python -m heurams.interface,
|
||||
或者 python -m heurams.interface.__main__
|
||||
python 代指您使用的解释器, 在某些发行版中可能是 python3, 而 python 命令被指向了 python2.
|
||||
尽管项目保留了 requirements.txt, 我们仍不推荐使用系统 python 和原始 venv 进行开发.
|
||||
项目的推荐开发环境工具是 uv.
|
||||
如果你的环境已经安装了 uv:
|
||||
先运行 uv sync 同步环境, 此命令只需要执行一遍, uv 会自动处理依赖.
|
||||
通过运行 uv run tui 启动内置基本用户界面.
|
||||
此时您的解释器在项目目录里的 .venv/bin 中, 使用 IDE 开发前, 务必切换解释器!
|
||||
注意: 一个常见的误区是, 执行 interface 下的 __main__.py 运行基本用户界面, 这会导致 Python 上下文环境异常, 请不要这样做."""
|
||||
print(prompt)
|
||||
print(prompt)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
import shutil
|
||||
from contextvars import ContextVar
|
||||
|
||||
from heurams.services.config import ConfigFile
|
||||
@@ -14,26 +15,39 @@ from heurams.services.logger import get_logger
|
||||
# 数据文件路径规定: 以运行目录为准
|
||||
|
||||
rootdir = pathlib.Path(__file__).parent
|
||||
print(f"rootdir: {rootdir}")
|
||||
print(f"项目根目录: {rootdir}")
|
||||
logger = get_logger(__name__)
|
||||
logger.debug(f"项目根目录: {rootdir}")
|
||||
workdir = pathlib.Path.cwd()
|
||||
print(f"workdir: {workdir}")
|
||||
print(f"工作目录: {workdir}")
|
||||
logger.debug(f"工作目录: {workdir}")
|
||||
config_var: ContextVar[ConfigFile] = ContextVar(
|
||||
"config_var", default=ConfigFile(rootdir / "default" / "config" / "config.toml")
|
||||
)
|
||||
try:
|
||||
config_var: ContextVar[ConfigFile] = ContextVar(
|
||||
"config_var", default=ConfigFile(workdir / "config" / "config.toml")
|
||||
) # 配置文件
|
||||
print("已加载自定义用户配置")
|
||||
logger.info("已加载自定义用户配置, 路径: %s", workdir / "config" / "config.toml")
|
||||
except Exception as e:
|
||||
print("未能加载自定义用户配置")
|
||||
logger.warning("未能加载自定义用户配置, 错误: %s", e)
|
||||
|
||||
# runtime_var: ContextVar = ContextVar('runtime_var', default=dict()) # 运行时共享数据
|
||||
if pathlib.Path(workdir / "data" / "config" / "config_dev.toml").exists():
|
||||
print("使用开发设置")
|
||||
logger.debug("使用开发设置")
|
||||
config_var: ContextVar[ConfigFile] = ContextVar(
|
||||
"config_var",
|
||||
default=ConfigFile(workdir / "data" / "config" / "config_dev.toml"),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
config_var: ContextVar[ConfigFile] = ContextVar(
|
||||
"config_var",
|
||||
default=ConfigFile(workdir / "data" / "config" / "config.toml"),
|
||||
) # 配置文件
|
||||
except Exception as e:
|
||||
input("按下回车以创建新的配置文件, 或按下 Ctrl + C 以终止程序 ")
|
||||
(workdir / "data" / "config").mkdir(parents=True, exist_ok=True)
|
||||
(workdir / "data" / "config" / "config").unlink(missing_ok=True)
|
||||
shutil.copy(
|
||||
(rootdir / "default" / "config" / "config.toml"),
|
||||
workdir / "data" / "config" / "config.toml",
|
||||
)
|
||||
finally:
|
||||
config_var: ContextVar[ConfigFile] = ContextVar(
|
||||
"config_var",
|
||||
default=ConfigFile(workdir / "data" / "config" / "config.toml"),
|
||||
) # 配置文件
|
||||
|
||||
|
||||
class ConfigContext:
|
||||
|
||||
@@ -14,6 +14,14 @@ scheduled_num = 8
|
||||
# UTC 时间戳修正 仅用于 UNIX 日时间戳的生成修正, 单位为秒
|
||||
timezone_offset = +28800 # 中国标准时间 (UTC+8)
|
||||
|
||||
[interface]
|
||||
|
||||
[interface.memorizor]
|
||||
autovoice = true # 自动语音播放, 仅限于 recognition 组件
|
||||
|
||||
[algorithm]
|
||||
default = "SM-2" # 主要算法; 可选项: SM-2, SM-15M, FSRS
|
||||
|
||||
[puzzles] # 谜题默认配置
|
||||
|
||||
[puzzles.mcq]
|
||||
@@ -23,17 +31,26 @@ max_riddles_num = 2
|
||||
min_denominator = 3
|
||||
|
||||
[paths] # 相对于配置文件的 ".." (即工作目录) 而言 或绝对路径
|
||||
nucleon_dir = "./data/nucleon"
|
||||
electron_dir = "./data/electron"
|
||||
orbital_dir = "./data/orbital"
|
||||
cache_dir = "./data/cache"
|
||||
template_dir = "./data/template"
|
||||
data = "./data"
|
||||
|
||||
[services] # 定义服务到提供者的映射
|
||||
audio = "playsound" # 可选项: playsound(通用), termux(仅用于支持 Android Termux), mpg123(TODO)
|
||||
tts = "edgetts" # 可选项: edgetts
|
||||
llm = "openai" # 可选项: openai
|
||||
sync = "webdav" # 可选项: 留空, webdav
|
||||
|
||||
[providers.tts.edgetts] # EdgeTTS 设置
|
||||
voice = "zh-CN-XiaoxiaoNeural" # 可选项: zh-CN-YunjianNeural (男声), zh-CN-XiaoxiaoNeural (女声)
|
||||
|
||||
[providers.llm.openai] # 与 OpenAI 相容的语言模型接口服务设置
|
||||
url = ""
|
||||
key = ""
|
||||
|
||||
[providers.sync.webdav] # WebDAV 同步设置
|
||||
url = ""
|
||||
username = ""
|
||||
password = ""
|
||||
remote_path = "/heurams/"
|
||||
verify_ssl = true
|
||||
|
||||
[sync]
|
||||
|
||||
74
src/heurams/interface/__init__.py
Normal file
74
src/heurams/interface/__init__.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Type
|
||||
|
||||
from textual.app import App
|
||||
from textual.driver import Driver
|
||||
from textual.widgets import Button
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .screens.about import AboutScreen
|
||||
from .screens.dashboard import DashboardScreen
|
||||
from .screens.llmchat import LLMChatScreen
|
||||
from .screens.navigator import NavigatorScreen
|
||||
from .screens.precache import PrecachingScreen
|
||||
from .screens.radio import RadioScreen
|
||||
from .screens.repocreator import RepoCreatorScreen
|
||||
from .screens.repoeditor import RepoEditorScreen
|
||||
from .screens.synctool import SyncScreen
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def environment_check():
|
||||
from pathlib import Path
|
||||
|
||||
logger.debug("检查环境路径")
|
||||
subdir = ["cache/voice", "repo", "global", "config"]
|
||||
for i in subdir:
|
||||
i = Path(config_var.get()["paths"]["data"]) / i
|
||||
if not i.exists():
|
||||
logger.info("创建目录: %s", i)
|
||||
print(f"创建 {i}")
|
||||
i.mkdir(exist_ok=True, parents=True)
|
||||
else:
|
||||
logger.debug("目录已存在: %s", i)
|
||||
print(f"找到 {i}")
|
||||
logger.debug("环境检查完成")
|
||||
|
||||
|
||||
class HeurAMSApp(App):
|
||||
TITLE = "潜进"
|
||||
CSS_PATH = "css/main.tcss"
|
||||
SUB_TITLE = "启发式辅助记忆调度器"
|
||||
BINDINGS = [
|
||||
("q", "go_back", "退出"),
|
||||
("d", "toggle_dark", "主题"),
|
||||
("n", "app.push_screen('navigator')", "导航"),
|
||||
("z", "app.push_screen('about')", "关于"),
|
||||
]
|
||||
SCREENS = {
|
||||
"dashboard": DashboardScreen,
|
||||
"repo_creator": RepoCreatorScreen,
|
||||
"precache_all": PrecachingScreen,
|
||||
"synctool": SyncScreen,
|
||||
"about": AboutScreen,
|
||||
"navigator": NavigatorScreen,
|
||||
"radio": RadioScreen,
|
||||
"repo_editor": RepoEditorScreen,
|
||||
"llmchat": LLMChatScreen,
|
||||
}
|
||||
|
||||
def on_mount(self) -> None:
|
||||
environment_check()
|
||||
self.push_screen("dashboard")
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
pass
|
||||
# self.exit(event.button.id)
|
||||
|
||||
def action_go_back(self) -> None:
|
||||
quit()
|
||||
|
||||
def action_do_nothing(self):
|
||||
self.refresh()
|
||||
@@ -1,89 +1,20 @@
|
||||
from textual.app import App
|
||||
from textual.widgets import Button
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.interface import HeurAMSApp
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .screens.about import AboutScreen
|
||||
from .screens.dashboard import DashboardScreen
|
||||
from .screens.nucreator import NucleonCreatorScreen
|
||||
from .screens.precache import PrecachingScreen
|
||||
from .screens.repocreator import RepoCreatorScreen
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class HeurAMSApp(App):
|
||||
TITLE = "潜进"
|
||||
CSS_PATH = "css/main.tcss"
|
||||
SUB_TITLE = "启发式辅助记忆调度器"
|
||||
BINDINGS = [
|
||||
("q", "quit", "退出"),
|
||||
("d", "toggle_dark", "切换色调"),
|
||||
("1", "app.push_screen('dashboard')", "仪表盘"),
|
||||
("2", "app.push_screen('precache_all')", "缓存管理器"),
|
||||
("3", "app.push_screen('nucleon_creator')", "创建新单元"),
|
||||
("0", "app.push_screen('about')", "版本信息"),
|
||||
]
|
||||
SCREENS = {
|
||||
"dashboard": DashboardScreen,
|
||||
"nucleon_creator": NucleonCreatorScreen,
|
||||
"precache_all": PrecachingScreen,
|
||||
"about": AboutScreen,
|
||||
}
|
||||
|
||||
def on_mount(self) -> None:
|
||||
self.push_screen("dashboard")
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
self.exit(event.button.id)
|
||||
|
||||
def action_do_nothing(self):
|
||||
print("DO NOTHING")
|
||||
self.refresh()
|
||||
|
||||
|
||||
def environment_check():
|
||||
from pathlib import Path
|
||||
|
||||
logger.debug("检查环境路径")
|
||||
|
||||
for i in config_var.get()["paths"].values():
|
||||
i = Path(i)
|
||||
if not i.exists():
|
||||
logger.info("创建目录: %s", i)
|
||||
print(f"创建 {i}")
|
||||
i.mkdir(exist_ok=True, parents=True)
|
||||
else:
|
||||
logger.debug("目录已存在: %s", i)
|
||||
print(f"找到 {i}")
|
||||
logger.debug("环境检查完成")
|
||||
|
||||
|
||||
def is_subdir(parent, child):
|
||||
try:
|
||||
child.relative_to(parent)
|
||||
logger.debug("is_subdir: %s 是 %s 的子目录", child, parent)
|
||||
return 1
|
||||
except:
|
||||
logger.debug("is_subdir: %s 不是 %s 的子目录", child, parent)
|
||||
return 0
|
||||
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 开发模式
|
||||
from heurams.context import config_var, rootdir, workdir
|
||||
|
||||
if is_subdir(Path(rootdir), Path(os.getcwd())):
|
||||
os.chdir(Path(rootdir) / ".." / "..")
|
||||
print(f'转入开发数据目录: {Path(rootdir)/".."/".."}')
|
||||
|
||||
environment_check()
|
||||
|
||||
app = HeurAMSApp()
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
|
||||
|
||||
def main():
|
||||
app = HeurAMSApp()
|
||||
app.run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,64 @@
|
||||
NavigatorScreen {
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
#dialog {
|
||||
grid-size: 2;
|
||||
grid-gutter: 1 1;
|
||||
grid-rows: 1fr 3;
|
||||
padding: 0 1;
|
||||
width: 46;
|
||||
height: 12;
|
||||
border: thick $background 80%;
|
||||
background: $surface;
|
||||
}
|
||||
|
||||
/* LLM 聊天界面样式 */
|
||||
LLMChatScreen {
|
||||
background: $surface;
|
||||
}
|
||||
|
||||
#chat-container {
|
||||
height: 100%;
|
||||
padding: 1;
|
||||
}
|
||||
|
||||
#toolbar {
|
||||
height: 3;
|
||||
margin-bottom: 1;
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
#toolbar Button {
|
||||
margin: 0 1;
|
||||
}
|
||||
|
||||
#chat-log {
|
||||
height: 1fr;
|
||||
border: solid $primary;
|
||||
padding: 1;
|
||||
background: $surface;
|
||||
}
|
||||
|
||||
#input-container {
|
||||
height: 3;
|
||||
margin-top: 1;
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
#message-input {
|
||||
width: 1fr;
|
||||
margin-right: 1;
|
||||
}
|
||||
|
||||
#status-bar {
|
||||
height: 1;
|
||||
margin-top: 1;
|
||||
text-style: italic;
|
||||
color: $text-muted;
|
||||
}
|
||||
|
||||
.session-label {
|
||||
color: $primary;
|
||||
text-style: bold;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
"""关于界面"""
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import ScrollableContainer
|
||||
from textual.screen import Screen
|
||||
@@ -9,6 +10,9 @@ from heurams.context import *
|
||||
|
||||
|
||||
class AboutScreen(Screen):
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
]
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=True)
|
||||
@@ -19,19 +23,24 @@ class AboutScreen(Screen):
|
||||
|
||||
版本 {version.ver} {version.stage.capitalize()}
|
||||
|
||||
开发代号: {version.codename.capitalize()}
|
||||
开发代号: {version.codename.capitalize()} {version.codename_cn}
|
||||
|
||||
一个基于启发式算法的开放源代码记忆调度器, 旨在帮助用户更高效地进行记忆工作与学习规划.
|
||||
一个基于启发式算法的辅助记忆调度器, 旨在帮助用户更高效地进行记忆工作与学习规划.
|
||||
|
||||
以 AGPL-3.0 开放源代码
|
||||
|
||||
您可在项目主页 https://ams.imwangzhiyu.xyz 获取用户指南, 开发文档与软件更新
|
||||
|
||||
如果您觉得这个软件有用, 请给它添加一个星标 :)
|
||||
|
||||
开发人员:
|
||||
|
||||
- Wang Zhiyu([@pluvium27](https://github.com/pluvium27)): 项目作者
|
||||
|
||||
特别感谢:
|
||||
|
||||
- [Piotr A. Woźniak](https://supermemo.guru/wiki/Piotr_Wozniak): SuperMemo-2 算法
|
||||
- [Piotr A. Woźniak](https://supermemo.guru/wiki/Piotr_Wozniak): SM-2 算法与 SM-15 算法理论
|
||||
- [Kazuaki Tanida](https://github.com/slaypni): SM-15 算法的 CoffeeScript 实现
|
||||
- [Thoughts Memo](https://www.zhihu.com/people/L.M.Sherlock): 文献参考
|
||||
|
||||
# 参与贡献
|
||||
|
||||
@@ -1,147 +1,174 @@
|
||||
#!/usr/bin/env python3
|
||||
"""仪表盘界面"""
|
||||
|
||||
import pathlib
|
||||
from pathlib import Path
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import ScrollableContainer
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import (Button, Footer, Header, Label, ListItem, ListView,
|
||||
Static)
|
||||
from textual.widgets import Button, Footer, Header, Label, ListItem, ListView, Static
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
import heurams.services.timer as timer
|
||||
import heurams.services.version as version
|
||||
from heurams.context import *
|
||||
from heurams.kernel.particles import *
|
||||
from heurams.kernel.repolib import *
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .about import AboutScreen
|
||||
from .navigator import NavigatorScreen
|
||||
from .preparation import PreparationScreen
|
||||
from .radio import RadioScreen
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DashboardScreen(Screen):
|
||||
"""主仪表盘屏幕"""
|
||||
|
||||
SUB_TITLE = "仪表盘"
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(name, id, classes)
|
||||
self.repostat = {}
|
||||
self.title2dirname = {}
|
||||
self.title2repo = {}
|
||||
self._load_data()
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""组合界面组件"""
|
||||
yield Header(show_clock=True)
|
||||
yield ScrollableContainer(
|
||||
Label(f'欢迎使用 "潜进" 启发式辅助记忆调度器', classes="title-label"),
|
||||
Label(f"当前 UNIX 日时间戳: {timer.get_daystamp()}"),
|
||||
Label(f'时区修正: UTC+{config_var.get()["timezone_offset"] / 3600}'),
|
||||
Label("选择待学习或待修改的记忆单元集:", classes="title-label"),
|
||||
ListView(id="union-list", classes="union-list-view"),
|
||||
Label('欢迎使用 "潜进" 启发式辅助记忆调度器', classes="title-label"),
|
||||
Label(
|
||||
f'"潜进" 启发式辅助记忆调度器 | 版本 {version.ver} {version.codename.capitalize()} 2025'
|
||||
f"当前 UNIX 日时间戳: {timer.get_daystamp()} (UTC+{config_var.get()['timezone_offset'] / 3600})"
|
||||
),
|
||||
Label(f"全局算法设置: {config_var.get()['algorithm']['default']}"),
|
||||
Label("选择待学习或待修改的项目:", classes="title-label"),
|
||||
ListView(id="repo-list", classes="repo-list-view"),
|
||||
Label(f'"潜进" 启发式辅助记忆调度器 版本 {version.ver} '),
|
||||
)
|
||||
yield Footer()
|
||||
|
||||
def item_desc_generator(self, filename) -> dict:
|
||||
"""简单分析以生成项目项显示文本
|
||||
|
||||
Returns:
|
||||
dict: 以数字为列表, 分别呈现单行字符串
|
||||
"""
|
||||
res = dict()
|
||||
filestem = pathlib.Path(filename).stem
|
||||
res[0] = f"{filename}\0"
|
||||
import heurams.kernel.particles as pt
|
||||
from heurams.kernel.particles.loader import load_electron
|
||||
|
||||
electron_file_path = pathlib.Path(config_var.get()["paths"]["electron_dir"]) / (
|
||||
filestem + ".json"
|
||||
def _load_data(self):
|
||||
self.repo_dirs = Repo.probe_valid_repos_in_dir(
|
||||
Path(config_var.get()["paths"]["data"]) / "repo"
|
||||
)
|
||||
for repo_dir in self.repo_dirs:
|
||||
repo = Repo.create_from_repodir(repo_dir)
|
||||
self._analyse_repo(repo)
|
||||
|
||||
logger.debug(f"电子文件路径: {electron_file_path}")
|
||||
|
||||
if electron_file_path.exists(): # 未找到则创建电子文件 (json)
|
||||
pass
|
||||
else:
|
||||
electron_file_path.touch()
|
||||
with open(electron_file_path, "w") as f:
|
||||
f.write("{}")
|
||||
electron_dict = load_electron(path=electron_file_path) # TODO: 取消硬编码扩展名
|
||||
logger.debug(electron_dict)
|
||||
def _analyse_repo(self, repo: Repo):
|
||||
dirname = repo.source.name # type: ignore
|
||||
title = repo.manifest["title"]
|
||||
is_due = 0
|
||||
is_activated = 0
|
||||
unit_sum = len(repo)
|
||||
activated_sum = 0
|
||||
nextdate = 0x3F3F3F3F
|
||||
for i in electron_dict.values():
|
||||
i: pt.Electron
|
||||
logger.debug(i, i.is_due())
|
||||
if i.is_due():
|
||||
is_due = 1
|
||||
if i.is_activated():
|
||||
is_activated = 1
|
||||
nextdate = min(nextdate, i.nextdate())
|
||||
res[1] = f"下一次复习: {nextdate}\n"
|
||||
res[1] += f"{is_due if "需要复习" else "当前无需复习"}"
|
||||
if not is_activated:
|
||||
res[1] = " 尚未激活"
|
||||
return res
|
||||
for i in repo.ident_index:
|
||||
nucleon = pt.Nucleon.create_on_nucleonic_data(
|
||||
nucleonic_data=repo.nucleonic_data_lict.get_itemic_unit(i)
|
||||
)
|
||||
electron = pt.Electron.create_on_electonic_data(
|
||||
electronic_data=repo.electronic_data_lict.get_itemic_unit(i)
|
||||
)
|
||||
if electron.is_activated():
|
||||
activated_sum += 1
|
||||
if electron.is_due():
|
||||
is_due = 1
|
||||
nextdate = min(nextdate, electron.nextdate())
|
||||
is_unfinished = unit_sum > activated_sum
|
||||
if is_unfinished:
|
||||
nextdate = min(nextdate, timer.get_daystamp())
|
||||
need_to_study = is_due or is_unfinished
|
||||
prompt = f"{title}\0\n 进度: {activated_sum}/{unit_sum} ({round(activated_sum/unit_sum*100)}%)\n {'需要学习' if need_to_study else '无需操作'}"
|
||||
stat = {
|
||||
"is_due": is_due,
|
||||
"unit_sum": unit_sum,
|
||||
"title": title,
|
||||
"activated_sum": activated_sum,
|
||||
"nextdate": nextdate,
|
||||
"is_unfinished": is_unfinished,
|
||||
"need_to_study": need_to_study,
|
||||
"prompt": prompt,
|
||||
"dirname": dirname,
|
||||
}
|
||||
self.repostat[dirname] = stat
|
||||
self.title2dirname[title] = dirname
|
||||
self.title2repo[title] = repo
|
||||
|
||||
def on_mount(self) -> None:
|
||||
union_list_widget = self.query_one("#union-list", ListView)
|
||||
"""挂载组件时初始化"""
|
||||
repo_list_widget = self.query_one("#repo-list", ListView)
|
||||
|
||||
probe = probe_all(0)
|
||||
|
||||
if len(probe["nucleon"]):
|
||||
for file in probe["nucleon"]:
|
||||
text = self.item_desc_generator(file)
|
||||
union_list_widget.append(
|
||||
ListItem(
|
||||
Label(text[0] + "\n" + text[1]),
|
||||
)
|
||||
)
|
||||
else:
|
||||
union_list_widget.append(
|
||||
# 按下次复习时间排序
|
||||
repodirs = sorted(
|
||||
self.repo_dirs,
|
||||
key=lambda f: self.repostat[f.name]["nextdate"],
|
||||
reverse=True,
|
||||
)
|
||||
repotitles = map(lambda f: self.repostat[f.name]["title"], repodirs)
|
||||
# 填充列表
|
||||
if not repodirs:
|
||||
repo_list_widget.append(
|
||||
ListItem(
|
||||
Static(
|
||||
"在 ./nucleon/ 中未找到任何内容源数据文件.\n请放置文件后重启应用.\n或者新建空的单元集."
|
||||
"在 ./data/repo/ 中未找到任何仓库.\n"
|
||||
"请导入仓库后重启应用, 或者新建空的仓库."
|
||||
)
|
||||
)
|
||||
)
|
||||
union_list_widget.disabled = True
|
||||
repo_list_widget.disabled = True
|
||||
return
|
||||
|
||||
for repotitle in repotitles:
|
||||
prompt = self.repostat[self.title2dirname[repotitle]]["prompt"]
|
||||
list_item = ListItem(Label(prompt))
|
||||
repo_list_widget.append(list_item)
|
||||
|
||||
# if not self.stay_enabled[repodir]:
|
||||
# list_item.disabled = True
|
||||
|
||||
def on_list_view_selected(self, event) -> None:
|
||||
"""处理列表项选择事件"""
|
||||
if not isinstance(event.item, ListItem):
|
||||
return
|
||||
|
||||
selected_label = event.item.query_one(Label)
|
||||
if "未找到任何 .toml 文件" in str(selected_label.renderable): # type: ignore
|
||||
label_text = str(selected_label.render())
|
||||
|
||||
if "未找到任何仓库" in label_text:
|
||||
return
|
||||
|
||||
selected_filename = pathlib.Path(
|
||||
str(selected_label.renderable)
|
||||
.partition("\0")[0] # 文件名末尾截断, 保留文件名
|
||||
.replace("*", "")
|
||||
) # 去除markdown加粗
|
||||
# 提取文件名
|
||||
selected_repotitle = label_text.partition("\0")[0].replace("*", "")
|
||||
selected_repo = self.title2repo[label_text.partition("\0")[0].replace("*", "")]
|
||||
|
||||
nucleon_file_path = (
|
||||
pathlib.Path(config_var.get()["paths"]["nucleon_dir"]) / selected_filename
|
||||
# 跳转到准备屏幕
|
||||
self.app.push_screen(
|
||||
PreparationScreen(
|
||||
selected_repo, self.repostat[self.title2dirname[selected_repotitle]]
|
||||
)
|
||||
)
|
||||
electron_file_path = pathlib.Path(config_var.get()["paths"]["electron_dir"]) / (
|
||||
str(selected_filename.stem) + ".json"
|
||||
)
|
||||
self.app.push_screen(PreparationScreen(nucleon_file_path, electron_file_path))
|
||||
|
||||
def on_button_pressed(self, event) -> None:
|
||||
if event.button.id == "new_nucleon_button":
|
||||
# 切换到创建单元
|
||||
from .nucreator import NucleonCreatorScreen
|
||||
|
||||
newscr = NucleonCreatorScreen()
|
||||
self.app.push_screen(newscr)
|
||||
elif event.button.id == "precache_all_button":
|
||||
# 切换到缓存管理器
|
||||
from .precache import PrecachingScreen
|
||||
|
||||
precache_screen = PrecachingScreen()
|
||||
self.app.push_screen(precache_screen)
|
||||
elif event.button.id == "about_button":
|
||||
from .about import AboutScreen
|
||||
|
||||
about_screen = AboutScreen()
|
||||
self.app.push_screen(about_screen)
|
||||
|
||||
def action_quit_app(self) -> None:
|
||||
"""退出应用程序"""
|
||||
self.app.exit()
|
||||
|
||||
def action_open_navigator(self) -> None:
|
||||
"""打开导航器"""
|
||||
self.app.push_screen(NavigatorScreen())
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""处理按钮点击事件"""
|
||||
if event.button.id == "navigator-button":
|
||||
self.action_open_navigator()
|
||||
|
||||
204
src/heurams/interface/screens/favmgr.py
Normal file
204
src/heurams/interface/screens/favmgr.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""收藏夹管理器界面"""
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import ScrollableContainer
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import (
|
||||
Button,
|
||||
Footer,
|
||||
Header,
|
||||
Label,
|
||||
ListItem,
|
||||
ListView,
|
||||
Markdown,
|
||||
Static,
|
||||
)
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.kernel.repolib import Repo
|
||||
from heurams.services.favorite_service import FavoriteItem, favorite_manager
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FavoriteManagerScreen(Screen):
|
||||
"""收藏夹管理器屏幕"""
|
||||
|
||||
SUB_TITLE = "收藏夹"
|
||||
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
("d", "toggle_dark", ""),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(name, id, classes)
|
||||
self.favorites: List[FavoriteItem] = []
|
||||
self._load_favorites()
|
||||
|
||||
def _load_favorites(self) -> None:
|
||||
"""加载收藏列表"""
|
||||
self.favorites = favorite_manager.get_all()
|
||||
logger.debug("加载 %d 个收藏项", len(self.favorites))
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""组合界面组件"""
|
||||
yield Header(show_clock=True)
|
||||
with ScrollableContainer(id="favorites-container"):
|
||||
if not self.favorites:
|
||||
yield Label("暂无收藏", classes="empty-label")
|
||||
yield Static("使用 * 键在记忆界面中添加收藏.")
|
||||
else:
|
||||
yield Label(f"共 {len(self.favorites)} 个收藏项", classes="count-label")
|
||||
yield ListView(id="favorites-list")
|
||||
yield Footer()
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""挂载后填充列表"""
|
||||
if self.favorites:
|
||||
list_view = self.query_one("#favorites-list")
|
||||
for fav in self.favorites:
|
||||
list_view.append(self._create_favorite_item(fav)) # type: ignore
|
||||
|
||||
def _encode_favorite_key(self, repo_path: str, ident: str) -> str:
|
||||
"""编码仓库路径和标识符为安全的按钮 ID 部分"""
|
||||
# 使用 \x00 分隔两部分,然后进行 base64 编码
|
||||
combined = f"{repo_path}\x00{ident}"
|
||||
encoded = base64.urlsafe_b64encode(combined.encode()).decode()
|
||||
# 去掉填充的等号
|
||||
return encoded.rstrip("=")
|
||||
|
||||
def _decode_favorite_key(self, key: str) -> tuple[str, str]:
|
||||
"""解码按钮 ID 部分为仓库路径和标识符"""
|
||||
# 补全等号以使长度是4的倍数
|
||||
padded = key + "=" * ((4 - len(key) % 4) % 4)
|
||||
decoded = base64.urlsafe_b64decode(padded.encode()).decode()
|
||||
repo_path, ident = decoded.split("\x00", 1)
|
||||
return repo_path, ident
|
||||
|
||||
def _create_favorite_item(self, fav: FavoriteItem) -> ListItem:
|
||||
"""创建收藏项列表项"""
|
||||
# 尝试获取仓库信息
|
||||
repo_info = self._get_repo_info(fav.repo_path, fav)
|
||||
title = repo_info.get("title", fav.repo_path) if repo_info else fav.repo_path
|
||||
content_preview = repo_info.get("content_preview", "") if repo_info else ""
|
||||
added_time = self._format_time(fav.added)
|
||||
|
||||
# 构建显示文本
|
||||
display_text = f"[b]{title}[/b] ({fav.ident})\n"
|
||||
if content_preview:
|
||||
display_text += f"{content_preview}\n"
|
||||
display_text += f"添加于: {added_time}"
|
||||
if fav.tags:
|
||||
display_text += f" 标签: {', '.join(fav.tags)}"
|
||||
|
||||
# 创建安全的按钮 ID
|
||||
button_key = self._encode_favorite_key(fav.repo_path, fav.ident)
|
||||
# 创建列表项,包含移除按钮
|
||||
container = ScrollableContainer(
|
||||
Markdown(display_text, classes="favorite-content"),
|
||||
Button("移除", id=f"remove-{button_key}", variant="error"),
|
||||
classes="favorite-item",
|
||||
)
|
||||
return ListItem(container)
|
||||
|
||||
def _get_repo_info(self, repo_path: str, fav: FavoriteItem) -> Optional[dict]:
|
||||
"""获取仓库信息(标题、原子内容预览)"""
|
||||
try:
|
||||
data_repo = Path(config_var.get()["paths"]["data"]) / "repo"
|
||||
repo_dir = data_repo / repo_path
|
||||
if not repo_dir.exists():
|
||||
logger.warning("仓库目录不存在: %s", repo_dir)
|
||||
return None
|
||||
repo = Repo.create_from_repodir(repo_dir)
|
||||
# 获取原子内容预览
|
||||
content_preview = ""
|
||||
payload = repo.payload
|
||||
# 查找对应 ident 的 payload 条目
|
||||
for ident_key, content in payload:
|
||||
if ident_key == fav.ident:
|
||||
# 截断过长的内容
|
||||
if isinstance(content, dict) and "content" in content:
|
||||
text = content["content"]
|
||||
else:
|
||||
text = str(content)
|
||||
if len(text) > 100:
|
||||
content_preview = text[:100] + "..."
|
||||
else:
|
||||
content_preview = text
|
||||
break
|
||||
return {
|
||||
"title": repo.manifest["title"],
|
||||
"content_preview": content_preview,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("获取仓库信息失败: %s", e)
|
||||
return None
|
||||
|
||||
def _format_time(self, timestamp: int) -> str:
|
||||
"""格式化时间戳"""
|
||||
from datetime import datetime
|
||||
|
||||
dt = datetime.fromtimestamp(timestamp)
|
||||
return dt.strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""处理按钮点击事件"""
|
||||
button_id = event.button.id
|
||||
if button_id and button_id.startswith("remove-"):
|
||||
# 提取编码后的键
|
||||
key = button_id[7:] # 去掉 "remove-" 前缀
|
||||
try:
|
||||
repo_path, ident = self._decode_favorite_key(key)
|
||||
self._remove_favorite(repo_path, ident)
|
||||
except Exception as e:
|
||||
logger.error("解析按钮 ID 失败: %s", e)
|
||||
self.app.notify("操作失败: 无效的按钮标识", severity="error")
|
||||
|
||||
def _remove_favorite(self, repo_path: str, ident: str) -> None:
|
||||
"""移除收藏项"""
|
||||
if favorite_manager.remove(repo_path, ident):
|
||||
self.app.notify(f"已移除收藏: {ident}", severity="information")
|
||||
# 重新加载列表
|
||||
self._load_favorites()
|
||||
# 刷新界面
|
||||
self._refresh_list()
|
||||
else:
|
||||
self.app.notify(f"移除失败: {ident}", severity="error")
|
||||
|
||||
def _refresh_list(self) -> None:
|
||||
"""刷新列表显示"""
|
||||
container = self.query_one("#favorites-container")
|
||||
# 清空容器
|
||||
for child in container.children:
|
||||
child.remove()
|
||||
# 重新组合
|
||||
if not self.favorites:
|
||||
container.mount(Label("暂无收藏", classes="empty-label"))
|
||||
container.mount(Static("使用 * 键在记忆界面中添加收藏。"))
|
||||
else:
|
||||
container.mount(
|
||||
Label(f"共 {len(self.favorites)} 个收藏项", classes="count-label")
|
||||
)
|
||||
list_view = ListView(id="favorites-list")
|
||||
container.mount(list_view)
|
||||
for fav in self.favorites:
|
||||
list_view.append(self._create_favorite_item(fav))
|
||||
|
||||
def action_go_back(self) -> None:
|
||||
"""返回上一屏幕"""
|
||||
self.app.pop_screen()
|
||||
|
||||
def action_toggle_dark(self) -> None:
|
||||
"""切换暗黑模式"""
|
||||
self.app.dark = not self.app.dark # type: ignore
|
||||
333
src/heurams/interface/screens/llmchat.py
Normal file
333
src/heurams/interface/screens/llmchat.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""LLM 聊天界面"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, Horizontal
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import Button, Footer, Header, Input, Label, RichLog, Static
|
||||
|
||||
from heurams.context import *
|
||||
from heurams.services.llm_service import ChatSession, get_chat_manager
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LLMChatScreen(Screen):
|
||||
"""LLM 聊天屏幕"""
|
||||
|
||||
SUB_TITLE = "AI 聊天"
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
("ctrl+s", "save_session", "保存会话"),
|
||||
("ctrl+l", "load_session", "加载会话"),
|
||||
("ctrl+n", "new_session", "新建会话"),
|
||||
("ctrl+c", "clear_history", "清空历史"),
|
||||
("escape", "focus_input", "聚焦输入"),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: Optional[str] = None,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(name, id, classes)
|
||||
self.session_id = session_id
|
||||
self.chat_manager = get_chat_manager()
|
||||
self.current_session: Optional[ChatSession] = None
|
||||
self.is_streaming = False
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""组合界面组件"""
|
||||
yield Header(show_clock=True)
|
||||
|
||||
with Container(id="chat-container"):
|
||||
# 顶部工具栏
|
||||
with Horizontal(id="toolbar"):
|
||||
yield Button("新建会话", id="new-session", variant="primary")
|
||||
yield Button("保存会话", id="save-session", variant="default")
|
||||
yield Button("加载会话", id="load-session", variant="default")
|
||||
yield Button("清空历史", id="clear-history", variant="default")
|
||||
yield Button("设置系统提示", id="set-system-prompt", variant="default")
|
||||
yield Static(" | ", classes="separator")
|
||||
yield Label("当前会话:", classes="label")
|
||||
yield Static(id="current-session-label", classes="session-label")
|
||||
|
||||
# 聊天记录显示区域
|
||||
yield RichLog(
|
||||
id="chat-log",
|
||||
wrap=True,
|
||||
highlight=True,
|
||||
markup=True,
|
||||
classes="chat-log",
|
||||
)
|
||||
|
||||
# 输入区域
|
||||
with Horizontal(id="input-container"):
|
||||
yield Input(
|
||||
id="message-input",
|
||||
placeholder="输入消息... (按 Ctrl+Enter 发送, Esc 聚焦)",
|
||||
classes="message-input",
|
||||
)
|
||||
yield Button(
|
||||
"发送", id="send-button", variant="primary", classes="send-button"
|
||||
)
|
||||
|
||||
# 状态栏
|
||||
yield Static(id="status-bar", classes="status-bar")
|
||||
|
||||
yield Footer()
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""挂载组件时初始化"""
|
||||
# 获取或创建会话
|
||||
self.current_session = self.chat_manager.get_session(self.session_id)
|
||||
if self.current_session is None:
|
||||
self.notify("无法创建 LLM 会话,请检查配置", severity="error")
|
||||
return
|
||||
|
||||
# 更新会话标签
|
||||
self.query_one("#current-session-label", Static).update(
|
||||
f"{self.current_session.session_id}"
|
||||
)
|
||||
|
||||
# 加载历史消息到聊天记录
|
||||
self._display_history()
|
||||
|
||||
# 聚焦输入框
|
||||
self.query_one("#message-input", Input).focus()
|
||||
|
||||
# 检查配置
|
||||
self._check_config()
|
||||
|
||||
def _check_config(self):
|
||||
"""检查 LLM 配置"""
|
||||
config = config_var.get()
|
||||
provider_name = config["services"]["llm"]
|
||||
provider_config = config["providers"]["llm"][provider_name]
|
||||
|
||||
if provider_name == "openai":
|
||||
if not provider_config.get("key") and not provider_config.get("url"):
|
||||
self.notify(
|
||||
"未配置 OpenAI API key 或 URL,请在 config.toml 中配置 [providers.llm.openai]",
|
||||
severity="warning",
|
||||
)
|
||||
|
||||
def _display_history(self):
|
||||
"""显示当前会话的历史消息"""
|
||||
if not self.current_session:
|
||||
return
|
||||
|
||||
chat_log = self.query_one("#chat-log", RichLog)
|
||||
chat_log.clear()
|
||||
|
||||
for msg in self.current_session.get_history():
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
|
||||
if role == "user":
|
||||
chat_log.write(f"[bold cyan]你:[/bold cyan] {content}")
|
||||
elif role == "assistant":
|
||||
chat_log.write(f"[bold green]AI:[/bold green] {content}")
|
||||
elif role == "system":
|
||||
# 系统消息不显示在聊天记录中
|
||||
pass
|
||||
|
||||
def _add_message_to_log(self, role: str, content: str):
|
||||
"""添加消息到聊天记录显示"""
|
||||
chat_log = self.query_one("#chat-log", RichLog)
|
||||
if role == "user":
|
||||
chat_log.write(f"[bold cyan]你:[/bold cyan] {content}")
|
||||
elif role == "assistant":
|
||||
chat_log.write(f"[bold green]AI:[/bold green] {content}")
|
||||
chat_log.scroll_end()
|
||||
|
||||
async def on_input_submitted(self, event: Input.Submitted):
|
||||
"""处理输入提交"""
|
||||
if event.input.id == "message-input":
|
||||
await self._send_message()
|
||||
|
||||
async def on_button_pressed(self, event: Button.Pressed):
|
||||
"""处理按钮点击"""
|
||||
button_id = event.button.id
|
||||
|
||||
if button_id == "send-button":
|
||||
await self._send_message()
|
||||
elif button_id == "new-session":
|
||||
self.action_new_session()
|
||||
elif button_id == "save-session":
|
||||
self.action_save_session()
|
||||
elif button_id == "load-session":
|
||||
self.action_load_session()
|
||||
elif button_id == "clear-history":
|
||||
self.action_clear_history()
|
||||
elif button_id == "set-system-prompt":
|
||||
self.action_set_system_prompt()
|
||||
|
||||
async def _send_message(self):
|
||||
"""发送当前输入的消息"""
|
||||
if not self.current_session or self.is_streaming:
|
||||
return
|
||||
|
||||
input_widget = self.query_one("#message-input", Input)
|
||||
message = input_widget.value.strip()
|
||||
|
||||
if not message:
|
||||
return
|
||||
|
||||
# 清空输入框
|
||||
input_widget.value = ""
|
||||
|
||||
# 显示用户消息
|
||||
self._add_message_to_log("user", message)
|
||||
|
||||
# 禁用输入和按钮
|
||||
self._set_input_state(disabled=True)
|
||||
self.is_streaming = True
|
||||
|
||||
# 更新状态
|
||||
self.query_one("#status-bar", Static).update("AI 正在思考...")
|
||||
|
||||
try:
|
||||
# 发送消息并获取响应
|
||||
response = await self.current_session.send_message(message)
|
||||
|
||||
# 显示AI响应
|
||||
self._add_message_to_log("assistant", response)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"请求失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
self._add_message_to_log("assistant", f"[red]{error_msg}[/red]")
|
||||
self.notify(error_msg, severity="error")
|
||||
|
||||
finally:
|
||||
# 恢复输入和按钮
|
||||
self._set_input_state(disabled=False)
|
||||
self.is_streaming = False
|
||||
self.query_one("#status-bar", Static).update("就绪")
|
||||
input_widget.focus()
|
||||
|
||||
def _set_input_state(self, disabled: bool):
|
||||
"""设置输入控件状态"""
|
||||
self.query_one("#message-input", Input).disabled = disabled
|
||||
self.query_one("#send-button", Button).disabled = disabled
|
||||
|
||||
async def action_save_session(self):
|
||||
"""保存当前会话到文件"""
|
||||
if not self.current_session:
|
||||
self.notify("无当前会话", severity="error")
|
||||
return
|
||||
|
||||
# 默认保存到 data/chat_sessions/ 目录
|
||||
save_dir = Path(config_var.get()["paths"]["data"]) / "chat_sessions"
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
file_path = save_dir / f"{self.current_session.session_id}.json"
|
||||
self.current_session.save_to_file(file_path)
|
||||
|
||||
self.notify(f"会话已保存到 {file_path}", severity="information")
|
||||
|
||||
async def action_load_session(self):
|
||||
"""从文件加载会话"""
|
||||
# 简化实现:加载默认目录下的第一个会话文件
|
||||
save_dir = Path(config_var.get()["paths"]["data"]) / "chat_sessions"
|
||||
if not save_dir.exists():
|
||||
self.notify(f"目录不存在: {save_dir}", severity="error")
|
||||
return
|
||||
|
||||
session_files = list(save_dir.glob("*.json"))
|
||||
if not session_files:
|
||||
self.notify("未找到会话文件", severity="error")
|
||||
return
|
||||
|
||||
# 使用第一个文件(在实际应用中可以让用户选择)
|
||||
file_path = session_files[0]
|
||||
|
||||
try:
|
||||
# 获取 LLM 提供者
|
||||
provider_name = config_var.get()["services"]["llm"]
|
||||
provider_config = config_var.get()["providers"]["llm"][provider_name]
|
||||
from heurams.providers.llm import providers as prov
|
||||
|
||||
llm_provider = prov[provider_name](provider_config)
|
||||
|
||||
# 加载会话
|
||||
self.current_session = ChatSession.load_from_file(file_path, llm_provider)
|
||||
|
||||
# 更新聊天管理器
|
||||
self.chat_manager.sessions[self.current_session.session_id] = (
|
||||
self.current_session
|
||||
)
|
||||
|
||||
# 更新UI
|
||||
self.query_one("#current-session-label", Static).update(
|
||||
f"{self.current_session.session_id}"
|
||||
)
|
||||
self._display_history()
|
||||
|
||||
self.notify(f"已加载会话: {file_path.name}", severity="information")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("加载会话失败: %s", e)
|
||||
self.notify(f"加载失败: {str(e)}", severity="error")
|
||||
|
||||
async def action_new_session(self):
|
||||
"""创建新会话"""
|
||||
# 简单实现:使用时间戳作为会话ID
|
||||
import time
|
||||
|
||||
new_session_id = f"session_{int(time.time())}"
|
||||
|
||||
self.current_session = self.chat_manager.get_session(new_session_id)
|
||||
|
||||
# 更新UI
|
||||
self.query_one("#current-session-label", Static).update(
|
||||
f"{self.current_session.session_id}"
|
||||
)
|
||||
self._display_history()
|
||||
|
||||
self.notify(f"已创建新会话: {new_session_id}", severity="information")
|
||||
self.query_one("#message-input", Input).focus()
|
||||
|
||||
async def action_clear_history(self):
|
||||
"""清空当前会话历史"""
|
||||
if not self.current_session:
|
||||
return
|
||||
|
||||
self.current_session.clear_history()
|
||||
self._display_history()
|
||||
self.notify("历史已清空", severity="information")
|
||||
|
||||
async def action_set_system_prompt(self):
|
||||
"""设置系统提示词"""
|
||||
if not self.current_session:
|
||||
return
|
||||
|
||||
# 使用输入框获取新提示词
|
||||
input_widget = self.query_one("#message-input", Input)
|
||||
current_value = input_widget.value
|
||||
|
||||
# 临时修改输入框提示
|
||||
input_widget.placeholder = "输入系统提示词... (按 Enter 确认, Esc 取消)"
|
||||
input_widget.value = self.current_session.system_prompt
|
||||
|
||||
# 等待用户输入
|
||||
self.notify("请输入系统提示词,按 Enter 确认", severity="information")
|
||||
|
||||
# 实际应用中需要更复杂的交互,这里简化处理
|
||||
# 用户手动输入后按 Enter 会触发 on_input_submitted
|
||||
# 这里我们只修改占位符,实际系统提示词设置需要额外界面
|
||||
|
||||
def action_focus_input(self):
|
||||
"""聚焦到输入框"""
|
||||
self.query_one("#message-input", Input).focus()
|
||||
|
||||
def action_go_back(self):
|
||||
"""返回上级屏幕"""
|
||||
self.app.pop_screen()
|
||||
1
src/heurams/interface/screens/memointegrity.py
Normal file
1
src/heurams/interface/screens/memointegrity.py
Normal file
@@ -0,0 +1 @@
|
||||
"""整体式记忆工作界面"""
|
||||
256
src/heurams/interface/screens/memoqueue.py
Normal file
256
src/heurams/interface/screens/memoqueue.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""队列式记忆工作界面"""
|
||||
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Center, ScrollableContainer
|
||||
from textual.reactive import reactive
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import Button, Footer, Header, Label, Static
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
import heurams.kernel.puzzles as pz
|
||||
from heurams.context import config_var
|
||||
from heurams.kernel.reactor import *
|
||||
from heurams.services.favorite_service import favorite_manager
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .. import shim
|
||||
|
||||
|
||||
class AtomState(Enum):
|
||||
FAILED = auto()
|
||||
NORMAL = auto()
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MemScreen(Screen):
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
("p", "prev", "查看上一个"),
|
||||
("d", "toggle_dark", ""),
|
||||
("v", "play_voice", "朗读"),
|
||||
("*", "toggle_favorite", "收藏"),
|
||||
("0,1,2,3", "app.push_screen('about')", ""),
|
||||
]
|
||||
|
||||
if config_var.get()["quick_pass"]:
|
||||
BINDINGS.append(("k", "quick_pass", "正确应答"))
|
||||
BINDINGS.append(("f", "quick_fail", "错误应答"))
|
||||
rating = reactive(-1)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
phaser: Phaser,
|
||||
save_func: Callable,
|
||||
repo=None,
|
||||
name=None,
|
||||
id=None,
|
||||
classes=None,
|
||||
) -> None:
|
||||
super().__init__(name, id, classes)
|
||||
self.phaser = phaser
|
||||
self.save_func = save_func
|
||||
self.repo = repo
|
||||
self.update_state()
|
||||
self.fission: Fission
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=True)
|
||||
with ScrollableContainer():
|
||||
yield Label(self._get_progress_text(), id="progress")
|
||||
yield ScrollableContainer(id="puzzle-container")
|
||||
yield Footer()
|
||||
|
||||
def update_state(self):
|
||||
"""更新状态机"""
|
||||
self.procession: Procession = self.phaser.current_procession() # type: ignore
|
||||
self.atom: pt.Atom = self.procession.current_atom # type: ignore
|
||||
|
||||
def on_mount(self):
|
||||
self.fission = self.procession.get_fission()
|
||||
self.mount_puzzle()
|
||||
self.update_display()
|
||||
|
||||
def puzzle_widget(self):
|
||||
try:
|
||||
puzzle = self.fission.get_current_puzzle_inf()
|
||||
return shim.puzzle2widget[puzzle["puzzle"]]( # type: ignore
|
||||
atom=self.atom, alia=puzzle["alia"] # type: ignore
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"调度展开出错: {e}")
|
||||
return Static(f"无法生成谜题 {e}")
|
||||
|
||||
def _get_progress_text(self):
|
||||
s = f"阶段: {self.procession.phase.name}\n"
|
||||
# 收藏状态
|
||||
if self.repo is not None:
|
||||
fav_status = "已收藏" if self._is_current_atom_favorited() else "未收藏"
|
||||
s += f"收藏: {fav_status}\n"
|
||||
if config_var.get().get("debug_topline", 0):
|
||||
try:
|
||||
alia = self.fission.get_current_puzzle_inf()["alia"] # type: ignore
|
||||
s += f"谜题: {alia}\n"
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
stat = self.phaser.__repr__("simple", "")
|
||||
s += f"{stat}\n"
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
stat = self.procession.__repr__("simple", "")
|
||||
s += f"{stat}\n"
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
stat = self.fission.__repr__("simple", "")
|
||||
s += f"{stat}\n"
|
||||
except Exception as e:
|
||||
s = str(e)
|
||||
s += f"进度: {self.procession.process() + 1}/{self.procession.total_length()}"
|
||||
return s
|
||||
|
||||
def update_display(self):
|
||||
"""更新进度显示"""
|
||||
progress_widget = self.query_one("#progress")
|
||||
progress_widget.update(self._get_progress_text()) # type: ignore
|
||||
|
||||
def mount_puzzle(self):
|
||||
"""挂载当前谜题组件"""
|
||||
if self.procession.phase == PhaserState.FINISHED:
|
||||
self.mount_finished_widget()
|
||||
return
|
||||
container = self.query_one("#puzzle-container")
|
||||
for i in container.children:
|
||||
i.remove()
|
||||
container.mount(self.puzzle_widget())
|
||||
|
||||
def mount_finished_widget(self):
|
||||
"""挂载已完成组件"""
|
||||
container = self.query_one("#puzzle-container")
|
||||
for i in container.children:
|
||||
i.remove()
|
||||
from heurams.interface.widgets.finished import Finished
|
||||
|
||||
if config_var.get().get("persist_to_file", 0):
|
||||
self.save_func()
|
||||
container.mount(Finished(is_saved=config_var.get().get("persist_to_file", 0)))
|
||||
|
||||
def on_button_pressed(self, event):
|
||||
event.stop()
|
||||
|
||||
def action_play_voice(self):
|
||||
self.run_worker(self.play_voice, exclusive=True, thread=True)
|
||||
|
||||
def play_voice(self):
|
||||
"""朗读当前内容"""
|
||||
from pathlib import Path
|
||||
|
||||
from heurams.services.audio_service import play_by_path
|
||||
from heurams.services.hasher import get_md5
|
||||
|
||||
path = Path(config_var.get()["paths"]["data"]) / "cache" / "voice"
|
||||
path = path / f"{get_md5(self.atom.registry['nucleon']["tts_text"])}.wav"
|
||||
if path.exists():
|
||||
play_by_path(path)
|
||||
else:
|
||||
from heurams.services.tts_service import convertor
|
||||
|
||||
convertor(self.atom.registry["nucleon"]["tts_text"], path)
|
||||
play_by_path(path)
|
||||
|
||||
def watch_rating(self, old_rating, new_rating) -> None:
|
||||
if new_rating == -1: # 安全值
|
||||
return
|
||||
self.update_state()
|
||||
if self.procession.phase == PhaserState.FINISHED:
|
||||
rating = -1
|
||||
return
|
||||
self.fission.report(new_rating)
|
||||
self.forward(new_rating)
|
||||
self.rating = -1
|
||||
|
||||
def forward(self, rating):
|
||||
self.update_state()
|
||||
allow_forward = 1 if rating >= 4 else 0
|
||||
if allow_forward:
|
||||
self.fission.forward()
|
||||
if self.fission.state == "retronly":
|
||||
self.forward_atom(self.fission.get_quality())
|
||||
self.update_state()
|
||||
self.mount_puzzle()
|
||||
self.update_display()
|
||||
|
||||
def atom_reporter(self, quality):
|
||||
if not self.atom.registry["runtime"]["locked"]:
|
||||
if not self.atom.registry["electron"].is_activated():
|
||||
self.atom.registry["electron"].activate()
|
||||
logger.debug(f"激活原子 {self.atom}")
|
||||
self.atom.lock(1)
|
||||
self.atom.minimize(5)
|
||||
else:
|
||||
self.atom.minimize(quality)
|
||||
else:
|
||||
pass
|
||||
|
||||
def forward_atom(self, quality):
|
||||
logger.debug(f"Quality: {quality}")
|
||||
self.atom_reporter(quality)
|
||||
if quality <= 3:
|
||||
self.procession.append()
|
||||
self.update_state() # 刷新状态
|
||||
self.procession.forward(1)
|
||||
self.update_state() # 刷新状态
|
||||
self.fission = self.procession.get_fission()
|
||||
|
||||
def action_go_back(self):
|
||||
self.app.pop_screen()
|
||||
|
||||
def action_quick_pass(self):
|
||||
self.rating = 5
|
||||
|
||||
def action_quick_fail(self):
|
||||
self.rating = 3
|
||||
|
||||
def _get_repo_rel_path(self) -> str:
|
||||
"""获取仓库相对路径(相对于 data/repo)"""
|
||||
if self.repo is None:
|
||||
return ""
|
||||
# self.repo.source 是 Path 对象,指向仓库目录
|
||||
repo_full_path = self.repo.source
|
||||
data_repo_path = Path(config_var.get()["paths"]["data"]) / "repo"
|
||||
try:
|
||||
rel_path = repo_full_path.relative_to(data_repo_path)
|
||||
return str(rel_path)
|
||||
except ValueError:
|
||||
# 如果不在 data/repo 下,则返回完整路径(字符串形式)
|
||||
return str(repo_full_path)
|
||||
|
||||
def _is_current_atom_favorited(self) -> bool:
|
||||
"""检查当前原子是否已收藏"""
|
||||
if self.repo is None:
|
||||
return False
|
||||
repo_path = self._get_repo_rel_path()
|
||||
return favorite_manager.has(repo_path, self.atom.ident)
|
||||
|
||||
def action_toggle_favorite(self):
|
||||
"""切换收藏状态"""
|
||||
if self.repo is None:
|
||||
self.app.notify("无法收藏:未关联仓库", severity="error")
|
||||
return
|
||||
repo_path = self._get_repo_rel_path()
|
||||
ident = self.atom.ident
|
||||
if favorite_manager.has(repo_path, ident):
|
||||
favorite_manager.remove(repo_path, ident)
|
||||
self.app.notify(f"已取消收藏:{ident}", severity="information")
|
||||
else:
|
||||
favorite_manager.add(repo_path, ident)
|
||||
self.app.notify(f"已收藏:{ident}", severity="information")
|
||||
# 更新显示(如果需要)
|
||||
self.update_display()
|
||||
@@ -1,154 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
from enum import Enum, auto
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Center, ScrollableContainer
|
||||
from textual.reactive import reactive
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import Button, Footer, Header, Label, Static
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
import heurams.kernel.puzzles as pz
|
||||
from heurams.context import config_var
|
||||
from heurams.kernel.reactor import *
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .. import shim
|
||||
|
||||
|
||||
class AtomState(Enum):
|
||||
FAILED = auto()
|
||||
NORMAL = auto()
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MemScreen(Screen):
|
||||
BINDINGS = [
|
||||
("q", "pop_screen", "返回"),
|
||||
# ("p", "prev", "复习上一个"),
|
||||
("d", "toggle_dark", ""),
|
||||
("v", "play_voice", "朗读"),
|
||||
("0,1,2,3", "app.push_screen('about')", ""),
|
||||
]
|
||||
|
||||
if config_var.get()["quick_pass"]:
|
||||
BINDINGS.append(("k", "quick_pass", "跳过"))
|
||||
rating = reactive(-1)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
atoms: list,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(name, id, classes)
|
||||
self.atoms = atoms
|
||||
self.phaser = Phaser(atoms)
|
||||
# logger.debug(self.phaser.state)
|
||||
self.procession: Procession = self.phaser.current_procession() # type: ignore
|
||||
self.atom: pt.Atom = self.procession.current_atom
|
||||
# logger.debug(self.phaser.state)
|
||||
# self.procession.forward(1)
|
||||
for i in atoms:
|
||||
i.do_eval()
|
||||
|
||||
def on_mount(self):
|
||||
self.load_puzzle()
|
||||
pass
|
||||
|
||||
def puzzle_widget(self):
|
||||
try:
|
||||
logger.debug(self.phaser.state)
|
||||
logger.debug(self.procession.cursor)
|
||||
logger.debug(self.atom)
|
||||
self.fission = Fission(self.atom, self.phaser.state)
|
||||
puzzle_debug = next(self.fission.generate())
|
||||
# logger.debug(puzzle_debug)
|
||||
return shim.puzzle2widget[puzzle_debug["puzzle"]](
|
||||
atom=self.atom, alia=puzzle_debug["alia"]
|
||||
)
|
||||
except (KeyError, StopIteration, AttributeError) as e:
|
||||
logger.debug(f"调度展开出错: {e}")
|
||||
return Static("无法生成谜题")
|
||||
# logger.debug(shim.puzzle2widget[puzzle_debug["puzzle"]])
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=True)
|
||||
with ScrollableContainer():
|
||||
yield Label(self._get_progress_text(), id="progress")
|
||||
|
||||
# self.mount(self.current_widget()) # type: ignore
|
||||
yield ScrollableContainer(id="puzzle-container")
|
||||
# yield Button("重新学习此单元", id="re-recognize", variant="warning")
|
||||
yield Footer()
|
||||
|
||||
def _get_progress_text(self):
|
||||
return f"当前进度: {self.procession.process() + 1}/{self.procession.total_length()}"
|
||||
|
||||
def update_display(self):
|
||||
progress_widget = self.query_one("#progress")
|
||||
progress_widget.update(self._get_progress_text()) # type: ignore
|
||||
|
||||
def load_puzzle(self):
|
||||
self.atom: pt.Atom = self.procession.current_atom
|
||||
container = self.query_one("#puzzle-container")
|
||||
for i in container.children:
|
||||
i.remove()
|
||||
container.mount(self.puzzle_widget())
|
||||
|
||||
def load_finished_widget(self):
|
||||
container = self.query_one("#puzzle-container")
|
||||
for i in container.children:
|
||||
i.remove()
|
||||
from heurams.interface.widgets.finished import Finished
|
||||
|
||||
container.mount(Finished())
|
||||
|
||||
def on_button_pressed(self, event):
|
||||
event.stop()
|
||||
|
||||
def watch_rating(self, old_rating, new_rating) -> None:
|
||||
if self.procession == 0:
|
||||
return
|
||||
if new_rating == -1:
|
||||
return
|
||||
forwards = 1 if new_rating >= 4 else 0
|
||||
self.rating = -1
|
||||
logger.debug(f"试图前进: {"允许" if forwards else "禁止"}")
|
||||
if forwards:
|
||||
ret = self.procession.forward(1)
|
||||
if ret == 0: # 若结束了此次队列
|
||||
self.procession = self.phaser.current_procession() # type: ignore
|
||||
if self.procession == 0: # 若所有队列都结束了
|
||||
logger.debug(f"记忆进程结束")
|
||||
for i in self.atoms:
|
||||
i: pt.Atom
|
||||
i.revise()
|
||||
i.persist("electron")
|
||||
self.load_finished_widget()
|
||||
return
|
||||
else:
|
||||
logger.debug(f"建立新队列 {self.procession.phase}")
|
||||
self.load_puzzle()
|
||||
else: # 若不通过
|
||||
self.procession.append()
|
||||
self.update_display()
|
||||
|
||||
def action_quick_pass(self):
|
||||
self.rating = 5
|
||||
self.atom.minimize(5)
|
||||
self.atom.registry["electron"].activate()
|
||||
self.atom.lock(1)
|
||||
|
||||
def action_play_voice(self):
|
||||
"""朗读当前内容"""
|
||||
pass
|
||||
|
||||
def action_toggle_dark(self):
|
||||
self.app.action_toggle_dark()
|
||||
|
||||
def action_pop_screen(self):
|
||||
self.app.pop_screen()
|
||||
93
src/heurams/interface/screens/navigator.py
Normal file
93
src/heurams/interface/screens/navigator.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import webbrowser
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Grid, ScrollableContainer
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Button, Footer, Header, Label, ListItem, ListView, Static
|
||||
|
||||
from heurams.context import *
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .favmgr import FavoriteManagerScreen
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class NavigatorScreen(ModalScreen):
|
||||
"""导航器模态窗口"""
|
||||
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
("escape", "go_back", "返回"),
|
||||
("n", "go_back", "切换"),
|
||||
]
|
||||
|
||||
SCREENS = [
|
||||
("仪表盘", "dashboard"),
|
||||
("电台", "radio"),
|
||||
("语言模型集成", "llmchat"),
|
||||
# ("创建仓库", "repo_creator"),
|
||||
("缓存管理器", "precache_all"),
|
||||
("收藏夹管理器", FavoriteManagerScreen),
|
||||
("关于此软件", "about"),
|
||||
("调试日志", "logviewer"),
|
||||
# ("同步工具", "synctool"),
|
||||
# ("仓库编辑器", "repo_editor"),
|
||||
]
|
||||
|
||||
OTHERS = [
|
||||
("退出程序", "self.app.exit()"),
|
||||
("项目主页", "webbrowser.open('https://ams.imwangzhiyu.xyz')"),
|
||||
]
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""组合界面组件"""
|
||||
with Grid(id="dialog"):
|
||||
yield Label(
|
||||
"[b]请选择要跳转的功能\n或记忆会话实例[/b]\n\n将在此处显示提示",
|
||||
classes="title-label",
|
||||
)
|
||||
yield ListView(
|
||||
*[ListItem(Label(title)) for title, _ in (self.SCREENS + self.OTHERS)],
|
||||
id="nav-list",
|
||||
classes="nav-list-view",
|
||||
)
|
||||
yield Static("按下回车以完成切换\n所有会话将被保存")
|
||||
yield Button(
|
||||
"关闭 (n)", id="close_button", variant="primary", classes="close-button", flat=True
|
||||
)
|
||||
|
||||
def on_mount(self) -> None:
|
||||
# 设置焦点到列表
|
||||
nav_list = self.query_one("#nav-list", ListView)
|
||||
nav_list.focus()
|
||||
|
||||
def on_list_view_selected(self, event) -> None:
|
||||
if not isinstance(event.item, ListItem):
|
||||
return
|
||||
selected_label = event.item.query_one(Label)
|
||||
label_text = str(selected_label.render())
|
||||
# 查找对应的屏幕标识
|
||||
for title, screen_id in self.SCREENS:
|
||||
if title == label_text:
|
||||
self.app.pop_screen()
|
||||
# 跳转到目标屏幕
|
||||
if isinstance(screen_id, str):
|
||||
# 已注册的字符串标识符
|
||||
self.app.push_screen(screen_id)
|
||||
else:
|
||||
self.app.push_screen(screen_id())
|
||||
return
|
||||
for title, cmd in self.OTHERS:
|
||||
if title == label_text:
|
||||
exec(cmd)
|
||||
return
|
||||
return
|
||||
|
||||
def on_button_pressed(self, event) -> None:
|
||||
event.stop()
|
||||
if event.button.id == "close_button":
|
||||
self.action_go_back()
|
||||
|
||||
def action_go_back(self) -> None:
|
||||
self.app.pop_screen()
|
||||
@@ -1,8 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
"""缓存工具界面"""
|
||||
|
||||
import pathlib
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Horizontal, ScrollableContainer
|
||||
from textual.containers import Horizontal, ScrollableContainer, Container
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import Button, Footer, Header, Label, ProgressBar, Static
|
||||
from textual.worker import get_current_worker
|
||||
@@ -11,6 +12,19 @@ import heurams.kernel.particles as pt
|
||||
import heurams.services.hasher as hasher
|
||||
from heurams.context import *
|
||||
|
||||
# 兼容性缓存路径:优先使用 paths.cache,否则使用 data/cache
|
||||
paths = config_var.get()["paths"]
|
||||
cache_dir = pathlib.Path(paths.get("cache", paths["data"] + "/cache")) / "voice"
|
||||
|
||||
|
||||
def format_size(bytes_num: int) -> str:
|
||||
"""将字节数格式化为人类可读的字符串"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
||||
if bytes_num < 1024.0:
|
||||
return f"{bytes_num:.2f} {unit}"
|
||||
bytes_num /= 1024.0 # type: ignore
|
||||
return f"{bytes_num:.2f} PB"
|
||||
|
||||
|
||||
class PrecachingScreen(Screen):
|
||||
"""预缓存音频文件屏幕
|
||||
@@ -23,7 +37,9 @@ class PrecachingScreen(Screen):
|
||||
"""
|
||||
|
||||
SUB_TITLE = "缓存管理器"
|
||||
BINDINGS = [("q", "go_back", "返回")]
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
]
|
||||
|
||||
def __init__(self, nucleons: list = [], desc: str = ""):
|
||||
super().__init__(name=None, id=None, classes=None)
|
||||
@@ -37,25 +53,70 @@ class PrecachingScreen(Screen):
|
||||
self.precache_worker = None
|
||||
self.cancel_flag = 0
|
||||
self.desc = desc
|
||||
for i in nucleons:
|
||||
i: pt.Nucleon
|
||||
i.do_eval()
|
||||
# print("完成 EVAL")
|
||||
# 不再需要缓存配置,保留配置读取以兼容
|
||||
self.cache_stats = {"total_size": 0, "file_count": 0, "human_size": "0 B", "cached_units": 0, "total_units": 0, "cache_rate": 0}
|
||||
self._update_cache_stats()
|
||||
|
||||
def _get_total_units(self) -> int:
|
||||
"""获取所有仓库的总单元数"""
|
||||
from heurams.context import config_var
|
||||
from heurams.kernel.repolib import Repo
|
||||
repo_path = pathlib.Path(config_var.get()["paths"]["data"]) / "repo"
|
||||
repo_dirs = Repo.probe_valid_repos_in_dir(repo_path)
|
||||
repos = map(Repo.create_from_repodir, repo_dirs)
|
||||
total = 0
|
||||
for repo in repos:
|
||||
try:
|
||||
total += len(repo.ident_index)
|
||||
except:
|
||||
continue
|
||||
return total
|
||||
|
||||
def _update_cache_stats(self) -> None:
|
||||
"""更新缓存统计信息"""
|
||||
total_size = 0
|
||||
file_count = 0
|
||||
cached_units = 0
|
||||
if cache_dir.exists():
|
||||
for file in cache_dir.rglob("*"):
|
||||
if file.is_file():
|
||||
total_size += file.stat().st_size
|
||||
file_count += 1
|
||||
if file.suffix.lower() == ".wav":
|
||||
cached_units += 1
|
||||
total_units = self._get_total_units()
|
||||
cache_rate = (cached_units / total_units * 100) if total_units > 0 else 0
|
||||
|
||||
self.cache_stats["total_size"] = total_size
|
||||
self.cache_stats["file_count"] = file_count
|
||||
self.cache_stats["human_size"] = format_size(total_size)
|
||||
self.cache_stats["cached_units"] = cached_units
|
||||
self.cache_stats["total_units"] = total_units
|
||||
self.cache_stats["cache_rate"] = cache_rate
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=True)
|
||||
with ScrollableContainer(id="precache_container"):
|
||||
yield Label("[b]音频预缓存[/b]", classes="title-label")
|
||||
with Container(classes="cache-info"):
|
||||
yield Static(f"缓存路径: {cache_dir}", classes="cache-path")
|
||||
yield Static(f"文件数: {self.cache_stats['file_count']}", classes="cache-count")
|
||||
yield Static(f"总大小: {self.cache_stats['human_size']}", classes="cache-size")
|
||||
yield Button("刷新", id="refresh_cache_stats", variant="default")
|
||||
with Container():
|
||||
yield Static(
|
||||
f"缓存率: {self.cache_stats.get('cache_rate', 0):.1f}% (已缓存 {self.cache_stats.get('cached_units', 0)} / {self.cache_stats.get('total_units', 0)} 个单元)",
|
||||
classes="cache-usage-text"
|
||||
)
|
||||
if self.nucleons:
|
||||
yield Static(f"目标单元归属: [b]{self.desc}[/b]", classes="target-info")
|
||||
yield Static(f"单元数量: {len(self.nucleons)}", classes="target-info")
|
||||
else:
|
||||
yield Static("目标: 所有单元", classes="target-info")
|
||||
|
||||
if self.nucleons:
|
||||
yield Static(f"目标单元归属: [b]{self.desc}[/b]", classes="target-info")
|
||||
yield Static(f"单元数量: {len(self.nucleons)}", classes="target-info")
|
||||
else:
|
||||
yield Static("目标: 所有单元", classes="target-info")
|
||||
|
||||
yield Static(id="status", classes="status-info")
|
||||
yield Static(id="current_item", classes="current-item")
|
||||
yield ProgressBar(total=100, show_eta=False, id="progress_bar")
|
||||
yield Static(id="status", classes="status-info")
|
||||
yield Static(id="current_item", classes="current-item")
|
||||
yield ProgressBar(total=100, show_eta=False, id="progress_bar")
|
||||
|
||||
with Horizontal(classes="button-group"):
|
||||
if not self.is_precaching:
|
||||
@@ -73,6 +134,7 @@ class PrecachingScreen(Screen):
|
||||
def on_mount(self):
|
||||
"""挂载时初始化状态"""
|
||||
self.update_status("就绪", "等待开始...")
|
||||
self._update_cache_display()
|
||||
|
||||
def update_status(self, status, current_item="", progress=None):
|
||||
"""更新状态显示"""
|
||||
@@ -87,19 +149,36 @@ class PrecachingScreen(Screen):
|
||||
progress_bar.progress = progress
|
||||
progress_bar.advance(0) # 刷新显示
|
||||
|
||||
def _update_cache_display(self) -> None:
|
||||
"""更新缓存信息显示"""
|
||||
# 更新统计信息
|
||||
self._update_cache_stats()
|
||||
# 更新缓存率进度条
|
||||
# 更新缓存大小和文件数显示
|
||||
cache_count_widget = self.query_one(".cache-count", Static)
|
||||
cache_size_widget = self.query_one(".cache-size", Static)
|
||||
cache_usage_text = self.query_one(".cache-usage-text", Static)
|
||||
if cache_count_widget:
|
||||
cache_count_widget.update(f"文件数: {self.cache_stats['file_count']}")
|
||||
if cache_size_widget:
|
||||
cache_size_widget.update(f"总大小: {self.cache_stats['human_size']}")
|
||||
if cache_usage_text:
|
||||
cache_usage_text.update(
|
||||
f"缓存率: {self.cache_stats.get('cache_rate', 0):.1f}% "
|
||||
f"(已缓存 {self.cache_stats.get('cached_units', 0)} / {self.cache_stats.get('total_units', 0)} 个单元)"
|
||||
)
|
||||
|
||||
def precache_by_text(self, text: str):
|
||||
"""预缓存单段文本的音频"""
|
||||
from heurams.context import config_var, rootdir, workdir
|
||||
|
||||
cache_dir = pathlib.Path(config_var.get()["paths"]["cache_dir"])
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
cache_file = cache_dir / f"{hasher.get_md5(text)}.wav"
|
||||
if not cache_file.exists():
|
||||
try: # TODO: 调用模块消除tts耦合
|
||||
import edge_tts as tts
|
||||
try:
|
||||
from heurams.services.tts_service import convertor
|
||||
|
||||
communicate = tts.Communicate(text, "zh-CN-XiaoxiaoNeural")
|
||||
communicate.save_sync(str(cache_file))
|
||||
convertor(text, cache_file)
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"预缓存失败 '{text}': {e}")
|
||||
@@ -108,10 +187,8 @@ class PrecachingScreen(Screen):
|
||||
|
||||
def precache_by_nucleon(self, nucleon: pt.Nucleon):
|
||||
"""依据 Nucleon 缓存"""
|
||||
# print(nucleon.metadata['formation']['tts_text'])
|
||||
ret = self.precache_by_text(nucleon.metadata["formation"]["tts_text"])
|
||||
ret = self.precache_by_text(nucleon["tts_text"])
|
||||
return ret
|
||||
# print(f"TTS 缓存: {nucleon.metadata['formation']['tts_text']}")
|
||||
|
||||
def precache_by_list(self, nucleons: list):
|
||||
"""依据 Nucleons 列表缓存"""
|
||||
@@ -120,7 +197,7 @@ class PrecachingScreen(Screen):
|
||||
worker = get_current_worker()
|
||||
if worker and worker.is_cancelled: # 函数在worker中执行且已被取消
|
||||
return False
|
||||
text = nucleon.metadata["formation"]["tts_text"]
|
||||
text = nucleon["tts_text"]
|
||||
# self.current_item = text[:30] + "..." if len(text) > 50 else text
|
||||
# print(text)
|
||||
self.processed += 1
|
||||
@@ -150,36 +227,30 @@ class PrecachingScreen(Screen):
|
||||
# print(f"返回 {ret}")
|
||||
return ret
|
||||
|
||||
def precache_by_filepath(self, path: pathlib.Path):
|
||||
"""预缓存单个文件的所有内容"""
|
||||
lst = list()
|
||||
for i in pt.load_nucleon(path):
|
||||
lst.append(i[0])
|
||||
return self.precache_by_list(lst)
|
||||
|
||||
def precache_all_files(self):
|
||||
"""预缓存所有文件"""
|
||||
from heurams.context import config_var, rootdir, workdir
|
||||
from heurams.kernel.repolib import Repo
|
||||
|
||||
nucleon_path = pathlib.Path(config_var.get()["paths"]["nucleon_dir"])
|
||||
nucleon_files = [
|
||||
f for f in nucleon_path.iterdir() if f.suffix == ".toml"
|
||||
] # TODO: 解耦合
|
||||
repo_path = pathlib.Path(config_var.get()["paths"]["data"]) / "repo"
|
||||
repo_dirs = Repo.probe_valid_repos_in_dir(repo_path)
|
||||
repos = map(Repo.create_from_repodir, repo_dirs)
|
||||
|
||||
# 计算总项目数
|
||||
self.total = 0
|
||||
nu = list()
|
||||
for file in nucleon_files:
|
||||
nucleon_list = list()
|
||||
for repo in repos:
|
||||
try:
|
||||
for i in pt.load_nucleon(file):
|
||||
nu.append(i[0])
|
||||
for i in repo.ident_index:
|
||||
nucleon_list.append(
|
||||
pt.Nucleon.create_on_nucleonic_data(
|
||||
repo.nucleonic_data_lict.get_itemic_unit(i)
|
||||
)
|
||||
)
|
||||
except:
|
||||
continue
|
||||
self.total = len(nu)
|
||||
for i in nu:
|
||||
i: pt.Nucleon
|
||||
i.do_eval()
|
||||
return self.precache_by_list(nu)
|
||||
self.total = len(nucleon_list)
|
||||
return self.precache_by_list(nucleon_list)
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
event.stop()
|
||||
@@ -216,16 +287,19 @@ class PrecachingScreen(Screen):
|
||||
|
||||
from heurams.context import config_var, rootdir, workdir
|
||||
|
||||
shutil.rmtree(
|
||||
f"{config_var.get()["paths"]["cache_dir"]}", ignore_errors=True
|
||||
)
|
||||
shutil.rmtree(cache_dir, ignore_errors=True)
|
||||
self.update_status("已清空", "音频缓存已清空", 0)
|
||||
self._update_cache_display() # 更新缓存统计显示
|
||||
except Exception as e:
|
||||
self.update_status("错误", f"清空缓存失败: {e}")
|
||||
self.cancel_flag = 1
|
||||
self.processed = 0
|
||||
self.progress = 0
|
||||
|
||||
elif event.button.id == "refresh_cache_stats":
|
||||
# 刷新缓存统计信息
|
||||
self._update_cache_display()
|
||||
self.app.notify("缓存信息已刷新", severity="information")
|
||||
elif event.button.id == "go_back":
|
||||
self.action_go_back()
|
||||
|
||||
@@ -233,8 +307,3 @@ class PrecachingScreen(Screen):
|
||||
if self.is_precaching and self.precache_worker:
|
||||
self.precache_worker.cancel()
|
||||
self.app.pop_screen()
|
||||
|
||||
def action_quit_app(self):
|
||||
if self.is_precaching and self.precache_worker:
|
||||
self.precache_worker.cancel()
|
||||
self.app.exit()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
"""记忆准备界面"""
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import ScrollableContainer
|
||||
from textual.reactive import reactive
|
||||
@@ -10,6 +11,7 @@ import heurams.kernel.particles as pt
|
||||
import heurams.services.hasher as hasher
|
||||
from heurams.context import *
|
||||
from heurams.context import config_var
|
||||
from heurams.kernel.repolib import *
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -28,24 +30,19 @@ class PreparationScreen(Screen):
|
||||
|
||||
scheduled_num = reactive(config_var.get()["scheduled_num"])
|
||||
|
||||
def __init__(self, nucleon_file: pathlib.Path, electron_file: pathlib.Path) -> None:
|
||||
def __init__(self, repo: Repo, repostat: dict) -> None:
|
||||
super().__init__(name=None, id=None, classes=None)
|
||||
self.nucleon_file = nucleon_file
|
||||
self.electron_file = electron_file
|
||||
self.nucleons_with_orbital = pt.load_nucleon(self.nucleon_file)
|
||||
self.electrons = pt.load_electron(self.electron_file)
|
||||
self.repo = repo
|
||||
self.repostat = repostat
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=True)
|
||||
with ScrollableContainer(id="vice_container"):
|
||||
yield Label(f"准备就绪: [b]{self.nucleon_file.stem}[/b]\n")
|
||||
yield Label(f"准备就绪: [b]{self.repostat['title']}[/b]\n")
|
||||
yield Label(
|
||||
f"内容源文件: {config_var.get()['paths']['nucleon_dir']}/[b]{self.nucleon_file.name}[/b]"
|
||||
f"仓库路径: {config_var.get()['paths']['data']}/repo/[b]{self.repostat['dirname']}[/b]"
|
||||
)
|
||||
yield Label(
|
||||
f"元数据文件: {config_var.get()['paths']['electron_dir']}/[b]{self.electron_file.name}[/b]"
|
||||
)
|
||||
yield Label(f"\n单元数量: {len(self.nucleons_with_orbital)}\n")
|
||||
yield Label(f"\n单元数量: {len(self.repo)}\n")
|
||||
yield Label(f"单次记忆数量: {self.scheduled_num}", id="schnum_label")
|
||||
|
||||
yield Button(
|
||||
@@ -62,7 +59,8 @@ class PreparationScreen(Screen):
|
||||
)
|
||||
|
||||
yield Static(f"\n单元预览:\n")
|
||||
yield Markdown(self._get_full_content().replace("/", ""), classes="full")
|
||||
for i in self._get_full_content().replace("/", "").splitlines():
|
||||
yield Static(i, classes="full")
|
||||
yield Footer()
|
||||
|
||||
# def watch_scheduled_num(self, old_scheduled_num, new_scheduled_num):
|
||||
@@ -75,10 +73,11 @@ class PreparationScreen(Screen):
|
||||
|
||||
def _get_full_content(self):
|
||||
content = ""
|
||||
for nucleon, orbital in self.nucleons_with_orbital:
|
||||
nucleon: pt.Nucleon
|
||||
# print(nucleon.payload)
|
||||
content += " - " + nucleon["content"] + " \n"
|
||||
for i in self.repo.ident_index:
|
||||
n = pt.Nucleon.create_on_nucleonic_data(
|
||||
nucleonic_data=self.repo.nucleonic_data_lict.get_itemic_unit(i)
|
||||
)
|
||||
content += f" • {n['content']} \n"
|
||||
return content
|
||||
|
||||
def action_go_back(self):
|
||||
@@ -88,9 +87,15 @@ class PreparationScreen(Screen):
|
||||
from ..screens.precache import PrecachingScreen
|
||||
|
||||
lst = list()
|
||||
for i in self.nucleons_with_orbital:
|
||||
lst.append(i[0])
|
||||
precache_screen = PrecachingScreen(lst)
|
||||
for i in self.repo.ident_index:
|
||||
lst.append(
|
||||
pt.Nucleon.create_on_nucleonic_data(
|
||||
self.repo.nucleonic_data_lict.get_itemic_unit(i)
|
||||
)
|
||||
)
|
||||
precache_screen = PrecachingScreen(
|
||||
nucleons=lst, desc=self.repo.manifest["title"]
|
||||
)
|
||||
self.app.push_screen(precache_screen)
|
||||
|
||||
def action_quit_app(self):
|
||||
@@ -101,38 +106,35 @@ class PreparationScreen(Screen):
|
||||
logger.debug("按下按钮")
|
||||
if event.button.id == "start_memorizing_button":
|
||||
atoms = list()
|
||||
for nucleon, orbital in self.nucleons_with_orbital:
|
||||
atom = pt.Atom(nucleon.ident)
|
||||
atom.link("nucleon", nucleon)
|
||||
try:
|
||||
atom.link("electron", self.electrons[nucleon.ident])
|
||||
except KeyError:
|
||||
atom.link("electron", pt.Electron(nucleon.ident))
|
||||
atom.link("orbital", orbital)
|
||||
atom.link("nucleon_fmt", "toml")
|
||||
atom.link("electron_fmt", "json")
|
||||
atom.link("orbital_fmt", "toml")
|
||||
atom.link("nucleon_path", self.nucleon_file)
|
||||
atom.link("electron_path", self.electron_file)
|
||||
atom.link("orbital_path", None)
|
||||
atoms.append(atom)
|
||||
for i in self.repo.ident_index:
|
||||
n = pt.Nucleon.create_on_nucleonic_data(
|
||||
nucleonic_data=self.repo.nucleonic_data_lict.get_itemic_unit(i)
|
||||
)
|
||||
e = pt.Electron.create_on_electonic_data(
|
||||
electronic_data=self.repo.electronic_data_lict.get_itemic_unit(i)
|
||||
)
|
||||
a = pt.Atom(n, e, self.repo.orbitic_data)
|
||||
atoms.append(a)
|
||||
|
||||
atoms_to_provide = list()
|
||||
left_new = self.scheduled_num
|
||||
for i in atoms:
|
||||
i: pt.Atom
|
||||
if i.registry["electron"].is_due():
|
||||
atoms_to_provide.append(i)
|
||||
if i.registry["electron"].is_activated():
|
||||
if i.registry["electron"].is_due():
|
||||
atoms_to_provide.append(i)
|
||||
else:
|
||||
if i.registry["electron"].is_activated():
|
||||
pass
|
||||
else:
|
||||
left_new -= 1
|
||||
if left_new >= 0:
|
||||
atoms_to_provide.append(i)
|
||||
logger.debug(f"ATP: {atoms_to_provide}")
|
||||
from .memorizor import MemScreen
|
||||
left_new -= 1
|
||||
if left_new >= 0:
|
||||
atoms_to_provide.append(i)
|
||||
import heurams.kernel.reactor as rt
|
||||
|
||||
memscreen = MemScreen(atoms_to_provide)
|
||||
from .memoqueue import MemScreen
|
||||
|
||||
pheser = rt.Phaser(atoms_to_provide)
|
||||
save_func = self.repo.persist_to_repodir
|
||||
memscreen = MemScreen(pheser, save_func, repo=self.repo)
|
||||
self.app.push_screen(memscreen)
|
||||
|
||||
elif event.button.id == "precache_button":
|
||||
self.action_precache()
|
||||
|
||||
217
src/heurams/interface/screens/radio.py
Normal file
217
src/heurams/interface/screens/radio.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""用于筛选当日记忆的条目 以音频形式重放"""
|
||||
|
||||
""" "前进电台" 界面"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, ScrollableContainer
|
||||
from textual.reactive import reactive
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import Button, Footer, Header, Label, Static
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
from heurams.kernel.repolib import Repo
|
||||
from heurams.context import config_var
|
||||
from heurams.services.audio_service import play_by_path
|
||||
from heurams.services.hasher import get_md5
|
||||
from heurams.services.logger import get_logger
|
||||
from heurams.services.tts_service import convertor
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RadioScreen(Screen):
|
||||
SUB_TITLE = "电台"
|
||||
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
("space", "toggle_play", "播放/暂停"),
|
||||
]
|
||||
|
||||
# 当前播放的原子索引
|
||||
current_index = reactive(0)
|
||||
# 播放状态: 'stopped', 'playing', 'paused'
|
||||
play_state = reactive("stopped")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(name, id, classes)
|
||||
self._organizer()
|
||||
|
||||
def _organizer(self):
|
||||
repodirs = Repo.probe_valid_repos_in_dir(Path(config_var.get()['paths']['data']) / 'repo')
|
||||
repos = list(map(lambda repodir: Repo.create_from_repodir(repodir), repodirs))
|
||||
for repo in repos:
|
||||
last_modify = 0.0
|
||||
for i in repo.ident_index:
|
||||
e = pt.Electron.create_on_electonic_data(
|
||||
electronic_data=repo.electronic_data_lict.get_itemic_unit(i)
|
||||
)
|
||||
last_modify = max(last_modify, e.las())
|
||||
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=True)
|
||||
with Container(id="main"):
|
||||
yield Label("[b]前进电台[/b]", classes="title")
|
||||
yield Static(f"共 {len(self.atoms)} 条当日记忆", id="status")
|
||||
with Container(id="controls"):
|
||||
yield Button("播放", id="play", variant="success")
|
||||
yield Button("暂停", id="pause", variant="primary")
|
||||
yield Button("上一首", id="prev", variant="default")
|
||||
yield Button("下一首", id="next", variant="default")
|
||||
yield Button("停止", id="stop", variant="error")
|
||||
yield ScrollableContainer(id="playlist")
|
||||
yield Footer()
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""挂载后更新播放列表显示"""
|
||||
self._update_playlist()
|
||||
|
||||
def _filter_due_atoms(self) -> List[pt.Atom]:
|
||||
"""筛选当日需要复习的原子(已激活且到期)"""
|
||||
atoms = []
|
||||
for ident in self.repo.ident_index:
|
||||
n = pt.Nucleon.create_on_nucleonic_data(
|
||||
nucleonic_data=self.repo.nucleonic_data_lict.get_itemic_unit(ident)
|
||||
)
|
||||
e = pt.Electron.create_on_electonic_data(
|
||||
electronic_data=self.repo.electronic_data_lict.get_itemic_unit(ident)
|
||||
)
|
||||
a = pt.Atom(n, e, self.repo.orbitic_data)
|
||||
# 仅选择已激活且到期的原子
|
||||
if (
|
||||
a.registry["electron"].is_activated()
|
||||
and a.registry["electron"].is_due()
|
||||
):
|
||||
atoms.append(a)
|
||||
return atoms
|
||||
|
||||
def _update_playlist(self) -> None:
|
||||
"""更新播放列表显示"""
|
||||
container = self.query_one("#playlist")
|
||||
container.remove_children()
|
||||
for idx, atom in enumerate(self.atoms):
|
||||
content = atom.registry["nucleon"].get("content", "无内容")
|
||||
prefix = "▶ " if idx == self.current_index else " "
|
||||
widget = Static(f"{prefix}{idx+1}. {content[:50]}...")
|
||||
widget.set_class(idx == self.current_index, "current")
|
||||
container.mount(widget)
|
||||
|
||||
def _get_audio_path(self, atom: pt.Atom) -> Path:
|
||||
"""返回音频文件路径,若不存在则生成"""
|
||||
tts_text = atom.registry["nucleon"].get("tts_text", "")
|
||||
if not tts_text:
|
||||
tts_text = atom.registry["nucleon"].get("content", "")
|
||||
voice_dir = Path(config_var.get()["paths"]["data"]) / "cache" / "voice"
|
||||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = voice_dir / f"{get_md5(tts_text)}.wav"
|
||||
if not path.exists():
|
||||
convertor(tts_text, path)
|
||||
return path
|
||||
|
||||
async def _play_atom(self, idx: int) -> None:
|
||||
"""播放指定索引的原子(异步)"""
|
||||
if idx < 0 or idx >= len(self.atoms):
|
||||
return
|
||||
atom = self.atoms[idx]
|
||||
try:
|
||||
path = self._get_audio_path(atom)
|
||||
self._current_path = path
|
||||
# 在后台线程中播放,避免阻塞UI
|
||||
await self.run_worker(
|
||||
lambda: play_by_path(path), exclusive=True, thread=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("播放失败: %s", e)
|
||||
|
||||
def _stop_playback(self) -> None:
|
||||
"""停止当前播放"""
|
||||
if self._play_task and not self._play_task.done():
|
||||
self._play_task.cancel()
|
||||
self._play_task = None
|
||||
self._current_path = None
|
||||
self.play_state = "stopped"
|
||||
|
||||
async def _play_current(self) -> None:
|
||||
"""播放当前索引的原子"""
|
||||
self._stop_playback()
|
||||
self.play_state = "playing"
|
||||
self._play_task = asyncio.create_task(self._play_atom(self.current_index))
|
||||
try:
|
||||
await self._play_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
if self.play_state == "playing":
|
||||
self.play_state = "stopped"
|
||||
|
||||
# 按钮事件处理
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
button_id = event.button.id
|
||||
if button_id == "play":
|
||||
self.action_toggle_play()
|
||||
elif button_id == "pause":
|
||||
self.action_pause()
|
||||
elif button_id == "prev":
|
||||
self.action_prev()
|
||||
elif button_id == "next":
|
||||
self.action_next()
|
||||
elif button_id == "stop":
|
||||
self.action_stop()
|
||||
|
||||
# 键盘动作
|
||||
def action_toggle_play(self) -> None:
|
||||
if self.play_state == "playing":
|
||||
self.action_pause()
|
||||
else:
|
||||
self.action_play()
|
||||
|
||||
def action_play(self) -> None:
|
||||
if self.play_state != "playing":
|
||||
if self.play_state == "paused":
|
||||
# 恢复播放(目前暂停功能简单实现为停止)
|
||||
self.play_state = "playing"
|
||||
else:
|
||||
asyncio.create_task(self._play_current())
|
||||
|
||||
def action_pause(self) -> None:
|
||||
if self.play_state == "playing":
|
||||
self._stop_playback()
|
||||
self.play_state = "paused"
|
||||
|
||||
def action_stop(self) -> None:
|
||||
self._stop_playback()
|
||||
self.play_state = "stopped"
|
||||
|
||||
def action_next(self) -> None:
|
||||
if self.current_index < len(self.atoms) - 1:
|
||||
self.current_index += 1
|
||||
self._update_playlist()
|
||||
if self.play_state == "playing":
|
||||
asyncio.create_task(self._play_current())
|
||||
|
||||
def action_prev(self) -> None:
|
||||
if self.current_index > 0:
|
||||
self.current_index -= 1
|
||||
self._update_playlist()
|
||||
if self.play_state == "playing":
|
||||
asyncio.create_task(self._play_current())
|
||||
|
||||
def action_go_back(self) -> None:
|
||||
self._stop_playback()
|
||||
self.app.pop_screen()
|
||||
|
||||
# 响应式更新
|
||||
def watch_current_index(self, old: int, new: int) -> None:
|
||||
self._update_playlist()
|
||||
|
||||
def watch_play_state(self, old: str, new: str) -> None:
|
||||
# 更新按钮状态(可在此添加样式变化)
|
||||
pass
|
||||
@@ -1,20 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
"""仓库创建向导界面"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import toml
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import ScrollableContainer
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import (Button, Footer, Header, Input, Label, Markdown,
|
||||
Select)
|
||||
from textual.widgets import Button, Footer, Header, Input, Label, Markdown, Select
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.services.version import ver
|
||||
|
||||
|
||||
class NucleonCreatorScreen(Screen):
|
||||
class RepoCreatorScreen(Screen):
|
||||
BINDINGS = [("q", "go_back", "返回")]
|
||||
SUB_TITLE = "单元集创建向导"
|
||||
SUB_TITLE = "仓库创建向导"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(name=None, id=None, classes=None)
|
||||
@@ -24,7 +24,7 @@ class NucleonCreatorScreen(Screen):
|
||||
|
||||
from heurams.context import config_var
|
||||
|
||||
template_dir = Path(config_var.get()["paths"]["template_dir"])
|
||||
template_dir = Path(config_var.get()["paths"]["data"]) / "templates"
|
||||
templates = list()
|
||||
for i in template_dir.iterdir():
|
||||
if i.name.endswith(".toml"):
|
||||
267
src/heurams/interface/screens/repoeditor.py
Normal file
267
src/heurams/interface/screens/repoeditor.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""仓库编辑器, 使用TextArea控件等实现仓库配置编辑"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import toml
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, Horizontal, ScrollableContainer, Vertical
|
||||
from textual.reactive import reactive
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import (
|
||||
Button,
|
||||
Footer,
|
||||
Header,
|
||||
Label,
|
||||
ListItem,
|
||||
ListView,
|
||||
Static,
|
||||
TextArea,
|
||||
)
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.kernel.repolib import Repo
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RepoEditorScreen(Screen):
|
||||
"""仓库编辑器屏幕"""
|
||||
|
||||
SUB_TITLE = "仓库编辑器"
|
||||
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
("s", "save_file", "保存"),
|
||||
("r", "reload_file", "重载"),
|
||||
("d", "toggle_dark", ""),
|
||||
]
|
||||
|
||||
# 当前选择的仓库路径
|
||||
selected_repo_path: reactive[Optional[Path]] = reactive(None)
|
||||
# 当前选择的文件名
|
||||
selected_filename: reactive[Optional[str]] = reactive(None)
|
||||
# 文件内容
|
||||
file_content: reactive[str] = reactive("")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo: Optional[Repo] = None,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(name, id, classes)
|
||||
self.repo = repo
|
||||
self.repo_dir: Optional[Path] = None
|
||||
self.file_list = []
|
||||
if repo is not None and repo.source is not None:
|
||||
self.repo_dir = repo.source
|
||||
self._load_file_list()
|
||||
# selected_repo_path 将在 on_mount 中设置,避免触发watch时组件未就绪
|
||||
|
||||
def _load_file_list(self) -> None:
|
||||
"""加载仓库目录下的文件列表"""
|
||||
if self.repo_dir is None:
|
||||
return
|
||||
self.file_list = []
|
||||
for fname in Repo.file_mapping.values():
|
||||
fpath = self.repo_dir / fname
|
||||
if fpath.exists():
|
||||
self.file_list.append(fname)
|
||||
# 也可能存在其他文件,但暂时只支持标准文件
|
||||
self.file_list.sort()
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""组合界面组件"""
|
||||
yield Header(show_clock=True)
|
||||
with Container(id="main_container"):
|
||||
with Horizontal(id="top_panel"):
|
||||
# 左侧: 仓库选择
|
||||
with Vertical(id="repo_selector", classes="panel"):
|
||||
yield Label("仓库列表", classes="panel-title")
|
||||
yield ListView(
|
||||
*[
|
||||
ListItem(Label(repo_dir.name))
|
||||
for repo_dir in self._get_repo_dirs()
|
||||
],
|
||||
id="repo_list",
|
||||
classes="list-view",
|
||||
)
|
||||
# 中间: 文件列表
|
||||
with Vertical(id="file_selector", classes="panel"):
|
||||
yield Label("文件列表", classes="panel-title")
|
||||
yield ListView(
|
||||
*[ListItem(Label(fname)) for fname in self.file_list],
|
||||
id="file_list",
|
||||
classes="list-view",
|
||||
)
|
||||
# 右侧: 编辑区域
|
||||
with Vertical(id="editor_panel", classes="panel"):
|
||||
yield Label("编辑文件", classes="panel-title")
|
||||
yield TextArea(
|
||||
id="text_editor",
|
||||
language="plaintext",
|
||||
classes="text-editor",
|
||||
)
|
||||
with Horizontal(id="button_bar"):
|
||||
yield Button("保存", id="save_button", variant="primary")
|
||||
yield Button("重载", id="reload_button", variant="default")
|
||||
yield Button("返回", id="back_button", variant="error")
|
||||
yield Footer()
|
||||
|
||||
def _get_repo_dirs(self) -> list[Path]:
|
||||
"""获取data/repo/下所有有效仓库目录"""
|
||||
repo_root = Path(config_var.get()["paths"]["data"]) / "repo"
|
||||
repo_dirs = []
|
||||
if repo_root.exists():
|
||||
for entry in repo_root.iterdir():
|
||||
if entry.is_dir():
|
||||
# 检查是否存在 manifest.toml
|
||||
if (entry / "manifest.toml").exists():
|
||||
repo_dirs.append(entry)
|
||||
return repo_dirs
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""挂载组件时初始化"""
|
||||
# 如果已有仓库,设置 selected_repo_path 以触发watch(此时组件已就绪)
|
||||
if self.repo_dir is not None:
|
||||
self.selected_repo_path = self.repo_dir
|
||||
# 焦点放在仓库列表
|
||||
self.query_one("#repo_list", ListView).focus()
|
||||
|
||||
def watch_selected_repo_path(
|
||||
self, old_path: Optional[Path], new_path: Optional[Path]
|
||||
) -> None:
|
||||
"""当选择的仓库路径变化时,加载文件列表"""
|
||||
if new_path is None:
|
||||
self.file_list = []
|
||||
self.selected_filename = None
|
||||
self.file_content = ""
|
||||
return
|
||||
self.repo_dir = new_path
|
||||
self._load_file_list()
|
||||
# 如果组件已挂载,更新UI
|
||||
if self.is_mounted:
|
||||
file_list_view = self.query_one("#file_list", ListView)
|
||||
file_list_view.clear()
|
||||
for fname in self.file_list:
|
||||
file_list_view.append(ListItem(Label(fname)))
|
||||
# 清空编辑器
|
||||
self.query_one("#text_editor", TextArea).text = ""
|
||||
self.selected_filename = None
|
||||
|
||||
def watch_selected_filename(
|
||||
self, old_name: Optional[str], new_name: Optional[str]
|
||||
) -> None:
|
||||
"""当选择的文件名变化时,加载文件内容"""
|
||||
if new_name is None or self.repo_dir is None:
|
||||
self.file_content = ""
|
||||
return
|
||||
file_path = self.repo_dir / new_name
|
||||
if not file_path.exists():
|
||||
self.notify(f"文件不存在: {new_name}", severity="error")
|
||||
return
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
self.file_content = content
|
||||
# 如果组件已挂载,更新编辑器
|
||||
if self.is_mounted:
|
||||
editor = self.query_one("#text_editor", TextArea)
|
||||
editor.text = content
|
||||
# 根据文件后缀设置语言
|
||||
if new_name.endswith(".toml"):
|
||||
editor.language = "toml"
|
||||
elif new_name.endswith(".json"):
|
||||
editor.language = "json"
|
||||
else:
|
||||
editor.language = "plaintext"
|
||||
except Exception as e:
|
||||
logger.error(f"读取文件失败: {e}")
|
||||
self.notify(f"读取文件失败: {e}", severity="error")
|
||||
|
||||
def watch_file_content(self, old_content: str, new_content: str) -> None:
|
||||
"""当文件内容变化时更新编辑器(仅当外部改变时)"""
|
||||
# 目前不需要做任何事情,因为编辑器内容已绑定
|
||||
pass
|
||||
|
||||
def on_list_view_selected(self, event) -> None:
|
||||
"""处理列表项选择事件"""
|
||||
if not isinstance(event.item, ListItem):
|
||||
return
|
||||
list_id = event.list_view.id
|
||||
selected_label = event.item.query_one(Label)
|
||||
selected_text = str(selected_label.render())
|
||||
|
||||
if list_id == "repo_list":
|
||||
# 用户选择了仓库
|
||||
repo_root = Path(config_var.get()["paths"]["data"]) / "repo"
|
||||
selected_dir = repo_root / selected_text
|
||||
if selected_dir.exists():
|
||||
self.selected_repo_path = selected_dir
|
||||
elif list_id == "file_list":
|
||||
# 用户选择了文件
|
||||
if self.repo_dir is None:
|
||||
self.notify("请先选择仓库", severity="warning")
|
||||
return
|
||||
self.selected_filename = selected_text
|
||||
|
||||
def on_button_pressed(self, event) -> None:
|
||||
"""处理按钮点击事件"""
|
||||
event.stop()
|
||||
if event.button.id == "save_button":
|
||||
self.action_save_file()
|
||||
elif event.button.id == "reload_button":
|
||||
self.action_reload_file()
|
||||
elif event.button.id == "back_button":
|
||||
self.action_go_back()
|
||||
|
||||
def action_save_file(self) -> None:
|
||||
"""保存当前编辑的文件"""
|
||||
if self.repo_dir is None or self.selected_filename is None:
|
||||
self.notify("未选择仓库或文件", severity="warning")
|
||||
return
|
||||
file_path = self.repo_dir / self.selected_filename
|
||||
editor = self.query_one("#text_editor", TextArea)
|
||||
new_content = editor.text
|
||||
# 验证格式
|
||||
try:
|
||||
if self.selected_filename.endswith(".toml"):
|
||||
toml.loads(new_content) # 验证TOML
|
||||
elif self.selected_filename.endswith(".json"):
|
||||
json.loads(new_content) # 验证JSON
|
||||
except Exception as e:
|
||||
self.notify(f"格式错误: {e}", severity="error")
|
||||
return
|
||||
# 写入文件
|
||||
try:
|
||||
file_path.write_text(new_content, encoding="utf-8")
|
||||
self.notify("保存成功", severity="information")
|
||||
except Exception as e:
|
||||
logger.error(f"保存文件失败: {e}")
|
||||
self.notify(f"保存文件失败: {e}", severity="error")
|
||||
|
||||
def action_reload_file(self) -> None:
|
||||
"""重新加载当前文件(放弃修改)"""
|
||||
if self.repo_dir is None or self.selected_filename is None:
|
||||
self.notify("未选择仓库或文件", severity="warning")
|
||||
return
|
||||
file_path = self.repo_dir / self.selected_filename
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
editor = self.query_one("#text_editor", TextArea)
|
||||
editor.text = content
|
||||
self.notify("已重载", severity="information")
|
||||
except Exception as e:
|
||||
logger.error(f"重载文件失败: {e}")
|
||||
self.notify(f"重载文件失败: {e}", severity="error")
|
||||
|
||||
def action_go_back(self) -> None:
|
||||
"""返回上一屏幕"""
|
||||
self.app.pop_screen()
|
||||
|
||||
def action_toggle_dark(self) -> None:
|
||||
"""切换暗色模式"""
|
||||
self.app.dark = not self.app.dark
|
||||
@@ -1,5 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
"""同步工具界面"""
|
||||
|
||||
import pathlib
|
||||
import time
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Horizontal, ScrollableContainer
|
||||
@@ -18,22 +20,287 @@ class SyncScreen(Screen):
|
||||
|
||||
def __init__(self, nucleons: list = [], desc: str = ""):
|
||||
super().__init__(name=None, id=None, classes=None)
|
||||
self.sync_service = None
|
||||
self.is_syncing = False
|
||||
self.is_paused = False
|
||||
self.log_messages = []
|
||||
self.max_log_lines = 50
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=True)
|
||||
with ScrollableContainer(id="sync_container"):
|
||||
pass
|
||||
# 标题和连接状态
|
||||
yield Static("同步工具", classes="title")
|
||||
yield Static("", id="status_label", classes="status")
|
||||
|
||||
# 配置信息
|
||||
yield Static(f"同步协议: {config_var.get()['services']['sync']}")
|
||||
yield Static("服务器配置:", classes="section_title")
|
||||
with Horizontal(classes="config_info"):
|
||||
yield Static("远程服务器:", classes="config_label")
|
||||
yield Static("", id="server_url", classes="config_value")
|
||||
with Horizontal(classes="config_info"):
|
||||
yield Static("远程路径:", classes="config_label")
|
||||
yield Static("", id="remote_path", classes="config_value")
|
||||
|
||||
with Horizontal(classes="control_buttons"):
|
||||
yield Button("测试连接", id="test_connection", variant="primary")
|
||||
yield Button("开始同步", id="start_sync", variant="success")
|
||||
yield Button("暂停", id="pause_sync", variant="warning", disabled=True)
|
||||
yield Button("取消", id="cancel_sync", variant="error", disabled=True)
|
||||
|
||||
yield Static("同步进度", classes="section_title")
|
||||
yield ProgressBar(id="progress_bar", show_percentage=True, total=100)
|
||||
yield Static("", id="progress_label", classes="progress_text")
|
||||
|
||||
yield Static("同步日志", classes="section_title")
|
||||
yield Static("", id="log_output", classes="log_output")
|
||||
|
||||
yield Footer()
|
||||
|
||||
def on_mount(self):
|
||||
"""挂载时初始化状态"""
|
||||
self.update_ui_from_config()
|
||||
self.log_message("同步工具已启动")
|
||||
|
||||
def update_ui_from_config(self):
|
||||
"""更新 UI 显示配置信息"""
|
||||
try:
|
||||
sync_cfg: dict = config_var.get()["providers"]["sync"]["webdav"]
|
||||
# 更新服务器 URL
|
||||
url = sync_cfg.get("url", "未配置")
|
||||
url_widget = self.query_one("#server_url")
|
||||
url_widget.update(url) # type: ignore
|
||||
# 更新远程路径
|
||||
remote_path = sync_cfg.get("remote_path", "/")
|
||||
path_widget = self.query_one("#remote_path")
|
||||
path_widget.update(remote_path) # type: ignore
|
||||
|
||||
# 更新状态标签
|
||||
status_widget = self.query_one("#status_label")
|
||||
if self.sync_service and self.sync_service.client:
|
||||
status_widget.update("✅ 同步服务已就绪") # type: ignore
|
||||
status_widget.add_class("ready")
|
||||
else:
|
||||
status_widget.update("❌ 同步服务未配置或未启用") # type: ignore
|
||||
status_widget.add_class("error")
|
||||
|
||||
except Exception as e:
|
||||
self.log_message(f"更新 UI 失败: {e}", is_error=True)
|
||||
|
||||
def update_status(self, status, current_item="", progress=None):
|
||||
"""更新状态显示"""
|
||||
try:
|
||||
status_widget = self.query_one("#status_label")
|
||||
status_widget.update(status) # type: ignore
|
||||
|
||||
if progress is not None:
|
||||
progress_bar = self.query_one("#progress_bar")
|
||||
progress_bar.progress = progress # type: ignore
|
||||
|
||||
progress_label = self.query_one("#progress_label")
|
||||
progress_label.update(f"{progress}% - {current_item}" if current_item else f"{progress}%") # type: ignore
|
||||
|
||||
except Exception as e:
|
||||
self.log_message(f"更新状态失败: {e}", is_error=True)
|
||||
|
||||
def log_message(self, message: str, is_error: bool = False):
|
||||
"""添加日志消息并更新显示"""
|
||||
timestamp = time.strftime("%H:%M:%S")
|
||||
prefix = "[ERROR]" if is_error else "[INFO]"
|
||||
log_line = f"{timestamp} {prefix} {message}"
|
||||
|
||||
self.log_messages.append(log_line)
|
||||
# 保持日志行数不超过最大值
|
||||
if len(self.log_messages) > self.max_log_lines:
|
||||
self.log_messages = self.log_messages[-self.max_log_lines :]
|
||||
|
||||
# 更新日志显示
|
||||
try:
|
||||
log_widget = self.query_one("#log_output")
|
||||
log_widget.update("\n".join(self.log_messages)) # type: ignore
|
||||
except Exception:
|
||||
pass # 如果组件未就绪,忽略错误
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""处理按钮点击事件"""
|
||||
button_id = event.button.id
|
||||
|
||||
if button_id == "test_connection":
|
||||
self.test_connection()
|
||||
elif button_id == "start_sync":
|
||||
self.start_sync()
|
||||
elif button_id == "pause_sync":
|
||||
self.pause_sync()
|
||||
elif button_id == "cancel_sync":
|
||||
self.cancel_sync()
|
||||
|
||||
event.stop()
|
||||
|
||||
def test_connection(self):
|
||||
"""测试 WebDAV 服务器连接"""
|
||||
if not self.sync_service:
|
||||
self.log_message("同步服务未初始化,请检查配置", is_error=True)
|
||||
self.update_status("❌ 同步服务未初始化")
|
||||
return
|
||||
|
||||
self.log_message("正在测试 WebDAV 连接...")
|
||||
self.update_status("正在测试连接...")
|
||||
|
||||
try:
|
||||
success = self.sync_service.test_connection()
|
||||
if success:
|
||||
self.log_message("连接测试成功")
|
||||
self.update_status("✅ 连接正常")
|
||||
else:
|
||||
self.log_message("连接测试失败", is_error=True)
|
||||
self.update_status("❌ 连接失败")
|
||||
except Exception as e:
|
||||
self.log_message(f"连接测试异常: {e}", is_error=True)
|
||||
self.update_status("❌ 连接异常")
|
||||
|
||||
def start_sync(self):
|
||||
"""开始同步"""
|
||||
if not self.sync_service:
|
||||
self.log_message("同步服务未初始化,无法开始同步", is_error=True)
|
||||
return
|
||||
|
||||
if self.is_syncing:
|
||||
self.log_message("同步已在进行中", is_error=True)
|
||||
return
|
||||
|
||||
self.is_syncing = True
|
||||
self.is_paused = False
|
||||
self.update_button_states()
|
||||
|
||||
self.log_message("开始同步数据...")
|
||||
self.update_status("正在同步...", progress=0)
|
||||
|
||||
# 启动后台同步任务
|
||||
self.run_worker(self.perform_sync, thread=True)
|
||||
|
||||
def perform_sync(self):
|
||||
"""执行同步任务(在后台线程中运行)"""
|
||||
worker = get_current_worker()
|
||||
|
||||
try:
|
||||
# 获取需要同步的本地目录
|
||||
from heurams.context import config_var
|
||||
|
||||
config = config_var.get()
|
||||
paths = config.get("paths", {})
|
||||
|
||||
# 同步 nucleon 目录
|
||||
nucleon_dir = pathlib.Path(paths.get("nucleon_dir", "./data/nucleon"))
|
||||
if nucleon_dir.exists():
|
||||
self.log_message(f"同步 nucleon 目录: {nucleon_dir}")
|
||||
self.update_status(f"同步 nucleon 目录...", progress=10)
|
||||
|
||||
result = self.sync_service.sync_directory(nucleon_dir) # type: ignore
|
||||
if result.get("success"):
|
||||
self.log_message(
|
||||
f"nucleon 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个"
|
||||
)
|
||||
else:
|
||||
self.log_message(
|
||||
f"nucleon 同步失败: {result.get('error', '未知错误')}",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
# 同步 electron 目录
|
||||
electron_dir = pathlib.Path(paths.get("electron_dir", "./data/electron"))
|
||||
if electron_dir.exists():
|
||||
self.log_message(f"同步 electron 目录: {electron_dir}")
|
||||
self.update_status(f"同步 electron 目录...", progress=60)
|
||||
|
||||
result = self.sync_service.sync_directory(electron_dir) # type: ignore
|
||||
if result.get("success"):
|
||||
self.log_message(
|
||||
f"electron 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个"
|
||||
)
|
||||
else:
|
||||
self.log_message(
|
||||
f"electron 同步失败: {result.get('error', '未知错误')}",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
# 同步 orbital 目录(如果存在)
|
||||
orbital_dir = pathlib.Path(paths.get("orbital_dir", "./data/orbital"))
|
||||
if orbital_dir.exists():
|
||||
self.log_message(f"同步 orbital 目录: {orbital_dir}")
|
||||
self.update_status(f"同步 orbital 目录...", progress=80)
|
||||
|
||||
result = self.sync_service.sync_directory(orbital_dir) # type: ignore
|
||||
if result.get("success"):
|
||||
self.log_message(
|
||||
f"orbital 同步完成: 上传 {result.get('uploaded', 0)} 个, 下载 {result.get('downloaded', 0)} 个"
|
||||
)
|
||||
else:
|
||||
self.log_message(
|
||||
f"orbital 同步失败: {result.get('error', '未知错误')}",
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
# 同步完成
|
||||
self.update_status("同步完成", progress=100)
|
||||
self.log_message("所有目录同步完成")
|
||||
|
||||
except Exception as e:
|
||||
self.log_message(f"同步过程中发生错误: {e}", is_error=True)
|
||||
self.update_status("同步失败")
|
||||
finally:
|
||||
# 重置同步状态
|
||||
self.is_syncing = False
|
||||
self.is_paused = False
|
||||
self.update_button_states() # type: ignore
|
||||
|
||||
def pause_sync(self):
|
||||
"""暂停同步"""
|
||||
if not self.is_syncing:
|
||||
return
|
||||
|
||||
self.is_paused = not self.is_paused
|
||||
self.update_button_states()
|
||||
|
||||
if self.is_paused:
|
||||
self.log_message("同步已暂停")
|
||||
self.update_status("同步已暂停")
|
||||
else:
|
||||
self.log_message("同步已恢复")
|
||||
self.update_status("正在同步...")
|
||||
|
||||
def cancel_sync(self):
|
||||
"""取消同步"""
|
||||
if not self.is_syncing:
|
||||
return
|
||||
|
||||
self.is_syncing = False
|
||||
self.is_paused = False
|
||||
self.update_button_states()
|
||||
|
||||
self.log_message("同步已取消")
|
||||
self.update_status("同步已取消")
|
||||
|
||||
def update_button_states(self):
|
||||
"""更新按钮状态"""
|
||||
try:
|
||||
start_button = self.query_one("#start_sync")
|
||||
pause_button = self.query_one("#pause_sync")
|
||||
cancel_button = self.query_one("#cancel_sync")
|
||||
|
||||
if self.is_syncing:
|
||||
start_button.disabled = True
|
||||
pause_button.disabled = False
|
||||
cancel_button.disabled = False
|
||||
pause_button.label = "继续" if self.is_paused else "暂停" # type: ignore
|
||||
else:
|
||||
start_button.disabled = False
|
||||
pause_button.disabled = True
|
||||
cancel_button.disabled = True
|
||||
|
||||
except Exception as e:
|
||||
self.log_message(f"更新按钮状态失败: {e}", is_error=True)
|
||||
|
||||
def action_go_back(self):
|
||||
self.app.pop_screen()
|
||||
|
||||
|
||||
@@ -1,34 +1,8 @@
|
||||
"""Kernel 操作辅助函数库"""
|
||||
|
||||
import random
|
||||
from typing import TypedDict
|
||||
|
||||
import heurams.interface.widgets as pzw
|
||||
import heurams.kernel.particles as pt
|
||||
import heurams.kernel.puzzles as pz
|
||||
|
||||
staging = {} # 细粒度缓存区, 是 ident -> quality 的封装
|
||||
|
||||
|
||||
def report_to_staging(atom: pt.Atom, quality):
|
||||
staging[atom.ident] = min(quality, staging[atom.ident])
|
||||
|
||||
|
||||
def clear():
|
||||
staging = dict()
|
||||
|
||||
|
||||
def deploy_to_electron():
|
||||
for atom_ident, quality in staging.items():
|
||||
if pt.atom_registry[atom_ident].registry["electron"].is_activated:
|
||||
pt.atom_registry[atom_ident].registry["electron"].revisor(quality=quality)
|
||||
else:
|
||||
pt.atom_registry[atom_ident].registry["electron"].revisor(
|
||||
quality=quality, is_new_activation=True
|
||||
)
|
||||
clear()
|
||||
|
||||
|
||||
puzzle2widget = {
|
||||
pz.RecognitionPuzzle: pzw.Recognition,
|
||||
pz.ClozePuzzle: pzw.ClozePuzzle,
|
||||
|
||||
@@ -53,7 +53,7 @@ class ClozePuzzle(BasePuzzleWidget):
|
||||
self.hashmap = dict()
|
||||
|
||||
def _load(self):
|
||||
setting = self.atom.registry["orbital"]["puzzles"][self.alia]
|
||||
setting = self.atom.registry["nucleon"]["puzzles"][self.alia]
|
||||
self.puzzle = pz.ClozePuzzle(
|
||||
text=setting["text"],
|
||||
delimiter=setting["delimiter"],
|
||||
|
||||
@@ -7,24 +7,27 @@ class Finished(Widget):
|
||||
self,
|
||||
*children: Widget,
|
||||
alia="",
|
||||
is_saved=0,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
disabled: bool = False,
|
||||
markup: bool = True
|
||||
markup: bool = True,
|
||||
) -> None:
|
||||
self.alia = alia
|
||||
self.is_saved = is_saved
|
||||
super().__init__(
|
||||
*children,
|
||||
name=name,
|
||||
id=id,
|
||||
classes=classes,
|
||||
disabled=disabled,
|
||||
markup=markup
|
||||
markup=markup,
|
||||
)
|
||||
|
||||
def compose(self):
|
||||
yield Label("本次记忆进程结束", id="finished_msg")
|
||||
yield Label(f"算法数据{'已保存' if self.is_saved else "未能保存"}")
|
||||
yield Button("返回上一级", id="back-to-menu")
|
||||
|
||||
def on_button_pressed(self, event):
|
||||
|
||||
@@ -54,22 +54,25 @@ class MCQPuzzle(BasePuzzleWidget):
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
cfg = self.atom.registry["orbital"]["puzzles"][self.alia]
|
||||
cfg = self.atom.registry["nucleon"]["puzzles"][self.alia]
|
||||
if cfg['mapping'] == {}:
|
||||
self.screen.rating = 5 # type: ignore
|
||||
self.puzzle = pz.MCQPuzzle(
|
||||
cfg["mapping"], cfg["jammer"], int(cfg["max_riddles_num"]), cfg["prefix"]
|
||||
)
|
||||
self.puzzle.refresh()
|
||||
|
||||
def compose(self):
|
||||
self.atom.registry["nucleon"].do_eval()
|
||||
setting: Setting = self.atom.registry["nucleon"].metadata["orbital"]["puzzles"][
|
||||
self.alia
|
||||
]
|
||||
logger.debug(f"Puzzle Setting: {setting}")
|
||||
current_options = self.puzzle.options[len(self.inputlist)]
|
||||
yield Label(setting["primary"], id="sentence")
|
||||
yield Label(self.puzzle.wording[len(self.inputlist)], id="puzzle")
|
||||
yield Label(f"当前输入: {self.inputlist}", id="inputpreview")
|
||||
setting: Setting = self.atom.registry["nucleon"]["puzzles"][self.alia]
|
||||
if len(self.inputlist) > len(self.puzzle.options):
|
||||
logger.debug("ERR IDX")
|
||||
logger.debug(self.inputlist)
|
||||
logger.debug(self.puzzle.options)
|
||||
else:
|
||||
current_options = self.puzzle.options[len(self.inputlist)]
|
||||
yield Label(setting["primary"], id="sentence")
|
||||
yield Label(self.puzzle.wording[len(self.inputlist)], id="puzzle")
|
||||
yield Label(f"当前输入: {self.inputlist}", id="inputpreview")
|
||||
|
||||
# 渲染当前问题的选项
|
||||
with Container(id="btn-container"):
|
||||
|
||||
@@ -49,8 +49,13 @@ class Recognition(BasePuzzleWidget):
|
||||
self.alia = alia
|
||||
|
||||
def compose(self):
|
||||
cfg: RecognitionConfig = self.atom.registry["orbital"]["puzzles"][self.alia]
|
||||
delim = self.atom.registry["nucleon"].metadata["formation"]["delimiter"]
|
||||
from heurams.context import config_var
|
||||
|
||||
autovoice = config_var.get()["interface"]["memorizor"]["autovoice"]
|
||||
if autovoice:
|
||||
self.screen.action_play_voice() # type: ignore
|
||||
cfg: RecognitionConfig = self.atom.registry["nucleon"]["puzzles"][self.alia]
|
||||
delim = self.atom.registry["nucleon"]["delimiter"]
|
||||
replace_dict = {
|
||||
", ": ",",
|
||||
". ": ".",
|
||||
@@ -64,11 +69,12 @@ class Recognition(BasePuzzleWidget):
|
||||
}
|
||||
|
||||
nucleon = self.atom.registry["nucleon"]
|
||||
metadata = self.atom.registry["nucleon"].metadata
|
||||
metadata = self.atom.registry["nucleon"]
|
||||
primary = cfg["primary"]
|
||||
|
||||
with Center():
|
||||
yield Static(f"[dim]{cfg['top_dim']}[/]")
|
||||
for i in cfg["top_dim"]:
|
||||
yield Static(f"[dim]{i}[/]")
|
||||
yield Label("")
|
||||
|
||||
for old, new in replace_dict.items():
|
||||
@@ -84,7 +90,7 @@ class Recognition(BasePuzzleWidget):
|
||||
for item in cfg["secondary"]:
|
||||
if isinstance(item, list):
|
||||
for j in item:
|
||||
yield Markdown(f"### {metadata['annotation'][item]}: {j}")
|
||||
yield Markdown(f"### 笔记: {j}") #TODO ANNOTATION
|
||||
continue
|
||||
if isinstance(item, Dict):
|
||||
total = ""
|
||||
@@ -101,13 +107,3 @@ class Recognition(BasePuzzleWidget):
|
||||
if event.button.id == "ok":
|
||||
self.screen.rating = 5 # type: ignore
|
||||
self.handler(5)
|
||||
|
||||
def handler(self, rating):
|
||||
if not self.atom.registry["runtime"]["locked"]:
|
||||
if not self.atom.registry["electron"].is_activated():
|
||||
self.atom.registry["electron"].activate()
|
||||
logger.debug(f"激活原子 {self.atom}")
|
||||
self.atom.lock(1)
|
||||
self.atom.minimize(5)
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .base import BaseAlgorithm
|
||||
from .sm2 import SM2Algorithm
|
||||
from .sm15m import SM15MAlgorithm
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"SM2Algorithm",
|
||||
"BaseAlgorithm",
|
||||
"SM15MAlgorithm",
|
||||
]
|
||||
|
||||
algorithms = {
|
||||
"SM-2": SM2Algorithm,
|
||||
"SM-15M": SM15MAlgorithm,
|
||||
# "SM-15M": SM15MAlgorithm,
|
||||
"Base": BaseAlgorithm,
|
||||
}
|
||||
|
||||
logger.debug("算法模块初始化完成, 注册的算法: %s", list(algorithms.keys()))
|
||||
|
||||
@@ -10,7 +10,6 @@ class BaseAlgorithm:
|
||||
algo_name = "BaseAlgorithm"
|
||||
|
||||
class AlgodataDict(TypedDict):
|
||||
efactor: float
|
||||
real_rept: int
|
||||
rept: int
|
||||
interval: int
|
||||
@@ -52,7 +51,7 @@ class BaseAlgorithm:
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def rate(cls, algodata) -> str:
|
||||
def get_rating(cls, algodata) -> str:
|
||||
"""获取评分信息"""
|
||||
logger.debug(
|
||||
"BaseAlgorithm.rate 被调用, algodata keys: %s",
|
||||
@@ -68,3 +67,11 @@ class BaseAlgorithm:
|
||||
list(algodata.keys()) if algodata else [],
|
||||
)
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def check_integrity(cls, algodata):
|
||||
try:
|
||||
cls.AlgodataDict(**algodata[cls.algo_name])
|
||||
return 1
|
||||
except:
|
||||
return 0
|
||||
|
||||
@@ -2,21 +2,35 @@
|
||||
SM-15 接口兼容实现, 基于 SM-15 算法的逆向工程
|
||||
全局状态保存在文件中, 项目状态通过 algodata 字典传递
|
||||
|
||||
基于: https://github.com/kazuaki/sm.js
|
||||
原始 CoffeeScript 代码: (c) 2014 Kazuaki Tanida (MIT 许可证)
|
||||
基于: https://github.com/slaypni/sm.js
|
||||
原始 CoffeeScript 代码: (c) 2014 Kazuaki Tanida
|
||||
MIT 许可证
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
from typing import TypedDict
|
||||
|
||||
from heurams.kernel.algorithms.sm15m_calc import (MAX_AF, MIN_AF, NOTCH_AF,
|
||||
RANGE_AF, RANGE_REPETITION,
|
||||
SM, THRESHOLD_RECALL, Item)
|
||||
from heurams.context import config_var
|
||||
from heurams.kernel.algorithms.sm15m_calc import (
|
||||
MAX_AF,
|
||||
MIN_AF,
|
||||
NOTCH_AF,
|
||||
RANGE_AF,
|
||||
RANGE_REPETITION,
|
||||
SM,
|
||||
THRESHOLD_RECALL,
|
||||
Item,
|
||||
)
|
||||
|
||||
# 全局状态文件路径
|
||||
_GLOBAL_STATE_FILE = os.path.expanduser("~/.sm15_global_state.json")
|
||||
_GLOBAL_STATE_FILE = os.path.expanduser(
|
||||
pathlib.Path(config_var.get()["paths"]["data"])
|
||||
/ "global"
|
||||
/ "sm15m_global_state.json"
|
||||
)
|
||||
|
||||
|
||||
def _get_global_sm():
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
基于: https://github.com/kazuaki/sm.js
|
||||
基于: https://github.com/slaypni/sm.js
|
||||
原始 CoffeeScript 代码: (c) 2014 Kazuaki Tanida
|
||||
MIT 许可证
|
||||
|
||||
|
||||
@@ -116,7 +116,7 @@ class SM2Algorithm(BaseAlgorithm):
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def rate(cls, algodata):
|
||||
def get_rating(cls, algodata):
|
||||
efactor = algodata[cls.algo_name]["efactor"]
|
||||
logger.debug("SM2.rate: efactor=%f", efactor)
|
||||
return str(efactor)
|
||||
|
||||
5
src/heurams/kernel/auxiliary/__init__.py
Normal file
5
src/heurams/kernel/auxiliary/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .evalizor import Evalizer
|
||||
from .lict import Lict
|
||||
from .refvar import RefVar
|
||||
|
||||
__all__ = ["Evalizer", "Lict", "RefVar"]
|
||||
33
src/heurams/kernel/auxiliary/evalizor.py
Normal file
33
src/heurams/kernel/auxiliary/evalizor.py
Normal file
@@ -0,0 +1,33 @@
|
||||
class Evalizer:
|
||||
"""几乎无副作用的模板系统
|
||||
|
||||
接受环境信息并创建一个模板解析工具, 工具传入参数支持list, dict及其嵌套
|
||||
副作用问题: 仅存在于 eval 函数
|
||||
"""
|
||||
|
||||
# TODO: 弃用风险极高的 eval
|
||||
# TODO: 异步/多线程执行避免堵塞
|
||||
def __init__(self, environment: dict) -> None:
|
||||
self.env = environment
|
||||
|
||||
def __call__(self, anyobj):
|
||||
return self.travel(anyobj)
|
||||
|
||||
def travel(self, anyobj):
|
||||
if isinstance(anyobj, list):
|
||||
return list(map(self.travel, anyobj))
|
||||
elif isinstance(anyobj, dict):
|
||||
return dict(map(self.travel, anyobj.items()))
|
||||
elif isinstance(anyobj, tuple):
|
||||
return tuple(map(self.travel, anyobj))
|
||||
elif isinstance(anyobj, str):
|
||||
if anyobj.startswith("eval:"):
|
||||
return self.eval_with_env(anyobj[5:])
|
||||
else:
|
||||
return anyobj
|
||||
else:
|
||||
return anyobj
|
||||
|
||||
def eval_with_env(self, s: str):
|
||||
ret = eval(s, globals(), self.env)
|
||||
return ret
|
||||
149
src/heurams/kernel/auxiliary/lict.py
Normal file
149
src/heurams/kernel/auxiliary/lict.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from collections import UserList
|
||||
from typing import Any, Iterator
|
||||
|
||||
|
||||
class Lict(UserList): # TODO: 优化同步(惰性同步), 当前性能为 O(n)
|
||||
""" "列典" 对象
|
||||
|
||||
同时兼容字典和列表大多数 API, 两边数据同步的容器
|
||||
列表数据是 dictobj.items() 的格式
|
||||
支持根据字典或列表初始化
|
||||
限制要求:
|
||||
- 键名一定唯一, 且仅能为字符串
|
||||
- 值一定是引用对象
|
||||
- 不使用并发
|
||||
- 不在乎列表顺序语义(严格按键名字符序排列)和列表索引查找, 因此外部的 sort, index 等功能不可用
|
||||
- append 的元组中, 表示键名的元素不能重复, 否则会导致覆盖行为
|
||||
|
||||
只有在 Python 3.7+ 中, forced_order 行为才能被取消.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initlist: list | None = None,
|
||||
initdict: dict | None = None,
|
||||
forced_order=False,
|
||||
):
|
||||
self.dicted_data = {}
|
||||
if initdict != None:
|
||||
initlist = list(initdict.items())
|
||||
super().__init__(initlist=initlist)
|
||||
self.forced_order = forced_order
|
||||
self._sync_based_on_list()
|
||||
if self.forced_order:
|
||||
self.data.sort()
|
||||
|
||||
def _sync_based_on_dict(self):
|
||||
self.data = list(self.dicted_data.items())
|
||||
if self.forced_order:
|
||||
self.data.sort()
|
||||
|
||||
def _sync_based_on_list(self):
|
||||
self.dicted_data = {}
|
||||
for i in self.data:
|
||||
self.dicted_data[i[0]] = i[1]
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
return self.data.__iter__()
|
||||
|
||||
def __getitem__(self, i):
|
||||
if isinstance(i, str):
|
||||
return self.dicted_data[i]
|
||||
else:
|
||||
return super().__getitem__(i)
|
||||
|
||||
def get_itemic_unit(self, ident):
|
||||
return (ident, self.dicted_data[ident])
|
||||
|
||||
def __setitem__(self, i, item):
|
||||
if isinstance(i, str):
|
||||
self.dicted_data[i] = item
|
||||
self._sync_based_on_dict()
|
||||
else:
|
||||
if item != (item[0], item[1]):
|
||||
raise NotImplementedError
|
||||
super().__setitem__(i, item)
|
||||
self._sync_based_on_list()
|
||||
|
||||
def __delitem__(self, i):
|
||||
if isinstance(i, str):
|
||||
del self.dicted_data[i]
|
||||
self._sync_based_on_dict()
|
||||
else:
|
||||
super().__delitem__(i)
|
||||
self._sync_based_on_list()
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self.data or item in self.keys() or item in self.values()
|
||||
|
||||
def append(self, item: Any) -> None:
|
||||
if item != (item[0], item[1]):
|
||||
raise NotImplementedError
|
||||
super().append(item)
|
||||
self._sync_based_on_list()
|
||||
if self.forced_order:
|
||||
self.data.sort()
|
||||
|
||||
def append_new(self, item: Any):
|
||||
if item != (item[0], item[1]):
|
||||
raise NotImplementedError
|
||||
if item[0] not in self:
|
||||
super().append(item)
|
||||
self._sync_based_on_list()
|
||||
if self.forced_order:
|
||||
self.data.sort()
|
||||
|
||||
def insert(self, i: int, item: Any) -> None:
|
||||
if item != (item[0], item[1]): # 确保 item 是遵从限制的元组
|
||||
raise NotImplementedError
|
||||
super().insert(i, item)
|
||||
self._sync_based_on_list()
|
||||
if self.forced_order:
|
||||
self.data.sort()
|
||||
|
||||
def pop(self, i: int = -1) -> Any:
|
||||
res = super().pop(i)
|
||||
self._sync_based_on_list()
|
||||
return res
|
||||
|
||||
def remove(self, item: Any) -> None:
|
||||
if isinstance(item, str):
|
||||
item = (item, self.dicted_data[item])
|
||||
if item != (item[0], item[1]):
|
||||
raise NotImplementedError
|
||||
super().remove(item)
|
||||
self._sync_based_on_list()
|
||||
if self.forced_order:
|
||||
self.data.sort()
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self._sync_based_on_list()
|
||||
|
||||
def index(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def extend(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def sort(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def reverse(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def keys(self):
|
||||
return self.dicted_data.keys()
|
||||
|
||||
def values(self):
|
||||
return self.dicted_data.values()
|
||||
|
||||
def items(self):
|
||||
return self.data
|
||||
|
||||
def keys_equal_with(self, other):
|
||||
return self.key_equality(self, other)
|
||||
|
||||
@classmethod
|
||||
def key_equality(cls, a, b):
|
||||
return a.keys() == b.keys()
|
||||
241
src/heurams/kernel/auxiliary/refvar.py
Normal file
241
src/heurams/kernel/auxiliary/refvar.py
Normal file
@@ -0,0 +1,241 @@
|
||||
class RefVar:
|
||||
def __init__(self, initvalue) -> None:
|
||||
self.data = initvalue
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RefVar({repr(self.data)})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.data)
|
||||
|
||||
def __add__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return RefVar(self.data + other.data)
|
||||
return RefVar(self.data + other)
|
||||
|
||||
def __radd__(self, other):
|
||||
return RefVar(other + self.data)
|
||||
|
||||
def __sub__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return RefVar(self.data - other.data)
|
||||
return RefVar(self.data - other)
|
||||
|
||||
def __rsub__(self, other):
|
||||
return RefVar(other - self.data)
|
||||
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return RefVar(self.data * other.data)
|
||||
return RefVar(self.data * other)
|
||||
|
||||
def __rmul__(self, other):
|
||||
return RefVar(other * self.data)
|
||||
|
||||
def __truediv__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return RefVar(self.data / other.data)
|
||||
return RefVar(self.data / other)
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
return RefVar(other / self.data)
|
||||
|
||||
def __floordiv__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return RefVar(self.data // other.data)
|
||||
return RefVar(self.data // other)
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
return RefVar(other // self.data)
|
||||
|
||||
def __mod__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return RefVar(self.data % other.data)
|
||||
return RefVar(self.data % other)
|
||||
|
||||
def __rmod__(self, other):
|
||||
return RefVar(other % self.data)
|
||||
|
||||
def __pow__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return RefVar(self.data**other.data)
|
||||
return RefVar(self.data**other)
|
||||
|
||||
def __rpow__(self, other):
|
||||
return RefVar(other**self.data)
|
||||
|
||||
def __neg__(self):
|
||||
return RefVar(-self.data)
|
||||
|
||||
def __pos__(self):
|
||||
return RefVar(+self.data)
|
||||
|
||||
def __abs__(self):
|
||||
return RefVar(abs(self.data))
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return self.data == other.data
|
||||
return self.data == other
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __lt__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return self.data < other.data
|
||||
return self.data < other
|
||||
|
||||
def __le__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return self.data <= other.data
|
||||
return self.data <= other
|
||||
|
||||
def __gt__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return self.data > other.data
|
||||
return self.data > other
|
||||
|
||||
def __ge__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return self.data >= other.data
|
||||
return self.data >= other
|
||||
|
||||
# 位运算
|
||||
def __and__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return RefVar(self.data & other.data)
|
||||
return RefVar(self.data & other)
|
||||
|
||||
def __rand__(self, other):
|
||||
return RefVar(other & self.data)
|
||||
|
||||
def __or__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return RefVar(self.data | other.data)
|
||||
return RefVar(self.data | other)
|
||||
|
||||
def __ror__(self, other):
|
||||
return RefVar(other | self.data)
|
||||
|
||||
def __xor__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return RefVar(self.data ^ other.data)
|
||||
return RefVar(self.data ^ other)
|
||||
|
||||
def __rxor__(self, other):
|
||||
return RefVar(other ^ self.data)
|
||||
|
||||
def __lshift__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return RefVar(self.data << other.data)
|
||||
return RefVar(self.data << other)
|
||||
|
||||
def __rlshift__(self, other):
|
||||
return RefVar(other << self.data)
|
||||
|
||||
def __rshift__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
return RefVar(self.data >> other.data)
|
||||
return RefVar(self.data >> other)
|
||||
|
||||
def __rrshift__(self, other):
|
||||
return RefVar(other >> self.data)
|
||||
|
||||
def __invert__(self):
|
||||
return RefVar(~self.data)
|
||||
|
||||
# 类型转换
|
||||
def __int__(self):
|
||||
return int(self.data)
|
||||
|
||||
def __float__(self):
|
||||
return float(self.data)
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.data)
|
||||
|
||||
def __complex__(self):
|
||||
return complex(self.data)
|
||||
|
||||
def __bytes__(self):
|
||||
return bytes(self.data)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.data)
|
||||
|
||||
# 容器操作(如果底层数据支持)
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.data[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.data[key]
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self.data
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.data)
|
||||
|
||||
def __iadd__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
self.data += other.data
|
||||
else:
|
||||
self.data += other
|
||||
return self
|
||||
|
||||
def __isub__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
self.data -= other.data
|
||||
else:
|
||||
self.data -= other
|
||||
return self
|
||||
|
||||
def __imul__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
self.data *= other.data
|
||||
else:
|
||||
self.data *= other
|
||||
return self
|
||||
|
||||
def __itruediv__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
self.data /= other.data
|
||||
else:
|
||||
self.data /= other
|
||||
return self
|
||||
|
||||
def __ifloordiv__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
self.data //= other.data
|
||||
else:
|
||||
self.data //= other
|
||||
return self
|
||||
|
||||
def __imod__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
self.data %= other.data
|
||||
else:
|
||||
self.data %= other
|
||||
return self
|
||||
|
||||
def __ipow__(self, other):
|
||||
if isinstance(other, RefVar):
|
||||
self.data **= other.data
|
||||
else:
|
||||
self.data **= other
|
||||
return self
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if callable(self.data):
|
||||
return self.data(*args, **kwargs)
|
||||
raise TypeError(f"'{type(self.data).__name__}' object is not callable")
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.data, name)
|
||||
@@ -1,29 +1,21 @@
|
||||
"""
|
||||
Particle 模块 - 粒子对象系统
|
||||
|
||||
提供闪卡所需对象, 使用物理学粒子的领域驱动设计
|
||||
"""
|
||||
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.debug("粒子模块已加载")
|
||||
|
||||
from .atom import Atom, atom_registry
|
||||
from .atom import Atom
|
||||
from .electron import Electron
|
||||
from .loader import load_electron, load_nucleon
|
||||
from .nucleon import Nucleon
|
||||
from .orbital import Orbital
|
||||
from .probe import probe_all, probe_by_filename
|
||||
from .placeholders import (
|
||||
AtomPlaceholder,
|
||||
ElectronPlaceholder,
|
||||
NucleonPlaceholder,
|
||||
orbital_placeholder,
|
||||
)
|
||||
|
||||
# from .orbital import Orbital
|
||||
|
||||
__all__ = [
|
||||
"Atom",
|
||||
"Electron",
|
||||
"Nucleon",
|
||||
"Orbital",
|
||||
"Atom",
|
||||
"probe_all",
|
||||
"probe_by_filename",
|
||||
"load_nucleon",
|
||||
"load_electron",
|
||||
"atom_registry",
|
||||
"AtomPlaceholder",
|
||||
"NucleonPlaceholder",
|
||||
"ElectronPlaceholder",
|
||||
"orbital_placeholder",
|
||||
]
|
||||
|
||||
@@ -1,17 +1,9 @@
|
||||
import json
|
||||
import pathlib
|
||||
import typing
|
||||
from typing import TypedDict
|
||||
|
||||
import bidict
|
||||
import toml
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .electron import Electron
|
||||
from .nucleon import Nucleon
|
||||
from .orbital import Orbital
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -19,19 +11,13 @@ logger = get_logger(__name__)
|
||||
class AtomRegister_runtime(TypedDict):
|
||||
locked: bool # 只读锁定标识符
|
||||
min_rate: int # 最低评分
|
||||
newact: bool # 新激活
|
||||
new_activation: bool # 新激活
|
||||
|
||||
|
||||
class AtomRegister(TypedDict):
|
||||
nucleon: Nucleon
|
||||
nucleon_path: pathlib.Path
|
||||
nucleon_fmt: str
|
||||
electron: Electron
|
||||
electron_path: pathlib.Path
|
||||
electron_fmt: str
|
||||
orbital: Orbital
|
||||
orbital_path: pathlib.Path
|
||||
orbital_fmt: str
|
||||
orbital: dict
|
||||
runtime: AtomRegister_runtime
|
||||
|
||||
|
||||
@@ -44,39 +30,27 @@ class Atom:
|
||||
以及关联路径
|
||||
"""
|
||||
|
||||
def __init__(self, ident=""):
|
||||
logger.debug("创建 Atom 实例, ident: '%s'", ident)
|
||||
self.ident = ident
|
||||
atom_registry[ident] = self
|
||||
logger.debug("Atom 已注册到全局注册表, 当前注册表大小: %d", len(atom_registry))
|
||||
# self.is_evaled = False
|
||||
self.registry: AtomRegister = { # type: ignore
|
||||
"nucleon": None,
|
||||
"nucleon_path": None,
|
||||
"nucleon_fmt": "toml",
|
||||
"electron": None,
|
||||
"electron_path": None,
|
||||
"electron_fmt": "json",
|
||||
"orbital": None,
|
||||
"orbital_path": None, # 允许设置为 None, 此时使用 nucleon 文件内的推荐配置
|
||||
"orbital_fmt": "toml",
|
||||
"runtime": {"locked": False, "min_rate": 0x3F3F3F3F, "newact": False},
|
||||
}
|
||||
self.do_eval()
|
||||
logger.debug("Atom 初始化完成")
|
||||
default_runtime = {
|
||||
"locked": False,
|
||||
"min_rate": 0x3F3F3F3F,
|
||||
"new_activation": False,
|
||||
}
|
||||
|
||||
def link(self, key, value):
|
||||
logger.debug("Atom.link: key='%s', value type: %s", key, type(value).__name__)
|
||||
if key in self.registry.keys():
|
||||
self.registry[key] = value
|
||||
logger.debug("键 '%s' 已链接, 触发 do_eval", key)
|
||||
self.do_eval()
|
||||
if key == "electron":
|
||||
if self.registry["electron"].is_activated() == 0:
|
||||
self.registry["runtime"]["newact"] = True
|
||||
else:
|
||||
logger.error("尝试链接不受支持的键: '%s'", key)
|
||||
raise ValueError("不受支持的原子元数据链接操作")
|
||||
def __init__(self, nucleon_obj=None, electron_obj=None, orbital_obj=None):
|
||||
self.ident = nucleon_obj["ident"] # type: ignore
|
||||
self.registry: AtomRegister = { # type: ignore
|
||||
"ident": nucleon_obj["ident"], # type: ignore
|
||||
"nucleon": nucleon_obj,
|
||||
"electron": electron_obj,
|
||||
"orbital": orbital_obj,
|
||||
"runtime": dict(),
|
||||
}
|
||||
self.init_runtime()
|
||||
if self.registry["electron"].is_activated() == 0:
|
||||
self.registry["runtime"]["new_activation"] = True
|
||||
|
||||
def init_runtime(self):
|
||||
self.registry["runtime"] = AtomRegister_runtime(**self.default_runtime)
|
||||
|
||||
def minimize(self, rating):
|
||||
"""效果等同于 self.registry['runtime']['min_rate'] = min(rating, self.registry['runtime']['min_rate'])
|
||||
@@ -109,136 +83,21 @@ class Atom:
|
||||
logger.debug(f"允许总评分: {self.registry['runtime']['min_rate']}")
|
||||
self.registry["electron"].revisor(
|
||||
self.registry["runtime"]["min_rate"],
|
||||
is_new_activation=self.registry["runtime"]["newact"],
|
||||
is_new_activation=self.registry["runtime"]["new_activation"],
|
||||
)
|
||||
else:
|
||||
logger.debug("禁止总评分")
|
||||
|
||||
def do_eval(self):
|
||||
"""
|
||||
执行并以结果替换当前单元的所有 eval 语句
|
||||
TODO: 带有限制的 eval, 异步/多线程执行避免堵塞
|
||||
"""
|
||||
logger.debug("Atom.do_eval 开始")
|
||||
|
||||
# eval 环境设置
|
||||
def eval_with_env(s: str):
|
||||
# 初始化默认值
|
||||
nucleon = self.registry["nucleon"]
|
||||
default = {}
|
||||
metadata = {}
|
||||
try:
|
||||
default = config_var.get()["puzzles"]
|
||||
metadata = nucleon.metadata
|
||||
except Exception:
|
||||
# 如果无法获取配置或元数据, 使用空字典
|
||||
logger.debug("无法获取配置或元数据, 使用空字典")
|
||||
pass
|
||||
try:
|
||||
eval_value = eval(s)
|
||||
if isinstance(eval_value, (list, dict)):
|
||||
ret = eval_value
|
||||
else:
|
||||
ret = str(eval_value)
|
||||
logger.debug(
|
||||
"eval 执行成功: '%s' -> '%s'",
|
||||
s,
|
||||
str(ret)[:50] + "..." if len(ret) > 50 else ret,
|
||||
)
|
||||
except Exception as e:
|
||||
ret = f"此 eval 实例发生错误: {e}"
|
||||
logger.warning("eval 执行错误: '%s' -> %s", s, e)
|
||||
return ret
|
||||
|
||||
def traverse(data, modifier):
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
data[key] = traverse(value, modifier)
|
||||
return data
|
||||
elif isinstance(data, list):
|
||||
for i, item in enumerate(data):
|
||||
data[i] = traverse(item, modifier)
|
||||
return data
|
||||
elif isinstance(data, tuple):
|
||||
return tuple(traverse(item, modifier) for item in data)
|
||||
else:
|
||||
if isinstance(data, str):
|
||||
if data.startswith("eval:"):
|
||||
logger.debug("发现 eval 表达式: '%s'", data[5:])
|
||||
return modifier(data[5:])
|
||||
return data
|
||||
|
||||
# 如果 nucleon 存在且有 do_eval 方法, 调用它
|
||||
nucleon = self.registry["nucleon"]
|
||||
if nucleon is not None and hasattr(nucleon, "do_eval"):
|
||||
nucleon.do_eval()
|
||||
logger.debug("已调用 nucleon.do_eval")
|
||||
|
||||
# 如果 electron 存在且其 algodata 包含 eval 字符串, 遍历它
|
||||
electron = self.registry["electron"]
|
||||
if electron is not None and hasattr(electron, "algodata"):
|
||||
traverse(electron.algodata, eval_with_env)
|
||||
logger.debug("已处理 electron algodata eval")
|
||||
|
||||
# 如果 orbital 存在且是字典, 遍历它
|
||||
orbital = self.registry["orbital"]
|
||||
if orbital is not None and isinstance(orbital, dict):
|
||||
traverse(orbital, eval_with_env)
|
||||
logger.debug("orbital eval 完成")
|
||||
|
||||
logger.debug("Atom.do_eval 完成")
|
||||
|
||||
def persist(self, key):
|
||||
logger.debug("Atom.persist: key='%s'", key)
|
||||
path: pathlib.Path | None = self.registry[key + "_path"]
|
||||
if isinstance(path, pathlib.Path):
|
||||
path = typing.cast(pathlib.Path, path)
|
||||
logger.debug("持久化路径: %s, 格式: %s", path, self.registry[key + "_fmt"])
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if self.registry[key + "_fmt"] == "toml":
|
||||
with open(path, "r+") as f:
|
||||
f.seek(0)
|
||||
f.truncate()
|
||||
toml.dump(self.registry[key], f)
|
||||
logger.debug("TOML 数据已保存到: %s", path)
|
||||
elif self.registry[key + "_fmt"] == "json":
|
||||
with open(path, "r+") as f:
|
||||
origin = json.load(f)
|
||||
f.seek(0)
|
||||
f.truncate()
|
||||
origin[self.ident] = self.registry[key].algodata
|
||||
json.dump(origin, f, indent=2, ensure_ascii=False)
|
||||
logger.debug("JSON 数据已保存到: %s", path)
|
||||
else:
|
||||
logger.error("不受支持的持久化格式: %s", self.registry[key + "_fmt"])
|
||||
raise KeyError("不受支持的持久化格式")
|
||||
else:
|
||||
logger.error("路径未初始化: %s_path", key)
|
||||
raise TypeError("对未初始化的路径对象操作")
|
||||
|
||||
def __getitem__(self, key):
|
||||
logger.debug("Atom.__getitem__: key='%s'", key)
|
||||
if key in self.registry:
|
||||
value = self.registry[key]
|
||||
logger.debug("返回 value type: %s", type(value).__name__)
|
||||
return value
|
||||
logger.error("不支持的键: '%s'", key)
|
||||
raise KeyError(f"不支持的键: {key}")
|
||||
return self.registry[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
logger.debug(
|
||||
"Atom.__setitem__: key='%s', value type: %s", key, type(value).__name__
|
||||
)
|
||||
if key in self.registry:
|
||||
self.registry[key] = value
|
||||
logger.debug("键 '%s' 已设置", key)
|
||||
else:
|
||||
logger.error("不支持的键: '%s'", key)
|
||||
raise KeyError(f"不支持的键: {key}")
|
||||
if key == "ident":
|
||||
raise AttributeError("应为只读")
|
||||
self.registry[key] = value
|
||||
|
||||
@staticmethod
|
||||
def placeholder():
|
||||
return (Electron.placeholder(), Nucleon.placeholder(), {})
|
||||
def __repr__(self):
|
||||
from pprint import pformat
|
||||
|
||||
|
||||
atom_registry: bidict.bidict[str, Atom] = bidict.bidict()
|
||||
s = pformat(self.registry, indent=4)
|
||||
return s
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from copy import deepcopy
|
||||
from typing import TypedDict
|
||||
|
||||
import heurams.kernel.algorithms as algolib
|
||||
import heurams.services.timer as timer
|
||||
from heurams.context import config_var
|
||||
from heurams.kernel.algorithms import algorithms
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
@@ -7,87 +10,70 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Electron:
|
||||
"""电子: 记忆分析元数据及算法"""
|
||||
"""电子: 单算法支持的记忆数据包装"""
|
||||
|
||||
def __init__(self, ident: str, algodata: dict = {}, algo_name: str = "SM-2"):
|
||||
def __init__(self, ident: str, algodata: dict = {}, algo_name: str = ""):
|
||||
"""初始化电子对象 (记忆数据)
|
||||
|
||||
Args:
|
||||
ident: 算法的唯一标识符, 用于区分不同的算法实例, 使用 algodata[ident] 获取
|
||||
algodata: 算法数据字典, 包含算法的各项参数和设置
|
||||
algo: 使用的算法模块标识
|
||||
algodata: 算法数据字典引用, 包含算法的各项参数和设置
|
||||
algo_name: 使用的算法模块标识
|
||||
"""
|
||||
logger.debug(
|
||||
"创建 Electron 实例, ident: '%s', algo_name: '%s'", ident, algo_name
|
||||
)
|
||||
if algo_name == "":
|
||||
algo_name = "SM-2"
|
||||
self.algodata = algodata
|
||||
self.ident = ident
|
||||
self.algo = algorithms[algo_name]
|
||||
logger.debug("使用的算法类: %s", self.algo.__name__)
|
||||
self.algo: algolib.BaseAlgorithm = algorithms[algo_name]
|
||||
|
||||
if self.algo not in self.algodata.keys():
|
||||
self.algodata[self.algo.algo_name] = {}
|
||||
logger.debug("算法键 '%s' 不存在, 已创建空字典", self.algo)
|
||||
if not self.algodata[self.algo.algo_name]:
|
||||
logger.debug("算法数据为空, 使用默认值初始化")
|
||||
self._default_init(self.algo.defaults)
|
||||
else:
|
||||
logger.debug("算法数据已存在, 跳过默认初始化")
|
||||
logger.debug(
|
||||
"Electron 初始化完成, algodata keys: %s", list(self.algodata.keys())
|
||||
)
|
||||
if not self.algo.check_integrity(self.algodata):
|
||||
self.algodata[self.algo.algo_name] = deepcopy(self.algo.defaults)
|
||||
|
||||
def _default_init(self, defaults: dict):
|
||||
"""默认初始化包装"""
|
||||
logger.debug(
|
||||
"Electron._default_init: 使用默认值, keys: %s", list(defaults.keys())
|
||||
)
|
||||
self.algodata[self.algo.algo_name] = defaults.copy()
|
||||
def __repr__(self):
|
||||
from pprint import pformat
|
||||
|
||||
s = pformat(self.algodata, indent=4)
|
||||
return s
|
||||
|
||||
def activate(self):
|
||||
"""激活此电子"""
|
||||
logger.debug("Electron.activate: 激活 ident='%s'", self.ident)
|
||||
self.algodata[self.algo.algo_name]["is_activated"] = 1
|
||||
self.algodata[self.algo.algo_name]["last_modify"] = timer.get_timestamp()
|
||||
logger.debug("电子已激活, is_activated=1")
|
||||
|
||||
def modify(self, var: str, value):
|
||||
def modify(self, key, value):
|
||||
"""修改 algodata[algo] 中子字典数据"""
|
||||
logger.debug("Electron.modify: var='%s', value=%s", var, value)
|
||||
if var in self.algodata[self.algo.algo_name]:
|
||||
self.algodata[self.algo.algo_name][var] = value
|
||||
if key in self.algodata[self.algo.algo_name]:
|
||||
self.algodata[self.algo.algo_name][key] = value
|
||||
self.algodata[self.algo.algo_name]["last_modify"] = timer.get_timestamp()
|
||||
logger.debug("变量 '%s' 已修改, 更新 last_modify", var)
|
||||
else:
|
||||
logger.warning("'%s' 非已知元数据字段", var)
|
||||
print(f"警告: '{var}' 非已知元数据字段")
|
||||
raise AttributeError("不存在的子键")
|
||||
|
||||
def is_due(self):
|
||||
"""是否应该复习"""
|
||||
logger.debug("Electron.is_due: 检查 ident='%s'", self.ident)
|
||||
result = self.algo.is_due(self.algodata)
|
||||
logger.debug("is_due 结果: %s", result)
|
||||
return result and self.is_activated()
|
||||
|
||||
def is_activated(self):
|
||||
result = self.algodata[self.algo.algo_name]["is_activated"]
|
||||
logger.debug("Electron.is_activated: ident='%s', 结果: %d", self.ident, result)
|
||||
return result
|
||||
|
||||
def get_rate(self):
|
||||
"评价"
|
||||
def last_modify(self):
|
||||
result = self.algodata[self.algo.algo_name]["last_modify"]
|
||||
return result
|
||||
|
||||
def get_rating(self):
|
||||
try:
|
||||
logger.debug("Electron.rate: ident='%s'", self.ident)
|
||||
result = self.algo.rate(self.algodata)
|
||||
logger.debug("rate 结果: %s", result)
|
||||
result = self.algo.get_rating(self.algodata)
|
||||
return result
|
||||
except:
|
||||
return 0
|
||||
|
||||
def nextdate(self) -> int:
|
||||
logger.debug("Electron.nextdate: ident='%s'", self.ident)
|
||||
result = self.algo.nextdate(self.algodata)
|
||||
logger.debug("nextdate 结果: %d", result)
|
||||
return result
|
||||
|
||||
def lastdate(self) -> int:
|
||||
result = self.algodata[self.algo.algo_name]["lastdate"]
|
||||
return result
|
||||
|
||||
def revisor(self, quality: int = 5, is_new_activation: bool = False):
|
||||
@@ -97,32 +83,7 @@ class Electron:
|
||||
quality (int): 记忆保留率量化参数 (0-5)
|
||||
is_new_activation (bool): 是否为初次激活
|
||||
"""
|
||||
logger.debug(
|
||||
"Electron.revisor: ident='%s', quality=%d, is_new_activation=%s",
|
||||
self.ident,
|
||||
quality,
|
||||
is_new_activation,
|
||||
)
|
||||
self.algo.revisor(self.algodata, quality, is_new_activation)
|
||||
logger.debug(
|
||||
"revisor 完成, 更新后的 algodata: %s", self.algodata.get(self.algo, {})
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"记忆单元预览 \n"
|
||||
f"标识符: '{self.ident}' \n"
|
||||
f"算法: {self.algo} \n"
|
||||
f"易度系数: {self.algodata[self.algo.algo_name]['efactor']:.2f} \n"
|
||||
f"已经重复的次数: {self.algodata[self.algo.algo_name]['rept']} \n"
|
||||
f"下次间隔: {self.algodata[self.algo.algo_name]['interval']} 天 \n"
|
||||
f"下次复习日期时间戳: {self.algodata[self.algo.algo_name]['next_date']}"
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.ident == other.ident:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.ident)
|
||||
@@ -146,6 +107,9 @@ class Electron:
|
||||
return len(self.algodata[self.algo.algo_name])
|
||||
|
||||
@staticmethod
|
||||
def placeholder():
|
||||
"""生成一个电子占位符"""
|
||||
return Electron("电子对象样例内容", {})
|
||||
def create_on_electonic_data(electronic_data: tuple, algo_name: str = ""):
|
||||
_data = electronic_data
|
||||
ident = _data[0]
|
||||
algodata = _data[1]
|
||||
ident = ident
|
||||
return Electron(ident, algodata, algo_name)
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
import json
|
||||
import pathlib
|
||||
from copy import deepcopy
|
||||
|
||||
import toml
|
||||
|
||||
import heurams.services.hasher as hasher
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .electron import Electron
|
||||
from .nucleon import Nucleon
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def load_nucleon(path: pathlib.Path, fmt="toml"):
|
||||
logger.debug("load_nucleon: 加载文件 %s, 格式: %s", path, fmt)
|
||||
with open(path, "r") as f:
|
||||
dictdata = dict()
|
||||
dictdata = toml.load(f) # type: ignore
|
||||
logger.debug("TOML 解析成功, keys: %s", list(dictdata.keys()))
|
||||
lst = list()
|
||||
nested_data = dict()
|
||||
# 修正 toml 解析器的不管嵌套行为
|
||||
for key, value in dictdata.items():
|
||||
if "__metadata__" in key: # 以免影响句号
|
||||
if "." in key:
|
||||
parts = key.split(".")
|
||||
current = nested_data
|
||||
for part in parts[:-1]:
|
||||
if part not in current:
|
||||
current[part] = {}
|
||||
current = current[part]
|
||||
current[parts[-1]] = value
|
||||
logger.debug("处理元数据键: %s", key)
|
||||
else:
|
||||
nested_data[key] = value
|
||||
logger.debug("嵌套数据处理完成, keys: %s", list(nested_data.keys()))
|
||||
# print(nested_data)
|
||||
for item, attr in nested_data.items():
|
||||
if item == "__metadata__":
|
||||
continue
|
||||
logger.debug("处理项目: %s", item)
|
||||
lst.append(
|
||||
(
|
||||
Nucleon(item, attr, deepcopy(nested_data["__metadata__"])),
|
||||
deepcopy(nested_data["__metadata__"]["orbital"]),
|
||||
)
|
||||
)
|
||||
logger.debug("load_nucleon 完成, 加载了 %d 个 Nucleon 对象", len(lst))
|
||||
return lst
|
||||
|
||||
|
||||
def load_electron(path: pathlib.Path, fmt="json") -> dict:
|
||||
"""从文件路径加载电子对象
|
||||
|
||||
Args:
|
||||
path (pathlib.Path): 路径
|
||||
fmt (str): 文件格式(可选, 默认 json)
|
||||
|
||||
Returns:
|
||||
dict: 键名是电子对象名称, 值是电子对象
|
||||
"""
|
||||
logger.debug("load_electron: 加载文件 %s, 格式: %s", path, fmt)
|
||||
with open(path, "r") as f:
|
||||
dictdata = dict()
|
||||
dictdata = json.load(f) # type: ignore
|
||||
logger.debug("JSON 解析成功, keys: %s", list(dictdata.keys()))
|
||||
dic = dict()
|
||||
for item, attr in dictdata.items():
|
||||
logger.debug("处理电子项目: %s", item)
|
||||
dic[item] = Electron(item, attr)
|
||||
logger.debug("load_electron 完成, 加载了 %d 个 Electron 对象", len(dic))
|
||||
return dic
|
||||
@@ -1,104 +1,64 @@
|
||||
from copy import deepcopy
|
||||
from logging import config
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.services.logger import get_logger
|
||||
from heurams.kernel.auxiliary.evalizor import Evalizer
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Nucleon:
|
||||
"""原子核: 材料元数据"""
|
||||
"""原子核: 带有运行时隔离的模板化只读材料元数据容器"""
|
||||
|
||||
def __init__(self, ident: str, payload: dict, metadata: dict = {}):
|
||||
"""初始化原子核 (记忆内容)
|
||||
|
||||
Args:
|
||||
ident: 唯一标识符
|
||||
payload: 记忆内容信息
|
||||
metadata: 可选元数据信息
|
||||
"""
|
||||
logger.debug(
|
||||
"创建 Nucleon 实例, ident: '%s', payload keys: %s, metadata keys: %s",
|
||||
ident,
|
||||
list(payload.keys()) if payload else [],
|
||||
list(metadata.keys()) if metadata else [],
|
||||
)
|
||||
self.metadata = metadata
|
||||
self.payload = payload
|
||||
def __init__(self, ident, payload, common):
|
||||
self.ident = ident
|
||||
logger.debug("Nucleon 初始化完成")
|
||||
env = {
|
||||
"payload": payload,
|
||||
"default": config_var.get()["puzzles"],
|
||||
"nucleon": (payload | common),
|
||||
}
|
||||
self.evalizer = Evalizer(environment=env)
|
||||
self.data: dict = self.evalizer(deepcopy((payload | common))) # type: ignore
|
||||
|
||||
def __getitem__(self, key):
|
||||
logger.debug("Nucleon.__getitem__: key='%s'", key)
|
||||
if key == "ident":
|
||||
logger.debug("返回 ident: '%s'", self.ident)
|
||||
return self.ident
|
||||
if key in self.payload:
|
||||
value = self.payload[key]
|
||||
logger.debug(
|
||||
"返回 payload['%s'], value type: %s", key, type(value).__name__
|
||||
)
|
||||
return value
|
||||
if isinstance(key, str):
|
||||
if key == "ident":
|
||||
return self.ident
|
||||
return self.data[key]
|
||||
else:
|
||||
logger.error("键 '%s' 未在 payload 中找到", key)
|
||||
raise KeyError(f"Key '{key}' not found in payload.")
|
||||
raise AttributeError
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
raise AttributeError("应为只读")
|
||||
|
||||
def __delitem__(self, key):
|
||||
raise AttributeError("应为只读")
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.payload.keys()
|
||||
return iter(self.data)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in (self.data)
|
||||
|
||||
def get(self, key, default=None):
|
||||
if key in self:
|
||||
return self[key]
|
||||
return default
|
||||
|
||||
def __len__(self):
|
||||
return len(self.payload)
|
||||
return len(self.data)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.ident)
|
||||
def __repr__(self):
|
||||
from pprint import pformat
|
||||
|
||||
def do_eval(self):
|
||||
"""
|
||||
执行并以结果替换当前单元的所有 eval 语句
|
||||
TODO: 带有限制的 eval, 异步/多线程执行避免堵塞
|
||||
"""
|
||||
logger.debug("Nucleon.do_eval 开始")
|
||||
|
||||
# eval 环境设置
|
||||
def eval_with_env(s: str):
|
||||
try:
|
||||
nucleon = self
|
||||
eval_value = eval(s)
|
||||
if isinstance(eval_value, (int, float)):
|
||||
ret = str(eval_value)
|
||||
else:
|
||||
ret = eval_value
|
||||
logger.debug(
|
||||
"eval 执行成功: '%s' -> '%s'",
|
||||
s,
|
||||
str(ret)[:50] + "..." if len(ret) > 50 else ret,
|
||||
)
|
||||
except Exception as e:
|
||||
ret = f"此 eval 实例发生错误: {e}"
|
||||
logger.warning("eval 执行错误: '%s' -> %s", s, e)
|
||||
return ret
|
||||
|
||||
def traverse(data, modifier):
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
data[key] = traverse(value, modifier)
|
||||
return data
|
||||
elif isinstance(data, list):
|
||||
for i, item in enumerate(data):
|
||||
data[i] = traverse(item, modifier)
|
||||
return data
|
||||
elif isinstance(data, tuple):
|
||||
return tuple(traverse(item, modifier) for item in data)
|
||||
else:
|
||||
if isinstance(data, str):
|
||||
if data.startswith("eval:"):
|
||||
logger.debug("发现 eval 表达式: '%s'", data[5:])
|
||||
return modifier(data[5:])
|
||||
return data
|
||||
|
||||
traverse(self.payload, eval_with_env)
|
||||
traverse(self.metadata, eval_with_env)
|
||||
logger.debug("Nucleon.do_eval 完成")
|
||||
s = pformat(self.data, indent=4)
|
||||
return s
|
||||
|
||||
@staticmethod
|
||||
def placeholder():
|
||||
"""生成一个占位原子核"""
|
||||
logger.debug("创建 Nucleon 占位符")
|
||||
return Nucleon("核子对象样例内容", {})
|
||||
def create_on_nucleonic_data(nucleonic_data: tuple):
|
||||
_data = nucleonic_data
|
||||
payload = _data[1][0]
|
||||
common = _data[1][1]
|
||||
ident = _data[0] # TODO:实现eval
|
||||
return Nucleon(ident, payload, common)
|
||||
|
||||
@@ -1,30 +1,17 @@
|
||||
from typing import TypedDict
|
||||
"""轨道对象"""
|
||||
|
||||
from heurams.services.logger import get_logger
|
||||
# 似乎没有实现这个类的必要...
|
||||
# 那不妨在这儿写点文档
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.debug("Orbital 类型定义模块已加载")
|
||||
|
||||
|
||||
class OrbitalSchedule(TypedDict):
|
||||
quick_review: list
|
||||
recognition: list
|
||||
final_review: list
|
||||
|
||||
|
||||
class Orbital(TypedDict):
|
||||
schedule: OrbitalSchedule
|
||||
puzzles: dict
|
||||
|
||||
|
||||
"""一份示例
|
||||
["__metadata__.orbital.puzzles"] # 谜题定义
|
||||
"Recognition" = { __origin__ = "recognition", __hint__ = "", primary = "eval:nucleon['content']", secondery = ["eval:nucleon['keyword_note']", "eval:nucleon['note']"], top_dim = ["eval:nucleon['translation']"] }
|
||||
"SelectMeaning" = { __origin__ = "mcq", __hint__ = "eval:nucleon['content']", jammer = "eval:nucleon['keyword_note']", max_riddles_num = "eval:default['mcq']['max_riddles_num']", prefix = "选择正确项: " }
|
||||
"FillBlank" = { __origin__ = "cloze", __hint__ = "", text = "eval:nucleon['content']", delimiter = "eval:metadata['formation']['delimiter']", min_denominator = "eval:default['cloze']['min_denominator']"}
|
||||
|
||||
["__metadata__.orbital.schedule"] # 内置的推荐学习方案
|
||||
quick_review = [["FillBlank", "1.0"], ["SelectMeaning", "0.5"], ["recognition", "1.0"]]
|
||||
recognition = [["recognition", "1.0"]]
|
||||
final_review = [["FillBlank", "0.7"], ["SelectMeaning", "0.7"], ["recognition", "1.0"]]
|
||||
"""
|
||||
orbital, 即轨道, 是定义队列式复习阶段流程的数据结构, 其实就是个字典, 至于为何不用typeddict, 因为懒.
|
||||
|
||||
orbital_example = {
|
||||
"schedule": [列表 存储阶段(phases)名称]
|
||||
"phases":{
|
||||
阶段名称 = [["谜题(puzzle 现称 Puzzles 评估器)名称", "概率系数 可大于1(整数部分为重复次数) 注意使用字符串包裹(toml 规范)"], ...],
|
||||
...
|
||||
}
|
||||
}
|
||||
至于谜题定义 放在 nucleon['puzzles'], 这样设计是为了兼容多种不同谜题实现的记忆单元, 尽管如此, 你也可见其谜题调度方式必须是相同的.
|
||||
"""
|
||||
|
||||
42
src/heurams/kernel/particles/placeholders.py
Normal file
42
src/heurams/kernel/particles/placeholders.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from heurams.kernel.particles import orbital
|
||||
|
||||
from .atom import Atom
|
||||
from .electron import Electron
|
||||
from .nucleon import Nucleon
|
||||
|
||||
orbital_placeholder = {
|
||||
"schedule": ["quick_review", "recognition", "final_review"],
|
||||
"phases": {
|
||||
"quick_review": [
|
||||
["FillBlank", 1.0],
|
||||
["SelectMeaning", 0.5],
|
||||
["Recognition", 1.0],
|
||||
],
|
||||
"recognition": [["Recognition", 1.0]],
|
||||
"final_review": [
|
||||
["FillBlank", 0.7],
|
||||
["SelectMeaning", 0.7],
|
||||
["Recognition", 1.0],
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class NucleonPlaceholder(Nucleon):
|
||||
def __init__(self):
|
||||
super().__init__("__placeholder__", {}, {})
|
||||
|
||||
def __getitem__(self, key):
|
||||
return f"__placeholder__ attempted {key}"
|
||||
|
||||
|
||||
class ElectronPlaceholder(Electron):
|
||||
def __init__(self):
|
||||
super().__init__("__placeholder__", {"": {"": ""}}, "")
|
||||
|
||||
|
||||
class AtomPlaceholder(Atom):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
NucleonPlaceholder(), ElectronPlaceholder(), orbital_placeholder
|
||||
)
|
||||
@@ -1,62 +0,0 @@
|
||||
import pathlib
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def probe_by_filename(filename):
|
||||
"""探测指定文件 (无扩展名) 的所有信息"""
|
||||
logger.debug("probe_by_filename: 探测文件 '%s'", filename)
|
||||
paths: dict = config_var.get().get("paths")
|
||||
logger.debug("配置路径: %s", paths)
|
||||
formats = ["toml", "json"]
|
||||
result = {}
|
||||
for item, attr in paths.items():
|
||||
for i in formats:
|
||||
attr: pathlib.Path = pathlib.Path(attr) / filename + "." + i
|
||||
if attr.exists():
|
||||
logger.debug("找到文件: %s", attr)
|
||||
result[item.replace("_dir", "")] = str(attr)
|
||||
else:
|
||||
logger.debug("文件不存在: %s", attr)
|
||||
logger.debug("probe_by_filename 结果: %s", result)
|
||||
return result
|
||||
|
||||
|
||||
def probe_all(is_stem=1):
|
||||
"""依据目录探测所有信息
|
||||
|
||||
Args:
|
||||
is_stem (boolean): 是否**删除**文件扩展名
|
||||
|
||||
Returns:
|
||||
dict: 有三项, 每一项的键名都是文件组类型, 值都是文件组列表, 只包含文件名
|
||||
"""
|
||||
logger.debug("probe_all: 开始探测, is_stem=%d", is_stem)
|
||||
paths: dict = config_var.get().get("paths")
|
||||
logger.debug("配置路径: %s", paths)
|
||||
result = {}
|
||||
for item, attr in paths.items():
|
||||
attr: pathlib.Path = pathlib.Path(attr)
|
||||
result[item.replace("_dir", "")] = list()
|
||||
logger.debug("扫描目录: %s", attr)
|
||||
file_count = 0
|
||||
for i in attr.iterdir():
|
||||
if not i.is_dir():
|
||||
file_count += 1
|
||||
if is_stem:
|
||||
result[item.replace("_dir", "")].append(str(i.stem))
|
||||
else:
|
||||
result[item.replace("_dir", "")].append(str(i.name))
|
||||
logger.debug("目录 %s 中找到 %d 个文件", attr, file_count)
|
||||
logger.debug("probe_all 完成, 结果 keys: %s", list(result.keys()))
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
print(os.getcwd())
|
||||
print(probe_all())
|
||||
@@ -1,13 +1,11 @@
|
||||
"""
|
||||
Puzzle 模块 - 谜题生成系统
|
||||
Puzzles 模块 - 生成评估模块
|
||||
|
||||
提供多种类型的谜题生成器, 支持从字符串、字典等数据源导入题目
|
||||
提供多种类型的辅助评估生成器, 支持从字符串、字典等数据源导入题目
|
||||
"""
|
||||
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
from .base import BasePuzzle
|
||||
from .cloze import ClozePuzzle
|
||||
from .mcq import MCQPuzzle
|
||||
@@ -26,38 +24,3 @@ puzzles = {
|
||||
"recognition": RecognitionPuzzle,
|
||||
"base": BasePuzzle,
|
||||
}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_by_dict(config_dict: dict) -> BasePuzzle:
|
||||
"""
|
||||
根据配置字典创建谜题
|
||||
|
||||
Args:
|
||||
config_dict: 配置字典, 包含谜题类型和参数
|
||||
|
||||
Returns:
|
||||
BasePuzzle: 谜题实例
|
||||
|
||||
Raises:
|
||||
ValueError: 当配置无效时抛出
|
||||
"""
|
||||
logger.debug(
|
||||
"puzzles.create_by_dict: config_dict keys=%s", list(config_dict.keys())
|
||||
)
|
||||
puzzle_type = config_dict.get("type")
|
||||
|
||||
if puzzle_type == "cloze":
|
||||
return puzzles["cloze"](
|
||||
text=config_dict["text"],
|
||||
min_denominator=config_dict.get("min_denominator", 7),
|
||||
)
|
||||
elif puzzle_type == "mcq":
|
||||
return puzzles["mcq"](
|
||||
mapping=config_dict["mapping"],
|
||||
jammer=config_dict.get("jammer", []),
|
||||
max_riddles_num=config_dict.get("max_riddles_num", 2),
|
||||
prefix=config_dict.get("prefix", ""),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"未知的谜题类型: {puzzle_type}")
|
||||
|
||||
12
src/heurams/kernel/puzzles/guess.py
Normal file
12
src/heurams/kernel/puzzles/guess.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import random
|
||||
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .base import BasePuzzle
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GuessPuzzle(BasePuzzle):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
123
src/heurams/kernel/reactor/README.md
Normal file
123
src/heurams/kernel/reactor/README.md
Normal file
@@ -0,0 +1,123 @@
|
||||
# Reactor - 记忆流程状态机模块
|
||||
Reactor 是 HeurAMS 的记忆流程状态机模块, 和界面 (interface) 的实现是解耦的, 以便后期与其他框架的适配.
|
||||
得益于 Pickle, 状态机模块支持快照!
|
||||
## Phaser - 全局阶段控制器
|
||||
在一次队列记忆流程中, Phaser 代表记忆流程本身.
|
||||
### 属性
|
||||
#### 状态属性
|
||||
其有状态属性:
|
||||
- unsure - 用于初始化
|
||||
- *quick_review - 复习逾期的单元
|
||||
- *recognition - 辨识新单元
|
||||
- *final_review - 复习所有逾期的和新辨认的单元
|
||||
- finished - 表示完成
|
||||
> 逾期的: 指 SM-2 算法间隔显示应该复习的单元
|
||||
|
||||
带 * 的属性表示实际的记忆阶段, 由 repo 中 schedule.toml 中 schedule 列表显式声明, 运行过程中可以选择性执行, "空的" Procession 会被直接跳过.
|
||||
|
||||
在初始化 Procession 时, 每个 Procession 被赋予一个不重复的状态属性 作为"阶段状态"属性, 以此标识 Procession 的阶段属性, 因为每个 Procession 管理一个阶段下的复习进程.
|
||||
|
||||
你可以用 state 属性获取 Phaser 的当前状态.
|
||||
#### Procession 属性
|
||||
储存一个顺序列表, 保存所有构造的 Procession.
|
||||
顺序与 repo 中 schedule.toml 中 schedule 列表中的顺序完全相同
|
||||
|
||||
### 初始化
|
||||
Phaser 接受一个存储 Atom 对象的列表, 作为组织记忆流程的材料
|
||||
在内部, 根据是否激活将其分为 new_atoms 与 old_atoms.
|
||||
因此, 如果你传入的列表中有算法上"无所事事"的 Atom, 流程会对其进行"加强复习"
|
||||
由此创建 Procession.
|
||||
|
||||
### 直接输出呈现形式
|
||||
Phaser 的 __repr__ 定义了此对象"官方的显示"用作直观的调试.
|
||||
其以 ascii 表格形式输出, 格式也符合 markdown 表格规范, 你可以直接复制到 markdown.
|
||||
示例:
|
||||
```text
|
||||
| Type | State | Processions | Current Procession |
|
||||
|:-------|:--------|:-----------------------|:---------------------|
|
||||
| Phaser | unsure | ['新记忆', '总体复习'] | 新记忆 |
|
||||
```
|
||||
| Type | State | Processions | Current Procession |
|
||||
|:-------|:--------|:-----------------------|:---------------------|
|
||||
| Phaser | unsure | ['新记忆', '总体复习'] | 新记忆 |
|
||||
|
||||
### 方法
|
||||
作为一个 Transition Machine 对象的继承, 其拥有 Machine 对象拥有的所有方法.
|
||||
除此之外, 它也拥有一些其他方法.
|
||||
#### current_procession(self)
|
||||
用于查询当前的 Procession, 并且根据当前 Procession 更新自身状态.
|
||||
返回一个 Procession 对象, 是当前阶段的 Procession.
|
||||
内部运作是返回第一个状态不为 finished 的 Procession, 并将自身状态变更为 Procession 的"阶段状态"属性
|
||||
若所有 Procession 都已完成, 将返回一个"阶段状态"为 finished 的 Procession 占位符对象(它不在 procession 属性中), 并更新自身状态为 finished.
|
||||
|
||||
## Procession - 阶段管理器
|
||||
### 属性
|
||||
#### 状态属性
|
||||
其有状态属性:
|
||||
- active - 标识未完成, 初始化的默认属性
|
||||
- finished - 完成了
|
||||
#### 其他属性
|
||||
- current_atom: 当前记忆原子的引用
|
||||
- atoms: 队列中所有原子列表
|
||||
- cursor: 指针, 是当前原子在 atoms 列表中的索引
|
||||
- phase: "阶段属性"
|
||||
> 注意区分 "Phaser" 和 "Phase", 其中 "Phase" 表示 "Phaser State".
|
||||
- name_: 阶段的命名
|
||||
- state: 当前状态属性
|
||||
### 初始化
|
||||
接受一个 atoms 列表与 phase_state (PhaserState Enum 类型)对象
|
||||
### 直接输出呈现形式
|
||||
同 Phaser, 但显示数据有所不同
|
||||
与 Phaser 不同, Procession 显示队列会对过长的 atom.ident 进行缩略(末尾 `>` 符号)
|
||||
```text
|
||||
| Type | Name | State | Progress | Queue | Current Atom |
|
||||
|:-----------|:-------|:--------|:-----------|:-----------------------|:------------------------------|
|
||||
| Procession | 新记忆 | active | 1 / 2 | ['秦孝公>', '君臣固>'] | 秦孝公据崤函之固, 拥雍州之地, |
|
||||
```
|
||||
| Type | Name | State | Progress | Queue | Current Atom |
|
||||
|:-----------|:-------|:--------|:-----------|:-----------------------|:------------------------------|
|
||||
| Procession | 新记忆 | active | 1 / 2 | ['秦孝公>', '君臣固>'] | 秦孝公据崤函之固, 拥雍州之地, |
|
||||
### 方法
|
||||
作为一个 Transition Machine 对象的继承, 其拥有 Machine 对象拥有的所有方法.
|
||||
除此之外, 它也拥有一些其他方法.
|
||||
#### forward(self, step=1)
|
||||
移动 cursor 并依情况更新 current_atom 和状态属性
|
||||
无论 Procession 是否处于完成状态, forward 操作都是可逆的, 你可以传入负数, 此时已完成的 Procession 会自动"重启".
|
||||
|
||||
#### append(self, atom=None)
|
||||
追加(回忆失败的)原子(默认为当前原子, 传入 None 会自动转化为当前原子)到队列末端
|
||||
如果这个原子已经处于队列末端, 不会重复追加, 除非队列只剩下这个原子还没完成(此时最多重复两个)
|
||||
|
||||
#### process(self)
|
||||
返回 cursor 值
|
||||
|
||||
#### __len__(self)
|
||||
返回剩余原子量(而不是原子总量)
|
||||
可以使用 len 函数调用
|
||||
获取原子总量请用 len(obj.atoms), 或者 total_length(self) 方法
|
||||
|
||||
#### total_length(self)
|
||||
返回队列原子总量
|
||||
|
||||
#### is_empty(self)
|
||||
判断是否为空队列(传入原子列表对象是空列表的队列)
|
||||
|
||||
#### get_fission(self)
|
||||
获取当前原子的 Fission 对象, 用于单原子调度展开
|
||||
|
||||
## Fission - 单原子调度控制器
|
||||
### 属性
|
||||
#### 状态属性
|
||||
- exammode: 测试模式(默认)
|
||||
- retronly: 仅回顾模式
|
||||
#### 其他属性
|
||||
- cursor
|
||||
- atom
|
||||
- current_puzzle
|
||||
- orbital_schedule
|
||||
- orbital_puzzles
|
||||
- puzzles
|
||||
### 初始化
|
||||
接受 atom 对象和 phase 参数
|
||||
### 方法
|
||||
#### get_puzzles(self)
|
||||
@@ -5,8 +5,4 @@ from .phaser import Phaser
|
||||
from .procession import Procession
|
||||
from .states import PhaserState, ProcessionState
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
__all__ = ["PhaserState", "ProcessionState", "Procession", "Fission", "Phaser"]
|
||||
|
||||
logger.debug("反应堆模块已加载")
|
||||
|
||||
@@ -1,45 +1,123 @@
|
||||
import random
|
||||
from functools import reduce
|
||||
|
||||
from tabulate import tabulate as tabu
|
||||
from transitions import Machine
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
import heurams.kernel.puzzles as puz
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .states import PhaserState
|
||||
from .states import FissionState, PhaserState
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Fission:
|
||||
"""裂变器: 单原子调度展开器"""
|
||||
class Fission(Machine):
|
||||
"""单原子调度展开器"""
|
||||
|
||||
def __init__(self, atom: pt.Atom, phase=PhaserState.RECOGNITION):
|
||||
self.logger = get_logger(__name__)
|
||||
self.phase = phase
|
||||
self.cursor = 0
|
||||
self.atom = atom
|
||||
# print(f"{phase.value}")
|
||||
self.orbital_schedule = atom.registry["orbital"]["schedule"][phase.value] # type: ignore
|
||||
self.orbital_puzzles = atom.registry["orbital"]["puzzles"]
|
||||
# print(self.orbital_schedule)
|
||||
self.puzzles = list()
|
||||
for item, possibility in self.orbital_schedule: # type: ignore
|
||||
print(f"ad:{item}")
|
||||
self.logger.debug(f"开始处理 orbital 项: {item}")
|
||||
self.current_puzzle_inf: dict
|
||||
# phase 为 PhaserState 枚举实例, 需要获取其value
|
||||
phase_value = phase.value
|
||||
states = [
|
||||
{"name": FissionState.EXAMMODE.value},
|
||||
{"name": FissionState.RETRONLY.value},
|
||||
]
|
||||
|
||||
transitions = [
|
||||
{
|
||||
"trigger": "finish",
|
||||
"source": FissionState.EXAMMODE.value,
|
||||
"dest": FissionState.RETRONLY.value,
|
||||
},
|
||||
]
|
||||
if phase == PhaserState.FINISHED:
|
||||
Machine.__init__(
|
||||
self,
|
||||
states=states,
|
||||
transitions=transitions,
|
||||
initial=FissionState.EXAMMODE.value,
|
||||
)
|
||||
return
|
||||
orbital_schedule = atom.registry["orbital"]["phases"][phase_value] # type: ignore
|
||||
orbital_puzzles = atom.registry["nucleon"]["puzzles"]
|
||||
self.puzzles_inf = list()
|
||||
self.min_ratings = []
|
||||
for item, possibility in orbital_schedule: # type: ignore
|
||||
logger.debug(f"开始处理: {item}")
|
||||
if not isinstance(possibility, float):
|
||||
possibility = float(possibility)
|
||||
|
||||
while possibility > 1:
|
||||
self.puzzles.append(
|
||||
self.puzzles_inf.append(
|
||||
{
|
||||
"puzzle": puz.puzzles[self.orbital_puzzles[item]["__origin__"]],
|
||||
"puzzle": puz.puzzles[orbital_puzzles[item]["__origin__"]],
|
||||
"alia": item,
|
||||
}
|
||||
)
|
||||
possibility -= 1
|
||||
|
||||
if random.random() <= possibility:
|
||||
self.puzzles.append(
|
||||
self.puzzles_inf.append(
|
||||
{
|
||||
"puzzle": puz.puzzles[self.orbital_puzzles[item]["__origin__"]],
|
||||
"puzzle": puz.puzzles[orbital_puzzles[item]["__origin__"]],
|
||||
"alia": item,
|
||||
}
|
||||
)
|
||||
print(f"ok:{item}")
|
||||
self.logger.debug(f"orbital 项处理完成: {item}")
|
||||
self.current_puzzle_inf = self.puzzles_inf[0]
|
||||
|
||||
def generate(self):
|
||||
yield from self.puzzles
|
||||
for i in range(len(self.puzzles_inf)):
|
||||
self.min_ratings.append(0x3F3F3F3F)
|
||||
|
||||
Machine.__init__(
|
||||
self,
|
||||
states=states,
|
||||
transitions=transitions,
|
||||
initial=FissionState.EXAMMODE.value,
|
||||
)
|
||||
|
||||
def get_puzzles_inf(self):
|
||||
if self.state == "retronly":
|
||||
return [{"puzzle": puz.puzzles["recognition"], "alia": "Recognition"}]
|
||||
return self.puzzles_inf
|
||||
|
||||
def get_current_puzzle_inf(self):
|
||||
if self.state == "retronly":
|
||||
return {"puzzle": puz.puzzles["recognition"], "alia": "Recognition"}
|
||||
return self.current_puzzle_inf
|
||||
|
||||
def report(self, rating):
|
||||
self.min_ratings[self.cursor] = min(rating, self.min_ratings[self.cursor])
|
||||
|
||||
def get_quality(self):
|
||||
if self.is_state("retronly", self):
|
||||
return reduce(lambda x, y: min(x, y), self.min_ratings)
|
||||
raise IndexError
|
||||
|
||||
def forward(self, step=1):
|
||||
"""将谜题指针向前移动并依情况更新或完成"""
|
||||
self.cursor += step
|
||||
if self.cursor >= len(self.puzzles_inf):
|
||||
if self.state != "retronly":
|
||||
self.finish()
|
||||
else:
|
||||
self.current_puzzle_inf = self.puzzles_inf[self.cursor]
|
||||
|
||||
def __repr__(self, style="pipe", ends="\n") -> str:
|
||||
from heurams.services.textproc import truncate
|
||||
|
||||
dic = [
|
||||
{
|
||||
"Type": "Fission",
|
||||
"Atom": truncate(self.atom.ident),
|
||||
"State": self.state,
|
||||
"Progress": f"{self.cursor + 1} / {len(self.puzzles_inf)}",
|
||||
"Queue": list(map(lambda f: truncate(f["alia"]), self.puzzles_inf)),
|
||||
"Current Puzzle": f"{self.current_puzzle_inf['alia']}@{self.current_puzzle_inf['puzzle'].__name__}", # type: ignore
|
||||
}
|
||||
]
|
||||
return str(tabu(dic, headers="keys", tablefmt=style)) + ends
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# 移相器类定义
|
||||
from click import style
|
||||
from transitions import Machine
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
from heurams.kernel.particles.placeholders import AtomPlaceholder
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .procession import Procession
|
||||
@@ -9,43 +11,139 @@ from .states import PhaserState, ProcessionState
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Phaser:
|
||||
"""移相器: 全局调度阶段管理器"""
|
||||
class Phaser(Machine):
|
||||
"""全局调度阶段管理器"""
|
||||
|
||||
def __init__(self, atoms: list[pt.Atom]) -> None:
|
||||
logger.debug("Phaser.__init__: 原子数量=%d", len(atoms))
|
||||
|
||||
self.atoms = atoms
|
||||
new_atoms = list()
|
||||
old_atoms = list()
|
||||
self.state = PhaserState.UNSURE
|
||||
|
||||
for i in atoms:
|
||||
if not i.registry["electron"].is_activated():
|
||||
new_atoms.append(i)
|
||||
else:
|
||||
old_atoms.append(i)
|
||||
|
||||
logger.debug("新原子数量=%d, 旧原子数量=%d", len(new_atoms), len(old_atoms))
|
||||
|
||||
self.processions = list()
|
||||
# TODO: 改进为基于配置文件的可选复习阶段
|
||||
if len(old_atoms):
|
||||
self.processions.append(
|
||||
Procession(old_atoms, PhaserState.QUICK_REVIEW, "初始复习")
|
||||
)
|
||||
logger.debug("创建初始复习 Procession")
|
||||
|
||||
if len(new_atoms):
|
||||
self.processions.append(
|
||||
Procession(new_atoms, PhaserState.RECOGNITION, "新记忆")
|
||||
)
|
||||
logger.debug("创建新记忆 Procession")
|
||||
|
||||
self.processions.append(Procession(atoms, PhaserState.FINAL_REVIEW, "总体复习"))
|
||||
logger.debug("创建总体复习 Procession")
|
||||
logger.debug("Phaser 初始化完成, processions 数量=%d", len(self.processions))
|
||||
|
||||
# 设置transitions状态机
|
||||
states = [
|
||||
{"name": PhaserState.UNSURE.value, "on_enter": "on_unsure"},
|
||||
{"name": PhaserState.QUICK_REVIEW.value, "on_enter": "on_quick_review"},
|
||||
{"name": PhaserState.RECOGNITION.value, "on_enter": "on_recognition"},
|
||||
{"name": PhaserState.FINAL_REVIEW.value, "on_enter": "on_final_review"},
|
||||
{"name": PhaserState.FINISHED.value, "on_enter": "on_finished"},
|
||||
]
|
||||
|
||||
transitions = [
|
||||
{"trigger": "to_unsure", "source": "*", "dest": PhaserState.UNSURE.value},
|
||||
{
|
||||
"trigger": "to_quick_review",
|
||||
"source": "*",
|
||||
"dest": PhaserState.QUICK_REVIEW.value,
|
||||
},
|
||||
{
|
||||
"trigger": "to_recognition",
|
||||
"source": "*",
|
||||
"dest": PhaserState.RECOGNITION.value,
|
||||
},
|
||||
{
|
||||
"trigger": "to_final_review",
|
||||
"source": "*",
|
||||
"dest": PhaserState.FINAL_REVIEW.value,
|
||||
},
|
||||
{
|
||||
"trigger": "to_finished",
|
||||
"source": "*",
|
||||
"dest": PhaserState.FINISHED.value,
|
||||
},
|
||||
]
|
||||
|
||||
Machine.__init__(
|
||||
self,
|
||||
states=states,
|
||||
transitions=transitions,
|
||||
initial=PhaserState.UNSURE.value,
|
||||
)
|
||||
|
||||
self.to_unsure()
|
||||
|
||||
def on_unsure(self):
|
||||
"""进入UNSURE状态时的回调"""
|
||||
logger.debug("Phaser 进入 UNSURE 状态")
|
||||
|
||||
def on_quick_review(self):
|
||||
"""进入QUICK_REVIEW状态时的回调"""
|
||||
logger.debug("Phaser 进入 QUICK_REVIEW 状态")
|
||||
|
||||
def on_recognition(self):
|
||||
"""进入RECOGNITION状态时的回调"""
|
||||
logger.debug("Phaser 进入 RECOGNITION 状态")
|
||||
|
||||
def on_final_review(self):
|
||||
"""进入FINAL_REVIEW状态时的回调"""
|
||||
logger.debug("Phaser 进入 FINAL_REVIEW 状态")
|
||||
|
||||
def on_finished(self):
|
||||
"""进入FINISHED状态时的回调"""
|
||||
for i in self.atoms:
|
||||
i.lock(1)
|
||||
i.revise()
|
||||
logger.debug("Phaser 进入 FINISHED 状态")
|
||||
|
||||
def current_procession(self):
|
||||
logger.debug("Phaser.current_procession 被调用")
|
||||
for i in self.processions:
|
||||
i: Procession
|
||||
if not i.state == ProcessionState.FINISHED:
|
||||
self.state = i.phase
|
||||
if i.state != ProcessionState.FINISHED.value:
|
||||
# if i.phase == PhaserState.UNSURE: 此判断是不必要的 因为没有这种 Procession
|
||||
if i.phase == PhaserState.QUICK_REVIEW:
|
||||
self.to_quick_review()
|
||||
elif i.phase == PhaserState.RECOGNITION:
|
||||
self.to_recognition()
|
||||
elif i.phase == PhaserState.FINAL_REVIEW:
|
||||
self.to_final_review()
|
||||
|
||||
logger.debug("找到未完成的 Procession: phase=%s", i.phase)
|
||||
return i
|
||||
self.state = PhaserState.FINISHED
|
||||
|
||||
# 所有Procession都已完成
|
||||
self.to_finished()
|
||||
logger.debug("所有 Procession 已完成, 状态设置为 FINISHED")
|
||||
return 0
|
||||
return Procession([AtomPlaceholder()], PhaserState.FINISHED)
|
||||
|
||||
def __repr__(self, style="pipe", ends="\n"):
|
||||
from tabulate import tabulate as tabu
|
||||
|
||||
from heurams.services.textproc import truncate
|
||||
|
||||
lst = [
|
||||
{
|
||||
"Type": "Phaser",
|
||||
"State": self.state,
|
||||
"Processions": list(map(lambda f: (f.name_), self.processions)),
|
||||
"Current Procession": "None" if not self.current_procession() else self.current_procession().name_, # type: ignore
|
||||
},
|
||||
]
|
||||
return str(tabu(tabular_data=lst, headers="keys", tablefmt=style)) + ends
|
||||
|
||||
@@ -1,61 +1,101 @@
|
||||
from tabulate import tabulate as tabu
|
||||
from transitions import Machine
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .fission import Fission
|
||||
from .states import PhaserState, ProcessionState
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Procession:
|
||||
class Procession(Machine):
|
||||
"""队列: 标识单次记忆流程"""
|
||||
|
||||
def __init__(self, atoms: list, phase: PhaserState, name: str = ""):
|
||||
def __init__(self, atoms: list, phase_state: PhaserState, name_: str = ""):
|
||||
logger.debug(
|
||||
"Procession.__init__: 原子数量=%d, phase=%s, name='%s'",
|
||||
len(atoms),
|
||||
phase.value,
|
||||
name,
|
||||
phase_state.value,
|
||||
name_,
|
||||
)
|
||||
self.current_atom: pt.Atom | None
|
||||
self.atoms = atoms
|
||||
self.queue = atoms.copy()
|
||||
self.current_atom = atoms[0]
|
||||
self.current_atom = atoms[0] if atoms else None
|
||||
self.cursor = 0
|
||||
self.name = name
|
||||
self.phase = phase
|
||||
self.state: ProcessionState = ProcessionState.RUNNING
|
||||
logger.debug("Procession 初始化完成, 队列长度=%d", len(self.queue))
|
||||
self.name_ = name_
|
||||
self.phase = phase_state
|
||||
|
||||
states = [
|
||||
{"name": ProcessionState.ACTIVE.value, "on_enter": "on_active"},
|
||||
{"name": ProcessionState.FINISHED.value, "on_enter": "on_finished"},
|
||||
]
|
||||
|
||||
transitions = [
|
||||
{
|
||||
"trigger": "finish",
|
||||
"source": ProcessionState.ACTIVE.value,
|
||||
"dest": ProcessionState.FINISHED.value,
|
||||
},
|
||||
{
|
||||
"trigger": "restart",
|
||||
"source": ProcessionState.FINISHED.value,
|
||||
"dest": ProcessionState.ACTIVE.value,
|
||||
},
|
||||
]
|
||||
|
||||
Machine.__init__(
|
||||
self,
|
||||
states=states,
|
||||
transitions=transitions,
|
||||
initial=ProcessionState.ACTIVE.value,
|
||||
)
|
||||
|
||||
logger.debug("Procession 初始化完成, 队列长度=%d", len(self.atoms))
|
||||
|
||||
def on_active(self):
|
||||
"""进入active状态时的回调"""
|
||||
logger.debug("Procession 进入 active 状态")
|
||||
|
||||
def on_finished(self):
|
||||
"""进入FINISHED状态时的回调"""
|
||||
logger.debug("Procession 进入 FINISHED 状态")
|
||||
|
||||
def forward(self, step=1):
|
||||
"""将记忆原子指针向前移动并依情况更新原子(返回 1)或完成队列(返回 0)"""
|
||||
logger.debug("Procession.forward: step=%d, 当前 cursor=%d", step, self.cursor)
|
||||
self.cursor += step
|
||||
if self.cursor == len(self.queue):
|
||||
self.state = ProcessionState.FINISHED
|
||||
if self.cursor >= len(self.atoms):
|
||||
if self.state != ProcessionState.FINISHED.value:
|
||||
self.finish() # 触发状态转换
|
||||
logger.debug("Procession 已完成")
|
||||
else:
|
||||
self.state = ProcessionState.RUNNING
|
||||
try:
|
||||
if self.state != ProcessionState.ACTIVE.value:
|
||||
self.restart() # 确保在active状态
|
||||
self.current_atom = self.atoms[self.cursor]
|
||||
logger.debug("cursor 更新为: %d", self.cursor)
|
||||
self.current_atom = self.queue[self.cursor]
|
||||
logger.debug("当前原子更新为: %s", self.current_atom.ident)
|
||||
return 1 # 成功
|
||||
except IndexError as e:
|
||||
logger.debug("IndexError: %s", e)
|
||||
self.state = ProcessionState.FINISHED
|
||||
logger.debug("Procession 因索引错误而完成")
|
||||
return 0
|
||||
logger.debug(
|
||||
"当前原子更新为: %s",
|
||||
self.current_atom.ident if self.current_atom else "None",
|
||||
)
|
||||
|
||||
def append(self, atom=None):
|
||||
if atom == None:
|
||||
"""追加(回忆失败的)原子(默认为当前原子)到队列末端"""
|
||||
if atom is None:
|
||||
atom = self.current_atom
|
||||
logger.debug("Procession.append: atom=%s", atom.ident if atom else "None")
|
||||
if self.queue[len(self.queue) - 1] != atom or len(self) <= 1:
|
||||
self.queue.append(atom)
|
||||
logger.debug("原子已追加到队列, 新队列长度=%d", len(self.queue))
|
||||
|
||||
if not self.atoms or self.atoms[-1] != atom or len(self) <= 1:
|
||||
self.atoms.append(atom)
|
||||
logger.debug("原子已追加到队列, 新队列长度=%d", len(self.atoms))
|
||||
else:
|
||||
logger.debug("原子未追加(重复或队列长度<=1)")
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.queue) - self.cursor
|
||||
if not self.atoms:
|
||||
return 0
|
||||
length = len(self.atoms) - self.cursor
|
||||
logger.debug("Procession.__len__: 剩余长度=%d", length)
|
||||
return length
|
||||
|
||||
@@ -64,11 +104,29 @@ class Procession:
|
||||
return self.cursor
|
||||
|
||||
def total_length(self):
|
||||
total = len(self.queue)
|
||||
total = len(self.atoms)
|
||||
logger.debug("Procession.total_length: %d", total)
|
||||
return total
|
||||
|
||||
def is_empty(self):
|
||||
empty = len(self.queue)
|
||||
logger.debug("Procession.is_empty: %d", empty)
|
||||
empty = len(self.atoms) == 0
|
||||
logger.debug("Procession.is_empty: %s", empty)
|
||||
return empty
|
||||
|
||||
def get_fission(self):
|
||||
return Fission(atom=self.current_atom, phase=self.phase) # type: ignore
|
||||
|
||||
def __repr__(self, style="pipe", ends="\n"):
|
||||
from heurams.services.textproc import truncate
|
||||
|
||||
dic = [
|
||||
{
|
||||
"Type": "Procession",
|
||||
"Name": self.name_,
|
||||
"State": self.state,
|
||||
"Progress": f"{self.cursor + 1} / {len(self.atoms)}",
|
||||
"Queue": list(map(lambda f: truncate(f.ident), self.atoms)),
|
||||
"Current Atom": self.current_atom.ident, # type: ignore
|
||||
}
|
||||
]
|
||||
return str(tabu(dic, headers="keys", tablefmt=style)) + ends
|
||||
|
||||
@@ -14,8 +14,13 @@ class PhaserState(Enum):
|
||||
|
||||
|
||||
class ProcessionState(Enum):
|
||||
RUNNING = auto()
|
||||
FINISHED = auto()
|
||||
ACTIVE = "active"
|
||||
FINISHED = "finished"
|
||||
|
||||
|
||||
class FissionState(Enum):
|
||||
EXAMMODE = "exammode"
|
||||
RETRONLY = "retronly"
|
||||
|
||||
|
||||
logger.debug("状态枚举定义已加载")
|
||||
|
||||
3
src/heurams/kernel/repolib/__init__.py
Normal file
3
src/heurams/kernel/repolib/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .repo import Repo, RepoManifest
|
||||
|
||||
__all__ = ["Repo", "RepoManifest"]
|
||||
3
src/heurams/kernel/repolib/navi.py
Normal file
3
src/heurams/kernel/repolib/navi.py
Normal file
@@ -0,0 +1,3 @@
|
||||
class Navi:
|
||||
def __init__(self, init) -> None:
|
||||
pass
|
||||
176
src/heurams/kernel/repolib/repo.py
Normal file
176
src/heurams/kernel/repolib/repo.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import json
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
import toml
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
|
||||
from heurams.kernel.auxiliary.lict import Lict
|
||||
|
||||
|
||||
class RepoManifest(TypedDict):
|
||||
title: str
|
||||
author: str
|
||||
desc: str
|
||||
|
||||
|
||||
class Repo:
|
||||
file_mapping = {
|
||||
"schedule": "schedule.toml",
|
||||
"payload": "payload.toml",
|
||||
"algodata": "algodata.json",
|
||||
"manifest": "manifest.toml",
|
||||
"typedef": "typedef.toml",
|
||||
}
|
||||
|
||||
type_mapping = {
|
||||
"schedule": "dict",
|
||||
"payload": "lict",
|
||||
"algodata": "lict",
|
||||
"manifest": "dict",
|
||||
"typedef": "dict",
|
||||
}
|
||||
|
||||
default_save_list = ["algodata"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schedule: dict,
|
||||
payload: Lict,
|
||||
manifest: dict,
|
||||
typedef: dict,
|
||||
algodata: Lict,
|
||||
source=None,
|
||||
) -> None:
|
||||
self.schedule: dict = schedule
|
||||
self.manifest: RepoManifest = manifest # type: ignore
|
||||
self.typedef: dict = typedef
|
||||
self.payload: Lict = payload
|
||||
self.algodata: Lict = algodata
|
||||
self.source: Path | None = source # 若存在, 指向 repo 所在 dir
|
||||
self.database = {
|
||||
"schedule": self.schedule,
|
||||
"payload": self.payload,
|
||||
"manifest": self.manifest,
|
||||
"typedef": self.typedef,
|
||||
"algodata": self.algodata,
|
||||
"source": self.source,
|
||||
}
|
||||
self.generate_particles_data()
|
||||
|
||||
def generate_particles_data(self):
|
||||
|
||||
self.nucleonic_data_lict = Lict(
|
||||
initlist=list(map(self._nucleonic_proc, self.payload))
|
||||
)
|
||||
self.orbitic_data = self.schedule
|
||||
self.ident_index = self.nucleonic_data_lict.keys()
|
||||
for i in self.ident_index:
|
||||
self.algodata.append_new((i, {}))
|
||||
self.electronic_data_lict = self.algodata
|
||||
|
||||
def _nucleonic_proc(self, unit):
|
||||
ident = unit[0]
|
||||
common = self.typedef["common"]
|
||||
return (ident, (unit[1], common))
|
||||
|
||||
@staticmethod
|
||||
def _merge(value):
|
||||
def inner(x):
|
||||
return (x, value)
|
||||
|
||||
return inner
|
||||
|
||||
def __len__(self):
|
||||
return len(self.payload)
|
||||
|
||||
def __repr__(self):
|
||||
from pprint import pformat
|
||||
|
||||
s = pformat(self.database, indent=4)
|
||||
return s
|
||||
|
||||
def persist_to_repodir(
|
||||
self, save_list: list | None = None, source: Path | None = None
|
||||
):
|
||||
if save_list == None:
|
||||
save_list = self.default_save_list
|
||||
if self.source != None and source == None:
|
||||
source = self.source
|
||||
if source == None:
|
||||
raise FileNotFoundError("不存在仓库到文件的映射")
|
||||
source.mkdir(parents=True, exist_ok=True)
|
||||
for keyname in save_list:
|
||||
filename = self.file_mapping[keyname]
|
||||
with open(source / filename, "w") as f:
|
||||
try:
|
||||
dict_data = self.database[keyname].dicted_data
|
||||
except:
|
||||
dict_data = dict(self.database[keyname])
|
||||
if filename.endswith("toml"):
|
||||
toml.dump(dict_data, f)
|
||||
elif filename.endswith("json"):
|
||||
json.dump(dict_data, f, ensure_ascii=False, indent=4)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件类型: {filename}")
|
||||
|
||||
def export_to_single_dict(self):
|
||||
return self.database
|
||||
|
||||
@classmethod
|
||||
def create_new_repo(cls, source=None):
|
||||
default_database = {
|
||||
"schedule": {},
|
||||
"payload": Lict([]),
|
||||
"algodata": Lict([]),
|
||||
"manifest": {},
|
||||
"typedef": {},
|
||||
"source": source,
|
||||
}
|
||||
return Repo(**default_database)
|
||||
|
||||
@classmethod
|
||||
def create_from_repodir(cls, source: Path):
|
||||
database = {}
|
||||
for keyname, filename in cls.file_mapping.items():
|
||||
with open(source / filename, "r") as f:
|
||||
loaded: dict
|
||||
if filename.endswith("toml"):
|
||||
loaded = toml.load(f)
|
||||
elif filename.endswith("json"):
|
||||
loaded = json.load(f)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件类型: {filename}")
|
||||
if cls.type_mapping[keyname] == "lict":
|
||||
database[keyname] = Lict(list(loaded.items()))
|
||||
elif cls.type_mapping[keyname] == "dict":
|
||||
database[keyname] = loaded
|
||||
else:
|
||||
raise ValueError(f"不支持的数据容器: {cls.type_mapping[keyname]}")
|
||||
database["source"] = source
|
||||
return Repo(**database)
|
||||
|
||||
@classmethod
|
||||
def create_from_single_dict(cls, dictdata, source: Path | None = None):
|
||||
database = dictdata
|
||||
database["source"] = source
|
||||
return Repo(**database)
|
||||
|
||||
@classmethod
|
||||
def check_repodir(cls, source: Path):
|
||||
try:
|
||||
cls.create_from_repodir(source)
|
||||
return 1
|
||||
except:
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def probe_valid_repos_in_dir(cls, folder: Path):
|
||||
lst = list()
|
||||
for i in folder.iterdir():
|
||||
if i.is_dir():
|
||||
if cls.check_repodir(i):
|
||||
lst.append(i)
|
||||
return lst
|
||||
@@ -1,13 +0,0 @@
|
||||
import pathlib
|
||||
from typing import Protocol
|
||||
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PlayFunctionProtocol(Protocol):
|
||||
def __call__(self, path: pathlib.Path) -> None: ...
|
||||
|
||||
|
||||
logger.debug("音频协议模块已加载")
|
||||
@@ -1,6 +1,19 @@
|
||||
# 大语言模型
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .base import BaseLLM
|
||||
from .openai import OpenAILLM
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.debug("LLM providers 模块已加载")
|
||||
__all__ = [
|
||||
"BaseLLM",
|
||||
"OpenAILLM",
|
||||
]
|
||||
|
||||
providers = {
|
||||
"base": BaseLLM,
|
||||
"openai": OpenAILLM,
|
||||
}
|
||||
|
||||
logger.debug("LLM providers 已注册: %s", list(providers.keys()))
|
||||
|
||||
@@ -1,5 +1,55 @@
|
||||
"""LLM 提供者基类"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.debug("LLM 基类模块已加载")
|
||||
|
||||
class BaseLLM:
|
||||
"""LLM 提供者基类"""
|
||||
|
||||
name = "BaseLLM"
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""初始化 LLM 提供者
|
||||
|
||||
Args:
|
||||
config: 提供者配置字典
|
||||
"""
|
||||
self.config = config
|
||||
logger.debug("BaseLLM 初始化完成")
|
||||
|
||||
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
||||
"""发送聊天消息并获取响应
|
||||
|
||||
Args:
|
||||
messages: 消息列表,每个消息为 {"role": "user"|"assistant"|"system", "content": "消息内容"}
|
||||
**kwargs: 其他参数,如 temperature, max_tokens 等
|
||||
|
||||
Returns:
|
||||
模型返回的文本响应
|
||||
"""
|
||||
logger.debug("BaseLLM.chat: messages=%d, kwargs=%s", len(messages), kwargs)
|
||||
logger.warning("BaseLLM.chat 是基类方法,未实现具体功能")
|
||||
await asyncio.sleep(0) # 避免未使用异步的警告
|
||||
return "BaseLLM 未实现具体功能"
|
||||
|
||||
async def chat_stream(self, messages: List[Dict[str, str]], **kwargs):
|
||||
"""流式聊天(可选实现)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
流式响应的文本块
|
||||
"""
|
||||
logger.debug(
|
||||
"BaseLLM.chat_stream: messages=%d, kwargs=%s", len(messages), kwargs
|
||||
)
|
||||
logger.warning("BaseLLM.chat_stream 是基类方法,未实现具体功能")
|
||||
await asyncio.sleep(0)
|
||||
yield "BaseLLM 未实现流式功能"
|
||||
|
||||
@@ -1,5 +1,96 @@
|
||||
"""OpenAI 兼容 LLM 提供者"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .base import BaseLLM
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.debug("OpenAI provider 模块已加载(未实现)")
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
"""OpenAI 兼容 LLM 提供者"""
|
||||
|
||||
name = "OpenAI"
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
self.api_key = config.get("key", "")
|
||||
self.base_url = config.get("url", "https://api.openai.com/v1")
|
||||
self._client = None
|
||||
logger.debug("OpenAILLM 初始化完成: base_url=%s", self.base_url)
|
||||
|
||||
def _get_client(self):
|
||||
"""获取 OpenAI 客户端(延迟导入)"""
|
||||
if self._client is None:
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
except ImportError:
|
||||
logger.error("未安装 openai 库,请运行: pip install openai")
|
||||
raise ImportError("未安装 openai 库,请运行: pip install openai")
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key if self.api_key else None,
|
||||
base_url=self.base_url if self.base_url else None,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
||||
"""发送聊天消息并获取响应"""
|
||||
logger.debug("OpenAILLM.chat: messages=%d", len(messages))
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
# 默认参数
|
||||
default_kwargs = {
|
||||
"model": kwargs.get("model", "gpt-3.5-turbo"),
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||
}
|
||||
|
||||
# 合并参数,优先使用传入的 kwargs
|
||||
request_kwargs = {**default_kwargs, **kwargs}
|
||||
request_kwargs["messages"] = messages
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(**request_kwargs)
|
||||
content = response.choices[0].message.content
|
||||
logger.debug(
|
||||
"OpenAILLM.chat 成功: response length=%d",
|
||||
len(content) if content else 0,
|
||||
)
|
||||
return content or ""
|
||||
except Exception as e:
|
||||
logger.error("OpenAILLM.chat 失败: %s", e)
|
||||
raise
|
||||
|
||||
async def chat_stream(
|
||||
self, messages: List[Dict[str, str]], **kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式聊天"""
|
||||
logger.debug("OpenAILLM.chat_stream: messages=%d", len(messages))
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
# 默认参数
|
||||
default_kwargs = {
|
||||
"model": kwargs.get("model", "gpt-3.5-turbo"),
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# 合并参数
|
||||
request_kwargs = {**default_kwargs, **kwargs}
|
||||
request_kwargs["messages"] = messages
|
||||
|
||||
try:
|
||||
stream = await client.chat.completions.create(**request_kwargs)
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
except Exception as e:
|
||||
logger.error("OpenAILLM.chat_stream 失败: %s", e)
|
||||
raise
|
||||
|
||||
@@ -2,6 +2,7 @@ import pathlib
|
||||
|
||||
import edge_tts
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .base import BaseTTS
|
||||
@@ -18,7 +19,7 @@ class EdgeTTS(BaseTTS):
|
||||
try:
|
||||
communicate = edge_tts.Communicate(
|
||||
text,
|
||||
"zh-CN-YunjianNeural",
|
||||
config_var.get()["providers"]["tts"]["edgetts"]["voice"],
|
||||
)
|
||||
logger.debug("EdgeTTS 通信对象创建成功, 正在保存音频")
|
||||
communicate.save_sync(str(path))
|
||||
|
||||
@@ -9,5 +9,5 @@ logger = get_logger(__name__)
|
||||
|
||||
play_by_path: Callable = prov[config_var.get()["services"]["audio"]].play_by_path
|
||||
logger.debug(
|
||||
"音频服务初始化完成, 使用 provider: %s", config_var.get()["services"]["audio"]
|
||||
"音频服务初始化完成, 使用 Provider: %s", config_var.get()["services"]["audio"]
|
||||
)
|
||||
|
||||
163
src/heurams/services/favorite_service.py
Normal file
163
src/heurams/services/favorite_service.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# 收藏服务
|
||||
import json
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FavoriteItem:
|
||||
"""收藏项"""
|
||||
|
||||
repo_path: str # 仓库相对路径 (相对于 data/repo)
|
||||
ident: str # 原子标识符
|
||||
added: int # 添加时间戳 (UNIX 秒)
|
||||
# 可选标签
|
||||
tags: List[str] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tags is None:
|
||||
self.tags = []
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"repo_path": self.repo_path,
|
||||
"ident": self.ident,
|
||||
"added": self.added,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "FavoriteItem":
|
||||
return cls(
|
||||
repo_path=data["repo_path"],
|
||||
ident=data["ident"],
|
||||
added=data["added"],
|
||||
tags=data.get("tags", []),
|
||||
)
|
||||
|
||||
|
||||
class FavoriteManager:
|
||||
"""收藏管理器"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_loaded"):
|
||||
self._loaded = True
|
||||
self._favorites: List[FavoriteItem] = []
|
||||
self._file_path = self._get_file_path()
|
||||
self.load()
|
||||
|
||||
def _get_file_path(self) -> Path:
|
||||
"""获取收藏文件路径"""
|
||||
config_path = Path(config_var.get()["paths"]["data"])
|
||||
fav_path = config_path / "global" / "favorites.json"
|
||||
fav_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return fav_path
|
||||
|
||||
def load(self) -> None:
|
||||
"""从文件加载收藏列表"""
|
||||
if self._file_path.exists():
|
||||
try:
|
||||
with open(self._file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
self._favorites = [FavoriteItem.from_dict(item) for item in data]
|
||||
logger.debug("收藏列表加载成功,共 %d 项", len(self._favorites))
|
||||
except Exception as e:
|
||||
logger.error("加载收藏列表失败: %s", e)
|
||||
self._favorites = []
|
||||
else:
|
||||
self._favorites = []
|
||||
|
||||
def save(self) -> None:
|
||||
"""保存收藏列表到文件"""
|
||||
try:
|
||||
data = [item.to_dict() for item in self._favorites]
|
||||
with open(self._file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
logger.debug("收藏列表保存成功,共 %d 项", len(self._favorites))
|
||||
except Exception as e:
|
||||
logger.error("保存收藏列表失败: %s", e)
|
||||
|
||||
def add(self, repo_path: str, ident: str, tags: List[str] | None = None) -> bool:
|
||||
"""添加收藏
|
||||
|
||||
Args:
|
||||
repo_path: 仓库相对路径
|
||||
ident: 原子标识符
|
||||
tags: 标签列表
|
||||
Returns:
|
||||
是否成功添加 (若已存在则返回 False)
|
||||
"""
|
||||
# 检查是否已存在
|
||||
for item in self._favorites:
|
||||
if item.repo_path == repo_path and item.ident == ident:
|
||||
logger.debug("收藏已存在: %s/%s", repo_path, ident)
|
||||
return False
|
||||
item = FavoriteItem(
|
||||
repo_path=repo_path,
|
||||
ident=ident,
|
||||
added=int(time.time()),
|
||||
tags=tags if tags else [],
|
||||
)
|
||||
self._favorites.append(item)
|
||||
self.save()
|
||||
logger.info("添加收藏: %s/%s", repo_path, ident)
|
||||
return True
|
||||
|
||||
def remove(self, repo_path: str, ident: str) -> bool:
|
||||
"""移除收藏
|
||||
|
||||
Returns:
|
||||
是否成功移除 (若不存在则返回 False)
|
||||
"""
|
||||
for idx, item in enumerate(self._favorites):
|
||||
if item.repo_path == repo_path and item.ident == ident:
|
||||
del self._favorites[idx]
|
||||
self.save()
|
||||
logger.info("移除收藏: %s/%s", repo_path, ident)
|
||||
return True
|
||||
logger.debug("收藏不存在: %s/%s", repo_path, ident)
|
||||
return False
|
||||
|
||||
def has(self, repo_path: str, ident: str) -> bool:
|
||||
"""检查是否已收藏"""
|
||||
for item in self._favorites:
|
||||
if item.repo_path == repo_path and item.ident == ident:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_all(self) -> List[FavoriteItem]:
|
||||
"""获取所有收藏项(按添加时间倒序)"""
|
||||
return sorted(self._favorites, key=lambda x: x.added, reverse=True)
|
||||
|
||||
def get_by_repo(self, repo_path: str) -> List[FavoriteItem]:
|
||||
"""获取指定仓库的所有收藏项"""
|
||||
return [item for item in self._favorites if item.repo_path == repo_path]
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空收藏列表"""
|
||||
self._favorites = []
|
||||
self.save()
|
||||
logger.info("清空收藏列表")
|
||||
|
||||
def count(self) -> int:
|
||||
"""收藏总数"""
|
||||
return len(self._favorites)
|
||||
|
||||
|
||||
# 全局单例实例
|
||||
favorite_manager = FavoriteManager()
|
||||
228
src/heurams/services/llm_service.py
Normal file
228
src/heurams/services/llm_service.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""LLM 聊天服务"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.providers.llm import providers as prov
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ChatSession:
|
||||
"""聊天会话,管理单个对话的历史和参数"""
|
||||
|
||||
def __init__(
|
||||
self, session_id: str, llm_provider, system_prompt: str = "", **default_params
|
||||
):
|
||||
"""初始化聊天会话
|
||||
|
||||
Args:
|
||||
session_id: 会话唯一标识符
|
||||
llm_provider: LLM 提供者实例
|
||||
system_prompt: 系统提示词
|
||||
**default_params: 默认参数(temperature, max_tokens, model 等)
|
||||
"""
|
||||
self.session_id = session_id
|
||||
self.llm_provider = llm_provider
|
||||
self.system_prompt = system_prompt
|
||||
self.default_params = default_params
|
||||
|
||||
# 消息历史
|
||||
self.messages: List[Dict[str, str]] = []
|
||||
if system_prompt:
|
||||
self.messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
logger.debug("创建聊天会话: id=%s", session_id)
|
||||
|
||||
def add_message(self, role: str, content: str):
|
||||
"""添加消息到历史"""
|
||||
self.messages.append({"role": role, "content": content})
|
||||
logger.debug(
|
||||
"会话 %s 添加消息: role=%s, length=%d", self.session_id, role, len(content)
|
||||
)
|
||||
|
||||
def clear_history(self):
|
||||
"""清空消息历史(保留系统提示)"""
|
||||
self.messages = []
|
||||
if self.system_prompt:
|
||||
self.messages.append({"role": "system", "content": self.system_prompt})
|
||||
logger.debug("会话 %s 清空历史", self.session_id)
|
||||
|
||||
def set_system_prompt(self, prompt: str):
|
||||
"""设置系统提示词"""
|
||||
self.system_prompt = prompt
|
||||
# 更新消息历史中的系统消息
|
||||
if self.messages and self.messages[0]["role"] == "system":
|
||||
self.messages[0]["content"] = prompt
|
||||
elif prompt:
|
||||
self.messages.insert(0, {"role": "system", "content": prompt})
|
||||
logger.debug("会话 %s 设置系统提示: length=%d", self.session_id, len(prompt))
|
||||
|
||||
async def send_message(self, message: str, **override_params) -> str:
|
||||
"""发送消息并获取响应
|
||||
|
||||
Args:
|
||||
message: 用户消息内容
|
||||
**override_params: 覆盖默认参数
|
||||
|
||||
Returns:
|
||||
模型响应内容
|
||||
"""
|
||||
# 添加用户消息
|
||||
self.add_message("user", message)
|
||||
|
||||
# 合并参数
|
||||
params = {**self.default_params, **override_params}
|
||||
|
||||
# 发送请求
|
||||
logger.debug("会话 %s 发送消息: length=%d", self.session_id, len(message))
|
||||
response = await self.llm_provider.chat(self.messages, **params)
|
||||
|
||||
# 添加助手响应
|
||||
self.add_message("assistant", response)
|
||||
|
||||
return response
|
||||
|
||||
async def send_message_stream(self, message: str, **override_params):
|
||||
"""流式发送消息
|
||||
|
||||
Args:
|
||||
message: 用户消息内容
|
||||
**override_params: 覆盖默认参数
|
||||
|
||||
Yields:
|
||||
流式响应的文本块
|
||||
"""
|
||||
# 添加用户消息
|
||||
self.add_message("user", message)
|
||||
|
||||
# 合并参数
|
||||
params = {**self.default_params, **override_params}
|
||||
|
||||
# 发送流式请求
|
||||
logger.debug("会话 %s 发送流式消息: length=%d", self.session_id, len(message))
|
||||
|
||||
full_response = ""
|
||||
async for chunk in self.llm_provider.chat_stream(self.messages, **params):
|
||||
yield chunk
|
||||
full_response += chunk
|
||||
|
||||
# 添加完整的助手响应到历史
|
||||
self.add_message("assistant", full_response)
|
||||
|
||||
def get_history(self) -> List[Dict[str, str]]:
|
||||
"""获取消息历史(不包括系统消息)"""
|
||||
# 返回用户和助手的消息,可选排除系统消息
|
||||
return [msg for msg in self.messages if msg["role"] != "system"]
|
||||
|
||||
def save_to_file(self, file_path: Path):
|
||||
"""保存会话到文件"""
|
||||
data = {
|
||||
"session_id": self.session_id,
|
||||
"system_prompt": self.system_prompt,
|
||||
"default_params": self.default_params,
|
||||
"messages": self.messages,
|
||||
}
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
logger.debug("会话 %s 保存到: %s", self.session_id, file_path)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: Path, llm_provider):
|
||||
"""从文件加载会话"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
session = cls(
|
||||
session_id=data["session_id"],
|
||||
llm_provider=llm_provider,
|
||||
system_prompt=data.get("system_prompt", ""),
|
||||
**data.get("default_params", {})
|
||||
)
|
||||
session.messages = data["messages"]
|
||||
logger.debug("从文件加载会话: %s", file_path)
|
||||
return session
|
||||
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器,管理多个会话"""
|
||||
|
||||
def __init__(self):
|
||||
self.sessions: Dict[str, ChatSession] = {}
|
||||
self.default_session_id = "default"
|
||||
logger.debug("聊天管理器初始化完成")
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
session_id: Optional[str] = None,
|
||||
create_if_missing: bool = True,
|
||||
**session_params
|
||||
) -> Optional[ChatSession]:
|
||||
"""获取或创建聊天会话
|
||||
|
||||
Args:
|
||||
session_id: 会话标识符,None 则使用默认会话
|
||||
create_if_missing: 如果会话不存在则创建
|
||||
**session_params: 传递给 ChatSession 的参数
|
||||
|
||||
Returns:
|
||||
聊天会话实例,如果不存在且不创建则返回 None
|
||||
"""
|
||||
if session_id is None:
|
||||
session_id = self.default_session_id
|
||||
|
||||
if session_id in self.sessions:
|
||||
return self.sessions[session_id]
|
||||
|
||||
if create_if_missing:
|
||||
# 获取 LLM 提供者
|
||||
provider_name = config_var.get()["services"]["llm"]
|
||||
provider_config = config_var.get()["providers"]["llm"][provider_name]
|
||||
llm_provider = prov[provider_name](provider_config)
|
||||
|
||||
session = ChatSession(
|
||||
session_id=session_id, llm_provider=llm_provider, **session_params
|
||||
)
|
||||
self.sessions[session_id] = session
|
||||
logger.debug("创建新会话: id=%s", session_id)
|
||||
return session
|
||||
|
||||
return None
|
||||
|
||||
def delete_session(self, session_id: str):
|
||||
"""删除会话"""
|
||||
if session_id in self.sessions:
|
||||
del self.sessions[session_id]
|
||||
logger.debug("删除会话: id=%s", session_id)
|
||||
|
||||
def list_sessions(self) -> List[str]:
|
||||
"""列出所有会话ID"""
|
||||
return list(self.sessions.keys())
|
||||
|
||||
|
||||
# 全局聊天管理器实例
|
||||
_chat_manager: Optional[ChatManager] = None
|
||||
|
||||
|
||||
def get_chat_manager() -> ChatManager:
|
||||
"""获取全局聊天管理器实例"""
|
||||
global _chat_manager
|
||||
if _chat_manager is None:
|
||||
_chat_manager = ChatManager()
|
||||
logger.debug("创建全局聊天管理器")
|
||||
return _chat_manager
|
||||
|
||||
|
||||
def create_chat_session(
|
||||
session_id: Optional[str] = None, **session_params
|
||||
) -> ChatSession:
|
||||
"""创建或获取聊天会话(便捷函数)"""
|
||||
manager = get_chat_manager()
|
||||
return manager.get_session(session_id, True, **session_params)
|
||||
|
||||
|
||||
logger.debug("LLM 服务初始化完成")
|
||||
4
src/heurams/services/textproc.py
Normal file
4
src/heurams/services/textproc.py
Normal file
@@ -0,0 +1,4 @@
|
||||
def truncate(text):
|
||||
if len(text) <= 3:
|
||||
return text
|
||||
return text[:3] + ">"
|
||||
@@ -2,12 +2,12 @@
|
||||
from typing import Callable
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.providers.tts import TTSs
|
||||
from heurams.providers.tts import providers as prov
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
convert: Callable = TTSs[config_var.get().get("tts_provider")]
|
||||
convertor: Callable = prov[config_var.get()["services"]["tts"]].convert
|
||||
logger.debug(
|
||||
"TTS服务初始化完成, 使用 provider: %s", config_var.get().get("tts_provider")
|
||||
"TTS服务初始化完成, 使用 provider: %s", config_var.get()["services"]["tts"]
|
||||
)
|
||||
|
||||
@@ -3,8 +3,9 @@ from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
ver = "0.4.1"
|
||||
ver = "0.5.0"
|
||||
stage = "prototype"
|
||||
codename = "fledge" # 雏鸟, 0.4.x 版本
|
||||
codename = "fulcrom"
|
||||
codename_cn = "支点"
|
||||
|
||||
logger.info("HeurAMS 版本: %s (%s), 阶段: %s", ver, codename, stage)
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
# Utils - 实用工具
|
||||
脚本与部分分离式工具函数
|
||||
@@ -1,153 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DashboardScreen 的测试, 包括单元测试和 pilot 测试.
|
||||
"""
|
||||
import pathlib
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from textual.pilot import Pilot
|
||||
|
||||
from heurams.context import ConfigContext
|
||||
from heurams.interface.__main__ import HeurAMSApp
|
||||
from heurams.interface.screens.dashboard import DashboardScreen
|
||||
from heurams.services.config import ConfigFile
|
||||
|
||||
|
||||
class TestDashboardScreenUnit(unittest.TestCase):
|
||||
"""DashboardScreen 的单元测试(不启动完整应用)."""
|
||||
|
||||
def setUp(self):
|
||||
"""在每个测试之前运行, 设置临时目录和配置."""
|
||||
# 创建临时目录用于测试数据
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.temp_path = pathlib.Path(self.temp_dir.name)
|
||||
|
||||
# 创建默认配置, 并修改路径指向临时目录
|
||||
default_config_path = (
|
||||
pathlib.Path(__file__).parent.parent.parent
|
||||
/ "src/heurams/default/config/config.toml"
|
||||
)
|
||||
self.config = ConfigFile(default_config_path)
|
||||
# 更新配置中的路径
|
||||
config_data = self.config.data
|
||||
config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon")
|
||||
config_data["paths"]["electron_dir"] = str(self.temp_path / "electron")
|
||||
config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital")
|
||||
config_data["paths"]["cache_dir"] = str(self.temp_path / "cache")
|
||||
# 禁用快速通过, 避免测试干扰
|
||||
config_data["quick_pass"] = 0
|
||||
# 禁用时间覆盖
|
||||
config_data["daystamp_override"] = -1
|
||||
config_data["timestamp_override"] = -1
|
||||
|
||||
# 创建目录
|
||||
for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]:
|
||||
pathlib.Path(config_data["paths"][dir_key]).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
# 使用 ConfigContext 设置配置
|
||||
self.config_ctx = ConfigContext(self.config)
|
||||
self.config_ctx.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
"""在每个测试之后清理."""
|
||||
self.config_ctx.__exit__(None, None, None)
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def test_compose(self):
|
||||
"""测试 compose 方法返回正确的部件."""
|
||||
screen = DashboardScreen()
|
||||
# 手动调用 compose 并收集部件
|
||||
from textual.app import ComposeResult
|
||||
|
||||
result = screen.compose()
|
||||
widgets = list(result)
|
||||
# 检查是否包含 Header 和 Footer
|
||||
from textual.widgets import Footer, Header
|
||||
|
||||
header_present = any(isinstance(w, Header) for w in widgets)
|
||||
footer_present = any(isinstance(w, Footer) for w in widgets)
|
||||
self.assertTrue(header_present)
|
||||
self.assertTrue(footer_present)
|
||||
# 检查是否有 ScrollableContainer
|
||||
from textual.containers import ScrollableContainer
|
||||
|
||||
container_present = any(isinstance(w, ScrollableContainer) for w in widgets)
|
||||
self.assertTrue(container_present)
|
||||
# 使用 query_one 查找 union-list, 即使屏幕未挂载也可能有效
|
||||
list_view = screen.query_one("#union-list")
|
||||
self.assertIsNotNone(list_view)
|
||||
self.assertEqual(list_view.id, "union-list")
|
||||
self.assertEqual(list_view.__class__.__name__, "ListView")
|
||||
|
||||
def test_item_desc_generator(self):
|
||||
"""测试 item_desc_generator 函数."""
|
||||
screen = DashboardScreen()
|
||||
# 模拟一个文件名
|
||||
filename = "test.toml"
|
||||
result = screen.item_desc_generator(filename)
|
||||
self.assertIsInstance(result, dict)
|
||||
self.assertIn(0, result)
|
||||
self.assertIn(1, result)
|
||||
# 检查内容
|
||||
self.assertIn("test.toml", result[0])
|
||||
# 由于 electron 文件不存在, 应显示“尚未激活”
|
||||
self.assertIn("尚未激活", result[1])
|
||||
|
||||
|
||||
@unittest.skip("Pilot 测试需要进一步配置, 暂不运行")
|
||||
class TestDashboardScreenPilot(unittest.TestCase):
|
||||
"""使用 Textual Pilot 的集成测试."""
|
||||
|
||||
def setUp(self):
|
||||
"""配置临时目录和配置."""
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.temp_path = pathlib.Path(self.temp_dir.name)
|
||||
|
||||
default_config_path = (
|
||||
pathlib.Path(__file__).parent.parent.parent
|
||||
/ "src/heurams/default/config/config.toml"
|
||||
)
|
||||
self.config = ConfigFile(default_config_path)
|
||||
config_data = self.config.data
|
||||
config_data["paths"]["nucleon_dir"] = str(self.temp_path / "nucleon")
|
||||
config_data["paths"]["electron_dir"] = str(self.temp_path / "electron")
|
||||
config_data["paths"]["orbital_dir"] = str(self.temp_path / "orbital")
|
||||
config_data["paths"]["cache_dir"] = str(self.temp_path / "cache")
|
||||
config_data["quick_pass"] = 0
|
||||
config_data["daystamp_override"] = -1
|
||||
config_data["timestamp_override"] = -1
|
||||
|
||||
for dir_key in ["nucleon_dir", "electron_dir", "orbital_dir", "cache_dir"]:
|
||||
pathlib.Path(config_data["paths"][dir_key]).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
self.config_ctx = ConfigContext(self.config)
|
||||
self.config_ctx.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
self.config_ctx.__exit__(None, None, None)
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
def test_dashboard_loads_with_pilot(self):
|
||||
"""使用 Pilot 测试 DashboardScreen 加载."""
|
||||
with patch("heurams.interface.__main__.environment_check"):
|
||||
app = HeurAMSApp()
|
||||
# 注意: Pilot 在 Textual 6.9.0 中的用法可能不同
|
||||
# 以下为示例代码, 可能需要调整
|
||||
pilot = Pilot(app)
|
||||
# 等待应用启动
|
||||
pilot.pause()
|
||||
screen = app.screen
|
||||
self.assertEqual(screen.__class__.__name__, "DashboardScreen")
|
||||
union_list = app.query_one("#union-list")
|
||||
self.assertIsNotNone(union_list)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,186 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from heurams.kernel.algorithms.sm2 import SM2Algorithm
|
||||
|
||||
|
||||
class TestSM2Algorithm(unittest.TestCase):
|
||||
"""测试 SM2Algorithm 类"""
|
||||
|
||||
def setUp(self):
|
||||
# 模拟 timer 函数
|
||||
self.timestamp_patcher = patch(
|
||||
"heurams.kernel.algorithms.sm2.timer.get_timestamp"
|
||||
)
|
||||
self.daystamp_patcher = patch(
|
||||
"heurams.kernel.algorithms.sm2.timer.get_daystamp"
|
||||
)
|
||||
self.mock_get_timestamp = self.timestamp_patcher.start()
|
||||
self.mock_get_daystamp = self.daystamp_patcher.start()
|
||||
|
||||
# 设置固定返回值
|
||||
self.mock_get_timestamp.return_value = 1000.0
|
||||
self.mock_get_daystamp.return_value = 100
|
||||
|
||||
def tearDown(self):
|
||||
self.timestamp_patcher.stop()
|
||||
self.daystamp_patcher.stop()
|
||||
|
||||
def test_defaults(self):
|
||||
"""测试默认值"""
|
||||
defaults = SM2Algorithm.defaults
|
||||
self.assertEqual(defaults["efactor"], 2.5)
|
||||
self.assertEqual(defaults["real_rept"], 0)
|
||||
self.assertEqual(defaults["rept"], 0)
|
||||
self.assertEqual(defaults["interval"], 0)
|
||||
self.assertEqual(defaults["last_date"], 0)
|
||||
self.assertEqual(defaults["next_date"], 0)
|
||||
self.assertEqual(defaults["is_activated"], 0)
|
||||
# last_modify 是动态的, 仅检查存在性
|
||||
self.assertIn("last_modify", defaults)
|
||||
|
||||
def test_revisor_feedback_minus_one(self):
|
||||
"""测试 feedback = -1 时跳过更新"""
|
||||
algodata = {SM2Algorithm.algo_name: SM2Algorithm.defaults.copy()}
|
||||
SM2Algorithm.revisor(algodata, feedback=-1)
|
||||
# 数据应保持不变
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["efactor"], 2.5)
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 0)
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 0)
|
||||
|
||||
def test_revisor_feedback_less_than_3(self):
|
||||
"""测试 feedback < 3 重置 rept 和 interval"""
|
||||
algodata = {
|
||||
SM2Algorithm.algo_name: {
|
||||
"efactor": 2.5,
|
||||
"rept": 5,
|
||||
"interval": 10,
|
||||
"real_rept": 3,
|
||||
}
|
||||
}
|
||||
SM2Algorithm.revisor(algodata, feedback=2)
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 0)
|
||||
# rept=0 导致 interval 被设置为 1
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 1)
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["real_rept"], 4) # 递增
|
||||
|
||||
def test_revisor_feedback_greater_equal_3(self):
|
||||
"""测试 feedback >= 3 递增 rept"""
|
||||
algodata = {
|
||||
SM2Algorithm.algo_name: {
|
||||
"efactor": 2.5,
|
||||
"rept": 2,
|
||||
"interval": 6,
|
||||
"real_rept": 2,
|
||||
}
|
||||
}
|
||||
SM2Algorithm.revisor(algodata, feedback=4)
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 3)
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["real_rept"], 3)
|
||||
# interval 应根据 rept 和 efactor 重新计算
|
||||
# rept=3, interval = round(6 * 2.5) = 15
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 15)
|
||||
|
||||
def test_revisor_new_activation(self):
|
||||
"""测试 is_new_activation 重置 rept 和 efactor"""
|
||||
algodata = {
|
||||
SM2Algorithm.algo_name: {
|
||||
"efactor": 3.0,
|
||||
"rept": 5,
|
||||
"interval": 20,
|
||||
"real_rept": 5,
|
||||
}
|
||||
}
|
||||
SM2Algorithm.revisor(algodata, feedback=5, is_new_activation=True)
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["rept"], 0)
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["efactor"], 2.5)
|
||||
# interval 应为 1(因为 rept=0)
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 1)
|
||||
|
||||
def test_revisor_efactor_calculation(self):
|
||||
"""测试 efactor 计算"""
|
||||
algodata = {
|
||||
SM2Algorithm.algo_name: {
|
||||
"efactor": 2.5,
|
||||
"rept": 1,
|
||||
"interval": 6,
|
||||
"real_rept": 1,
|
||||
}
|
||||
}
|
||||
SM2Algorithm.revisor(algodata, feedback=5)
|
||||
# efactor = 2.5 + (0.1 - (5-5)*(0.08 + (5-5)*0.02)) = 2.5 + 0.1 = 2.6
|
||||
self.assertAlmostEqual(
|
||||
algodata[SM2Algorithm.algo_name]["efactor"], 2.6, places=6
|
||||
)
|
||||
|
||||
# 测试 efactor 下限为 1.3
|
||||
algodata[SM2Algorithm.algo_name]["efactor"] = 1.2
|
||||
SM2Algorithm.revisor(algodata, feedback=5)
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["efactor"], 1.3)
|
||||
|
||||
def test_revisor_interval_calculation(self):
|
||||
"""测试 interval 计算规则"""
|
||||
algodata = {
|
||||
SM2Algorithm.algo_name: {
|
||||
"efactor": 2.5,
|
||||
"rept": 0,
|
||||
"interval": 0,
|
||||
"real_rept": 0,
|
||||
}
|
||||
}
|
||||
SM2Algorithm.revisor(algodata, feedback=4)
|
||||
# rept 从 0 递增到 1, 因此 interval 应为 6
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 6)
|
||||
|
||||
# 现在 rept=1, 再次调用 revisor 递增到 2
|
||||
SM2Algorithm.revisor(algodata, feedback=4)
|
||||
# rept=2, interval = round(6 * 2.5) = 15
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["interval"], 15)
|
||||
|
||||
# 单独测试 rept=1 的情况
|
||||
algodata2 = {
|
||||
SM2Algorithm.algo_name: {
|
||||
"efactor": 2.5,
|
||||
"rept": 1,
|
||||
"interval": 0,
|
||||
"real_rept": 0,
|
||||
}
|
||||
}
|
||||
SM2Algorithm.revisor(algodata2, feedback=4)
|
||||
# rept 递增到 2, interval = round(0 * 2.5) = 0
|
||||
self.assertEqual(algodata2[SM2Algorithm.algo_name]["interval"], 0)
|
||||
|
||||
def test_revisor_updates_dates(self):
|
||||
"""测试更新日期字段"""
|
||||
algodata = {SM2Algorithm.algo_name: SM2Algorithm.defaults.copy()}
|
||||
self.mock_get_daystamp.return_value = 200
|
||||
SM2Algorithm.revisor(algodata, feedback=5)
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["last_date"], 200)
|
||||
self.assertEqual(
|
||||
algodata[SM2Algorithm.algo_name]["next_date"],
|
||||
200 + algodata[SM2Algorithm.algo_name]["interval"],
|
||||
)
|
||||
self.assertEqual(algodata[SM2Algorithm.algo_name]["last_modify"], 1000.0)
|
||||
|
||||
def test_is_due(self):
|
||||
"""测试 is_due 方法"""
|
||||
algodata = {SM2Algorithm.algo_name: {"next_date": 100}}
|
||||
self.mock_get_daystamp.return_value = 150
|
||||
self.assertTrue(SM2Algorithm.is_due(algodata))
|
||||
|
||||
algodata[SM2Algorithm.algo_name]["next_date"] = 200
|
||||
self.assertFalse(SM2Algorithm.is_due(algodata))
|
||||
|
||||
def test_rate(self):
|
||||
"""测试 rate 方法"""
|
||||
algodata = {SM2Algorithm.algo_name: {"efactor": 2.7}}
|
||||
self.assertEqual(SM2Algorithm.rate(algodata), "2.7")
|
||||
|
||||
def test_nextdate(self):
|
||||
"""测试 nextdate 方法"""
|
||||
algodata = {SM2Algorithm.algo_name: {"next_date": 12345}}
|
||||
self.assertEqual(SM2Algorithm.nextdate(algodata), 12345)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,202 +0,0 @@
|
||||
import json
|
||||
import pathlib
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import toml
|
||||
|
||||
from heurams.context import ConfigContext
|
||||
from heurams.kernel.particles.atom import Atom, atom_registry
|
||||
from heurams.kernel.particles.electron import Electron
|
||||
from heurams.kernel.particles.nucleon import Nucleon
|
||||
from heurams.kernel.particles.orbital import Orbital
|
||||
from heurams.services.config import ConfigFile
|
||||
|
||||
|
||||
class TestAtom(unittest.TestCase):
|
||||
"""测试 Atom 类"""
|
||||
|
||||
def setUp(self):
|
||||
"""在每个测试之前运行"""
|
||||
# 创建临时目录用于持久化测试
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.temp_path = pathlib.Path(self.temp_dir.name)
|
||||
|
||||
# 创建默认配置
|
||||
self.config = ConfigFile(
|
||||
pathlib.Path(__file__).parent.parent.parent.parent
|
||||
/ "src/heurams/default/config/config.toml"
|
||||
)
|
||||
|
||||
# 使用 ConfigContext 设置配置
|
||||
self.config_ctx = ConfigContext(self.config)
|
||||
self.config_ctx.__enter__()
|
||||
|
||||
# 清空全局注册表
|
||||
atom_registry.clear()
|
||||
|
||||
def tearDown(self):
|
||||
"""在每个测试之后运行"""
|
||||
self.config_ctx.__exit__(None, None, None)
|
||||
self.temp_dir.cleanup()
|
||||
atom_registry.clear()
|
||||
|
||||
def test_init(self):
|
||||
"""测试 Atom 初始化"""
|
||||
atom = Atom("test_atom")
|
||||
self.assertEqual(atom.ident, "test_atom")
|
||||
self.assertIn("test_atom", atom_registry)
|
||||
self.assertEqual(atom_registry["test_atom"], atom)
|
||||
|
||||
# 检查 registry 默认值
|
||||
self.assertIsNone(atom.registry["nucleon"])
|
||||
self.assertIsNone(atom.registry["electron"])
|
||||
self.assertIsNone(atom.registry["orbital"])
|
||||
self.assertEqual(atom.registry["nucleon_fmt"], "toml")
|
||||
self.assertEqual(atom.registry["electron_fmt"], "json")
|
||||
self.assertEqual(atom.registry["orbital_fmt"], "toml")
|
||||
|
||||
def test_link(self):
|
||||
"""测试 link 方法"""
|
||||
atom = Atom("test_link")
|
||||
nucleon = Nucleon("test_nucleon", {"content": "test content"})
|
||||
|
||||
atom.link("nucleon", nucleon)
|
||||
self.assertEqual(atom.registry["nucleon"], nucleon)
|
||||
|
||||
# 测试链接不支持的键
|
||||
with self.assertRaises(ValueError):
|
||||
atom.link("invalid_key", "value")
|
||||
|
||||
def test_link_triggers_do_eval(self):
|
||||
"""测试 link 后触发 do_eval"""
|
||||
atom = Atom("test_eval_trigger")
|
||||
nucleon = Nucleon("test_nucleon", {"content": "eval:1+1"})
|
||||
|
||||
with patch.object(atom, "do_eval") as mock_do_eval:
|
||||
atom.link("nucleon", nucleon)
|
||||
mock_do_eval.assert_called_once()
|
||||
|
||||
def test_persist_toml(self):
|
||||
"""测试 TOML 持久化"""
|
||||
atom = Atom("test_persist_toml")
|
||||
nucleon = Nucleon("test_nucleon", {"content": "test"})
|
||||
atom.link("nucleon", nucleon)
|
||||
|
||||
# 设置路径
|
||||
test_path = self.temp_path / "test.toml"
|
||||
atom.link("nucleon_path", test_path)
|
||||
|
||||
atom.persist("nucleon")
|
||||
|
||||
# 验证文件存在且内容正确
|
||||
self.assertTrue(test_path.exists())
|
||||
with open(test_path, "r") as f:
|
||||
data = toml.load(f)
|
||||
self.assertEqual(data["ident"], "test_nucleon")
|
||||
self.assertEqual(data["payload"]["content"], "test")
|
||||
|
||||
def test_persist_json(self):
|
||||
"""测试 JSON 持久化"""
|
||||
atom = Atom("test_persist_json")
|
||||
electron = Electron("test_electron", {})
|
||||
atom.link("electron", electron)
|
||||
|
||||
test_path = self.temp_path / "test.json"
|
||||
atom.link("electron_path", test_path)
|
||||
|
||||
atom.persist("electron")
|
||||
|
||||
self.assertTrue(test_path.exists())
|
||||
with open(test_path, "r") as f:
|
||||
data = json.load(f)
|
||||
self.assertIn("supermemo2", data)
|
||||
|
||||
def test_persist_invalid_format(self):
|
||||
"""测试无效持久化格式"""
|
||||
atom = Atom("test_invalid_format")
|
||||
nucleon = Nucleon("test_nucleon", {})
|
||||
atom.link("nucleon", nucleon)
|
||||
atom.link("nucleon_path", self.temp_path / "test.txt")
|
||||
atom.registry["nucleon_fmt"] = "invalid"
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
atom.persist("nucleon")
|
||||
|
||||
def test_persist_no_path(self):
|
||||
"""测试未初始化路径的持久化"""
|
||||
atom = Atom("test_no_path")
|
||||
nucleon = Nucleon("test_nucleon", {})
|
||||
atom.link("nucleon", nucleon)
|
||||
# 不设置 nucleon_path
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
atom.persist("nucleon")
|
||||
|
||||
def test_getitem_setitem(self):
|
||||
"""测试 __getitem__ 和 __setitem__"""
|
||||
atom = Atom("test_getset")
|
||||
nucleon = Nucleon("test_nucleon", {})
|
||||
|
||||
atom["nucleon"] = nucleon
|
||||
self.assertEqual(atom["nucleon"], nucleon)
|
||||
|
||||
# 测试不支持的键
|
||||
with self.assertRaises(KeyError):
|
||||
_ = atom["invalid_key"]
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
atom["invalid_key"] = "value"
|
||||
|
||||
def test_do_eval_with_eval_string(self):
|
||||
"""测试 do_eval 处理 eval: 字符串"""
|
||||
atom = Atom("test_do_eval")
|
||||
nucleon = Nucleon(
|
||||
"test_nucleon",
|
||||
{"content": "eval:'hello' + ' world'", "number": "eval:2 + 3"},
|
||||
)
|
||||
atom.link("nucleon", nucleon)
|
||||
|
||||
# do_eval 应该在链接时自动调用
|
||||
# 检查 eval 表达式是否被求值
|
||||
self.assertEqual(nucleon.payload["content"], "hello world")
|
||||
self.assertEqual(nucleon.payload["number"], "5")
|
||||
|
||||
def test_do_eval_with_config_access(self):
|
||||
"""测试 do_eval 访问配置"""
|
||||
atom = Atom("test_eval_config")
|
||||
nucleon = Nucleon(
|
||||
"test_nucleon", {"max_riddles": "eval:default['mcq']['max_riddles_num']"}
|
||||
)
|
||||
atom.link("nucleon", nucleon)
|
||||
|
||||
# 配置中 puzzles.mcq.max_riddles_num = 2
|
||||
self.assertEqual(nucleon.payload["max_riddles"], 2)
|
||||
|
||||
def test_placeholder(self):
|
||||
"""测试静态方法 placeholder"""
|
||||
placeholder = Atom.placeholder()
|
||||
self.assertIsInstance(placeholder, tuple)
|
||||
self.assertEqual(len(placeholder), 3)
|
||||
self.assertIsInstance(placeholder[0], Electron)
|
||||
self.assertIsInstance(placeholder[1], Nucleon)
|
||||
self.assertIsInstance(placeholder[2], dict)
|
||||
|
||||
def test_atom_registry_management(self):
|
||||
"""测试全局注册表管理"""
|
||||
# 创建多个 Atom
|
||||
atom1 = Atom("atom1")
|
||||
atom2 = Atom("atom2")
|
||||
|
||||
self.assertEqual(len(atom_registry), 2)
|
||||
self.assertEqual(atom_registry["atom1"], atom1)
|
||||
self.assertEqual(atom_registry["atom2"], atom2)
|
||||
|
||||
# 测试 bidict 的反向查找
|
||||
self.assertEqual(atom_registry.inverse[atom1], "atom1")
|
||||
self.assertEqual(atom_registry.inverse[atom2], "atom2")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,179 +0,0 @@
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from heurams.kernel.algorithms import algorithms
|
||||
from heurams.kernel.particles.electron import Electron
|
||||
|
||||
|
||||
class TestElectron(unittest.TestCase):
|
||||
"""测试 Electron 类"""
|
||||
|
||||
def setUp(self):
|
||||
# 模拟 timer.get_timestamp 返回固定值
|
||||
self.timestamp_patcher = patch(
|
||||
"heurams.kernel.particles.electron.timer.get_timestamp"
|
||||
)
|
||||
self.mock_get_timestamp = self.timestamp_patcher.start()
|
||||
self.mock_get_timestamp.return_value = 1234567890.0
|
||||
|
||||
def tearDown(self):
|
||||
self.timestamp_patcher.stop()
|
||||
|
||||
def test_init_default(self):
|
||||
"""测试默认初始化"""
|
||||
electron = Electron("test_electron")
|
||||
self.assertEqual(electron.ident, "test_electron")
|
||||
self.assertEqual(electron.algo, algorithms["supermemo2"])
|
||||
self.assertIn(electron.algo, electron.algodata)
|
||||
self.assertIsInstance(electron.algodata[electron.algo], dict)
|
||||
# 检查默认值(排除动态字段)
|
||||
defaults = electron.algo.defaults
|
||||
for key, value in defaults.items():
|
||||
if key == "last_modify":
|
||||
# last_modify 是动态的, 只检查存在性
|
||||
self.assertIn(key, electron.algodata[electron.algo])
|
||||
elif key == "is_activated":
|
||||
# TODO: 调查为什么 is_activated 是 1
|
||||
self.assertEqual(electron.algodata[electron.algo][key], 1)
|
||||
else:
|
||||
self.assertEqual(electron.algodata[electron.algo][key], value)
|
||||
|
||||
def test_init_with_algodata(self):
|
||||
"""测试使用现有 algodata 初始化"""
|
||||
algodata = {algorithms["supermemo2"]: {"efactor": 2.5, "interval": 1}}
|
||||
electron = Electron("test_electron", algodata=algodata)
|
||||
self.assertEqual(electron.algodata[electron.algo]["efactor"], 2.5)
|
||||
self.assertEqual(electron.algodata[electron.algo]["interval"], 1)
|
||||
# 其他字段可能不存在, 因为未提供默认初始化
|
||||
# 检查 real_rept 不存在
|
||||
self.assertNotIn("real_rept", electron.algodata[electron.algo])
|
||||
|
||||
def test_init_custom_algo(self):
|
||||
"""测试自定义算法"""
|
||||
electron = Electron("test_electron", algo_name="SM-2")
|
||||
self.assertEqual(electron.algo, algorithms["SM-2"])
|
||||
self.assertIn(electron.algo, electron.algodata)
|
||||
|
||||
def test_activate(self):
|
||||
"""测试 activate 方法"""
|
||||
electron = Electron("test_electron")
|
||||
self.assertEqual(electron.algodata[electron.algo]["is_activated"], 0)
|
||||
electron.activate()
|
||||
self.assertEqual(electron.algodata[electron.algo]["is_activated"], 1)
|
||||
self.assertEqual(electron.algodata[electron.algo]["last_modify"], 1234567890.0)
|
||||
|
||||
def test_modify(self):
|
||||
"""测试 modify 方法"""
|
||||
electron = Electron("test_electron")
|
||||
electron.modify("interval", 5)
|
||||
self.assertEqual(electron.algodata[electron.algo]["interval"], 5)
|
||||
self.assertEqual(electron.algodata[electron.algo]["last_modify"], 1234567890.0)
|
||||
|
||||
# 修改不存在的字段应记录警告但不引发异常
|
||||
with patch("heurams.kernel.particles.electron.logger.warning") as mock_warning:
|
||||
electron.modify("unknown_field", 99)
|
||||
mock_warning.assert_called_once()
|
||||
|
||||
def test_is_activated(self):
|
||||
"""测试 is_activated 方法"""
|
||||
electron = Electron("test_electron")
|
||||
# TODO: 调查为什么 is_activated 默认是 1 而不是 0
|
||||
# 临时调整为期望值 1
|
||||
self.assertEqual(electron.is_activated(), 1)
|
||||
electron.activate()
|
||||
self.assertEqual(electron.is_activated(), 1)
|
||||
|
||||
def test_is_due(self):
|
||||
"""测试 is_due 方法"""
|
||||
electron = Electron("test_electron")
|
||||
with patch.object(electron.algo, "is_due") as mock_is_due:
|
||||
mock_is_due.return_value = 1
|
||||
result = electron.is_due()
|
||||
mock_is_due.assert_called_once_with(electron.algodata)
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
def test_rate(self):
|
||||
"""测试 rate 方法"""
|
||||
electron = Electron("test_electron")
|
||||
with patch.object(electron.algo, "rate") as mock_rate:
|
||||
mock_rate.return_value = "good"
|
||||
result = electron.get_rate()
|
||||
mock_rate.assert_called_once_with(electron.algodata)
|
||||
self.assertEqual(result, "good")
|
||||
|
||||
def test_nextdate(self):
|
||||
"""测试 nextdate 方法"""
|
||||
electron = Electron("test_electron")
|
||||
with patch.object(electron.algo, "nextdate") as mock_nextdate:
|
||||
mock_nextdate.return_value = 1234568000
|
||||
result = electron.nextdate()
|
||||
mock_nextdate.assert_called_once_with(electron.algodata)
|
||||
self.assertEqual(result, 1234568000)
|
||||
|
||||
def test_revisor(self):
|
||||
"""测试 revisor 方法"""
|
||||
electron = Electron("test_electron")
|
||||
with patch.object(electron.algo, "revisor") as mock_revisor:
|
||||
electron.revisor(quality=3, is_new_activation=True)
|
||||
mock_revisor.assert_called_once_with(electron.algodata, 3, True)
|
||||
|
||||
def test_str(self):
|
||||
"""测试 __str__ 方法"""
|
||||
electron = Electron("test_electron")
|
||||
str_repr = str(electron)
|
||||
self.assertIn("记忆单元预览", str_repr)
|
||||
self.assertIn("test_electron", str_repr)
|
||||
# 算法类名会出现在字符串表示中
|
||||
self.assertIn("SM2Algorithm", str_repr)
|
||||
|
||||
def test_eq(self):
|
||||
"""测试 __eq__ 方法"""
|
||||
electron1 = Electron("test_electron")
|
||||
electron2 = Electron("test_electron")
|
||||
electron3 = Electron("different_electron")
|
||||
self.assertEqual(electron1, electron2)
|
||||
self.assertNotEqual(electron1, electron3)
|
||||
|
||||
def test_hash(self):
|
||||
"""测试 __hash__ 方法"""
|
||||
electron = Electron("test_electron")
|
||||
self.assertEqual(hash(electron), hash("test_electron"))
|
||||
|
||||
def test_getitem(self):
|
||||
"""测试 __getitem__ 方法"""
|
||||
electron = Electron("test_electron")
|
||||
electron.activate()
|
||||
self.assertEqual(electron["ident"], "test_electron")
|
||||
self.assertEqual(electron["is_activated"], 1)
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
_ = electron["nonexistent_key"]
|
||||
|
||||
def test_setitem(self):
|
||||
"""测试 __setitem__ 方法"""
|
||||
electron = Electron("test_electron")
|
||||
electron["interval"] = 10
|
||||
self.assertEqual(electron.algodata[electron.algo]["interval"], 10)
|
||||
self.assertEqual(electron.algodata[electron.algo]["last_modify"], 1234567890.0)
|
||||
|
||||
with self.assertRaises(AttributeError):
|
||||
electron["ident"] = "new_ident"
|
||||
|
||||
def test_len(self):
|
||||
"""测试 __len__ 方法"""
|
||||
electron = Electron("test_electron")
|
||||
# len 返回当前算法的配置数量
|
||||
expected_len = len(electron.algo.defaults)
|
||||
self.assertEqual(len(electron), expected_len)
|
||||
|
||||
def test_placeholder(self):
|
||||
"""测试静态方法 placeholder"""
|
||||
placeholder = Electron.placeholder()
|
||||
self.assertIsInstance(placeholder, Electron)
|
||||
self.assertEqual(placeholder.ident, "电子对象样例内容")
|
||||
self.assertEqual(placeholder.algo, algorithms["supermemo2"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,108 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from heurams.kernel.particles.nucleon import Nucleon
|
||||
|
||||
|
||||
class TestNucleon(unittest.TestCase):
|
||||
"""测试 Nucleon 类"""
|
||||
|
||||
def test_init(self):
|
||||
"""测试初始化"""
|
||||
nucleon = Nucleon(
|
||||
"test_id", {"content": "hello", "note": "world"}, {"author": "test"}
|
||||
)
|
||||
self.assertEqual(nucleon.ident, "test_id")
|
||||
self.assertEqual(nucleon.payload, {"content": "hello", "note": "world"})
|
||||
self.assertEqual(nucleon.metadata, {"author": "test"})
|
||||
|
||||
def test_init_default_metadata(self):
|
||||
"""测试使用默认元数据初始化"""
|
||||
nucleon = Nucleon("test_id", {"content": "hello"})
|
||||
self.assertEqual(nucleon.ident, "test_id")
|
||||
self.assertEqual(nucleon.payload, {"content": "hello"})
|
||||
self.assertEqual(nucleon.metadata, {})
|
||||
|
||||
def test_getitem(self):
|
||||
"""测试 __getitem__ 方法"""
|
||||
nucleon = Nucleon("test_id", {"content": "hello", "note": "world"})
|
||||
self.assertEqual(nucleon["ident"], "test_id")
|
||||
self.assertEqual(nucleon["content"], "hello")
|
||||
self.assertEqual(nucleon["note"], "world")
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
_ = nucleon["nonexistent"]
|
||||
|
||||
def test_iter(self):
|
||||
"""测试 __iter__ 方法"""
|
||||
nucleon = Nucleon("test_id", {"a": 1, "b": 2, "c": 3})
|
||||
keys = list(nucleon)
|
||||
self.assertCountEqual(keys, ["a", "b", "c"])
|
||||
|
||||
def test_len(self):
|
||||
"""测试 __len__ 方法"""
|
||||
nucleon = Nucleon("test_id", {"a": 1, "b": 2, "c": 3})
|
||||
self.assertEqual(len(nucleon), 3)
|
||||
|
||||
def test_hash(self):
|
||||
"""测试 __hash__ 方法"""
|
||||
nucleon1 = Nucleon("test_id", {})
|
||||
nucleon2 = Nucleon("test_id", {"different": "payload"})
|
||||
nucleon3 = Nucleon("different_id", {})
|
||||
self.assertEqual(hash(nucleon1), hash(nucleon2)) # 相同 ident
|
||||
self.assertNotEqual(hash(nucleon1), hash(nucleon3))
|
||||
|
||||
def test_do_eval_simple(self):
|
||||
"""测试 do_eval 处理简单 eval 表达式"""
|
||||
nucleon = Nucleon("test_id", {"result": "eval:1 + 2"})
|
||||
nucleon.do_eval()
|
||||
self.assertEqual(nucleon.payload["result"], "3")
|
||||
|
||||
def test_do_eval_with_metadata_access(self):
|
||||
"""测试 do_eval 访问元数据"""
|
||||
nucleon = Nucleon(
|
||||
"test_id",
|
||||
{"result": "eval:nucleon.metadata.get('value', 0)"},
|
||||
{"value": 42},
|
||||
)
|
||||
nucleon.do_eval()
|
||||
self.assertEqual(nucleon.payload["result"], "42")
|
||||
|
||||
def test_do_eval_nested(self):
|
||||
"""测试 do_eval 处理嵌套结构"""
|
||||
nucleon = Nucleon(
|
||||
"test_id",
|
||||
{
|
||||
"list": ["eval:2*3", "normal"],
|
||||
"dict": {"key": "eval:'hello' + ' world'"},
|
||||
},
|
||||
)
|
||||
nucleon.do_eval()
|
||||
self.assertEqual(nucleon.payload["list"][0], "6")
|
||||
self.assertEqual(nucleon.payload["list"][1], "normal")
|
||||
self.assertEqual(nucleon.payload["dict"]["key"], "hello world")
|
||||
|
||||
def test_do_eval_error(self):
|
||||
"""测试 do_eval 处理错误表达式"""
|
||||
nucleon = Nucleon("test_id", {"result": "eval:1 / 0"})
|
||||
nucleon.do_eval()
|
||||
self.assertIn("此 eval 实例发生错误", nucleon.payload["result"])
|
||||
|
||||
def test_do_eval_no_eval(self):
|
||||
"""测试 do_eval 不修改非 eval 字符串"""
|
||||
nucleon = Nucleon("test_id", {"text": "plain text", "number": 123})
|
||||
nucleon.do_eval()
|
||||
self.assertEqual(nucleon.payload["text"], "plain text")
|
||||
self.assertEqual(nucleon.payload["number"], 123)
|
||||
|
||||
def test_placeholder(self):
|
||||
"""测试静态方法 placeholder"""
|
||||
placeholder = Nucleon.placeholder()
|
||||
self.assertIsInstance(placeholder, Nucleon)
|
||||
self.assertEqual(placeholder.ident, "核子对象样例内容")
|
||||
self.assertEqual(placeholder.payload, {})
|
||||
self.assertEqual(placeholder.metadata, {})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,23 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from heurams.kernel.puzzles.base import BasePuzzle
|
||||
|
||||
|
||||
class TestBasePuzzle(unittest.TestCase):
|
||||
"""测试 BasePuzzle 基类"""
|
||||
|
||||
def test_refresh_not_implemented(self):
|
||||
"""测试 refresh 方法未实现时抛出异常"""
|
||||
puzzle = BasePuzzle()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
puzzle.refresh()
|
||||
|
||||
def test_str(self):
|
||||
"""测试 __str__ 方法"""
|
||||
puzzle = BasePuzzle()
|
||||
self.assertEqual(str(puzzle), "谜题: BasePuzzle")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user