diff --git a/extensions/sql-database-projects/src/common/constants.ts b/extensions/sql-database-projects/src/common/constants.ts index 3e76de1ae0..1b3f65ee1f 100644 --- a/extensions/sql-database-projects/src/common/constants.ts +++ b/extensions/sql-database-projects/src/common/constants.ts @@ -119,6 +119,8 @@ export const exampleUsage = localize('exampleUsage', "Example Usage"); export const enterSystemDbName = localize('enterSystemDbName', "Enter a database name for this system database"); export const databaseNameRequiredVariableOptional = localize('databaseNameRequiredVariableOptional', "A database name is required. The database variable is optional."); export const databaseNameServerNameVariableRequired = localize('databaseNameServerNameVariableRequired', "A database name, server name, and server variable are required. The database variable is optional"); +export const otherServer = 'OtherServer'; +export const otherSeverVariable = '$(OtherServer)'; export const databaseProject = localize('databaseProject', "Database project"); // Error messages @@ -166,6 +168,7 @@ export function unexpectedProjectContext(uri: string) { return localize('unexpec export function unableToPerformAction(action: string, uri: string) { return localize('unableToPerformAction', "Unable to locate '{0}' target: '{1}'", action, uri); } export function unableToFindObject(path: string, objType: string) { return localize('unableToFindFile', "Unable to find {1} with path '{0}'", path, objType); } export function deployScriptExists(scriptType: string) { return localize('deployScriptExists', "A {0} script already exists. The new script will not be included in build.", scriptType); } +export function notValidVariableName(name: string) { return localize('notValidVariableName', "The variable name '{0}' is not valid.", name); } export function cantAddCircularProjectReference(project: string) { return localize('cantAddCircularProjectReference', "A reference to project '{0} cannot be added. Adding this project as a reference would cause a circular dependency", project); } // Action types @@ -280,5 +283,5 @@ export const systemDbs = ['master', 'msdb', 'tempdb', 'model']; // SQL queries export const sameDatabaseExampleUsage = 'SELECT * FROM [Schema1].[Table1]'; -export function differentDbSameServerExampleUsage(db: string) { return `SELECT * FROM [${db}].[Schema1].[Table1]"`; } +export function differentDbSameServerExampleUsage(db: string) { return `SELECT * FROM [${db}].[Schema1].[Table1]`; } export function differentDbDifferentServerExampleUsage(server: string, db: string) { return `SELECT * FROM [${server}].[${db}].[Schema1].[Table1]`; } diff --git a/extensions/sql-database-projects/src/common/utils.ts b/extensions/sql-database-projects/src/common/utils.ts index 8e0fb1cec1..68cf04a10c 100644 --- a/extensions/sql-database-projects/src/common/utils.ts +++ b/extensions/sql-database-projects/src/common/utils.ts @@ -146,6 +146,77 @@ export function readSqlCmdVariables(xmlDoc: any): Record { } /** + * Removes $() around a sqlcmd variable + * @param name + */ +export function removeSqlCmdVariableFormatting(name: string | undefined): string { + if (!name || name === '') { + return ''; + } + + if (name.length > 3) { + // Trim in case we get " $(x)" + name = name.trim(); + let indexStart = name.startsWith('$(') ? 2 : 0; + let indexEnd = name.endsWith(')') ? 1 : 0; + if (indexStart > 0 || indexEnd > 0) { + name = name.substr(indexStart, name.length - indexEnd - indexStart); + } + } + + // Trim in case the customer types " $(x )" + return name.trim(); +} + +/** + * Format as sqlcmd variable by adding $() if necessary + * if the variable already starts with $(, then add ) + * @param name + */ +export function formatSqlCmdVariable(name: string): string { + if (!name || name === '') { + return name; + } + + // Trim in case we get " $(x)" + name = name.trim(); + + if (!name.startsWith('$(') && !name.endsWith(')')) { + name = `$(${name})`; + } else if (name.startsWith('$(') && !name.endsWith(')')) { + // add missing end parenthesis, same behavior as SSDT + name = `${name})`; + } + + return name; +} + +/** + * Checks if it's a valid sqlcmd variable name + * https://docs.microsoft.com/en-us/sql/ssms/scripting/sqlcmd-use-with-scripting-variables?redirectedfrom=MSDN&view=sql-server-ver15#guidelines-for-scripting-variable-names-and-values + * @param name variable name to validate + */ +export function isValidSqlCmdVariableName(name: string | undefined): boolean { + // remove $() around named if it's there + name = removeSqlCmdVariableFormatting(name); + + // can't contain whitespace + if (!name || name.trim() === '' || name.includes(' ')) { + return false; + } + + // can't contain these characters + if (name.includes('$') || name.includes('@') || name.includes('#') || name.includes('"') || name.includes('\'') || name.includes('-')) { + return false; + } + + // TODO: tsql parsing to check if it's a reserved keyword or invalid tsql https://github.com/microsoft/azuredatastudio/issues/12204 + // TODO: give more detail why variable name was invalid https://github.com/microsoft/azuredatastudio/issues/12231 + + return true; +} + +/* * Recursively gets all the sqlproj files at any depth in a folder * @param folderPath */ diff --git a/extensions/sql-database-projects/src/dialogs/addDatabaseReferenceDialog.ts b/extensions/sql-database-projects/src/dialogs/addDatabaseReferenceDialog.ts index fb135c61d1..e5db5ca839 100644 --- a/extensions/sql-database-projects/src/dialogs/addDatabaseReferenceDialog.ts +++ b/extensions/sql-database-projects/src/dialogs/addDatabaseReferenceDialog.ts @@ -7,12 +7,12 @@ import * as azdata from 'azdata'; import * as vscode from 'vscode'; import * as path from 'path'; import * as constants from '../common/constants'; +import * as utils from '../common/utils'; import { Project, SystemDatabase } from '../models/project'; import { cssStyles } from '../common/uiConstants'; import { IconPathHelper } from '../common/iconHelper'; import { ISystemDatabaseReferenceSettings, IDacpacReferenceSettings, IProjectReferenceSettings } from '../models/IDatabaseReferenceSettings'; -import { getSqlProjectFilesInFolder } from '../common/utils'; export enum ReferenceType { project, @@ -146,9 +146,9 @@ export class AddDatabaseReferenceDialog { referenceSettings = { databaseName: this.databaseNameTextbox?.value, dacpacFileLocation: vscode.Uri.file(this.dacpacTextbox?.value), - databaseVariable: this.databaseVariableTextbox?.value, + databaseVariable: utils.removeSqlCmdVariableFormatting(this.databaseVariableTextbox?.value), serverName: this.serverNameTextbox?.value, - serverVariable: this.serverVariableTextbox?.value, + serverVariable: utils.removeSqlCmdVariableFormatting(this.serverVariableTextbox?.value), suppressMissingDependenciesErrors: this.suppressMissingDependenciesErrorsCheckbox?.checked }; } @@ -260,13 +260,13 @@ export class AddDatabaseReferenceDialog { // get projects in workspace const workspaceFolders = vscode.workspace.workspaceFolders; if (workspaceFolders?.length) { - let projectFiles = await getSqlProjectFilesInFolder(workspaceFolders[0].uri.fsPath); + let projectFiles = await utils.getSqlProjectFilesInFolder(workspaceFolders[0].uri.fsPath); // check if current project is in same open folder (should only be able to add a reference to another project in // the folder if the current project is also in the folder) - if (projectFiles.find(p => p === this.project.projectFilePath)) { + if (projectFiles.find(p => p === utils.getPlatformSafeFileEntryPath(this.project.projectFilePath))) { // filter out current project - projectFiles = projectFiles.filter(p => p !== this.project.projectFilePath); + projectFiles = projectFiles.filter(p => p !== utils.getPlatformSafeFileEntryPath(this.project.projectFilePath)); projectFiles.forEach(p => { projectFiles[projectFiles.indexOf(p)] = path.parse(p).name; @@ -291,6 +291,10 @@ export class AddDatabaseReferenceDialog { ariaLabel: constants.databaseNameLabel }).component(); + this.systemDatabaseDropdown.onValueChanged(() => { + this.setDefaultDatabaseValues(); + }); + // only master is a valid system db reference for projects targetting Azure if (this.project.getProjectTargetPlatform().toLowerCase().includes('azure')) { this.systemDatabaseDropdown.values?.splice(1); @@ -310,6 +314,7 @@ export class AddDatabaseReferenceDialog { }).component(); this.dacpacTextbox.onTextChanged(() => { + this.setDefaultDatabaseValues(); this.tryEnableAddReferenceButton(); this.updateExampleUsage(); }); @@ -404,11 +409,39 @@ export class AddDatabaseReferenceDialog { this.databaseVariableTextbox!.value = isSystemDb ? '' : this.databaseVariableTextbox!.value; this.serverNameTextbox!.value = ''; this.serverVariableTextbox!.value = ''; + + // add default values in enabled fields + this.setDefaultDatabaseValues(); } else if (this.locationDropdown?.value === constants.differentDbDifferentServer) { this.databaseNameTextbox!.enabled = true; this.databaseVariableTextbox!.enabled = true; this.serverNameTextbox!.enabled = true; this.serverVariableTextbox!.enabled = true; + + // add default values in enabled fields + this.setDefaultDatabaseValues(); + this.serverNameTextbox!.value = constants.otherServer; + this.serverVariableTextbox!.value = constants.otherSeverVariable; + } + } + + private setDefaultDatabaseValues(): void { + switch (this.currentReferenceType) { + case ReferenceType.project: { + this.databaseNameTextbox!.value = this.projectDropdown?.value; + this.databaseVariableTextbox!.value = `$(${this.projectDropdown?.value})`; + break; + } + case ReferenceType.systemDb: { + this.databaseNameTextbox!.value = this.systemDatabaseDropdown?.value; + break; + } + case ReferenceType.dacpac: { + const dacpacName = this.dacpacTextbox!.value ? path.parse(this.dacpacTextbox!.value!).name : ''; + this.databaseNameTextbox!.value = dacpacName; + this.databaseVariableTextbox!.value = dacpacName ? `$(${dacpacName})` : ''; + break; + } } } @@ -430,6 +463,7 @@ export class AddDatabaseReferenceDialog { const serverVariableRow = this.view!.modelBuilder.flexContainer().withItems([this.createLabel(constants.serverVariable, true), this.serverVariableTextbox], { flex: '0 0 auto' }).withLayout({ flexFlow: 'row', alignItems: 'center' }).component(); const variableSection = this.view!.modelBuilder.flexContainer().withItems([databaseNameRow, databaseVariableRow, serverNameRow, serverVariableRow]).withLayout({ flexFlow: 'column' }).withProperties({ CSSStyles: { 'margin-bottom': '25px' } }).component(); + this.setDefaultDatabaseValues(); return { component: variableSection, @@ -490,7 +524,7 @@ export class AddDatabaseReferenceDialog { newText = this.currentReferenceType === ReferenceType.systemDb ? constants.enterSystemDbName : constants.databaseNameRequiredVariableOptional; fontStyle = cssStyles.fontStyle.italics; } else { - const db = this.databaseVariableTextbox?.value ? this.databaseVariableTextbox?.value : this.databaseNameTextbox.value; + const db = this.databaseVariableTextbox?.value ? utils.formatSqlCmdVariable(this.databaseVariableTextbox?.value) : this.databaseNameTextbox.value; newText = constants.differentDbSameServerExampleUsage(db); } break; @@ -500,18 +534,34 @@ export class AddDatabaseReferenceDialog { newText = constants.databaseNameServerNameVariableRequired; fontStyle = cssStyles.fontStyle.italics; } else { - const server = this.serverVariableTextbox.value; - const db = this.databaseVariableTextbox?.value ? this.databaseVariableTextbox?.value : this.databaseNameTextbox.value; + const server = utils.formatSqlCmdVariable(this.serverVariableTextbox.value); + const db = this.databaseVariableTextbox?.value ? utils.formatSqlCmdVariable(this.databaseVariableTextbox?.value) : this.databaseNameTextbox.value; newText = constants.differentDbDifferentServerExampleUsage(server, db); } break; } } + // check for invalid variables + if (!this.validSqlCmdVariables()) { + let invalidName = !utils.isValidSqlCmdVariableName(this.databaseVariableTextbox?.value) ? this.databaseVariableTextbox!.value! : this.serverVariableTextbox!.value!; + invalidName = utils.removeSqlCmdVariableFormatting(invalidName); + newText = constants.notValidVariableName(invalidName); + } + this.exampleUsage!.value = newText; this.exampleUsage?.updateCssStyles({ 'font-style': fontStyle }); } + private validSqlCmdVariables(): boolean { + if (this.databaseVariableTextbox?.enabled && this.databaseVariableTextbox?.value && !utils.isValidSqlCmdVariableName(this.databaseVariableTextbox?.value) + || this.serverVariableTextbox?.enabled && this.serverVariableTextbox?.value && !utils.isValidSqlCmdVariableName(this.serverVariableTextbox?.value)) { + return false; + } + + return true; + } + /** * Only enable Add reference button if all enabled fields are filled */ @@ -533,8 +583,9 @@ export class AddDatabaseReferenceDialog { } private dacpacRequiredFieldsFilled(): boolean { - return !!this.dacpacTextbox?.value && - ((this.locationDropdown?.value === constants.sameDatabase) + return !!this.dacpacTextbox?.value + && this.validSqlCmdVariables() + && ((this.locationDropdown?.value === constants.sameDatabase) || (this.locationDropdown?.value === constants.differentDbSameServer && this.differentDatabaseSameServerRequiredFieldsFilled()) || ((this.locationDropdown?.value === constants.differentDbDifferentServer && this.differentDatabaseDifferentServerRequiredFieldsFilled()))); } diff --git a/extensions/sql-database-projects/src/models/project.ts b/extensions/sql-database-projects/src/models/project.ts index 493e623424..a4322ea442 100644 --- a/extensions/sql-database-projects/src/models/project.ts +++ b/extensions/sql-database-projects/src/models/project.ts @@ -849,7 +849,7 @@ export class FileProjectEntry extends ProjectEntry { } public pathForSqlProj(): string { - return utils.convertSlashesForSqlProj(this.fsUri.path); + return utils.convertSlashesForSqlProj(this.fsUri.fsPath); } } diff --git a/extensions/sql-database-projects/src/models/publishProfile/publishProfile.ts b/extensions/sql-database-projects/src/models/publishProfile/publishProfile.ts index a758fe84bb..d9a38589c0 100644 --- a/extensions/sql-database-projects/src/models/publishProfile/publishProfile.ts +++ b/extensions/sql-database-projects/src/models/publishProfile/publishProfile.ts @@ -58,7 +58,7 @@ async function readConnectionString(xmlDoc: any): Promise<{ connectionId: string if (xmlDoc.documentElement.getElementsByTagName(constants.targetConnectionString).length > 0) { const targetConnectionString = xmlDoc.documentElement.getElementsByTagName(constants.TargetConnectionString)[0].textContent; - const dataSource = new SqlConnectionDataSource('temp', targetConnectionString); + const dataSource = new SqlConnectionDataSource('', targetConnectionString); let server: string = ''; let username: string = ''; const connectionProfile = dataSource.getConnectionProfile(); @@ -74,7 +74,7 @@ async function readConnectionString(xmlDoc: any): Promise<{ connectionId: string const connection = await azdata.connection.openConnectionDialog(undefined, connectionProfile); connId = connection.connectionId; server = connection.options['server']; - username = connection.options['username']; + username = connection.options['user']; } targetConnection = `${server} (${username})`; diff --git a/extensions/sql-database-projects/src/test/dialogs/addDatabaseReferenceDialog.test.ts b/extensions/sql-database-projects/src/test/dialogs/addDatabaseReferenceDialog.test.ts index 506b44fc2e..4e7cf495b5 100644 --- a/extensions/sql-database-projects/src/test/dialogs/addDatabaseReferenceDialog.test.ts +++ b/extensions/sql-database-projects/src/test/dialogs/addDatabaseReferenceDialog.test.ts @@ -32,10 +32,15 @@ describe('Add Database Reference Dialog', () => { should(dialog.dialog.okButton.enabled).equal(false); should(dialog.currentReferenceType).equal(ReferenceType.systemDb); dialog.tryEnableAddReferenceButton(); - should(dialog.dialog.okButton.enabled).equal(false); + should(dialog.dialog.okButton.enabled).equal(true, 'Ok button should be enabled because there is a default value in the database name textbox'); + + // empty db name textbox + dialog.databaseNameTextbox!.value = ''; + dialog.tryEnableAddReferenceButton(); + should(dialog.dialog.okButton.enabled).equal(false, 'Ok button should be disabled after clearing the database name textbox'); // fill in db name and ok button should be enabled - dialog.databaseNameTextbox!.value = 'dbName'; + dialog.databaseNameTextbox!.value = 'master'; dialog.tryEnableAddReferenceButton(); should(dialog.dialog.okButton.enabled).equal(true, 'Ok button should be enabled after the database name textbox is filled'); @@ -43,10 +48,12 @@ describe('Add Database Reference Dialog', () => { dialog.dacpacRadioButtonClick(); should(dialog.currentReferenceType).equal(ReferenceType.dacpac); should(dialog.locationDropdown?.value).equal(constants.differentDbSameServer); + should(dialog.databaseNameTextbox!.value).equal('', 'database name text box should be empty because no dacpac has been selected'); should(dialog.dialog.okButton.enabled).equal(false, 'Ok button should not be enabled because dacpac input box is not filled'); - // fill in dacpac textbox + // fill in dacpac textbox and database name text box dialog.dacpacTextbox!.value = 'testDb.dacpac'; + dialog.databaseNameTextbox!.value = 'testDb'; dialog.tryEnableAddReferenceButton(); should(dialog.dialog.okButton.enabled).equal(true, 'Ok button should be enabled after the dacpac textbox is filled'); @@ -54,13 +61,7 @@ describe('Add Database Reference Dialog', () => { dialog.locationDropdown!.value = constants.differentDbDifferentServer; dialog.updateEnabledInputBoxes(); dialog.tryEnableAddReferenceButton(); - should(dialog.dialog.okButton.enabled).equal(false, 'Ok button should not be enabled because server fields are not filled'); - - // fill in server fields - dialog.serverNameTextbox!.value = 'serverName'; - dialog.serverVariableTextbox!.value = '$(serverName)'; - dialog.tryEnableAddReferenceButton(); - should(dialog.dialog.okButton.enabled).equal(true, 'Ok button should be enabled after server fields are filled'); + should(dialog.dialog.okButton.enabled).equal(true, 'Ok button should be enabled because server fields are filled'); // change location to same database dialog.locationDropdown!.value = constants.sameDatabase; @@ -74,8 +75,9 @@ describe('Add Database Reference Dialog', () => { // change reference type back to system db dialog.systemDbRadioButtonClick(); - should(dialog.databaseNameTextbox?.value).equal('', `Database name textbox should be empty. Actual:${dialog.databaseNameTextbox?.value}`); - should(dialog.dialog.okButton.enabled).equal(false, 'Ok button should not be enabled because database name is not filled out'); + should(dialog.locationDropdown?.value).equal(constants.differentDbSameServer); + should(dialog.databaseNameTextbox?.value).equal('master', `Database name textbox should be set to master. Actual:${dialog.databaseNameTextbox?.value}`); + should(dialog.dialog.okButton.enabled).equal(true, 'Ok button should be enabled because database name is filled'); }); it('Should enable and disable input boxes depending on the reference type', async function (): Promise { diff --git a/extensions/sql-database-projects/src/test/publishProfile.test.ts b/extensions/sql-database-projects/src/test/publishProfile.test.ts index daf0a7d000..ef84c710a9 100644 --- a/extensions/sql-database-projects/src/test/publishProfile.test.ts +++ b/extensions/sql-database-projects/src/test/publishProfile.test.ts @@ -62,7 +62,7 @@ describe('Publish profile tests', function (): void { connectionId: 'connId', options: { 'server': 'testserver', - 'username': 'testUser' + 'user': 'testUser' } }; testContext.dacFxService.setup(x => x.getOptionsFromProfile(TypeMoq.It.isAny())).returns(async () => { diff --git a/extensions/sql-database-projects/src/test/utils.test.ts b/extensions/sql-database-projects/src/test/utils.test.ts index 0a7de5ac48..8004c53499 100644 --- a/extensions/sql-database-projects/src/test/utils.test.ts +++ b/extensions/sql-database-projects/src/test/utils.test.ts @@ -7,7 +7,7 @@ import * as should from 'should'; import * as path from 'path'; import * as os from 'os'; import { createDummyFileStructure } from './testUtils'; -import { exists, trimUri } from '../common/utils'; +import { exists, trimUri, removeSqlCmdVariableFormatting, formatSqlCmdVariable, isValidSqlCmdVariableName } from '../common/utils'; import { Uri } from 'vscode'; describe('Tests to verify utils functions', function (): void { @@ -38,5 +38,45 @@ describe('Tests to verify utils functions', function (): void { fileUri = Uri.file(path.join(root, 'forked', 'from', 'top', 'file.sql')); should(trimUri(projectUri, fileUri)).equal('../../forked/from/top/file.sql'); }); + + it('Should remove $() from sqlcmd variables', () => { + should(removeSqlCmdVariableFormatting('$(test)')).equal('test', '$() surrounding the variable should have been removed'); + should(removeSqlCmdVariableFormatting('$(test')).equal('test', '$( at the beginning of the variable should have been removed'); + should(removeSqlCmdVariableFormatting('test')).equal('test', 'string should not have been changed because it is not in sqlcmd variable format'); + }); + + it('Should make variable be in sqlcmd variable format with $()', () => { + should(formatSqlCmdVariable('$(test)')).equal('$(test)', 'string should not have been changed because it was already in the correct format'); + should(formatSqlCmdVariable('test')).equal('$(test)', 'string should have been changed to be in sqlcmd variable format'); + should(formatSqlCmdVariable('$(test')).equal('$(test)', 'string should have been changed to be in sqlcmd variable format'); + should(formatSqlCmdVariable('')).equal('', 'should not do anything to an empty string'); + }); + + it('Should determine invalid sqlcmd variable names', () => { + // valid names + should(isValidSqlCmdVariableName('$(test)')).equal(true); + should(isValidSqlCmdVariableName('$(test )')).equal(true, 'trailing spaces should be valid because they will be trimmed'); + should(isValidSqlCmdVariableName('test')).equal(true); + should(isValidSqlCmdVariableName('test ')).equal(true, 'trailing spaces should be valid because they will be trimmed'); + should(isValidSqlCmdVariableName('$(test')).equal(true); + should(isValidSqlCmdVariableName('$(test ')).equal(true, 'trailing spaces should be valid because they will be trimmed'); + + // whitespace + should(isValidSqlCmdVariableName('')).equal(false); + should(isValidSqlCmdVariableName(' ')).equal(false); + should(isValidSqlCmdVariableName(' ')).equal(false); + should(isValidSqlCmdVariableName('test abc')).equal(false); + should(isValidSqlCmdVariableName(' ')).equal(false); + + // invalid characters + should(isValidSqlCmdVariableName('$($test')).equal(false); + should(isValidSqlCmdVariableName('$test')).equal(false); + should(isValidSqlCmdVariableName('$test')).equal(false); + should(isValidSqlCmdVariableName('test@')).equal(false); + should(isValidSqlCmdVariableName('test#')).equal(false); + should(isValidSqlCmdVariableName('test"')).equal(false); + should(isValidSqlCmdVariableName('test\'')).equal(false); + should(isValidSqlCmdVariableName('test-1')).equal(false); + }); });