diff --git a/.gitignore b/.gitignore index ff7e71c..c863a00 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ __pycache__/ .idea/ cache/ +data/repo/cngk #nucleon/test.toml electron/test.toml *.egg-info/ diff --git a/data/config/config.toml b/data/config/config.toml index eb225ae..6479ca1 100644 --- a/data/config/config.toml +++ b/data/config/config.toml @@ -9,7 +9,7 @@ timestamp_override = -1 quick_pass = 1 # 对于每个项目的默认新记忆原子数量 -scheduled_num = 8 +scheduled_num = 999 # UTC 时间戳修正 仅用于 UNIX 日时间戳的生成修正, 单位为秒 timezone_offset = +28800 # 中国标准时间 (UTC+8) @@ -17,7 +17,7 @@ timezone_offset = +28800 # 中国标准时间 (UTC+8) [interface] [interface.memorizor] -autovoice = true # 自动语音播放, 仅限于 recognition 组件 +autovoice = 0 # 自动语音播放, 仅限于 recognition 组件 [algorithm] default = "SM-2" # 主要算法; 可选项: SM-2, SM-15M, FSRS diff --git a/examples/repo.ipynb b/examples/repo.ipynb index ccd57ed..db08ff7 100644 --- a/examples/repo.ipynb +++ b/examples/repo.ipynb @@ -340,7 +340,7 @@ } ], "source": [ - "from heurams.utils.lict import Lict\n", + "from heurams.kernel.auxiliary.lict import Lict\n", "\n", "lct = Lict() # 空的\n", "lct = Lict(initlist=[(\"name\", \"tom\"), (\"age\", 12), (\"enemy\", \"jerry\")]) # 基于列表\n", diff --git a/examples/simplemem.py b/examples/simplemem.py index e33cbbc..55c0a9f 100644 --- a/examples/simplemem.py +++ b/examples/simplemem.py @@ -1,8 +1,9 @@ -import heurams.kernel.repolib as repolib -import heurams.kernel.particles as pt -from heurams.services.textproc import truncate -from pathlib import Path import time +from pathlib import Path + +import heurams.kernel.particles as pt +import heurams.kernel.repolib as repolib +from heurams.services.textproc import truncate repo = repolib.Repo.create_from_repodir(Path("./test_repo")) alist = list() diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 0000000..87ab9ea --- /dev/null +++ b/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "HeurAMS", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/requirements.txt b/requirements.txt index e7e51e9..f4b26f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,9 @@ edge-tts==7.0.2 fsspec==2025.12.0 jieba==0.42.1 +openai>=1.0.0 playsound==1.2.2 tabulate==0.9.0 -textual==5.3.0 +textual==7.0.0 toml==0.10.2 -transitions==0.9.3 \ No newline at end of file +transitions==0.9.3 diff --git a/src/heurams/context.py b/src/heurams/context.py index 439640a..b9ea142 100644 --- a/src/heurams/context.py +++ b/src/heurams/context.py @@ -4,8 +4,8 @@ """ import pathlib -from contextvars import ContextVar import shutil +from contextvars import ContextVar from heurams.services.config import ConfigFile from heurams.services.logger import get_logger diff --git a/src/heurams/interface/__init__.py b/src/heurams/interface/__init__.py index f36099f..5142d97 100644 --- a/src/heurams/interface/__init__.py +++ b/src/heurams/interface/__init__.py @@ -1,4 +1,7 @@ +from typing import Type + from textual.app import App +from textual.driver import Driver from textual.widgets import Button from heurams.context import config_var @@ -6,8 +9,12 @@ from heurams.services.logger import get_logger from .screens.about import AboutScreen from .screens.dashboard import DashboardScreen +from .screens.llmchat import LLMChatScreen +from .screens.navigator import NavigatorScreen from .screens.precache import PrecachingScreen +from .screens.radio import RadioScreen from .screens.repocreator import RepoCreatorScreen +from .screens.repoeditor import RepoEditorScreen from .screens.synctool import SyncScreen logger = get_logger(__name__) @@ -35,13 +42,10 @@ class HeurAMSApp(App): CSS_PATH = "css/main.tcss" SUB_TITLE = "启发式辅助记忆调度器" BINDINGS = [ - ("q", "quit", "退出"), - ("d", "toggle_dark", "切换色调"), - ("1", "app.push_screen('dashboard')", "仪表盘"), - ("2", "app.push_screen('precache_all')", "缓存管理器"), - ("3", "app.push_screen('repo_creator')", "创建新仓库"), - # ("4", "app.push_screen('synctool')", "同步工具"), - ("0", "app.push_screen('about')", "版本信息"), + ("q", "go_back", "退出"), + ("d", "toggle_dark", "主题"), + ("n", "app.push_screen('navigator')", "导航"), + ("z", "app.push_screen('about')", "关于"), ] SCREENS = { "dashboard": DashboardScreen, @@ -49,6 +53,10 @@ class HeurAMSApp(App): "precache_all": PrecachingScreen, "synctool": SyncScreen, "about": AboutScreen, + "navigator": NavigatorScreen, + "radio": RadioScreen, + "repo_editor": RepoEditorScreen, + "llmchat": LLMChatScreen, } def on_mount(self) -> None: @@ -56,8 +64,11 @@ class HeurAMSApp(App): self.push_screen("dashboard") def on_button_pressed(self, event: Button.Pressed) -> None: - self.exit(event.button.id) + pass + # self.exit(event.button.id) + + def action_go_back(self) -> None: + quit() def action_do_nothing(self): - print("DO NOTHING") self.refresh() diff --git a/src/heurams/interface/css/main.tcss b/src/heurams/interface/css/main.tcss index e69de29..f060ddd 100644 --- a/src/heurams/interface/css/main.tcss +++ b/src/heurams/interface/css/main.tcss @@ -0,0 +1,64 @@ +NavigatorScreen { + align: center middle; +} + +#dialog { + grid-size: 2; + grid-gutter: 1 1; + grid-rows: 1fr 3; + padding: 0 1; + width: 46; + height: 12; + border: thick $background 80%; + background: $surface; +} + +/* LLM 聊天界面样式 */ +LLMChatScreen { + background: $surface; +} + +#chat-container { + height: 100%; + padding: 1; +} + +#toolbar { + height: 3; + margin-bottom: 1; + align: center middle; +} + +#toolbar Button { + margin: 0 1; +} + +#chat-log { + height: 1fr; + border: solid $primary; + padding: 1; + background: $surface; +} + +#input-container { + height: 3; + margin-top: 1; + align: center middle; +} + +#message-input { + width: 1fr; + margin-right: 1; +} + +#status-bar { + height: 1; + margin-top: 1; + text-style: italic; + color: $text-muted; +} + +.session-label { + color: $primary; + text-style: bold; +} diff --git a/src/heurams/interface/screens/about.py b/src/heurams/interface/screens/about.py index 1bf0e1e..e25cc4d 100644 --- a/src/heurams/interface/screens/about.py +++ b/src/heurams/interface/screens/about.py @@ -10,6 +10,9 @@ from heurams.context import * class AboutScreen(Screen): + BINDINGS = [ + ("q", "go_back", "返回"), + ] def compose(self) -> ComposeResult: yield Header(show_clock=True) @@ -22,10 +25,14 @@ class AboutScreen(Screen): 开发代号: {version.codename.capitalize()} {version.codename_cn} -一个基于启发式算法的开放源代码记忆调度器, 旨在帮助用户更高效地进行记忆工作与学习规划. +一个基于启发式算法的辅助记忆调度器, 旨在帮助用户更高效地进行记忆工作与学习规划. 以 AGPL-3.0 开放源代码 +您可在项目主页 https://ams.imwangzhiyu.xyz 获取用户指南, 开发文档与软件更新 + +如果您觉得这个软件有用, 请给它添加一个星标 :) + 开发人员: - Wang Zhiyu([@pluvium27](https://github.com/pluvium27)): 项目作者 diff --git a/src/heurams/interface/screens/dashboard.py b/src/heurams/interface/screens/dashboard.py index 34f8902..0d510b9 100644 --- a/src/heurams/interface/screens/dashboard.py +++ b/src/heurams/interface/screens/dashboard.py @@ -1,12 +1,14 @@ """仪表盘界面""" import pathlib +from pathlib import Path from textual.app import ComposeResult from textual.containers import ScrollableContainer from textual.screen import Screen from textual.widgets import Button, Footer, Header, Label, ListItem, ListView, Static +import heurams.kernel.particles as pt import heurams.services.timer as timer import heurams.services.version as version from heurams.context import * @@ -14,10 +16,10 @@ from heurams.kernel.particles import * from heurams.kernel.repolib import * from heurams.services.logger import get_logger -import heurams.kernel.particles as pt -from pathlib import Path from .about import AboutScreen +from .navigator import NavigatorScreen from .preparation import PreparationScreen +from .radio import RadioScreen logger = get_logger(__name__) @@ -26,6 +28,9 @@ class DashboardScreen(Screen): """主仪表盘屏幕""" SUB_TITLE = "仪表盘" + BINDINGS = [ + ("q", "go_back", "返回"), + ] def __init__( self, @@ -50,12 +55,12 @@ class DashboardScreen(Screen): Label(f"全局算法设置: {config_var.get()['algorithm']['default']}"), Label("选择待学习或待修改的项目:", classes="title-label"), ListView(id="repo-list", classes="repo-list-view"), - Label(f'"潜进" 启发式辅助记忆调度器 | 版本 {version.ver} '), + Label(f'"潜进" 启发式辅助记忆调度器 版本 {version.ver} '), ) yield Footer() def _load_data(self): - self.repo_dirs = Repo.probe_vaild_repos_in_dir( + self.repo_dirs = Repo.probe_valid_repos_in_dir( Path(config_var.get()["paths"]["data"]) / "repo" ) for repo_dir in self.repo_dirs: @@ -69,7 +74,6 @@ class DashboardScreen(Screen): unit_sum = len(repo) activated_sum = 0 nextdate = 0x3F3F3F3F - is_unfinished = unit_sum > activated_sum for i in repo.ident_index: nucleon = pt.Nucleon.create_on_nucleonic_data( nucleonic_data=repo.nucleonic_data_lict.get_itemic_unit(i) @@ -82,10 +86,11 @@ class DashboardScreen(Screen): if electron.is_due(): is_due = 1 nextdate = min(nextdate, electron.nextdate()) + is_unfinished = unit_sum > activated_sum if is_unfinished: nextdate = min(nextdate, timer.get_daystamp()) need_to_study = is_due or is_unfinished - prompt = f"{title}\0\n 进度: {activated_sum}/{unit_sum}\n {"需要学习" if need_to_study else "无需操作"}" + prompt = f"{title}\0\n 进度: {activated_sum}/{unit_sum} ({round(activated_sum/unit_sum*100)}%)\n {"需要学习" if need_to_study else "无需操作"}" stat = { "is_due": is_due, "unit_sum": unit_sum, @@ -139,7 +144,7 @@ class DashboardScreen(Screen): return selected_label = event.item.query_one(Label) - label_text = str(selected_label.renderable) + label_text = str(selected_label.render()) if "未找到任何仓库" in label_text: return @@ -158,3 +163,12 @@ class DashboardScreen(Screen): def action_quit_app(self) -> None: """退出应用程序""" self.app.exit() + + def action_open_navigator(self) -> None: + """打开导航器""" + self.app.push_screen(NavigatorScreen()) + + def on_button_pressed(self, event: Button.Pressed) -> None: + """处理按钮点击事件""" + if event.button.id == "navigator-button": + self.action_open_navigator() diff --git a/src/heurams/interface/screens/favmgr.py b/src/heurams/interface/screens/favmgr.py new file mode 100644 index 0000000..b90a957 --- /dev/null +++ b/src/heurams/interface/screens/favmgr.py @@ -0,0 +1,204 @@ +"""收藏夹管理器界面""" + +import base64 +from pathlib import Path +from typing import List, Optional + +from textual.app import ComposeResult +from textual.containers import ScrollableContainer +from textual.screen import Screen +from textual.widgets import ( + Button, + Footer, + Header, + Label, + ListItem, + ListView, + Markdown, + Static, +) + +from heurams.context import config_var +from heurams.kernel.repolib import Repo +from heurams.services.favorite_service import FavoriteItem, favorite_manager +from heurams.services.logger import get_logger + +logger = get_logger(__name__) + + +class FavoriteManagerScreen(Screen): + """收藏夹管理器屏幕""" + + SUB_TITLE = "收藏夹" + + BINDINGS = [ + ("q", "go_back", "返回"), + ("d", "toggle_dark", ""), + ] + + def __init__( + self, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + ) -> None: + super().__init__(name, id, classes) + self.favorites: List[FavoriteItem] = [] + self._load_favorites() + + def _load_favorites(self) -> None: + """加载收藏列表""" + self.favorites = favorite_manager.get_all() + logger.debug("加载 %d 个收藏项", len(self.favorites)) + + def compose(self) -> ComposeResult: + """组合界面组件""" + yield Header(show_clock=True) + with ScrollableContainer(id="favorites-container"): + if not self.favorites: + yield Label("暂无收藏", classes="empty-label") + yield Static("使用 * 键在记忆界面中添加收藏.") + else: + yield Label(f"共 {len(self.favorites)} 个收藏项", classes="count-label") + yield ListView(id="favorites-list") + yield Footer() + + def on_mount(self) -> None: + """挂载后填充列表""" + if self.favorites: + list_view = self.query_one("#favorites-list") + for fav in self.favorites: + list_view.append(self._create_favorite_item(fav)) # type: ignore + + def _encode_favorite_key(self, repo_path: str, ident: str) -> str: + """编码仓库路径和标识符为安全的按钮 ID 部分""" + # 使用 \x00 分隔两部分,然后进行 base64 编码 + combined = f"{repo_path}\x00{ident}" + encoded = base64.urlsafe_b64encode(combined.encode()).decode() + # 去掉填充的等号 + return encoded.rstrip("=") + + def _decode_favorite_key(self, key: str) -> tuple[str, str]: + """解码按钮 ID 部分为仓库路径和标识符""" + # 补全等号以使长度是4的倍数 + padded = key + "=" * ((4 - len(key) % 4) % 4) + decoded = base64.urlsafe_b64decode(padded.encode()).decode() + repo_path, ident = decoded.split("\x00", 1) + return repo_path, ident + + def _create_favorite_item(self, fav: FavoriteItem) -> ListItem: + """创建收藏项列表项""" + # 尝试获取仓库信息 + repo_info = self._get_repo_info(fav.repo_path, fav) + title = repo_info.get("title", fav.repo_path) if repo_info else fav.repo_path + content_preview = repo_info.get("content_preview", "") if repo_info else "" + added_time = self._format_time(fav.added) + + # 构建显示文本 + display_text = f"[b]{title}[/b] ({fav.ident})\n" + if content_preview: + display_text += f"{content_preview}\n" + display_text += f"添加于: {added_time}" + if fav.tags: + display_text += f" 标签: {', '.join(fav.tags)}" + + # 创建安全的按钮 ID + button_key = self._encode_favorite_key(fav.repo_path, fav.ident) + # 创建列表项,包含移除按钮 + container = ScrollableContainer( + Markdown(display_text, classes="favorite-content"), + Button("移除", id=f"remove-{button_key}", variant="error"), + classes="favorite-item", + ) + return ListItem(container) + + def _get_repo_info(self, repo_path: str, fav: FavoriteItem) -> Optional[dict]: + """获取仓库信息(标题、原子内容预览)""" + try: + data_repo = Path(config_var.get()["paths"]["data"]) / "repo" + repo_dir = data_repo / repo_path + if not repo_dir.exists(): + logger.warning("仓库目录不存在: %s", repo_dir) + return None + repo = Repo.create_from_repodir(repo_dir) + # 获取原子内容预览 + content_preview = "" + payload = repo.payload + # 查找对应 ident 的 payload 条目 + for ident_key, content in payload: + if ident_key == fav.ident: + # 截断过长的内容 + if isinstance(content, dict) and "content" in content: + text = content["content"] + else: + text = str(content) + if len(text) > 100: + content_preview = text[:100] + "..." + else: + content_preview = text + break + return { + "title": repo.manifest["title"], + "content_preview": content_preview, + } + except Exception as e: + logger.error("获取仓库信息失败: %s", e) + return None + + def _format_time(self, timestamp: int) -> str: + """格式化时间戳""" + from datetime import datetime + + dt = datetime.fromtimestamp(timestamp) + return dt.strftime("%Y-%m-%d %H:%M") + + def on_button_pressed(self, event: Button.Pressed) -> None: + """处理按钮点击事件""" + button_id = event.button.id + if button_id and button_id.startswith("remove-"): + # 提取编码后的键 + key = button_id[7:] # 去掉 "remove-" 前缀 + try: + repo_path, ident = self._decode_favorite_key(key) + self._remove_favorite(repo_path, ident) + except Exception as e: + logger.error("解析按钮 ID 失败: %s", e) + self.app.notify("操作失败: 无效的按钮标识", severity="error") + + def _remove_favorite(self, repo_path: str, ident: str) -> None: + """移除收藏项""" + if favorite_manager.remove(repo_path, ident): + self.app.notify(f"已移除收藏: {ident}", severity="information") + # 重新加载列表 + self._load_favorites() + # 刷新界面 + self._refresh_list() + else: + self.app.notify(f"移除失败: {ident}", severity="error") + + def _refresh_list(self) -> None: + """刷新列表显示""" + container = self.query_one("#favorites-container") + # 清空容器 + for child in container.children: + child.remove() + # 重新组合 + if not self.favorites: + container.mount(Label("暂无收藏", classes="empty-label")) + container.mount(Static("使用 * 键在记忆界面中添加收藏。")) + else: + container.mount( + Label(f"共 {len(self.favorites)} 个收藏项", classes="count-label") + ) + list_view = ListView(id="favorites-list") + container.mount(list_view) + for fav in self.favorites: + list_view.append(self._create_favorite_item(fav)) + + def action_go_back(self) -> None: + """返回上一屏幕""" + self.app.pop_screen() + + def action_toggle_dark(self) -> None: + """切换暗黑模式""" + self.app.dark = not self.app.dark # type: ignore diff --git a/src/heurams/interface/screens/intelinote.py b/src/heurams/interface/screens/intelinote.py deleted file mode 100644 index 4a489df..0000000 --- a/src/heurams/interface/screens/intelinote.py +++ /dev/null @@ -1 +0,0 @@ -"""笔记界面""" diff --git a/src/heurams/interface/screens/llmchat.py b/src/heurams/interface/screens/llmchat.py new file mode 100644 index 0000000..b444de7 --- /dev/null +++ b/src/heurams/interface/screens/llmchat.py @@ -0,0 +1,333 @@ +"""LLM 聊天界面""" + +import asyncio +from pathlib import Path +from typing import Optional + +from textual.app import ComposeResult +from textual.containers import Container, Horizontal +from textual.screen import Screen +from textual.widgets import Button, Footer, Header, Input, Label, RichLog, Static + +from heurams.context import * +from heurams.services.llm_service import ChatSession, get_chat_manager +from heurams.services.logger import get_logger + +logger = get_logger(__name__) + + +class LLMChatScreen(Screen): + """LLM 聊天屏幕""" + + SUB_TITLE = "AI 聊天" + BINDINGS = [ + ("q", "go_back", "返回"), + ("ctrl+s", "save_session", "保存会话"), + ("ctrl+l", "load_session", "加载会话"), + ("ctrl+n", "new_session", "新建会话"), + ("ctrl+c", "clear_history", "清空历史"), + ("escape", "focus_input", "聚焦输入"), + ] + + def __init__( + self, + session_id: Optional[str] = None, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + ) -> None: + super().__init__(name, id, classes) + self.session_id = session_id + self.chat_manager = get_chat_manager() + self.current_session: Optional[ChatSession] = None + self.is_streaming = False + + def compose(self) -> ComposeResult: + """组合界面组件""" + yield Header(show_clock=True) + + with Container(id="chat-container"): + # 顶部工具栏 + with Horizontal(id="toolbar"): + yield Button("新建会话", id="new-session", variant="primary") + yield Button("保存会话", id="save-session", variant="default") + yield Button("加载会话", id="load-session", variant="default") + yield Button("清空历史", id="clear-history", variant="default") + yield Button("设置系统提示", id="set-system-prompt", variant="default") + yield Static(" | ", classes="separator") + yield Label("当前会话:", classes="label") + yield Static(id="current-session-label", classes="session-label") + + # 聊天记录显示区域 + yield RichLog( + id="chat-log", + wrap=True, + highlight=True, + markup=True, + classes="chat-log", + ) + + # 输入区域 + with Horizontal(id="input-container"): + yield Input( + id="message-input", + placeholder="输入消息... (按 Ctrl+Enter 发送, Esc 聚焦)", + classes="message-input", + ) + yield Button( + "发送", id="send-button", variant="primary", classes="send-button" + ) + + # 状态栏 + yield Static(id="status-bar", classes="status-bar") + + yield Footer() + + def on_mount(self) -> None: + """挂载组件时初始化""" + # 获取或创建会话 + self.current_session = self.chat_manager.get_session(self.session_id) + if self.current_session is None: + self.notify("无法创建 LLM 会话,请检查配置", severity="error") + return + + # 更新会话标签 + self.query_one("#current-session-label", Static).update( + f"{self.current_session.session_id}" + ) + + # 加载历史消息到聊天记录 + self._display_history() + + # 聚焦输入框 + self.query_one("#message-input", Input).focus() + + # 检查配置 + self._check_config() + + def _check_config(self): + """检查 LLM 配置""" + config = config_var.get() + provider_name = config["services"]["llm"] + provider_config = config["providers"]["llm"][provider_name] + + if provider_name == "openai": + if not provider_config.get("key") and not provider_config.get("url"): + self.notify( + "未配置 OpenAI API key 或 URL,请在 config.toml 中配置 [providers.llm.openai]", + severity="warning", + ) + + def _display_history(self): + """显示当前会话的历史消息""" + if not self.current_session: + return + + chat_log = self.query_one("#chat-log", RichLog) + chat_log.clear() + + for msg in self.current_session.get_history(): + role = msg["role"] + content = msg["content"] + + if role == "user": + chat_log.write(f"[bold cyan]你:[/bold cyan] {content}") + elif role == "assistant": + chat_log.write(f"[bold green]AI:[/bold green] {content}") + elif role == "system": + # 系统消息不显示在聊天记录中 + pass + + def _add_message_to_log(self, role: str, content: str): + """添加消息到聊天记录显示""" + chat_log = self.query_one("#chat-log", RichLog) + if role == "user": + chat_log.write(f"[bold cyan]你:[/bold cyan] {content}") + elif role == "assistant": + chat_log.write(f"[bold green]AI:[/bold green] {content}") + chat_log.scroll_end() + + async def on_input_submitted(self, event: Input.Submitted): + """处理输入提交""" + if event.input.id == "message-input": + await self._send_message() + + async def on_button_pressed(self, event: Button.Pressed): + """处理按钮点击""" + button_id = event.button.id + + if button_id == "send-button": + await self._send_message() + elif button_id == "new-session": + self.action_new_session() + elif button_id == "save-session": + self.action_save_session() + elif button_id == "load-session": + self.action_load_session() + elif button_id == "clear-history": + self.action_clear_history() + elif button_id == "set-system-prompt": + self.action_set_system_prompt() + + async def _send_message(self): + """发送当前输入的消息""" + if not self.current_session or self.is_streaming: + return + + input_widget = self.query_one("#message-input", Input) + message = input_widget.value.strip() + + if not message: + return + + # 清空输入框 + input_widget.value = "" + + # 显示用户消息 + self._add_message_to_log("user", message) + + # 禁用输入和按钮 + self._set_input_state(disabled=True) + self.is_streaming = True + + # 更新状态 + self.query_one("#status-bar", Static).update("AI 正在思考...") + + try: + # 发送消息并获取响应 + response = await self.current_session.send_message(message) + + # 显示AI响应 + self._add_message_to_log("assistant", response) + + except Exception as e: + error_msg = f"请求失败: {str(e)}" + logger.error(error_msg) + self._add_message_to_log("assistant", f"[red]{error_msg}[/red]") + self.notify(error_msg, severity="error") + + finally: + # 恢复输入和按钮 + self._set_input_state(disabled=False) + self.is_streaming = False + self.query_one("#status-bar", Static).update("就绪") + input_widget.focus() + + def _set_input_state(self, disabled: bool): + """设置输入控件状态""" + self.query_one("#message-input", Input).disabled = disabled + self.query_one("#send-button", Button).disabled = disabled + + async def action_save_session(self): + """保存当前会话到文件""" + if not self.current_session: + self.notify("无当前会话", severity="error") + return + + # 默认保存到 data/chat_sessions/ 目录 + save_dir = Path(config_var.get()["paths"]["data"]) / "chat_sessions" + save_dir.mkdir(exist_ok=True, parents=True) + + file_path = save_dir / f"{self.current_session.session_id}.json" + self.current_session.save_to_file(file_path) + + self.notify(f"会话已保存到 {file_path}", severity="information") + + async def action_load_session(self): + """从文件加载会话""" + # 简化实现:加载默认目录下的第一个会话文件 + save_dir = Path(config_var.get()["paths"]["data"]) / "chat_sessions" + if not save_dir.exists(): + self.notify(f"目录不存在: {save_dir}", severity="error") + return + + session_files = list(save_dir.glob("*.json")) + if not session_files: + self.notify("未找到会话文件", severity="error") + return + + # 使用第一个文件(在实际应用中可以让用户选择) + file_path = session_files[0] + + try: + # 获取 LLM 提供者 + provider_name = config_var.get()["services"]["llm"] + provider_config = config_var.get()["providers"]["llm"][provider_name] + from heurams.providers.llm import providers as prov + + llm_provider = prov[provider_name](provider_config) + + # 加载会话 + self.current_session = ChatSession.load_from_file(file_path, llm_provider) + + # 更新聊天管理器 + self.chat_manager.sessions[self.current_session.session_id] = ( + self.current_session + ) + + # 更新UI + self.query_one("#current-session-label", Static).update( + f"{self.current_session.session_id}" + ) + self._display_history() + + self.notify(f"已加载会话: {file_path.name}", severity="information") + + except Exception as e: + logger.error("加载会话失败: %s", e) + self.notify(f"加载失败: {str(e)}", severity="error") + + async def action_new_session(self): + """创建新会话""" + # 简单实现:使用时间戳作为会话ID + import time + + new_session_id = f"session_{int(time.time())}" + + self.current_session = self.chat_manager.get_session(new_session_id) + + # 更新UI + self.query_one("#current-session-label", Static).update( + f"{self.current_session.session_id}" + ) + self._display_history() + + self.notify(f"已创建新会话: {new_session_id}", severity="information") + self.query_one("#message-input", Input).focus() + + async def action_clear_history(self): + """清空当前会话历史""" + if not self.current_session: + return + + self.current_session.clear_history() + self._display_history() + self.notify("历史已清空", severity="information") + + async def action_set_system_prompt(self): + """设置系统提示词""" + if not self.current_session: + return + + # 使用输入框获取新提示词 + input_widget = self.query_one("#message-input", Input) + current_value = input_widget.value + + # 临时修改输入框提示 + input_widget.placeholder = "输入系统提示词... (按 Enter 确认, Esc 取消)" + input_widget.value = self.current_session.system_prompt + + # 等待用户输入 + self.notify("请输入系统提示词,按 Enter 确认", severity="information") + + # 实际应用中需要更复杂的交互,这里简化处理 + # 用户手动输入后按 Enter 会触发 on_input_submitted + # 这里我们只修改占位符,实际系统提示词设置需要额外界面 + + def action_focus_input(self): + """聚焦到输入框""" + self.query_one("#message-input", Input).focus() + + def action_go_back(self): + """返回上级屏幕""" + self.app.pop_screen() diff --git a/src/heurams/interface/screens/memoqueue.py b/src/heurams/interface/screens/memoqueue.py index 9ec5a38..a880d89 100644 --- a/src/heurams/interface/screens/memoqueue.py +++ b/src/heurams/interface/screens/memoqueue.py @@ -1,6 +1,7 @@ """队列式记忆工作界面""" from enum import Enum, auto +from pathlib import Path from typing import Callable from textual.app import ComposeResult @@ -9,10 +10,11 @@ from textual.reactive import reactive from textual.screen import Screen from textual.widgets import Button, Footer, Header, Label, Static -import heurams.kernel.puzzles as pz import heurams.kernel.particles as pt +import heurams.kernel.puzzles as pz from heurams.context import config_var from heurams.kernel.reactor import * +from heurams.services.favorite_service import favorite_manager from heurams.services.logger import get_logger from .. import shim @@ -32,6 +34,7 @@ class MemScreen(Screen): ("p", "prev", "查看上一个"), ("d", "toggle_dark", ""), ("v", "play_voice", "朗读"), + ("*", "toggle_favorite", "收藏"), ("0,1,2,3", "app.push_screen('about')", ""), ] @@ -44,6 +47,7 @@ class MemScreen(Screen): self, phaser: Phaser, save_func: Callable, + repo=None, name=None, id=None, classes=None, @@ -51,9 +55,9 @@ class MemScreen(Screen): super().__init__(name, id, classes) self.phaser = phaser self.save_func = save_func + self.repo = repo self.update_state() self.fission: Fission - def compose(self) -> ComposeResult: yield Header(show_clock=True) @@ -84,6 +88,10 @@ class MemScreen(Screen): def _get_progress_text(self): s = f"阶段: {self.procession.phase.name}\n" + # 收藏状态 + if self.repo is not None: + fav_status = "★" if self._is_current_atom_favorited() else "☆" + s += f"收藏: {fav_status}\n" if config_var.get().get("debug_topline", 0): try: alia = self.fission.get_current_puzzle_inf()["alia"] # type: ignore @@ -129,6 +137,7 @@ class MemScreen(Screen): for i in container.children: i.remove() from heurams.interface.widgets.finished import Finished + if config_var.get().get("persist_to_file", 0): self.save_func() container.mount(Finished(is_saved=config_var.get().get("persist_to_file", 0))) @@ -208,3 +217,40 @@ class MemScreen(Screen): def action_quick_fail(self): self.rating = 3 + + def _get_repo_rel_path(self) -> str: + """获取仓库相对路径(相对于 data/repo)""" + if self.repo is None: + return "" + # self.repo.source 是 Path 对象,指向仓库目录 + repo_full_path = self.repo.source + data_repo_path = Path(config_var.get()["paths"]["data"]) / "repo" + try: + rel_path = repo_full_path.relative_to(data_repo_path) + return str(rel_path) + except ValueError: + # 如果不在 data/repo 下,则返回完整路径(字符串形式) + return str(repo_full_path) + + def _is_current_atom_favorited(self) -> bool: + """检查当前原子是否已收藏""" + if self.repo is None: + return False + repo_path = self._get_repo_rel_path() + return favorite_manager.has(repo_path, self.atom.ident) + + def action_toggle_favorite(self): + """切换收藏状态""" + if self.repo is None: + self.app.notify("无法收藏:未关联仓库", severity="error") + return + repo_path = self._get_repo_rel_path() + ident = self.atom.ident + if favorite_manager.has(repo_path, ident): + favorite_manager.remove(repo_path, ident) + self.app.notify(f"已取消收藏:{ident}", severity="information") + else: + favorite_manager.add(repo_path, ident) + self.app.notify(f"已收藏:{ident}", severity="information") + # 更新显示(如果需要) + self.update_display() diff --git a/src/heurams/interface/screens/navigator.py b/src/heurams/interface/screens/navigator.py new file mode 100644 index 0000000..4739a1a --- /dev/null +++ b/src/heurams/interface/screens/navigator.py @@ -0,0 +1,93 @@ +import webbrowser + +from textual.app import ComposeResult +from textual.containers import Grid, ScrollableContainer +from textual.screen import ModalScreen +from textual.widgets import Button, Footer, Header, Label, ListItem, ListView, Static + +from heurams.context import * +from heurams.services.logger import get_logger + +from .favmgr import FavoriteManagerScreen + +logger = get_logger(__name__) + + +class NavigatorScreen(ModalScreen): + """导航器模态窗口""" + + BINDINGS = [ + ("q", "go_back", "返回"), + ("escape", "go_back", "返回"), + ("n", "go_back", "切换"), + ] + + SCREENS = [ + ("仪表盘", "dashboard"), + ("电台", "radio"), + ("语言模型集成", "llmchat"), + # ("创建仓库", "repo_creator"), + ("缓存管理器", "precache_all"), + ("收藏夹管理器", FavoriteManagerScreen), + ("关于此软件", "about"), + ("调试日志", "logviewer"), + # ("同步工具", "synctool"), + # ("仓库编辑器", "repo_editor"), + ] + + OTHERS = [ + ("退出程序", "self.app.exit()"), + ("项目主页", "webbrowser.open('https://ams.imwangzhiyu.xyz')"), + ] + + def compose(self) -> ComposeResult: + """组合界面组件""" + with Grid(id="dialog"): + yield Label( + "[b]请选择要跳转的功能\n或记忆会话实例[/b]\n\n将在此处显示提示", + classes="title-label", + ) + yield ListView( + *[ListItem(Label(title)) for title, _ in (self.SCREENS + self.OTHERS)], + id="nav-list", + classes="nav-list-view", + ) + yield Static("按下回车以完成切换\n所有会话将被保存") + yield Button( + "关闭 (n)", id="close_button", variant="primary", classes="close-button", flat=True + ) + + def on_mount(self) -> None: + # 设置焦点到列表 + nav_list = self.query_one("#nav-list", ListView) + nav_list.focus() + + def on_list_view_selected(self, event) -> None: + if not isinstance(event.item, ListItem): + return + selected_label = event.item.query_one(Label) + label_text = str(selected_label.render()) + # 查找对应的屏幕标识 + for title, screen_id in self.SCREENS: + if title == label_text: + self.app.pop_screen() + # 跳转到目标屏幕 + if isinstance(screen_id, str): + # 已注册的字符串标识符 + self.app.push_screen(screen_id) + else: + self.app.push_screen(screen_id()) + return + for title, cmd in self.OTHERS: + if title == label_text: + exec(cmd) + return + return + + def on_button_pressed(self, event) -> None: + event.stop() + if event.button.id == "close_button": + self.action_go_back() + + def action_go_back(self) -> None: + self.app.pop_screen() diff --git a/src/heurams/interface/screens/precache.py b/src/heurams/interface/screens/precache.py index 0bb54d5..300690b 100644 --- a/src/heurams/interface/screens/precache.py +++ b/src/heurams/interface/screens/precache.py @@ -3,7 +3,7 @@ import pathlib from textual.app import ComposeResult -from textual.containers import Horizontal, ScrollableContainer +from textual.containers import Horizontal, ScrollableContainer, Container from textual.screen import Screen from textual.widgets import Button, Footer, Header, Label, ProgressBar, Static from textual.worker import get_current_worker @@ -12,7 +12,18 @@ import heurams.kernel.particles as pt import heurams.services.hasher as hasher from heurams.context import * -cache_dir = pathlib.Path(config_var.get()["paths"]["data"]) / "cache" / "voice" +# 兼容性缓存路径:优先使用 paths.cache,否则使用 data/cache +paths = config_var.get()["paths"] +cache_dir = pathlib.Path(paths.get("cache", paths["data"] + "/cache")) / "voice" + + +def format_size(bytes_num: int) -> str: + """将字节数格式化为人类可读的字符串""" + for unit in ['B', 'KB', 'MB', 'GB', 'TB']: + if bytes_num < 1024.0: + return f"{bytes_num:.2f} {unit}" + bytes_num /= 1024.0 # type: ignore + return f"{bytes_num:.2f} PB" class PrecachingScreen(Screen): @@ -26,7 +37,9 @@ class PrecachingScreen(Screen): """ SUB_TITLE = "缓存管理器" - BINDINGS = [("q", "go_back", "返回")] + BINDINGS = [ + ("q", "go_back", "返回"), + ] def __init__(self, nucleons: list = [], desc: str = ""): super().__init__(name=None, id=None, classes=None) @@ -40,21 +53,70 @@ class PrecachingScreen(Screen): self.precache_worker = None self.cancel_flag = 0 self.desc = desc + # 不再需要缓存配置,保留配置读取以兼容 + self.cache_stats = {"total_size": 0, "file_count": 0, "human_size": "0 B", "cached_units": 0, "total_units": 0, "cache_rate": 0} + self._update_cache_stats() + + def _get_total_units(self) -> int: + """获取所有仓库的总单元数""" + from heurams.context import config_var + from heurams.kernel.repolib import Repo + repo_path = pathlib.Path(config_var.get()["paths"]["data"]) / "repo" + repo_dirs = Repo.probe_valid_repos_in_dir(repo_path) + repos = map(Repo.create_from_repodir, repo_dirs) + total = 0 + for repo in repos: + try: + total += len(repo.ident_index) + except: + continue + return total + + def _update_cache_stats(self) -> None: + """更新缓存统计信息""" + total_size = 0 + file_count = 0 + cached_units = 0 + if cache_dir.exists(): + for file in cache_dir.rglob("*"): + if file.is_file(): + total_size += file.stat().st_size + file_count += 1 + if file.suffix.lower() == ".wav": + cached_units += 1 + total_units = self._get_total_units() + cache_rate = (cached_units / total_units * 100) if total_units > 0 else 0 + + self.cache_stats["total_size"] = total_size + self.cache_stats["file_count"] = file_count + self.cache_stats["human_size"] = format_size(total_size) + self.cache_stats["cached_units"] = cached_units + self.cache_stats["total_units"] = total_units + self.cache_stats["cache_rate"] = cache_rate def compose(self) -> ComposeResult: yield Header(show_clock=True) with ScrollableContainer(id="precache_container"): yield Label("[b]音频预缓存[/b]", classes="title-label") + with Container(classes="cache-info"): + yield Static(f"缓存路径: {cache_dir}", classes="cache-path") + yield Static(f"文件数: {self.cache_stats['file_count']}", classes="cache-count") + yield Static(f"总大小: {self.cache_stats['human_size']}", classes="cache-size") + yield Button("刷新", id="refresh_cache_stats", variant="default") + with Container(): + yield Static( + f"缓存率: {self.cache_stats.get('cache_rate', 0):.1f}% (已缓存 {self.cache_stats.get('cached_units', 0)} / {self.cache_stats.get('total_units', 0)} 个单元)", + classes="cache-usage-text" + ) + if self.nucleons: + yield Static(f"目标单元归属: [b]{self.desc}[/b]", classes="target-info") + yield Static(f"单元数量: {len(self.nucleons)}", classes="target-info") + else: + yield Static("目标: 所有单元", classes="target-info") - if self.nucleons: - yield Static(f"目标单元归属: [b]{self.desc}[/b]", classes="target-info") - yield Static(f"单元数量: {len(self.nucleons)}", classes="target-info") - else: - yield Static("目标: 所有单元", classes="target-info") - - yield Static(id="status", classes="status-info") - yield Static(id="current_item", classes="current-item") - yield ProgressBar(total=100, show_eta=False, id="progress_bar") + yield Static(id="status", classes="status-info") + yield Static(id="current_item", classes="current-item") + yield ProgressBar(total=100, show_eta=False, id="progress_bar") with Horizontal(classes="button-group"): if not self.is_precaching: @@ -72,6 +134,7 @@ class PrecachingScreen(Screen): def on_mount(self): """挂载时初始化状态""" self.update_status("就绪", "等待开始...") + self._update_cache_display() def update_status(self, status, current_item="", progress=None): """更新状态显示""" @@ -86,6 +149,25 @@ class PrecachingScreen(Screen): progress_bar.progress = progress progress_bar.advance(0) # 刷新显示 + def _update_cache_display(self) -> None: + """更新缓存信息显示""" + # 更新统计信息 + self._update_cache_stats() + # 更新缓存率进度条 + # 更新缓存大小和文件数显示 + cache_count_widget = self.query_one(".cache-count", Static) + cache_size_widget = self.query_one(".cache-size", Static) + cache_usage_text = self.query_one(".cache-usage-text", Static) + if cache_count_widget: + cache_count_widget.update(f"文件数: {self.cache_stats['file_count']}") + if cache_size_widget: + cache_size_widget.update(f"总大小: {self.cache_stats['human_size']}") + if cache_usage_text: + cache_usage_text.update( + f"缓存率: {self.cache_stats.get('cache_rate', 0):.1f}% " + f"(已缓存 {self.cache_stats.get('cached_units', 0)} / {self.cache_stats.get('total_units', 0)} 个单元)" + ) + def precache_by_text(self, text: str): """预缓存单段文本的音频""" from heurams.context import config_var, rootdir, workdir @@ -151,7 +233,7 @@ class PrecachingScreen(Screen): from heurams.kernel.repolib import Repo repo_path = pathlib.Path(config_var.get()["paths"]["data"]) / "repo" - repo_dirs = Repo.probe_vaild_repos_in_dir(repo_path) + repo_dirs = Repo.probe_valid_repos_in_dir(repo_path) repos = map(Repo.create_from_repodir, repo_dirs) # 计算总项目数 @@ -207,12 +289,17 @@ class PrecachingScreen(Screen): shutil.rmtree(cache_dir, ignore_errors=True) self.update_status("已清空", "音频缓存已清空", 0) + self._update_cache_display() # 更新缓存统计显示 except Exception as e: self.update_status("错误", f"清空缓存失败: {e}") self.cancel_flag = 1 self.processed = 0 self.progress = 0 + elif event.button.id == "refresh_cache_stats": + # 刷新缓存统计信息 + self._update_cache_display() + self.app.notify("缓存信息已刷新", severity="information") elif event.button.id == "go_back": self.action_go_back() @@ -220,8 +307,3 @@ class PrecachingScreen(Screen): if self.is_precaching and self.precache_worker: self.precache_worker.cancel() self.app.pop_screen() - - def action_quit_app(self): - if self.is_precaching and self.precache_worker: - self.precache_worker.cancel() - self.app.exit() diff --git a/src/heurams/interface/screens/preparation.py b/src/heurams/interface/screens/preparation.py index ee36806..79acb11 100644 --- a/src/heurams/interface/screens/preparation.py +++ b/src/heurams/interface/screens/preparation.py @@ -11,8 +11,8 @@ import heurams.kernel.particles as pt import heurams.services.hasher as hasher from heurams.context import * from heurams.context import config_var -from heurams.services.logger import get_logger from heurams.kernel.repolib import * +from heurams.services.logger import get_logger logger = get_logger(__name__) @@ -59,7 +59,8 @@ class PreparationScreen(Screen): ) yield Static(f"\n单元预览:\n") - yield Markdown(self._get_full_content().replace("/", ""), classes="full") + for i in self._get_full_content().replace("/", "").splitlines(): + yield Static(i, classes="full") yield Footer() # def watch_scheduled_num(self, old_scheduled_num, new_scheduled_num): @@ -76,7 +77,7 @@ class PreparationScreen(Screen): n = pt.Nucleon.create_on_nucleonic_data( nucleonic_data=self.repo.nucleonic_data_lict.get_itemic_unit(i) ) - content += f"- {n['content']} \n" + content += f" • {n['content']} \n" return content def action_go_back(self): @@ -126,14 +127,14 @@ class PreparationScreen(Screen): left_new -= 1 if left_new >= 0: atoms_to_provide.append(i) - from .memoqueue import MemScreen import heurams.kernel.reactor as rt + from .memoqueue import MemScreen + pheser = rt.Phaser(atoms_to_provide) save_func = self.repo.persist_to_repodir - memscreen = MemScreen(pheser, save_func) + memscreen = MemScreen(pheser, save_func, repo=self.repo) self.app.push_screen(memscreen) - elif event.button.id == "precache_button": self.action_precache() diff --git a/src/heurams/interface/screens/radio.py b/src/heurams/interface/screens/radio.py index dd05b12..000fb2a 100644 --- a/src/heurams/interface/screens/radio.py +++ b/src/heurams/interface/screens/radio.py @@ -1 +1,218 @@ +"""用于筛选当日记忆的条目 以音频形式重放""" + """ "前进电台" 界面""" +import os +from pathlib import Path +from typing import List, Optional + +from matplotlib.cbook import ls_mapper +from textual.app import ComposeResult +from textual.containers import Container, ScrollableContainer +from textual.reactive import reactive +from textual.screen import Screen +from textual.widgets import Button, Footer, Header, Label, Static + +import heurams.kernel.particles as pt +from heurams.kernel.repolib import Repo +from heurams.context import config_var +from heurams.services.audio_service import play_by_path +from heurams.services.hasher import get_md5 +from heurams.services.logger import get_logger +from heurams.services.tts_service import convertor + +logger = get_logger(__name__) + + +class RadioScreen(Screen): + SUB_TITLE = "电台" + + BINDINGS = [ + ("q", "go_back", "返回"), + ("space", "toggle_play", "播放/暂停"), + ] + + # 当前播放的原子索引 + current_index = reactive(0) + # 播放状态: 'stopped', 'playing', 'paused' + play_state = reactive("stopped") + + def __init__( + self, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + ) -> None: + super().__init__(name, id, classes) + self._organizer() + + def _organizer(self): + repodirs = Repo.probe_valid_repos_in_dir(Path(config_var.get()['paths']['data']) / 'repo') + repos = list(map(lambda repodir: Repo.create_from_repodir(repodir), repodirs)) + for repo in repos: + last_modify = 0.0 + for i in repo.ident_index: + e = pt.Electron.create_on_electonic_data( + electronic_data=repo.electronic_data_lict.get_itemic_unit(i) + ) + last_modify = max(last_modify, e.las()) + + + def compose(self) -> ComposeResult: + yield Header(show_clock=True) + with Container(id="main"): + yield Label("[b]前进电台[/b]", classes="title") + yield Static(f"共 {len(self.atoms)} 条当日记忆", id="status") + with Container(id="controls"): + yield Button("播放", id="play", variant="success") + yield Button("暂停", id="pause", variant="primary") + yield Button("上一首", id="prev", variant="default") + yield Button("下一首", id="next", variant="default") + yield Button("停止", id="stop", variant="error") + yield ScrollableContainer(id="playlist") + yield Footer() + + def on_mount(self) -> None: + """挂载后更新播放列表显示""" + self._update_playlist() + + def _filter_due_atoms(self) -> List[pt.Atom]: + """筛选当日需要复习的原子(已激活且到期)""" + atoms = [] + for ident in self.repo.ident_index: + n = pt.Nucleon.create_on_nucleonic_data( + nucleonic_data=self.repo.nucleonic_data_lict.get_itemic_unit(ident) + ) + e = pt.Electron.create_on_electonic_data( + electronic_data=self.repo.electronic_data_lict.get_itemic_unit(ident) + ) + a = pt.Atom(n, e, self.repo.orbitic_data) + # 仅选择已激活且到期的原子 + if ( + a.registry["electron"].is_activated() + and a.registry["electron"].is_due() + ): + atoms.append(a) + return atoms + + def _update_playlist(self) -> None: + """更新播放列表显示""" + container = self.query_one("#playlist") + container.remove_children() + for idx, atom in enumerate(self.atoms): + content = atom.registry["nucleon"].get("content", "无内容") + prefix = "▶ " if idx == self.current_index else " " + widget = Static(f"{prefix}{idx+1}. {content[:50]}...") + widget.set_class(idx == self.current_index, "current") + container.mount(widget) + + def _get_audio_path(self, atom: pt.Atom) -> Path: + """返回音频文件路径,若不存在则生成""" + tts_text = atom.registry["nucleon"].get("tts_text", "") + if not tts_text: + tts_text = atom.registry["nucleon"].get("content", "") + voice_dir = Path(config_var.get()["paths"]["data"]) / "cache" / "voice" + voice_dir.mkdir(parents=True, exist_ok=True) + path = voice_dir / f"{get_md5(tts_text)}.wav" + if not path.exists(): + convertor(tts_text, path) + return path + + async def _play_atom(self, idx: int) -> None: + """播放指定索引的原子(异步)""" + if idx < 0 or idx >= len(self.atoms): + return + atom = self.atoms[idx] + try: + path = self._get_audio_path(atom) + self._current_path = path + # 在后台线程中播放,避免阻塞UI + await self.run_worker( + lambda: play_by_path(path), exclusive=True, thread=True + ) + except Exception as e: + logger.error("播放失败: %s", e) + + def _stop_playback(self) -> None: + """停止当前播放""" + if self._play_task and not self._play_task.done(): + self._play_task.cancel() + self._play_task = None + self._current_path = None + self.play_state = "stopped" + + async def _play_current(self) -> None: + """播放当前索引的原子""" + self._stop_playback() + self.play_state = "playing" + self._play_task = asyncio.create_task(self._play_atom(self.current_index)) + try: + await self._play_task + except asyncio.CancelledError: + pass + finally: + if self.play_state == "playing": + self.play_state = "stopped" + + # 按钮事件处理 + def on_button_pressed(self, event: Button.Pressed) -> None: + button_id = event.button.id + if button_id == "play": + self.action_toggle_play() + elif button_id == "pause": + self.action_pause() + elif button_id == "prev": + self.action_prev() + elif button_id == "next": + self.action_next() + elif button_id == "stop": + self.action_stop() + + # 键盘动作 + def action_toggle_play(self) -> None: + if self.play_state == "playing": + self.action_pause() + else: + self.action_play() + + def action_play(self) -> None: + if self.play_state != "playing": + if self.play_state == "paused": + # 恢复播放(目前暂停功能简单实现为停止) + self.play_state = "playing" + else: + asyncio.create_task(self._play_current()) + + def action_pause(self) -> None: + if self.play_state == "playing": + self._stop_playback() + self.play_state = "paused" + + def action_stop(self) -> None: + self._stop_playback() + self.play_state = "stopped" + + def action_next(self) -> None: + if self.current_index < len(self.atoms) - 1: + self.current_index += 1 + self._update_playlist() + if self.play_state == "playing": + asyncio.create_task(self._play_current()) + + def action_prev(self) -> None: + if self.current_index > 0: + self.current_index -= 1 + self._update_playlist() + if self.play_state == "playing": + asyncio.create_task(self._play_current()) + + def action_go_back(self) -> None: + self._stop_playback() + self.app.pop_screen() + + # 响应式更新 + def watch_current_index(self, old: int, new: int) -> None: + self._update_playlist() + + def watch_play_state(self, old: str, new: str) -> None: + # 更新按钮状态(可在此添加样式变化) + pass diff --git a/src/heurams/interface/screens/repocreator.py b/src/heurams/interface/screens/repocreator.py index 6b2e916..d60a9ec 100644 --- a/src/heurams/interface/screens/repocreator.py +++ b/src/heurams/interface/screens/repocreator.py @@ -24,7 +24,7 @@ class RepoCreatorScreen(Screen): from heurams.context import config_var - template_dir = Path(config_var.get()["paths"]["template_dir"]) + template_dir = Path(config_var.get()["paths"]["data"]) / "templates" templates = list() for i in template_dir.iterdir(): if i.name.endswith(".toml"): diff --git a/src/heurams/interface/screens/repoeditor.py b/src/heurams/interface/screens/repoeditor.py new file mode 100644 index 0000000..ec3f92c --- /dev/null +++ b/src/heurams/interface/screens/repoeditor.py @@ -0,0 +1,267 @@ +"""仓库编辑器, 使用TextArea控件等实现仓库配置编辑""" + +import json +from pathlib import Path +from typing import Optional + +import toml +from textual.app import ComposeResult +from textual.containers import Container, Horizontal, ScrollableContainer, Vertical +from textual.reactive import reactive +from textual.screen import Screen +from textual.widgets import ( + Button, + Footer, + Header, + Label, + ListItem, + ListView, + Static, + TextArea, +) + +from heurams.context import config_var +from heurams.kernel.repolib import Repo +from heurams.services.logger import get_logger + +logger = get_logger(__name__) + + +class RepoEditorScreen(Screen): + """仓库编辑器屏幕""" + + SUB_TITLE = "仓库编辑器" + + BINDINGS = [ + ("q", "go_back", "返回"), + ("s", "save_file", "保存"), + ("r", "reload_file", "重载"), + ("d", "toggle_dark", ""), + ] + + # 当前选择的仓库路径 + selected_repo_path: reactive[Optional[Path]] = reactive(None) + # 当前选择的文件名 + selected_filename: reactive[Optional[str]] = reactive(None) + # 文件内容 + file_content: reactive[str] = reactive("") + + def __init__( + self, + repo: Optional[Repo] = None, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + ) -> None: + super().__init__(name, id, classes) + self.repo = repo + self.repo_dir: Optional[Path] = None + self.file_list = [] + if repo is not None and repo.source is not None: + self.repo_dir = repo.source + self._load_file_list() + # selected_repo_path 将在 on_mount 中设置,避免触发watch时组件未就绪 + + def _load_file_list(self) -> None: + """加载仓库目录下的文件列表""" + if self.repo_dir is None: + return + self.file_list = [] + for fname in Repo.file_mapping.values(): + fpath = self.repo_dir / fname + if fpath.exists(): + self.file_list.append(fname) + # 也可能存在其他文件,但暂时只支持标准文件 + self.file_list.sort() + + def compose(self) -> ComposeResult: + """组合界面组件""" + yield Header(show_clock=True) + with Container(id="main_container"): + with Horizontal(id="top_panel"): + # 左侧: 仓库选择 + with Vertical(id="repo_selector", classes="panel"): + yield Label("仓库列表", classes="panel-title") + yield ListView( + *[ + ListItem(Label(repo_dir.name)) + for repo_dir in self._get_repo_dirs() + ], + id="repo_list", + classes="list-view", + ) + # 中间: 文件列表 + with Vertical(id="file_selector", classes="panel"): + yield Label("文件列表", classes="panel-title") + yield ListView( + *[ListItem(Label(fname)) for fname in self.file_list], + id="file_list", + classes="list-view", + ) + # 右侧: 编辑区域 + with Vertical(id="editor_panel", classes="panel"): + yield Label("编辑文件", classes="panel-title") + yield TextArea( + id="text_editor", + language="plaintext", + classes="text-editor", + ) + with Horizontal(id="button_bar"): + yield Button("保存", id="save_button", variant="primary") + yield Button("重载", id="reload_button", variant="default") + yield Button("返回", id="back_button", variant="error") + yield Footer() + + def _get_repo_dirs(self) -> list[Path]: + """获取data/repo/下所有有效仓库目录""" + repo_root = Path(config_var.get()["paths"]["data"]) / "repo" + repo_dirs = [] + if repo_root.exists(): + for entry in repo_root.iterdir(): + if entry.is_dir(): + # 检查是否存在 manifest.toml + if (entry / "manifest.toml").exists(): + repo_dirs.append(entry) + return repo_dirs + + def on_mount(self) -> None: + """挂载组件时初始化""" + # 如果已有仓库,设置 selected_repo_path 以触发watch(此时组件已就绪) + if self.repo_dir is not None: + self.selected_repo_path = self.repo_dir + # 焦点放在仓库列表 + self.query_one("#repo_list", ListView).focus() + + def watch_selected_repo_path( + self, old_path: Optional[Path], new_path: Optional[Path] + ) -> None: + """当选择的仓库路径变化时,加载文件列表""" + if new_path is None: + self.file_list = [] + self.selected_filename = None + self.file_content = "" + return + self.repo_dir = new_path + self._load_file_list() + # 如果组件已挂载,更新UI + if self.is_mounted: + file_list_view = self.query_one("#file_list", ListView) + file_list_view.clear() + for fname in self.file_list: + file_list_view.append(ListItem(Label(fname))) + # 清空编辑器 + self.query_one("#text_editor", TextArea).text = "" + self.selected_filename = None + + def watch_selected_filename( + self, old_name: Optional[str], new_name: Optional[str] + ) -> None: + """当选择的文件名变化时,加载文件内容""" + if new_name is None or self.repo_dir is None: + self.file_content = "" + return + file_path = self.repo_dir / new_name + if not file_path.exists(): + self.notify(f"文件不存在: {new_name}", severity="error") + return + try: + content = file_path.read_text(encoding="utf-8") + self.file_content = content + # 如果组件已挂载,更新编辑器 + if self.is_mounted: + editor = self.query_one("#text_editor", TextArea) + editor.text = content + # 根据文件后缀设置语言 + if new_name.endswith(".toml"): + editor.language = "toml" + elif new_name.endswith(".json"): + editor.language = "json" + else: + editor.language = "plaintext" + except Exception as e: + logger.error(f"读取文件失败: {e}") + self.notify(f"读取文件失败: {e}", severity="error") + + def watch_file_content(self, old_content: str, new_content: str) -> None: + """当文件内容变化时更新编辑器(仅当外部改变时)""" + # 目前不需要做任何事情,因为编辑器内容已绑定 + pass + + def on_list_view_selected(self, event) -> None: + """处理列表项选择事件""" + if not isinstance(event.item, ListItem): + return + list_id = event.list_view.id + selected_label = event.item.query_one(Label) + selected_text = str(selected_label.render()) + + if list_id == "repo_list": + # 用户选择了仓库 + repo_root = Path(config_var.get()["paths"]["data"]) / "repo" + selected_dir = repo_root / selected_text + if selected_dir.exists(): + self.selected_repo_path = selected_dir + elif list_id == "file_list": + # 用户选择了文件 + if self.repo_dir is None: + self.notify("请先选择仓库", severity="warning") + return + self.selected_filename = selected_text + + def on_button_pressed(self, event) -> None: + """处理按钮点击事件""" + event.stop() + if event.button.id == "save_button": + self.action_save_file() + elif event.button.id == "reload_button": + self.action_reload_file() + elif event.button.id == "back_button": + self.action_go_back() + + def action_save_file(self) -> None: + """保存当前编辑的文件""" + if self.repo_dir is None or self.selected_filename is None: + self.notify("未选择仓库或文件", severity="warning") + return + file_path = self.repo_dir / self.selected_filename + editor = self.query_one("#text_editor", TextArea) + new_content = editor.text + # 验证格式 + try: + if self.selected_filename.endswith(".toml"): + toml.loads(new_content) # 验证TOML + elif self.selected_filename.endswith(".json"): + json.loads(new_content) # 验证JSON + except Exception as e: + self.notify(f"格式错误: {e}", severity="error") + return + # 写入文件 + try: + file_path.write_text(new_content, encoding="utf-8") + self.notify("保存成功", severity="information") + except Exception as e: + logger.error(f"保存文件失败: {e}") + self.notify(f"保存文件失败: {e}", severity="error") + + def action_reload_file(self) -> None: + """重新加载当前文件(放弃修改)""" + if self.repo_dir is None or self.selected_filename is None: + self.notify("未选择仓库或文件", severity="warning") + return + file_path = self.repo_dir / self.selected_filename + try: + content = file_path.read_text(encoding="utf-8") + editor = self.query_one("#text_editor", TextArea) + editor.text = content + self.notify("已重载", severity="information") + except Exception as e: + logger.error(f"重载文件失败: {e}") + self.notify(f"重载文件失败: {e}", severity="error") + + def action_go_back(self) -> None: + """返回上一屏幕""" + self.app.pop_screen() + + def action_toggle_dark(self) -> None: + """切换暗色模式""" + self.app.dark = not self.app.dark diff --git a/src/heurams/interface/widgets/cloze_puzzle.py b/src/heurams/interface/widgets/cloze_puzzle.py index a669418..c140dba 100644 --- a/src/heurams/interface/widgets/cloze_puzzle.py +++ b/src/heurams/interface/widgets/cloze_puzzle.py @@ -7,8 +7,8 @@ from textual.message import Message from textual.widget import Widget from textual.widgets import Button, Label -import heurams.kernel.puzzles as pz import heurams.kernel.particles as pt +import heurams.kernel.puzzles as pz from heurams.services.logger import get_logger from .base_puzzle_widget import BasePuzzleWidget diff --git a/src/heurams/interface/widgets/finished.py b/src/heurams/interface/widgets/finished.py index bb25739..db1c31e 100644 --- a/src/heurams/interface/widgets/finished.py +++ b/src/heurams/interface/widgets/finished.py @@ -7,12 +7,12 @@ class Finished(Widget): self, *children: Widget, alia="", - is_saved = 0, + is_saved=0, name: str | None = None, id: str | None = None, classes: str | None = None, disabled: bool = False, - markup: bool = True + markup: bool = True, ) -> None: self.alia = alia self.is_saved = is_saved @@ -22,7 +22,7 @@ class Finished(Widget): id=id, classes=classes, disabled=disabled, - markup=markup + markup=markup, ) def compose(self): diff --git a/src/heurams/interface/widgets/mcq_puzzle.py b/src/heurams/interface/widgets/mcq_puzzle.py index 9cdb832..8499b94 100644 --- a/src/heurams/interface/widgets/mcq_puzzle.py +++ b/src/heurams/interface/widgets/mcq_puzzle.py @@ -5,8 +5,8 @@ from textual.containers import Container, ScrollableContainer from textual.widget import Widget from textual.widgets import Button, Label -import heurams.kernel.puzzles as pz import heurams.kernel.particles as pt +import heurams.kernel.puzzles as pz from heurams.services.hasher import hash from heurams.services.logger import get_logger diff --git a/src/heurams/interface/widgets/recognition.py b/src/heurams/interface/widgets/recognition.py index 233dc80..3917a7f 100644 --- a/src/heurams/interface/widgets/recognition.py +++ b/src/heurams/interface/widgets/recognition.py @@ -90,7 +90,7 @@ class Recognition(BasePuzzleWidget): for item in cfg["secondary"]: if isinstance(item, list): for j in item: - yield Markdown(f"### {metadata['annotation'][item]}: {j}") + yield Markdown(f"### {j}") #TODO ANNOTATION continue if isinstance(item, Dict): total = "" diff --git a/src/heurams/utils/__init__.py b/src/heurams/kernel/auxiliary/__init__.py similarity index 100% rename from src/heurams/utils/__init__.py rename to src/heurams/kernel/auxiliary/__init__.py diff --git a/src/heurams/utils/evalizor.py b/src/heurams/kernel/auxiliary/evalizor.py similarity index 100% rename from src/heurams/utils/evalizor.py rename to src/heurams/kernel/auxiliary/evalizor.py diff --git a/src/heurams/utils/lict.py b/src/heurams/kernel/auxiliary/lict.py similarity index 100% rename from src/heurams/utils/lict.py rename to src/heurams/kernel/auxiliary/lict.py diff --git a/src/heurams/utils/refvar.py b/src/heurams/kernel/auxiliary/refvar.py similarity index 100% rename from src/heurams/utils/refvar.py rename to src/heurams/kernel/auxiliary/refvar.py diff --git a/src/heurams/kernel/particles/__init__.py b/src/heurams/kernel/particles/__init__.py index 211e9d5..48accff 100644 --- a/src/heurams/kernel/particles/__init__.py +++ b/src/heurams/kernel/particles/__init__.py @@ -3,8 +3,8 @@ from .electron import Electron from .nucleon import Nucleon from .placeholders import ( AtomPlaceholder, - NucleonPlaceholder, ElectronPlaceholder, + NucleonPlaceholder, orbital_placeholder, ) diff --git a/src/heurams/kernel/particles/atom.py b/src/heurams/kernel/particles/atom.py index 3cd22a0..463add0 100644 --- a/src/heurams/kernel/particles/atom.py +++ b/src/heurams/kernel/particles/atom.py @@ -1,6 +1,5 @@ from typing import TypedDict - from heurams.services.logger import get_logger from .electron import Electron diff --git a/src/heurams/kernel/particles/electron.py b/src/heurams/kernel/particles/electron.py index ba58db5..4833b12 100644 --- a/src/heurams/kernel/particles/electron.py +++ b/src/heurams/kernel/particles/electron.py @@ -57,6 +57,10 @@ class Electron: result = self.algodata[self.algo.algo_name]["is_activated"] return result + def last_modify(self): + result = self.algodata[self.algo.algo_name]["last_modify"] + return result + def get_rating(self): try: result = self.algo.get_rating(self.algodata) @@ -68,6 +72,10 @@ class Electron: result = self.algo.nextdate(self.algodata) return result + def lastdate(self) -> int: + result = self.algodata[self.algo.algo_name]["lastdate"] + return result + def revisor(self, quality: int = 5, is_new_activation: bool = False): """算法迭代决策机制实现 diff --git a/src/heurams/kernel/particles/nucleon.py b/src/heurams/kernel/particles/nucleon.py index e34875f..9369bba 100644 --- a/src/heurams/kernel/particles/nucleon.py +++ b/src/heurams/kernel/particles/nucleon.py @@ -1,9 +1,9 @@ from copy import deepcopy from logging import config -from heurams.services.logger import get_logger -from heurams.utils.evalizor import Evalizer from heurams.context import config_var +from heurams.services.logger import get_logger +from heurams.kernel.auxiliary.evalizor import Evalizer logger = get_logger(__name__) diff --git a/src/heurams/kernel/particles/placeholders.py b/src/heurams/kernel/particles/placeholders.py index 6190d94..2e0b8c4 100644 --- a/src/heurams/kernel/particles/placeholders.py +++ b/src/heurams/kernel/particles/placeholders.py @@ -1,7 +1,8 @@ from heurams.kernel.particles import orbital + +from .atom import Atom from .electron import Electron from .nucleon import Nucleon -from .atom import Atom orbital_placeholder = { "schedule": ["quick_review", "recognition", "final_review"], diff --git a/src/heurams/kernel/reactor/fission.py b/src/heurams/kernel/reactor/fission.py index 4f23ccb..7d5812b 100644 --- a/src/heurams/kernel/reactor/fission.py +++ b/src/heurams/kernel/reactor/fission.py @@ -1,12 +1,13 @@ -from functools import reduce import random - -import heurams.kernel.puzzles as puz -import heurams.kernel.particles as pt -from heurams.services.logger import get_logger +from functools import reduce from tabulate import tabulate as tabu from transitions import Machine + +import heurams.kernel.particles as pt +import heurams.kernel.puzzles as puz +from heurams.services.logger import get_logger + from .states import FissionState, PhaserState logger = get_logger(__name__) diff --git a/src/heurams/kernel/reactor/phaser.py b/src/heurams/kernel/reactor/phaser.py index 9afdccd..19cf55e 100644 --- a/src/heurams/kernel/reactor/phaser.py +++ b/src/heurams/kernel/reactor/phaser.py @@ -1,8 +1,9 @@ from click import style +from transitions import Machine + import heurams.kernel.particles as pt from heurams.kernel.particles.placeholders import AtomPlaceholder from heurams.services.logger import get_logger -from transitions import Machine from .procession import Procession from .states import PhaserState, ProcessionState @@ -133,9 +134,10 @@ class Phaser(Machine): return Procession([AtomPlaceholder()], PhaserState.FINISHED) def __repr__(self, style="pipe", ends="\n"): - from heurams.services.textproc import truncate from tabulate import tabulate as tabu + from heurams.services.textproc import truncate + lst = [ { "Type": "Phaser", diff --git a/src/heurams/kernel/reactor/procession.py b/src/heurams/kernel/reactor/procession.py index 29d7fca..3f37ca5 100644 --- a/src/heurams/kernel/reactor/procession.py +++ b/src/heurams/kernel/reactor/procession.py @@ -1,7 +1,8 @@ +from tabulate import tabulate as tabu +from transitions import Machine + import heurams.kernel.particles as pt from heurams.services.logger import get_logger -from transitions import Machine -from tabulate import tabulate as tabu from .fission import Fission from .states import PhaserState, ProcessionState diff --git a/src/heurams/kernel/repolib/repo.py b/src/heurams/kernel/repolib/repo.py index ec5cbbc..b863a34 100644 --- a/src/heurams/kernel/repolib/repo.py +++ b/src/heurams/kernel/repolib/repo.py @@ -7,7 +7,7 @@ import toml import heurams.kernel.particles as pt -from ...utils.lict import Lict +from heurams.kernel.auxiliary.lict import Lict class RepoManifest(TypedDict): @@ -167,7 +167,7 @@ class Repo: return 0 @classmethod - def probe_vaild_repos_in_dir(cls, folder: Path): + def probe_valid_repos_in_dir(cls, folder: Path): lst = list() for i in folder.iterdir(): if i.is_dir(): diff --git a/src/heurams/providers/audio/protocol.py b/src/heurams/providers/audio/protocol.py deleted file mode 100644 index 664d32e..0000000 --- a/src/heurams/providers/audio/protocol.py +++ /dev/null @@ -1,13 +0,0 @@ -import pathlib -from typing import Protocol - -from heurams.services.logger import get_logger - -logger = get_logger(__name__) - - -class PlayFunctionProtocol(Protocol): - def __call__(self, path: pathlib.Path) -> None: ... - - -logger.debug("音频协议模块已加载") diff --git a/src/heurams/providers/llm/__init__.py b/src/heurams/providers/llm/__init__.py index e87befa..7cd42b7 100644 --- a/src/heurams/providers/llm/__init__.py +++ b/src/heurams/providers/llm/__init__.py @@ -1,6 +1,19 @@ # 大语言模型 from heurams.services.logger import get_logger +from .base import BaseLLM +from .openai import OpenAILLM + logger = get_logger(__name__) -logger.debug("LLM providers 模块已加载") +__all__ = [ + "BaseLLM", + "OpenAILLM", +] + +providers = { + "base": BaseLLM, + "openai": OpenAILLM, +} + +logger.debug("LLM providers 已注册: %s", list(providers.keys())) diff --git a/src/heurams/providers/llm/base.py b/src/heurams/providers/llm/base.py index b7c50c8..6426588 100644 --- a/src/heurams/providers/llm/base.py +++ b/src/heurams/providers/llm/base.py @@ -1,5 +1,55 @@ +"""LLM 提供者基类""" + +import asyncio +from typing import Any, Dict, List, Optional + from heurams.services.logger import get_logger logger = get_logger(__name__) -logger.debug("LLM 基类模块已加载") + +class BaseLLM: + """LLM 提供者基类""" + + name = "BaseLLM" + + def __init__(self, config: Dict[str, Any]): + """初始化 LLM 提供者 + + Args: + config: 提供者配置字典 + """ + self.config = config + logger.debug("BaseLLM 初始化完成") + + async def chat(self, messages: List[Dict[str, str]], **kwargs) -> str: + """发送聊天消息并获取响应 + + Args: + messages: 消息列表,每个消息为 {"role": "user"|"assistant"|"system", "content": "消息内容"} + **kwargs: 其他参数,如 temperature, max_tokens 等 + + Returns: + 模型返回的文本响应 + """ + logger.debug("BaseLLM.chat: messages=%d, kwargs=%s", len(messages), kwargs) + logger.warning("BaseLLM.chat 是基类方法,未实现具体功能") + await asyncio.sleep(0) # 避免未使用异步的警告 + return "BaseLLM 未实现具体功能" + + async def chat_stream(self, messages: List[Dict[str, str]], **kwargs): + """流式聊天(可选实现) + + Args: + messages: 消息列表 + **kwargs: 其他参数 + + Yields: + 流式响应的文本块 + """ + logger.debug( + "BaseLLM.chat_stream: messages=%d, kwargs=%s", len(messages), kwargs + ) + logger.warning("BaseLLM.chat_stream 是基类方法,未实现具体功能") + await asyncio.sleep(0) + yield "BaseLLM 未实现流式功能" diff --git a/src/heurams/providers/llm/openai.py b/src/heurams/providers/llm/openai.py index 43a74f7..252f45d 100644 --- a/src/heurams/providers/llm/openai.py +++ b/src/heurams/providers/llm/openai.py @@ -1,5 +1,96 @@ +"""OpenAI 兼容 LLM 提供者""" + +import asyncio +from typing import Any, AsyncGenerator, Dict, List, Optional + from heurams.services.logger import get_logger +from .base import BaseLLM + logger = get_logger(__name__) -logger.debug("OpenAI provider 模块已加载(未实现)") + +class OpenAILLM(BaseLLM): + """OpenAI 兼容 LLM 提供者""" + + name = "OpenAI" + + def __init__(self, config: Dict[str, Any]): + super().__init__(config) + self.api_key = config.get("key", "") + self.base_url = config.get("url", "https://api.openai.com/v1") + self._client = None + logger.debug("OpenAILLM 初始化完成: base_url=%s", self.base_url) + + def _get_client(self): + """获取 OpenAI 客户端(延迟导入)""" + if self._client is None: + try: + from openai import AsyncOpenAI + except ImportError: + logger.error("未安装 openai 库,请运行: pip install openai") + raise ImportError("未安装 openai 库,请运行: pip install openai") + + self._client = AsyncOpenAI( + api_key=self.api_key if self.api_key else None, + base_url=self.base_url if self.base_url else None, + ) + return self._client + + async def chat(self, messages: List[Dict[str, str]], **kwargs) -> str: + """发送聊天消息并获取响应""" + logger.debug("OpenAILLM.chat: messages=%d", len(messages)) + + client = self._get_client() + + # 默认参数 + default_kwargs = { + "model": kwargs.get("model", "gpt-3.5-turbo"), + "temperature": kwargs.get("temperature", 0.7), + "max_tokens": kwargs.get("max_tokens", 1000), + } + + # 合并参数,优先使用传入的 kwargs + request_kwargs = {**default_kwargs, **kwargs} + request_kwargs["messages"] = messages + + try: + response = await client.chat.completions.create(**request_kwargs) + content = response.choices[0].message.content + logger.debug( + "OpenAILLM.chat 成功: response length=%d", + len(content) if content else 0, + ) + return content or "" + except Exception as e: + logger.error("OpenAILLM.chat 失败: %s", e) + raise + + async def chat_stream( + self, messages: List[Dict[str, str]], **kwargs + ) -> AsyncGenerator[str, None]: + """流式聊天""" + logger.debug("OpenAILLM.chat_stream: messages=%d", len(messages)) + + client = self._get_client() + + # 默认参数 + default_kwargs = { + "model": kwargs.get("model", "gpt-3.5-turbo"), + "temperature": kwargs.get("temperature", 0.7), + "max_tokens": kwargs.get("max_tokens", 1000), + "stream": True, + } + + # 合并参数 + request_kwargs = {**default_kwargs, **kwargs} + request_kwargs["messages"] = messages + + try: + stream = await client.chat.completions.create(**request_kwargs) + async for chunk in stream: + if chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + except Exception as e: + logger.error("OpenAILLM.chat_stream 失败: %s", e) + raise diff --git a/src/heurams/services/favorite_service.py b/src/heurams/services/favorite_service.py new file mode 100644 index 0000000..a925a34 --- /dev/null +++ b/src/heurams/services/favorite_service.py @@ -0,0 +1,163 @@ +# 收藏服务 +import json +import shutil +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +from heurams.context import config_var +from heurams.services.logger import get_logger + +logger = get_logger(__name__) + + +@dataclass +class FavoriteItem: + """收藏项""" + + repo_path: str # 仓库相对路径 (相对于 data/repo) + ident: str # 原子标识符 + added: int # 添加时间戳 (UNIX 秒) + # 可选标签 + tags: List[str] | None = None + + def __post_init__(self): + if self.tags is None: + self.tags = [] + + def to_dict(self) -> dict: + return { + "repo_path": self.repo_path, + "ident": self.ident, + "added": self.added, + "tags": self.tags, + } + + @classmethod + def from_dict(cls, data: dict) -> "FavoriteItem": + return cls( + repo_path=data["repo_path"], + ident=data["ident"], + added=data["added"], + tags=data.get("tags", []), + ) + + +class FavoriteManager: + """收藏管理器""" + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if not hasattr(self, "_loaded"): + self._loaded = True + self._favorites: List[FavoriteItem] = [] + self._file_path = self._get_file_path() + self.load() + + def _get_file_path(self) -> Path: + """获取收藏文件路径""" + config_path = Path(config_var.get()["paths"]["data"]) + fav_path = config_path / "global" / "favorites.json" + fav_path.parent.mkdir(parents=True, exist_ok=True) + return fav_path + + def load(self) -> None: + """从文件加载收藏列表""" + if self._file_path.exists(): + try: + with open(self._file_path, "r", encoding="utf-8") as f: + data = json.load(f) + self._favorites = [FavoriteItem.from_dict(item) for item in data] + logger.debug("收藏列表加载成功,共 %d 项", len(self._favorites)) + except Exception as e: + logger.error("加载收藏列表失败: %s", e) + self._favorites = [] + else: + self._favorites = [] + + def save(self) -> None: + """保存收藏列表到文件""" + try: + data = [item.to_dict() for item in self._favorites] + with open(self._file_path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + logger.debug("收藏列表保存成功,共 %d 项", len(self._favorites)) + except Exception as e: + logger.error("保存收藏列表失败: %s", e) + + def add(self, repo_path: str, ident: str, tags: List[str] | None = None) -> bool: + """添加收藏 + + Args: + repo_path: 仓库相对路径 + ident: 原子标识符 + tags: 标签列表 + Returns: + 是否成功添加 (若已存在则返回 False) + """ + # 检查是否已存在 + for item in self._favorites: + if item.repo_path == repo_path and item.ident == ident: + logger.debug("收藏已存在: %s/%s", repo_path, ident) + return False + item = FavoriteItem( + repo_path=repo_path, + ident=ident, + added=int(time.time()), + tags=tags if tags else [], + ) + self._favorites.append(item) + self.save() + logger.info("添加收藏: %s/%s", repo_path, ident) + return True + + def remove(self, repo_path: str, ident: str) -> bool: + """移除收藏 + + Returns: + 是否成功移除 (若不存在则返回 False) + """ + for idx, item in enumerate(self._favorites): + if item.repo_path == repo_path and item.ident == ident: + del self._favorites[idx] + self.save() + logger.info("移除收藏: %s/%s", repo_path, ident) + return True + logger.debug("收藏不存在: %s/%s", repo_path, ident) + return False + + def has(self, repo_path: str, ident: str) -> bool: + """检查是否已收藏""" + for item in self._favorites: + if item.repo_path == repo_path and item.ident == ident: + return True + return False + + def get_all(self) -> List[FavoriteItem]: + """获取所有收藏项(按添加时间倒序)""" + return sorted(self._favorites, key=lambda x: x.added, reverse=True) + + def get_by_repo(self, repo_path: str) -> List[FavoriteItem]: + """获取指定仓库的所有收藏项""" + return [item for item in self._favorites if item.repo_path == repo_path] + + def clear(self) -> None: + """清空收藏列表""" + self._favorites = [] + self.save() + logger.info("清空收藏列表") + + def count(self) -> int: + """收藏总数""" + return len(self._favorites) + + +# 全局单例实例 +favorite_manager = FavoriteManager() diff --git a/src/heurams/services/llm_service.py b/src/heurams/services/llm_service.py new file mode 100644 index 0000000..2fd730e --- /dev/null +++ b/src/heurams/services/llm_service.py @@ -0,0 +1,228 @@ +"""LLM 聊天服务""" + +import asyncio +import json +from pathlib import Path +from typing import Any, Dict, List, Optional + +from heurams.context import config_var +from heurams.providers.llm import providers as prov +from heurams.services.logger import get_logger + +logger = get_logger(__name__) + + +class ChatSession: + """聊天会话,管理单个对话的历史和参数""" + + def __init__( + self, session_id: str, llm_provider, system_prompt: str = "", **default_params + ): + """初始化聊天会话 + + Args: + session_id: 会话唯一标识符 + llm_provider: LLM 提供者实例 + system_prompt: 系统提示词 + **default_params: 默认参数(temperature, max_tokens, model 等) + """ + self.session_id = session_id + self.llm_provider = llm_provider + self.system_prompt = system_prompt + self.default_params = default_params + + # 消息历史 + self.messages: List[Dict[str, str]] = [] + if system_prompt: + self.messages.append({"role": "system", "content": system_prompt}) + + logger.debug("创建聊天会话: id=%s", session_id) + + def add_message(self, role: str, content: str): + """添加消息到历史""" + self.messages.append({"role": role, "content": content}) + logger.debug( + "会话 %s 添加消息: role=%s, length=%d", self.session_id, role, len(content) + ) + + def clear_history(self): + """清空消息历史(保留系统提示)""" + self.messages = [] + if self.system_prompt: + self.messages.append({"role": "system", "content": self.system_prompt}) + logger.debug("会话 %s 清空历史", self.session_id) + + def set_system_prompt(self, prompt: str): + """设置系统提示词""" + self.system_prompt = prompt + # 更新消息历史中的系统消息 + if self.messages and self.messages[0]["role"] == "system": + self.messages[0]["content"] = prompt + elif prompt: + self.messages.insert(0, {"role": "system", "content": prompt}) + logger.debug("会话 %s 设置系统提示: length=%d", self.session_id, len(prompt)) + + async def send_message(self, message: str, **override_params) -> str: + """发送消息并获取响应 + + Args: + message: 用户消息内容 + **override_params: 覆盖默认参数 + + Returns: + 模型响应内容 + """ + # 添加用户消息 + self.add_message("user", message) + + # 合并参数 + params = {**self.default_params, **override_params} + + # 发送请求 + logger.debug("会话 %s 发送消息: length=%d", self.session_id, len(message)) + response = await self.llm_provider.chat(self.messages, **params) + + # 添加助手响应 + self.add_message("assistant", response) + + return response + + async def send_message_stream(self, message: str, **override_params): + """流式发送消息 + + Args: + message: 用户消息内容 + **override_params: 覆盖默认参数 + + Yields: + 流式响应的文本块 + """ + # 添加用户消息 + self.add_message("user", message) + + # 合并参数 + params = {**self.default_params, **override_params} + + # 发送流式请求 + logger.debug("会话 %s 发送流式消息: length=%d", self.session_id, len(message)) + + full_response = "" + async for chunk in self.llm_provider.chat_stream(self.messages, **params): + yield chunk + full_response += chunk + + # 添加完整的助手响应到历史 + self.add_message("assistant", full_response) + + def get_history(self) -> List[Dict[str, str]]: + """获取消息历史(不包括系统消息)""" + # 返回用户和助手的消息,可选排除系统消息 + return [msg for msg in self.messages if msg["role"] != "system"] + + def save_to_file(self, file_path: Path): + """保存会话到文件""" + data = { + "session_id": self.session_id, + "system_prompt": self.system_prompt, + "default_params": self.default_params, + "messages": self.messages, + } + with open(file_path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + logger.debug("会话 %s 保存到: %s", self.session_id, file_path) + + @classmethod + def load_from_file(cls, file_path: Path, llm_provider): + """从文件加载会话""" + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + + session = cls( + session_id=data["session_id"], + llm_provider=llm_provider, + system_prompt=data.get("system_prompt", ""), + **data.get("default_params", {}) + ) + session.messages = data["messages"] + logger.debug("从文件加载会话: %s", file_path) + return session + + +class ChatManager: + """聊天管理器,管理多个会话""" + + def __init__(self): + self.sessions: Dict[str, ChatSession] = {} + self.default_session_id = "default" + logger.debug("聊天管理器初始化完成") + + def get_session( + self, + session_id: Optional[str] = None, + create_if_missing: bool = True, + **session_params + ) -> Optional[ChatSession]: + """获取或创建聊天会话 + + Args: + session_id: 会话标识符,None 则使用默认会话 + create_if_missing: 如果会话不存在则创建 + **session_params: 传递给 ChatSession 的参数 + + Returns: + 聊天会话实例,如果不存在且不创建则返回 None + """ + if session_id is None: + session_id = self.default_session_id + + if session_id in self.sessions: + return self.sessions[session_id] + + if create_if_missing: + # 获取 LLM 提供者 + provider_name = config_var.get()["services"]["llm"] + provider_config = config_var.get()["providers"]["llm"][provider_name] + llm_provider = prov[provider_name](provider_config) + + session = ChatSession( + session_id=session_id, llm_provider=llm_provider, **session_params + ) + self.sessions[session_id] = session + logger.debug("创建新会话: id=%s", session_id) + return session + + return None + + def delete_session(self, session_id: str): + """删除会话""" + if session_id in self.sessions: + del self.sessions[session_id] + logger.debug("删除会话: id=%s", session_id) + + def list_sessions(self) -> List[str]: + """列出所有会话ID""" + return list(self.sessions.keys()) + + +# 全局聊天管理器实例 +_chat_manager: Optional[ChatManager] = None + + +def get_chat_manager() -> ChatManager: + """获取全局聊天管理器实例""" + global _chat_manager + if _chat_manager is None: + _chat_manager = ChatManager() + logger.debug("创建全局聊天管理器") + return _chat_manager + + +def create_chat_session( + session_id: Optional[str] = None, **session_params +) -> ChatSession: + """创建或获取聊天会话(便捷函数)""" + manager = get_chat_manager() + return manager.get_session(session_id, True, **session_params) + + +logger.debug("LLM 服务初始化完成")