建立自定義 Spring Cloud Gateway Filter

工程 | Fredrich Ombico | 2022 年 8 月 27 日 | ...

在本文中,我們將探討如何為 Spring Cloud Gateway 編寫自定義擴充套件。在我們開始之前,先回顧一下 Spring Cloud Gateway 的工作原理。

Spring Cloud Gateway diagram

  1. 首先,客戶端向閘道器發起網路請求
  2. 閘道器定義了許多路由,每個路由都有謂詞 (Predicates) 來將請求與路由匹配。例如,您可以基於 URL 的路徑段或請求的 HTTP 方法進行匹配。
  3. 匹配成功後,閘道器會在應用於該路由的每個過濾器上執行請求前邏輯。例如,您可能希望向請求中新增查詢引數。
  4. 代理過濾器將請求路由到被代理的服務
  5. 服務執行並返回響應
  6. 閘道器接收到響應並在返回響應之前對每個過濾器執行請求後邏輯。例如,您可以在返回客戶端之前移除不想要的響應頭。

我們的擴充套件將對請求體進行雜湊計算,並將計算出的值新增為一個名為 X-Hash 的請求頭。這對應於上面圖示中的步驟 3。注意:由於我們需要讀取請求體,閘道器將受到記憶體限制。

首先,我們在 start.spring.io 建立一個包含 Gateway 依賴的專案。在本例中,我們將使用 Java 和 JDK 17 以及 Spring Boot 2.7.3 的 Gradle 專案。下載、解壓並在您喜歡的 IDE 中開啟專案並執行,以確保您已完成本地開發的設定。

接下來,我們建立一個 GatewayFilter Factory,它是一個作用域限定於特定路由的過濾器,允許我們以某種方式修改傳入的 HTTP 請求或傳出的 HTTP 響應。在我們的例子中,我們將透過新增額外的請求頭來修改傳入的 HTTP 請求。

package com.example.demo;

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Collections;
import java.util.List;

import org.bouncycastle.util.encoders.Hex;
import reactor.core.publisher.Mono;

import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.server.HandlerStrategies;
import org.springframework.web.reactive.function.server.ServerRequest;

import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CACHED_SERVER_HTTP_REQUEST_DECORATOR_ATTR;

/**
 * This filter hashes the request body, placing the value in the X-Hash header.
 * Note: This causes the gateway to be memory constrained.
 * Sample usage: RequestHashing=SHA-256
 */
@Component
public class RequestHashingGatewayFilterFactory extends
        AbstractGatewayFilterFactory<RequestHashingGatewayFilterFactory.Config> {

    private static final String HASH_ATTR = "hash";
    private static final String HASH_HEADER = "X-Hash";
    private final List<HttpMessageReader<?>> messageReaders =
            HandlerStrategies.withDefaults().messageReaders();

    public RequestHashingGatewayFilterFactory() {
        super(Config.class);
    }

    @Override
    public GatewayFilter apply(Config config) {
        MessageDigest digest = config.getMessageDigest();
        return (exchange, chain) -> ServerWebExchangeUtils
                .cacheRequestBodyAndRequest(exchange, (httpRequest) -> ServerRequest
                    .create(exchange.mutate().request(httpRequest).build(),
                            messageReaders)
                    .bodyToMono(String.class)
                    .doOnNext(requestPayload -> exchange
                            .getAttributes()
                            .put(HASH_ATTR, computeHash(digest, requestPayload)))
                    .then(Mono.defer(() -> {
                        ServerHttpRequest cachedRequest = exchange.getAttribute(
                                CACHED_SERVER_HTTP_REQUEST_DECORATOR_ATTR);
                        Assert.notNull(cachedRequest, 
                                "cache request shouldn't be null");
                        exchange.getAttributes()
                                .remove(CACHED_SERVER_HTTP_REQUEST_DECORATOR_ATTR);

                        String hash = exchange.getAttribute(HASH_ATTR);
                        cachedRequest = cachedRequest.mutate()
                                .header(HASH_HEADER, hash)
                                .build();
                        return chain.filter(exchange.mutate()
                                .request(cachedRequest)
                                .build());
                    })));
    }

    @Override
    public List<String> shortcutFieldOrder() {
        return Collections.singletonList("algorithm");
    }

    private String computeHash(MessageDigest messageDigest, String requestPayload) {
        return Hex.toHexString(messageDigest.digest(requestPayload.getBytes()));
    }

    static class Config {

        private MessageDigest messageDigest;

        public MessageDigest getMessageDigest() {
            return messageDigest;
        }

        public void setAlgorithm(String algorithm) throws NoSuchAlgorithmException {
            messageDigest = MessageDigest.getInstance(algorithm);
        }
    }
}

讓我們更詳細地看一下程式碼

  • 我們為類添加了 @Component 註解。Spring Cloud Gateway 需要能夠檢測到這個類才能使用它。或者,我們可以使用 @Bean 定義一個例項。
  • 在我們的類名中,我們使用 GatewayFilterFactory 作為字尾。在 application.yaml 中新增此過濾器時,我們不包含字尾,只寫 RequestHashing。這是 Spring Cloud Gateway 過濾器命名約定。
  • 我們的類也像所有其他 Spring Cloud Gateway 過濾器一樣,擴充套件了 AbstractGatewayFilterFactory。我們還指定了一個類來配置我們的過濾器,一個巢狀的靜態類 Config 有助於保持簡單。配置類允許我們設定使用哪種雜湊演算法。
  • 重寫的 apply 方法是所有工作發生的地方。在引數中,我們得到了一個配置類的例項,在這裡我們可以訪問用於雜湊計算的 MessageDigest 例項。接下來,我們看到 (exchange, chain),這是返回的 GatewayFilter 介面類的一個 lambda 表示式。exchange 是 ServerWebExchange 的一個例項,它為 Gateway 過濾器提供了訪問 HTTP 請求和響應的能力。對於我們的情況,我們希望修改 HTTP 請求,這需要我們修改 (mutate) exchange。
  • 我們需要讀取請求體來生成雜湊,然而,由於請求體儲存在位元組緩衝區中,它在過濾器中只能被讀取一次。透過使用 ServerWebExchangeUtils,我們將請求作為屬性快取在 exchange 中。屬性提供了一種在過濾器鏈中共享特定請求資料的方式。我們還將儲存計算出的請求體雜湊。
  • 我們使用 exchange 屬性來獲取快取的請求和計算出的雜湊。然後,透過新增雜湊頭來修改 exchange,最後將其傳送到鏈中的下一個過濾器。
  • shortcutFieldOrder 方法有助於將引數的數量和順序對映到過濾器。字串 algorithm 匹配到 Config 類中的 setter 方法。

為了測試程式碼,我們將使用 WireMock。將依賴新增到您的 build.gradle 檔案中。

testImplementation 'com.github.tomakehurst:wiremock:2.27.2'

這裡我們有一個測試檢查請求頭的存在和值,另一個測試檢查如果請求體不存在,請求頭是否也不存在。

package com.example.demo;

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;

import com.github.tomakehurst.wiremock.WireMockServer;
import com.github.tomakehurst.wiremock.client.WireMock;
import com.github.tomakehurst.wiremock.core.WireMockConfiguration;
import org.bouncycastle.jcajce.provider.digest.SHA512;
import org.bouncycastle.util.encoders.Hex;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.autoconfigure.web.reactive.AutoConfigureWebTestClient;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.cloud.gateway.route.builder.RouteLocatorBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.http.HttpStatus;
import org.springframework.test.web.reactive.server.WebTestClient;

import static com.example.demo.RequestHashingGatewayFilterFactory.*;
import static com.example.demo.RequestHashingGatewayFilterFactoryTest.*;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT;

@SpringBootTest(
        webEnvironment = RANDOM_PORT,
        classes = RequestHashingFilterTestConfig.class)
@AutoConfigureWebTestClient
class RequestHashingGatewayFilterFactoryTest {

    @TestConfiguration
    static class RequestHashingFilterTestConfig {

        @Autowired
        RequestHashingGatewayFilterFactory requestHashingGatewayFilter;

        @Bean(destroyMethod = "stop")
        WireMockServer wireMockServer() {
            WireMockConfiguration options = wireMockConfig().dynamicPort();
            WireMockServer wireMock = new WireMockServer(options);
            wireMock.start();
            return wireMock;
        }

        @Bean
        RouteLocator testRoutes(RouteLocatorBuilder builder, WireMockServer wireMock)
                throws NoSuchAlgorithmException {
            Config config = new Config();
            config.setAlgorithm("SHA-512");

            GatewayFilter gatewayFilter = requestHashingGatewayFilter.apply(config);
            return builder
                    .routes()
                    .route(predicateSpec -> predicateSpec
                            .path("/post")
                            .filters(spec -> spec.filter(gatewayFilter))
                            .uri(wireMock.baseUrl()))
                    .build();
        }
    }

    @Autowired
    WebTestClient webTestClient;

    @Autowired
    WireMockServer wireMockServer;

    @AfterEach
    void afterEach() {
        wireMockServer.resetAll();
    }

    @Test
    void shouldAddHeaderWithComputedHash() {
        MessageDigest messageDigest = new SHA512.Digest();
        String body = "hello world";
        String expectedHash = Hex.toHexString(messageDigest.digest(body.getBytes()));

        wireMockServer.stubFor(WireMock.post("/post").willReturn(WireMock.ok()));

        webTestClient.post().uri("/post")
                .bodyValue(body)
                .exchange()
                .expectStatus()
                .isEqualTo(HttpStatus.OK);

        wireMockServer.verify(postRequestedFor(urlEqualTo("/post"))
                .withHeader("X-Hash", equalTo(expectedHash)));
    }

    @Test
    void shouldNotAddHeaderIfNoBody() {
        wireMockServer.stubFor(WireMock.post("/post").willReturn(WireMock.ok()));

        webTestClient.post().uri("/post")
                .exchange()
                .expectStatus()
                .isEqualTo(HttpStatus.OK);

        wireMockServer.verify(postRequestedFor(urlEqualTo("/post"))
                .withoutHeader("X-Hash"));
    }
}

要在閘道器中使用該過濾器,我們將 RequestHashing 過濾器新增到 application.yaml 中的一個路由中,使用 SHA-256 作為演算法。

spring:
  cloud:
    gateway:
      routes:
        - id: demo
          uri: https://httpbin.org
          predicates:
            - Path=/post/**
          filters:
            - RequestHashing=SHA-256

我們使用 https://httpbin.org,因為它在返回的響應中顯示了我們的請求頭。執行應用程式併發送一個 curl 請求檢視結果。

$> curl --request POST 'https://:8080/post' \
--header 'Content-Type: application/json' \
--data-raw '{
    "data": {
        "hello": "world"
    }
}'

{
  ...
  "data": "{\n    \"data\": {\n        \"hello\": \"world\"\n    }\n}",
  "headers": {
        "Accept": "*/*",
        "Accept-Encoding": "gzip, deflate, br",
        "Content-Length": "48",
        "Content-Type": "application/json",
        "Forwarded": "proto=http;host=\"localhost:8080\";for=\"[0:0:0:0:0:0:0:1]:55647\"",
        "Host": "httpbin.org",
        "User-Agent": "PostmanRuntime/7.29.0",
        "X-Forwarded-Host": "localhost:8080",
        "X-Hash": "1bd93d38735501b5aec7a822f8bc8136d9f1f71a30c2020511bdd5df379772b8"
    },
  ...
}

總之,我們瞭解瞭如何為 Spring Cloud Gateway 編寫自定義擴充套件。我們的過濾器讀取了請求體以生成一個雜湊值,並將其新增為請求頭。我們還使用 WireMock 編寫了過濾器的測試,以檢查請求頭的值。最後,我們運行了一個帶有該過濾器的閘道器來驗證結果。

如果您計劃在 Kubernetes 叢集上部署 Spring Cloud Gateway,請務必查閱 VMware Spring Cloud Gateway for Kubernetes。除了支援開源 Spring Cloud Gateway 過濾器和自定義過濾器(例如我們在上面編寫的過濾器)之外,它還提供了 更多內建過濾器 來操作您的請求和響應。Spring Cloud Gateway for Kubernetes 代表 API 開發團隊處理橫切關注點,例如:單點登入 (SSO)、訪問控制、速率限制、彈性、安全性等。

訂閱 Spring 新聞通訊

透過 Spring 新聞通訊保持聯絡

訂閱

搶先一步

VMware 提供培訓和認證,助力您加速發展。

瞭解更多

獲取支援

Tanzu Spring 透過一個簡單的訂閱提供對 OpenJDK™、Spring 和 Apache Tomcat® 的支援和二進位制檔案。

瞭解更多

即將舉行的活動

檢視 Spring 社群所有即將舉行的活動。

檢視全部