中间件基础
# ASP.NET Core 中间件开发基础
# 中间件基础概念
中间件是一种按特定顺序组装在应用管道中的组件,用于处理请求和响应。每个中间件可以:
- 选择是否将请求传递给管道中的下一个组件
- 在调用管道中的下一个组件之前和之后执行工作
# 中间件开发工具和技能
工具:
- Visual Studio 或 VS Code
- .NET SDK 6.0+
- Postman/Swagger (API测试)
- Application Insights (性能监控)
- 性能分析工具 (dotTrace, PerfView)
- Docker (容器化)
必备技能:
- HTTP协议理解
- 异步编程
- 依赖注入
- 管道模式
- 线程安全知识
- 性能优化技术
# 1. 简单的请求日志中间件
用途: 记录所有到达应用程序的HTTP请求,有助于调试和监控。
实现:
public class RequestLoggingMiddleware
{
private readonly RequestDelegate _next;
private readonly ILogger<RequestLoggingMiddleware> _logger;
public RequestLoggingMiddleware(RequestDelegate next, ILogger<RequestLoggingMiddleware> logger)
{
_next = next;
_logger = logger;
}
public async Task InvokeAsync(HttpContext context)
{
_logger.LogInformation(
"HTTP {Method} {Url} - {StatusCode}",
context.Request.Method,
context.Request.Path,
context.Response.StatusCode);
await _next(context);
}
}
// 扩展方法用于更简洁的注册
public static class RequestLoggingMiddlewareExtensions
{
public static IApplicationBuilder UseRequestLogging(this IApplicationBuilder builder)
{
return builder.UseMiddleware<RequestLoggingMiddleware>();
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
注册方式:
app.UseRequestLogging();
工作原理: 该中间件拦截每个HTTP请求,记录请求方法和路径信息,然后将请求传递给下一个中间件。由于在调用_next()之前记录日志,因此它记录了请求的初始信息。
# 2. 请求时间监控中间件
用途: 测量请求处理时间,帮助识别性能瓶颈。
实现:
public class RequestTimingMiddleware
{
private readonly RequestDelegate _next;
private readonly ILogger<RequestTimingMiddleware> _logger;
public RequestTimingMiddleware(RequestDelegate next, ILogger<RequestTimingMiddleware> logger)
{
_next = next;
_logger = logger;
}
public async Task InvokeAsync(HttpContext context)
{
var sw = Stopwatch.StartNew();
try
{
await _next(context);
}
finally
{
sw.Stop();
_logger.LogInformation(
"Request {Method} {Path} completed in {ElapsedMilliseconds}ms",
context.Request.Method,
context.Request.Path,
sw.ElapsedMilliseconds);
}
}
}
public static class RequestTimingMiddlewareExtensions
{
public static IApplicationBuilder UseRequestTiming(this IApplicationBuilder builder)
{
return builder.UseMiddleware<RequestTimingMiddleware>();
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
注册方式:
app.UseRequestTiming();
工作原理: 使用Stopwatch在请求开始时启动计时,在请求完成后停止计时并记录耗时。finally块确保即使发生异常也会记录时间。
# 3. IP地址过滤中间件
用途: 根据IP地址白名单或黑名单限制访问,提高安全性。
实现:
public class IpFilterOptions
{
public List<string> AllowedIps { get; set; } = new List<string>();
public List<string> BlockedIps { get; set; } = new List<string>();
public bool AllowAllIpsIfListEmpty { get; set; } = true;
}
public class IpFilterMiddleware
{
private readonly RequestDelegate _next;
private readonly IpFilterOptions _options;
private readonly ILogger<IpFilterMiddleware> _logger;
public IpFilterMiddleware(RequestDelegate next, IpFilterOptions options, ILogger<IpFilterMiddleware> logger)
{
_next = next;
_options = options;
_logger = logger;
}
public async Task InvokeAsync(HttpContext context)
{
var remoteIp = context.Connection.RemoteIpAddress?.ToString();
if (string.IsNullOrEmpty(remoteIp))
{
context.Response.StatusCode = 400;
await context.Response.WriteAsync("Invalid IP address");
return;
}
var isAllowed = IsIpAllowed(remoteIp);
if (!isAllowed)
{
_logger.LogWarning("Request from blocked IP: {IpAddress}", remoteIp);
context.Response.StatusCode = 403;
await context.Response.WriteAsync("Access denied based on IP address");
return;
}
await _next(context);
}
private bool IsIpAllowed(string ipAddress)
{
// 如果在黑名单中,拒绝访问
if (_options.BlockedIps.Contains(ipAddress))
return false;
// 如果白名单为空且设置为允许,则允许所有IP
if (_options.AllowedIps.Count == 0 && _options.AllowAllIpsIfListEmpty)
return true;
// 如果在白名单中,允许访问
return _options.AllowedIps.Contains(ipAddress);
}
}
public static class IpFilterMiddlewareExtensions
{
public static IApplicationBuilder UseIpFiltering(
this IApplicationBuilder builder,
Action<IpFilterOptions> configureOptions = null)
{
var options = new IpFilterOptions();
configureOptions?.Invoke(options);
return builder.UseMiddleware<IpFilterMiddleware>(options);
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
注册方式:
app.UseIpFiltering(options => {
options.AllowedIps.Add("127.0.0.1");
options.AllowedIps.Add("::1");
options.BlockedIps.Add("192.168.1.100");
});
2
3
4
5
工作原理: 该中间件检查请求的远程IP地址,根据配置的白名单和黑名单规则决定是否允许请求继续。如果IP被阻止,将返回403状态码。
# 4. 安全头部中间件
用途: 添加安全相关的HTTP头部,如内容安全策略(CSP)、XSS保护等。
实现:
public class SecurityHeadersOptions
{
public bool UseHsts { get; set; } = true;
public bool UseXssProtection { get; set; } = true;
public bool UseContentTypeOptions { get; set; } = true;
public bool UseFrameOptions { get; set; } = true;
public string ContentSecurityPolicy { get; set; } = "default-src 'self'";
}
public class SecurityHeadersMiddleware
{
private readonly RequestDelegate _next;
private readonly SecurityHeadersOptions _options;
public SecurityHeadersMiddleware(RequestDelegate next, SecurityHeadersOptions options)
{
_next = next;
_options = options;
}
public async Task InvokeAsync(HttpContext context)
{
if (_options.UseHsts)
{
context.Response.Headers.Add("Strict-Transport-Security", "max-age=31536000; includeSubDomains");
}
if (_options.UseXssProtection)
{
context.Response.Headers.Add("X-XSS-Protection", "1; mode=block");
}
if (_options.UseContentTypeOptions)
{
context.Response.Headers.Add("X-Content-Type-Options", "nosniff");
}
if (_options.UseFrameOptions)
{
context.Response.Headers.Add("X-Frame-Options", "DENY");
}
if (!string.IsNullOrEmpty(_options.ContentSecurityPolicy))
{
context.Response.Headers.Add("Content-Security-Policy", _options.ContentSecurityPolicy);
}
await _next(context);
}
}
public static class SecurityHeadersMiddlewareExtensions
{
public static IApplicationBuilder UseSecurityHeaders(
this IApplicationBuilder builder,
Action<SecurityHeadersOptions> configureOptions = null)
{
var options = new SecurityHeadersOptions();
configureOptions?.Invoke(options);
return builder.UseMiddleware<SecurityHeadersMiddleware>(options);
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
注册方式:
app.UseSecurityHeaders(options => {
options.UseHsts = true;
options.ContentSecurityPolicy = "default-src 'self'; script-src 'self' https://trusted-cdn.com";
});
2
3
4
工作原理: 此中间件为所有响应添加配置的安全头部信息,有助于防御常见的Web安全威胁,如XSS和点击劫持。
# 5. 响应压缩中间件
用途: 压缩HTTP响应以减少带宽使用,提高性能。
实现:
public class ResponseCompressionOptions
{
public List<string> MimeTypes { get; set; } = new List<string>
{
"text/plain",
"text/html",
"text/css",
"text/javascript",
"application/javascript",
"application/json",
"application/xml"
};
public bool EnableForHttps { get; set; } = true;
public CompressionLevel CompressionLevel { get; set; } = CompressionLevel.Fastest;
}
public class ResponseCompressionMiddleware
{
private readonly RequestDelegate _next;
private readonly ResponseCompressionOptions _options;
private readonly ILogger<ResponseCompressionMiddleware> _logger;
public ResponseCompressionMiddleware(
RequestDelegate next,
ResponseCompressionOptions options,
ILogger<ResponseCompressionMiddleware> logger)
{
_next = next;
_options = options;
_logger = logger;
}
public async Task InvokeAsync(HttpContext context)
{
var originalBodyStream = context.Response.Body;
// 获取接受的压缩方法
var acceptEncoding = context.Request.Headers["Accept-Encoding"].ToString().ToLowerInvariant();
// 检查是否支持压缩
if (ShouldCompress(context, acceptEncoding))
{
using var memoryStream = new MemoryStream();
context.Response.Body = memoryStream;
await _next(context);
if (memoryStream.Length > 0)
{
memoryStream.Position = 0;
if (acceptEncoding.Contains("gzip"))
{
context.Response.Headers.Add("Content-Encoding", "gzip");
await CompressStreamGzip(memoryStream, originalBodyStream, _options.CompressionLevel);
}
else if (acceptEncoding.Contains("deflate"))
{
context.Response.Headers.Add("Content-Encoding", "deflate");
await CompressStreamDeflate(memoryStream, originalBodyStream, _options.CompressionLevel);
}
else
{
memoryStream.Position = 0;
await memoryStream.CopyToAsync(originalBodyStream);
}
}
}
else
{
await _next(context);
}
}
private bool ShouldCompress(HttpContext context, string acceptEncoding)
{
// 不支持压缩方法的情况
if (string.IsNullOrEmpty(acceptEncoding) ||
!(acceptEncoding.Contains("gzip") || acceptEncoding.Contains("deflate")))
{
return false;
}
// HTTPS请求但不启用HTTPS压缩的情况
if (context.Request.IsHttps && !_options.EnableForHttps)
{
return false;
}
// 检查内容类型是否应该被压缩
var contentType = context.Response.ContentType;
if (string.IsNullOrEmpty(contentType))
{
return false;
}
return _options.MimeTypes.Any(m => contentType.StartsWith(m, StringComparison.OrdinalIgnoreCase));
}
private async Task CompressStreamGzip(Stream source, Stream destination, CompressionLevel compressionLevel)
{
using var gzipStream = new GZipStream(destination, compressionLevel, true);
await source.CopyToAsync(gzipStream);
}
private async Task CompressStreamDeflate(Stream source, Stream destination, CompressionLevel compressionLevel)
{
using var deflateStream = new DeflateStream(destination, compressionLevel, true);
await source.CopyToAsync(deflateStream);
}
}
public static class ResponseCompressionMiddlewareExtensions
{
public static IApplicationBuilder UseResponseCompression(
this IApplicationBuilder builder,
Action<ResponseCompressionOptions> configureOptions = null)
{
var options = new ResponseCompressionOptions();
configureOptions?.Invoke(options);
return builder.UseMiddleware<ResponseCompressionMiddleware>(options);
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
注册方式:
app.UseResponseCompression(options => {
options.CompressionLevel = System.IO.Compression.CompressionLevel.Optimal;
options.MimeTypes.Add("application/vnd.ms-excel");
});
2
3
4
工作原理: 该中间件检测浏览器支持的压缩方法,并根据内容类型和压缩级别选择性地压缩响应。它通过拦截原始响应流,将内容写入内存流,压缩后再写回原始流实现。
# 6. 健康检查中间件
用途: 提供健康状态检查端点,便于负载均衡器和监控工具使用。
实现:
public class HealthCheckOptions
{
public string HealthEndpointPath { get; set; } = "/health";
public List<Func<HttpContext, Task<HealthCheckResult>>> HealthChecks { get; } = new List<Func<HttpContext, Task<HealthCheckResult>>>();
}
public enum HealthStatus
{
Healthy,
Degraded,
Unhealthy
}
public class HealthCheckResult
{
public string Name { get; set; }
public HealthStatus Status { get; set; }
public string Description { get; set; }
public TimeSpan Duration { get; set; }
}
public class HealthCheckMiddleware
{
private readonly RequestDelegate _next;
private readonly HealthCheckOptions _options;
private readonly ILogger<HealthCheckMiddleware> _logger;
public HealthCheckMiddleware(
RequestDelegate next,
HealthCheckOptions options,
ILogger<HealthCheckMiddleware> logger)
{
_next = next;
_options = options;
_logger = logger;
}
public async Task InvokeAsync(HttpContext context)
{
if (context.Request.Path.Equals(_options.HealthEndpointPath, StringComparison.OrdinalIgnoreCase))
{
await HandleHealthCheckAsync(context);
return;
}
await _next(context);
}
private async Task HandleHealthCheckAsync(HttpContext context)
{
context.Response.ContentType = "application/json";
var results = new List<HealthCheckResult>();
var overallStatus = HealthStatus.Healthy;
foreach (var check in _options.HealthChecks)
{
var sw = Stopwatch.StartNew();
var result = await check(context);
sw.Stop();
result.Duration = sw.Elapsed;
results.Add(result);
// 如果任何检查是不健康的,整体状态就是不健康的
if (result.Status == HealthStatus.Unhealthy)
{
overallStatus = HealthStatus.Unhealthy;
}
// 如果当前是健康的,但有任何检查是降级的,则整体状态为降级
else if (overallStatus == HealthStatus.Healthy && result.Status == HealthStatus.Degraded)
{
overallStatus = HealthStatus.Degraded;
}
}
var response = new
{
Status = overallStatus.ToString(),
Results = results.Select(r => new
{
r.Name,
Status = r.Status.ToString(),
r.Description,
Duration = $"{r.Duration.TotalMilliseconds}ms"
}),
Timestamp = DateTime.UtcNow
};
context.Response.StatusCode = overallStatus == HealthStatus.Healthy ? 200 :
overallStatus == HealthStatus.Degraded ? 200 : 503;
await context.Response.WriteAsJsonAsync(response);
}
}
public static class HealthCheckMiddlewareExtensions
{
public static IApplicationBuilder UseHealthChecks(
this IApplicationBuilder builder,
Action<HealthCheckOptions> configureOptions = null)
{
var options = new HealthCheckOptions();
configureOptions?.Invoke(options);
// 添加默认检查如果没有配置任何检查
if (options.HealthChecks.Count == 0)
{
options.HealthChecks.Add(_ => Task.FromResult(
new HealthCheckResult
{
Name = "System",
Status = HealthStatus.Healthy,
Description = "Application is running"
}));
}
return builder.UseMiddleware<HealthCheckMiddleware>(options);
}
// 便捷方法用于添加数据库健康检查
public static HealthCheckOptions AddDbContextCheck<TContext>(
this HealthCheckOptions options,
IServiceProvider serviceProvider,
string name = "Database")
where TContext : DbContext
{
options.HealthChecks.Add(async _ =>
{
try
{
using var scope = serviceProvider.CreateScope();
var dbContext = scope.ServiceProvider.GetRequiredService<TContext>();
var canConnect = await dbContext.Database.CanConnectAsync();
return new HealthCheckResult
{
Name = name,
Status = canConnect ? HealthStatus.Healthy : HealthStatus.Unhealthy,
Description = canConnect ? "Database is healthy" : "Database is unhealthy"
};
}
catch (Exception ex)
{
return new HealthCheckResult
{
Name = name,
Status = HealthStatus.Unhealthy,
Description = $"Exception: {ex.Message}"
};
}
});
return options;
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
注册方式:
app.UseHealthChecks(options => {
options.HealthEndpointPath = "/api/health";
// 添加自定义健康检查
options.HealthChecks.Add(async _ => {
return new HealthCheckResult {
Name = "Redis",
Status = await CheckRedisConnectionAsync() ? HealthStatus.Healthy : HealthStatus.Unhealthy,
Description = "Redis connection check"
};
});
// 使用扩展方法添加数据库检查
options.AddDbContextCheck<ApplicationDbContext>(app.ApplicationServices);
});
2
3
4
5
6
7
8
9
10
11
12
13
14
15
工作原理: 此中间件在指定路径提供健康状态API端点,它执行已注册的健康检查,汇总结果,并以JSON格式返回健康状态信息。它支持检查数据库连接、依赖服务等,便于监控系统确定应用程序是否正常运行。
# 7. API密钥验证中间件
用途: 通过API密钥验证请求,保护API端点。
实现:
public class ApiKeyOptions
{
public string HeaderName { get; set; } = "X-API-Key";
public string QueryParamName { get; set; } = "api_key";
public List<string> ValidApiKeys { get; set; } = new List<string>();
public Func<string, Task<bool>> ValidateApiKeyAsync { get; set; }
public List<string> ExcludedPaths { get; set; } = new List<string>();
}
public class ApiKeyMiddleware
{
private readonly RequestDelegate _next;
private readonly ApiKeyOptions _options;
private readonly ILogger<ApiKeyMiddleware> _logger;
public ApiKeyMiddleware(
RequestDelegate next,
ApiKeyOptions options,
ILogger<ApiKeyMiddleware> logger)
{
_next = next;
_options = options;
_logger = logger;
}
public async Task InvokeAsync(HttpContext context)
{
// 检查是否需要跳过验证
if (ShouldSkipValidation(context.Request.Path))
{
await _next(context);
return;
}
// 从Header或查询参数中获取API密钥
if (!TryGetApiKey(context, out var apiKey))
{
_logger.LogWarning("API key missing in request");
context.Response.StatusCode = 401; // Unauthorized
await context.Response.WriteAsJsonAsync(new { error = "API key is required" });
return;
}
// 验证API密钥
bool isValid = false;
// 如果定义了自定义验证函数,使用它
if (_options.ValidateApiKeyAsync != null)
{
isValid = await _options.ValidateApiKeyAsync(apiKey);
}
// 否则使用预定义的有效密钥列表
else if (_options.ValidApiKeys.Count > 0)
{
isValid = _options.ValidApiKeys.Contains(apiKey);
}
if (!isValid)
{
_logger.LogWarning("Invalid API key: {ApiKey}", apiKey);
context.Response.StatusCode = 401; // Unauthorized
await context.Response.WriteAsJsonAsync(new { error = "Invalid API key" });
return;
}
// API密钥有效,继续处理请求
await _next(context);
}
private bool ShouldSkipValidation(PathString path)
{
return _options.ExcludedPaths.Any(p =>
path.StartsWithSegments(p, StringComparison.OrdinalIgnoreCase));
}
private bool TryGetApiKey(HttpContext context, out string apiKey)
{
// 尝试从头部获取
if (context.Request.Headers.TryGetValue(_options.HeaderName, out var headerValues))
{
apiKey = headerValues.FirstOrDefault();
if (!string.IsNullOrEmpty(apiKey))
{
return true;
}
}
// 尝试从查询参数获取
if (context.Request.Query.TryGetValue(_options.QueryParamName, out var queryValues))
{
apiKey = queryValues.FirstOrDefault();
if (!string.IsNullOrEmpty(apiKey))
{
return true;
}
}
apiKey = null;
return false;
}
}
public static class ApiKeyMiddlewareExtensions
{
public static IApplicationBuilder UseApiKeyValidation(
this IApplicationBuilder builder,
Action<ApiKeyOptions> configureOptions = null)
{
var options = new ApiKeyOptions();
configureOptions?.Invoke(options);
return builder.UseMiddleware<ApiKeyMiddleware>(options);
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
注册方式:
app.UseApiKeyValidation(options => {
options.ValidApiKeys.Add("sk_test_12345");
options.ValidApiKeys.Add("sk_live_67890");
options.ExcludedPaths.Add("/api/public");
options.ExcludedPaths.Add("/health");
});
2
3
4
5
6
工作原理: 此中间件从HTTP头或URL查询参数中检查API密钥,并根据配置的验证规则验证它。验证失败时返回401未授权响应。它支持排除特定路径,允许公共访问端点不需要验证。
# 8. 请求限流中间件
用途: 限制请求速率,防止API滥用和DOS攻击。
实现:
public class RateLimitOptions
{
public int PermitsPerSecond { get; set; } = 10;
public int BurstSize { get; set; } = 20;
public string ClientIdHeader { get; set; } = "X-ClientId";
public Func<HttpContext, string> ClientIdResolver { get; set; }
public List<string> ExcludedPaths { get; set; } = new List<string>();
public int StatusCode { get; set; } = 429; // Too Many Requests
}
public class RateLimitMiddleware
{
private readonly RequestDelegate _next;
private readonly RateLimitOptions _options;
private readonly ILogger<RateLimitMiddleware> _logger;
private readonly ConcurrentDictionary<string, TokenBucket> _buckets = new();
public RateLimitMiddleware(
RequestDelegate next,
RateLimitOptions options,
ILogger<RateLimitMiddleware> logger)
{
_next = next;
_options = options;
_logger = logger;
}
public async Task InvokeAsync(HttpContext context)
{
// 检查是否需要跳过限流
if (ShouldSkipRateLimit(context.Request.Path))
{
await _next(context);
return;
}
// 获取客户端ID
var clientId = ResolveClientId(context);
if (string.IsNullOrEmpty(clientId))
{
clientId = "anonymous";
}
// 获取或创建令牌桶
var bucket = _buckets.GetOrAdd(clientId, _ => new TokenBucket(
_options.PermitsPerSecond,
_options.BurstSize));
// 尝试获取令牌
if (!bucket.TryConsume(1))
{
_logger.LogWarning("Rate limit exceeded for client: {ClientId}", clientId);
context.Response.StatusCode = _options.StatusCode;
context.Response.Headers.Add("Retry-After", "1");
await context.Response.WriteAsJsonAsync(new { error = "Rate limit exceeded. Try again later." });
return;
}
await _next(context);
}
private bool ShouldSkipRateLimit(PathString path)
{
return _options.ExcludedPaths.Any(p =>
path.StartsWithSegments(p, StringComparison.OrdinalIgnoreCase));
}
private string ResolveClientId(HttpContext context)
{
// 使用自定义解析器(如果提供)
if (_options.ClientIdResolver != null)
{
return _options.ClientIdResolver(context);
}
// 从头部获取
if (context.Request.Headers.TryGetValue(_options.ClientIdHeader, out var headerValue))
{
return headerValue;
}
// 使用IP地址作为后备
return context.Connection.RemoteIpAddress?.ToString();
}
// 令牌桶实现
private class TokenBucket
{
private readonly double _refillRate;
private readonly double _maxTokens;
private double _currentTokens;
private long _lastRefillTimestamp;
private readonly SemaphoreSlim _sync = new SemaphoreSlim(1, 1);
public TokenBucket(double refillRate, double maxTokens)
{
_refillRate = refillRate;
_maxTokens = maxTokens;
_currentTokens = maxTokens;
_lastRefillTimestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds();
}
public bool TryConsume(int tokens)
{
_sync.Wait();
try
{
RefillTokens();
if (_currentTokens < tokens)
{
return false;
}
_currentTokens -= tokens;
return true;
}
finally
{
_sync.Release();
}
}
private void RefillTokens()
{
var now = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds();
var elapsedMillis = now - _lastRefillTimestamp;
if (elapsedMillis <= 0)
{
return;
}
// 计算要添加的令牌数量
var tokensToAdd = elapsedMillis * _refillRate / 1000.0;
if (tokensToAdd < 0.01)
{
return;
}
// 添加令牌并更新时间戳
_currentTokens = Math.Min(_maxTokens, _currentTokens + tokensToAdd);
_lastRefillTimestamp = now;
}
}
}
public static class RateLimitMiddlewareExtensions
{
public static IApplicationBuilder UseRateLimit(
this IApplicationBuilder builder,
Action<RateLimitOptions> configureOptions = null)
{
var options = new RateLimitOptions();
configureOptions?.Invoke(options);
return builder.UseMiddleware<RateLimitMiddleware>(options);
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
注册方式:
app.UseRateLimit(options => {
options.PermitsPerSecond = 5;
options.BurstSize = 10;
options.ClientIdResolver = context => {
// 可以从认证用户或API密钥获取客户端ID
return context.User?.Identity?.Name ??
context.Request.Headers["X-API-Key"].FirstOrDefault() ??
context.Connection.RemoteIpAddress?.ToString();
};
options.ExcludedPaths.Add("/health");
});
2
3
4
5
6
7
8
9
10
11
工作原理: 采用令牌桶算法实现请求限流。为每个客户端维护一个令牌桶,以固定速率生成令牌。请求需要消耗令牌才能通过,当没有可用令牌时请求被拒绝。支持突发流量和自定义客户端识别策略。
# 9. 全局异常处理中间件
用途: 捕获和统一处理应用程序中的异常,提供一致的错误响应。
实现:
public class ErrorHandlingOptions
{
public bool IncludeExceptionDetails { get; set; } = false;
public bool LogExceptions { get; set; } = true;
public Dictionary<Type, Func<Exception, HttpContext, Task>> ExceptionHandlers { get; } =
new Dictionary<Type, Func<Exception, HttpContext, Task>>();
}
public class ErrorHandlingMiddleware
{
private readonly RequestDelegate _next;
private readonly ErrorHandlingOptions _options;
private readonly ILogger<ErrorHandlingMiddleware> _logger;
private readonly IWebHostEnvironment _environment;
public ErrorHandlingMiddleware(
RequestDelegate next,
ErrorHandlingOptions options,
ILogger<ErrorHandlingMiddleware> logger,
IWebHostEnvironment environment)
{
_next = next;
_options = options;
_logger = logger;
_environment = environment;
}
public async Task InvokeAsync(HttpContext context)
{
try
{
await _next(context);
}
catch (Exception ex)
{
// 已处理的响应不需要再处理错误
if (context.Response.HasStarted)
{
_logger.LogWarning("Response has already started, the error middleware will not be executed.");
throw;
}
await HandleExceptionAsync(context, ex);
}
}
private async Task HandleExceptionAsync(HttpContext context, Exception exception)
{
// 记录异常
if (_options.LogExceptions)
{
_logger.LogError(exception, "An unhandled exception occurred");
}
// 检查是否有自定义处理程序
var exceptionType = exception.GetType();
if (_options.ExceptionHandlers.TryGetValue(exceptionType, out var handler))
{
await handler(exception, context);
return;
}
// 默认处理
context.Response.ContentType = "application/json";
context.Response.StatusCode = DetermineStatusCode(exception);
var response = CreateErrorResponse(exception);
await context.Response.WriteAsJsonAsync(response);
}
private int DetermineStatusCode(Exception exception)
{
return exception switch
{
ArgumentException => StatusCodes.Status400BadRequest,
UnauthorizedAccessException => StatusCodes.Status401Unauthorized,
FileNotFoundException => StatusCodes.Status404NotFound,
NotImplementedException => StatusCodes.Status501NotImplemented,
_ => StatusCodes.Status500InternalServerError
};
}
private object CreateErrorResponse(Exception exception)
{
if (_options.IncludeExceptionDetails || _environment.IsDevelopment())
{
return new
{
error = new
{
message = exception.Message,
type = exception.GetType().Name,
stackTrace = exception.StackTrace,
innerException = exception.InnerException != null ?
CreateErrorResponse(exception.InnerException) : null
}
};
}
else
{
return new
{
error = new
{
message = "An error occurred. Please try again later.",
traceId = Activity.Current?.Id ?? context?.TraceIdentifier
}
};
}
}
}
public static class ErrorHandlingMiddlewareExtensions
{
public static IApplicationBuilder UseGlobalErrorHandling(
this IApplicationBuilder builder,
Action<ErrorHandlingOptions> configureOptions = null)
{
var options = new ErrorHandlingOptions();
configureOptions?.Invoke(options);
return builder.UseMiddleware<ErrorHandlingMiddleware>(options);
}
// 便捷方法用于添加自定义异常处理程序
public static ErrorHandlingOptions AddExceptionHandler<TException>(
this ErrorHandlingOptions options,
Func<TException, HttpContext, Task> handler)
where TException : Exception
{
options.ExceptionHandlers[typeof(TException)] = (ex, context) =>
handler((TException)ex, context);
return options;
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
注册方式:
app.UseGlobalErrorHandling(options => {
options.IncludeExceptionDetails = env.IsDevelopment();
// 添加自定义异常处理程序
options.AddExceptionHandler<ValidationException>(async (ex, context) => {
context.Response.StatusCode = StatusCodes.Status400BadRequest;
context.Response.ContentType = "application/json";
await context.Response.WriteAsJsonAsync(new {
errors = ex.Errors
});
});
});
2
3
4
5
6
7
8
9
10
11
12
13
工作原理: 该中间件捕获整个应用管道中的未处理异常,并将其转换为统一的JSON错误响应。它支持根据异常类型映射HTTP状态码,并提供自定义处理程序来处理特定类型的异常。可配置是否在响应中包含详细的异常信息以及是否记录异常。
# 10. 本地化中间件
用途: 处理多语言支持,根据请求自动设置正确的文化信息。
实现:
public class LocalizationOptions
{
public List<CultureInfo> SupportedCultures { get; set; } = new List<CultureInfo>();
public string DefaultCulture { get; set; } = "en-US";
public string CookieName { get; set; } = ".Culture";
public TimeSpan CookieExpiration { get; set; } = TimeSpan.FromDays(30);
public bool UseCookies { get; set; } = true;
public bool UseQueryString { get; set; } = true;
public bool UseRequestHeader { get; set; } = true;
public string QueryStringKey { get; set; } = "culture";
}
public class LocalizationMiddleware
{
private readonly RequestDelegate _next;
private readonly LocalizationOptions _options;
private readonly ILogger<LocalizationMiddleware> _logger;
public LocalizationMiddleware(
RequestDelegate next,
LocalizationOptions options,
ILogger<LocalizationMiddleware> logger)
{
_next = next;
_options = options;
_logger = logger;
// 确保至少有默认文化
if (_options.SupportedCultures.Count == 0)
{
_options.SupportedCultures.Add(new CultureInfo(_options.DefaultCulture));
}
}
public async Task InvokeAsync(HttpContext context)
{
var cultureName = DetermineCultureName(context);
var culture = GetSupportedCulture(cultureName);
SetCurrentCulture(culture);
// 如果使用cookie且请求中指定了不同的文化,更新cookie
if (_options.UseCookies && !string.IsNullOrEmpty(cultureName))
{
var currentCookieCulture = GetCultureFromCookie(context);
if (currentCookieCulture != culture.Name)
{
SetCultureCookie(context, culture.Name);
}
}
await _next(context);
}
private string DetermineCultureName(HttpContext context)
{
// 按优先级检查文化信息来源
// 1. 查询字符串
if (_options.UseQueryString &&
context.Request.Query.TryGetValue(_options.QueryStringKey, out var queryCulture))
{
return queryCulture;
}
// 2. Cookie
if (_options.UseCookies)
{
var cookieCulture = GetCultureFromCookie(context);
if (!string.IsNullOrEmpty(cookieCulture))
{
return cookieCulture;
}
}
// 3. Accept-Language 头
if (_options.UseRequestHeader)
{
var headerCulture = GetCultureFromHeader(context);
if (!string.IsNullOrEmpty(headerCulture))
{
return headerCulture;
}
}
// 默认文化
return _options.DefaultCulture;
}
private string GetCultureFromCookie(HttpContext context)
{
if (context.Request.Cookies.TryGetValue(_options.CookieName, out var cookieValue))
{
return cookieValue;
}
return null;
}
private string GetCultureFromHeader(HttpContext context)
{
if (context.Request.Headers.TryGetValue("Accept-Language", out var values))
{
// 解析Accept-Language头,例如:en-US,en;q=0.9,zh-CN;q=0.8,zh;q=0.7
var languages = values.ToString().Split(',');
foreach (var language in languages)
{
var parts = language.Split(';');
var code = parts[0].Trim();
// 检查是否支持该文化
if (IsSupportedCulture(code))
{
return code;
}
}
}
return null;
}
private bool IsSupportedCulture(string cultureName)
{
return _options.SupportedCultures.Any(c =>
string.Equals(c.Name, cultureName, StringComparison.OrdinalIgnoreCase));
}
private CultureInfo GetSupportedCulture(string cultureName)
{
// 尝试找到匹配的文化
var supportedCulture = _options.SupportedCultures.FirstOrDefault(c =>
string.Equals(c.Name, cultureName, StringComparison.OrdinalIgnoreCase));
// 如果找不到,使用默认文化
return supportedCulture ?? _options.SupportedCultures.First();
}
private void SetCurrentCulture(CultureInfo culture)
{
CultureInfo.CurrentCulture = culture;
CultureInfo.CurrentUICulture = culture;
}
private void SetCultureCookie(HttpContext context, string cultureName)
{
context.Response.Cookies.Append(
_options.CookieName,
cultureName,
new CookieOptions
{
Expires = DateTimeOffset.UtcNow.Add(_options.CookieExpiration),
IsEssential = true,
SameSite = SameSiteMode.Lax
});
}
}
public static class LocalizationMiddlewareExtensions
{
public static IApplicationBuilder UseCustomLocalization(
this IApplicationBuilder builder,
Action<LocalizationOptions> configureOptions = null)
{
var options = new LocalizationOptions();
configureOptions?.Invoke(options);
return builder.UseMiddleware<LocalizationMiddleware>(options);
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
注册方式:
app.UseCustomLocalization(options => {
options.SupportedCultures = new List<CultureInfo>
{
new CultureInfo("en-US"),
new CultureInfo("fr-FR"),
new CultureInfo("zh-CN")
};
options.DefaultCulture = "en-US";
});
2
3
4
5
6
7
8
9
工作原理: 此中间件根据请求中的信息决定要使用的文化信息。它按优先级检查查询字符串、Cookie和Accept-Language头,设置当前线程的文化信息,并可选择将文化选择持久化到Cookie。这使应用程序能够自动呈现适合用户语言偏好的内容。
# 11. 请求上下文中间件
用途: 存储请求特定的上下文信息,在整个请求处理过程中可用。
实现:
// 请求上下文类,包含请求相关信息
public class RequestContext
{
public string TraceId { get; set; }
public string UserId { get; set; }
public bool IsAuthenticated { get; set; }
public string UserAgent { get; set; }
public string IpAddress { get; set; }
public DateTime RequestStartTime { get; set; }
public Dictionary<string, object> Items { get; } = new Dictionary<string, object>();
}
public class RequestContextMiddleware
{
private readonly RequestDelegate _next;
private readonly ILogger<RequestContextMiddleware> _logger;
public RequestContextMiddleware(
RequestDelegate next,
ILogger<RequestContextMiddleware> logger)
{
_next = next;
_logger = logger;
}
public async Task InvokeAsync(HttpContext httpContext, IServiceProvider serviceProvider)
{
// 创建请求上下文
var requestContext = new RequestContext
{
TraceId = Activity.Current?.Id ?? httpContext.TraceIdentifier,
IsAuthenticated = httpContext.User?.Identity?.IsAuthenticated ?? false,
UserId = httpContext.User?.FindFirst(ClaimTypes.NameIdentifier)?.Value,
UserAgent = httpContext.Request.Headers["User-Agent"],
IpAddress = httpContext.Connection.RemoteIpAddress?.ToString(),
RequestStartTime = DateTime.UtcNow
};
// 将上下文存储在HTTP上下文中,以便在请求处理期间访问
httpContext.Items["RequestContext"] = requestContext;
// 注入请求上下文到当前作用域
var scope = serviceProvider.CreateScope();
var requestContextAccessor = scope.ServiceProvider.GetRequiredService<IRequestContextAccessor>();
requestContextAccessor.Context = requestContext;
try
{
// 继续处理请求
await _next(httpContext);
}
finally
{
// 请求结束时的清理
requestContextAccessor.Context = null;
scope.Dispose();
}
}
}
// 访问器接口,用于在应用程序中访问当前请求上下文
public interface IRequestContextAccessor
{
RequestContext Context { get; set; }
}
// 实现访问器
public class RequestContextAccessor : IRequestContextAccessor
{
private static readonly AsyncLocal<RequestContextHolder> _requestContextCurrent = new();
public RequestContext Context
{
get => _requestContextCurrent.Value?.Context;
set
{
if (_requestContextCurrent.Value == null)
{
_requestContextCurrent.Value = new RequestContextHolder();
}
_requestContextCurrent.Value.Context = value;
}
}
private class RequestContextHolder
{
public RequestContext Context;
}
}
public static class RequestContextMiddlewareExtensions
{
public static IApplicationBuilder UseRequestContext(this IApplicationBuilder builder)
{
return builder.UseMiddleware<RequestContextMiddleware>();
}
public static IServiceCollection AddRequestContext(this IServiceCollection services)
{
services.AddScoped<IRequestContextAccessor, RequestContextAccessor>();
return services;
}
// 获取当前请求上下文的辅助方法
public static RequestContext GetRequestContext(this HttpContext httpContext)
{
if (httpContext.Items.TryGetValue("RequestContext", out var context))
{
return context as RequestContext;
}
return null;
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
注册方式:
// 在Startup.cs中配置服务
public void ConfigureServices(IServiceCollection services)
{
services.AddRequestContext();
// 其他服务...
}
// 在请求管道中配置
public void Configure(IApplicationBuilder app)
{
// 应该在早期注册,以便在其他中间件中可用
app.UseRequestContext();
// 其他中间件...
}
2
3
4
5
6
7
8
9
10
11
12
13
14
使用方式:
// 在控制器中
public class SampleController : ControllerBase
{
private readonly IRequestContextAccessor _contextAccessor;
public SampleController(IRequestContextAccessor contextAccessor)
{
_contextAccessor = contextAccessor;
}
[HttpGet]
public IActionResult Get()
{
var context = _contextAccessor.Context;
// 或者直接从HTTP上下文获取
var contextFromHttp = HttpContext.GetRequestContext();
return Ok(new {
TraceId = context.TraceId,
UserId = context.UserId,
UserAgent = context.UserAgent,
RequestTime = DateTime.UtcNow - context.RequestStartTime
});
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
工作原理: 该中间件创建一个包含当前请求信息的上下文对象,并使其在整个请求处理过程中可用。它使用AsyncLocal
# 12. API版本控制中间件
用途: 实现API版本控制,允许同时支持多个API版本。
实现:
public enum VersionSource
{
Header,
QueryString,
MediaType,
UrlPath
}
public class ApiVersionOptions
{
public string DefaultVersion { get; set; } = "1.0";
public string HeaderName { get; set; } = "X-API-Version";
public string QueryStringParam { get; set; } = "api-version";
public string MediaTypeParameter { get; set; } = "v";
public List<VersionSource> Sources { get; set; } = new List<VersionSource>
{
VersionSource.Header,
VersionSource.QueryString,
VersionSource.MediaType
};
}
public class ApiVersionMiddleware
{
private readonly RequestDelegate _next;
private readonly ApiVersionOptions _options;
private readonly ILogger<ApiVersionMiddleware> _logger;
public ApiVersionMiddleware(
RequestDelegate next,
ApiVersionOptions options,
ILogger<ApiVersionMiddleware> logger)
{
_next = next;
_options = options;
_logger = logger;
}
public async Task InvokeAsync(HttpContext context)
{
var version = ExtractVersion(context);
// 将提取的版本添加到请求特性中
context.Items["ApiVersion"] = version;
// 添加路由数据,以便MVC路由可以使用它
if (context.GetRouteData().Values.ContainsKey("version"))
{
context.GetRouteData().Values["version"] = version;
}
else
{
context.GetRouteData().Values.Add("version", version);
}
await _next(context);
}
private string ExtractVersion(HttpContext context)
{
string version = null;
foreach (var source in _options.Sources)
{
switch (source)
{
case VersionSource.Header:
version = ExtractFromHeader(context);
break;
case VersionSource.QueryString:
version = ExtractFromQueryString(context);
break;
case VersionSource.MediaType:
version = ExtractFromMediaType(context);
break;
case VersionSource.UrlPath:
version = ExtractFromUrlPath(context);
break;
}
if (!string.IsNullOrEmpty(version))
{
break;
}
}
return string.IsNullOrEmpty(version) ? _options.DefaultVersion : version;
}
private string ExtractFromHeader(HttpContext context)
{
if (context.Request.Headers.TryGetValue(_options.HeaderName, out var values))
{
return values.FirstOrDefault();
}
return null;
}
private string ExtractFromQueryString(HttpContext context)
{
if (context.Request.Query.TryGetValue(_options.QueryStringParam, out var values))
{
return values.FirstOrDefault();
}
return null;
}
private string ExtractFromMediaType(HttpContext context)
{
// 检查Accept和Content-Type头
var acceptHeader = context.Request.Headers["Accept"].FirstOrDefault();
if (!string.IsNullOrEmpty(acceptHeader))
{
var version = ParseVersionFromMediaType(acceptHeader);
if (!string.IsNullOrEmpty(version))
{
return version;
}
}
var contentTypeHeader = context.Request.Headers["Content-Type"].FirstOrDefault();
if (!string.IsNullOrEmpty(contentTypeHeader))
{
return ParseVersionFromMediaType(contentTypeHeader);
}
return null;
}
private string ExtractFromUrlPath(HttpContext context)
{
// 从URL路径中提取版本
// 假定版本格式为 /v{version}/...
var path = context.Request.Path.Value;
var segments = path.Split('/', StringSplitOptions.RemoveEmptyEntries);
if (segments.Length > 0 && segments[0].StartsWith("v", StringComparison.OrdinalIgnoreCase))
{
return segments[0].Substring(1);
}
return null;
}
private string ParseVersionFromMediaType(string mediaType)
{
// 解析如 application/json;v=2.0 格式的媒体类型
var parts = mediaType.Split(';');
foreach (var part in parts)
{
var trimmedPart = part.Trim();
if (trimmedPart.StartsWith(_options.MediaTypeParameter + "=", StringComparison.OrdinalIgnoreCase))
{
return trimmedPart.Substring(_options.MediaTypeParameter.Length + 1);
}
}
return null;
}
}
public static class ApiVersionMiddlewareExtensions
{
public static IApplicationBuilder UseApiVersioning(
this IApplicationBuilder builder,
Action<ApiVersionOptions> configureOptions = null)
{
var options = new ApiVersionOptions();
configureOptions?.Invoke(options);
return builder.UseMiddleware<ApiVersionMiddleware>(options);
}
// 获取当前API版本的辅助方法
public static string GetApiVersion(this HttpContext context)
{
if (context.Items.TryGetValue("ApiVersion", out var version))
{
return version as string;
}
return null;
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
注册方式:
app.UseApiVersioning(options => {
options.DefaultVersion = "1";
options.Sources = new List<VersionSource> {
VersionSource.UrlPath,
VersionSource.Header,
VersionSource.QueryString
};
});
2
3
4
5
6
7
8
使用方式:
// 在控制器中
[Route("api/[controller]")]
public class ProductsController : ControllerBase
{
[HttpGet]
public IActionResult Get()
{
var version = HttpContext.GetApiVersion();
if (version == "1.0")
{
return Ok(new { Version = "1.0", Products = GetV1Products() });
}
else if (version == "2.0")
{
return Ok(new { Version = "2.0", Products = GetV2Products() });
}
return NotFound();
}
}
// 使用特性路由的方式
[ApiController]
public class ProductsController : ControllerBase
{
[HttpGet]
[Route("api/v1/products")]
public IActionResult GetV1()
{
// V1 实现
}
[HttpGet]
[Route("api/v2/products")]
public IActionResult GetV2()
{
// V2 实现
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
工作原理: 此中间件支持从多种来源提取API版本信息,包括头部、查询字符串、媒体类型和URL路径。它将提取的版本信息放入请求上下文中,以便控制器和操作可以根据版本提供不同的实现。这使得可以同时支持多个API版本,进行平滑升级。
# 13. 响应缓存中间件
用途: 缓存HTTP响应以提高性能,减少服务器负载。
实现:
public class ResponseCachingOptions
{
public TimeSpan DefaultDuration { get; set; } = TimeSpan.FromMinutes(10);
public int MaxCacheSize { get; set; } = 100 * 1024 * 1024; // 100MB
public bool UseCacheControlHeaders { get; set; } = true;
public List<string> VaryByHeaders { get; set; } = new List<string> { "Accept", "Accept-Encoding" };
public List<string> ExcludedPaths { get; set; } = new List<string>();
}
public class CacheEntry
{
public string Key { get; set; }
public byte[] Content { get; set; }
public string ContentType { get; set; }
public List<KeyValuePair<string, string>> Headers { get; set; }
public DateTime Expires { get; set; }
}
public class ResponseCachingMiddleware
{
private readonly RequestDelegate _next;
private readonly ResponseCachingOptions _options;
private readonly ILogger<ResponseCachingMiddleware> _logger;
private readonly MemoryCache _cache;
public ResponseCachingMiddleware(
RequestDelegate next,
ResponseCachingOptions options,
ILogger<ResponseCachingMiddleware> logger)
{
_next = next;
_options = options;
_logger = logger;
var memoryCacheOptions = new MemoryCacheOptions
{
SizeLimit = _options.MaxCacheSize
};
_cache = new MemoryCache(memoryCacheOptions);
}
public async Task InvokeAsync(HttpContext context)
{
// 只缓存GET请求
if (!HttpMethods.IsGet(context.Request.Method))
{
await _next(context);
return;
}
// 检查排除路径
if (IsPathExcluded(context.Request.Path))
{
await _next(context);
return;
}
// 生成缓存键
var cacheKey = GenerateCacheKey(context);
// 尝试从缓存获取响应
if (_cache.TryGetValue(cacheKey, out CacheEntry cachedResponse))
{
_logger.LogDebug("Cache hit for {CacheKey}", cacheKey);
await ServeCachedResponseAsync(context, cachedResponse);
return;
}
// 缓存未命中,继续处理请求
_logger.LogDebug("Cache miss for {CacheKey}", cacheKey);
// 捕获原始响应
using var responseBodyStream = new MemoryStream();
var originalBodyStream = context.Response.Body;
context.Response.Body = responseBodyStream;
try
{
await _next(context);
// 只缓存成功的响应
if (context.Response.StatusCode == 200 && ShouldCacheResponse(context))
{
// 准备缓存条目
responseBodyStream.Position = 0;
var responseContent = await ReadStreamAsync(responseBodyStream);
// 创建缓存项
var cacheDuration = GetCacheDuration(context);
var cacheEntry = new CacheEntry
{
Key = cacheKey,
Content = responseContent,
ContentType = context.Response.ContentType,
Headers = context.Response.Headers
.Where(h => !h.Key.Equals("Set-Cookie", StringComparison.OrdinalIgnoreCase))
.Select(h => new KeyValuePair<string, string>(h.Key, h.Value))
.ToList(),
Expires = DateTime.UtcNow.Add(cacheDuration)
};
// 添加到缓存
var entryOptions = new MemoryCacheEntryOptions()
.SetSize(responseContent.Length)
.SetAbsoluteExpiration(cacheDuration);
_cache.Set(cacheKey, cacheEntry, entryOptions);
_logger.LogDebug("Response cached with key {CacheKey} for {Duration}", cacheKey, cacheDuration);
// 设置缓存相关头部
if (_options.UseCacheControlHeaders)
{
context.Response.Headers["Cache-Control"] = $"public, max-age={cacheDuration.TotalSeconds}";
context.Response.Headers["Expires"] = cacheEntry.Expires.ToString("R");
}
}
// 将响应写回原始流
responseBodyStream.Position = 0;
await responseBodyStream.CopyToAsync(originalBodyStream);
}
finally
{
context.Response.Body = originalBodyStream;
}
}
private bool IsPathExcluded(string path)
{
return _options.ExcludedPaths.Any(p =>
path.StartsWith(p, StringComparison.OrdinalIgnoreCase));
}
private string GenerateCacheKey(HttpContext context)
{
// 基于路径和查询参数
var key = context.Request.Path.Value + context.Request.QueryString.Value;
// 添加Vary头部
foreach (var header in _options.VaryByHeaders)
{
if (context.Request.Headers.TryGetValue(header, out var value))
{
key += $"|{header}={value}";
}
}
// 如果已认证,可以添加用户ID
if (context.User?.Identity?.IsAuthenticated == true)
{
var userId = context.User.FindFirst(ClaimTypes.NameIdentifier)?.Value;
if (!string.IsNullOrEmpty(userId))
{
key += $"|user={userId}";
}
}
return key;
}
private bool ShouldCacheResponse(HttpContext context)
{
// 检查 Cache-Control 头部
if (context.Response.Headers.TryGetValue("Cache-Control", out var cacheControl))
{
if (cacheControl.ToString().Contains("no-store") ||
cacheControl.ToString().Contains("no-cache"))
{
return false;
}
}
// 检查设置了哪些头部
if (context.Response.Headers.ContainsKey("Set-Cookie"))
{
// 包含cookie的响应通常不应该被缓存
return false;
}
return true;
}
private TimeSpan GetCacheDuration(HttpContext context)
{
// 尝试从Cache-Control头部获取max-age
if (context.Response.Headers.TryGetValue("Cache-Control", out var cacheControl))
{
var match = Regex.Match(cacheControl, @"max-age=(\d+)");
if (match.Success && int.TryParse(match.Groups[1].Value, out var seconds))
{
return TimeSpan.FromSeconds(seconds);
}
}
return _options.DefaultDuration;
}
private async Task<byte[]> ReadStreamAsync(Stream stream)
{
using var ms = new MemoryStream();
await stream.CopyToAsync(ms);
return ms.ToArray();
}
private async Task ServeCachedResponseAsync(HttpContext context, CacheEntry cachedResponse)
{
context.Response.StatusCode = 200;
context.Response.ContentType = cachedResponse.ContentType;
// 复制缓存的头部
foreach (var header in cachedResponse.Headers)
{
context.Response.Headers[header.Key] = header.Value;
}
// 添加指示缓存的头部
context.Response.Headers["X-Cache"] = "HIT";
// 写入缓存的响应内容
await context.Response.Body.WriteAsync(cachedResponse.Content, 0, cachedResponse.Content.Length);
}
}
public static class ResponseCachingMiddlewareExtensions
{
public static IApplicationBuilder UseResponseCaching(
this IApplicationBuilder builder,
Action<ResponseCachingOptions> configureOptions = null)
{
var options = new ResponseCachingOptions();
configureOptions?.Invoke(options);
return builder.UseMiddleware<ResponseCachingMiddleware>(options);
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
注册方式:
app.UseResponseCaching(options => {
options.DefaultDuration = TimeSpan.FromMinutes(5);
options.MaxCacheSize = 50 * 1024 * 1024; // 50MB
options.ExcludedPaths.Add("/api/dynamic");
options.ExcludedPaths.Add("/api/personal");
});
2
3
4
5
6
工作原理: 该中间件为GET请求实现响应缓存。它拦截响应,将内容存储在内存缓存中,并在随后的请求中直接从缓存提供内容。它支持基于路径、查询参数和请求头进行缓存键生成,并遵循HTTP缓存控制头。该中间件可以显著减少对后端服务的请求,提高应用性能。
# 14. JWT认证中间件
用途: 验证并处理JWT (JSON Web Token) 令牌,用于API身份验证。
实现:
public class JwtOptions
{
public string SecretKey { get; set; }
public string Issuer { get; set; }
public string Audience { get; set; }
public TimeSpan TokenLifetime { get; set; } = TimeSpan.FromHours(1);
public bool ValidateIssuer { get; set; } = true;
public bool ValidateAudience { get; set; } = true;
public bool ValidateLifetime { get; set; } = true;
public bool ValidateIssuerSigningKey { get; set; } = true;
public string AuthorizationHeaderName { get; set; } = "Authorization";
public string AuthorizationScheme { get; set; } = "Bearer";
public List<string> ExcludedPaths { get; set; } = new List<string>();
}
public class JwtMiddleware
{
private readonly RequestDelegate _next;
private readonly JwtOptions _options;
private readonly ILogger<JwtMiddleware> _logger;
public JwtMiddleware(
RequestDelegate next,
JwtOptions options,
ILogger<JwtMiddleware> logger)
{
_next = next;
_options = options;
_logger = logger;
}
public async Task InvokeAsync(HttpContext context)
{
// 检查是否排除验证
if (ShouldSkipAuthentication(context.Request.Path))
{
await _next(context);
return;
}
// 提取令牌
if (!TryExtractToken(context, out var token))
{
// 没有找到令牌,继续处理请求
// 后续的授权中间件或控制器属性将拒绝未经授权的请求
await _next(context);
return;
}
try
{
// 验证并处理令牌
var principal = ValidateToken(token);
// 设置当前用户
context.User = principal;
}
catch (Exception ex)
{
_logger.LogWarning(ex, "JWT token validation failed");
// 令牌无效,但仍继续处理请求
// 授权中间件将拒绝请求
}
await _next(context);
}
private bool ShouldSkipAuthentication(PathString path)
{
return _options.ExcludedPaths.Any(p =>
path.StartsWithSegments(p, StringComparison.OrdinalIgnoreCase));
}
private bool TryExtractToken(HttpContext context, out string token)
{
token = null;
// 从授权头部提取
if (context.Request.Headers.TryGetValue(_options.AuthorizationHeaderName, out var authHeader))
{
var authHeaderVal = authHeader.FirstOrDefault();
if (!string.IsNullOrEmpty(authHeaderVal) && authHeaderVal.StartsWith($"{_options.AuthorizationScheme} ", StringComparison.OrdinalIgnoreCase))
{
token = authHeaderVal.Substring(_options.AuthorizationScheme.Length + 1).Trim();
return true;
}
}
// 从查询字符串提取
if (context.Request.Query.TryGetValue("access_token", out var queryToken))
{
token = queryToken;
return true;
}
return false;
}
private ClaimsPrincipal ValidateToken(string token)
{
var tokenValidationParameters = new TokenValidationParameters
{
ValidateIssuer = _options.ValidateIssuer,
ValidateAudience = _options.ValidateAudience,
ValidateLifetime = _options.ValidateLifetime,
ValidateIssuerSigningKey = _options.ValidateIssuerSigningKey,
ValidIssuer = _options.Issuer,
ValidAudience = _options.Audience,
IssuerSigningKey = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(_options.SecretKey)),
ClockSkew = TimeSpan.Zero
};
var tokenHandler = new JwtSecurityTokenHandler();
var principal = tokenHandler.ValidateToken(token, tokenValidationParameters, out var securityToken);
// 额外验证加密算法
if (!(securityToken is JwtSecurityToken jwtSecurityToken) ||
!jwtSecurityToken.Header.Alg.Equals(SecurityAlgorithms.HmacSha256, StringComparison.InvariantCultureIgnoreCase))
{
throw new SecurityTokenException("Invalid token");
}
return principal;
}
// 生成JWT令牌的方法
public static string GenerateJwtToken(IEnumerable<Claim> claims, JwtOptions options)
{
var key = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(options.SecretKey));
var credentials = new SigningCredentials(key, SecurityAlgorithms.HmacSha256);
var token = new JwtSecurityToken(
issuer: options.Issuer,
audience: options.Audience,
claims: claims,
notBefore: DateTime.UtcNow,
expires: DateTime.UtcNow.Add(options.TokenLifetime),
signingCredentials: credentials
);
return new JwtSecurityTokenHandler().WriteToken(token);
}
}
public static class JwtMiddlewareExtensions
{
public static IApplicationBuilder UseJwtAuthentication(
this IApplicationBuilder builder,
Action<JwtOptions> configureOptions = null)
{
var options = new JwtOptions();
configureOptions?.Invoke(options);
if (string.IsNullOrEmpty(options.SecretKey))
{
throw new ArgumentException("JWT secret key must be configured");
}
return builder.UseMiddleware<JwtMiddleware>(options);
}
// 生成令牌的扩展方法
public static string GenerateJwtToken(this IServiceProvider services, IEnumerable<Claim> claims)
{
var options = services.GetRequiredService<IOptions<JwtOptions>>().Value;
return JwtMiddleware.GenerateJwtToken(claims, options);
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
注册方式:
// 在Startup.ConfigureServices中
services.Configure<JwtOptions>(Configuration.GetSection("Jwt"));
// 在Startup.Configure中
app.UseJwtAuthentication(options => {
options.SecretKey = Configuration["Jwt:SecretKey"];
options.Issuer = Configuration["Jwt:Issuer"];
options.Audience = Configuration["Jwt:Audience"];
options.ExcludedPaths.Add("/api/auth/login");
options.ExcludedPaths.Add("/api/public");
});
// 确保在身份验证中间件之前注册
app.UseAuthentication();
app.UseAuthorization();
2
3
4
5
6
7
8
9
10
11
12
13
14
15
发放令牌的示例:
[HttpPost("login")]
public IActionResult Login([FromBody] LoginModel model)
{
// 验证用户凭据
var user = _userService.Authenticate(model.Username, model.Password);
if (user == null)
return Unauthorized();
// 创建声明
var claims = new List<Claim>
{
new Claim(ClaimTypes.Name, user.Username),
new Claim(ClaimTypes.NameIdentifier, user.Id.ToString()),
new Claim(ClaimTypes.Role, user.Role)
};
// 生成令牌
var token = HttpContext.RequestServices.GenerateJwtToken(claims);
return Ok(new { token });
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
工作原理: 该中间件验证请求中的JWT令牌,创建身份信息并将其与当前HTTP上下文关联。它从HTTP头或查询参数中提取令牌,验证其签名、发行者、受众和有效期。验证成功后,将从令牌中提取的声明设置为当前用户的身份,使授权决策能够基于这些声明进行。此中间件结合授权中间件,可以保护API端点,确保只有授权用户才能访问。
# 15. WebSockets中间件
用途: 支持WebSocket连接,实现实时双向通信。
实现:
public class WebSocketOptions
{
public string Path { get; set; } = "/ws";
public int ReceiveBufferSize { get; set; } = 4 * 1024; // 4KB
public TimeSpan KeepAliveInterval { get; set; } = TimeSpan.FromMinutes(2);
public bool RequireAuthentication { get; set; } = false;
}
public class WebSocketConnectionManager
{
private readonly ConcurrentDictionary<string, WebSocket> _sockets = new ConcurrentDictionary<string, WebSocket>();
private readonly ConcurrentDictionary<string, string> _socketToUserMap = new ConcurrentDictionary<string, string>();
public string AddSocket(WebSocket socket, string userId = null)
{
var connectionId = Guid.NewGuid().ToString();
_sockets.TryAdd(connectionId, socket);
if (!string.IsNullOrEmpty(userId))
{
_socketToUserMap.TryAdd(connectionId, userId);
}
return connectionId;
}
public async Task RemoveSocket(string id)
{
_sockets.TryRemove(id, out var socket);
_socketToUserMap.TryRemove(id, out _);
if (socket != null)
{
try
{
await socket.CloseAsync(
WebSocketCloseStatus.NormalClosure,
"Connection closed by the server",
CancellationToken.None);
}
catch (Exception)
{
// 忽略关闭时的异常
}
}
}
public WebSocket GetSocketById(string id)
{
return _sockets.TryGetValue(id, out var socket) ? socket : null;
}
public List<string> GetAllConnectionIds()
{
return _sockets.Keys.ToList();
}
public List<string> GetConnectionIdsByUser(string userId)
{
return _socketToUserMap
.Where(x => x.Value == userId)
.Select(x => x.Key)
.ToList();
}
// 向特定连接发送消息
public async Task SendMessageAsync(string connectionId, string message)
{
var socket = GetSocketById(connectionId);
if (socket != null && socket.State == WebSocketState.Open)
{
var buffer = Encoding.UTF8.GetBytes(message);
await socket.SendAsync(
new ArraySegment<byte>(buffer),
WebSocketMessageType.Text,
true,
CancellationToken.None);
}
}
// 向所有连接广播消息
public async Task BroadcastAsync(string message)
{
foreach (var id in GetAllConnectionIds())
{
await SendMessageAsync(id, message);
}
}
// 向特定用户的所有连接发送消息
public async Task SendToUserAsync(string userId, string message)
{
var connectionIds = GetConnectionIdsByUser(userId);
foreach (var id in connectionIds)
{
await SendMessageAsync(id, message);
}
}
}
public interface IWebSocketHandler
{
Task OnConnected(WebSocket socket, HttpContext context, string connectionId);
Task OnDisconnected(string connectionId);
Task ReceiveAsync(string connectionId, string message, WebSocketMessageType messageType);
}
public class WebSocketMiddleware
{
private readonly RequestDelegate _next;
private readonly WebSocketOptions _options;
private readonly ILogger<WebSocketMiddleware> _logger;
private readonly WebSocketConnectionManager _connectionManager;
private readonly IWebSocketHandler _webSocketHandler;
public WebSocketMiddleware(
RequestDelegate next,
WebSocketOptions options,
ILogger<WebSocketMiddleware> logger,
WebSocketConnectionManager connectionManager,
IWebSocketHandler webSocketHandler)
{
_next = next;
_options = options;
_logger = logger;
_connectionManager = connectionManager;
_webSocketHandler = webSocketHandler;
}
public async Task InvokeAsync(HttpContext context)
{
if (context.Request.Path == _options.Path)
{
if (context.WebSockets.IsWebSocketRequest)
{
// 检查身份验证要求
if (_options.RequireAuthentication && !context.User.Identity.IsAuthenticated)
{
context.Response.StatusCode = 401;
return;
}
await HandleWebSocketAsync(context);
return;
}
context.Response.StatusCode = 400;
return;
}
await _next(context);
}
private async Task HandleWebSocketAsync(HttpContext context)
{
using var socket = await context.WebSockets.AcceptWebSocketAsync();
// 获取用户ID(如果已认证)
var userId = context.User.Identity.IsAuthenticated ?
context.User.FindFirst(ClaimTypes.NameIdentifier)?.Value : null;
// 添加连接并通知处理程序
var connectionId = _connectionManager.AddSocket(socket, userId);
try
{
await _webSocketHandler.OnConnected(socket, context, connectionId);
await ReceiveMessagesAsync(socket, connectionId);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error in WebSocket connection {ConnectionId}", connectionId);
}
finally
{
// 移除连接并通知处理程序
await _connectionManager.RemoveSocket(connectionId);
await _webSocketHandler.OnDisconnected(connectionId);
}
}
private async Task ReceiveMessagesAsync(WebSocket socket, string connectionId)
{
var buffer = new byte[_options.ReceiveBufferSize];
while (socket.State == WebSocketState.Open)
{
var result = await socket.ReceiveAsync(
new ArraySegment<byte>(buffer),
CancellationToken.None);
if (result.MessageType == WebSocketMessageType.Close)
{
await socket.CloseAsync(
WebSocketCloseStatus.NormalClosure,
"Connection closed by the client",
CancellationToken.None);
break;
}
// 处理接收到的消息
var message = Encoding.UTF8.GetString(buffer, 0, result.Count);
await _webSocketHandler.ReceiveAsync(connectionId, message, result.MessageType);
}
}
}
// 默认WebSocket处理程序
public class DefaultWebSocketHandler : IWebSocketHandler
{
private readonly WebSocketConnectionManager _connectionManager;
private readonly ILogger<DefaultWebSocketHandler> _logger;
public DefaultWebSocketHandler(
WebSocketConnectionManager connectionManager,
ILogger<DefaultWebSocketHandler> logger)
{
_connectionManager = connectionManager;
_logger = logger;
}
public virtual Task OnConnected(WebSocket socket, HttpContext context, string connectionId)
{
_logger.LogInformation("WebSocket client connected: {ConnectionId}", connectionId);
return Task.CompletedTask;
}
public virtual Task OnDisconnected(string connectionId)
{
_logger.LogInformation("WebSocket client disconnected: {ConnectionId}", connectionId);
return Task.CompletedTask;
}
public virtual async Task ReceiveAsync(string connectionId, string message, WebSocketMessageType messageType)
{
_logger.LogDebug("Message received from {ConnectionId}: {Message}", connectionId, message);
// 默认实现:向发送者回显消息
if (messageType == WebSocketMessageType.Text)
{
await _connectionManager.SendMessageAsync(connectionId, $"Echo: {message}");
}
}
}
public static class WebSocketMiddlewareExtensions
{
public static IApplicationBuilder UseWebSocketHandler(
this IApplicationBuilder builder,
Action<WebSocketOptions> configureOptions = null)
{
// 注册服务
var serviceProvider = builder.ApplicationServices;
// 获取选项
var options = new WebSocketOptions();
configureOptions?.Invoke(options);
// 确保已配置WebSockets
return builder
.UseWebSockets(new Microsoft.AspNetCore.Builder.WebSocketOptions
{
KeepAliveInterval = options.KeepAliveInterval
})
.UseMiddleware<WebSocketMiddleware>(
options,
serviceProvider.GetRequiredService<WebSocketConnectionManager>(),
serviceProvider.GetRequiredService<IWebSocketHandler>());
}
// 注册WebSocket服务
public static IServiceCollection AddWebSocketManager(this IServiceCollection services)
{
services.AddSingleton<WebSocketConnectionManager>();
services.AddSingleton<IWebSocketHandler, DefaultWebSocketHandler>();
return services;
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
注册方式:
// 在Startup.ConfigureServices中
services.AddWebSocketManager();
// 可选:注册自定义处理程序
services.AddSingleton<IWebSocketHandler, ChatWebSocketHandler>();
// 在Startup.Configure中
app.UseWebSocketHandler(options => {
options.Path = "/chat";
options.RequireAuthentication = true;
});
2
3
4
5
6
7
8
9
10
11
自定义处理程序示例:
public class ChatWebSocketHandler : IWebSocketHandler
{
private readonly WebSocketConnectionManager _connectionManager;
private readonly ILogger<ChatWebSocketHandler> _logger;
public ChatWebSocketHandler(
WebSocketConnectionManager connectionManager,
ILogger<ChatWebSocketHandler> logger)
{
_connectionManager = connectionManager;
_logger = logger;
}
public async Task OnConnected(WebSocket socket, HttpContext context, string connectionId)
{
var username = context.User.Identity.Name ?? "Anonymous";
_logger.LogInformation("User {Username} connected with ID {ConnectionId}",
username, connectionId);
// 广播新用户加入消息
await _connectionManager.BroadcastAsync(
JsonSerializer.Serialize(new {
type = "userConnected",
username,
connectionId,
timestamp = DateTime.UtcNow
}));
}
public async Task OnDisconnected(string connectionId)
{
_logger.LogInformation("Client {ConnectionId} disconnected", connectionId);
// 广播用户离开消息
await _connectionManager.BroadcastAsync(
JsonSerializer.Serialize(new {
type = "userDisconnected",
connectionId,
timestamp = DateTime.UtcNow
}));
}
public async Task ReceiveAsync(string connectionId, string message, WebSocketMessageType messageType)
{
if (messageType == WebSocketMessageType.Text)
{
try
{
// 解析消息
var chatMessage = JsonSerializer.Deserialize<ChatMessage>(message);
// 根据消息类型处理
switch (chatMessage.Type)
{
case "message":
// 广播消息给所有客户端
await _connectionManager.BroadcastAsync(
JsonSerializer.Serialize(new {
type = "message",
from = chatMessage.From,
content = chatMessage.Content,
timestamp = DateTime.UtcNow
}));
break;
case "privateMessage":
// 发送私信
if (!string.IsNullOrEmpty(chatMessage.To))
{
await _connectionManager.SendToUserAsync(
chatMessage.To,
JsonSerializer.Serialize(new {
type = "privateMessage",
from = chatMessage.From,
content = chatMessage.Content,
timestamp = DateTime.UtcNow
}));
}
break;
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Error processing message from {ConnectionId}", connectionId);
// 发送错误消息回客户端
await _connectionManager.SendMessageAsync(
connectionId,
JsonSerializer.Serialize(new {
type = "error",
message = "Invalid message format"
}));
}
}
}
private class ChatMessage
{
public string Type { get; set; }
public string From { get; set; }
public string To { get; set; }
public string Content { get; set; }
}
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
客户端示例 (JavaScript):
// 建立WebSocket连接
const socket = new WebSocket(`ws://${window.location.host}/chat`);
// 连接建立处理
socket.onopen = () => {
console.log('WebSocket连接已建立');
// 发送用户信息
socket.send(JSON.stringify({
type: 'message',
from: 'User123',
content: 'Hello, everyone!'
}));
};
// 接收消息处理
socket.onmessage = (event) => {
const message = JSON.parse(event.data);
console.log('收到消息:', message);
// 根据消息类型处理
switch (message.type) {
case 'message':
displayChatMessage(message);
break;
case 'privateMessage':
displayPrivateMessage(message);
break;
case 'userConnected':
notifyUserConnected(message);
break;
case 'userDisconnected':
notifyUserDisconnected(message);
break;
}
};
// 连接关闭处理
socket.onclose = (event) => {
console.log(`WebSocket连接已关闭,代码: ${event.code}, 原因: ${event.reason}`);
};
// 连接错误处理
socket.onerror = (error) => {
console.error('WebSocket错误:', error);
};
// 发送消息函数
function sendMessage(content) {
socket.send(JSON.stringify({
type: 'message',
from: 'User123',
content: content
}));
}
// 发送私信函数
function sendPrivateMessage(toUser, content) {
socket.send(JSON.stringify({
type: 'privateMessage',
from: 'User123',
to: toUser,
content: content
}));
}
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
工作原理: 这个WebSocket中间件支持实时双向通信。它维护客户端连接,处理连接和断开,并管理消息传输。该中间件使用处理程序模式允许自定义消息处理逻辑,并支持广播和点对点通信。它还与ASP.NET Core的身份验证系统集成,可以限制只有已认证的用户才能建立WebSocket连接。该中间件是构建聊天、实时通知和协作应用程序的基础。
# 总结
以上15个中间件示例涵盖了从简单到复杂的各种实用场景,每个都可以根据项目需求进行修改和扩展。开发ASP.NET Core中间件的关键技能包括:
- 理解请求-响应管道
- 掌握异步编程模式
- 熟悉依赖注入和服务生命周期
- 正确处理异常和错误状态
- 理解HTTP协议和网络通信
- 实现合适的线程安全机制
- 优化性能和资源使用
在实际工作中,这些中间件可以帮助您构建更安全、更高效、更易于维护的Web应用程序和API。根据您的具体需求,可以组合使用这些中间件或根据这些示例开发自定义中间件。
要记住的最重要原则是:中间件应该遵循单一职责原则,专注于解决特定问题,并可以与其他中间件组合使用以构建完整的应用管道。