From e87dd41a9e6a00e052fb37c90b41979d5d866bcd Mon Sep 17 00:00:00 2001 From: Chan <55515281+sichanyoo@users.noreply.github.com> Date: Fri, 6 Sep 2024 12:49:11 -0700 Subject: [PATCH] feat: Add default trait support (#803) * Add default value handling to SwiftSymbolProvider. * Add default value handling for generating expected body when generating expected HTTP response in protocol tests. * Add more comprehensive default value handling for deserialization, also add error correction for server failing to send a required value in response by filling with zero or zero-equivalents. * Modify ShapeExt::defaultValue logic to handle an optional value having a default value. * Handle edgecase where floating point value has integer default value given in trait. * Provide zero-equivalent error correction default value for enums by using .sdkUnknown case. * Handle JSON number value equality with loosened restriction (JSON handles 2 and 2.0 the same, as numbers) * Add enum trait handling * Fix enum case name codegen for default value * Fix int enum value handling & address future Swift 6 error warning by using Foundation.Data() to convert string to data. Use symbol's property bag to set flag for importing Data outside of SwiftSymbolProvider. * Refactor blob and document shape type default value codegen to reduce duplication. * Use flag for importing Foundation.Data set by SwiftSymbolProvider (when handling blob shape) and add import. * Ktlint * Refactor to use when{} as expression; return no default value instead of throwing error in MemberShapeDecodeGenerator for null node case. * ktlint * Address PR comments * Address PR comments * Address PR comment * Use closure to handle dependency import for the default value resolved by SwiftSymbolProvider at a later time when the resolved symbol gets used by SwiftWriter. Also, fix timestamp value handling and add logic for date-time case as well. * ktlint --------- Co-authored-by: Sichan Yoo --- Sources/SmithyTestUtil/JSONComparator.swift | 2 + .../swift/codegen/SwiftSymbolProvider.kt | 160 ++++++++++++++++-- .../smithy/swift/codegen/SwiftWriter.kt | 12 +- .../amazon/smithy/swift/codegen/Utils.kt | 2 +- .../HttpProtocolUnitTestResponseGenerator.kt | 61 ++++++- .../member/MemberShapeDecodeGenerator.kt | 103 ++++++++++- .../smithy/swift/codegen/model/SymbolExt.kt | 44 ++++- .../swift/codegen/swiftmodules/SmithyTypes.kt | 1 + 8 files changed, 351 insertions(+), 34 deletions(-) diff --git a/Sources/SmithyTestUtil/JSONComparator.swift b/Sources/SmithyTestUtil/JSONComparator.swift index 6eb4349ad..702e6a9a0 100644 --- a/Sources/SmithyTestUtil/JSONComparator.swift +++ b/Sources/SmithyTestUtil/JSONComparator.swift @@ -47,6 +47,8 @@ fileprivate func anyValuesAreEqual(_ lhs: Any?, _ rhs: Any?) -> Bool { return anyDictsAreEqual(lhsDict, rhsDict) } else if let lhsArray = lhs as? [Any], let rhsArray = rhs as? [Any] { return anyArraysAreEqual(lhsArray, rhsArray) + } else if let lhn = lhs as? NSNumber, let rhn = rhs as? NSNumber { + return lhn == rhn } else { return type(of: lhs) == type(of: rhs) && "\(lhs)" == "\(rhs)" } diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/SwiftSymbolProvider.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/SwiftSymbolProvider.kt index 3b9e10142..80dcf6299 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/SwiftSymbolProvider.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/SwiftSymbolProvider.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.codegen.core.SymbolReference import software.amazon.smithy.model.Model import software.amazon.smithy.model.knowledge.NullableIndex +import software.amazon.smithy.model.node.Node import software.amazon.smithy.model.shapes.BigDecimalShape import software.amazon.smithy.model.shapes.BigIntegerShape import software.amazon.smithy.model.shapes.BlobShape @@ -23,6 +24,7 @@ import software.amazon.smithy.model.shapes.DocumentShape import software.amazon.smithy.model.shapes.DoubleShape import software.amazon.smithy.model.shapes.EnumShape import software.amazon.smithy.model.shapes.FloatShape +import software.amazon.smithy.model.shapes.IntEnumShape import software.amazon.smithy.model.shapes.IntegerShape import software.amazon.smithy.model.shapes.ListShape import software.amazon.smithy.model.shapes.LongShape @@ -39,8 +41,11 @@ import software.amazon.smithy.model.shapes.StringShape import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.ClientOptionalTrait +import software.amazon.smithy.model.traits.DefaultTrait import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.ErrorTrait +import software.amazon.smithy.model.traits.InputTrait import software.amazon.smithy.model.traits.SparseTrait import software.amazon.smithy.model.traits.StreamingTrait import software.amazon.smithy.swift.codegen.customtraits.NestedTrait @@ -49,8 +54,15 @@ import software.amazon.smithy.swift.codegen.model.SymbolProperty import software.amazon.smithy.swift.codegen.model.boxed import software.amazon.smithy.swift.codegen.model.buildSymbol import software.amazon.smithy.swift.codegen.model.defaultName +import software.amazon.smithy.swift.codegen.model.defaultValue +import software.amazon.smithy.swift.codegen.model.defaultValueClosure +import software.amazon.smithy.swift.codegen.model.getTrait import software.amazon.smithy.swift.codegen.model.hasTrait import software.amazon.smithy.swift.codegen.model.nestedNamespaceType +import software.amazon.smithy.swift.codegen.swiftmodules.FoundationTypes +import software.amazon.smithy.swift.codegen.swiftmodules.SmithyReadWriteTypes +import software.amazon.smithy.swift.codegen.swiftmodules.SmithyTimestampsTypes +import software.amazon.smithy.swift.codegen.swiftmodules.SmithyTypes import software.amazon.smithy.swift.codegen.swiftmodules.SwiftTypes import software.amazon.smithy.swift.codegen.utils.ModelFileUtils import software.amazon.smithy.swift.codegen.utils.clientName @@ -104,21 +116,21 @@ class SwiftSymbolProvider(private val model: Model, val swiftSettings: SwiftSett return escaper.escapeMemberName(shape.memberName.toLowerCamelCase()) } - override fun integerShape(shape: IntegerShape): Symbol = numberShape(shape, "Int", "0") + override fun integerShape(shape: IntegerShape): Symbol = numberShape(shape, "Int") - override fun floatShape(shape: FloatShape): Symbol = numberShape(shape, "Float", "0.0") + override fun floatShape(shape: FloatShape): Symbol = numberShape(shape, "Float") - override fun longShape(shape: LongShape): Symbol = numberShape(shape, "Int", "0") + override fun longShape(shape: LongShape): Symbol = numberShape(shape, "Int") - override fun doubleShape(shape: DoubleShape): Symbol = numberShape(shape, "Double", "0.0") + override fun doubleShape(shape: DoubleShape): Symbol = numberShape(shape, "Double") - override fun byteShape(shape: ByteShape): Symbol = numberShape(shape, "Int8", "0") + override fun byteShape(shape: ByteShape): Symbol = numberShape(shape, "Int8") - override fun shortShape(shape: ShortShape): Symbol = numberShape(shape, "Int16", "0") + override fun shortShape(shape: ShortShape): Symbol = numberShape(shape, "Int16") - override fun bigIntegerShape(shape: BigIntegerShape): Symbol = numberShape(shape, "Int", defaultValue = "0") + override fun bigIntegerShape(shape: BigIntegerShape): Symbol = numberShape(shape, "Int") - override fun bigDecimalShape(shape: BigDecimalShape): Symbol = numberShape(shape, "Double", "0.0") + override fun bigDecimalShape(shape: BigDecimalShape): Symbol = numberShape(shape, "Double") override fun stringShape(shape: StringShape): Symbol { val enumTrait = shape.getTrait(EnumTrait::class.java) @@ -149,7 +161,7 @@ class SwiftSymbolProvider(private val model: Model, val swiftSettings: SwiftSett } override fun booleanShape(shape: BooleanShape): Symbol { - return createSymbolBuilder(shape, "Bool", namespace = "Swift", SwiftDeclaration.STRUCT).putProperty(SymbolProperty.DEFAULT_VALUE_KEY, "false").build() + return createSymbolBuilder(shape, "Bool", namespace = "Swift", SwiftDeclaration.STRUCT).build() } override fun structureShape(shape: StructureShape): Symbol { @@ -205,7 +217,7 @@ class SwiftSymbolProvider(private val model: Model, val swiftSettings: SwiftSett .putProperty(SymbolProperty.NESTED_SYMBOL, symbol) .build() } - return symbol + return handleDefaultValue(shape, symbol.toBuilder()).build() } override fun timestampShape(shape: TimestampShape): Symbol { @@ -243,17 +255,22 @@ class SwiftSymbolProvider(private val model: Model, val swiftSettings: SwiftSett .build() } - private fun numberShape(shape: Shape?, typeName: String, defaultValue: String = "0"): Symbol { - if (shape != null && shape.isIntEnumShape()) { + private fun numberShape(shape: Shape, typeName: String): Symbol { + if (shape.isIntEnumShape()) { return createEnumSymbol(shape) } - return createSymbolBuilder(shape, typeName, "Swift", SwiftDeclaration.STRUCT).putProperty(SymbolProperty.DEFAULT_VALUE_KEY, defaultValue).build() + return createSymbolBuilder(shape, typeName, "Swift", SwiftDeclaration.STRUCT).build() } /** * Creates a symbol builder for the shape with the given type name in the root namespace. */ - private fun createSymbolBuilder(shape: Shape?, typeName: String, declaration: SwiftDeclaration, boxed: Boolean = false): Symbol.Builder { + private fun createSymbolBuilder( + shape: Shape, + typeName: String, + declaration: SwiftDeclaration, + boxed: Boolean = false + ): Symbol.Builder { val builder = Symbol.builder() .putProperty("shape", shape) .putProperty("decl", declaration.keyword) @@ -261,7 +278,7 @@ class SwiftSymbolProvider(private val model: Model, val swiftSettings: SwiftSett if (boxed) { builder.boxed() } - return builder + return handleDefaultValue(shape, builder) } /** @@ -270,7 +287,7 @@ class SwiftSymbolProvider(private val model: Model, val swiftSettings: SwiftSett * the namespace (and ultimately the package name) to `foo.bar` for the symbol. */ private fun createSymbolBuilder( - shape: Shape?, + shape: Shape, typeName: String, namespace: String, declaration: SwiftDeclaration, @@ -284,6 +301,117 @@ class SwiftSymbolProvider(private val model: Model, val swiftSettings: SwiftSett return ModelFileUtils.filename(swiftSettings, name) } + /** + * Resolve default value for a given shape and save it as a property in symbol builder if needed. + * + * The default trait can be applied to list shape, map shape, and all simple types as per Smithy spec. + * Both the member shape and the target shape may have the default trait. + * + * There exist default value restrictions for the following shapes: + * - enum: can be set to any valid string value of the enum. + * - intEnum: can be set to any valid integer value of the enum. + * - document: can be set to null, true, false, string, numbers, an empty list, or an empty map. + * - list: can only be set to an empty list. + * - map: can only be set to an empty map. + */ + private fun handleDefaultValue(shape: Shape, builder: Symbol.Builder): Symbol.Builder { + // Skip if the current shape is a member shape with @clientOptional trait + if (shape.hasTrait()) return builder + // Skip if the current shape doesn't have default trait. Otherwise, get the default value as literal string + val defaultValueLiteral = shape.getTrait()?.toNode()?.toString() ?: return builder + // If default value is "null", it is explicit notation for no default value. Return unmodified builder. + if (defaultValueLiteral == "null") return builder + + // The current shape may be a member shape or a root level shape. + val targetShape = when (shape) { + is MemberShape -> { + // If containing shape is an input shape, return unmodified builder. + if (model.expectShape(shape.container).hasTrait()) return builder + model.expectShape(shape.target) + } + else -> shape + } + val node = shape.getTrait()!!.toNode() + + return when (targetShape) { + is ListShape -> builder.defaultValue("[]") + is EnumShape -> builder.defaultValue(".${swiftEnumCaseName(null, defaultValueLiteral)}") + is IntEnumShape -> { + // Get the corresponding enum member name (enum case name) for the int value from default trait + val enumMemberName = targetShape.enumValues.entries.firstOrNull { + it.value == defaultValueLiteral.toInt() + }!!.key + builder.defaultValue(".${swiftEnumCaseName(enumMemberName, defaultValueLiteral)}") + } + is StringShape -> builder.defaultValue("\"$defaultValueLiteral\"") + is MapShape -> builder.defaultValue("[:]") + is BlobShape -> handleBlobDefaultValue(defaultValueLiteral, targetShape, builder) + is DocumentShape -> { + handleDocumentDefaultValue(defaultValueLiteral, node, builder) + } + is TimestampShape -> handleTimestampDefaultValue(defaultValueLiteral, node, builder) + is FloatShape, is DoubleShape -> { + val decimal = ".0".takeIf { !defaultValueLiteral.contains(".") } ?: "" + builder.defaultValue(defaultValueLiteral + decimal) + } + // For: boolean, byte, short, integer, long, bigInteger, bigDecimal, + // just take the literal string value from the trait. + else -> builder.defaultValue(defaultValueLiteral) + } + } + + // Document: default value can be set to null, true, false, string, numbers, an empty list, or an empty map. + private fun handleDocumentDefaultValue(literal: String, node: Node, builder: Symbol.Builder): Symbol.Builder { + var formatString = when { + node.isObjectNode -> "\$N.object([:])" + node.isArrayNode -> "\$N.array([])" + node.isBooleanNode -> "\$N.boolean($literal)" + node.isStringNode -> "\$N.string(\"$literal\")" + node.isNumberNode -> "\$N.number($literal)" + else -> return builder // no-op + } + return builder.defaultValueClosure { writer -> + writer.format(formatString, SmithyReadWriteTypes.Document) + } + } + + private fun handleBlobDefaultValue(literal: String, shape: Shape, builder: Symbol.Builder): Symbol.Builder { + return builder.defaultValueClosure( + if (shape.hasTrait()) { + { writer -> + writer.format( + "\$N.data(\$N(\"$literal\".utf8))", + SmithyTypes.ByteStream, + FoundationTypes.Data + ) + } + } else { + { writer -> + writer.format("\$N(\"$literal\".utf8)", FoundationTypes.Data) + } + } + ) + } + + private fun handleTimestampDefaultValue(literal: String, node: Node, builder: Symbol.Builder): Symbol.Builder { + // Smithy validates that default value given to timestamp shape must either be a + // number (for epoch-seconds) or a date-time string compliant with RFC3339. + return builder.defaultValueClosure( + if (node.isNumberNode) { + { writer -> + writer.format("\$N(timeIntervalSince1970: $literal)", FoundationTypes.Date) + } + } else { + { writer -> + writer.format( + "\$N(format: .dateTime).date(from: \"$literal\")", + SmithyTimestampsTypes.TimestampFormatter + ) + } + } + ) + } + /** * Add all the [members] as references needed to declare the given symbol being built. */ diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/SwiftWriter.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/SwiftWriter.kt index 3abaa7057..ed8bc8e06 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/SwiftWriter.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/SwiftWriter.kt @@ -21,7 +21,9 @@ import software.amazon.smithy.swift.cod.DocumentationConverter import software.amazon.smithy.swift.codegen.integration.SectionId import software.amazon.smithy.swift.codegen.integration.SectionWriter import software.amazon.smithy.swift.codegen.integration.SwiftIntegration +import software.amazon.smithy.swift.codegen.model.SymbolProperty import software.amazon.smithy.swift.codegen.model.defaultValue +import software.amazon.smithy.swift.codegen.model.defaultValueFromClosure import software.amazon.smithy.swift.codegen.model.isBoxed import software.amazon.smithy.swift.codegen.model.isBuiltIn import software.amazon.smithy.swift.codegen.model.isGeneric @@ -190,7 +192,7 @@ class SwiftWriter( } if (shouldSetDefault) { - type.defaultValue()?.let { + getDefaultValue(type)?.let { formatted += " = $it" } } @@ -200,6 +202,14 @@ class SwiftWriter( else -> throw CodegenException("Invalid type provided for \$T. Expected a Symbol, but found `$type`") } } + + private fun getDefaultValue(symbol: Symbol): String? { + return if (symbol.properties.containsKey(SymbolProperty.DEFAULT_VALUE_CLOSURE_KEY)) { + symbol.defaultValueFromClosure(writer) + } else { + symbol.defaultValue() + } + } } /** diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/Utils.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/Utils.kt index 43fd82421..f85e08019 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/Utils.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/Utils.kt @@ -19,7 +19,7 @@ fun String.removeSurroundingBackticks() = removeSurrounding("`", "`") */ fun swiftEnumCaseName(name: String?, value: String, shouldBeEscaped: Boolean = true): String { val resolvedName = name ?: value - var enumCaseName = CaseUtils.toCamelCase(resolvedName.replace(Regex("[^a-zA-Z0-9_ ]"), "")) + var enumCaseName = CaseUtils.toCamelCase(resolvedName.replace(Regex("[^a-zA-Z0-9_ -]"), "")) if (!SwiftSymbolProvider.isValidSwiftIdentifier(enumCaseName)) { enumCaseName = "_$enumCaseName" } diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpProtocolUnitTestResponseGenerator.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpProtocolUnitTestResponseGenerator.kt index 61f16a350..541c6defe 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpProtocolUnitTestResponseGenerator.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpProtocolUnitTestResponseGenerator.kt @@ -7,11 +7,16 @@ package software.amazon.smithy.swift.codegen.integration import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.model.shapes.Shape import software.amazon.smithy.model.shapes.StructureShape +import software.amazon.smithy.model.traits.DefaultTrait +import software.amazon.smithy.model.traits.HttpHeaderTrait +import software.amazon.smithy.model.traits.HttpPayloadTrait +import software.amazon.smithy.model.traits.HttpPrefixHeadersTrait import software.amazon.smithy.protocoltests.traits.HttpResponseTestCase import software.amazon.smithy.swift.codegen.ShapeValueGenerator import software.amazon.smithy.swift.codegen.hasStreamingMember import software.amazon.smithy.swift.codegen.integration.serde.readwrite.ResponseClosureUtils import software.amazon.smithy.swift.codegen.model.RecursiveShapeBoxer +import software.amazon.smithy.swift.codegen.model.hasTrait import software.amazon.smithy.swift.codegen.swiftmodules.SmithyStreamsTypes /** @@ -55,16 +60,31 @@ open class HttpProtocolUnitTestResponseGenerator protected constructor(builder: } open fun renderExpectedBody(test: HttpResponseTestCase) { - if (test.body.isPresent && test.body.get().isNotBlank()) { + if (test.body.isPresent) { operation.output.ifPresent { - val outputShape = model.expectShape(it) as StructureShape - val data = writer.format("Data(\"\"\"\n\$L\n\"\"\".utf8)", test.body.get().replace("\\\"", "\\\\\"")) - // depending on the shape of the output, we may need to wrap the body in a stream - if (outputShape.hasStreamingMember(model)) { - // wrapping to CachingStream required for test asserts which reads body multiple times - writer.write("content: .stream(\$N(data: \$L, isClosed: true))", SmithyStreamsTypes.Core.BufferedStream, data) + if (test.body.get().isNotBlank()) { + val outputShape = model.expectShape(it) as StructureShape + val data = writer.format( + "Data(\"\"\"\n\$L\n\"\"\".utf8)", + test.body.get().replace("\\\"", "\\\\\"") + ) + // depending on the shape of the output, we may need to wrap the body in a stream + if (outputShape.hasStreamingMember(model)) { + // wrapping to CachingStream required for test asserts which reads body multiple times + writer.write( + "content: .stream(\$N(data: \$L, isClosed: true))", + SmithyStreamsTypes.Core.BufferedStream, + data + ) + } else { + writer.write("content: .data(\$L)", data) + } + } else if (test.body.get().isBlank() && bodyHasDefaultValue()) { + // Expected body is blank but not because it's nil, but because it's a default empty blob value + writer.write("content: .data(Data(\"\".utf8))") } else { - writer.write("content: .data(\$L)", data) + // Expected body is blank and underlying member shape does not have default values. + writer.write("content: nil") } } } else { @@ -72,6 +92,31 @@ open class HttpProtocolUnitTestResponseGenerator protected constructor(builder: } } + private fun bodyHasDefaultValue(): Boolean { + var result = false + operation.output.ifPresent { + val outputShape = model.expectShape(it) as StructureShape + outputShape.allMembers.forEach { + val member = it.value + val target = model.expectShape(member.target) + val defaultValueExists = member.hasTrait() || target.hasTrait() + // If a top level input member shape has the payload trait, it's a bound payload member + val isBoundPayloadMember = member.hasTrait() + // If a top level input member doesn't have payload trait, header trait nor prefix header trait, + // it is an unbound payload member. + val isUnboundPayloadMember = !member.hasTrait() && + !member.hasTrait() && + !member.hasTrait() + // If a member has default value and goes in payload, return true + if (defaultValueExists && (isBoundPayloadMember || isUnboundPayloadMember)) { + result = true + return@ifPresent + } + } + } + return result + } + private fun renderBuildHttpResponseParams(test: HttpResponseTestCase) { writer.write("code: \$L,", test.code) renderExpectedHeaders(test) diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/serde/member/MemberShapeDecodeGenerator.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/serde/member/MemberShapeDecodeGenerator.kt index 9548555af..c758c9cee 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/serde/member/MemberShapeDecodeGenerator.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/serde/member/MemberShapeDecodeGenerator.kt @@ -10,8 +10,10 @@ import software.amazon.smithy.model.node.NumberNode import software.amazon.smithy.model.node.StringNode import software.amazon.smithy.model.shapes.BigDecimalShape import software.amazon.smithy.model.shapes.BigIntegerShape +import software.amazon.smithy.model.shapes.BlobShape import software.amazon.smithy.model.shapes.BooleanShape import software.amazon.smithy.model.shapes.ByteShape +import software.amazon.smithy.model.shapes.DocumentShape import software.amazon.smithy.model.shapes.DoubleShape import software.amazon.smithy.model.shapes.EnumShape import software.amazon.smithy.model.shapes.FloatShape @@ -29,8 +31,11 @@ import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.shapes.TimestampShape import software.amazon.smithy.model.shapes.UnionShape import software.amazon.smithy.model.traits.DefaultTrait +import software.amazon.smithy.model.traits.EnumTrait import software.amazon.smithy.model.traits.EnumValueTrait +import software.amazon.smithy.model.traits.RequiredTrait import software.amazon.smithy.model.traits.SparseTrait +import software.amazon.smithy.model.traits.StreamingTrait import software.amazon.smithy.model.traits.TimestampFormatTrait import software.amazon.smithy.model.traits.XmlFlattenedTrait import software.amazon.smithy.swift.codegen.SwiftWriter @@ -44,7 +49,10 @@ import software.amazon.smithy.swift.codegen.model.getTrait import software.amazon.smithy.swift.codegen.model.hasTrait import software.amazon.smithy.swift.codegen.model.isError import software.amazon.smithy.swift.codegen.swiftEnumCaseName +import software.amazon.smithy.swift.codegen.swiftmodules.FoundationTypes +import software.amazon.smithy.swift.codegen.swiftmodules.SmithyReadWriteTypes import software.amazon.smithy.swift.codegen.swiftmodules.SmithyTimestampsTypes +import software.amazon.smithy.swift.codegen.swiftmodules.SmithyTypes open class MemberShapeDecodeGenerator( private val ctx: ProtocolGenerator.GenerationContext, @@ -89,12 +97,13 @@ open class MemberShapeDecodeGenerator( val memberNodeInfo = nodeInfoUtils.nodeInfo(listShape.member) val isFlattened = memberShape.hasTrait() return writer.format( - "try \$L.\$L(memberReadingClosure: \$L, memberNodeInfo: \$L, isFlattened: \$L)", + "try \$L.\$L(memberReadingClosure: \$L, memberNodeInfo: \$L, isFlattened: \$L)\$L", reader(memberShape, false), readMethodName("readList"), memberReadingClosure, memberNodeInfo, - isFlattened + isFlattened, + default(memberShape) ) } @@ -105,13 +114,14 @@ open class MemberShapeDecodeGenerator( val valueNodeInfo = nodeInfoUtils.nodeInfo(mapShape.value) val isFlattened = memberShape.hasTrait() return writer.format( - "try \$L.\$L(valueReadingClosure: \$L, keyNodeInfo: \$L, valueNodeInfo: \$L, isFlattened: \$L)", + "try \$L.\$L(valueReadingClosure: \$L, keyNodeInfo: \$L, valueNodeInfo: \$L, isFlattened: \$L)\$L", reader(memberShape, false), readMethodName("readMap"), valueReadingClosure, keyNodeInfo, valueNodeInfo, - isFlattened + isFlattened, + default(memberShape) ) } @@ -119,11 +129,12 @@ open class MemberShapeDecodeGenerator( val memberTimestampFormatTrait = memberShape.getTrait() val swiftTimestampFormatCase = TimestampUtils.timestampFormat(ctx, memberTimestampFormatTrait, timestampShape) return writer.format( - "try \$L.\$L(format: \$N\$L)", + "try \$L.\$L(format: \$N\$L)\$L", reader(memberShape, false), readMethodName("readTimestamp"), SmithyTimestampsTypes.TimestampFormat, swiftTimestampFormatCase, + default(memberShape) ) } @@ -149,6 +160,31 @@ open class MemberShapeDecodeGenerator( private fun default(memberShape: MemberShape): String { val targetShape = ctx.model.expectShape(memberShape.target) val defaultTrait = memberShape.getTrait() ?: targetShape.getTrait() + val requiredTrait = memberShape.getTrait() + // If member is required but there isn't a default value, use zero-equivalents for error correction + if (requiredTrait != null && defaultTrait == null) { + return when (targetShape) { + is EnumShape, is IntEnumShape -> " ?? .sdkUnknown(\"\")" + is StringShape -> { + // Enum trait is deprecated but many services still use it in their models + if (targetShape.hasTrait()) { + " ?? .sdkUnknown(\"\")" + } else { + " ?? \"\"" + } + } + is ByteShape, is ShortShape, is IntegerShape, is LongShape -> " ?? 0" + is FloatShape, is DoubleShape -> " ?? 0.0" + is BooleanShape -> " ?? false" + is ListShape -> " ?? []" + is MapShape -> " ?? [:]" + is TimestampShape -> resolveTimestampDefault(true, requiredTrait.toNode()) + is DocumentShape -> resolveDocumentDefault(true, requiredTrait.toNode()) + is BlobShape -> resolveBlobDefault(targetShape) + // No default provided for other types + else -> "" + } + } return defaultTrait?.toNode()?.let { // If the default value is null, provide no default. if (it.isNullNode) { return "" } @@ -170,6 +206,9 @@ open class MemberShapeDecodeGenerator( is ListShape, is SetShape -> " ?? []" // Maps can only have empty map as default value is MapShape -> " ?? [:]" + is TimestampShape -> resolveTimestampDefault(false, it) + is DocumentShape -> resolveDocumentDefault(false, it) + is BlobShape -> resolveBlobDefault(targetShape, it.toString()) // No default provided for other shapes else -> "" } @@ -195,4 +234,58 @@ open class MemberShapeDecodeGenerator( else -> "" } } + + private fun resolveBlobDefault(targetShape: Shape, value: String = ""): String { + writer.addImport(FoundationTypes.Data) + return if (targetShape.hasTrait()) { + writer.format( + " ?? \$N.data(\$N(\"$value\".utf8))", + SmithyTypes.ByteStream, + FoundationTypes.Data + ) + } else { + writer.format( + " ?? \$N(\"$value\".utf8)", + FoundationTypes.Data + ) + } + } + + private fun resolveDocumentDefault(useZeroValue: Boolean, node: Node): String { + return when { + node.isObjectNode -> writer.format(" ?? \$N.object([:])", SmithyReadWriteTypes.Document) + node.isArrayNode -> writer.format(" ?? \$N.array([])", SmithyReadWriteTypes.Document) + node.isStringNode -> { + val resolvedValue = "".takeIf { useZeroValue } ?: node.expectStringNode().value + writer.format(" ?? \$N.string(\"$resolvedValue\")", SmithyReadWriteTypes.Document) + } + node.isBooleanNode -> { + val resolvedValue = "false".takeIf { useZeroValue } ?: node.expectBooleanNode().value + writer.format(" ?? \$N.boolean($resolvedValue)", SmithyReadWriteTypes.Document) + } + node.isNumberNode -> { + val resolvedValue = "0".takeIf { useZeroValue } ?: node.expectNumberNode().value + writer.format(" ?? \$N.number($resolvedValue)", SmithyReadWriteTypes.Document) + } + else -> "" // null node type means no default value but explicit + } + } + + private fun resolveTimestampDefault(useZeroValue: Boolean, node: Node): String { + // Smithy validates that default value given to timestamp shape must either be a + // number (for epoch-seconds) or a date-time string compliant with RFC3339. + return if (node.isNumberNode) { + val value = "0".takeIf { useZeroValue } ?: node.expectNumberNode().value + writer.format( + " ?? \$N(timeIntervalSince1970: $value)", + FoundationTypes.Date + ) + } else { + val value = "1970-01-01T00:00:00Z".takeIf { useZeroValue } ?: node.expectStringNode().value + writer.format( + " ?? \$N(format: .dateTime).date(from: \"$value\")", + SmithyTimestampsTypes.TimestampFormatter + ) + } + } } diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/model/SymbolExt.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/model/SymbolExt.kt index 6d496a165..035fe7649 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/model/SymbolExt.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/model/SymbolExt.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.swift.codegen.model import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.codegen.core.SymbolProvider import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.swift.codegen.SwiftWriter import software.amazon.smithy.swift.codegen.removeSurroundingBackticks /** @@ -20,6 +21,9 @@ object SymbolProperty { // The key that holds the default value for a type (symbol) as a string const val DEFAULT_VALUE_KEY: String = "defaultValue" + // The key that holds the default value closure for a type (symbol) that returns a string + const val DEFAULT_VALUE_CLOSURE_KEY: String = "defaultValueClosure" + // Boolean property indicating this symbol should be boxed const val BOXED_KEY: String = "boxed" @@ -42,15 +46,36 @@ fun Symbol.isBoxed(): Boolean { * Gets the default value for the symbol if present, else null */ fun Symbol.defaultValue(): String? { - // boxed types should always be defaulted to null - if (isBoxed()) { + val default = getProperty(SymbolProperty.DEFAULT_VALUE_KEY, String::class.java) + + // If shape is boxed (nullable) AND there is no default value set, return nil as default value + if (isBoxed() && !default.isPresent) { return "nil" } - val default = getProperty(SymbolProperty.DEFAULT_VALUE_KEY, String::class.java) + // If default value is present, return default value. Otherwise, return null return if (default.isPresent) default.get() else null } +/** + * Gets the default value for the symbol by processing closure if present, else null + */ +fun Symbol.defaultValueFromClosure(writer: SwiftWriter): String? { + val default = getProperty(SymbolProperty.DEFAULT_VALUE_CLOSURE_KEY) + + // If shape is boxed (nullable) AND there is no default value set, return nil as default value + if (isBoxed() && !default.isPresent) { + return "nil" + } + + // Suppress the warning and force-cast the closure to the expected type + @Suppress("UNCHECKED_CAST") + return if (default.isPresent) { + val closure = default.get() as Function1 + closure(writer) + } else null +} + /** * Mark a symbol as being boxed (nullable) i.e. `T?` */ @@ -61,6 +86,19 @@ fun Symbol.Builder.boxed(): Symbol.Builder = apply { putProperty(SymbolProperty. */ fun Symbol.Builder.defaultValue(value: String): Symbol.Builder = apply { putProperty(SymbolProperty.DEFAULT_VALUE_KEY, value) } +/** + * Set the closure that gets called with a SwiftWriter to import symbols needed for default value + * Used in SwiftSymbolProvider (which doesn't have access to SwiftWriter). + * Allows default value of a symbol X returned by SwiftSymbolProvider to have needed imports + * for symbol Y at the time symbol X is printed by SwiftWriter. + * + * Example: Default value for a symbol called X could be "Data()", which means when the symbol X is printed by SwiftWriter, + * we need to import Foundation.Data. + */ +fun Symbol.Builder.defaultValueClosure(closure: (SwiftWriter) -> String): Symbol.Builder = apply { + putProperty(SymbolProperty.DEFAULT_VALUE_CLOSURE_KEY, closure) +} + fun SymbolProvider.toMemberNames(shape: MemberShape): Pair { val escapedName = toMemberName(shape) return Pair(escapedName, escapedName.removeSurroundingBackticks()) diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/swiftmodules/SmithyTypes.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/swiftmodules/SmithyTypes.kt index 3a7ea73c8..13270177c 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/swiftmodules/SmithyTypes.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/swiftmodules/SmithyTypes.kt @@ -11,6 +11,7 @@ import software.amazon.smithy.swift.codegen.SwiftDependency * NOTE: Not all symbols need be added here but it doesn't hurt to define runtime symbols once. */ object SmithyTypes { + val ByteStream = runtimeSymbol("ByteStream", SwiftDeclaration.ENUM) val Attributes = runtimeSymbol("Attributes", SwiftDeclaration.STRUCT) val AttributeKey = runtimeSymbol("AttributeKey", SwiftDeclaration.STRUCT) val ClientError = runtimeSymbol("ClientError", SwiftDeclaration.ENUM)