支持gemini

This commit is contained in:
IndieKKY
2024-01-20 12:20:09 +08:00
parent 771a802728
commit 5078baca1b
7 changed files with 104 additions and 49 deletions

View File

@@ -45,7 +45,6 @@ const Body = () => {
const foldAll = useAppSelector(state => state.env.foldAll)
const envData = useAppSelector(state => state.env.envData)
const compact = useAppSelector(state => state.env.tempData.compact)
const apiKey = useAppSelector(state => state.env.envData.apiKey)
const floatKeyPointsSegIdx = useAppSelector(state => state.env.floatKeyPointsSegIdx)
const translateEnable = useAppSelector(state => state.env.envData.translateEnable)
const summarizeEnable = useAppSelector(state => state.env.envData.summarizeEnable)
@@ -76,6 +75,7 @@ const Body = () => {
}, [dispatch])
const onSummarizeAll = useCallback(() => {
const apiKey = envData.aiType === 'gemini'?envData.geminiApiKey:envData.apiKey
if (!apiKey) {
dispatch(setPage(PAGE_SETTINGS))
toast.error('需要先设置ApiKey!')
@@ -98,7 +98,7 @@ const Body = () => {
}
toast.success(`已添加${segments_.length}个总结任务!`)
}
}, [addSummarizeTask, apiKey, curSummaryType, dispatch, segments])
}, [addSummarizeTask, curSummaryType, dispatch, envData.aiType, envData.apiKey, envData.geminiApiKey, segments])
const onFoldAll = useCallback(() => {
dispatch(setFoldAll(!foldAll))
@@ -111,13 +111,14 @@ const Body = () => {
}, [dispatch, foldAll, segments])
const toggleAutoTranslateCallback = useCallback(() => {
if (envData.apiKey) {
const apiKey = envData.aiType === 'gemini'?envData.geminiApiKey:envData.apiKey
if (apiKey) {
dispatch(setAutoTranslate(!autoTranslate))
} else {
dispatch(setPage(PAGE_SETTINGS))
toast.error('需要先设置ApiKey!')
}
}, [autoTranslate, dispatch, envData.apiKey])
}, [autoTranslate, dispatch, envData.aiType, envData.apiKey, envData.geminiApiKey])
const onEnableAutoScroll = useCallback(() => {
dispatch(setAutoScroll(true))

View File

@@ -66,19 +66,20 @@ const Summarize = (props: {
const {segment, segmentIdx, summary, float} = props
const dispatch = useAppDispatch()
const apiKey = useAppSelector(state => state.env.envData.apiKey)
const envData = useAppSelector(state => state.env.envData)
const fontSize = useAppSelector(state => state.env.envData.fontSize)
const curSummaryType = useAppSelector(state => state.env.tempData.curSummaryType)
const {addSummarizeTask} = useTranslate()
const onGenerate = useCallback(() => {
const apiKey = envData.aiType === 'gemini'?envData.geminiApiKey:envData.apiKey
if (apiKey) {
addSummarizeTask(curSummaryType, segment).catch(console.error)
} else {
dispatch(setPage(PAGE_SETTINGS))
toast.error('需要先设置ApiKey!')
}
}, [addSummarizeTask, apiKey, curSummaryType, dispatch, segment])
}, [addSummarizeTask, curSummaryType, dispatch, envData.aiType, envData.apiKey, envData.geminiApiKey, segment])
const onCopy = useCallback(() => {
if (summary != null) {

View File

@@ -262,18 +262,21 @@ const Settings = () => {
{aiTypeValue === 'gemini' && <Section title='gemini配置'>
<FormItem title='ApiKey' htmlFor='geminiApiKey'>
<input id='geminiApiKey' type='text' className='input input-sm input-bordered w-full' placeholder='xxx' value={geminiApiKeyValue} onChange={onChangeGeminiApiKeyValue}/>
<input id='geminiApiKey' type='text' className='input input-sm input-bordered w-full' placeholder='xxx'
value={geminiApiKeyValue} onChange={onChangeGeminiApiKeyValue}/>
</FormItem>
<div className='flex justify-center'>
<a className='link text-xs' onClick={toggleMoreFold}>{moreFold?'点击查看说明':'点击折叠说明'}</a>
<a className='link text-xs' onClick={toggleMoreFold}>{moreFold ? '点击查看说明' : '点击折叠说明'}</a>
</div>
{!moreFold && <div>
<ul className='pl-3 list-decimal desc text-xs'>
<li><a className='link' href='https://makersuite.google.com/app/apikey' target='_blank' rel="noreferrer">Google AI Studio</a></li>
<li><a className='link' href='https://makersuite.google.com/app/apikey' target='_blank'
rel="noreferrer">Google AI Studio</a> ()</li>
</ul>
</div>}
<div className='flex justify-center'>
<a className='link text-xs' onClick={togglePromptsFold}>{promptsFold?'点击查看提示词':'点击折叠提示词'}</a>
<a className='link text-xs'
onClick={togglePromptsFold}>{promptsFold ? '点击查看提示词' : '点击折叠提示词'}</a>
</div>
{!promptsFold && <div>
{PROMPT_TYPES.map((item, idx) => <FormItem key={item.type} title={<div>
@@ -282,16 +285,18 @@ const Settings = () => {
setPromptsValue({
...promptsValue,
// @ts-expect-error
[item.type]: PROMPT_DEFAULTS[item.type]??''
[item.type]: PROMPT_DEFAULTS[item.type] ?? ''
})
}}></div>
}}>
</div>
</div>} htmlFor={`prompt-${item.type}`}>
<textarea id={`prompt-${item.type}`} className='mt-2 textarea input-bordered w-full' placeholder='留空使用默认提示词' value={promptsValue[item.type]??''} onChange={(e) => {
setPromptsValue({
...promptsValue,
[item.type]: e.target.value
})
}}/>
<textarea id={`prompt-${item.type}`} className='mt-2 textarea input-bordered w-full'
placeholder='留空使用默认提示词' value={promptsValue[item.type] ?? ''} onChange={(e) => {
setPromptsValue({
...promptsValue,
[item.type]: e.target.value
})
}}/>
</FormItem>)}
</div>}
</Section>}

View File

@@ -18,3 +18,16 @@ export const handleChatCompleteTask = async (task: Task) => {
throw new Error(`${task.resp.error.code as string??''} ${task.resp.error.message as string ??''}`)
}
}
export const handleGeminiChatCompleteTask = async (task: Task) => {
const data = task.def.data
const resp = await fetch('https://generativelanguage.googleapis.com/v1/models/gemini-pro:generateContent', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-goog-api-key': task.def.extra.geminiApiKey,
},
body: JSON.stringify(data),
})
task.resp = await resp.json()
}

View File

@@ -1,5 +1,5 @@
import {TASK_EXPIRE_TIME} from '../const'
import {handleChatCompleteTask} from './openaiService'
import {handleChatCompleteTask, handleGeminiChatCompleteTask} from './openaiService'
export const tasksMap = new Map<string, Task>()
@@ -11,6 +11,9 @@ export const handleTask = async (task: Task) => {
case 'chatComplete':
await handleChatCompleteTask(task)
break
case 'geminiChatComplete':
await handleGeminiChatCompleteTask(task)
break
default:
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
throw new Error(`任务类型不支持: ${task.def.type}`)

View File

@@ -82,23 +82,39 @@ const useTranslate = () => {
prompt = prompt.replaceAll('{{subtitles}}', lineStr)
const taskDef: TaskDef = {
type: 'chatComplete',
type: envData.aiType === 'gemini'?'geminiChatComplete':'chatComplete',
serverUrl: envData.serverUrl,
data: {
model: envData.model??MODEL_DEFAULT,
messages: [
{
role: 'user',
content: prompt,
data: envData.aiType === 'gemini'
?{
contents: [
{
parts: [
{
text: prompt
}
]
}
],
generationConfig: {
maxOutputTokens: 2048
}
}
],
temperature: 0,
n: 1,
stream: false,
},
:{
model: envData.model??MODEL_DEFAULT,
messages: [
{
role: 'user',
content: prompt,
}
],
temperature: 0,
n: 1,
stream: false,
},
extra: {
type: 'translate',
apiKey: envData.apiKey,
geminiApiKey: envData.geminiApiKey,
startIdx,
size: lines.length,
}
@@ -117,10 +133,10 @@ const useTranslate = () => {
dispatch(addTaskId(task.id))
}
}
}, [data?.body, dispatch, envData.apiKey, envData.fetchAmount, envData.serverUrl, envData.prompts, title, language.name])
}, [data?.body, envData.fetchAmount, envData.prompts, envData.aiType, envData.serverUrl, envData.model, envData.apiKey, envData.geminiApiKey, language.name, title, dispatch])
const addSummarizeTask = useCallback(async (type: SummaryType, segment: Segment) => {
if (segment.text.length >= SUMMARIZE_THRESHOLD && envData.apiKey) {
if (segment.text.length >= SUMMARIZE_THRESHOLD) {
let subtitles = ''
for (const item of segment.items) {
subtitles += formatTime(item.from) + ' ' + item.content + '\n'
@@ -135,25 +151,41 @@ const useTranslate = () => {
prompt = prompt.replaceAll('{{segment}}', segment.text)
const taskDef: TaskDef = {
type: 'chatComplete',
type: envData.aiType === 'gemini'?'geminiChatComplete':'chatComplete',
serverUrl: envData.serverUrl,
data: {
model: envData.model??MODEL_DEFAULT,
messages: [
{
role: 'user',
content: prompt,
data: envData.aiType === 'gemini'
?{
contents: [
{
parts: [
{
text: prompt
}
]
}
],
generationConfig: {
maxOutputTokens: 2048
}
}
],
temperature: 0,
n: 1,
stream: false,
},
:{
model: envData.model??MODEL_DEFAULT,
messages: [
{
role: 'user',
content: prompt,
}
],
temperature: 0,
n: 1,
stream: false,
},
extra: {
type: 'summarize',
summaryType: type,
startIdx: segment.startIdx,
apiKey: envData.apiKey,
geminiApiKey: envData.geminiApiKey,
}
}
console.debug('addSummarizeTask', taskDef)
@@ -162,7 +194,7 @@ const useTranslate = () => {
const task = await chrome.runtime.sendMessage({type: 'addTask', taskDef})
dispatch(addTaskId(task.id))
}
}, [dispatch, envData.apiKey, envData.prompts, envData.serverUrl, summarizeLanguage.name, title])
}, [dispatch, envData.aiType, envData.apiKey, envData.geminiApiKey, envData.model, envData.prompts, envData.serverUrl, summarizeLanguage.name, title])
const handleTranslate = useMemoizedFn((task: Task, content: string) => {
let map: {[key: string]: string} = {}
@@ -221,7 +253,7 @@ const useTranslate = () => {
console.debug('getTask', taskResp.task)
const task: Task = taskResp.task
const taskType: string | undefined = task.def.extra?.type
const content = task.resp?.choices?.[0]?.message?.content?.trim()
const content = envData.aiType === 'gemini'?task.resp?.candidates[0]?.content?.parts[0]?.text?.trim():task.resp?.choices?.[0]?.message?.content?.trim()
if (task.status === 'done') {
// 异常提示
if (task.error) {
@@ -239,7 +271,7 @@ const useTranslate = () => {
} else {
dispatch(delTaskId(taskId))
}
}, [dispatch, handleSummarize, handleTranslate])
}, [dispatch, envData.aiType, handleSummarize, handleTranslate])
return {getFetch, getTask, addTask, addSummarizeTask}
}

2
src/typings.d.ts vendored
View File

@@ -38,7 +38,7 @@ interface TempData {
}
interface TaskDef {
type: 'chatComplete'
type: 'chatComplete' | 'geminiChatComplete'
serverUrl?: string
data: any
extra?: any