diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index add30d1189..61720db948 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -556,6 +556,10 @@ def _parse_if(self: Parser) -> t.Optional[exp.Expr]: # to parse a statement / command to support the macro @IF(condition, statement) index = self._index try: + if self.dialect == "tsql": + if not (self._index >= 2 and self._tokens[self._index - 2].text == "@"): + return self.__parse_if() # type: ignore + return Parser.__parse_if(self) # type: ignore return self.__parse_if() # type: ignore except ParseError: self._retreat(index) @@ -1123,8 +1127,8 @@ def extend_sqlglot() -> None: _override(Parser, _parse_value) _override(Parser, _parse_lambda) _override(Parser, _parse_types) - _override(TSQL.Parser, Parser._parse_if) _override(Parser, _parse_if) + _override(TSQL.Parser, Parser._parse_if) _override(Parser, _parse_id_var) _override(Parser, _warn_unsupported) _override(Snowflake.Parser, _parse_table_parts) diff --git a/tests/core/test_dialect.py b/tests/core/test_dialect.py index 3b8df28f8b..4393bccfad 100644 --- a/tests/core/test_dialect.py +++ b/tests/core/test_dialect.py @@ -707,6 +707,20 @@ def test_conditional_statement(): q = parse_one("@IF(cond, VACUUM ANALYZE);", read="postgres") assert q.sql(dialect="postgres") == "@IF(cond, VACUUM ANALYZE)" + # Verify that the original error case from issue #5823 (Required keyword: 'true' missing) is resolved. + # It must be parsed as a macro function containing an Anonymous expression rather than exp.If. + q = parse_one("@IF(1 = 1, ALTER TABLE x ADD y INT);", read="tsql") + assert q.sql(dialect="tsql") == "@IF(1 = 1, ALTER TABLE x ADD y INTEGER)" + assert isinstance(q.this, exp.Anonymous) + assert q.this.name == "IF" + + # Note: SQLGlot's fallback Command parser strips quotes from string literal tokens when parsing unparsed commands + q = parse_one("@IF(cond, PRINT 'hello');", read="tsql") + assert q.sql(dialect="tsql") == "@IF(cond, PRINT hello)" + + q = parse_one("@IF(@runtime_stage = 'evaluating', SELECT 1);", read="tsql") + assert q.sql(dialect="tsql") == "@IF(@runtime_stage = 'evaluating', SELECT 1)" + def test_model_name_cannot_be_string(): with pytest.raises(ParseError) as parse_error: