fix(vertex): honor custom base_url as gateway prefix
This commit is contained in:
parent
ed7f839911
commit
987b7ecd22
@ -95,20 +95,7 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro
|
||||
if strings.TrimSpace(region) == "" {
|
||||
region = "global"
|
||||
}
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning",
|
||||
adc.ProjectID,
|
||||
modelName,
|
||||
), nil
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
modelName,
|
||||
), nil
|
||||
return vertexcore.BuildGoogleModelURL(a.baseURL, vertexcore.DefaultAPIVersion, adc.ProjectID, region, modelName, "predictLongRunning"), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
@ -238,6 +225,22 @@ func (a *TaskAdaptor) GetModelList() []string {
|
||||
}
|
||||
func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
|
||||
|
||||
func buildFetchOperationURL(baseURL, upstreamName string) (string, error) {
|
||||
region := extractRegionFromOperationName(upstreamName)
|
||||
if region == "" {
|
||||
region = "us-central1"
|
||||
}
|
||||
project := extractProjectFromOperationName(upstreamName)
|
||||
modelName := extractModelFromOperationName(upstreamName)
|
||||
if strings.TrimSpace(modelName) == "" {
|
||||
return "", fmt.Errorf("cannot extract model from operation name")
|
||||
}
|
||||
if strings.TrimSpace(project) == "" {
|
||||
return "", fmt.Errorf("cannot extract project from operation name")
|
||||
}
|
||||
return vertexcore.BuildGoogleModelURL(baseURL, vertexcore.DefaultAPIVersion, project, region, modelName, "fetchPredictOperation"), nil
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
@ -248,20 +251,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode task_id failed: %w", err)
|
||||
}
|
||||
region := extractRegionFromOperationName(upstreamName)
|
||||
if region == "" {
|
||||
region = "us-central1"
|
||||
}
|
||||
project := extractProjectFromOperationName(upstreamName)
|
||||
modelName := extractModelFromOperationName(upstreamName)
|
||||
if project == "" || modelName == "" {
|
||||
return nil, fmt.Errorf("cannot extract project/model from operation name")
|
||||
}
|
||||
var url string
|
||||
if region == "global" {
|
||||
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName)
|
||||
} else {
|
||||
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
|
||||
url, err := buildFetchOperationURL(baseUrl, upstreamName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := fetchOperationPayload{OperationName: upstreamName}
|
||||
data, err := common.Marshal(payload)
|
||||
|
||||
@ -133,47 +133,11 @@ func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix s
|
||||
a.AccountCredentials = *adc
|
||||
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
return BuildGoogleModelURL(info.ChannelBaseUrl, DefaultAPIVersion, adc.ProjectID, region, modelName, suffix), nil
|
||||
} else if a.RequestMode == RequestModeClaude {
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
return BuildAnthropicModelURL(info.ChannelBaseUrl, DefaultAPIVersion, adc.ProjectID, region, modelName, suffix), nil
|
||||
} else if a.RequestMode == RequestModeOpenSource {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
adc.ProjectID,
|
||||
region,
|
||||
), nil
|
||||
return BuildOpenSourceChatCompletionsURL(info.ChannelBaseUrl, adc.ProjectID, region), nil
|
||||
}
|
||||
} else {
|
||||
var keyPrefix string
|
||||
@ -182,20 +146,17 @@ func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix s
|
||||
} else {
|
||||
keyPrefix = "?"
|
||||
}
|
||||
if region == "global" {
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s%skey=%s",
|
||||
modelName,
|
||||
suffix,
|
||||
"%s%skey=%s",
|
||||
BuildGoogleModelURL(info.ChannelBaseUrl, DefaultAPIVersion, "", region, modelName, suffix),
|
||||
keyPrefix,
|
||||
info.ApiKey,
|
||||
), nil
|
||||
} else {
|
||||
} else if a.RequestMode == RequestModeClaude {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s%skey=%s",
|
||||
region,
|
||||
modelName,
|
||||
suffix,
|
||||
"%s%skey=%s",
|
||||
BuildAnthropicModelURL(info.ChannelBaseUrl, DefaultAPIVersion, "", region, modelName, suffix),
|
||||
keyPrefix,
|
||||
info.ApiKey,
|
||||
), nil
|
||||
|
||||
86
relay/channel/vertex/url_builder.go
Normal file
86
relay/channel/vertex/url_builder.go
Normal file
@ -0,0 +1,86 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultAPIVersion = "v1"
|
||||
OpenSourceAPIVersion = "v1beta1"
|
||||
PublisherGoogle = "google"
|
||||
PublisherAnthropic = "anthropic"
|
||||
)
|
||||
|
||||
func normalizeVertexBaseURL(baseURL string) string {
|
||||
return strings.TrimRight(strings.TrimSpace(baseURL), "/")
|
||||
}
|
||||
|
||||
func normalizeVertexRegion(region string) string {
|
||||
region = strings.TrimSpace(region)
|
||||
if region == "" {
|
||||
return "global"
|
||||
}
|
||||
return region
|
||||
}
|
||||
|
||||
func appendVertexAPIVersion(baseURL, version string) string {
|
||||
version = strings.Trim(strings.TrimSpace(version), "/")
|
||||
if version == "" {
|
||||
return baseURL
|
||||
}
|
||||
if strings.HasSuffix(baseURL, "/"+version) {
|
||||
return baseURL
|
||||
}
|
||||
return baseURL + "/" + version
|
||||
}
|
||||
|
||||
func BuildAPIBaseURL(baseURL, version, projectID, region string) string {
|
||||
if normalized := normalizeVertexBaseURL(baseURL); normalized != "" {
|
||||
normalized = appendVertexAPIVersion(normalized, version)
|
||||
|
||||
region = normalizeVertexRegion(region)
|
||||
if strings.TrimSpace(projectID) != "" {
|
||||
normalized = fmt.Sprintf("%s/projects/%s/locations/%s", normalized, projectID, region)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
region = normalizeVertexRegion(region)
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
if region == "global" {
|
||||
return fmt.Sprintf("https://aiplatform.googleapis.com/%s", version)
|
||||
}
|
||||
return fmt.Sprintf("https://%s-aiplatform.googleapis.com/%s", region, version)
|
||||
}
|
||||
|
||||
if region == "global" {
|
||||
return fmt.Sprintf("https://aiplatform.googleapis.com/%s/projects/%s/locations/global", version, projectID)
|
||||
}
|
||||
return fmt.Sprintf("https://%s-aiplatform.googleapis.com/%s/projects/%s/locations/%s", region, version, projectID, region)
|
||||
}
|
||||
|
||||
func BuildPublisherModelURL(baseURL, version, projectID, region, publisher, modelName, action string) string {
|
||||
return fmt.Sprintf(
|
||||
"%s/publishers/%s/models/%s:%s",
|
||||
BuildAPIBaseURL(baseURL, version, projectID, region),
|
||||
publisher,
|
||||
modelName,
|
||||
action,
|
||||
)
|
||||
}
|
||||
|
||||
func BuildGoogleModelURL(baseURL, version, projectID, region, modelName, action string) string {
|
||||
return BuildPublisherModelURL(baseURL, version, projectID, region, PublisherGoogle, modelName, action)
|
||||
}
|
||||
|
||||
func BuildAnthropicModelURL(baseURL, version, projectID, region, modelName, action string) string {
|
||||
return BuildPublisherModelURL(baseURL, version, projectID, region, PublisherAnthropic, modelName, action)
|
||||
}
|
||||
|
||||
func BuildOpenSourceChatCompletionsURL(baseURL, projectID, region string) string {
|
||||
return fmt.Sprintf(
|
||||
"%s/endpoints/openapi/chat/completions",
|
||||
BuildAPIBaseURL(baseURL, OpenSourceAPIVersion, projectID, region),
|
||||
)
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user