From 987b7ecd223ed42096d26fab8f1784fcea68f926 Mon Sep 17 00:00:00 2001 From: yyhhyyyyyy Date: Thu, 30 Apr 2026 15:08:10 +0800 Subject: [PATCH] fix(vertex): honor custom base_url as gateway prefix --- relay/channel/task/vertex/adaptor.go | 48 +++++++--------- relay/channel/vertex/adaptor.go | 57 +++--------------- relay/channel/vertex/url_builder.go | 86 ++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 76 deletions(-) create mode 100644 relay/channel/vertex/url_builder.go diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go index b76364ee..a296d4cc 100644 --- a/relay/channel/task/vertex/adaptor.go +++ b/relay/channel/task/vertex/adaptor.go @@ -95,20 +95,7 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, erro if strings.TrimSpace(region) == "" { region = "global" } - if region == "global" { - return fmt.Sprintf( - "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning", - adc.ProjectID, - modelName, - ), nil - } - return fmt.Sprintf( - "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning", - region, - adc.ProjectID, - region, - modelName, - ), nil + return vertexcore.BuildGoogleModelURL(a.baseURL, vertexcore.DefaultAPIVersion, adc.ProjectID, region, modelName, "predictLongRunning"), nil } // BuildRequestHeader sets required headers. @@ -238,6 +225,22 @@ func (a *TaskAdaptor) GetModelList() []string { } func (a *TaskAdaptor) GetChannelName() string { return "vertex" } +func buildFetchOperationURL(baseURL, upstreamName string) (string, error) { + region := extractRegionFromOperationName(upstreamName) + if region == "" { + region = "us-central1" + } + project := extractProjectFromOperationName(upstreamName) + modelName := extractModelFromOperationName(upstreamName) + if strings.TrimSpace(modelName) == "" { + return "", fmt.Errorf("cannot extract model from operation name") + } + if strings.TrimSpace(project) == "" { + return "", fmt.Errorf("cannot extract project from operation name") + } + return vertexcore.BuildGoogleModelURL(baseURL, vertexcore.DefaultAPIVersion, project, region, modelName, "fetchPredictOperation"), nil +} + // FetchTask fetch task status func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy string) (*http.Response, error) { taskID, ok := body["task_id"].(string) @@ -248,20 +251,9 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy if err != nil { return nil, fmt.Errorf("decode task_id failed: %w", err) } - region := extractRegionFromOperationName(upstreamName) - if region == "" { - region = "us-central1" - } - project := extractProjectFromOperationName(upstreamName) - modelName := extractModelFromOperationName(upstreamName) - if project == "" || modelName == "" { - return nil, fmt.Errorf("cannot extract project/model from operation name") - } - var url string - if region == "global" { - url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName) - } else { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName) + url, err := buildFetchOperationURL(baseUrl, upstreamName) + if err != nil { + return nil, err } payload := fetchOperationPayload{OperationName: upstreamName} data, err := common.Marshal(payload) diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 7e56c52b..93114d6e 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -133,47 +133,11 @@ func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix s a.AccountCredentials = *adc if a.RequestMode == RequestModeGemini { - if region == "global" { - return fmt.Sprintf( - "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s", - adc.ProjectID, - modelName, - suffix, - ), nil - } else { - return fmt.Sprintf( - "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s", - region, - adc.ProjectID, - region, - modelName, - suffix, - ), nil - } + return BuildGoogleModelURL(info.ChannelBaseUrl, DefaultAPIVersion, adc.ProjectID, region, modelName, suffix), nil } else if a.RequestMode == RequestModeClaude { - if region == "global" { - return fmt.Sprintf( - "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s", - adc.ProjectID, - modelName, - suffix, - ), nil - } else { - return fmt.Sprintf( - "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s", - region, - adc.ProjectID, - region, - modelName, - suffix, - ), nil - } + return BuildAnthropicModelURL(info.ChannelBaseUrl, DefaultAPIVersion, adc.ProjectID, region, modelName, suffix), nil } else if a.RequestMode == RequestModeOpenSource { - return fmt.Sprintf( - "https://aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", - adc.ProjectID, - region, - ), nil + return BuildOpenSourceChatCompletionsURL(info.ChannelBaseUrl, adc.ProjectID, region), nil } } else { var keyPrefix string @@ -182,20 +146,17 @@ func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix s } else { keyPrefix = "?" } - if region == "global" { + if a.RequestMode == RequestModeGemini { return fmt.Sprintf( - "https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s%skey=%s", - modelName, - suffix, + "%s%skey=%s", + BuildGoogleModelURL(info.ChannelBaseUrl, DefaultAPIVersion, "", region, modelName, suffix), keyPrefix, info.ApiKey, ), nil - } else { + } else if a.RequestMode == RequestModeClaude { return fmt.Sprintf( - "https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s%skey=%s", - region, - modelName, - suffix, + "%s%skey=%s", + BuildAnthropicModelURL(info.ChannelBaseUrl, DefaultAPIVersion, "", region, modelName, suffix), keyPrefix, info.ApiKey, ), nil diff --git a/relay/channel/vertex/url_builder.go b/relay/channel/vertex/url_builder.go new file mode 100644 index 00000000..0fa83439 --- /dev/null +++ b/relay/channel/vertex/url_builder.go @@ -0,0 +1,86 @@ +package vertex + +import ( + "fmt" + "strings" +) + +const ( + DefaultAPIVersion = "v1" + OpenSourceAPIVersion = "v1beta1" + PublisherGoogle = "google" + PublisherAnthropic = "anthropic" +) + +func normalizeVertexBaseURL(baseURL string) string { + return strings.TrimRight(strings.TrimSpace(baseURL), "/") +} + +func normalizeVertexRegion(region string) string { + region = strings.TrimSpace(region) + if region == "" { + return "global" + } + return region +} + +func appendVertexAPIVersion(baseURL, version string) string { + version = strings.Trim(strings.TrimSpace(version), "/") + if version == "" { + return baseURL + } + if strings.HasSuffix(baseURL, "/"+version) { + return baseURL + } + return baseURL + "/" + version +} + +func BuildAPIBaseURL(baseURL, version, projectID, region string) string { + if normalized := normalizeVertexBaseURL(baseURL); normalized != "" { + normalized = appendVertexAPIVersion(normalized, version) + + region = normalizeVertexRegion(region) + if strings.TrimSpace(projectID) != "" { + normalized = fmt.Sprintf("%s/projects/%s/locations/%s", normalized, projectID, region) + } + return normalized + } + + region = normalizeVertexRegion(region) + if strings.TrimSpace(projectID) == "" { + if region == "global" { + return fmt.Sprintf("https://aiplatform.googleapis.com/%s", version) + } + return fmt.Sprintf("https://%s-aiplatform.googleapis.com/%s", region, version) + } + + if region == "global" { + return fmt.Sprintf("https://aiplatform.googleapis.com/%s/projects/%s/locations/global", version, projectID) + } + return fmt.Sprintf("https://%s-aiplatform.googleapis.com/%s/projects/%s/locations/%s", region, version, projectID, region) +} + +func BuildPublisherModelURL(baseURL, version, projectID, region, publisher, modelName, action string) string { + return fmt.Sprintf( + "%s/publishers/%s/models/%s:%s", + BuildAPIBaseURL(baseURL, version, projectID, region), + publisher, + modelName, + action, + ) +} + +func BuildGoogleModelURL(baseURL, version, projectID, region, modelName, action string) string { + return BuildPublisherModelURL(baseURL, version, projectID, region, PublisherGoogle, modelName, action) +} + +func BuildAnthropicModelURL(baseURL, version, projectID, region, modelName, action string) string { + return BuildPublisherModelURL(baseURL, version, projectID, region, PublisherAnthropic, modelName, action) +} + +func BuildOpenSourceChatCompletionsURL(baseURL, projectID, region string) string { + return fmt.Sprintf( + "%s/endpoints/openapi/chat/completions", + BuildAPIBaseURL(baseURL, OpenSourceAPIVersion, projectID, region), + ) +}