diff --git a/Database/HDBC/PostgreSQL/Parser.hs b/Database/HDBC/PostgreSQL/Parser.hs index 4c50d2e..6f77999 100644 --- a/Database/HDBC/PostgreSQL/Parser.hs +++ b/Database/HDBC/PostgreSQL/Parser.hs @@ -3,64 +3,97 @@ {- PostgreSQL uses $1, $2, etc. instead of ? in query strings. So we have to do some basic parsing on these things to fix 'em up. -} - module Database.HDBC.PostgreSQL.Parser where -import Text.ParserCombinators.Parsec - -escapeseq :: GenParser Char st String -escapeseq = (try $ string "''") <|> - (try $ string "\\'") +data ParserState = Clear + | Literal Bool + | QIdentifier + | CComment Int + | LineComment Int -literal :: GenParser Char st [Char] -literal = do _ <- char '\'' - s <- many (escapeseq <|> (noneOf "'" >>= (\x -> return [x]))) - _ <- char '\'' - return $ "'" ++ (concat s) ++ "'" +convertSQL :: String -> String +convertSQL = convertSQLAux "" 1 Clear -qidentifier :: GenParser Char st [Char] -qidentifier = do _ <- char '"' - s <- many (noneOf "\"") - _ <- char '"' - return $ "\"" ++ s ++ "\"" +ungetLiteral :: Bool -> String -> (String, String) +ungetLiteral prevBackSlash = aux "" + where aux :: String -> String -> (String, String) + aux acc "" = (reverse acc, "") -comment :: GenParser Char st [Char] -comment = ccomment <|> linecomment + aux acc ('\'':'\'':xs) = aux ('\'':'\'':acc) xs -ccomment :: GenParser Char st [Char] -ccomment = do _ <- string "/*" - c <- manyTill ((try ccomment) <|> - (anyChar >>= (\x -> return [x]))) - (try (string "*/")) - return $ "/*" ++ concat c ++ "*/" + aux acc s@('\'':'\\':xs) = + if prevBackSlash then + (reverse acc, s) + else + aux ('\\':'\'':acc) xs -linecomment :: GenParser Char st [Char] -linecomment = do _ <- string "--" - c <- many (noneOf "\n") - _ <- char '\n' - return $ "--" ++ c ++ "\n" + aux acc s@('\'':_) = (reverse acc, s) --- FIXME: handle pgsql dollar-quoted constants + aux acc (x:xs) = aux (x:acc) xs -qmark :: (Num st, Show st) => GenParser Char st [Char] -qmark = do _ <- char '?' - n <- getState - updateState (+1) - return $ "$" ++ show n +ungetQIdentifier :: String -> (String, String) +ungetQIdentifier = aux "" + where aux :: String -> String -> (String, String) + aux acc "" = (reverse acc, "") + aux acc s@('"':_) = (reverse acc, s) + aux acc (x:xs) = aux (x:acc) xs -escapedQmark :: GenParser Char st [Char] -escapedQmark = do _ <- try (char '\\' >> char '?') - return "?" +ungetLineComment :: Int -> String -> (String, String) +ungetLineComment = aux "" + where aux :: String -> Int -> String -> (String, String) + aux acc _ "" = (reverse acc, "") + aux acc 1 s@('-':'-':_) = (reverse acc, s) + aux acc level ('-':'-':xs) = aux ('-':'-':acc) (level - 1) xs + aux acc level (x:xs) = aux (x:acc) level xs -statement :: (Num st, Show st) => GenParser Char st [Char] -statement = - do s <- many ((try escapedQmark) <|> - (try qmark) <|> - (try comment) <|> - (try literal) <|> - (try qidentifier) <|> - (anyChar >>= (\x -> return [x]))) - return $ concat s +ungetCComment :: Int -> String -> (String, String) +ungetCComment = aux "" + where aux :: String -> Int -> String -> (String, String) + aux acc _ "" = (reverse acc, "") + aux acc 1 s@('*':'/':_) = (reverse acc, s) + aux acc level ('*':'/':xs) = aux ('/':'*':acc) (level - 1) xs + aux acc level (x:xs) = aux (x:acc) level xs -convertSQL :: String -> Either ParseError String -convertSQL input = runParser statement (1::Integer) "" input +convertSQLAux :: String -> Int -> ParserState -> String -> String +convertSQLAux acc _ Clear "" = reverse acc +convertSQLAux acc paramCount state input = + case state of + Clear -> + case input of + '?':xs -> convertSQLAux ((reverse $ show paramCount) ++ ('$':acc)) (paramCount + 1) Clear xs + '\\':'\'':xs -> convertSQLAux ('\'':'\\':acc) paramCount (Literal True) xs + '\'':xs -> convertSQLAux ('\'':acc) paramCount (Literal False) xs + '"':xs -> convertSQLAux ('"':acc) paramCount QIdentifier xs + '-':'-':xs -> convertSQLAux ('-':'-':acc) paramCount (LineComment 1) xs + '/':'*':xs -> convertSQLAux ('*':'/':acc) paramCount (CComment 1) xs + '\\':'?':xs -> convertSQLAux ('?':acc) paramCount Clear xs + x:xs -> convertSQLAux (x:acc) paramCount Clear xs + "" -> reverse acc + Literal prevBackSlash -> + case input of + '\'':'\'':xs -> convertSQLAux ('\'':'\'':acc) paramCount state xs + '\\':'\'':xs -> convertSQLAux ('\'':'\\':acc) paramCount state xs + '\'':xs -> convertSQLAux ('\'':acc) paramCount Clear xs + x:xs -> convertSQLAux (x:acc) paramCount state xs + "" -> let (literal, acc') = ungetLiteral prevBackSlash acc + in convertSQLAux acc' paramCount Clear $ reverse literal + QIdentifier -> + case input of + '"':xs -> convertSQLAux ('"':acc) paramCount Clear xs + x:xs -> convertSQLAux (x:acc) paramCount QIdentifier xs + "" -> let (qidentifier, acc') = ungetQIdentifier acc + in convertSQLAux acc' paramCount Clear $ reverse qidentifier + LineComment level -> + case input of + '\n':xs -> convertSQLAux ('\n':acc) paramCount Clear xs + '-':'-':xs -> convertSQLAux ('-':'-':acc) paramCount (LineComment $ level + 1) xs + x:xs -> convertSQLAux (x:acc) paramCount (LineComment level) xs + "" -> let (lineComment, acc') = ungetLineComment level acc + in convertSQLAux acc' paramCount Clear $ reverse lineComment + CComment level -> + case input of + '*':'/':xs -> convertSQLAux ('/':'*':acc) paramCount (if level == 1 then Clear else CComment $ level - 1) xs + '/':'*':xs -> convertSQLAux ('*':'/':acc) paramCount (CComment $ level + 1) xs + x:xs -> convertSQLAux (x:acc) paramCount state xs + "" -> let (cComment, acc') = ungetCComment level acc + in convertSQLAux acc' paramCount Clear $ reverse cComment diff --git a/Database/HDBC/PostgreSQL/Statement.hsc b/Database/HDBC/PostgreSQL/Statement.hsc index 2ee49ed..26fc0f8 100644 --- a/Database/HDBC/PostgreSQL/Statement.hsc +++ b/Database/HDBC/PostgreSQL/Statement.hsc @@ -45,15 +45,8 @@ newSth indbo mchildren query = newstomv <- newMVar Nothing newnextrowmv <- newMVar (-1) newcoldefmv <- newMVar [] - usequery <- case convertSQL query of - Left errstr -> throwSqlError $ SqlError - {seState = "", - seNativeError = (-1), - seErrorMsg = "hdbc prepare: " ++ - show errstr} - Right converted -> return converted let sstate = SState {stomv = newstomv, nextrowmv = newnextrowmv, - dbo = indbo, squery = usequery, + dbo = indbo, squery = convertSQL query, coldefmv = newcoldefmv} let retval = Statement {execute = fexecute sstate, diff --git a/HDBC-postgresql.cabal b/HDBC-postgresql.cabal index b0d7902..3268d6f 100644 --- a/HDBC-postgresql.cabal +++ b/HDBC-postgresql.cabal @@ -35,7 +35,7 @@ Library Database.HDBC.PostgreSQL.PTypeConv, Database.HDBC.PostgreSQL.ErrorCodes Extensions: ExistentialQuantification, ForeignFunctionInterface - Build-Depends: base >= 3 && < 5, mtl, HDBC>=2.2.0, parsec, utf8-string, + Build-Depends: base >= 3 && < 5, mtl, HDBC>=2.2.0, utf8-string, bytestring, old-time, old-locale, time, convertible if impl(ghc >= 6.9) Build-Depends: base >= 4 @@ -47,8 +47,8 @@ Library Executable runtests if flag(buildtests) Buildable: True - Build-Depends: HUnit, QuickCheck, testpack, containers, - convertible, time, old-locale, parsec, utf8-string, + Build-Depends: HUnit, QuickCheck < 2, testpack, containers, + convertible, time, old-locale, utf8-string, bytestring, old-time, base, HDBC>=2.2.6 else Buildable: False