mirror of
https://github.com/woodchen-ink/Oapi-Feishu.git
synced 2025-07-18 05:42:08 +08:00
251 lines
5.7 KiB
Go
251 lines
5.7 KiB
Go
package openai
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/url"
|
|
"start-feishubot/initialization"
|
|
"start-feishubot/services/loadbalancer"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type PlatForm string
|
|
|
|
const (
|
|
AzureApiUrlV1 = "openai.azure.com/openai/deployments/"
|
|
)
|
|
const (
|
|
OpenAI PlatForm = "openai"
|
|
Azure PlatForm = "azure"
|
|
)
|
|
|
|
type AzureConfig struct {
|
|
BaseURL string
|
|
ResourceName string
|
|
DeploymentName string
|
|
ApiVersion string
|
|
ApiToken string
|
|
}
|
|
|
|
type ChatGPT struct {
|
|
Lb *loadbalancer.LoadBalancer
|
|
ApiKey []string
|
|
ApiUrl string
|
|
ApiModel string
|
|
HttpProxy string
|
|
Platform PlatForm
|
|
AzureConfig AzureConfig
|
|
}
|
|
type requestBodyType int
|
|
|
|
const (
|
|
jsonBody requestBodyType = iota
|
|
formVoiceDataBody
|
|
formPictureDataBody
|
|
streamBody
|
|
nilBody
|
|
)
|
|
|
|
func (gpt *ChatGPT) doAPIRequestWithRetry(url, method string,
|
|
bodyType requestBodyType,
|
|
requestBody interface{}, responseBody interface{}, client *http.Client, maxRetries int) error {
|
|
var api *loadbalancer.API
|
|
var requestBodyData []byte
|
|
var err error
|
|
var writer *multipart.Writer
|
|
api = gpt.Lb.GetAPI()
|
|
|
|
switch bodyType {
|
|
case jsonBody:
|
|
requestBodyData, err = json.Marshal(requestBody)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
case formVoiceDataBody:
|
|
formBody := &bytes.Buffer{}
|
|
writer = multipart.NewWriter(formBody)
|
|
err = audioMultipartForm(requestBody.(AudioToTextRequestBody), writer)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = writer.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
requestBodyData = formBody.Bytes()
|
|
case formPictureDataBody:
|
|
formBody := &bytes.Buffer{}
|
|
writer = multipart.NewWriter(formBody)
|
|
err = pictureMultipartForm(requestBody.(ImageVariantRequestBody), writer)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = writer.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
requestBodyData = formBody.Bytes()
|
|
|
|
case nilBody:
|
|
requestBodyData = nil
|
|
|
|
default:
|
|
return errors.New("unknown request body type")
|
|
}
|
|
|
|
if api == nil {
|
|
return errors.New("no available API")
|
|
}
|
|
|
|
req, err := http.NewRequest(method, url, bytes.NewReader(requestBodyData))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if bodyType == formVoiceDataBody || bodyType == formPictureDataBody {
|
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
|
}
|
|
if bodyType == streamBody {
|
|
req.Header.Set("Accept", "text/event-stream")
|
|
req.Header.Set("Connection", "keep-alive")
|
|
req.Header.Set("Cache-Control", "no-cache")
|
|
}
|
|
if gpt.Platform == OpenAI {
|
|
req.Header.Set("Authorization", "Bearer "+api.Key)
|
|
} else {
|
|
req.Header.Set("api-key", gpt.AzureConfig.ApiToken)
|
|
}
|
|
|
|
var response *http.Response
|
|
var retry int
|
|
for retry = 0; retry <= maxRetries; retry++ {
|
|
response, err = client.Do(req)
|
|
//fmt.Println("--------------------")
|
|
//fmt.Println("req", req.Header)
|
|
//fmt.Printf("response: %v", response)
|
|
// read body
|
|
if err != nil || response.StatusCode < 200 || response.StatusCode >= 300 {
|
|
|
|
body, _ := ioutil.ReadAll(response.Body)
|
|
fmt.Println("body", string(body))
|
|
|
|
gpt.Lb.SetAvailability(api.Key, false)
|
|
if retry == maxRetries || bodyType == streamBody {
|
|
break
|
|
}
|
|
time.Sleep(time.Duration(retry+1) * time.Second)
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
if response != nil {
|
|
defer response.Body.Close()
|
|
}
|
|
|
|
if response == nil || response.StatusCode < 200 || response.StatusCode >= 300 {
|
|
return fmt.Errorf("%s api failed after %d retries", strings.ToUpper(method), retry)
|
|
}
|
|
|
|
body, err := ioutil.ReadAll(response.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = json.Unmarshal(body, responseBody)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
gpt.Lb.SetAvailability(api.Key, true)
|
|
return nil
|
|
}
|
|
|
|
func (gpt *ChatGPT) sendRequestWithBodyType(link, method string,
|
|
bodyType requestBodyType,
|
|
requestBody interface{}, responseBody interface{}) error {
|
|
var err error
|
|
proxyString := gpt.HttpProxy
|
|
|
|
client, parseProxyError := GetProxyClient(proxyString)
|
|
if parseProxyError != nil {
|
|
return parseProxyError
|
|
}
|
|
|
|
err = gpt.doAPIRequestWithRetry(link, method, bodyType,
|
|
requestBody, responseBody, client, 3)
|
|
|
|
return err
|
|
}
|
|
|
|
func GetProxyClient(proxyString string) (*http.Client, error) {
|
|
var client *http.Client
|
|
timeOutDuration := time.Duration(initialization.GetConfig().OpenAIHttpClientTimeOut) * time.Second
|
|
if proxyString == "" {
|
|
client = &http.Client{Timeout: timeOutDuration}
|
|
} else {
|
|
proxyUrl, err := url.Parse(proxyString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
transport := &http.Transport{
|
|
Proxy: http.ProxyURL(proxyUrl),
|
|
}
|
|
client = &http.Client{
|
|
Transport: transport,
|
|
Timeout: timeOutDuration,
|
|
}
|
|
}
|
|
return client, nil
|
|
}
|
|
|
|
func NewChatGPT(config initialization.Config) *ChatGPT {
|
|
var lb *loadbalancer.LoadBalancer
|
|
if config.AzureOn {
|
|
keys := []string{config.AzureOpenaiToken}
|
|
lb = loadbalancer.NewLoadBalancer(keys)
|
|
} else {
|
|
lb = loadbalancer.NewLoadBalancer(config.OpenaiApiKeys)
|
|
}
|
|
platform := OpenAI
|
|
|
|
if config.AzureOn {
|
|
platform = Azure
|
|
}
|
|
|
|
return &ChatGPT{
|
|
Lb: lb,
|
|
ApiKey: config.OpenaiApiKeys,
|
|
ApiUrl: config.OpenaiApiUrl,
|
|
HttpProxy: config.HttpProxy,
|
|
Platform: platform,
|
|
ApiModel: config.OpenaiModel,
|
|
AzureConfig: AzureConfig{
|
|
BaseURL: AzureApiUrlV1,
|
|
ResourceName: config.AzureResourceName,
|
|
DeploymentName: config.AzureDeploymentName,
|
|
ApiVersion: config.AzureApiVersion,
|
|
ApiToken: config.AzureOpenaiToken,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (gpt *ChatGPT) FullUrl(suffix string) string {
|
|
var url string
|
|
switch gpt.Platform {
|
|
case Azure:
|
|
url = fmt.Sprintf("https://%s.%s%s/%s?api-version=%s",
|
|
gpt.AzureConfig.ResourceName, gpt.AzureConfig.BaseURL,
|
|
gpt.AzureConfig.DeploymentName, suffix, gpt.AzureConfig.ApiVersion)
|
|
case OpenAI:
|
|
url = fmt.Sprintf("%s/v1/%s", gpt.ApiUrl, suffix)
|
|
}
|
|
return url
|
|
}
|