Skip to content

Commit

Permalink
Fixes XML encoding issues (#488)
Browse files Browse the repository at this point in the history
* Verifies serialized data matches expected data in tests

* Handle xmlName trait

- Adds code to properly set the xml root key if the object has a custom xml name set
- Sets the xmlName trait for operation shapes since we modify operation shape names

* No longer encodes floats and doubles as strings

* Updates tests

* ktlintformat

* Addresses PR feedback

* Adds explicit FoundationXML import for linux

* Checks Any objects for equality

* Cleans up code and adds docs
  • Loading branch information
epau committed Dec 7, 2022
1 parent 0cca10a commit 581320d
Show file tree
Hide file tree
Showing 17 changed files with 295 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ public struct SerializableBodyMiddleware<OperationStackInput: Encodable,
OperationStackOutput: HttpResponseBinding>: Middleware {
public let id: Swift.String = "\(String(describing: OperationStackInput.self))BodyMiddleware"

public init() {}
let xmlName: String?

public init(xmlName: String? = nil) {
self.xmlName = xmlName
}

public func handle<H>(context: Context,
input: SerializeStepInput<OperationStackInput>,
Expand All @@ -20,7 +24,12 @@ public struct SerializableBodyMiddleware<OperationStackInput: Encodable,
Self.Context == H.Context {
do {
let encoder = context.getEncoder()
let data = try encoder.encode(input.operationInput)
let data: Data
if let xmlName = xmlName, let xmlEncoder = encoder as? XMLEncoder {
data = try xmlEncoder.encode(input.operationInput, withRootKey: xmlName)
} else {
data = try encoder.encode(input.operationInput)
}
let body = HttpBody.data(data)
input.builder.withBody(body)
} catch let err {
Expand Down
53 changes: 53 additions & 0 deletions Packages/SmithyTestUtil/Sources/JSONComparator.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//

import Foundation

public struct JSONComparator {
/// Returns true if the JSON documents, for the corresponding data objects, are equal.
/// - Parameters:
/// - dataA: The first data object to compare to the second data object.
/// - dataB: The second data object to compare to the first data object.
/// - Returns: Returns true if the JSON documents, for the corresponding data objects, are equal.
public static func jsonData(_ dataA: Data, isEqualTo dataB: Data) throws -> Bool {
let jsonDictA = try JSONSerialization.jsonObject(with: dataA)
let jsonDictB = try JSONSerialization.jsonObject(with: dataB)
return anyValuesAreEqual(jsonDictA, jsonDictB)
}
}

fileprivate func anyDictsAreEqual(_ lhs: [String: Any], _ rhs: [String: Any]) -> Bool {
guard lhs.keys == rhs.keys else { return false }
for key in lhs.keys {
if !anyValuesAreEqual(lhs[key], rhs[key]) {
return false
}
}
return true
}

fileprivate func anyArraysAreEqual(_ lhs: [Any], _ rhs: [Any]) -> Bool {
guard lhs.count == rhs.count else { return false }
for i in 0..<lhs.count {
if !anyValuesAreEqual(lhs[i], rhs[i]) {
return false
}
}
return true
}

fileprivate func anyValuesAreEqual(_ lhs: Any?, _ rhs: Any?) -> Bool {
if lhs == nil && rhs == nil { return true }
guard let lhs = lhs, let rhs = rhs else { return false }
if let lhsDict = lhs as? [String: Any], let rhsDict = rhs as? [String: Any] {
return anyDictsAreEqual(lhsDict, rhsDict)
} else if let lhsArray = lhs as? [Any], let rhsArray = rhs as? [Any] {
return anyArraysAreEqual(lhsArray, rhsArray)
} else {
return type(of: lhs) == type(of: rhs) && "\(lhs)" == "\(rhs)"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ open class HttpRequestTestBase: XCTestCase {
public func genericAssertEqualHttpBodyData(
_ expected: HttpBody,
_ actual: HttpBody,
_ encoder: Any,
_ callback: (Data, Data) -> Void,
file: StaticString = #filePath,
line: UInt = #line
Expand All @@ -244,6 +245,11 @@ open class HttpRequestTestBase: XCTestCase {
return
}
if shouldCompareData(expectedData, actualData) {
if encoder is XMLEncoder {
XCTAssertXMLDataEqual(actualData!, expectedData!, file: file, line: line)
} else if encoder is JSONEncoder {
XCTAssertJSONDataEqual(actualData!, expectedData!, file: file, line: line)
}
callback(expectedData!, actualData!)
}
}
Expand Down
48 changes: 48 additions & 0 deletions Packages/SmithyTestUtil/Sources/XCTAssertions.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//

import Foundation
import XCTest

public func XCTAssertJSONDataEqual(
_ expression1: @autoclosure () throws -> Data,
_ expression2: @autoclosure () throws -> Data,
_ message: @autoclosure () -> String = "",
file: StaticString = #filePath,
line: UInt = #line
) {
do {
let data1 = try expression1()
let data2 = try expression2()
guard data1 != data2 else { return }
XCTAssertTrue(
try JSONComparator.jsonData(data1, isEqualTo: data2),
message(),
file: file,
line: line
)
} catch {
XCTFail("Failed to evaluate JSON with error: \(error)", file: file, line: line)
}
}

public func XCTAssertXMLDataEqual(
_ expression1: @autoclosure () throws -> Data,
_ expression2: @autoclosure () throws -> Data,
_ message: @autoclosure () -> String = "",
file: StaticString = #filePath,
line: UInt = #line
) {
do {
let data1 = try expression1()
let data2 = try expression2()
guard data1 != data2 else { return }
XCTAssertTrue(XMLComparator.xmlData(data1, isEqualTo: data2), message(), file: file, line: line)
} catch {
XCTFail("Failed to evaluate XML with error: \(error)", file: file, line: line)
}
}
95 changes: 95 additions & 0 deletions Packages/SmithyTestUtil/Sources/XMLComparator.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//

import Foundation
#if canImport(FoundationXML)
// As of Swift 5.1, the Foundation module on Linux only has the same set of dependencies as the Swift standard library itself
// Therefore, we need to explicitly import FoundationXML on linux.
// The preferred way to do this, is to check if FoundationXML can be imported.
// https://github.com/apple/swift-corelibs-foundation/blob/main/Docs/ReleaseNotes_Swift5.md
import FoundationXML
#endif

public struct XMLComparator {
/// Returns true if the XML documents, for the corresponding data objects, are equal.
/// Order of elements within the document do not affect equality.
/// - Parameters:
/// - dataA: The first data object to compare to the second data object.
/// - dataB: The second data object to compare to the first data object.
/// - Returns: Returns true if the XML documents, for the corresponding data objects, are equal.
public static func xmlData(_ dataA: Data, isEqualTo dataB: Data) -> Bool {
let rootA = XMLConverter.xmlTree(with: dataA)
let rootB = XMLConverter.xmlTree(with: dataB)
return rootA == rootB
}
}

private struct XMLElement: Hashable {
var name: String?
var attributes: [String : String]?
var string: String?
var elements: Set<XMLElement> = []
}

private class XMLConverter: NSObject {
/// Keeps track of the value since `foundCharacters` can be called multiple times for the same element
private var valueBuffer = ""
private var stack: [XMLElement] = []

static func xmlTree(with data: Data) -> XMLElement {
let converter = XMLConverter()
converter.stack.append(XMLElement())

let parser = XMLParser(data: data)
parser.delegate = converter
parser.parse()

return converter.stack.first!
}
}

extension XMLConverter: XMLParserDelegate {
func parser(
_ parser: XMLParser,
didStartElement elementName: String,
namespaceURI: String?,
qualifiedName qName: String?,
attributes attributeDict: [String : String] = [:]
) {
let parent = stack.last!
let element = XMLElement(
name: elementName,
attributes: attributeDict
)

stack.append(element)
}

func parser(_ parser: XMLParser, foundCharacters string: String) {
let trimmedString = string.trimmingCharacters(in: .whitespacesAndNewlines)
valueBuffer.append(trimmedString)
}

func parser(
_ parser: XMLParser, didEndElement
elementName: String,
namespaceURI: String?,
qualifiedName qName: String?
) {
var element = stack.popLast()!
var parent = stack.last!

element.string = valueBuffer

var elements = parent.elements
elements.insert(element)
parent.elements = elements

stack[stack.endIndex - 1] = parent
valueBuffer = ""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class HttpRequestTestBaseTests: HttpRequestTestBase {
self.assertEqual(expected, actual, { (expectedHttpBody, actualHttpBody) -> Void in
XCTAssertNotNil(actualHttpBody, "The actual HttpBody is nil")
XCTAssertNotNil(expectedHttpBody, "The expected HttpBody is nil")
self.genericAssertEqualHttpBodyData(expectedHttpBody!, actualHttpBody!) { (expectedData, actualData) in
self.genericAssertEqualHttpBodyData(expectedHttpBody!, actualHttpBody!, JSONEncoder()) { (expectedData, actualData) in
do {
let decoder = JSONDecoder()
let expectedObj = try decoder.decode(SayHelloInputBody.self, from: expectedData)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ open class HttpProtocolUnitTestRequestGenerator protected constructor(builder: B
writer.write("XCTAssertNotNil(expectedHttpBody, \"The expected HttpBody is nil\")")
val expectedData = "expectedData"
val actualData = "actualData"
writer.openBlock("self.genericAssertEqualHttpBodyData(expectedHttpBody!, actualHttpBody!) { $expectedData, $actualData in ", "}") {
writer.openBlock("self.genericAssertEqualHttpBodyData(expectedHttpBody!, actualHttpBody!, encoder) { $expectedData, $actualData in ", "}") {
val httpPayloadShape = inputShape.members().firstOrNull { it.hasTrait(HttpPayloadTrait::class.java) }

httpPayloadShape?.let {
Expand Down Expand Up @@ -204,6 +204,11 @@ open class HttpProtocolUnitTestRequestGenerator protected constructor(builder: B
writer.write("}")
}

private fun renderDataComparison(writer: SwiftWriter, expectedData: String, actualData: String) {
val assertionMethod = "XCTAssertJSONDataEqual"
writer.write("\$L(\$L, \$L, \"Some error message\")", assertionMethod, actualData, expectedData)
}

protected open fun renderAssertions(test: HttpRequestTestCase, outputShape: Shape) {
val members = outputShape.members().filterNot { it.hasTrait(HttpQueryTrait::class.java) }
.filterNot { it.hasTrait(HttpHeaderTrait::class.java) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package software.amazon.smithy.swift.codegen.integration.middlewares
import software.amazon.smithy.codegen.core.SymbolProvider
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.traits.XmlNameTrait
import software.amazon.smithy.swift.codegen.ClientRuntimeTypes
import software.amazon.smithy.swift.codegen.SwiftWriter
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils
import software.amazon.smithy.swift.codegen.middleware.MiddlewarePosition
import software.amazon.smithy.swift.codegen.middleware.MiddlewareRenderable
import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep
import software.amazon.smithy.swift.codegen.model.getTrait

class OperationInputBodyMiddleware(
val model: Model,
Expand All @@ -27,15 +29,38 @@ class OperationInputBodyMiddleware(
op: OperationShape,
operationStackName: String,
) {
val inputShape = MiddlewareShapeUtils.inputShape(model, op)
val inputShapeName = MiddlewareShapeUtils.inputSymbol(symbolProvider, model, op).name
val outputShapeName = MiddlewareShapeUtils.outputSymbol(symbolProvider, model, op).name
val xmlName = inputShape.getTrait<XmlNameTrait>()?.value

if (alwaysSendBody) {
writer.write("$operationStackName.${middlewareStep.stringValue()}.intercept(position: ${position.stringValue()}, middleware: \$N<$inputShapeName, $outputShapeName>())", ClientRuntimeTypes.Middleware.SerializableBodyMiddleware)
if (xmlName != null) {
writer.write(
"\$L.\$L.intercept(position: \$L, middleware: \$N<\$L, \$L>(xmlName: \"\$L\"))",
operationStackName, middlewareStep.stringValue(), position.stringValue(), ClientRuntimeTypes.Middleware.SerializableBodyMiddleware, inputShapeName, outputShapeName, xmlName
)
} else {
writer.write(
"\$L.\$L.intercept(position: \$L, middleware: \$N<\$L, \$L>())",
operationStackName, middlewareStep.stringValue(), position.stringValue(), ClientRuntimeTypes.Middleware.SerializableBodyMiddleware, inputShapeName, outputShapeName
)
}
} else if (MiddlewareShapeUtils.hasHttpBody(model, op)) {
if (MiddlewareShapeUtils.bodyIsHttpPayload(model, op)) {
writer.write("$operationStackName.${middlewareStep.stringValue()}.intercept(position: ${position.stringValue()}, middleware: ${inputShapeName}BodyMiddleware())")
} else {
writer.write("$operationStackName.${middlewareStep.stringValue()}.intercept(position: ${position.stringValue()}, middleware: \$N<$inputShapeName, $outputShapeName>())", ClientRuntimeTypes.Middleware.SerializableBodyMiddleware)
if (xmlName != null) {
writer.write(
"\$L.\$L.intercept(position: \$L, middleware: \$N<\$L, \$L>(xmlName: \"\$L\"))",
operationStackName, middlewareStep.stringValue(), position.stringValue(), ClientRuntimeTypes.Middleware.SerializableBodyMiddleware, inputShapeName, outputShapeName, xmlName
)
} else {
writer.write(
"\$L.\$L.intercept(position: \$L, middleware: \$N<\$L, \$L>())",
operationStackName, middlewareStep.stringValue(), position.stringValue(), ClientRuntimeTypes.Middleware.SerializableBodyMiddleware, inputShapeName, outputShapeName
)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

package software.amazon.smithy.swift.codegen.integration.middlewares.handlers

import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ShapeType
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.model.traits.StreamingTrait
import software.amazon.smithy.model.traits.XmlNameTrait
import software.amazon.smithy.swift.codegen.ClientRuntimeTypes
import software.amazon.smithy.swift.codegen.Middleware
import software.amazon.smithy.swift.codegen.MiddlewareGenerator
Expand All @@ -21,6 +23,7 @@ import software.amazon.smithy.swift.codegen.integration.HttpBindingDescriptor
import software.amazon.smithy.swift.codegen.integration.HttpBindingResolver
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
import software.amazon.smithy.swift.codegen.integration.steps.OperationSerializeStep
import software.amazon.smithy.swift.codegen.model.getTrait
import software.amazon.smithy.swift.codegen.model.hasTrait

class HttpBodyMiddleware(
Expand Down Expand Up @@ -107,7 +110,19 @@ class HttpBodyMiddleware(
writer.openBlock("do {", "} catch let err {") {
writer.write("let encoder = context.getEncoder()")
writer.openBlock("if let $memberName = input.operationInput.$memberName {", "} else {") {
writer.write("let $dataDeclaration = try encoder.encode(\$L)", memberName)

val xmlNameTrait = binding.member.getTrait<XmlNameTrait>() ?: target.getTrait<XmlNameTrait>()
if (ctx.protocol == RestXmlTrait.ID && xmlNameTrait != null) {
val xmlName = xmlNameTrait.value
writer.write("let xmlEncoder = encoder as! XMLEncoder")
writer.write(
"let $dataDeclaration = try xmlEncoder.encode(\$L, withRootKey: \"\$L\")",
memberName, xmlName
)
} else {
writer.write("let $dataDeclaration = try encoder.encode(\$L)", memberName)
}

renderEncodedBodyAddedToRequest(bodyDeclaration, dataDeclaration)
}
writer.indent()
Expand Down
Loading

0 comments on commit 581320d

Please sign in to comment.