Skip to content

Commit

Permalink
Related to #119:
Browse files Browse the repository at this point in the history
– Detection of simple byte overflow cases.
– Optimization of 8×8→16 multiplication on 6809.
– Multiplication optimizations on Z80.
  • Loading branch information
KarolS committed Aug 6, 2021
1 parent 7f6a0c6 commit 90e5360
Show file tree
Hide file tree
Showing 12 changed files with 252 additions and 9 deletions.
3 changes: 3 additions & 0 deletions src/main/scala/millfork/CompilationOptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ object Cpu extends Enumeration {
EnableBreakpoints,
UseOptimizationHints,
GenericWarnings,
ByteOverflowWarning,
UselessCodeWarning,
BuggyCodeWarning,
FallbackValueUseWarning,
Expand Down Expand Up @@ -585,6 +586,7 @@ object CompilationFlag extends Enumeration {
SingleThreaded,
// warning options
GenericWarnings,
ByteOverflowWarning,
UselessCodeWarning,
BuggyCodeWarning,
DeprecationWarning,
Expand All @@ -603,6 +605,7 @@ object CompilationFlag extends Enumeration {

val allWarnings: Set[CompilationFlag.Value] = Set(
GenericWarnings,
ByteOverflowWarning,
UselessCodeWarning,
BuggyCodeWarning,
DeprecationWarning,
Expand Down
4 changes: 4 additions & 0 deletions src/main/scala/millfork/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,10 @@ object Main {
c.changeFlag(CompilationFlag.RorWarning, v)
}.description("Whether should warn about the ROR instruction (6502 only). Default: disabled.")

boolean("-Woverflow", "-Wno-overflow").repeatable().action { (c, v) =>
c.changeFlag(CompilationFlag.ByteOverflowWarning, v)
}.description("Whether should warn about byte overflow. Default: enabled.")

boolean("-Wuseless", "-Wno-useless").repeatable().action { (c, v) =>
c.changeFlag(CompilationFlag.UselessCodeWarning, v)
}.description("Whether should warn about code that does nothing. Default: enabled.")
Expand Down
29 changes: 29 additions & 0 deletions src/main/scala/millfork/compiler/AbstractExpressionCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,42 @@ import millfork.error.{ConsoleLogger, Logger}
import millfork.assembly.AbstractCode
import millfork.output.NoAlignment

import scala.collection.mutable.ListBuffer

/**
* @author Karol Stasiak
*/
class AbstractExpressionCompiler[T <: AbstractCode] {

def getExpressionType(ctx: CompilationContext, expr: Expression): Type = AbstractExpressionCompiler.getExpressionType(ctx, expr)

def extractWordExpandedBytes(ctx: CompilationContext, params:List[Expression]): Option[List[Expression]] = {
val result = ListBuffer[Expression]()
for(param <- params) {
if (ctx.env.eval(param).isDefined) return None
AbstractExpressionCompiler.getExpressionType(ctx, param) match {
case t: PlainType if t.size == 1 && !t.isSigned =>
result += param
case t: PlainType if t.size == 2 =>
param match {
case FunctionCallExpression(functionName, List(inner)) =>
AbstractExpressionCompiler.getExpressionType(ctx, inner) match {
case t: PlainType if t.size == 1 && !t.isSigned =>
ctx.env.maybeGet[Type](functionName) match {
case Some(tw: PlainType) if tw.size == 2 =>
result += inner
case _ => return None
}
case _ => return None
}
case _ => return None
}
case _ => return None
}
}
Some(result.toList)
}

def assertAllArithmetic(ctx: CompilationContext,expressions: List[Expression], booleanHint: String = ""): Unit = {
for(e <- expressions) {
val typ = getExpressionType(ctx, e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ abstract class AbstractStatementPreprocessor(protected val ctx: CompilationConte

case _ =>
}
new OverflowDetector(ctx).detectOverflow(stmt)
stmt match {
case Assignment(ve@VariableExpression(v), arg) if trackableVars(v) =>
cv = search(arg, cv)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import millfork.assembly.m6809.{Absolute, DAccumulatorIndexed, Immediate, Indexe
import millfork.compiler.{AbstractExpressionCompiler, BranchIfFalse, BranchIfTrue, BranchSpec, ComparisonType, CompilationContext, NoBranching}
import millfork.node.{DerefExpression, Expression, FunctionCallExpression, GeneratedConstantExpression, IndexedExpression, LhsExpression, LiteralExpression, M6809Register, SeparateBytesExpression, SumExpression, VariableExpression}
import millfork.assembly.m6809.MOpcode._
import millfork.env.{AssemblyOrMacroParamSignature, BuiltInBooleanType, Constant, ConstantBooleanType, ConstantPointy, ExternFunction, FatBooleanType, FlagBooleanType, FunctionInMemory, FunctionPointerType, KernalInterruptPointerType, Label, M6809RegisterVariable, MacroFunction, MathOperator, MemoryAddressConstant, MemoryVariable, NonFatalCompilationException, NormalFunction, NormalParamSignature, NumericConstant, StackOffsetThing, StackVariable, StackVariablePointy, StructureConstant, Thing, ThingInMemory, Type, Variable, VariableInMemory, VariableLikeThing, VariablePointy, VariableType}
import millfork.env.{AssemblyOrMacroParamSignature, BuiltInBooleanType, Constant, ConstantBooleanType, ConstantPointy, ExternFunction, FatBooleanType, FlagBooleanType, FunctionInMemory, FunctionPointerType, KernalInterruptPointerType, Label, M6809RegisterVariable, MacroFunction, MathOperator, MemoryAddressConstant, MemoryVariable, NonFatalCompilationException, NormalFunction, NormalParamSignature, NumericConstant, PlainType, StackOffsetThing, StackVariable, StackVariablePointy, StructureConstant, Thing, ThingInMemory, Type, Variable, VariableInMemory, VariableLikeThing, VariablePointy, VariableType}

import scala.collection.GenTraversableOnce

Expand Down Expand Up @@ -292,7 +292,13 @@ object M6809ExpressionCompiler extends AbstractExpressionCompiler[MLine] {
assertSizesForMultiplication(ctx, params, inPlace = false)
getArithmeticParamMaxSize(ctx, params) match {
case 1 => M6809MulDiv.compileByteMultiplication(ctx, params, updateDerefX = false) ++ targetifyB(ctx, target, isSigned = false)
case 2 => M6809MulDiv.compileWordMultiplication(ctx, params, updateDerefX = false) ++ targetifyD(ctx, target)
case 2 =>
extractWordExpandedBytes(ctx, params) match {
case Some(byteParams) if byteParams.size == 2 =>
M6809MulDiv.compileByteMultiplication(ctx, byteParams, updateDerefX = false) ++ targetifyD(ctx, target)
case _ =>
M6809MulDiv.compileWordMultiplication(ctx, params, updateDerefX = false) ++ targetifyD(ctx, target)
}
case 0 => Nil
case _ =>
ctx.log.error("Multiplication of variables larger than 2 bytes is not supported", expr.position)
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/millfork/compiler/m6809/M6809MulDiv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ import scala.collection.mutable.ListBuffer
*/
object M6809MulDiv {

def compileByteMultiplication(ctx: CompilationContext, params: List[Expression], updateDerefX: Boolean): List[MLine] = {
def compileByteMultiplication(ctx: CompilationContext, params: List[Expression], updateDerefX: Boolean, forceMul: Boolean = false): List[MLine] = {
var constant = Constant.One
val variablePart = params.flatMap { p =>
val variablePart = if(forceMul) params else params.flatMap { p =>
ctx.env.eval(p) match {
case Some(c) =>
constant = CompoundConstant(MathOperator.Times, constant, c).quickSimplify
Expand Down
15 changes: 12 additions & 3 deletions src/main/scala/millfork/compiler/mos/PseudoregisterBuiltIns.scala
Original file line number Diff line number Diff line change
Expand Up @@ -730,21 +730,30 @@ object PseudoregisterBuiltIns {
case (1, 1) => // ok
case _ => ctx.log.fatal("Invalid code path", param2.position)
}
val b = ctx.env.get[Type]("byte")
val w = ctx.env.get[Type]("word")
val reg = ctx.env.get[VariableInMemory]("__reg")
if (!storeInRegLo && param1OrRegister.isDefined) {
(ctx.env.eval(param1OrRegister.get), ctx.env.eval(param2)) match {
case (Some(l), Some(r)) =>
val product = CompoundConstant(MathOperator.Times, l, r).quickSimplify
return List(AssemblyLine.immediate(LDA, product.loByte), AssemblyLine.immediate(LDX, product.hiByte))
case (Some(NumericConstant(2, _)), _) =>
val evalParam2 = MosExpressionCompiler.compile(ctx, param2, Some(b -> RegisterVariable(MosRegister.A, b)), BranchSpec.None)
val label = ctx.nextLabel("sh")
return evalParam2 ++ List(
AssemblyLine.implied(ASL),
AssemblyLine.immediate(LDX, 0),
AssemblyLine.relative(BCC, label),
AssemblyLine.implied(INX),
AssemblyLine.label(label))
case (Some(NumericConstant(c, _)), _) if isPowerOfTwoUpTo15(c)=>
return compileWordShiftOps(left = true, ctx, param2, LiteralExpression(java.lang.Long.bitCount(c - 1), 1))
case (_, Some(NumericConstant(c, _))) if isPowerOfTwoUpTo15(c)=>
return compileWordShiftOps(left = true, ctx, param1OrRegister.get, LiteralExpression(java.lang.Long.bitCount(c - 1), 1))
case _ =>
}
}
val b = ctx.env.get[Type]("byte")
val w = ctx.env.get[Type]("word")
val reg = ctx.env.get[VariableInMemory]("__reg")
val load: List[AssemblyLine] = param1OrRegister match {
case Some(param1) =>
val code1 = MosExpressionCompiler.compile(ctx, param1, Some(w -> RegisterVariable(MosRegister.AX, w)), BranchSpec.None)
Expand Down
10 changes: 8 additions & 2 deletions src/main/scala/millfork/compiler/z80/Z80Multiply.scala
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,14 @@ object Z80Multiply {
case (1, 1) => // ok
case _ => ctx.log.fatal("Invalid code path", l.position)
}
ctx.env.eval(r) match {
case Some(c) =>
(ctx.env.eval(l), ctx.env.eval(r)) match {
case (Some(p), Some(q)) =>
List(ZLine.ldImm16(ZRegister.HL, CompoundConstant(MathOperator.Times, p, q).quickSimplify))
case (Some(NumericConstant(c, _)), _) if isPowerOfTwoUpTo15(c) =>
Z80ExpressionCompiler.compileToHL(ctx, l) ++ List.fill(Integer.numberOfTrailingZeros(c.toInt))(ZLine.registers(ZOpcode.ADD_16, ZRegister.HL, ZRegister.HL))
case (_, Some(NumericConstant(c, _))) if isPowerOfTwoUpTo15(c) =>
Z80ExpressionCompiler.compileToHL(ctx, l) ++ List.fill(Integer.numberOfTrailingZeros(c.toInt))(ZLine.registers(ZOpcode.ADD_16, ZRegister.HL, ZRegister.HL))
case (_, Some(c)) =>
Z80ExpressionCompiler.compileToDE(ctx, l) ++ List(ZLine.ldImm8(ZRegister.A, c)) ++ multiplication16And8(ctx)
case _ =>
val lw = Z80ExpressionCompiler.compileToDE(ctx, l)
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/millfork/env/Environment.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1929,6 +1929,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}

def registerArray(stmt: ArrayDeclarationStatement, options: CompilationOptions): Unit = {
new OverflowDetector(this, options).detectOverflow(stmt)
if (options.flag(CompilationFlag.LUnixRelocatableCode) && stmt.alignment.exists(_.isMultiplePages)) {
log.error("Invalid alignment for LUnix code", stmt.position)
}
Expand Down Expand Up @@ -2090,6 +2091,7 @@ class Environment(val parent: Option[Environment], val prefix: String, val cpuFa
}

def registerVariable(stmt: VariableDeclarationStatement, options: CompilationOptions, isPointy: Boolean): Unit = {
new OverflowDetector(this, options).detectOverflow(stmt)
val name = stmt.name
val position = stmt.position
if (name == "" || name.contains(".") && !name.contains(".return")) {
Expand Down
134 changes: 134 additions & 0 deletions src/main/scala/millfork/env/OverflowDetector.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package millfork.env

import millfork.{CompilationFlag, CompilationOptions}
import millfork.compiler.{AbstractExpressionCompiler, CompilationContext}
import millfork.error.Logger
import millfork.node._

/**
* @author Karol Stasiak
*/
class OverflowDetector(env: Environment, options: CompilationOptions) {

def this(ctx: CompilationContext) {
this(ctx.env, ctx.options)
}

private def log: Logger = options.log

private def isWord(e: Expression): Boolean =
AbstractExpressionCompiler.getExpressionType(env, log, e) match {
case t: PlainType => t.size == 2
case _ => false
}

private def isWord(typeName: String): Boolean =
env.maybeGet[Thing](typeName) match {
case Some(t: PlainType) => t.size == 2
case _ => false
}

private def isWord(typ: Type): Boolean =
typ match {
case t: PlainType => t.size == 2
case _ => false
}

private def isByte(e: Expression): Boolean =
AbstractExpressionCompiler.getExpressionType(env, log, e) match {
case t: PlainType => t.size == 1
case _ => false
}

def warnConstantOverflow(e: Expression, op: String): Unit = {
if (options.flag(CompilationFlag.ByteOverflowWarning)) {
log.warn(s"Constant byte overflow. Consider wrapping one of the arguments of $op with word( )", e.position)
}
}

def warnDynamicOverflow(e: Expression, op: String): Unit = {
if (options.flag(CompilationFlag.ByteOverflowWarning)) {
log.warn(s"Potential byte overflow. Consider wrapping one of the arguments of $op with word( )", e.position)
}
}

def scanExpression(e: Expression, willBeAssignedToWord: Boolean): Unit = {
if (willBeAssignedToWord) {
e match {
case FunctionCallExpression("<<", List(l, r)) =>
if (isByte(l) && isByte(r)) {
(env.eval(l), env.eval(r)) match {
case (Some(NumericConstant(lc, 1)), Some(NumericConstant(rc, 1))) =>
if (lc >= 0 && rc >= 0 && (lc << rc) > 255) {
warnConstantOverflow(e, "<<")
}
case (_, Some(NumericConstant(0, _))) =>
case _ =>
warnDynamicOverflow(e, "<<")
}
}
case FunctionCallExpression("*", List(l, r)) =>
if (isByte(l) && isByte(r)) {
(env.eval(l), env.eval(r)) match {
case (Some(NumericConstant(lc, 1)), Some(NumericConstant(rc, 1))) =>
if (lc >= 0 && rc >= 0 && (lc * rc) > 255) {
warnConstantOverflow(e, "*")
}
case (_, Some(NumericConstant(0, _))) =>
case (_, Some(NumericConstant(1, _))) =>
case (Some(NumericConstant(0, _)), _) =>
case (Some(NumericConstant(1, _)), _) =>
case _ =>
warnDynamicOverflow(e, "*")
}
}
case FunctionCallExpression("word" | "unsigned16" | "signed16" | "pointer", List(SumExpression(expressions, _))) =>
if (expressions.map(_._2).forall(isByte)) {

}
case _ =>
}
}
e match {
case SumExpression(expressions, decimal) =>
if (willBeAssignedToWord && !decimal && isByte(e)) env.eval(e) match {
case Some(NumericConstant(n, _)) if n < -128 || n > 255 =>
warnConstantOverflow(e, "+")
case _ =>
}
for ((_, e) <- expressions) {
scanExpression(e, willBeAssignedToWord = willBeAssignedToWord)
}
case FunctionCallExpression("word" | "unsigned16" | "signed16" | "pointer", expressions) =>
expressions.foreach(x => scanExpression(x, willBeAssignedToWord = true))
case FunctionCallExpression("|" | "^" | "&" | "not", expressions) =>
expressions.foreach(x => scanExpression(x, willBeAssignedToWord = false))
case FunctionCallExpression(fname, expressions) =>
env.maybeGet[Thing](fname) match {
case Some(f: FunctionInMemory) if f.params.length == expressions.length =>
for ((e, t) <- expressions zip f.params.types) {
scanExpression(e, willBeAssignedToWord = isWord(t))
}
case _ =>
for (e <- expressions) {
scanExpression(e, willBeAssignedToWord = false)
}
}
case _ =>
}
}

def detectOverflow(stmt: Statement): Unit = {
stmt match {
case Assignment(lhs, rhs) =>
if (isWord(lhs)) scanExpression(rhs, willBeAssignedToWord = true)
case v: VariableDeclarationStatement =>
v.initialValue match {
case Some(e) => scanExpression(e, willBeAssignedToWord = isWord(v.typ))
case _ =>
}
case s =>
s.getAllExpressions.foreach(e => scanExpression(e, willBeAssignedToWord = false))
}
}
}
17 changes: 17 additions & 0 deletions src/test/scala/millfork/test/ByteMathSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -465,4 +465,21 @@ class ByteMathSuite extends FunSuite with Matchers with AppendedClues {
m.readByte(0xc000) should equal(125)
}
}

test("Optimal multiplication detection") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos, Cpu.Z80, Cpu.Motorola6809)(
"""
| import zp_reg
| word output @$c000
| noinline void run(byte a, byte b) {
| output = word(a) * b
| }
| void main () {
| run(100, 42)
| }
""".
stripMargin) { m =>
m.readWord(0xc000) should equal(4200)
}
}
}
32 changes: 32 additions & 0 deletions src/test/scala/millfork/test/WarningSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,36 @@ class WarningSuite extends FunSuite with Matchers {
""".stripMargin) { m =>
}
}

test("Warn about unintended byte overflow") {
EmuUnoptimizedCrossPlatformRun(Cpu.Mos)(
"""
| import zp_reg
| const word screenOffset = (10*40)+5
| noinline void func(byte x, byte y) {
| word screenOffset
| screenOffset = (x*40) + y
| }
| noinline word getNESScreenOffset(byte x, byte y) {
| word temp
| temp = (y << 5) +x
| }
| noinline word getSomeFunc(byte x, byte y, byte z) {
| word temp
| temp = ((x + z) << 2) + (y << 5)
| temp = byte((x + z) << 2) + (y << 5)
| }
|
| noinline byte someFunc(byte x, byte y) {
| return (x*y)-24
| }
| void main() {
| func(0,0)
| getNESScreenOffset(0,0)
| getSomeFunc(0,screenOffset.lo,5)
| someFunc(0,0)
| }
""".stripMargin) { m =>
}
}
}

0 comments on commit 90e5360

Please sign in to comment.