Skip to content

Commit

Permalink
Add new SwiftSymbol features
Browse files Browse the repository at this point in the history
  • Loading branch information
jbelkins committed Aug 30, 2024
1 parent e1621de commit ab14ee8
Show file tree
Hide file tree
Showing 35 changed files with 105 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,19 @@ class SwiftImportContainer : ImportContainer {
fun addImport(
packageName: String,
isTestable: Boolean = false,
internalSPIName: String? = null,
internalSPINames: List<String> = emptyList(),
importOnlyIfCanImport: Boolean = false
) {
if (packageName.isEmpty()) { return }
importStatements.find { it.packageName == packageName }?.let {
// Update isTestable to true if needed
it.isTestable = it.isTestable || isTestable
// If we have an existing import with the same package name, then add the SPI name to the existing list
if (internalSPIName != null) {
it.internalSPINames.add(internalSPIName)
}
internalSPINames.forEach { name -> it.internalSPINames.add(name) }
// Update importOnlyIfCanImport to true if needed
it.importOnlyIfCanImport = it.importOnlyIfCanImport || importOnlyIfCanImport
} ?: run {
val internalSPINames = listOf(internalSPIName).mapNotNull { it }.toMutableSet()
importStatements.add(ImportStatement(packageName, isTestable, internalSPINames, importOnlyIfCanImport))
importStatements.add(ImportStatement(packageName, isTestable, internalSPINames.toMutableSet(), importOnlyIfCanImport))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import software.amazon.smithy.swift.codegen.model.isGeneric
import software.amazon.smithy.swift.codegen.model.isOptional
import software.amazon.smithy.swift.codegen.model.isServiceNestedNamespace
import java.util.function.BiFunction
import kotlin.jvm.optionals.getOrElse
import kotlin.jvm.optionals.getOrNull

/**
Expand Down Expand Up @@ -104,7 +105,7 @@ class SwiftWriter(
fun addImport(symbol: Symbol) {
symbol.references.forEach { addImport(it.symbol) }
if (symbol.isBuiltIn || symbol.isServiceNestedNamespace || symbol.namespace.isEmpty()) return
val spiName = symbol.getProperty("spiName").getOrNull()?.toString()
val spiNames = symbol.getProperty("spiNames").getOrElse { emptyList<String>() } as List<String>
val decl = symbol.getProperty("decl").getOrNull()?.toString()
decl?.let {
// No need to import Foundation types individually because:
Expand All @@ -118,21 +119,23 @@ class SwiftWriter(
if (symbol.namespace == "Foundation") {
addImport(symbol.namespace)
} else {
addImport("$it ${symbol.namespace}.${symbol.name}", internalSPIName = spiName)
addImport("$it ${symbol.namespace}.${symbol.name}", internalSPINames = spiNames)
}
} ?: run {
addImport(symbol.namespace, internalSPIName = spiName)
addImport(symbol.namespace, internalSPINames = spiNames)
}
symbol.dependencies.forEach { addDependency(it) }
val additionalImports = symbol.getProperty("additionalImports").getOrElse { emptyList<Symbol>() } as List<Symbol>
additionalImports.forEach { addImport(it) }
}

fun addImport(
packageName: String,
isTestable: Boolean = false,
internalSPIName: String? = null,
internalSPINames: List<String> = emptyList(),
importOnlyIfCanImport: Boolean = false
) {
importContainer.addImport(packageName, isTestable, internalSPIName, importOnlyIfCanImport)
importContainer.addImport(packageName, isTestable, internalSPINames, importOnlyIfCanImport)
}

fun addImportReferences(symbol: Symbol, vararg options: SymbolReference.ContextOption) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class DefaultClientConfiguration : ClientConfiguration {
),
ConfigProperty(
"interceptorProviders",
ClientRuntimeTypes.Core.InterceptorProviders,
ClientRuntimeTypes.Composite.InterceptorProviders,
{ "[]" },
accessModifier = AccessModifier.PublicPrivateSet
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DefaultHttpClientConfiguration : ClientConfiguration {
),
ConfigProperty(
"httpInterceptorProviders",
ClientRuntimeTypes.Core.HttpInterceptorProviders,
ClientRuntimeTypes.Composite.HttpInterceptorProviders,
{ "[]" },
accessModifier = AccessModifier.PublicPrivateSet
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ class MessageMarshallableGenerator(
val nodeInfoUtils = NodeInfoUtils(ctx, writer, ctx.service.responseWireProtocol)
val rootNodeInfo = nodeInfoUtils.nodeInfo(memberShape, true)
val valueWritingClosure = WritingClosureUtils(ctx, writer).writingClosure(memberShape)
writer.addImport(SmithyReadWriteTypes.SmithyReader)
writer.write(
"payload = try \$N.write(value.\$L, rootNodeInfo: \$L, with: \$L)",
ctx.service.writerSymbol,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ class MessageUnmarshallableGenerator(

private fun renderReadToValue(writer: SwiftWriter, memberShape: MemberShape) {
val readingClosure = ReadingClosureUtils(ctx, writer).readingClosure(memberShape)
writer.addImport(SmithyReadWriteTypes.SmithyReader)
writer.write(
"let value = try \$N.readFrom(message.payload, with: \$L)",
ctx.service.readerSymbol,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class HTTPResponseBindingErrorGenerator(
.map { ctx.model.expectShape(it) as StructureShape }
.toSet()
.sorted()
writer.addImport(SwiftSymbol.make("ClientRuntime", null, SwiftDependency.CLIENT_RUNTIME, "SmithyReadWrite"))
writer.addImport(SwiftSymbol.make("ClientRuntime", null, SwiftDependency.CLIENT_RUNTIME, emptyList(), listOf("SmithyReadWrite")))
writer.write("let data = try await httpResponse.data()")
writer.write("let responseReader = try \$N.from(data: data)", ctx.service.readerSymbol)
val noErrorWrapping = ctx.service.getTrait<RestXmlTrait>()?.isNoErrorWrapping ?: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class HTTPResponseBindingErrorInitGenerator(
errorShape,
) {
if (needsReader) {
writer.addImport(SmithyReadWriteTypes.SmithyReader)
writer.addImport(ctx.service.readerSymbol)
writer.write("let reader = baseError.errorBodyReader")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class HTTPResponseBindingOutputGenerator(
writer.write("return \$N()", outputSymbol)
} else {
if (needsAReader(ctx, responseBindings)) {
writer.addImport(SwiftSymbol.make("ClientRuntime", null, SwiftDependency.CLIENT_RUNTIME, "SmithyReadWrite"))
writer.addImport(SwiftSymbol.make("ClientRuntime", null, SwiftDependency.CLIENT_RUNTIME, emptyList(), listOf("SmithyReadWrite")))
writer.write("let data = try await httpResponse.data()")
writer.write("let responseReader = try \$N.from(data: data)", ctx.service.readerSymbol)
writer.write("let reader = \$L", reader(ctx, op, writer))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ open class MemberShapeDecodeGenerator(
}
val memberName = ctx.symbolProvider.toMemberName(member)
if (shapeContainingMembers.isUnionShape) {
writer.addImport(SmithyReadWriteTypes.SmithyReader)
writer.write("return .\$L(\$L)", memberName, readExp)
} else if (shapeContainingMembers.isError) {
writer.write("value.properties.\$L = \$L", memberName, readExp)
Expand All @@ -77,7 +76,6 @@ open class MemberShapeDecodeGenerator(

private fun renderStructOrUnionExp(memberShape: MemberShape, isPayload: Boolean): String {
val readingClosure = readingClosureUtils.readingClosure(memberShape)
writer.addImport(SmithyReadWriteTypes.SmithyReader)
return writer.format(
"try \$L.\$L(with: \$L)",
reader(memberShape, isPayload),
Expand Down Expand Up @@ -121,17 +119,16 @@ open class MemberShapeDecodeGenerator(
private fun renderTimestampExp(memberShape: MemberShape, timestampShape: TimestampShape): String {
val memberTimestampFormatTrait = memberShape.getTrait<TimestampFormatTrait>()
val swiftTimestampFormatCase = TimestampUtils.timestampFormat(ctx, memberTimestampFormatTrait, timestampShape)
writer.addImport(SmithyTimestampsTypes.TimestampFormat)
return writer.format(
"try \$L.\$L(format: \$L)",
"try \$L.\$L(format: \$N\$L)",
reader(memberShape, false),
readMethodName("readTimestamp"),
swiftTimestampFormatCase
SmithyTimestampsTypes.TimestampFormat,
swiftTimestampFormatCase,
)
}

private fun renderMemberExp(memberShape: MemberShape, isPayload: Boolean): String {
writer.addImport(SmithyReadWriteTypes.SmithyReader)
return writer.format(
"try \$L.\$L()\$L",
reader(memberShape, isPayload),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ abstract class MemberShapeEncodeGenerator(
val memberName = ctx.symbolProvider.toMemberName(memberShape)
val propertyKey = nodeInfoUtils.nodeInfo(memberShape)
val writingClosure = writingClosureUtils.writingClosure(memberShape)
writer.addImport(SmithyReadWriteTypes.SmithyWriter)
writer.write(
"try writer[\$L].write(\$L\$L, with: \$L)",
propertyKey,
Expand All @@ -81,20 +80,19 @@ abstract class MemberShapeEncodeGenerator(
val timestampKey = nodeInfoUtils.nodeInfo(memberShape)
val memberTimestampFormatTrait = memberShape.getTrait<TimestampFormatTrait>()
val swiftTimestampFormatCase = TimestampUtils.timestampFormat(ctx, memberTimestampFormatTrait, timestampShape)
writer.addImport(SmithyTimestampsTypes.TimestampFormat)
writer.write(
"try writer[\$L].writeTimestamp(\$L\$L, format: \$L)",
"try writer[\$L].writeTimestamp(\$L\$L, format: \$N\$L)",
timestampKey,
prefix,
memberName,
swiftTimestampFormatCase
SmithyTimestampsTypes.TimestampFormat,
swiftTimestampFormatCase,
)
}

private fun writePropertyMember(memberShape: MemberShape, prefix: String) {
val propertyNodeInfo = nodeInfoUtils.nodeInfo(memberShape)
val memberName = ctx.symbolProvider.toMemberName(memberShape)
writer.addImport(SmithyReadWriteTypes.SmithyWriter)
writer.write(
"try writer[\$L].write(\$L\$L)",
propertyNodeInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ class ReadingClosureUtils(
)
}
shape is TimestampShape -> {
writer.addImport(SmithyTimestampsTypes.TimestampFormat)
writer.format(
"\$N(format: \$L)",
"\$N(format: \$N\$L)",
SmithyReadWriteTypes.timestampReadingClosure,
TimestampUtils.timestampFormat(ctx, memberTimestampFormatTrait, shape)
SmithyTimestampsTypes.TimestampFormat,
TimestampUtils.timestampFormat(ctx, memberTimestampFormatTrait, shape),
)
}
shape is EnumShape || shape is IntEnumShape || shape.hasTrait<EnumTrait>() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ class WritingClosureUtils(
)
}
shape is TimestampShape -> {
writer.addImport(SmithyTimestampsTypes.TimestampFormat)
writer.format(
"\$N(format: \$L)",
"\$N(format: \$N\$L)",
SmithyReadWriteTypes.timestampWritingClosure,
TimestampUtils.timestampFormat(ctx, memberTimestampFormatTrait, shape)
SmithyTimestampsTypes.TimestampFormat,
TimestampUtils.timestampFormat(ctx, memberTimestampFormatTrait, shape),
)
}
shape is EnumShape || shape is IntEnumShape || shape.hasTrait<EnumTrait>() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class InitialRequestIntegration : SwiftIntegration {
val nodeInfoUtils = NodeInfoUtils(protocolGenerationContext, writer, protocolGenerationContext.service.requestWireProtocol)
val rootNodeInfo = nodeInfoUtils.nodeInfo(it, true)
val valueWritingClosure = WritingClosureUtils(protocolGenerationContext, writer).writingClosure(it)
writer.addImport(SmithyReadWriteTypes.SmithyWriter)
writer.write("let writer = \$N(nodeInfo: \$L)", protocolGenerationContext.service.writerSymbol, rootNodeInfo)
writer.write("try writer.write(self, with: \$L)", valueWritingClosure)
writer.write("let initialRequestPayload = try writer.data()")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package software.amazon.smithy.swift.codegen.swiftmodules
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.swift.codegen.SwiftDeclaration
import software.amazon.smithy.swift.codegen.SwiftDependency
import software.amazon.smithy.swift.codegen.swiftmodules.ClientRuntimeTypes.Core.HttpInterceptorProvider
import software.amazon.smithy.swift.codegen.swiftmodules.ClientRuntimeTypes.Core.InterceptorProvider

/**
* Commonly used runtime types. Provides a single definition of a runtime symbol such that codegen isn't littered
Expand All @@ -23,7 +25,7 @@ object ClientRuntimeTypes {
val ContentLengthMiddleware = runtimeSymbol("ContentLengthMiddleware", SwiftDeclaration.STRUCT)
val ContentTypeMiddleware = runtimeSymbol("ContentTypeMiddleware", SwiftDeclaration.STRUCT)
val ContentMD5Middleware = runtimeSymbol("ContentMD5Middleware", SwiftDeclaration.STRUCT)
val DeserializeMiddleware = runtimeSymbol("DeserializeMiddleware", SwiftDeclaration.STRUCT, "SmithyReadWrite")
val DeserializeMiddleware = runtimeSymbol("DeserializeMiddleware", SwiftDeclaration.STRUCT, emptyList(), listOf("SmithyReadWrite"))
val MutateHeadersMiddleware = runtimeSymbol("MutateHeadersMiddleware", SwiftDeclaration.STRUCT)
val URLHostMiddleware = runtimeSymbol("URLHostMiddleware", SwiftDeclaration.STRUCT)
val URLPathMiddleware = runtimeSymbol("URLPathMiddleware", SwiftDeclaration.STRUCT)
Expand All @@ -33,7 +35,7 @@ object ClientRuntimeTypes {
runtimeSymbol("IdempotencyTokenMiddleware", SwiftDeclaration.STRUCT)
val SignerMiddleware = runtimeSymbol("SignerMiddleware", SwiftDeclaration.STRUCT)
val AuthSchemeMiddleware = runtimeSymbol("AuthSchemeMiddleware", SwiftDeclaration.STRUCT)
val BodyMiddleware = runtimeSymbol("BodyMiddleware", SwiftDeclaration.STRUCT, "SmithyReadWrite")
val BodyMiddleware = runtimeSymbol("BodyMiddleware", SwiftDeclaration.STRUCT, emptyList(), listOf("SmithyReadWrite"))
val PayloadBodyMiddleware = runtimeSymbol("PayloadBodyMiddleware", SwiftDeclaration.STRUCT)
val EventStreamBodyMiddleware = runtimeSymbol("EventStreamBodyMiddleware", SwiftDeclaration.STRUCT)
val BlobStreamBodyMiddleware = runtimeSymbol("BlobStreamBodyMiddleware", SwiftDeclaration.STRUCT)
Expand Down Expand Up @@ -86,28 +88,26 @@ object ClientRuntimeTypes {
val splitHeaderListValues = runtimeSymbol("splitHeaderListValues", SwiftDeclaration.FUNC)
val splitHttpDateHeaderListValues = runtimeSymbol("splitHttpDateHeaderListValues", SwiftDeclaration.FUNC)
val OrchestratorBuilder = runtimeSymbol("OrchestratorBuilder", SwiftDeclaration.CLASS)
val InterceptorProviders = runtimeSymbolWithoutNamespace("[ClientRuntime.InterceptorProvider]")
val InterceptorProvider = runtimeSymbol("InterceptorProvider", SwiftDeclaration.PROTOCOL)
val HttpInterceptorProviders = runtimeSymbolWithoutNamespace("[ClientRuntime.HttpInterceptorProvider]")
val HttpInterceptorProvider = runtimeSymbol("HttpInterceptorProvider", SwiftDeclaration.PROTOCOL)
val HttpInterceptor = runtimeSymbol("HttpInterceptor", SwiftDeclaration.PROTOCOL)
}

object Composite {
val InterceptorProviders = runtimeSymbol("[ClientRuntime.InterceptorProvider]", null, listOf(InterceptorProvider))
val HttpInterceptorProviders = runtimeSymbol("[ClientRuntime.HttpInterceptorProvider]", null, listOf(HttpInterceptorProvider))
}
}

private fun runtimeSymbol(
name: String,
declaration: SwiftDeclaration,
spiName: String? = null,
declaration: SwiftDeclaration?,
additionalImports: List<Symbol> = emptyList(),
spiName: List<String> = emptyList(),
): Symbol = SwiftSymbol.make(
name,
declaration,
SwiftDependency.CLIENT_RUNTIME,
SwiftDependency.CLIENT_RUNTIME.takeIf { additionalImports.isEmpty() },
additionalImports,
spiName,
)

private fun runtimeSymbolWithoutNamespace(name: String, declaration: SwiftDeclaration? = null): Symbol = SwiftSymbol.make(
name,
declaration,
null,
null,
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ private fun runtimeSymbol(name: String, declaration: SwiftDeclaration? = null):
name,
declaration,
SwiftDependency.SMITHY_EVENT_STREAMS_API,
null,
emptyList(),
emptyList(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ private fun runtimeSymbol(name: String, declaration: SwiftDeclaration? = null):
name,
declaration,
SwiftDependency.SMITHY_EVENT_STREAMS,
null,
emptyList(),
emptyList(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@ import software.amazon.smithy.swift.codegen.SwiftDeclaration
import software.amazon.smithy.swift.codegen.SwiftDependency

object SmithyFormURLTypes {
val Writer = runtimeSymbol("Writer", SwiftDeclaration.CLASS)
val Writer = runtimeSymbol("Writer", SwiftDeclaration.CLASS, listOf(SmithyReadWriteTypes.SmithyWriter))
}

private fun runtimeSymbol(
name: String,
declaration: SwiftDeclaration,
): Symbol = SwiftSymbol.make(
additionalImports: List<Symbol> = emptyList(),
): Symbol = SwiftSymbol.make(
name,
declaration,
SwiftDependency.SMITHY_FORM_URL,
"SmithyReadWrite",
additionalImports,
listOf("SmithyReadWrite"),
)
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ private fun runtimeSymbol(name: String, declaration: SwiftDeclaration? = null):
name,
declaration,
SwiftDependency.SMITHY_HTTP_API,
null,
emptyList(),
emptyList(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@ private fun runtimeSymbol(name: String, declaration: SwiftDeclaration? = null):
name,
declaration,
SwiftDependency.SMITHY_HTTP_AUTH_API,
null,
)

private fun runtimeSymbolWithoutNamespace(name: String, declaration: SwiftDeclaration? = null): Symbol = SwiftSymbol.make(
name,
declaration,
null,
null,
emptyList(),
emptyList()
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ private fun runtimeSymbol(name: String, declaration: SwiftDeclaration? = null):
name,
declaration,
SwiftDependency.SMITHY_HTTP_AUTH,
null,
emptyList(),
emptyList(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ private fun runtimeSymbol(name: String, declaration: SwiftDeclaration? = null):
name,
declaration,
SwiftDependency.SMITHY_IDENTITY,
null,
emptyList(),
emptyList(),
)
Loading

0 comments on commit ab14ee8

Please sign in to comment.