403 lines
12 KiB
Go
Raw Normal View History

2025-06-08 21:40:57 +08:00
package kling
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
2025-08-25 18:01:10 +08:00
"github.com/samber/lo"
2025-06-08 21:40:57 +08:00
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
2025-06-08 21:40:57 +08:00
"github.com/pkg/errors"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/relay/channel"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/service"
2025-06-08 21:40:57 +08:00
)
// ============================
// Request / Response structures
// ============================
type TrajectoryPoint struct {
X int `json:"x"`
Y int `json:"y"`
}
type DynamicMask struct {
Mask string `json:"mask,omitempty"`
Trajectories []TrajectoryPoint `json:"trajectories,omitempty"`
}
type CameraConfig struct {
Horizontal float64 `json:"horizontal,omitempty"`
Vertical float64 `json:"vertical,omitempty"`
Pan float64 `json:"pan,omitempty"`
Tilt float64 `json:"tilt,omitempty"`
Roll float64 `json:"roll,omitempty"`
Zoom float64 `json:"zoom,omitempty"`
}
type CameraControl struct {
Type string `json:"type,omitempty"`
Config *CameraConfig `json:"config,omitempty"`
}
2025-06-08 21:40:57 +08:00
type requestPayload struct {
Prompt string `json:"prompt,omitempty"`
Image string `json:"image,omitempty"`
ImageTail string `json:"image_tail,omitempty"`
NegativePrompt string `json:"negative_prompt,omitempty"`
Mode string `json:"mode,omitempty"`
Duration string `json:"duration,omitempty"`
AspectRatio string `json:"aspect_ratio,omitempty"`
ModelName string `json:"model_name,omitempty"`
Model string `json:"model,omitempty"` // Compatible with upstreams that only recognize "model"
CfgScale float64 `json:"cfg_scale,omitempty"`
StaticMask string `json:"static_mask,omitempty"`
DynamicMasks []DynamicMask `json:"dynamic_masks,omitempty"`
CameraControl *CameraControl `json:"camera_control,omitempty"`
CallbackUrl string `json:"callback_url,omitempty"`
ExternalTaskId string `json:"external_task_id,omitempty"`
2025-06-08 21:40:57 +08:00
}
type responsePayload struct {
2025-06-20 15:50:00 +08:00
Code int `json:"code"`
Message string `json:"message"`
2025-07-21 15:06:26 +08:00
TaskId string `json:"task_id"`
2025-06-20 15:50:00 +08:00
RequestId string `json:"request_id"`
Data struct {
TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"`
TaskStatusMsg string `json:"task_status_msg"`
TaskResult struct {
Videos []struct {
Id string `json:"id"`
Url string `json:"url"`
Duration string `json:"duration"`
} `json:"videos"`
} `json:"task_result"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
2025-06-08 21:40:57 +08:00
} `json:"data"`
}
// ============================
// Adaptor implementation
// ============================
type TaskAdaptor struct {
ChannelType int
2025-07-21 15:06:26 +08:00
apiKey string
2025-06-08 21:40:57 +08:00
baseURL string
}
2025-08-25 18:01:10 +08:00
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
2025-06-08 21:40:57 +08:00
a.ChannelType = info.ChannelType
a.baseURL = info.ChannelBaseUrl
2025-07-21 15:06:26 +08:00
a.apiKey = info.ApiKey
2025-06-08 21:40:57 +08:00
// apiKey format: "access_key|secret_key"
2025-06-08 21:40:57 +08:00
}
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
2025-08-25 18:01:10 +08:00
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
2025-09-12 21:52:32 +08:00
// Use the standard validation method for TaskSubmitReq
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
2025-06-08 21:40:57 +08:00
}
// BuildRequestURL constructs the upstream URL.
2025-08-25 18:01:10 +08:00
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
2025-06-27 22:43:01 +08:00
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
if isNewAPIRelay(info.ApiKey) {
return fmt.Sprintf("%s/kling%s", a.baseURL, path), nil
}
2025-06-23 21:22:01 +08:00
return fmt.Sprintf("%s%s", a.baseURL, path), nil
2025-06-08 21:40:57 +08:00
}
// BuildRequestHeader sets required headers.
2025-08-25 18:01:10 +08:00
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
2025-06-08 21:40:57 +08:00
token, err := a.createJWTToken()
if err != nil {
2025-06-17 13:37:07 +08:00
return fmt.Errorf("failed to create JWT token: %w", err)
2025-06-08 21:40:57 +08:00
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0")
return nil
}
// BuildRequestBody converts request into Kling specific format.
2025-08-25 18:01:10 +08:00
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
2025-06-20 15:50:00 +08:00
v, exists := c.Get("task_request")
2025-06-08 21:40:57 +08:00
if !exists {
return nil, fmt.Errorf("request not found in context")
}
2025-09-12 21:52:32 +08:00
req := v.(relaycommon.TaskSubmitReq)
2025-06-08 21:40:57 +08:00
2025-06-23 21:22:01 +08:00
body, err := a.convertToRequestPayload(&req)
if err != nil {
return nil, err
}
if body.Image == "" && body.ImageTail == "" {
c.Set("action", constant.TaskActionTextGenerate)
}
2025-06-08 21:40:57 +08:00
data, err := json.Marshal(body)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}
// DoRequest delegates to common helper.
2025-08-25 18:01:10 +08:00
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
2025-06-23 21:22:01 +08:00
if action := c.GetString("action"); action != "" {
info.Action = action
}
2025-06-08 21:40:57 +08:00
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
// DoResponse handles upstream response, returns taskID etc.
2025-08-25 18:01:10 +08:00
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
2025-06-08 21:40:57 +08:00
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
return
}
var kResp responsePayload
2025-07-21 15:06:26 +08:00
err = json.Unmarshal(responseBody, &kResp)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
2025-06-08 21:40:57 +08:00
return
}
2025-07-21 15:06:26 +08:00
if kResp.Code != 0 {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf(kResp.Message), "task_failed", http.StatusBadRequest)
2025-06-08 21:40:57 +08:00
return
}
ov := dto.NewOpenAIVideo()
2025-10-11 14:37:19 +08:00
ov.ID = kResp.Data.TaskId
ov.TaskID = kResp.Data.TaskId
ov.CreatedAt = time.Now().Unix()
ov.Model = info.OriginModelName
c.JSON(http.StatusOK, ov)
2025-07-21 15:06:26 +08:00
return kResp.Data.TaskId, responseBody, nil
2025-06-08 21:40:57 +08:00
}
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
2025-06-23 21:22:01 +08:00
action, ok := body["action"].(string)
if !ok {
return nil, fmt.Errorf("invalid action")
}
2025-06-27 22:43:01 +08:00
path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
2025-06-23 21:22:01 +08:00
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
if isNewAPIRelay(key) {
url = fmt.Sprintf("%s/kling%s/%s", baseUrl, path, taskID)
}
2025-06-08 21:40:57 +08:00
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
token, err := a.createJWTTokenWithKey(key)
if err != nil {
token = key
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0")
return service.GetHttpClient().Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {
return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
}
func (a *TaskAdaptor) GetChannelName() string {
return "kling"
}
// ============================
// helpers
// ============================
2025-09-12 21:52:32 +08:00
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
2025-06-23 21:22:01 +08:00
r := requestPayload{
Prompt: req.Prompt,
Image: req.Image,
Mode: defaultString(req.Mode, "std"),
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
AspectRatio: a.getAspectRatio(req.Size),
ModelName: req.Model,
Model: req.Model, // Keep consistent with model_name, double writing improves compatibility
CfgScale: 0.5,
StaticMask: "",
DynamicMasks: []DynamicMask{},
CameraControl: nil,
CallbackUrl: "",
ExternalTaskId: "",
2025-06-08 21:40:57 +08:00
}
2025-06-23 21:22:01 +08:00
if r.ModelName == "" {
2025-06-08 21:40:57 +08:00
r.ModelName = "kling-v1"
}
2025-06-23 21:22:01 +08:00
metadata := req.Metadata
medaBytes, err := json.Marshal(metadata)
if err != nil {
return nil, errors.Wrap(err, "metadata marshal metadata failed")
}
err = json.Unmarshal(medaBytes, &r)
if err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
return &r, nil
2025-06-08 21:40:57 +08:00
}
func (a *TaskAdaptor) getAspectRatio(size string) string {
switch size {
case "1024x1024", "512x512":
return "1:1"
case "1280x720", "1920x1080":
return "16:9"
case "720x1280", "1080x1920":
return "9:16"
default:
return "1:1"
}
}
func defaultString(s, def string) string {
if strings.TrimSpace(s) == "" {
return def
}
return s
}
func defaultInt(v int, def int) int {
if v == 0 {
return def
}
return v
}
// ============================
// JWT helpers
// ============================
func (a *TaskAdaptor) createJWTToken() (string, error) {
2025-07-21 15:06:26 +08:00
return a.createJWTTokenWithKey(a.apiKey)
2025-06-08 21:40:57 +08:00
}
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
if isNewAPIRelay(apiKey) {
return apiKey, nil // new api relay
}
2025-07-21 15:06:26 +08:00
keyParts := strings.Split(apiKey, "|")
if len(keyParts) != 2 {
return "", errors.New("invalid api_key, required format is accessKey|secretKey")
}
2025-07-21 15:06:26 +08:00
accessKey := strings.TrimSpace(keyParts[0])
if len(keyParts) == 1 {
return accessKey, nil
2025-06-08 21:40:57 +08:00
}
2025-07-21 15:06:26 +08:00
secretKey := strings.TrimSpace(keyParts[1])
2025-06-08 21:40:57 +08:00
now := time.Now().Unix()
claims := jwt.MapClaims{
"iss": accessKey,
"exp": now + 1800, // 30 minutes
"nbf": now - 5,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["typ"] = "JWT"
return token.SignedString([]byte(secretKey))
}
2025-06-20 15:50:00 +08:00
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
2025-07-21 15:06:26 +08:00
taskInfo := &relaycommon.TaskInfo{}
2025-06-20 15:50:00 +08:00
resPayload := responsePayload{}
err := json.Unmarshal(respBody, &resPayload)
if err != nil {
return nil, errors.Wrap(err, "failed to unmarshal response body")
2025-06-08 21:40:57 +08:00
}
2025-06-20 15:50:00 +08:00
taskInfo.Code = resPayload.Code
taskInfo.TaskID = resPayload.Data.TaskId
taskInfo.Reason = resPayload.Message
//任务状态枚举值submitted已提交、processing处理中、succeed成功、failed失败
status := resPayload.Data.TaskStatus
switch status {
case "submitted":
taskInfo.Status = model.TaskStatusSubmitted
case "processing":
taskInfo.Status = model.TaskStatusInProgress
case "succeed":
taskInfo.Status = model.TaskStatusSuccess
case "failed":
taskInfo.Status = model.TaskStatusFailure
default:
return nil, fmt.Errorf("unknown task status: %s", status)
2025-06-08 21:40:57 +08:00
}
2025-06-20 15:50:00 +08:00
if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 {
video := videos[0]
taskInfo.Url = video.Url
2025-06-08 21:40:57 +08:00
}
2025-06-20 15:50:00 +08:00
return taskInfo, nil
2025-06-08 21:40:57 +08:00
}
func isNewAPIRelay(apiKey string) bool {
return strings.HasPrefix(apiKey, "sk-")
}
2025-10-11 00:05:56 +08:00
func (a *TaskAdaptor) ConvertToOpenAIVideo(originTask *model.Task) ([]byte, error) {
2025-10-11 00:05:56 +08:00
var klingResp responsePayload
if err := json.Unmarshal(originTask.Data, &klingResp); err != nil {
return nil, errors.Wrap(err, "unmarshal kling task data failed")
}
2025-10-11 14:37:19 +08:00
openAIVideo := dto.NewOpenAIVideo()
2025-10-11 14:37:19 +08:00
openAIVideo.ID = originTask.TaskID
openAIVideo.Status = originTask.Status.ToVideoStatus()
2025-10-11 13:52:37 +08:00
openAIVideo.SetProgressStr(originTask.Progress)
2025-10-11 14:37:19 +08:00
openAIVideo.CreatedAt = klingResp.Data.CreatedAt
openAIVideo.CompletedAt = klingResp.Data.UpdatedAt
2025-10-11 00:05:56 +08:00
if len(klingResp.Data.TaskResult.Videos) > 0 {
video := klingResp.Data.TaskResult.Videos[0]
if video.Url != "" {
2025-10-11 14:37:19 +08:00
openAIVideo.SetMetadata("url", video.Url)
2025-10-11 00:05:56 +08:00
}
if video.Duration != "" {
openAIVideo.Seconds = video.Duration
}
}
if klingResp.Code != 0 && klingResp.Message != "" {
openAIVideo.Error = &dto.OpenAIVideoError{
2025-10-11 00:05:56 +08:00
Message: klingResp.Message,
Code: fmt.Sprintf("%d", klingResp.Code),
}
}
jsonData, _ := common.Marshal(openAIVideo)
return jsonData, nil
2025-10-11 00:05:56 +08:00
}