From 359015de2c49e426c27b1d25dbf599b08a9d3ee6 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Thu, 15 Sep 2022 18:26:10 +0100 Subject: [PATCH] Correctly validate the bounds of decompression Motivation Currently we don't confirm that the decompression has completed successfully. This means that we can incorrectly spin forever attempting to decompress past the end of a message, and that we can fail to notice that a message is truncated. Neither of these is good. Modifications Propagate the message zlib gives us as to whether or not decompression is done, and keep track of it. Add some tests written by @vojtarylko to validate the behaviour. Result Correctly police the bounds of the messages. Resolves #175 and #176. --- .../HTTPDecompression.swift | 55 ++++++++++++++++--- .../HTTPRequestDecompressor.swift | 20 ++++++- .../HTTPResponseDecompressor.swift | 20 ++++++- .../HTTPRequestDecompressorTest+XCTest.swift | 2 + .../HTTPRequestDecompressorTest.swift | 26 ++++++++- .../HTTPResponseDecompressorTest+XCTest.swift | 2 + .../HTTPResponseDecompressorTest.swift | 25 +++++++++ 7 files changed, 136 insertions(+), 14 deletions(-) diff --git a/Sources/NIOHTTPCompression/HTTPDecompression.swift b/Sources/NIOHTTPCompression/HTTPDecompression.swift index d8b8cad8..f9e5aaf2 100644 --- a/Sources/NIOHTTPCompression/HTTPDecompression.swift +++ b/Sources/NIOHTTPCompression/HTTPDecompression.swift @@ -57,6 +57,29 @@ public enum NIOHTTPDecompression { case initializationError(Int) } + public struct ExtraDecompressionError: Error, Hashable, CustomStringConvertible { + private var backing: Backing + + private enum Backing { + case invalidTrailingData + case truncatedData + } + + private init(_ backing: Backing) { + self.backing = backing + } + + /// Decompression completed but there was invalid trailing data behind the compressed data. + public static let invalidTrailingData = Self(.invalidTrailingData) + + /// The decompressed data was incorrectly truncated. + public static let truncatedData = Self(.truncatedData) + + public var description: String { + return String(describing: self.backing) + } + } + enum CompressionAlgorithm: String { case gzip case deflate @@ -91,12 +114,15 @@ public enum NIOHTTPDecompression { self.limit = limit } - mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer, compressedLength: Int) throws { - self.inflated += try self.stream.inflatePart(input: &part, output: &buffer) + mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer, compressedLength: Int) throws -> InflateResult { + let result = try self.stream.inflatePart(input: &part, output: &buffer) + self.inflated += result.written if self.limit.exceeded(compressed: compressedLength, decompressed: self.inflated) { throw NIOHTTPDecompression.DecompressionError.limit } + + return result } mutating func initializeDecoder(encoding: NIOHTTPDecompression.CompressionAlgorithm) throws { @@ -117,9 +143,10 @@ public enum NIOHTTPDecompression { } extension z_stream { - mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws -> Int { + mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws -> InflateResult { let minimumCapacity = input.readableBytes * 2 - var written = 0 + var inflateResult = InflateResult(written: 0, complete: false) + try input.readWithUnsafeMutableReadableBytes { pointer in self.avail_in = UInt32(pointer.count) self.next_in = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!) @@ -131,24 +158,34 @@ extension z_stream { self.next_out = nil } - written += try self.inflatePart(to: &output, minimumCapacity: minimumCapacity) + inflateResult = try self.inflatePart(to: &output, minimumCapacity: minimumCapacity) return pointer.count - Int(self.avail_in) } - return written + return inflateResult } - private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> Int { - return try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in + private mutating func inflatePart(to buffer: inout ByteBuffer, minimumCapacity: Int) throws -> InflateResult { + var rc = Z_OK + + let written = try buffer.writeWithUnsafeMutableBytes(minimumWritableBytes: minimumCapacity) { pointer in self.avail_out = UInt32(pointer.count) self.next_out = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!) - let rc = inflate(&self, Z_NO_FLUSH) + rc = inflate(&self, Z_NO_FLUSH) guard rc == Z_OK || rc == Z_STREAM_END else { throw NIOHTTPDecompression.DecompressionError.inflationError(Int(rc)) } return pointer.count - Int(self.avail_out) } + + return InflateResult(written: written, complete: rc == Z_STREAM_END) } } + +struct InflateResult { + var written: Int + + var complete: Bool +} diff --git a/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift b/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift index d3e52fa9..bbbc81a8 100644 --- a/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift +++ b/Sources/NIOHTTPCompression/HTTPRequestDecompressor.swift @@ -34,12 +34,14 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh private var decompressor: NIOHTTPDecompression.Decompressor private var compression: Compression? + private var decompressionComplete: Bool /// Initialise with limits. /// - Parameter limit: Limit to how much inflation can occur to protect against bad cases. public init(limit: NIOHTTPDecompression.DecompressionLimit) { self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit) self.compression = nil + self.decompressionComplete = false } public func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -68,10 +70,13 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh return } - while part.readableBytes > 0 { + while part.readableBytes > 0 && !self.decompressionComplete { do { var buffer = context.channel.allocator.buffer(capacity: 16384) - try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.contentLength) + let result = try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.contentLength) + if result.complete { + self.decompressionComplete = true + } context.fireChannelRead(self.wrapInboundOut(.body(buffer))) } catch let error { @@ -79,10 +84,21 @@ public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableCh return } } + + if part.readableBytes > 0 { + context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.invalidTrailingData) + } case .end: if self.compression != nil { + let wasDecompressionComplete = self.decompressionComplete + self.decompressor.deinitializeDecoder() self.compression = nil + self.decompressionComplete = false + + if !wasDecompressionComplete { + context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.truncatedData) + } } context.fireChannelRead(data) diff --git a/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift b/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift index 1e7442f3..64c60182 100644 --- a/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift +++ b/Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift @@ -38,11 +38,13 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC private var compression: Compression? = nil private var decompressor: NIOHTTPDecompression.Decompressor + private var decompressionComplete: Bool /// Initialise /// - Parameter limit: Limit on the amount of decompression allowed. public init(limit: NIOHTTPDecompression.DecompressionLimit) { self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit) + self.decompressionComplete = false } public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { @@ -84,22 +86,36 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC do { compression.compressedLength += part.readableBytes - while part.readableBytes > 0 { + while part.readableBytes > 0 && !self.decompressionComplete { var buffer = context.channel.allocator.buffer(capacity: 16384) - try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.compressedLength) + let result = try self.decompressor.decompress(part: &part, buffer: &buffer, compressedLength: compression.compressedLength) + if result.complete { + self.decompressionComplete = true + } context.fireChannelRead(self.wrapInboundOut(.body(buffer))) } // assign the changed local property back to the class state self.compression = compression + + if part.readableBytes > 0 { + context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.invalidTrailingData) + } } catch { context.fireErrorCaught(error) } case .end: if self.compression != nil { + let wasDecompressionComplete = self.decompressionComplete + self.decompressor.deinitializeDecoder() self.compression = nil + self.decompressionComplete = false + + if !wasDecompressionComplete { + context.fireErrorCaught(NIOHTTPDecompression.ExtraDecompressionError.truncatedData) + } } context.fireChannelRead(data) } diff --git a/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest+XCTest.swift b/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest+XCTest.swift index 12649957..04a31a9c 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest+XCTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest+XCTest.swift @@ -30,6 +30,8 @@ extension HTTPRequestDecompressorTest { ("testDecompressionLimitRatio", testDecompressionLimitRatio), ("testDecompressionLimitSize", testDecompressionLimitSize), ("testDecompression", testDecompression), + ("testDecompressionTrailingData", testDecompressionTrailingData), + ("testDecompressionTruncatedInput", testDecompressionTruncatedInput), ] } } diff --git a/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift b/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift index 38d0701f..8e035785 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift @@ -120,9 +120,33 @@ class HTTPRequestDecompressorTest: XCTestCase { ) XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed))) + XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil))) } + } + + func testDecompressionTrailingData() throws { + // Valid compressed data with some trailing garbage + let compressed = ByteBuffer(bytes: [120, 156, 99, 0, 0, 0, 1, 0, 1] + [1, 2, 3]) + + let channel = EmbeddedChannel() + try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait() + let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) + try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) + + XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.body(compressed))) + } + + func testDecompressionTruncatedInput() throws { + // Truncated compressed data + let compressed = ByteBuffer(bytes: [120, 156, 99, 0]) - XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil))) + let channel = EmbeddedChannel() + try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait() + let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) + try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) + + XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed))) + XCTAssertThrowsError(try channel.writeInbound(HTTPServerRequestPart.end(nil))) } private func compress(_ body: ByteBuffer, _ algorithm: String) -> ByteBuffer { diff --git a/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest+XCTest.swift b/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest+XCTest.swift index 30d6d459..9f844d11 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest+XCTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest+XCTest.swift @@ -37,6 +37,8 @@ extension HTTPResponseDecompressorTest { ("testDecompressionLimitRatioWithoutContentLenghtHeaderFails", testDecompressionLimitRatioWithoutContentLenghtHeaderFails), ("testDecompression", testDecompression), ("testDecompressionWithoutContentLength", testDecompressionWithoutContentLength), + ("testDecompressionTrailingData", testDecompressionTrailingData), + ("testDecompressionTruncatedInput", testDecompressionTruncatedInput), ] } } diff --git a/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift b/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift index 1d9ccf79..b42e629f 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPResponseDecompressorTest.swift @@ -239,6 +239,31 @@ class HTTPResponseDecompressorTest: XCTestCase { XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.end(nil))) } + func testDecompressionTrailingData() throws { + // Valid compressed data with some trailing garbage + let compressed = ByteBuffer(bytes: [120, 156, 99, 0, 0, 0, 1, 0, 1] + [1, 2, 3]) + + let channel = EmbeddedChannel() + try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait() + let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) + try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers))) + + XCTAssertThrowsError(try channel.writeInbound(HTTPClientResponsePart.body(compressed))) + } + + func testDecompressionTruncatedInput() throws { + // Truncated compressed data + let compressed = ByteBuffer(bytes: [120, 156, 99, 0]) + + let channel = EmbeddedChannel() + try channel.pipeline.addHandler(NIOHTTPResponseDecompressor(limit: .none)).wait() + let headers = HTTPHeaders([("Content-Encoding", "deflate"), ("Content-Length", "\(compressed.readableBytes)")]) + try channel.writeInbound(HTTPClientResponsePart.head(.init(version: .init(major: 1, minor: 1), status: .ok, headers: headers))) + + XCTAssertNoThrow(try channel.writeInbound(HTTPClientResponsePart.body(compressed))) + XCTAssertThrowsError(try channel.writeInbound(HTTPClientResponsePart.end(nil))) + } + private func compress(_ body: ByteBuffer, _ algorithm: String) -> ByteBuffer { var stream = z_stream()