3、中间件

中间件简介

在Web应用开发中,中间件(Middleware)是位于应用程序与服务器之间的软件组件,能够拦截HTTP请求和响应,并执行特定的逻辑处理。

可以将中间件想象成一条管道上的过滤器,每个HTTP请求都必须通过这些过滤器才能到达目标处理函数,同样,响应也会通过这些过滤器返回给客户端。

中间件的主要作用

功能类型 说明 实例
预处理请求 在请求到达业务逻辑前进行验证、转换或过滤 身份验证、请求日志记录、输入验证
后处理响应 在响应返回客户端前进行修改或增强 响应压缩、添加安全头、格式转换
横切关注点 处理与业务逻辑无关但必要的功能 性能监控、错误处理、跨域资源共享

在Gin框架中,中间件是一个接收上下文对象*gin.Context的函数,其基本结构为:

func MyMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 中间件逻辑
    }
}

中间件的工作原理

Gin的中间件工作原理基于责任链模式(Chain of Responsibility Pattern)。这种设计模式允许多个处理器依次处理同一个请求,每个处理器专注于自己的职责领域。

在Gin中,每个中间件可以执行以下操作:

注意:如果中间件调用了c.Abort(),则后续的中间件不会执行,但当前中间件的剩余代码仍会继续执行。

Gin中间件的类型

按照特殊性原则组织中间件,将通用功能(如日志、错误处理)设为全局中间件,将特定功能(如认证、权限)设为路由组或单个路由中间件。

全局中间件

应用于所有路由,在任何路由处理之前执行。

// 全局中间件:应用于所有路由
router.Use(Logger())

路由组中间件

只应用于特定路由组内的路由。

// 路由组中间件:只应用于authorized组
authorized := router.Group("/")
authorized.Use(AuthRequired())

单个路由中间件

// 单个路由中间件:只应用于此路由
router.GET("/benchmark", MyBenchLogger(), benchEndpoint)

Gin内置中间件

Gin框架提供了多种内置中间件,用于解决常见的Web开发需求。这些中间件经过了性能优化,可以直接在项目中使用。

Logger

Logger中间件用于记录HTTP请求的详细信息,如请求方法、路径、状态码、响应时间等。

// 创建一个默认的路由器(包含Logger和Recovery中间件)
r := gin.Default()

// 或者手动添加Logger中间件
r := gin.New()
r.Use(gin.Logger())

Logger中间件输出的日志格式例如:

[GIN] 2023/07/15 - 11:48:16 | 200 |   1.234567ms |      ::1 | GET      "/ping"

日志包含以下信息:时间戳、状态码、延迟时间、客户端IP、请求方法、请求路径

在生产环境中,你可能需要自定义日志格式或将日志写入文件,这时可以编写自定义的日志中间件。

Recovery

Recovery中间件用于捕获处理过程中的panic异常,防止服务器崩溃,并返回500状态码。

r := gin.New()
r.Use(gin.Recovery())

工作原理:

在所有生产环境的Gin应用中,Recovery中间件都是必不可少的,它能确保服务的稳定性和可靠性。

BasicAuth

BasicAuth中间件提供了HTTP基本认证功能,用于保护敏感路由。

// 设置授权用户
authorized := r.Group("/admin", gin.BasicAuth(gin.Accounts{
    "admin": "password123",
    "user":  "secret",
}))

// 受保护的路由
authorized.GET("/secrets", func(c *gin.Context) {
    // 获取用户名
    user := c.MustGet(gin.AuthUserKey).(string)
    c.JSON(200, gin.H{
        "user":    user,
        "message": "You have access to the secrets",
    })
})

BasicAuth工作流程:

  1. 检查请求头中的Authorization字段
  2. 如果不存在或格式不正确,返回401 Unauthorized响应,并带有WWW-Authenticate
  3. 如果存在且有效,设置用户信息到上下文并继续处理
  4. 在处理函数中可以通过c.MustGet(gin.AuthUserKey)获取用户名

BasicAuth的凭证在传输过程中只是Base64编码,未加密,因此应始终与HTTPS一起使用。

CORS

可以使用第三方包github.com/gin-contrib/cors

go get -u github.com/gin-contrib/cors
r.Use(cors.New(cors.Config{
    AllowOrigins:     []string{"https://example.com"},
    AllowMethods:     []string{"GET", "POST", "PUT", "DELETE"},
    AllowHeaders:     []string{"Origin", "Content-Type"},
    ExposeHeaders:    []string{"Content-Length"},
    AllowCredentials: true,
    AllowOriginFunc: func(origin string) bool {
        return origin == "https://github.com"
    },
    MaxAge: 12 * time.Hour,
}))

CORS配置选项说明:

自定义中间件开发

中间件的基本结构

Gin中间件本质上是一个返回gin.HandlerFunc的函数,其基本结构如下:

func MiddlewareName() gin.HandlerFunc {
    // 初始化中间件所需的资源或配置
    
    // 返回实际的中间件处理函数
    return func(c *gin.Context) {
        // 前置处理逻辑:在请求到达路由处理函数前执行
        
        // 调用下一个中间件
        c.Next()
        
        // 后置处理逻辑:在所有中间件和处理函数执行完后执行
    }
}

这种结构有两个主要优势:

中间件执行流程控制

在中间件内部,有两个关键函数控制执行流程:

注意:调用c.Abort()后,当前中间件的后续代码仍会执行,如果需要完全退出中间件,还需要使用return语句。

执行流程控制示例:

func MyMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 前置逻辑
        fmt.Println("⬇️ 进入中间件")
        
        // 条件性终止
        if unauthorized(c) {
            // 设置错误响应
            c.AbortWithStatusJSON(403, gin.H{"error": "Forbidden"})
            // 注意:即使Abort,下面的代码仍会执行
            fmt.Println("❌ 请求被拒绝")
            return // 使用return阻止执行后续代码
        }
        
        // 继续调用链
        c.Next()
        
        // 后置逻辑,只有未被Abort时才会执行
        fmt.Println("⬆️ 离开中间件")
    }
}

中间件之间的数据传递

中间件可以通过Gin上下文在整个请求生命周期内传递数据:

// 第一个中间件设置数据
func Middleware1() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 使用Set存储数据
        c.Set("requestID", uuid.New().String())
        c.Set("startTime", time.Now())
        
        c.Next()
        
        // 可以在这里访问和修改数据
    }
}

// 第二个中间件或处理器读取数据
func Middleware2() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 使用Get获取数据
        requestID, exists := c.Get("requestID")
        if exists {
            fmt.Printf("处理请求: %v\n", requestID)
        }
        
        // 使用MustGet获取数据(如果不存在会panic)
        startTime := c.MustGet("startTime").(time.Time)
        
        c.Next()
        
        // 计算处理时间
        duration := time.Since(startTime)
        fmt.Printf("请求处理时间: %v\n", duration)
    }
}

c.Get()返回的是interface{}类型,需要进行类型断言才能安全使用。如果确定键一定存在,可以使用c.MustGet(),但要注意它会在键不存在时引发panic。

常见的自定义中间件

日志中间件

// middleware/logger.go

package middleware

import (
    "log"
    "time"
    
    "github.com/gin-gonic/gin"
)

// Logger 返回一个记录请求详情的日志中间件
func Logger() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 开始时间
        startTime := time.Now()
        
        // 设置变量记录请求处理状态
        c.Set("status", 200)
        
        // 请求路径
        path := c.Request.URL.Path
        
        // 请求方法
        method := c.Request.Method
        
        // 处理请求 - 调用下一个中间件或处理函数
        c.Next()
        
        // 结束时间
        endTime := time.Now()
        
        // 执行时间
        latency := endTime.Sub(startTime)
        
        // 获取状态码
        statusCode := c.Writer.Status()
        
        // 请求IP
        clientIP := c.ClientIP()
        
        // 获取错误信息(如果有)
        errorMessage := ""
        if len(c.Errors) > 0 {
            errorMessage = c.Errors.String()
        }
        
        // 日志格式
        log.Printf("[GIN] %s | %3d | %13v | %15s | %-7s | %s | %s",
            endTime.Format("2006/01/02 - 15:04:05"), // 时间
            statusCode,                              // 状态码
            latency,                                 // 耗时
            clientIP,                                // 客户端IP
            method,                                  // 请求方法
            path,                                    // 请求路径
            errorMessage,                            // 错误信息
        )
        
        // 示例:特定条件下记录更详细的信息
        if statusCode >= 500 {
            // 记录请求头和响应体等详细信息
            log.Printf("[ERROR] 服务器错误: %v", c.Errors.String())
        }
    }
}

认证中间件

JWT(JSON Web Token)认证是现代Web应用中常用的认证方式。以下是一个完整的JWT认证中间件实现:

// middleware/auth.go

package middleware

import (
    "errors"
    "net/http"
    "strings"
    "time"
    
    "github.com/gin-gonic/gin"
    "github.com/golang-jwt/jwt/v4"
)

// JWTConfig JWT配置
type JWTConfig struct {
    SecretKey     string        // JWT密钥
    Realm         string        // 认证领域
    TokenLookup   string        // Token查找位置
    TokenHeadName string        // Token头部名称
    Timeout       time.Duration // Token超时时间
}

// DefaultJWTConfig 默认JWT配置
var DefaultJWTConfig = JWTConfig{
    SecretKey:     "secret_key",
    Realm:         "gin jwt",
    TokenLookup:   "header:Authorization",
    TokenHeadName: "Bearer",
    Timeout:       time.Hour * 24,
}

// JWTClaims JWT声明
type JWTClaims struct {
    UserID uint   `json:"user_id"`
    Role   string `json:"role"`
    jwt.RegisteredClaims
}

// JWTAuth 返回JWT认证中间件
func JWTAuth(config ...JWTConfig) gin.HandlerFunc {
    // 使用提供的配置或默认配置
    var conf JWTConfig
    if len(config) > 0 {
        conf = config[0]
    } else {
        conf = DefaultJWTConfig
    }
    
    return func(c *gin.Context) {
        // 从请求中获取Token
        token, err := extractToken(c, conf)
        if err != nil {
            c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
                "error": "认证失败: " + err.Error(),
            })
            return
        }
        
        // 验证Token
        claims, err := validateToken(token, conf.SecretKey)
        if err != nil {
            c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
                "error": "无效的Token: " + err.Error(),
            })
            return
        }
        
        // 将用户信息存储到上下文
        c.Set("user_id", claims.UserID)
        c.Set("user_role", claims.Role)
        c.Set("claims", claims)
        
        c.Next()
    }
}

// 从请求中提取Token
func extractToken(c *gin.Context, conf JWTConfig) (string, error) {
    parts := strings.Split(conf.TokenLookup, ":")
    if len(parts) != 2 {
        return "", errors.New("无效的token查找设置")
    }
    
    switch parts[0] {
    case "header":
        // 从请求头获取Token
        auth := c.GetHeader(parts[1])
        if auth == "" {
            return "", errors.New("认证头不存在")
        }
        
        // 检查Token前缀
        if conf.TokenHeadName != "" {
            if !strings.HasPrefix(auth, conf.TokenHeadName+" ") {
                return "", errors.New("无效的认证头格式")
            }
            return auth[len(conf.TokenHeadName)+1:], nil
        }
        return auth, nil
        
    case "query":
        // 从URL查询参数获取Token
        token := c.Query(parts[1])
        if token == "" {
            return "", errors.New("未提供token参数")
        }
        return token, nil
        
    case "cookie":
        // 从Cookie获取Token
        token, err := c.Cookie(parts[1])
        if err != nil {
            return "", errors.New("未提供token Cookie")
        }
        return token, nil
    }
    
    return "", errors.New("不支持的token获取方法")
}

// 验证Token并返回声明
func validateToken(tokenString, secretKey string) (*JWTClaims, error) {
    // 解析Token
    token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
        // 验证签名算法
        if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
            return nil, errors.New("无效的签名算法")
        }
        return []byte(secretKey), nil
    })
    
    if err != nil {
        return nil, err
    }
    
    // 检查Token是否有效
    if !token.Valid {
        return nil, errors.New("无效的token")
    }
    
    // 类型断言
    claims, ok := token.Claims.(*JWTClaims)
    if !ok {
        return nil, errors.New("无效的token声明")
    }
    
    return claims, nil
}

错误处理中间件

统一的错误处理可以使应用更加健壮和用户友好:

// middleware/error_handler.go

package middleware

import (
    "errors"
    "log"
    "net/http"
    "runtime/debug"
    
    "github.com/gin-gonic/gin"
    "github.com/go-playground/validator/v10"
)

// 自定义错误类型
type AppError struct {
    Code    int         `json:"code"`
    Message string      `json:"message"`
    Details interface{} `json:"details,omitempty"`
}

func (e *AppError) Error() string {
    return e.Message
}

// ValidationError 验证错误
type ValidationError struct {
    Field   string `json:"field"`
    Message string `json:"message"`
}

// ErrorHandler 返回统一的错误处理中间件
func ErrorHandler() gin.HandlerFunc {
    return func(c *gin.Context) {
        // 设置一个recover,防止应用崩溃
        defer func() {
            if err := recover(); err != nil {
                // 记录堆栈信息
                log.Printf("Panic: %v\nStack: %s", err, debug.Stack())
                
                // 返回500错误
                c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
                    "error": gin.H{
                        "code":    http.StatusInternalServerError,
                        "message": "服务器内部错误",
                    },
                })
            }
        }()
        
        // 处理请求
        c.Next()
        
        // 检查是否有错误
        if len(c.Errors) > 0 {
            // 获取最后一个错误
            err := c.Errors.Last().Err
            
            // 根据错误类型响应不同的状态码
            switch e := err.(type) {
            case *AppError:
                // 自定义应用错误
                c.JSON(e.Code, gin.H{
                    "error": gin.H{
                        "code":    e.Code,
                        "message": e.Message,
                        "details": e.Details,
                    },
                })
                
            case validator.ValidationErrors:
                // 验证错误
                var validationErrors []ValidationError
                for _, err := range e {
                    validationErrors = append(validationErrors, ValidationError{
                        Field:   err.Field(),
                        Message: getValidationErrorMsg(err),
                    })
                }
                
                c.JSON(http.StatusBadRequest, gin.H{
                    "error": gin.H{
                        "code":    http.StatusBadRequest,
                        "message": "请求数据验证失败",
                        "details": validationErrors,
                    },
                })
                
            default:
                // 其他错误
                c.JSON(http.StatusInternalServerError, gin.H{
                    "error": gin.H{
                        "code":    http.StatusInternalServerError,
                        "message": "服务器内部错误",
                        "details": err.Error(),
                    },
                })
            }
        }
    }
}

// 获取验证错误的友好消息
func getValidationErrorMsg(err validator.FieldError) string {
    switch err.Tag() {
    case "required":
        return "此字段为必填项"
    case "email":
        return "必须是有效的电子邮箱地址"
    case "min":
        return "值太小"
    case "max":
        return "值太大"
    default:
        return "验证失败"
    }
}

// 在控制器中使用
func SomeController(c *gin.Context) {
    if somethingWrong {
        // 添加自定义错误到上下文
        c.Error(&AppError{
            Code:    http.StatusBadRequest,
            Message: "请求参数错误",
            Details: "无效的ID格式",
        })
        return
    }
    
    // 继续处理...
}

限流中间件

使用令牌桶算法实现API限流,防止服务被过度使用:

// middleware/rate_limiter.go

package middleware

import (
    "net/http"
    "sync"
    "time"
    
    "github.com/gin-gonic/gin"
    "golang.org/x/time/rate"
)

// IPRateLimiter IP限流器
type IPRateLimiter struct {
    ips    map[string]*rate.Limiter
    mu     *sync.RWMutex
    rate   rate.Limit
    burst  int
    expiry time.Duration
    lastSeen map[string]time.Time
}

// NewIPRateLimiter 创建新的IP限流器
func NewIPRateLimiter(r rate.Limit, b int, expiry time.Duration) *IPRateLimiter {
    return &IPRateLimiter{
        ips:    make(map[string]*rate.Limiter),
        mu:     &sync.RWMutex{},
        rate:   r,
        burst:  b,
        expiry: expiry,
        lastSeen: make(map[string]time.Time),
    }
}

// GetLimiter 获取给定IP的限流器
func (i *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
    i.mu.RLock()
    limiter, exists := i.ips[ip]
    now := time.Now()
    
    // 检查上次访问时间,清理过期的限流器
    if exists {
        lastSeen, ok := i.lastSeen[ip]
        if ok && now.Sub(lastSeen) > i.expiry {
            exists = false
        }
    }
    
    i.mu.RUnlock()
    
    if !exists {
        i.mu.Lock()
        // 创建新的限流器
        limiter = rate.NewLimiter(i.rate, i.burst)
        i.ips[ip] = limiter
        i.lastSeen[ip] = now
        
        // 清理过期的限流器
        if len(i.ips) > 10000 { // 避免无限增长
            for ip, t := range i.lastSeen {
                if now.Sub(t) > i.expiry {
                    delete(i.ips, ip)
                    delete(i.lastSeen, ip)
                }
            }
        }
        
        i.mu.Unlock()
    } else {
        // 更新最后访问时间
        i.mu.Lock()
        i.lastSeen[ip] = now
        i.mu.Unlock()
    }
    
    return limiter
}

// RateLimiter 返回IP限流中间件
// r: 每秒请求速率
// b: 突发请求数
func RateLimiter(r float64, b int) gin.HandlerFunc {
    // 创建IP限流器实例,限流器过期时间1小时
    limiter := NewIPRateLimiter(rate.Limit(r), b, time.Hour)
    
    return func(c *gin.Context) {
        // 获取客户端IP
        ip := c.ClientIP()
        
        // 获取该IP的限流器
        ipLimiter := limiter.GetLimiter(ip)
        
        // 检查是否允许请求
        if !ipLimiter.Allow() {
            c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
                "error": "请求频率超限,请稍后再试",
            })
            return
        }
        
        c.Next()
    }
}