feat: 一系列新功能
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -5,6 +5,7 @@
|
||||
__pycache__/
|
||||
.idea/
|
||||
cache/
|
||||
data/repo/cngk
|
||||
#nucleon/test.toml
|
||||
electron/test.toml
|
||||
*.egg-info/
|
||||
|
||||
@@ -9,7 +9,7 @@ timestamp_override = -1
|
||||
quick_pass = 1
|
||||
|
||||
# 对于每个项目的默认新记忆原子数量
|
||||
scheduled_num = 8
|
||||
scheduled_num = 999
|
||||
|
||||
# UTC 时间戳修正 仅用于 UNIX 日时间戳的生成修正, 单位为秒
|
||||
timezone_offset = +28800 # 中国标准时间 (UTC+8)
|
||||
@@ -17,7 +17,7 @@ timezone_offset = +28800 # 中国标准时间 (UTC+8)
|
||||
[interface]
|
||||
|
||||
[interface.memorizor]
|
||||
autovoice = true # 自动语音播放, 仅限于 recognition 组件
|
||||
autovoice = 0 # 自动语音播放, 仅限于 recognition 组件
|
||||
|
||||
[algorithm]
|
||||
default = "SM-2" # 主要算法; 可选项: SM-2, SM-15M, FSRS
|
||||
|
||||
@@ -340,7 +340,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from heurams.utils.lict import Lict\n",
|
||||
"from heurams.kernel.auxiliary.lict import Lict\n",
|
||||
"\n",
|
||||
"lct = Lict() # 空的\n",
|
||||
"lct = Lict(initlist=[(\"name\", \"tom\"), (\"age\", 12), (\"enemy\", \"jerry\")]) # 基于列表\n",
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import heurams.kernel.repolib as repolib
|
||||
import heurams.kernel.particles as pt
|
||||
from heurams.services.textproc import truncate
|
||||
from pathlib import Path
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
import heurams.kernel.repolib as repolib
|
||||
from heurams.services.textproc import truncate
|
||||
|
||||
repo = repolib.Repo.create_from_repodir(Path("./test_repo"))
|
||||
alist = list()
|
||||
|
||||
6
package-lock.json
generated
Normal file
6
package-lock.json
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"name": "HeurAMS",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {}
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
edge-tts==7.0.2
|
||||
fsspec==2025.12.0
|
||||
jieba==0.42.1
|
||||
openai>=1.0.0
|
||||
playsound==1.2.2
|
||||
tabulate==0.9.0
|
||||
textual==5.3.0
|
||||
textual==7.0.0
|
||||
toml==0.10.2
|
||||
transitions==0.9.3
|
||||
@@ -4,8 +4,8 @@
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
from contextvars import ContextVar
|
||||
import shutil
|
||||
from contextvars import ContextVar
|
||||
|
||||
from heurams.services.config import ConfigFile
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from typing import Type
|
||||
|
||||
from textual.app import App
|
||||
from textual.driver import Driver
|
||||
from textual.widgets import Button
|
||||
|
||||
from heurams.context import config_var
|
||||
@@ -6,8 +9,12 @@ from heurams.services.logger import get_logger
|
||||
|
||||
from .screens.about import AboutScreen
|
||||
from .screens.dashboard import DashboardScreen
|
||||
from .screens.llmchat import LLMChatScreen
|
||||
from .screens.navigator import NavigatorScreen
|
||||
from .screens.precache import PrecachingScreen
|
||||
from .screens.radio import RadioScreen
|
||||
from .screens.repocreator import RepoCreatorScreen
|
||||
from .screens.repoeditor import RepoEditorScreen
|
||||
from .screens.synctool import SyncScreen
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -35,13 +42,10 @@ class HeurAMSApp(App):
|
||||
CSS_PATH = "css/main.tcss"
|
||||
SUB_TITLE = "启发式辅助记忆调度器"
|
||||
BINDINGS = [
|
||||
("q", "quit", "退出"),
|
||||
("d", "toggle_dark", "切换色调"),
|
||||
("1", "app.push_screen('dashboard')", "仪表盘"),
|
||||
("2", "app.push_screen('precache_all')", "缓存管理器"),
|
||||
("3", "app.push_screen('repo_creator')", "创建新仓库"),
|
||||
# ("4", "app.push_screen('synctool')", "同步工具"),
|
||||
("0", "app.push_screen('about')", "版本信息"),
|
||||
("q", "go_back", "退出"),
|
||||
("d", "toggle_dark", "主题"),
|
||||
("n", "app.push_screen('navigator')", "导航"),
|
||||
("z", "app.push_screen('about')", "关于"),
|
||||
]
|
||||
SCREENS = {
|
||||
"dashboard": DashboardScreen,
|
||||
@@ -49,6 +53,10 @@ class HeurAMSApp(App):
|
||||
"precache_all": PrecachingScreen,
|
||||
"synctool": SyncScreen,
|
||||
"about": AboutScreen,
|
||||
"navigator": NavigatorScreen,
|
||||
"radio": RadioScreen,
|
||||
"repo_editor": RepoEditorScreen,
|
||||
"llmchat": LLMChatScreen,
|
||||
}
|
||||
|
||||
def on_mount(self) -> None:
|
||||
@@ -56,8 +64,11 @@ class HeurAMSApp(App):
|
||||
self.push_screen("dashboard")
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
self.exit(event.button.id)
|
||||
pass
|
||||
# self.exit(event.button.id)
|
||||
|
||||
def action_go_back(self) -> None:
|
||||
quit()
|
||||
|
||||
def action_do_nothing(self):
|
||||
print("DO NOTHING")
|
||||
self.refresh()
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
NavigatorScreen {
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
#dialog {
|
||||
grid-size: 2;
|
||||
grid-gutter: 1 1;
|
||||
grid-rows: 1fr 3;
|
||||
padding: 0 1;
|
||||
width: 46;
|
||||
height: 12;
|
||||
border: thick $background 80%;
|
||||
background: $surface;
|
||||
}
|
||||
|
||||
/* LLM 聊天界面样式 */
|
||||
LLMChatScreen {
|
||||
background: $surface;
|
||||
}
|
||||
|
||||
#chat-container {
|
||||
height: 100%;
|
||||
padding: 1;
|
||||
}
|
||||
|
||||
#toolbar {
|
||||
height: 3;
|
||||
margin-bottom: 1;
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
#toolbar Button {
|
||||
margin: 0 1;
|
||||
}
|
||||
|
||||
#chat-log {
|
||||
height: 1fr;
|
||||
border: solid $primary;
|
||||
padding: 1;
|
||||
background: $surface;
|
||||
}
|
||||
|
||||
#input-container {
|
||||
height: 3;
|
||||
margin-top: 1;
|
||||
align: center middle;
|
||||
}
|
||||
|
||||
#message-input {
|
||||
width: 1fr;
|
||||
margin-right: 1;
|
||||
}
|
||||
|
||||
#status-bar {
|
||||
height: 1;
|
||||
margin-top: 1;
|
||||
text-style: italic;
|
||||
color: $text-muted;
|
||||
}
|
||||
|
||||
.session-label {
|
||||
color: $primary;
|
||||
text-style: bold;
|
||||
}
|
||||
|
||||
@@ -10,6 +10,9 @@ from heurams.context import *
|
||||
|
||||
|
||||
class AboutScreen(Screen):
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
]
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=True)
|
||||
@@ -22,10 +25,14 @@ class AboutScreen(Screen):
|
||||
|
||||
开发代号: {version.codename.capitalize()} {version.codename_cn}
|
||||
|
||||
一个基于启发式算法的开放源代码记忆调度器, 旨在帮助用户更高效地进行记忆工作与学习规划.
|
||||
一个基于启发式算法的辅助记忆调度器, 旨在帮助用户更高效地进行记忆工作与学习规划.
|
||||
|
||||
以 AGPL-3.0 开放源代码
|
||||
|
||||
您可在项目主页 https://ams.imwangzhiyu.xyz 获取用户指南, 开发文档与软件更新
|
||||
|
||||
如果您觉得这个软件有用, 请给它添加一个星标 :)
|
||||
|
||||
开发人员:
|
||||
|
||||
- Wang Zhiyu([@pluvium27](https://github.com/pluvium27)): 项目作者
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""仪表盘界面"""
|
||||
|
||||
import pathlib
|
||||
from pathlib import Path
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import ScrollableContainer
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import Button, Footer, Header, Label, ListItem, ListView, Static
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
import heurams.services.timer as timer
|
||||
import heurams.services.version as version
|
||||
from heurams.context import *
|
||||
@@ -14,10 +16,10 @@ from heurams.kernel.particles import *
|
||||
from heurams.kernel.repolib import *
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
from pathlib import Path
|
||||
from .about import AboutScreen
|
||||
from .navigator import NavigatorScreen
|
||||
from .preparation import PreparationScreen
|
||||
from .radio import RadioScreen
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -26,6 +28,9 @@ class DashboardScreen(Screen):
|
||||
"""主仪表盘屏幕"""
|
||||
|
||||
SUB_TITLE = "仪表盘"
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -50,12 +55,12 @@ class DashboardScreen(Screen):
|
||||
Label(f"全局算法设置: {config_var.get()['algorithm']['default']}"),
|
||||
Label("选择待学习或待修改的项目:", classes="title-label"),
|
||||
ListView(id="repo-list", classes="repo-list-view"),
|
||||
Label(f'"潜进" 启发式辅助记忆调度器 | 版本 {version.ver} '),
|
||||
Label(f'"潜进" 启发式辅助记忆调度器 版本 {version.ver} '),
|
||||
)
|
||||
yield Footer()
|
||||
|
||||
def _load_data(self):
|
||||
self.repo_dirs = Repo.probe_vaild_repos_in_dir(
|
||||
self.repo_dirs = Repo.probe_valid_repos_in_dir(
|
||||
Path(config_var.get()["paths"]["data"]) / "repo"
|
||||
)
|
||||
for repo_dir in self.repo_dirs:
|
||||
@@ -69,7 +74,6 @@ class DashboardScreen(Screen):
|
||||
unit_sum = len(repo)
|
||||
activated_sum = 0
|
||||
nextdate = 0x3F3F3F3F
|
||||
is_unfinished = unit_sum > activated_sum
|
||||
for i in repo.ident_index:
|
||||
nucleon = pt.Nucleon.create_on_nucleonic_data(
|
||||
nucleonic_data=repo.nucleonic_data_lict.get_itemic_unit(i)
|
||||
@@ -82,10 +86,11 @@ class DashboardScreen(Screen):
|
||||
if electron.is_due():
|
||||
is_due = 1
|
||||
nextdate = min(nextdate, electron.nextdate())
|
||||
is_unfinished = unit_sum > activated_sum
|
||||
if is_unfinished:
|
||||
nextdate = min(nextdate, timer.get_daystamp())
|
||||
need_to_study = is_due or is_unfinished
|
||||
prompt = f"{title}\0\n 进度: {activated_sum}/{unit_sum}\n {"需要学习" if need_to_study else "无需操作"}"
|
||||
prompt = f"{title}\0\n 进度: {activated_sum}/{unit_sum} ({round(activated_sum/unit_sum*100)}%)\n {"需要学习" if need_to_study else "无需操作"}"
|
||||
stat = {
|
||||
"is_due": is_due,
|
||||
"unit_sum": unit_sum,
|
||||
@@ -139,7 +144,7 @@ class DashboardScreen(Screen):
|
||||
return
|
||||
|
||||
selected_label = event.item.query_one(Label)
|
||||
label_text = str(selected_label.renderable)
|
||||
label_text = str(selected_label.render())
|
||||
|
||||
if "未找到任何仓库" in label_text:
|
||||
return
|
||||
@@ -158,3 +163,12 @@ class DashboardScreen(Screen):
|
||||
def action_quit_app(self) -> None:
|
||||
"""退出应用程序"""
|
||||
self.app.exit()
|
||||
|
||||
def action_open_navigator(self) -> None:
|
||||
"""打开导航器"""
|
||||
self.app.push_screen(NavigatorScreen())
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""处理按钮点击事件"""
|
||||
if event.button.id == "navigator-button":
|
||||
self.action_open_navigator()
|
||||
|
||||
204
src/heurams/interface/screens/favmgr.py
Normal file
204
src/heurams/interface/screens/favmgr.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""收藏夹管理器界面"""
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import ScrollableContainer
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import (
|
||||
Button,
|
||||
Footer,
|
||||
Header,
|
||||
Label,
|
||||
ListItem,
|
||||
ListView,
|
||||
Markdown,
|
||||
Static,
|
||||
)
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.kernel.repolib import Repo
|
||||
from heurams.services.favorite_service import FavoriteItem, favorite_manager
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FavoriteManagerScreen(Screen):
|
||||
"""收藏夹管理器屏幕"""
|
||||
|
||||
SUB_TITLE = "收藏夹"
|
||||
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
("d", "toggle_dark", ""),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(name, id, classes)
|
||||
self.favorites: List[FavoriteItem] = []
|
||||
self._load_favorites()
|
||||
|
||||
def _load_favorites(self) -> None:
|
||||
"""加载收藏列表"""
|
||||
self.favorites = favorite_manager.get_all()
|
||||
logger.debug("加载 %d 个收藏项", len(self.favorites))
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""组合界面组件"""
|
||||
yield Header(show_clock=True)
|
||||
with ScrollableContainer(id="favorites-container"):
|
||||
if not self.favorites:
|
||||
yield Label("暂无收藏", classes="empty-label")
|
||||
yield Static("使用 * 键在记忆界面中添加收藏.")
|
||||
else:
|
||||
yield Label(f"共 {len(self.favorites)} 个收藏项", classes="count-label")
|
||||
yield ListView(id="favorites-list")
|
||||
yield Footer()
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""挂载后填充列表"""
|
||||
if self.favorites:
|
||||
list_view = self.query_one("#favorites-list")
|
||||
for fav in self.favorites:
|
||||
list_view.append(self._create_favorite_item(fav)) # type: ignore
|
||||
|
||||
def _encode_favorite_key(self, repo_path: str, ident: str) -> str:
|
||||
"""编码仓库路径和标识符为安全的按钮 ID 部分"""
|
||||
# 使用 \x00 分隔两部分,然后进行 base64 编码
|
||||
combined = f"{repo_path}\x00{ident}"
|
||||
encoded = base64.urlsafe_b64encode(combined.encode()).decode()
|
||||
# 去掉填充的等号
|
||||
return encoded.rstrip("=")
|
||||
|
||||
def _decode_favorite_key(self, key: str) -> tuple[str, str]:
|
||||
"""解码按钮 ID 部分为仓库路径和标识符"""
|
||||
# 补全等号以使长度是4的倍数
|
||||
padded = key + "=" * ((4 - len(key) % 4) % 4)
|
||||
decoded = base64.urlsafe_b64decode(padded.encode()).decode()
|
||||
repo_path, ident = decoded.split("\x00", 1)
|
||||
return repo_path, ident
|
||||
|
||||
def _create_favorite_item(self, fav: FavoriteItem) -> ListItem:
|
||||
"""创建收藏项列表项"""
|
||||
# 尝试获取仓库信息
|
||||
repo_info = self._get_repo_info(fav.repo_path, fav)
|
||||
title = repo_info.get("title", fav.repo_path) if repo_info else fav.repo_path
|
||||
content_preview = repo_info.get("content_preview", "") if repo_info else ""
|
||||
added_time = self._format_time(fav.added)
|
||||
|
||||
# 构建显示文本
|
||||
display_text = f"[b]{title}[/b] ({fav.ident})\n"
|
||||
if content_preview:
|
||||
display_text += f"{content_preview}\n"
|
||||
display_text += f"添加于: {added_time}"
|
||||
if fav.tags:
|
||||
display_text += f" 标签: {', '.join(fav.tags)}"
|
||||
|
||||
# 创建安全的按钮 ID
|
||||
button_key = self._encode_favorite_key(fav.repo_path, fav.ident)
|
||||
# 创建列表项,包含移除按钮
|
||||
container = ScrollableContainer(
|
||||
Markdown(display_text, classes="favorite-content"),
|
||||
Button("移除", id=f"remove-{button_key}", variant="error"),
|
||||
classes="favorite-item",
|
||||
)
|
||||
return ListItem(container)
|
||||
|
||||
def _get_repo_info(self, repo_path: str, fav: FavoriteItem) -> Optional[dict]:
|
||||
"""获取仓库信息(标题、原子内容预览)"""
|
||||
try:
|
||||
data_repo = Path(config_var.get()["paths"]["data"]) / "repo"
|
||||
repo_dir = data_repo / repo_path
|
||||
if not repo_dir.exists():
|
||||
logger.warning("仓库目录不存在: %s", repo_dir)
|
||||
return None
|
||||
repo = Repo.create_from_repodir(repo_dir)
|
||||
# 获取原子内容预览
|
||||
content_preview = ""
|
||||
payload = repo.payload
|
||||
# 查找对应 ident 的 payload 条目
|
||||
for ident_key, content in payload:
|
||||
if ident_key == fav.ident:
|
||||
# 截断过长的内容
|
||||
if isinstance(content, dict) and "content" in content:
|
||||
text = content["content"]
|
||||
else:
|
||||
text = str(content)
|
||||
if len(text) > 100:
|
||||
content_preview = text[:100] + "..."
|
||||
else:
|
||||
content_preview = text
|
||||
break
|
||||
return {
|
||||
"title": repo.manifest["title"],
|
||||
"content_preview": content_preview,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("获取仓库信息失败: %s", e)
|
||||
return None
|
||||
|
||||
def _format_time(self, timestamp: int) -> str:
|
||||
"""格式化时间戳"""
|
||||
from datetime import datetime
|
||||
|
||||
dt = datetime.fromtimestamp(timestamp)
|
||||
return dt.strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
"""处理按钮点击事件"""
|
||||
button_id = event.button.id
|
||||
if button_id and button_id.startswith("remove-"):
|
||||
# 提取编码后的键
|
||||
key = button_id[7:] # 去掉 "remove-" 前缀
|
||||
try:
|
||||
repo_path, ident = self._decode_favorite_key(key)
|
||||
self._remove_favorite(repo_path, ident)
|
||||
except Exception as e:
|
||||
logger.error("解析按钮 ID 失败: %s", e)
|
||||
self.app.notify("操作失败: 无效的按钮标识", severity="error")
|
||||
|
||||
def _remove_favorite(self, repo_path: str, ident: str) -> None:
|
||||
"""移除收藏项"""
|
||||
if favorite_manager.remove(repo_path, ident):
|
||||
self.app.notify(f"已移除收藏: {ident}", severity="information")
|
||||
# 重新加载列表
|
||||
self._load_favorites()
|
||||
# 刷新界面
|
||||
self._refresh_list()
|
||||
else:
|
||||
self.app.notify(f"移除失败: {ident}", severity="error")
|
||||
|
||||
def _refresh_list(self) -> None:
|
||||
"""刷新列表显示"""
|
||||
container = self.query_one("#favorites-container")
|
||||
# 清空容器
|
||||
for child in container.children:
|
||||
child.remove()
|
||||
# 重新组合
|
||||
if not self.favorites:
|
||||
container.mount(Label("暂无收藏", classes="empty-label"))
|
||||
container.mount(Static("使用 * 键在记忆界面中添加收藏。"))
|
||||
else:
|
||||
container.mount(
|
||||
Label(f"共 {len(self.favorites)} 个收藏项", classes="count-label")
|
||||
)
|
||||
list_view = ListView(id="favorites-list")
|
||||
container.mount(list_view)
|
||||
for fav in self.favorites:
|
||||
list_view.append(self._create_favorite_item(fav))
|
||||
|
||||
def action_go_back(self) -> None:
|
||||
"""返回上一屏幕"""
|
||||
self.app.pop_screen()
|
||||
|
||||
def action_toggle_dark(self) -> None:
|
||||
"""切换暗黑模式"""
|
||||
self.app.dark = not self.app.dark # type: ignore
|
||||
@@ -1 +0,0 @@
|
||||
"""笔记界面"""
|
||||
333
src/heurams/interface/screens/llmchat.py
Normal file
333
src/heurams/interface/screens/llmchat.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""LLM 聊天界面"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, Horizontal
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import Button, Footer, Header, Input, Label, RichLog, Static
|
||||
|
||||
from heurams.context import *
|
||||
from heurams.services.llm_service import ChatSession, get_chat_manager
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LLMChatScreen(Screen):
|
||||
"""LLM 聊天屏幕"""
|
||||
|
||||
SUB_TITLE = "AI 聊天"
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
("ctrl+s", "save_session", "保存会话"),
|
||||
("ctrl+l", "load_session", "加载会话"),
|
||||
("ctrl+n", "new_session", "新建会话"),
|
||||
("ctrl+c", "clear_history", "清空历史"),
|
||||
("escape", "focus_input", "聚焦输入"),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: Optional[str] = None,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(name, id, classes)
|
||||
self.session_id = session_id
|
||||
self.chat_manager = get_chat_manager()
|
||||
self.current_session: Optional[ChatSession] = None
|
||||
self.is_streaming = False
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""组合界面组件"""
|
||||
yield Header(show_clock=True)
|
||||
|
||||
with Container(id="chat-container"):
|
||||
# 顶部工具栏
|
||||
with Horizontal(id="toolbar"):
|
||||
yield Button("新建会话", id="new-session", variant="primary")
|
||||
yield Button("保存会话", id="save-session", variant="default")
|
||||
yield Button("加载会话", id="load-session", variant="default")
|
||||
yield Button("清空历史", id="clear-history", variant="default")
|
||||
yield Button("设置系统提示", id="set-system-prompt", variant="default")
|
||||
yield Static(" | ", classes="separator")
|
||||
yield Label("当前会话:", classes="label")
|
||||
yield Static(id="current-session-label", classes="session-label")
|
||||
|
||||
# 聊天记录显示区域
|
||||
yield RichLog(
|
||||
id="chat-log",
|
||||
wrap=True,
|
||||
highlight=True,
|
||||
markup=True,
|
||||
classes="chat-log",
|
||||
)
|
||||
|
||||
# 输入区域
|
||||
with Horizontal(id="input-container"):
|
||||
yield Input(
|
||||
id="message-input",
|
||||
placeholder="输入消息... (按 Ctrl+Enter 发送, Esc 聚焦)",
|
||||
classes="message-input",
|
||||
)
|
||||
yield Button(
|
||||
"发送", id="send-button", variant="primary", classes="send-button"
|
||||
)
|
||||
|
||||
# 状态栏
|
||||
yield Static(id="status-bar", classes="status-bar")
|
||||
|
||||
yield Footer()
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""挂载组件时初始化"""
|
||||
# 获取或创建会话
|
||||
self.current_session = self.chat_manager.get_session(self.session_id)
|
||||
if self.current_session is None:
|
||||
self.notify("无法创建 LLM 会话,请检查配置", severity="error")
|
||||
return
|
||||
|
||||
# 更新会话标签
|
||||
self.query_one("#current-session-label", Static).update(
|
||||
f"{self.current_session.session_id}"
|
||||
)
|
||||
|
||||
# 加载历史消息到聊天记录
|
||||
self._display_history()
|
||||
|
||||
# 聚焦输入框
|
||||
self.query_one("#message-input", Input).focus()
|
||||
|
||||
# 检查配置
|
||||
self._check_config()
|
||||
|
||||
def _check_config(self):
|
||||
"""检查 LLM 配置"""
|
||||
config = config_var.get()
|
||||
provider_name = config["services"]["llm"]
|
||||
provider_config = config["providers"]["llm"][provider_name]
|
||||
|
||||
if provider_name == "openai":
|
||||
if not provider_config.get("key") and not provider_config.get("url"):
|
||||
self.notify(
|
||||
"未配置 OpenAI API key 或 URL,请在 config.toml 中配置 [providers.llm.openai]",
|
||||
severity="warning",
|
||||
)
|
||||
|
||||
def _display_history(self):
|
||||
"""显示当前会话的历史消息"""
|
||||
if not self.current_session:
|
||||
return
|
||||
|
||||
chat_log = self.query_one("#chat-log", RichLog)
|
||||
chat_log.clear()
|
||||
|
||||
for msg in self.current_session.get_history():
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
|
||||
if role == "user":
|
||||
chat_log.write(f"[bold cyan]你:[/bold cyan] {content}")
|
||||
elif role == "assistant":
|
||||
chat_log.write(f"[bold green]AI:[/bold green] {content}")
|
||||
elif role == "system":
|
||||
# 系统消息不显示在聊天记录中
|
||||
pass
|
||||
|
||||
def _add_message_to_log(self, role: str, content: str):
|
||||
"""添加消息到聊天记录显示"""
|
||||
chat_log = self.query_one("#chat-log", RichLog)
|
||||
if role == "user":
|
||||
chat_log.write(f"[bold cyan]你:[/bold cyan] {content}")
|
||||
elif role == "assistant":
|
||||
chat_log.write(f"[bold green]AI:[/bold green] {content}")
|
||||
chat_log.scroll_end()
|
||||
|
||||
async def on_input_submitted(self, event: Input.Submitted):
|
||||
"""处理输入提交"""
|
||||
if event.input.id == "message-input":
|
||||
await self._send_message()
|
||||
|
||||
async def on_button_pressed(self, event: Button.Pressed):
|
||||
"""处理按钮点击"""
|
||||
button_id = event.button.id
|
||||
|
||||
if button_id == "send-button":
|
||||
await self._send_message()
|
||||
elif button_id == "new-session":
|
||||
self.action_new_session()
|
||||
elif button_id == "save-session":
|
||||
self.action_save_session()
|
||||
elif button_id == "load-session":
|
||||
self.action_load_session()
|
||||
elif button_id == "clear-history":
|
||||
self.action_clear_history()
|
||||
elif button_id == "set-system-prompt":
|
||||
self.action_set_system_prompt()
|
||||
|
||||
async def _send_message(self):
|
||||
"""发送当前输入的消息"""
|
||||
if not self.current_session or self.is_streaming:
|
||||
return
|
||||
|
||||
input_widget = self.query_one("#message-input", Input)
|
||||
message = input_widget.value.strip()
|
||||
|
||||
if not message:
|
||||
return
|
||||
|
||||
# 清空输入框
|
||||
input_widget.value = ""
|
||||
|
||||
# 显示用户消息
|
||||
self._add_message_to_log("user", message)
|
||||
|
||||
# 禁用输入和按钮
|
||||
self._set_input_state(disabled=True)
|
||||
self.is_streaming = True
|
||||
|
||||
# 更新状态
|
||||
self.query_one("#status-bar", Static).update("AI 正在思考...")
|
||||
|
||||
try:
|
||||
# 发送消息并获取响应
|
||||
response = await self.current_session.send_message(message)
|
||||
|
||||
# 显示AI响应
|
||||
self._add_message_to_log("assistant", response)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"请求失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
self._add_message_to_log("assistant", f"[red]{error_msg}[/red]")
|
||||
self.notify(error_msg, severity="error")
|
||||
|
||||
finally:
|
||||
# 恢复输入和按钮
|
||||
self._set_input_state(disabled=False)
|
||||
self.is_streaming = False
|
||||
self.query_one("#status-bar", Static).update("就绪")
|
||||
input_widget.focus()
|
||||
|
||||
def _set_input_state(self, disabled: bool):
|
||||
"""设置输入控件状态"""
|
||||
self.query_one("#message-input", Input).disabled = disabled
|
||||
self.query_one("#send-button", Button).disabled = disabled
|
||||
|
||||
async def action_save_session(self):
|
||||
"""保存当前会话到文件"""
|
||||
if not self.current_session:
|
||||
self.notify("无当前会话", severity="error")
|
||||
return
|
||||
|
||||
# 默认保存到 data/chat_sessions/ 目录
|
||||
save_dir = Path(config_var.get()["paths"]["data"]) / "chat_sessions"
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
file_path = save_dir / f"{self.current_session.session_id}.json"
|
||||
self.current_session.save_to_file(file_path)
|
||||
|
||||
self.notify(f"会话已保存到 {file_path}", severity="information")
|
||||
|
||||
async def action_load_session(self):
|
||||
"""从文件加载会话"""
|
||||
# 简化实现:加载默认目录下的第一个会话文件
|
||||
save_dir = Path(config_var.get()["paths"]["data"]) / "chat_sessions"
|
||||
if not save_dir.exists():
|
||||
self.notify(f"目录不存在: {save_dir}", severity="error")
|
||||
return
|
||||
|
||||
session_files = list(save_dir.glob("*.json"))
|
||||
if not session_files:
|
||||
self.notify("未找到会话文件", severity="error")
|
||||
return
|
||||
|
||||
# 使用第一个文件(在实际应用中可以让用户选择)
|
||||
file_path = session_files[0]
|
||||
|
||||
try:
|
||||
# 获取 LLM 提供者
|
||||
provider_name = config_var.get()["services"]["llm"]
|
||||
provider_config = config_var.get()["providers"]["llm"][provider_name]
|
||||
from heurams.providers.llm import providers as prov
|
||||
|
||||
llm_provider = prov[provider_name](provider_config)
|
||||
|
||||
# 加载会话
|
||||
self.current_session = ChatSession.load_from_file(file_path, llm_provider)
|
||||
|
||||
# 更新聊天管理器
|
||||
self.chat_manager.sessions[self.current_session.session_id] = (
|
||||
self.current_session
|
||||
)
|
||||
|
||||
# 更新UI
|
||||
self.query_one("#current-session-label", Static).update(
|
||||
f"{self.current_session.session_id}"
|
||||
)
|
||||
self._display_history()
|
||||
|
||||
self.notify(f"已加载会话: {file_path.name}", severity="information")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("加载会话失败: %s", e)
|
||||
self.notify(f"加载失败: {str(e)}", severity="error")
|
||||
|
||||
async def action_new_session(self):
|
||||
"""创建新会话"""
|
||||
# 简单实现:使用时间戳作为会话ID
|
||||
import time
|
||||
|
||||
new_session_id = f"session_{int(time.time())}"
|
||||
|
||||
self.current_session = self.chat_manager.get_session(new_session_id)
|
||||
|
||||
# 更新UI
|
||||
self.query_one("#current-session-label", Static).update(
|
||||
f"{self.current_session.session_id}"
|
||||
)
|
||||
self._display_history()
|
||||
|
||||
self.notify(f"已创建新会话: {new_session_id}", severity="information")
|
||||
self.query_one("#message-input", Input).focus()
|
||||
|
||||
async def action_clear_history(self):
|
||||
"""清空当前会话历史"""
|
||||
if not self.current_session:
|
||||
return
|
||||
|
||||
self.current_session.clear_history()
|
||||
self._display_history()
|
||||
self.notify("历史已清空", severity="information")
|
||||
|
||||
async def action_set_system_prompt(self):
|
||||
"""设置系统提示词"""
|
||||
if not self.current_session:
|
||||
return
|
||||
|
||||
# 使用输入框获取新提示词
|
||||
input_widget = self.query_one("#message-input", Input)
|
||||
current_value = input_widget.value
|
||||
|
||||
# 临时修改输入框提示
|
||||
input_widget.placeholder = "输入系统提示词... (按 Enter 确认, Esc 取消)"
|
||||
input_widget.value = self.current_session.system_prompt
|
||||
|
||||
# 等待用户输入
|
||||
self.notify("请输入系统提示词,按 Enter 确认", severity="information")
|
||||
|
||||
# 实际应用中需要更复杂的交互,这里简化处理
|
||||
# 用户手动输入后按 Enter 会触发 on_input_submitted
|
||||
# 这里我们只修改占位符,实际系统提示词设置需要额外界面
|
||||
|
||||
def action_focus_input(self):
|
||||
"""聚焦到输入框"""
|
||||
self.query_one("#message-input", Input).focus()
|
||||
|
||||
def action_go_back(self):
|
||||
"""返回上级屏幕"""
|
||||
self.app.pop_screen()
|
||||
@@ -1,6 +1,7 @@
|
||||
"""队列式记忆工作界面"""
|
||||
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from textual.app import ComposeResult
|
||||
@@ -9,10 +10,11 @@ from textual.reactive import reactive
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import Button, Footer, Header, Label, Static
|
||||
|
||||
import heurams.kernel.puzzles as pz
|
||||
import heurams.kernel.particles as pt
|
||||
import heurams.kernel.puzzles as pz
|
||||
from heurams.context import config_var
|
||||
from heurams.kernel.reactor import *
|
||||
from heurams.services.favorite_service import favorite_manager
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .. import shim
|
||||
@@ -32,6 +34,7 @@ class MemScreen(Screen):
|
||||
("p", "prev", "查看上一个"),
|
||||
("d", "toggle_dark", ""),
|
||||
("v", "play_voice", "朗读"),
|
||||
("*", "toggle_favorite", "收藏"),
|
||||
("0,1,2,3", "app.push_screen('about')", ""),
|
||||
]
|
||||
|
||||
@@ -44,6 +47,7 @@ class MemScreen(Screen):
|
||||
self,
|
||||
phaser: Phaser,
|
||||
save_func: Callable,
|
||||
repo=None,
|
||||
name=None,
|
||||
id=None,
|
||||
classes=None,
|
||||
@@ -51,10 +55,10 @@ class MemScreen(Screen):
|
||||
super().__init__(name, id, classes)
|
||||
self.phaser = phaser
|
||||
self.save_func = save_func
|
||||
self.repo = repo
|
||||
self.update_state()
|
||||
self.fission: Fission
|
||||
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=True)
|
||||
with ScrollableContainer():
|
||||
@@ -84,6 +88,10 @@ class MemScreen(Screen):
|
||||
|
||||
def _get_progress_text(self):
|
||||
s = f"阶段: {self.procession.phase.name}\n"
|
||||
# 收藏状态
|
||||
if self.repo is not None:
|
||||
fav_status = "★" if self._is_current_atom_favorited() else "☆"
|
||||
s += f"收藏: {fav_status}\n"
|
||||
if config_var.get().get("debug_topline", 0):
|
||||
try:
|
||||
alia = self.fission.get_current_puzzle_inf()["alia"] # type: ignore
|
||||
@@ -129,6 +137,7 @@ class MemScreen(Screen):
|
||||
for i in container.children:
|
||||
i.remove()
|
||||
from heurams.interface.widgets.finished import Finished
|
||||
|
||||
if config_var.get().get("persist_to_file", 0):
|
||||
self.save_func()
|
||||
container.mount(Finished(is_saved=config_var.get().get("persist_to_file", 0)))
|
||||
@@ -208,3 +217,40 @@ class MemScreen(Screen):
|
||||
|
||||
def action_quick_fail(self):
|
||||
self.rating = 3
|
||||
|
||||
def _get_repo_rel_path(self) -> str:
|
||||
"""获取仓库相对路径(相对于 data/repo)"""
|
||||
if self.repo is None:
|
||||
return ""
|
||||
# self.repo.source 是 Path 对象,指向仓库目录
|
||||
repo_full_path = self.repo.source
|
||||
data_repo_path = Path(config_var.get()["paths"]["data"]) / "repo"
|
||||
try:
|
||||
rel_path = repo_full_path.relative_to(data_repo_path)
|
||||
return str(rel_path)
|
||||
except ValueError:
|
||||
# 如果不在 data/repo 下,则返回完整路径(字符串形式)
|
||||
return str(repo_full_path)
|
||||
|
||||
def _is_current_atom_favorited(self) -> bool:
|
||||
"""检查当前原子是否已收藏"""
|
||||
if self.repo is None:
|
||||
return False
|
||||
repo_path = self._get_repo_rel_path()
|
||||
return favorite_manager.has(repo_path, self.atom.ident)
|
||||
|
||||
def action_toggle_favorite(self):
|
||||
"""切换收藏状态"""
|
||||
if self.repo is None:
|
||||
self.app.notify("无法收藏:未关联仓库", severity="error")
|
||||
return
|
||||
repo_path = self._get_repo_rel_path()
|
||||
ident = self.atom.ident
|
||||
if favorite_manager.has(repo_path, ident):
|
||||
favorite_manager.remove(repo_path, ident)
|
||||
self.app.notify(f"已取消收藏:{ident}", severity="information")
|
||||
else:
|
||||
favorite_manager.add(repo_path, ident)
|
||||
self.app.notify(f"已收藏:{ident}", severity="information")
|
||||
# 更新显示(如果需要)
|
||||
self.update_display()
|
||||
|
||||
93
src/heurams/interface/screens/navigator.py
Normal file
93
src/heurams/interface/screens/navigator.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import webbrowser
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Grid, ScrollableContainer
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Button, Footer, Header, Label, ListItem, ListView, Static
|
||||
|
||||
from heurams.context import *
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .favmgr import FavoriteManagerScreen
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class NavigatorScreen(ModalScreen):
|
||||
"""导航器模态窗口"""
|
||||
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
("escape", "go_back", "返回"),
|
||||
("n", "go_back", "切换"),
|
||||
]
|
||||
|
||||
SCREENS = [
|
||||
("仪表盘", "dashboard"),
|
||||
("电台", "radio"),
|
||||
("语言模型集成", "llmchat"),
|
||||
# ("创建仓库", "repo_creator"),
|
||||
("缓存管理器", "precache_all"),
|
||||
("收藏夹管理器", FavoriteManagerScreen),
|
||||
("关于此软件", "about"),
|
||||
("调试日志", "logviewer"),
|
||||
# ("同步工具", "synctool"),
|
||||
# ("仓库编辑器", "repo_editor"),
|
||||
]
|
||||
|
||||
OTHERS = [
|
||||
("退出程序", "self.app.exit()"),
|
||||
("项目主页", "webbrowser.open('https://ams.imwangzhiyu.xyz')"),
|
||||
]
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""组合界面组件"""
|
||||
with Grid(id="dialog"):
|
||||
yield Label(
|
||||
"[b]请选择要跳转的功能\n或记忆会话实例[/b]\n\n将在此处显示提示",
|
||||
classes="title-label",
|
||||
)
|
||||
yield ListView(
|
||||
*[ListItem(Label(title)) for title, _ in (self.SCREENS + self.OTHERS)],
|
||||
id="nav-list",
|
||||
classes="nav-list-view",
|
||||
)
|
||||
yield Static("按下回车以完成切换\n所有会话将被保存")
|
||||
yield Button(
|
||||
"关闭 (n)", id="close_button", variant="primary", classes="close-button", flat=True
|
||||
)
|
||||
|
||||
def on_mount(self) -> None:
|
||||
# 设置焦点到列表
|
||||
nav_list = self.query_one("#nav-list", ListView)
|
||||
nav_list.focus()
|
||||
|
||||
def on_list_view_selected(self, event) -> None:
|
||||
if not isinstance(event.item, ListItem):
|
||||
return
|
||||
selected_label = event.item.query_one(Label)
|
||||
label_text = str(selected_label.render())
|
||||
# 查找对应的屏幕标识
|
||||
for title, screen_id in self.SCREENS:
|
||||
if title == label_text:
|
||||
self.app.pop_screen()
|
||||
# 跳转到目标屏幕
|
||||
if isinstance(screen_id, str):
|
||||
# 已注册的字符串标识符
|
||||
self.app.push_screen(screen_id)
|
||||
else:
|
||||
self.app.push_screen(screen_id())
|
||||
return
|
||||
for title, cmd in self.OTHERS:
|
||||
if title == label_text:
|
||||
exec(cmd)
|
||||
return
|
||||
return
|
||||
|
||||
def on_button_pressed(self, event) -> None:
|
||||
event.stop()
|
||||
if event.button.id == "close_button":
|
||||
self.action_go_back()
|
||||
|
||||
def action_go_back(self) -> None:
|
||||
self.app.pop_screen()
|
||||
@@ -3,7 +3,7 @@
|
||||
import pathlib
|
||||
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Horizontal, ScrollableContainer
|
||||
from textual.containers import Horizontal, ScrollableContainer, Container
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import Button, Footer, Header, Label, ProgressBar, Static
|
||||
from textual.worker import get_current_worker
|
||||
@@ -12,7 +12,18 @@ import heurams.kernel.particles as pt
|
||||
import heurams.services.hasher as hasher
|
||||
from heurams.context import *
|
||||
|
||||
cache_dir = pathlib.Path(config_var.get()["paths"]["data"]) / "cache" / "voice"
|
||||
# 兼容性缓存路径:优先使用 paths.cache,否则使用 data/cache
|
||||
paths = config_var.get()["paths"]
|
||||
cache_dir = pathlib.Path(paths.get("cache", paths["data"] + "/cache")) / "voice"
|
||||
|
||||
|
||||
def format_size(bytes_num: int) -> str:
|
||||
"""将字节数格式化为人类可读的字符串"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
||||
if bytes_num < 1024.0:
|
||||
return f"{bytes_num:.2f} {unit}"
|
||||
bytes_num /= 1024.0 # type: ignore
|
||||
return f"{bytes_num:.2f} PB"
|
||||
|
||||
|
||||
class PrecachingScreen(Screen):
|
||||
@@ -26,7 +37,9 @@ class PrecachingScreen(Screen):
|
||||
"""
|
||||
|
||||
SUB_TITLE = "缓存管理器"
|
||||
BINDINGS = [("q", "go_back", "返回")]
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
]
|
||||
|
||||
def __init__(self, nucleons: list = [], desc: str = ""):
|
||||
super().__init__(name=None, id=None, classes=None)
|
||||
@@ -40,21 +53,70 @@ class PrecachingScreen(Screen):
|
||||
self.precache_worker = None
|
||||
self.cancel_flag = 0
|
||||
self.desc = desc
|
||||
# 不再需要缓存配置,保留配置读取以兼容
|
||||
self.cache_stats = {"total_size": 0, "file_count": 0, "human_size": "0 B", "cached_units": 0, "total_units": 0, "cache_rate": 0}
|
||||
self._update_cache_stats()
|
||||
|
||||
def _get_total_units(self) -> int:
|
||||
"""获取所有仓库的总单元数"""
|
||||
from heurams.context import config_var
|
||||
from heurams.kernel.repolib import Repo
|
||||
repo_path = pathlib.Path(config_var.get()["paths"]["data"]) / "repo"
|
||||
repo_dirs = Repo.probe_valid_repos_in_dir(repo_path)
|
||||
repos = map(Repo.create_from_repodir, repo_dirs)
|
||||
total = 0
|
||||
for repo in repos:
|
||||
try:
|
||||
total += len(repo.ident_index)
|
||||
except:
|
||||
continue
|
||||
return total
|
||||
|
||||
def _update_cache_stats(self) -> None:
|
||||
"""更新缓存统计信息"""
|
||||
total_size = 0
|
||||
file_count = 0
|
||||
cached_units = 0
|
||||
if cache_dir.exists():
|
||||
for file in cache_dir.rglob("*"):
|
||||
if file.is_file():
|
||||
total_size += file.stat().st_size
|
||||
file_count += 1
|
||||
if file.suffix.lower() == ".wav":
|
||||
cached_units += 1
|
||||
total_units = self._get_total_units()
|
||||
cache_rate = (cached_units / total_units * 100) if total_units > 0 else 0
|
||||
|
||||
self.cache_stats["total_size"] = total_size
|
||||
self.cache_stats["file_count"] = file_count
|
||||
self.cache_stats["human_size"] = format_size(total_size)
|
||||
self.cache_stats["cached_units"] = cached_units
|
||||
self.cache_stats["total_units"] = total_units
|
||||
self.cache_stats["cache_rate"] = cache_rate
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=True)
|
||||
with ScrollableContainer(id="precache_container"):
|
||||
yield Label("[b]音频预缓存[/b]", classes="title-label")
|
||||
with Container(classes="cache-info"):
|
||||
yield Static(f"缓存路径: {cache_dir}", classes="cache-path")
|
||||
yield Static(f"文件数: {self.cache_stats['file_count']}", classes="cache-count")
|
||||
yield Static(f"总大小: {self.cache_stats['human_size']}", classes="cache-size")
|
||||
yield Button("刷新", id="refresh_cache_stats", variant="default")
|
||||
with Container():
|
||||
yield Static(
|
||||
f"缓存率: {self.cache_stats.get('cache_rate', 0):.1f}% (已缓存 {self.cache_stats.get('cached_units', 0)} / {self.cache_stats.get('total_units', 0)} 个单元)",
|
||||
classes="cache-usage-text"
|
||||
)
|
||||
if self.nucleons:
|
||||
yield Static(f"目标单元归属: [b]{self.desc}[/b]", classes="target-info")
|
||||
yield Static(f"单元数量: {len(self.nucleons)}", classes="target-info")
|
||||
else:
|
||||
yield Static("目标: 所有单元", classes="target-info")
|
||||
|
||||
if self.nucleons:
|
||||
yield Static(f"目标单元归属: [b]{self.desc}[/b]", classes="target-info")
|
||||
yield Static(f"单元数量: {len(self.nucleons)}", classes="target-info")
|
||||
else:
|
||||
yield Static("目标: 所有单元", classes="target-info")
|
||||
|
||||
yield Static(id="status", classes="status-info")
|
||||
yield Static(id="current_item", classes="current-item")
|
||||
yield ProgressBar(total=100, show_eta=False, id="progress_bar")
|
||||
yield Static(id="status", classes="status-info")
|
||||
yield Static(id="current_item", classes="current-item")
|
||||
yield ProgressBar(total=100, show_eta=False, id="progress_bar")
|
||||
|
||||
with Horizontal(classes="button-group"):
|
||||
if not self.is_precaching:
|
||||
@@ -72,6 +134,7 @@ class PrecachingScreen(Screen):
|
||||
def on_mount(self):
|
||||
"""挂载时初始化状态"""
|
||||
self.update_status("就绪", "等待开始...")
|
||||
self._update_cache_display()
|
||||
|
||||
def update_status(self, status, current_item="", progress=None):
|
||||
"""更新状态显示"""
|
||||
@@ -86,6 +149,25 @@ class PrecachingScreen(Screen):
|
||||
progress_bar.progress = progress
|
||||
progress_bar.advance(0) # 刷新显示
|
||||
|
||||
def _update_cache_display(self) -> None:
|
||||
"""更新缓存信息显示"""
|
||||
# 更新统计信息
|
||||
self._update_cache_stats()
|
||||
# 更新缓存率进度条
|
||||
# 更新缓存大小和文件数显示
|
||||
cache_count_widget = self.query_one(".cache-count", Static)
|
||||
cache_size_widget = self.query_one(".cache-size", Static)
|
||||
cache_usage_text = self.query_one(".cache-usage-text", Static)
|
||||
if cache_count_widget:
|
||||
cache_count_widget.update(f"文件数: {self.cache_stats['file_count']}")
|
||||
if cache_size_widget:
|
||||
cache_size_widget.update(f"总大小: {self.cache_stats['human_size']}")
|
||||
if cache_usage_text:
|
||||
cache_usage_text.update(
|
||||
f"缓存率: {self.cache_stats.get('cache_rate', 0):.1f}% "
|
||||
f"(已缓存 {self.cache_stats.get('cached_units', 0)} / {self.cache_stats.get('total_units', 0)} 个单元)"
|
||||
)
|
||||
|
||||
def precache_by_text(self, text: str):
|
||||
"""预缓存单段文本的音频"""
|
||||
from heurams.context import config_var, rootdir, workdir
|
||||
@@ -151,7 +233,7 @@ class PrecachingScreen(Screen):
|
||||
from heurams.kernel.repolib import Repo
|
||||
|
||||
repo_path = pathlib.Path(config_var.get()["paths"]["data"]) / "repo"
|
||||
repo_dirs = Repo.probe_vaild_repos_in_dir(repo_path)
|
||||
repo_dirs = Repo.probe_valid_repos_in_dir(repo_path)
|
||||
repos = map(Repo.create_from_repodir, repo_dirs)
|
||||
|
||||
# 计算总项目数
|
||||
@@ -207,12 +289,17 @@ class PrecachingScreen(Screen):
|
||||
|
||||
shutil.rmtree(cache_dir, ignore_errors=True)
|
||||
self.update_status("已清空", "音频缓存已清空", 0)
|
||||
self._update_cache_display() # 更新缓存统计显示
|
||||
except Exception as e:
|
||||
self.update_status("错误", f"清空缓存失败: {e}")
|
||||
self.cancel_flag = 1
|
||||
self.processed = 0
|
||||
self.progress = 0
|
||||
|
||||
elif event.button.id == "refresh_cache_stats":
|
||||
# 刷新缓存统计信息
|
||||
self._update_cache_display()
|
||||
self.app.notify("缓存信息已刷新", severity="information")
|
||||
elif event.button.id == "go_back":
|
||||
self.action_go_back()
|
||||
|
||||
@@ -220,8 +307,3 @@ class PrecachingScreen(Screen):
|
||||
if self.is_precaching and self.precache_worker:
|
||||
self.precache_worker.cancel()
|
||||
self.app.pop_screen()
|
||||
|
||||
def action_quit_app(self):
|
||||
if self.is_precaching and self.precache_worker:
|
||||
self.precache_worker.cancel()
|
||||
self.app.exit()
|
||||
|
||||
@@ -11,8 +11,8 @@ import heurams.kernel.particles as pt
|
||||
import heurams.services.hasher as hasher
|
||||
from heurams.context import *
|
||||
from heurams.context import config_var
|
||||
from heurams.services.logger import get_logger
|
||||
from heurams.kernel.repolib import *
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -59,7 +59,8 @@ class PreparationScreen(Screen):
|
||||
)
|
||||
|
||||
yield Static(f"\n单元预览:\n")
|
||||
yield Markdown(self._get_full_content().replace("/", ""), classes="full")
|
||||
for i in self._get_full_content().replace("/", "").splitlines():
|
||||
yield Static(i, classes="full")
|
||||
yield Footer()
|
||||
|
||||
# def watch_scheduled_num(self, old_scheduled_num, new_scheduled_num):
|
||||
@@ -76,7 +77,7 @@ class PreparationScreen(Screen):
|
||||
n = pt.Nucleon.create_on_nucleonic_data(
|
||||
nucleonic_data=self.repo.nucleonic_data_lict.get_itemic_unit(i)
|
||||
)
|
||||
content += f"- {n['content']} \n"
|
||||
content += f" • {n['content']} \n"
|
||||
return content
|
||||
|
||||
def action_go_back(self):
|
||||
@@ -126,14 +127,14 @@ class PreparationScreen(Screen):
|
||||
left_new -= 1
|
||||
if left_new >= 0:
|
||||
atoms_to_provide.append(i)
|
||||
from .memoqueue import MemScreen
|
||||
import heurams.kernel.reactor as rt
|
||||
|
||||
from .memoqueue import MemScreen
|
||||
|
||||
pheser = rt.Phaser(atoms_to_provide)
|
||||
save_func = self.repo.persist_to_repodir
|
||||
memscreen = MemScreen(pheser, save_func)
|
||||
memscreen = MemScreen(pheser, save_func, repo=self.repo)
|
||||
self.app.push_screen(memscreen)
|
||||
|
||||
|
||||
elif event.button.id == "precache_button":
|
||||
self.action_precache()
|
||||
|
||||
@@ -1 +1,218 @@
|
||||
"""用于筛选当日记忆的条目 以音频形式重放"""
|
||||
|
||||
""" "前进电台" 界面"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from matplotlib.cbook import ls_mapper
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, ScrollableContainer
|
||||
from textual.reactive import reactive
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import Button, Footer, Header, Label, Static
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
from heurams.kernel.repolib import Repo
|
||||
from heurams.context import config_var
|
||||
from heurams.services.audio_service import play_by_path
|
||||
from heurams.services.hasher import get_md5
|
||||
from heurams.services.logger import get_logger
|
||||
from heurams.services.tts_service import convertor
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RadioScreen(Screen):
|
||||
SUB_TITLE = "电台"
|
||||
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
("space", "toggle_play", "播放/暂停"),
|
||||
]
|
||||
|
||||
# 当前播放的原子索引
|
||||
current_index = reactive(0)
|
||||
# 播放状态: 'stopped', 'playing', 'paused'
|
||||
play_state = reactive("stopped")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(name, id, classes)
|
||||
self._organizer()
|
||||
|
||||
def _organizer(self):
|
||||
repodirs = Repo.probe_valid_repos_in_dir(Path(config_var.get()['paths']['data']) / 'repo')
|
||||
repos = list(map(lambda repodir: Repo.create_from_repodir(repodir), repodirs))
|
||||
for repo in repos:
|
||||
last_modify = 0.0
|
||||
for i in repo.ident_index:
|
||||
e = pt.Electron.create_on_electonic_data(
|
||||
electronic_data=repo.electronic_data_lict.get_itemic_unit(i)
|
||||
)
|
||||
last_modify = max(last_modify, e.las())
|
||||
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=True)
|
||||
with Container(id="main"):
|
||||
yield Label("[b]前进电台[/b]", classes="title")
|
||||
yield Static(f"共 {len(self.atoms)} 条当日记忆", id="status")
|
||||
with Container(id="controls"):
|
||||
yield Button("播放", id="play", variant="success")
|
||||
yield Button("暂停", id="pause", variant="primary")
|
||||
yield Button("上一首", id="prev", variant="default")
|
||||
yield Button("下一首", id="next", variant="default")
|
||||
yield Button("停止", id="stop", variant="error")
|
||||
yield ScrollableContainer(id="playlist")
|
||||
yield Footer()
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""挂载后更新播放列表显示"""
|
||||
self._update_playlist()
|
||||
|
||||
def _filter_due_atoms(self) -> List[pt.Atom]:
|
||||
"""筛选当日需要复习的原子(已激活且到期)"""
|
||||
atoms = []
|
||||
for ident in self.repo.ident_index:
|
||||
n = pt.Nucleon.create_on_nucleonic_data(
|
||||
nucleonic_data=self.repo.nucleonic_data_lict.get_itemic_unit(ident)
|
||||
)
|
||||
e = pt.Electron.create_on_electonic_data(
|
||||
electronic_data=self.repo.electronic_data_lict.get_itemic_unit(ident)
|
||||
)
|
||||
a = pt.Atom(n, e, self.repo.orbitic_data)
|
||||
# 仅选择已激活且到期的原子
|
||||
if (
|
||||
a.registry["electron"].is_activated()
|
||||
and a.registry["electron"].is_due()
|
||||
):
|
||||
atoms.append(a)
|
||||
return atoms
|
||||
|
||||
def _update_playlist(self) -> None:
|
||||
"""更新播放列表显示"""
|
||||
container = self.query_one("#playlist")
|
||||
container.remove_children()
|
||||
for idx, atom in enumerate(self.atoms):
|
||||
content = atom.registry["nucleon"].get("content", "无内容")
|
||||
prefix = "▶ " if idx == self.current_index else " "
|
||||
widget = Static(f"{prefix}{idx+1}. {content[:50]}...")
|
||||
widget.set_class(idx == self.current_index, "current")
|
||||
container.mount(widget)
|
||||
|
||||
def _get_audio_path(self, atom: pt.Atom) -> Path:
|
||||
"""返回音频文件路径,若不存在则生成"""
|
||||
tts_text = atom.registry["nucleon"].get("tts_text", "")
|
||||
if not tts_text:
|
||||
tts_text = atom.registry["nucleon"].get("content", "")
|
||||
voice_dir = Path(config_var.get()["paths"]["data"]) / "cache" / "voice"
|
||||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = voice_dir / f"{get_md5(tts_text)}.wav"
|
||||
if not path.exists():
|
||||
convertor(tts_text, path)
|
||||
return path
|
||||
|
||||
async def _play_atom(self, idx: int) -> None:
|
||||
"""播放指定索引的原子(异步)"""
|
||||
if idx < 0 or idx >= len(self.atoms):
|
||||
return
|
||||
atom = self.atoms[idx]
|
||||
try:
|
||||
path = self._get_audio_path(atom)
|
||||
self._current_path = path
|
||||
# 在后台线程中播放,避免阻塞UI
|
||||
await self.run_worker(
|
||||
lambda: play_by_path(path), exclusive=True, thread=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("播放失败: %s", e)
|
||||
|
||||
def _stop_playback(self) -> None:
|
||||
"""停止当前播放"""
|
||||
if self._play_task and not self._play_task.done():
|
||||
self._play_task.cancel()
|
||||
self._play_task = None
|
||||
self._current_path = None
|
||||
self.play_state = "stopped"
|
||||
|
||||
async def _play_current(self) -> None:
|
||||
"""播放当前索引的原子"""
|
||||
self._stop_playback()
|
||||
self.play_state = "playing"
|
||||
self._play_task = asyncio.create_task(self._play_atom(self.current_index))
|
||||
try:
|
||||
await self._play_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
if self.play_state == "playing":
|
||||
self.play_state = "stopped"
|
||||
|
||||
# 按钮事件处理
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
button_id = event.button.id
|
||||
if button_id == "play":
|
||||
self.action_toggle_play()
|
||||
elif button_id == "pause":
|
||||
self.action_pause()
|
||||
elif button_id == "prev":
|
||||
self.action_prev()
|
||||
elif button_id == "next":
|
||||
self.action_next()
|
||||
elif button_id == "stop":
|
||||
self.action_stop()
|
||||
|
||||
# 键盘动作
|
||||
def action_toggle_play(self) -> None:
|
||||
if self.play_state == "playing":
|
||||
self.action_pause()
|
||||
else:
|
||||
self.action_play()
|
||||
|
||||
def action_play(self) -> None:
|
||||
if self.play_state != "playing":
|
||||
if self.play_state == "paused":
|
||||
# 恢复播放(目前暂停功能简单实现为停止)
|
||||
self.play_state = "playing"
|
||||
else:
|
||||
asyncio.create_task(self._play_current())
|
||||
|
||||
def action_pause(self) -> None:
|
||||
if self.play_state == "playing":
|
||||
self._stop_playback()
|
||||
self.play_state = "paused"
|
||||
|
||||
def action_stop(self) -> None:
|
||||
self._stop_playback()
|
||||
self.play_state = "stopped"
|
||||
|
||||
def action_next(self) -> None:
|
||||
if self.current_index < len(self.atoms) - 1:
|
||||
self.current_index += 1
|
||||
self._update_playlist()
|
||||
if self.play_state == "playing":
|
||||
asyncio.create_task(self._play_current())
|
||||
|
||||
def action_prev(self) -> None:
|
||||
if self.current_index > 0:
|
||||
self.current_index -= 1
|
||||
self._update_playlist()
|
||||
if self.play_state == "playing":
|
||||
asyncio.create_task(self._play_current())
|
||||
|
||||
def action_go_back(self) -> None:
|
||||
self._stop_playback()
|
||||
self.app.pop_screen()
|
||||
|
||||
# 响应式更新
|
||||
def watch_current_index(self, old: int, new: int) -> None:
|
||||
self._update_playlist()
|
||||
|
||||
def watch_play_state(self, old: str, new: str) -> None:
|
||||
# 更新按钮状态(可在此添加样式变化)
|
||||
pass
|
||||
|
||||
@@ -24,7 +24,7 @@ class RepoCreatorScreen(Screen):
|
||||
|
||||
from heurams.context import config_var
|
||||
|
||||
template_dir = Path(config_var.get()["paths"]["template_dir"])
|
||||
template_dir = Path(config_var.get()["paths"]["data"]) / "templates"
|
||||
templates = list()
|
||||
for i in template_dir.iterdir():
|
||||
if i.name.endswith(".toml"):
|
||||
|
||||
267
src/heurams/interface/screens/repoeditor.py
Normal file
267
src/heurams/interface/screens/repoeditor.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""仓库编辑器, 使用TextArea控件等实现仓库配置编辑"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import toml
|
||||
from textual.app import ComposeResult
|
||||
from textual.containers import Container, Horizontal, ScrollableContainer, Vertical
|
||||
from textual.reactive import reactive
|
||||
from textual.screen import Screen
|
||||
from textual.widgets import (
|
||||
Button,
|
||||
Footer,
|
||||
Header,
|
||||
Label,
|
||||
ListItem,
|
||||
ListView,
|
||||
Static,
|
||||
TextArea,
|
||||
)
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.kernel.repolib import Repo
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RepoEditorScreen(Screen):
|
||||
"""仓库编辑器屏幕"""
|
||||
|
||||
SUB_TITLE = "仓库编辑器"
|
||||
|
||||
BINDINGS = [
|
||||
("q", "go_back", "返回"),
|
||||
("s", "save_file", "保存"),
|
||||
("r", "reload_file", "重载"),
|
||||
("d", "toggle_dark", ""),
|
||||
]
|
||||
|
||||
# 当前选择的仓库路径
|
||||
selected_repo_path: reactive[Optional[Path]] = reactive(None)
|
||||
# 当前选择的文件名
|
||||
selected_filename: reactive[Optional[str]] = reactive(None)
|
||||
# 文件内容
|
||||
file_content: reactive[str] = reactive("")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo: Optional[Repo] = None,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(name, id, classes)
|
||||
self.repo = repo
|
||||
self.repo_dir: Optional[Path] = None
|
||||
self.file_list = []
|
||||
if repo is not None and repo.source is not None:
|
||||
self.repo_dir = repo.source
|
||||
self._load_file_list()
|
||||
# selected_repo_path 将在 on_mount 中设置,避免触发watch时组件未就绪
|
||||
|
||||
def _load_file_list(self) -> None:
|
||||
"""加载仓库目录下的文件列表"""
|
||||
if self.repo_dir is None:
|
||||
return
|
||||
self.file_list = []
|
||||
for fname in Repo.file_mapping.values():
|
||||
fpath = self.repo_dir / fname
|
||||
if fpath.exists():
|
||||
self.file_list.append(fname)
|
||||
# 也可能存在其他文件,但暂时只支持标准文件
|
||||
self.file_list.sort()
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
"""组合界面组件"""
|
||||
yield Header(show_clock=True)
|
||||
with Container(id="main_container"):
|
||||
with Horizontal(id="top_panel"):
|
||||
# 左侧: 仓库选择
|
||||
with Vertical(id="repo_selector", classes="panel"):
|
||||
yield Label("仓库列表", classes="panel-title")
|
||||
yield ListView(
|
||||
*[
|
||||
ListItem(Label(repo_dir.name))
|
||||
for repo_dir in self._get_repo_dirs()
|
||||
],
|
||||
id="repo_list",
|
||||
classes="list-view",
|
||||
)
|
||||
# 中间: 文件列表
|
||||
with Vertical(id="file_selector", classes="panel"):
|
||||
yield Label("文件列表", classes="panel-title")
|
||||
yield ListView(
|
||||
*[ListItem(Label(fname)) for fname in self.file_list],
|
||||
id="file_list",
|
||||
classes="list-view",
|
||||
)
|
||||
# 右侧: 编辑区域
|
||||
with Vertical(id="editor_panel", classes="panel"):
|
||||
yield Label("编辑文件", classes="panel-title")
|
||||
yield TextArea(
|
||||
id="text_editor",
|
||||
language="plaintext",
|
||||
classes="text-editor",
|
||||
)
|
||||
with Horizontal(id="button_bar"):
|
||||
yield Button("保存", id="save_button", variant="primary")
|
||||
yield Button("重载", id="reload_button", variant="default")
|
||||
yield Button("返回", id="back_button", variant="error")
|
||||
yield Footer()
|
||||
|
||||
def _get_repo_dirs(self) -> list[Path]:
|
||||
"""获取data/repo/下所有有效仓库目录"""
|
||||
repo_root = Path(config_var.get()["paths"]["data"]) / "repo"
|
||||
repo_dirs = []
|
||||
if repo_root.exists():
|
||||
for entry in repo_root.iterdir():
|
||||
if entry.is_dir():
|
||||
# 检查是否存在 manifest.toml
|
||||
if (entry / "manifest.toml").exists():
|
||||
repo_dirs.append(entry)
|
||||
return repo_dirs
|
||||
|
||||
def on_mount(self) -> None:
|
||||
"""挂载组件时初始化"""
|
||||
# 如果已有仓库,设置 selected_repo_path 以触发watch(此时组件已就绪)
|
||||
if self.repo_dir is not None:
|
||||
self.selected_repo_path = self.repo_dir
|
||||
# 焦点放在仓库列表
|
||||
self.query_one("#repo_list", ListView).focus()
|
||||
|
||||
def watch_selected_repo_path(
|
||||
self, old_path: Optional[Path], new_path: Optional[Path]
|
||||
) -> None:
|
||||
"""当选择的仓库路径变化时,加载文件列表"""
|
||||
if new_path is None:
|
||||
self.file_list = []
|
||||
self.selected_filename = None
|
||||
self.file_content = ""
|
||||
return
|
||||
self.repo_dir = new_path
|
||||
self._load_file_list()
|
||||
# 如果组件已挂载,更新UI
|
||||
if self.is_mounted:
|
||||
file_list_view = self.query_one("#file_list", ListView)
|
||||
file_list_view.clear()
|
||||
for fname in self.file_list:
|
||||
file_list_view.append(ListItem(Label(fname)))
|
||||
# 清空编辑器
|
||||
self.query_one("#text_editor", TextArea).text = ""
|
||||
self.selected_filename = None
|
||||
|
||||
def watch_selected_filename(
|
||||
self, old_name: Optional[str], new_name: Optional[str]
|
||||
) -> None:
|
||||
"""当选择的文件名变化时,加载文件内容"""
|
||||
if new_name is None or self.repo_dir is None:
|
||||
self.file_content = ""
|
||||
return
|
||||
file_path = self.repo_dir / new_name
|
||||
if not file_path.exists():
|
||||
self.notify(f"文件不存在: {new_name}", severity="error")
|
||||
return
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
self.file_content = content
|
||||
# 如果组件已挂载,更新编辑器
|
||||
if self.is_mounted:
|
||||
editor = self.query_one("#text_editor", TextArea)
|
||||
editor.text = content
|
||||
# 根据文件后缀设置语言
|
||||
if new_name.endswith(".toml"):
|
||||
editor.language = "toml"
|
||||
elif new_name.endswith(".json"):
|
||||
editor.language = "json"
|
||||
else:
|
||||
editor.language = "plaintext"
|
||||
except Exception as e:
|
||||
logger.error(f"读取文件失败: {e}")
|
||||
self.notify(f"读取文件失败: {e}", severity="error")
|
||||
|
||||
def watch_file_content(self, old_content: str, new_content: str) -> None:
|
||||
"""当文件内容变化时更新编辑器(仅当外部改变时)"""
|
||||
# 目前不需要做任何事情,因为编辑器内容已绑定
|
||||
pass
|
||||
|
||||
def on_list_view_selected(self, event) -> None:
|
||||
"""处理列表项选择事件"""
|
||||
if not isinstance(event.item, ListItem):
|
||||
return
|
||||
list_id = event.list_view.id
|
||||
selected_label = event.item.query_one(Label)
|
||||
selected_text = str(selected_label.render())
|
||||
|
||||
if list_id == "repo_list":
|
||||
# 用户选择了仓库
|
||||
repo_root = Path(config_var.get()["paths"]["data"]) / "repo"
|
||||
selected_dir = repo_root / selected_text
|
||||
if selected_dir.exists():
|
||||
self.selected_repo_path = selected_dir
|
||||
elif list_id == "file_list":
|
||||
# 用户选择了文件
|
||||
if self.repo_dir is None:
|
||||
self.notify("请先选择仓库", severity="warning")
|
||||
return
|
||||
self.selected_filename = selected_text
|
||||
|
||||
def on_button_pressed(self, event) -> None:
|
||||
"""处理按钮点击事件"""
|
||||
event.stop()
|
||||
if event.button.id == "save_button":
|
||||
self.action_save_file()
|
||||
elif event.button.id == "reload_button":
|
||||
self.action_reload_file()
|
||||
elif event.button.id == "back_button":
|
||||
self.action_go_back()
|
||||
|
||||
def action_save_file(self) -> None:
|
||||
"""保存当前编辑的文件"""
|
||||
if self.repo_dir is None or self.selected_filename is None:
|
||||
self.notify("未选择仓库或文件", severity="warning")
|
||||
return
|
||||
file_path = self.repo_dir / self.selected_filename
|
||||
editor = self.query_one("#text_editor", TextArea)
|
||||
new_content = editor.text
|
||||
# 验证格式
|
||||
try:
|
||||
if self.selected_filename.endswith(".toml"):
|
||||
toml.loads(new_content) # 验证TOML
|
||||
elif self.selected_filename.endswith(".json"):
|
||||
json.loads(new_content) # 验证JSON
|
||||
except Exception as e:
|
||||
self.notify(f"格式错误: {e}", severity="error")
|
||||
return
|
||||
# 写入文件
|
||||
try:
|
||||
file_path.write_text(new_content, encoding="utf-8")
|
||||
self.notify("保存成功", severity="information")
|
||||
except Exception as e:
|
||||
logger.error(f"保存文件失败: {e}")
|
||||
self.notify(f"保存文件失败: {e}", severity="error")
|
||||
|
||||
def action_reload_file(self) -> None:
|
||||
"""重新加载当前文件(放弃修改)"""
|
||||
if self.repo_dir is None or self.selected_filename is None:
|
||||
self.notify("未选择仓库或文件", severity="warning")
|
||||
return
|
||||
file_path = self.repo_dir / self.selected_filename
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
editor = self.query_one("#text_editor", TextArea)
|
||||
editor.text = content
|
||||
self.notify("已重载", severity="information")
|
||||
except Exception as e:
|
||||
logger.error(f"重载文件失败: {e}")
|
||||
self.notify(f"重载文件失败: {e}", severity="error")
|
||||
|
||||
def action_go_back(self) -> None:
|
||||
"""返回上一屏幕"""
|
||||
self.app.pop_screen()
|
||||
|
||||
def action_toggle_dark(self) -> None:
|
||||
"""切换暗色模式"""
|
||||
self.app.dark = not self.app.dark
|
||||
@@ -7,8 +7,8 @@ from textual.message import Message
|
||||
from textual.widget import Widget
|
||||
from textual.widgets import Button, Label
|
||||
|
||||
import heurams.kernel.puzzles as pz
|
||||
import heurams.kernel.particles as pt
|
||||
import heurams.kernel.puzzles as pz
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .base_puzzle_widget import BasePuzzleWidget
|
||||
|
||||
@@ -7,12 +7,12 @@ class Finished(Widget):
|
||||
self,
|
||||
*children: Widget,
|
||||
alia="",
|
||||
is_saved = 0,
|
||||
is_saved=0,
|
||||
name: str | None = None,
|
||||
id: str | None = None,
|
||||
classes: str | None = None,
|
||||
disabled: bool = False,
|
||||
markup: bool = True
|
||||
markup: bool = True,
|
||||
) -> None:
|
||||
self.alia = alia
|
||||
self.is_saved = is_saved
|
||||
@@ -22,7 +22,7 @@ class Finished(Widget):
|
||||
id=id,
|
||||
classes=classes,
|
||||
disabled=disabled,
|
||||
markup=markup
|
||||
markup=markup,
|
||||
)
|
||||
|
||||
def compose(self):
|
||||
|
||||
@@ -5,8 +5,8 @@ from textual.containers import Container, ScrollableContainer
|
||||
from textual.widget import Widget
|
||||
from textual.widgets import Button, Label
|
||||
|
||||
import heurams.kernel.puzzles as pz
|
||||
import heurams.kernel.particles as pt
|
||||
import heurams.kernel.puzzles as pz
|
||||
from heurams.services.hasher import hash
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ class Recognition(BasePuzzleWidget):
|
||||
for item in cfg["secondary"]:
|
||||
if isinstance(item, list):
|
||||
for j in item:
|
||||
yield Markdown(f"### {metadata['annotation'][item]}: {j}")
|
||||
yield Markdown(f"### {j}") #TODO ANNOTATION
|
||||
continue
|
||||
if isinstance(item, Dict):
|
||||
total = ""
|
||||
|
||||
@@ -3,8 +3,8 @@ from .electron import Electron
|
||||
from .nucleon import Nucleon
|
||||
from .placeholders import (
|
||||
AtomPlaceholder,
|
||||
NucleonPlaceholder,
|
||||
ElectronPlaceholder,
|
||||
NucleonPlaceholder,
|
||||
orbital_placeholder,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .electron import Electron
|
||||
|
||||
@@ -57,6 +57,10 @@ class Electron:
|
||||
result = self.algodata[self.algo.algo_name]["is_activated"]
|
||||
return result
|
||||
|
||||
def last_modify(self):
|
||||
result = self.algodata[self.algo.algo_name]["last_modify"]
|
||||
return result
|
||||
|
||||
def get_rating(self):
|
||||
try:
|
||||
result = self.algo.get_rating(self.algodata)
|
||||
@@ -68,6 +72,10 @@ class Electron:
|
||||
result = self.algo.nextdate(self.algodata)
|
||||
return result
|
||||
|
||||
def lastdate(self) -> int:
|
||||
result = self.algodata[self.algo.algo_name]["lastdate"]
|
||||
return result
|
||||
|
||||
def revisor(self, quality: int = 5, is_new_activation: bool = False):
|
||||
"""算法迭代决策机制实现
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from copy import deepcopy
|
||||
from logging import config
|
||||
|
||||
from heurams.services.logger import get_logger
|
||||
from heurams.utils.evalizor import Evalizer
|
||||
from heurams.context import config_var
|
||||
from heurams.services.logger import get_logger
|
||||
from heurams.kernel.auxiliary.evalizor import Evalizer
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from heurams.kernel.particles import orbital
|
||||
|
||||
from .atom import Atom
|
||||
from .electron import Electron
|
||||
from .nucleon import Nucleon
|
||||
from .atom import Atom
|
||||
|
||||
orbital_placeholder = {
|
||||
"schedule": ["quick_review", "recognition", "final_review"],
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from functools import reduce
|
||||
import random
|
||||
|
||||
import heurams.kernel.puzzles as puz
|
||||
import heurams.kernel.particles as pt
|
||||
from heurams.services.logger import get_logger
|
||||
from functools import reduce
|
||||
|
||||
from tabulate import tabulate as tabu
|
||||
from transitions import Machine
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
import heurams.kernel.puzzles as puz
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .states import FissionState, PhaserState
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from click import style
|
||||
from transitions import Machine
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
from heurams.kernel.particles.placeholders import AtomPlaceholder
|
||||
from heurams.services.logger import get_logger
|
||||
from transitions import Machine
|
||||
|
||||
from .procession import Procession
|
||||
from .states import PhaserState, ProcessionState
|
||||
@@ -133,9 +134,10 @@ class Phaser(Machine):
|
||||
return Procession([AtomPlaceholder()], PhaserState.FINISHED)
|
||||
|
||||
def __repr__(self, style="pipe", ends="\n"):
|
||||
from heurams.services.textproc import truncate
|
||||
from tabulate import tabulate as tabu
|
||||
|
||||
from heurams.services.textproc import truncate
|
||||
|
||||
lst = [
|
||||
{
|
||||
"Type": "Phaser",
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from tabulate import tabulate as tabu
|
||||
from transitions import Machine
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
from heurams.services.logger import get_logger
|
||||
from transitions import Machine
|
||||
from tabulate import tabulate as tabu
|
||||
|
||||
from .fission import Fission
|
||||
from .states import PhaserState, ProcessionState
|
||||
|
||||
@@ -7,7 +7,7 @@ import toml
|
||||
|
||||
import heurams.kernel.particles as pt
|
||||
|
||||
from ...utils.lict import Lict
|
||||
from heurams.kernel.auxiliary.lict import Lict
|
||||
|
||||
|
||||
class RepoManifest(TypedDict):
|
||||
@@ -167,7 +167,7 @@ class Repo:
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def probe_vaild_repos_in_dir(cls, folder: Path):
|
||||
def probe_valid_repos_in_dir(cls, folder: Path):
|
||||
lst = list()
|
||||
for i in folder.iterdir():
|
||||
if i.is_dir():
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
import pathlib
|
||||
from typing import Protocol
|
||||
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PlayFunctionProtocol(Protocol):
|
||||
def __call__(self, path: pathlib.Path) -> None: ...
|
||||
|
||||
|
||||
logger.debug("音频协议模块已加载")
|
||||
@@ -1,6 +1,19 @@
|
||||
# 大语言模型
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .base import BaseLLM
|
||||
from .openai import OpenAILLM
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.debug("LLM providers 模块已加载")
|
||||
__all__ = [
|
||||
"BaseLLM",
|
||||
"OpenAILLM",
|
||||
]
|
||||
|
||||
providers = {
|
||||
"base": BaseLLM,
|
||||
"openai": OpenAILLM,
|
||||
}
|
||||
|
||||
logger.debug("LLM providers 已注册: %s", list(providers.keys()))
|
||||
|
||||
@@ -1,5 +1,55 @@
|
||||
"""LLM 提供者基类"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.debug("LLM 基类模块已加载")
|
||||
|
||||
class BaseLLM:
|
||||
"""LLM 提供者基类"""
|
||||
|
||||
name = "BaseLLM"
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""初始化 LLM 提供者
|
||||
|
||||
Args:
|
||||
config: 提供者配置字典
|
||||
"""
|
||||
self.config = config
|
||||
logger.debug("BaseLLM 初始化完成")
|
||||
|
||||
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
||||
"""发送聊天消息并获取响应
|
||||
|
||||
Args:
|
||||
messages: 消息列表,每个消息为 {"role": "user"|"assistant"|"system", "content": "消息内容"}
|
||||
**kwargs: 其他参数,如 temperature, max_tokens 等
|
||||
|
||||
Returns:
|
||||
模型返回的文本响应
|
||||
"""
|
||||
logger.debug("BaseLLM.chat: messages=%d, kwargs=%s", len(messages), kwargs)
|
||||
logger.warning("BaseLLM.chat 是基类方法,未实现具体功能")
|
||||
await asyncio.sleep(0) # 避免未使用异步的警告
|
||||
return "BaseLLM 未实现具体功能"
|
||||
|
||||
async def chat_stream(self, messages: List[Dict[str, str]], **kwargs):
|
||||
"""流式聊天(可选实现)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
流式响应的文本块
|
||||
"""
|
||||
logger.debug(
|
||||
"BaseLLM.chat_stream: messages=%d, kwargs=%s", len(messages), kwargs
|
||||
)
|
||||
logger.warning("BaseLLM.chat_stream 是基类方法,未实现具体功能")
|
||||
await asyncio.sleep(0)
|
||||
yield "BaseLLM 未实现流式功能"
|
||||
|
||||
@@ -1,5 +1,96 @@
|
||||
"""OpenAI 兼容 LLM 提供者"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
from .base import BaseLLM
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
logger.debug("OpenAI provider 模块已加载(未实现)")
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
"""OpenAI 兼容 LLM 提供者"""
|
||||
|
||||
name = "OpenAI"
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
super().__init__(config)
|
||||
self.api_key = config.get("key", "")
|
||||
self.base_url = config.get("url", "https://api.openai.com/v1")
|
||||
self._client = None
|
||||
logger.debug("OpenAILLM 初始化完成: base_url=%s", self.base_url)
|
||||
|
||||
def _get_client(self):
|
||||
"""获取 OpenAI 客户端(延迟导入)"""
|
||||
if self._client is None:
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
except ImportError:
|
||||
logger.error("未安装 openai 库,请运行: pip install openai")
|
||||
raise ImportError("未安装 openai 库,请运行: pip install openai")
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key if self.api_key else None,
|
||||
base_url=self.base_url if self.base_url else None,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
||||
"""发送聊天消息并获取响应"""
|
||||
logger.debug("OpenAILLM.chat: messages=%d", len(messages))
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
# 默认参数
|
||||
default_kwargs = {
|
||||
"model": kwargs.get("model", "gpt-3.5-turbo"),
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||
}
|
||||
|
||||
# 合并参数,优先使用传入的 kwargs
|
||||
request_kwargs = {**default_kwargs, **kwargs}
|
||||
request_kwargs["messages"] = messages
|
||||
|
||||
try:
|
||||
response = await client.chat.completions.create(**request_kwargs)
|
||||
content = response.choices[0].message.content
|
||||
logger.debug(
|
||||
"OpenAILLM.chat 成功: response length=%d",
|
||||
len(content) if content else 0,
|
||||
)
|
||||
return content or ""
|
||||
except Exception as e:
|
||||
logger.error("OpenAILLM.chat 失败: %s", e)
|
||||
raise
|
||||
|
||||
async def chat_stream(
|
||||
self, messages: List[Dict[str, str]], **kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式聊天"""
|
||||
logger.debug("OpenAILLM.chat_stream: messages=%d", len(messages))
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
# 默认参数
|
||||
default_kwargs = {
|
||||
"model": kwargs.get("model", "gpt-3.5-turbo"),
|
||||
"temperature": kwargs.get("temperature", 0.7),
|
||||
"max_tokens": kwargs.get("max_tokens", 1000),
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# 合并参数
|
||||
request_kwargs = {**default_kwargs, **kwargs}
|
||||
request_kwargs["messages"] = messages
|
||||
|
||||
try:
|
||||
stream = await client.chat.completions.create(**request_kwargs)
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
except Exception as e:
|
||||
logger.error("OpenAILLM.chat_stream 失败: %s", e)
|
||||
raise
|
||||
|
||||
163
src/heurams/services/favorite_service.py
Normal file
163
src/heurams/services/favorite_service.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# 收藏服务
|
||||
import json
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FavoriteItem:
|
||||
"""收藏项"""
|
||||
|
||||
repo_path: str # 仓库相对路径 (相对于 data/repo)
|
||||
ident: str # 原子标识符
|
||||
added: int # 添加时间戳 (UNIX 秒)
|
||||
# 可选标签
|
||||
tags: List[str] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tags is None:
|
||||
self.tags = []
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"repo_path": self.repo_path,
|
||||
"ident": self.ident,
|
||||
"added": self.added,
|
||||
"tags": self.tags,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "FavoriteItem":
|
||||
return cls(
|
||||
repo_path=data["repo_path"],
|
||||
ident=data["ident"],
|
||||
added=data["added"],
|
||||
tags=data.get("tags", []),
|
||||
)
|
||||
|
||||
|
||||
class FavoriteManager:
|
||||
"""收藏管理器"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, "_loaded"):
|
||||
self._loaded = True
|
||||
self._favorites: List[FavoriteItem] = []
|
||||
self._file_path = self._get_file_path()
|
||||
self.load()
|
||||
|
||||
def _get_file_path(self) -> Path:
|
||||
"""获取收藏文件路径"""
|
||||
config_path = Path(config_var.get()["paths"]["data"])
|
||||
fav_path = config_path / "global" / "favorites.json"
|
||||
fav_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return fav_path
|
||||
|
||||
def load(self) -> None:
|
||||
"""从文件加载收藏列表"""
|
||||
if self._file_path.exists():
|
||||
try:
|
||||
with open(self._file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
self._favorites = [FavoriteItem.from_dict(item) for item in data]
|
||||
logger.debug("收藏列表加载成功,共 %d 项", len(self._favorites))
|
||||
except Exception as e:
|
||||
logger.error("加载收藏列表失败: %s", e)
|
||||
self._favorites = []
|
||||
else:
|
||||
self._favorites = []
|
||||
|
||||
def save(self) -> None:
|
||||
"""保存收藏列表到文件"""
|
||||
try:
|
||||
data = [item.to_dict() for item in self._favorites]
|
||||
with open(self._file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
logger.debug("收藏列表保存成功,共 %d 项", len(self._favorites))
|
||||
except Exception as e:
|
||||
logger.error("保存收藏列表失败: %s", e)
|
||||
|
||||
def add(self, repo_path: str, ident: str, tags: List[str] | None = None) -> bool:
|
||||
"""添加收藏
|
||||
|
||||
Args:
|
||||
repo_path: 仓库相对路径
|
||||
ident: 原子标识符
|
||||
tags: 标签列表
|
||||
Returns:
|
||||
是否成功添加 (若已存在则返回 False)
|
||||
"""
|
||||
# 检查是否已存在
|
||||
for item in self._favorites:
|
||||
if item.repo_path == repo_path and item.ident == ident:
|
||||
logger.debug("收藏已存在: %s/%s", repo_path, ident)
|
||||
return False
|
||||
item = FavoriteItem(
|
||||
repo_path=repo_path,
|
||||
ident=ident,
|
||||
added=int(time.time()),
|
||||
tags=tags if tags else [],
|
||||
)
|
||||
self._favorites.append(item)
|
||||
self.save()
|
||||
logger.info("添加收藏: %s/%s", repo_path, ident)
|
||||
return True
|
||||
|
||||
def remove(self, repo_path: str, ident: str) -> bool:
|
||||
"""移除收藏
|
||||
|
||||
Returns:
|
||||
是否成功移除 (若不存在则返回 False)
|
||||
"""
|
||||
for idx, item in enumerate(self._favorites):
|
||||
if item.repo_path == repo_path and item.ident == ident:
|
||||
del self._favorites[idx]
|
||||
self.save()
|
||||
logger.info("移除收藏: %s/%s", repo_path, ident)
|
||||
return True
|
||||
logger.debug("收藏不存在: %s/%s", repo_path, ident)
|
||||
return False
|
||||
|
||||
def has(self, repo_path: str, ident: str) -> bool:
|
||||
"""检查是否已收藏"""
|
||||
for item in self._favorites:
|
||||
if item.repo_path == repo_path and item.ident == ident:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_all(self) -> List[FavoriteItem]:
|
||||
"""获取所有收藏项(按添加时间倒序)"""
|
||||
return sorted(self._favorites, key=lambda x: x.added, reverse=True)
|
||||
|
||||
def get_by_repo(self, repo_path: str) -> List[FavoriteItem]:
|
||||
"""获取指定仓库的所有收藏项"""
|
||||
return [item for item in self._favorites if item.repo_path == repo_path]
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空收藏列表"""
|
||||
self._favorites = []
|
||||
self.save()
|
||||
logger.info("清空收藏列表")
|
||||
|
||||
def count(self) -> int:
|
||||
"""收藏总数"""
|
||||
return len(self._favorites)
|
||||
|
||||
|
||||
# 全局单例实例
|
||||
favorite_manager = FavoriteManager()
|
||||
228
src/heurams/services/llm_service.py
Normal file
228
src/heurams/services/llm_service.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""LLM 聊天服务"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from heurams.context import config_var
|
||||
from heurams.providers.llm import providers as prov
|
||||
from heurams.services.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ChatSession:
|
||||
"""聊天会话,管理单个对话的历史和参数"""
|
||||
|
||||
def __init__(
|
||||
self, session_id: str, llm_provider, system_prompt: str = "", **default_params
|
||||
):
|
||||
"""初始化聊天会话
|
||||
|
||||
Args:
|
||||
session_id: 会话唯一标识符
|
||||
llm_provider: LLM 提供者实例
|
||||
system_prompt: 系统提示词
|
||||
**default_params: 默认参数(temperature, max_tokens, model 等)
|
||||
"""
|
||||
self.session_id = session_id
|
||||
self.llm_provider = llm_provider
|
||||
self.system_prompt = system_prompt
|
||||
self.default_params = default_params
|
||||
|
||||
# 消息历史
|
||||
self.messages: List[Dict[str, str]] = []
|
||||
if system_prompt:
|
||||
self.messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
logger.debug("创建聊天会话: id=%s", session_id)
|
||||
|
||||
def add_message(self, role: str, content: str):
|
||||
"""添加消息到历史"""
|
||||
self.messages.append({"role": role, "content": content})
|
||||
logger.debug(
|
||||
"会话 %s 添加消息: role=%s, length=%d", self.session_id, role, len(content)
|
||||
)
|
||||
|
||||
def clear_history(self):
|
||||
"""清空消息历史(保留系统提示)"""
|
||||
self.messages = []
|
||||
if self.system_prompt:
|
||||
self.messages.append({"role": "system", "content": self.system_prompt})
|
||||
logger.debug("会话 %s 清空历史", self.session_id)
|
||||
|
||||
def set_system_prompt(self, prompt: str):
|
||||
"""设置系统提示词"""
|
||||
self.system_prompt = prompt
|
||||
# 更新消息历史中的系统消息
|
||||
if self.messages and self.messages[0]["role"] == "system":
|
||||
self.messages[0]["content"] = prompt
|
||||
elif prompt:
|
||||
self.messages.insert(0, {"role": "system", "content": prompt})
|
||||
logger.debug("会话 %s 设置系统提示: length=%d", self.session_id, len(prompt))
|
||||
|
||||
async def send_message(self, message: str, **override_params) -> str:
|
||||
"""发送消息并获取响应
|
||||
|
||||
Args:
|
||||
message: 用户消息内容
|
||||
**override_params: 覆盖默认参数
|
||||
|
||||
Returns:
|
||||
模型响应内容
|
||||
"""
|
||||
# 添加用户消息
|
||||
self.add_message("user", message)
|
||||
|
||||
# 合并参数
|
||||
params = {**self.default_params, **override_params}
|
||||
|
||||
# 发送请求
|
||||
logger.debug("会话 %s 发送消息: length=%d", self.session_id, len(message))
|
||||
response = await self.llm_provider.chat(self.messages, **params)
|
||||
|
||||
# 添加助手响应
|
||||
self.add_message("assistant", response)
|
||||
|
||||
return response
|
||||
|
||||
async def send_message_stream(self, message: str, **override_params):
|
||||
"""流式发送消息
|
||||
|
||||
Args:
|
||||
message: 用户消息内容
|
||||
**override_params: 覆盖默认参数
|
||||
|
||||
Yields:
|
||||
流式响应的文本块
|
||||
"""
|
||||
# 添加用户消息
|
||||
self.add_message("user", message)
|
||||
|
||||
# 合并参数
|
||||
params = {**self.default_params, **override_params}
|
||||
|
||||
# 发送流式请求
|
||||
logger.debug("会话 %s 发送流式消息: length=%d", self.session_id, len(message))
|
||||
|
||||
full_response = ""
|
||||
async for chunk in self.llm_provider.chat_stream(self.messages, **params):
|
||||
yield chunk
|
||||
full_response += chunk
|
||||
|
||||
# 添加完整的助手响应到历史
|
||||
self.add_message("assistant", full_response)
|
||||
|
||||
def get_history(self) -> List[Dict[str, str]]:
|
||||
"""获取消息历史(不包括系统消息)"""
|
||||
# 返回用户和助手的消息,可选排除系统消息
|
||||
return [msg for msg in self.messages if msg["role"] != "system"]
|
||||
|
||||
def save_to_file(self, file_path: Path):
|
||||
"""保存会话到文件"""
|
||||
data = {
|
||||
"session_id": self.session_id,
|
||||
"system_prompt": self.system_prompt,
|
||||
"default_params": self.default_params,
|
||||
"messages": self.messages,
|
||||
}
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
logger.debug("会话 %s 保存到: %s", self.session_id, file_path)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: Path, llm_provider):
|
||||
"""从文件加载会话"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
session = cls(
|
||||
session_id=data["session_id"],
|
||||
llm_provider=llm_provider,
|
||||
system_prompt=data.get("system_prompt", ""),
|
||||
**data.get("default_params", {})
|
||||
)
|
||||
session.messages = data["messages"]
|
||||
logger.debug("从文件加载会话: %s", file_path)
|
||||
return session
|
||||
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器,管理多个会话"""
|
||||
|
||||
def __init__(self):
|
||||
self.sessions: Dict[str, ChatSession] = {}
|
||||
self.default_session_id = "default"
|
||||
logger.debug("聊天管理器初始化完成")
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
session_id: Optional[str] = None,
|
||||
create_if_missing: bool = True,
|
||||
**session_params
|
||||
) -> Optional[ChatSession]:
|
||||
"""获取或创建聊天会话
|
||||
|
||||
Args:
|
||||
session_id: 会话标识符,None 则使用默认会话
|
||||
create_if_missing: 如果会话不存在则创建
|
||||
**session_params: 传递给 ChatSession 的参数
|
||||
|
||||
Returns:
|
||||
聊天会话实例,如果不存在且不创建则返回 None
|
||||
"""
|
||||
if session_id is None:
|
||||
session_id = self.default_session_id
|
||||
|
||||
if session_id in self.sessions:
|
||||
return self.sessions[session_id]
|
||||
|
||||
if create_if_missing:
|
||||
# 获取 LLM 提供者
|
||||
provider_name = config_var.get()["services"]["llm"]
|
||||
provider_config = config_var.get()["providers"]["llm"][provider_name]
|
||||
llm_provider = prov[provider_name](provider_config)
|
||||
|
||||
session = ChatSession(
|
||||
session_id=session_id, llm_provider=llm_provider, **session_params
|
||||
)
|
||||
self.sessions[session_id] = session
|
||||
logger.debug("创建新会话: id=%s", session_id)
|
||||
return session
|
||||
|
||||
return None
|
||||
|
||||
def delete_session(self, session_id: str):
|
||||
"""删除会话"""
|
||||
if session_id in self.sessions:
|
||||
del self.sessions[session_id]
|
||||
logger.debug("删除会话: id=%s", session_id)
|
||||
|
||||
def list_sessions(self) -> List[str]:
|
||||
"""列出所有会话ID"""
|
||||
return list(self.sessions.keys())
|
||||
|
||||
|
||||
# 全局聊天管理器实例
|
||||
_chat_manager: Optional[ChatManager] = None
|
||||
|
||||
|
||||
def get_chat_manager() -> ChatManager:
|
||||
"""获取全局聊天管理器实例"""
|
||||
global _chat_manager
|
||||
if _chat_manager is None:
|
||||
_chat_manager = ChatManager()
|
||||
logger.debug("创建全局聊天管理器")
|
||||
return _chat_manager
|
||||
|
||||
|
||||
def create_chat_session(
|
||||
session_id: Optional[str] = None, **session_params
|
||||
) -> ChatSession:
|
||||
"""创建或获取聊天会话(便捷函数)"""
|
||||
manager = get_chat_manager()
|
||||
return manager.get_session(session_id, True, **session_params)
|
||||
|
||||
|
||||
logger.debug("LLM 服务初始化完成")
|
||||
Reference in New Issue
Block a user