v2 / vlib / veb / middleware.v
511 lines · 451 sloc · 15.57 KB · e632a84cd573bb05f3f72a0ae0cb9bbcaae404da
Raw
1module veb
2
3import compress.gzip
4import compress.zstd
5import net.http
6
7pub type MiddlewareHandler[T] = fn (mut T) bool
8
9interface MiddlewareApp {
10 get_handlers_for_route(route_path string) []RouteMiddleware
11 get_handlers_for_route_after(route_path string) []RouteMiddleware
12 get_global_handlers() []voidptr
13 get_global_handlers_after() []voidptr
14}
15
16struct RouteMiddleware {
17 url_parts []string
18 methods []http.Method
19 handler voidptr
20}
21
22pub struct Middleware[T] {
23mut:
24 global_handlers []voidptr
25 global_handlers_after []voidptr
26 route_handlers []RouteMiddleware
27 route_handlers_after []RouteMiddleware
28}
29
30@[params]
31pub struct MiddlewareOptions[T] {
32pub:
33 handler fn (mut ctx T) bool @[required]
34 after bool
35 methods []http.Method
36}
37
38// string representation of Middleware
39pub fn (m &Middleware[T]) str() string {
40 return 'veb.Middleware[${T.name}]{
41 global_handlers: [${m.global_handlers.len}]
42 global_handlers_after: [${m.global_handlers_after.len}]
43 route_handlers: [${m.route_handlers.len}]
44 route_handlers_after: [${m.route_handlers_after.len}]
45 }'
46}
47
48// use registers a global middleware handler
49pub fn (mut m Middleware[T]) use(options MiddlewareOptions[T]) {
50 if options.after {
51 m.global_handlers_after << voidptr(options.handler)
52 } else {
53 m.global_handlers << voidptr(options.handler)
54 }
55}
56
57// route_use registers a middleware handler for a specific route(s)
58pub fn (mut m Middleware[T]) route_use(route string, options MiddlewareOptions[T]) {
59 middleware := RouteMiddleware{
60 url_parts: route.split('/').filter(it != '')
61 methods: options.methods.clone()
62 handler: voidptr(options.handler)
63 }
64
65 if options.after {
66 m.route_handlers_after << middleware
67 } else {
68 m.route_handlers << middleware
69 }
70}
71
72fn (m &Middleware[T]) get_handlers_for_route(route_path string) []RouteMiddleware {
73 mut handlers := []RouteMiddleware{}
74 route_parts := route_path.split('/').filter(it != '')
75
76 for handler in m.route_handlers {
77 if _ := route_matches(route_parts, handler.url_parts) {
78 handlers << handler
79 } else if handler.url_parts.len == 0 && route_path == '/index' {
80 handlers << handler
81 }
82 }
83
84 return handlers
85}
86
87fn (m &Middleware[T]) get_handlers_for_route_after(route_path string) []RouteMiddleware {
88 mut handlers := []RouteMiddleware{}
89 route_parts := route_path.split('/').filter(it != '')
90
91 for handler in m.route_handlers_after {
92 if _ := route_matches(route_parts, handler.url_parts) {
93 handlers << handler
94 } else if handler.url_parts.len == 0 && route_path == '/index' {
95 handlers << handler
96 }
97 }
98
99 return handlers
100}
101
102fn (m &Middleware[T]) get_global_handlers() []voidptr {
103 return m.global_handlers
104}
105
106fn (m &Middleware[T]) get_global_handlers_after() []voidptr {
107 return m.global_handlers_after
108}
109
110fn app_route_handlers[A](app &A, route_path string) []RouteMiddleware {
111 $if A is $struct {
112 $for field in A.fields {
113 $if field.is_embed {
114 $if field.name == 'Middleware' {
115 return app.$(field.name).get_handlers_for_route(route_path)
116 } $else $if field.typ is $struct {
117 handlers := app_route_handlers(app.$(field.name), route_path)
118 if handlers.len > 0 {
119 return handlers
120 }
121 }
122 }
123 }
124 }
125 return []RouteMiddleware{}
126}
127
128fn app_route_handlers_after[A](app &A, route_path string) []RouteMiddleware {
129 $if A is $struct {
130 $for field in A.fields {
131 $if field.is_embed {
132 $if field.name == 'Middleware' {
133 return app.$(field.name).get_handlers_for_route_after(route_path)
134 } $else $if field.typ is $struct {
135 handlers := app_route_handlers_after(app.$(field.name), route_path)
136 if handlers.len > 0 {
137 return handlers
138 }
139 }
140 }
141 }
142 }
143 return []RouteMiddleware{}
144}
145
146fn app_global_handlers[A](app &A) []voidptr {
147 $if A is $struct {
148 $for field in A.fields {
149 $if field.is_embed {
150 $if field.name == 'Middleware' {
151 return app.$(field.name).get_global_handlers()
152 } $else $if field.typ is $struct {
153 handlers := app_global_handlers(app.$(field.name))
154 if handlers.len > 0 {
155 return handlers
156 }
157 }
158 }
159 }
160 }
161 return []voidptr{}
162}
163
164fn app_global_handlers_after[A](app &A) []voidptr {
165 $if A is $struct {
166 $for field in A.fields {
167 $if field.is_embed {
168 $if field.name == 'Middleware' {
169 return app.$(field.name).get_global_handlers_after()
170 } $else $if field.typ is $struct {
171 handlers := app_global_handlers_after(app.$(field.name))
172 if handlers.len > 0 {
173 return handlers
174 }
175 }
176 }
177 }
178 }
179 return []voidptr{}
180}
181
182fn validate_middleware[T](mut ctx T, raw_handlers []voidptr) bool {
183 for handler in raw_handlers {
184 func := MiddlewareHandler[T](handler)
185 if func(mut ctx) == false {
186 return false
187 }
188 }
189
190 return true
191}
192
193fn route_middleware_matches_method(route_middleware RouteMiddleware, request_method http.Method) bool {
194 return route_middleware.methods.len == 0 || request_method in route_middleware.methods
195}
196
197fn get_handlers_for_method(route_middlewares []RouteMiddleware, request_method http.Method) []voidptr {
198 mut handlers := []voidptr{}
199 for route_middleware in route_middlewares {
200 if route_middleware_matches_method(route_middleware, request_method) {
201 handlers << route_middleware.handler
202 }
203 }
204 return handlers
205}
206
207// Compression encoding types for HTTP responses
208enum ContentEncoding {
209 gzip
210 zstd
211}
212
213// send_compressed_response compresses the response body and updates the response.
214// Returns true if compression should be skipped, false if compression was applied.
215fn send_compressed_response(mut ctx Context, encoding ContentEncoding) bool {
216 compressed, encoding_name := match encoding {
217 .zstd {
218 data := zstd.compress(ctx.res.body.bytes()) or {
219 eprintln('[veb] error while compressing with zstd: ${err.msg()}')
220 return true
221 }
222 data, 'zstd'
223 }
224 .gzip {
225 data := gzip.compress(ctx.res.body.bytes()) or {
226 eprintln('[veb] error while compressing with gzip: ${err.msg()}')
227 return true
228 }
229 data, 'gzip'
230 }
231 }
232
233 // Set HTTP headers for compressed content
234 ctx.res.header.add(.content_encoding, encoding_name)
235 ctx.res.header.set(.vary, 'Accept-Encoding')
236
237 // Replace the response body with the compressed data and update Content-Length.
238 // The normal response path will handle sending it.
239 ctx.res.body = compressed.bytestr()
240 ctx.res.header.set(.content_length, compressed.len.str())
241 ctx.already_compressed = true
242
243 return false
244}
245
246// should_skip_compression checks if compression should be skipped for this context.
247fn should_skip_compression(ctx Context) bool {
248 // Skip if already compressed (optimization for static files compressed in send_file)
249 if ctx.already_compressed {
250 return true
251 }
252 // Skip compression for files in streaming mode (no takeover)
253 // Files in takeover mode (small files loaded in memory) are compressed
254 if ctx.return_type == .file && ctx.takeover_mode == .none {
255 return true
256 }
257 return false
258}
259
260// encode_gzip adds gzip encoding to the HTTP Response body.
261// This middleware compresses dynamic routes and static files loaded in memory (takeover mode).
262// Static files in streaming mode are compressed by send_file() when static compression is enabled,
263// and this middleware skips them to avoid double compression (via the already_compressed flag).
264// Register this middleware as last!
265// Usage example: app.use(veb.encode_gzip[Context]())
266pub fn encode_gzip[T]() MiddlewareOptions[T] {
267 return MiddlewareOptions[T]{
268 after: true
269 handler: fn [T](mut ctx T) bool {
270 if should_skip_compression(ctx.Context) {
271 return true
272 }
273 return send_compressed_response(mut ctx.Context, .gzip)
274 }
275 }
276}
277
278// encode_zstd adds zstd encoding to the HTTP Response body.
279// This middleware compresses dynamic routes and static files loaded in memory (takeover mode).
280// Static files in streaming mode are compressed by send_file() when static compression is enabled,
281// and this middleware skips them to avoid double compression (via the already_compressed flag).
282// Register this middleware as last!
283// Usage example: app.route_use('/api', veb.encode_zstd[Context]())
284pub fn encode_zstd[T]() MiddlewareOptions[T] {
285 return MiddlewareOptions[T]{
286 after: true
287 handler: fn [T](mut ctx T) bool {
288 if should_skip_compression(ctx.Context) {
289 return true
290 }
291 return send_compressed_response(mut ctx.Context, .zstd)
292 }
293 }
294}
295
296// encode_auto adds automatic content encoding (zstd or gzip) based on the client's Accept-Encoding header.
297// This middleware checks the Accept-Encoding header and compresses with zstd if supported, otherwise gzip.
298// Static files in streaming mode are compressed by send_file() when static compression is enabled,
299// and this middleware skips them to avoid double compression (via the already_compressed flag).
300// Register this middleware as last!
301// Usage example: app.use(veb.encode_auto[Context]())
302pub fn encode_auto[T]() MiddlewareOptions[T] {
303 return MiddlewareOptions[T]{
304 after: true
305 handler: fn [T](mut ctx T) bool {
306 if should_skip_compression(ctx.Context) {
307 return true
308 }
309
310 // Check Accept-Encoding header to determine best compression
311 accept_encoding := ctx.req.header.get(.accept_encoding) or { '' }
312 supports_zstd := accept_encoding.contains('zstd')
313 supports_gzip := accept_encoding.contains('gzip')
314
315 // Try zstd first (better compression ratio), fallback to gzip
316 if supports_zstd {
317 return send_compressed_response(mut ctx.Context, .zstd)
318 }
319 if supports_gzip {
320 return send_compressed_response(mut ctx.Context, .gzip)
321 }
322
323 // No supported compression
324 return true
325 }
326 }
327}
328
329// decode_gzip decodes the body of a gzip'ed HTTP request.
330// Register this middleware before you do anything with the request body!
331// Usage example: app.use(veb.decode_gzip[Context]())
332pub fn decode_gzip[T]() MiddlewareOptions[T] {
333 return MiddlewareOptions[T]{
334 handler: fn [T](mut ctx T) bool {
335 if encoding := ctx.req.header.get(.content_encoding) {
336 if encoding == 'gzip' {
337 decompressed := gzip.decompress(ctx.req.data.bytes()) or {
338 ctx.request_error('invalid gzip encoding')
339 return false
340 }
341 ctx.req.data = decompressed.bytestr()
342 }
343 }
344 return true
345 }
346 }
347}
348
349// decode_zstd decodes the body of a zstd-compressed HTTP request.
350// Register this middleware before you do anything with the request body!
351// Usage example: app.use(veb.decode_zstd[Context]())
352pub fn decode_zstd[T]() MiddlewareOptions[T] {
353 return MiddlewareOptions[T]{
354 handler: fn [T](mut ctx T) bool {
355 if encoding := ctx.req.header.get(.content_encoding) {
356 if encoding == 'zstd' {
357 decompressed := zstd.decompress(ctx.req.data.bytes()) or {
358 ctx.request_error('invalid zstd encoding')
359 return false
360 }
361 ctx.req.data = decompressed.bytestr()
362 }
363 }
364 return true
365 }
366 }
367}
368
369pub const cors_safelisted_response_headers = [http.CommonHeader.cache_control, .content_language,
370 .content_length, .content_type, .expires, .last_modified, .pragma].map(it.str()).join(',')
371
372// CorsOptions is used to set CORS response headers.
373// See https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS#the_http_response_headers
374@[params]
375pub struct CorsOptions {
376pub:
377 // from which origin(s) can cross-origin requests be made; `Access-Control-Allow-Origin`
378 origins []string @[required]
379 // indicate whether the server allows credentials, e.g. cookies, in cross-origin requests.
380 // ;`Access-Control-Allow-Credentials`
381 allow_credentials bool
382 // allowed HTTP headers for a cross-origin request; `Access-Control-Allow-Headers`
383 allowed_headers []string
384 // allowed HTTP methods for a cross-origin request; `Access-Control-Allow-Methods`
385 allowed_methods []http.Method
386 // indicate if clients are able to access other headers than the "CORS-safelisted"
387 // response headers; `Access-Control-Expose-Headers`
388 expose_headers []string
389 // how long the results of a preflight request can be cached, value is in seconds
390 // ; `Access-Control-Max-Age`
391 max_age ?int
392}
393
394// set_headers adds the CORS headers on the response
395pub fn (options &CorsOptions) set_headers(mut ctx Context) {
396 // A browser will reject a CORS request when the Access-Control-Allow-Origin header
397 // is not present. By not setting the CORS headers when an invalid origin is supplied
398 // we force the browser to reject the preflight and the actual request.
399 origin := ctx.req.header.get(.origin) or { return }
400 if options.origins != ['*'] && origin !in options.origins {
401 return
402 }
403
404 ctx.set_header(.access_control_allow_origin, origin)
405 ctx.set_header(.vary, 'Origin, Access-Control-Request-Headers')
406
407 // dont' set the value of `Access-Control-Allow-Credentials` to 'false', but
408 // omit the header if the value is `false`
409 if options.allow_credentials {
410 ctx.set_header(.access_control_allow_credentials, 'true')
411 }
412
413 if options.allowed_headers.len > 0 {
414 ctx.set_header(.access_control_allow_headers, options.allowed_headers.join(','))
415 } else if _ := ctx.req.header.get(.access_control_request_headers) {
416 // a server must respond with `Access-Control-Allow-Headers` if
417 // `Access-Control-Request-Headers` is present in a preflight request
418 ctx.set_header(.access_control_allow_headers, cors_safelisted_response_headers)
419 }
420
421 if options.allowed_methods.len > 0 {
422 method_str := options.allowed_methods.str().trim('[]')
423 ctx.set_header(.access_control_allow_methods, method_str)
424 }
425
426 if options.expose_headers.len > 0 {
427 ctx.set_header(.access_control_expose_headers, options.expose_headers.join(','))
428 }
429
430 if max_age := options.max_age {
431 ctx.set_header(.access_control_max_age, max_age.str())
432 }
433}
434
435// validate_request checks if a cross-origin request is made and verifies the CORS
436// headers. If a cross-origin request is invalid this method will send a response
437// using `ctx`.
438pub fn (options &CorsOptions) validate_request(mut ctx Context) bool {
439 origin := ctx.req.header.get(.origin) or { return true }
440 if options.origins != ['*'] && origin !in options.origins {
441 ctx.res.set_status(.forbidden)
442 ctx.text('invalid CORS origin')
443
444 $if veb_trace_cors ? {
445 eprintln('[veb]: rejected CORS request from "${origin}". Reason: invalid origin')
446 }
447 return false
448 }
449
450 ctx.set_header(.access_control_allow_origin, origin)
451 ctx.set_header(.vary, 'Origin, Access-Control-Request-Headers')
452
453 if options.allow_credentials {
454 ctx.set_header(.access_control_allow_credentials, 'true')
455 }
456
457 // validate request method
458 if ctx.req.method !in options.allowed_methods {
459 ctx.res.set_status(.method_not_allowed)
460 ctx.text('${ctx.req.method} requests are not allowed')
461
462 $if veb_trace_cors ? {
463 eprintln('[veb]: rejected CORS request from "${origin}". Reason: invalid request method: ${ctx.req.method}')
464 }
465 return false
466 }
467
468 if options.allowed_headers.len > 0 && options.allowed_headers != ['*'] {
469 // validate request headers
470 for header in ctx.req.header.keys() {
471 if header !in options.allowed_headers {
472 ctx.res.set_status(.forbidden)
473 ctx.text('invalid Header "${header}"')
474
475 $if veb_trace_cors ? {
476 eprintln('[veb]: rejected CORS request from "${origin}". Reason: invalid header "${header}"')
477 }
478 return false
479 }
480 }
481 }
482
483 $if veb_trace_cors ? {
484 eprintln('[veb]: received CORS request from "${origin}": HTTP ${ctx.req.method} ${ctx.req.url}')
485 }
486
487 return true
488}
489
490// cors handles cross-origin requests by adding Access-Control-* headers to a
491// preflight request and validating the headers of a cross-origin request.
492// Usage example:
493// ```v
494// app.use(veb.cors[Context](veb.CorsOptions{
495// origins: ['*']
496// allowed_methods: [.get, .head, .patch, .put, .post, .delete]
497// }))
498// ```
499pub fn cors[T](options CorsOptions) MiddlewareOptions[T] {
500 return MiddlewareOptions[T]{
501 handler: fn [options] [T](mut ctx T) bool {
502 if ctx.req.method == .options { // preflight
503 options.set_headers(mut ctx.Context)
504 ctx.text('ok')
505 return false
506 } else {
507 return options.validate_request(mut ctx.Context)
508 }
509 }
510 }
511}
512