diff --git a/extensions/machine-learning/images/dark/warning_notification_inverse.svg b/extensions/machine-learning/images/dark/warning_notification_inverse.svg new file mode 100644 index 0000000000..c2f3c2103d --- /dev/null +++ b/extensions/machine-learning/images/dark/warning_notification_inverse.svg @@ -0,0 +1 @@ +warning_notification_inverse \ No newline at end of file diff --git a/extensions/machine-learning/images/light/warning_notification.svg b/extensions/machine-learning/images/light/warning_notification.svg new file mode 100644 index 0000000000..34a28f8280 --- /dev/null +++ b/extensions/machine-learning/images/light/warning_notification.svg @@ -0,0 +1 @@ +warning_notification \ No newline at end of file diff --git a/extensions/machine-learning/package.json b/extensions/machine-learning/package.json index ae272ca2ab..3a388bb5cc 100644 --- a/extensions/machine-learning/package.json +++ b/extensions/machine-learning/package.json @@ -2,7 +2,7 @@ "name": "machine-learning", "displayName": "%displayName%", "description": "%description%", - "version": "1.0.0", + "version": "0.1.0", "publisher": "Microsoft", "preview": true, "engines": { @@ -84,6 +84,10 @@ { "command": "mls.command.dependencies", "title": "%mls.command.dependencies%" + }, + { + "command": "mls.command.enableExternalScript", + "title": "%mls.command.enableExternalScript%" } ], "dashboard.tabs": [ @@ -113,8 +117,8 @@ "name": "", "row": 0, "col": 1, - "rowspan": 4, - "colspan": 3, + "rowspan": 5, + "colspan": 5, "widget": { "modelview": { "id":"mls.dashboard" diff --git a/extensions/machine-learning/package.nls.json b/extensions/machine-learning/package.nls.json index 29b7e46b11..7d919e0730 100644 --- a/extensions/machine-learning/package.nls.json +++ b/extensions/machine-learning/package.nls.json @@ -12,8 +12,9 @@ "mls.command.importModel": "Import model", "mls.configuration.title": "Machine Learning Configurations", "mls.pythonPath.description": "Local path to a preexisting Python installation used by Machine Learning.", - "mls.enablePython.description": "Enable Python package management.", - "mls.enableR.description": "Enable R package management.", + "mls.enablePython.description": "Enable Python package management in database.", + "mls.enableR.description": "Enable R package management in database.", "mls.rPath.description": "Local path to a preexisting R installation used by Machine Learning.", - "mls.command.dependencies": "Install Machine Learning Dependencies" + "mls.command.dependencies": "Install Machine Learning Dependencies", + "mls.command.enableExternalScript": "Enable External script" } diff --git a/extensions/machine-learning/src/common/apiWrapper.ts b/extensions/machine-learning/src/common/apiWrapper.ts index 9beea070aa..6f3904682f 100644 --- a/extensions/machine-learning/src/common/apiWrapper.ts +++ b/extensions/machine-learning/src/common/apiWrapper.ts @@ -110,6 +110,10 @@ export class ApiWrapper { return azdata.connection.listDatabases(connectionId); } + public getServerInfo(connectionId: string): Thenable { + return azdata.connection.getServerInfo(connectionId); + } + public openTextDocument(options?: { language?: string; content?: string; }): Thenable { return vscode.workspace.openTextDocument(options); } diff --git a/extensions/machine-learning/src/common/constants.ts b/extensions/machine-learning/src/common/constants.ts index 637db20d4f..a497b52b9f 100644 --- a/extensions/machine-learning/src/common/constants.ts +++ b/extensions/machine-learning/src/common/constants.ts @@ -30,6 +30,7 @@ export const mlManageModelsCommand = 'mls.command.manageModels'; export const mlImportModelCommand = 'mls.command.importModel'; export const mlManagePackagesCommand = 'mls.command.managePackages'; export const mlsDependenciesCommand = 'mls.command.dependencies'; +export const mlsEnableExternalScriptCommand = 'mls.command.enableExternalScript'; export const notebookCommandNew = 'notebook.command.new'; // Configurations @@ -45,7 +46,7 @@ export const rPathConfigKey = 'rPath'; // export const msgYes = localize('msgYes', "Yes"); export const msgNo = localize('msgNo', "No"); -export const managePackageCommandError = localize('mls.managePackages.error', "Either no connection is available or the server does not have external script enabled."); +export const managePackageCommandError = localize('mls.managePackages.error', "Package management is not supported for the server. Make sure you have Python or R installed."); export function taskFailedError(taskName: string, err: string): string { return localize('mls.taskFailedError.error', "Failed to complete task '{0}'. Error: {1}", taskName, err); } export const installPackageMngDependenciesMsgTaskName = localize('mls.installPackageMngDependencies.msgTaskName', "Installing package management dependencies"); export const installModelMngDependenciesMsgTaskName = localize('mls.installModelMngDependencies.msgTaskName', "Installing model management dependencies"); @@ -54,9 +55,10 @@ export const requiredPackagesNotInstalled = localize('mls.requiredPackagesNotIns export const confirmEnableExternalScripts = localize('mls.confirmEnableExternalScripts', "External script is required for package management. Are you sure you want to enable that."); export const enableExternalScriptsError = localize('mls.enableExternalScriptsError', "Failed to enable External script."); export const externalScriptsIsRequiredError = localize('mls.externalScriptsIsRequiredError', "External script configuration is required for this action."); -export function confirmInstallPythonPackages(packages: string): string { +export const confirmInstallPythonPackages = localize('mls.confirmInstallPythonPackages', "Are you sure you want to install required packages?"); +export function confirmInstallPythonPackagesDetails(packages: string): string { return localize('mls.installDependencies.confirmInstallPythonPackages' - , "The following Python packages are required to install: {0}. Are you sure you want to install?", packages); + , "The following Python packages are required to install: {0}", packages); } export function confirmDeleteModel(modelName: string): string { return localize('models.confirmDeleteModel' @@ -120,27 +122,30 @@ export const extLangInstallFailedError = localize('extLang.installFailedError', export const extLangUpdateFailedError = localize('extLang.updateFailedError', "Failed to update language"); export const modelUpdateFailedError = localize('models.modelUpdateFailedError', "Failed to update the model"); -export const databaseName = localize('databaseName', "Database name"); -export const tableName = localize('tableName', "Table name"); +export const databaseName = localize('databaseName', "Models database"); +export const tableName = localize('tableName', "Models table"); export const modelName = localize('models.name', "Name"); export const modelFileName = localize('models.fileName', "File"); export const modelDescription = localize('models.description', "Description"); export const modelCreated = localize('models.created', "Date created"); -export const modelDeployed = localize('models.deployed', "Date deployed"); +export const modelImported = localize('models.imported', "Date imported"); export const modelFramework = localize('models.framework', "Framework"); export const modelFrameworkVersion = localize('models.frameworkVersion', "Framework version"); export const modelVersion = localize('models.version', "Version"); export const browseModels = localize('models.browseButton', "..."); export const azureAccount = localize('models.azureAccount', "Azure account"); export const azureSignIn = localize('models.azureSignIn', "Sign in to Azure"); -export const columnDatabase = localize('predict.columnDatabase', "Target database"); -export const columnTable = localize('predict.columnTable', "Target table"); -export const inputColumns = localize('predict.inputColumns', "Model input mapping"); +export const columnDatabase = localize('predict.columnDatabase', "Source database"); +export const columnTable = localize('predict.columnTable', "Source table"); +export const inputColumns = localize('predict.inputColumns', "Model Input mapping"); export const outputColumns = localize('predict.outputColumns', "Model output"); -export const columnName = localize('predict.columnName', "Target columns"); +export const columnName = localize('predict.columnName', "Source columns"); export const dataTypeName = localize('predict.dataTypeName', "Type"); export const displayName = localize('predict.displayName', "Display name"); -export const inputName = localize('predict.inputName', "Required model input features"); +export const inputName = localize('predict.inputName', "Model input"); +export const selectColumnTitle = localize('predict.selectColumnTitle', "Select column..."); +export const selectDatabaseTitle = localize('predict.selectDatabaseTitle', "Select database"); +export const selectTableTitle = localize('predict.selectTableTitle', "Select table"); export const outputName = localize('predict.outputName', "Name"); export const azureSubscription = localize('models.azureSubscription', "Azure subscription"); export const azureGroup = localize('models.azureGroup', "Azure resource group"); @@ -151,10 +156,12 @@ export const azureModelsTitle = localize('models.azureModelsTitle', "Azure model export const localModelsTitle = localize('models.localModelsTitle', "Local models"); export const modelSourcesTitle = localize('models.modelSourcesTitle', "Source location"); export const modelSourcePageTitle = localize('models.modelSourcePageTitle', "Where is your model located?"); -export const modelImportTargetPageTitle = localize('models.modelImportTargetPageTitle', "Where do you want import models to?"); -export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map predictions target data to model input"); +export const modelImportTargetPageTitle = localize('models.modelImportTargetPageTitle', "Select or enter the location to import the models to"); +export const columnSelectionPageTitle = localize('models.columnSelectionPageTitle', "Map source data to model"); export const modelDetailsPageTitle = localize('models.modelDetailsPageTitle', "Enter model details"); -export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source file"); +export const modelLocalSourceTitle = localize('models.modelLocalSourceTitle', "Source files"); +export const modelLocalSourceTooltip = localize('models.modelLocalSourceTooltip', "File paths of the models to import"); +export const onnxNotSupportedError = localize('models.onnxNotSupportedError', "ONNX runtime is not supported in current server"); export const currentModelsTitle = localize('models.currentModelsTitle', "Models"); export const azureRegisterModel = localize('models.azureRegisterModel', "Deploy"); export const predictModel = localize('models.predictModel', "Predict"); @@ -179,10 +186,14 @@ export const downloadModelMsgTaskName = localize('models.downloadModelMsgTaskNam export const invalidAzureResourceError = localize('models.invalidAzureResourceError', "Invalid Azure resource"); export const invalidModelToRegisterError = localize('models.invalidModelToRegisterError', "Invalid model to register"); export const invalidModelToPredictError = localize('models.invalidModelToPredictError', "Invalid model to predict"); +export const invalidModelParametersError = localize('models.invalidModelParametersError', "Please select valid source table and model parameters"); export const invalidModelToSelectError = localize('models.invalidModelToSelectError', "Please select a valid model"); export const invalidModelImportTargetError = localize('models.invalidModelImportTargetError', "Please select a valid table"); +export const columnDataTypeMismatchWarning = localize('models.columnDataTypeMismatchWarning', "The data type of the source table column does not match the required input field’s type."); export const modelNameRequiredError = localize('models.modelNameRequiredError', "Model name is required."); export const updateModelFailedError = localize('models.updateModelFailedError', "Failed to update the model"); +export const modelSchemaIsAcceptedMessage = localize('models.modelSchemaIsAcceptedMessage', "Table meets requirements!"); +export const modelSchemaIsNotAcceptedMessage = localize('models.modelSchemaIsNotAcceptedMessage', "Invalid table structure"); export function importModelFailedError(modelName: string | undefined, filePath: string | undefined): string { return localize('models.importModelFailedError', "Failed to register the model: {0} ,file: {1}", modelName || '', filePath || ''); } export function invalidImportTableError(databaseName: string | undefined, tableName: string | undefined): string { return localize('models.invalidImportTableError', "Invalid table for importing models. database name: {0} ,table name: {1}", databaseName || '', tableName || ''); } export function invalidImportTableSchemaError(databaseName: string | undefined, tableName: string | undefined): string { return localize('models.invalidImportTableSchemaError', "Table schema is not supported for model import. database name: {0} ,table name: {1}", databaseName || '', tableName || ''); } diff --git a/extensions/machine-learning/src/common/eventEmitter.ts b/extensions/machine-learning/src/common/eventEmitter.ts index 1fbcda73d5..fec3a45c24 100644 --- a/extensions/machine-learning/src/common/eventEmitter.ts +++ b/extensions/machine-learning/src/common/eventEmitter.ts @@ -13,17 +13,16 @@ export class EventEmitterCollection extends vscode.Disposable { */ constructor() { super(() => this.dispose()); - } - public on(evt: string, listener: (e: any) => any, thisArgs?: any) { + public on(evt: string, listener: (e: any) => any, thisArgs?: any): vscode.EventEmitter { if (!this._events.has(evt)) { this._events.set(evt, []); } let eventEmitter = new vscode.EventEmitter(); eventEmitter.event(listener, thisArgs); this._events.get(evt)?.push(eventEmitter); - return this; + return eventEmitter; } public fire(evt: string, arg?: any) { @@ -35,6 +34,16 @@ export class EventEmitterCollection extends vscode.Disposable { }); } + public disposeEvent(evt: string, emitter: vscode.EventEmitter): void { + if (this._events.has(evt)) { + const emitters = this._events.get(evt); + if (emitters) { + this._events.set(evt, emitters.filter(x => x !== emitter)); + } + } + emitter.dispose(); + } + public dispose(): any { this._events.forEach(events => { events.forEach(event => { diff --git a/extensions/machine-learning/src/controllers/mainController.ts b/extensions/machine-learning/src/controllers/mainController.ts index 50ccb82f4a..f3fe8e5977 100644 --- a/extensions/machine-learning/src/controllers/mainController.ts +++ b/extensions/machine-learning/src/controllers/mainController.ts @@ -95,14 +95,14 @@ export default class MainController implements vscode.Disposable { let modelManagementController = new ModelManagementController(this._apiWrapper, this._rootPath, azureModelsService, registeredModelService, predictService); - let dashboardWidget = new DashboardWidget(this._apiWrapper, this._rootPath); + let dashboardWidget = new DashboardWidget(this._apiWrapper, this._rootPath, predictService); dashboardWidget.register(); this._apiWrapper.registerCommand(constants.mlManageModelsCommand, (async () => { await modelManagementController.manageRegisteredModels(); })); this._apiWrapper.registerCommand(constants.mlImportModelCommand, (async () => { - await modelManagementController.registerModel(undefined); + await modelManagementController.importModel(undefined); })); this._apiWrapper.registerCommand(constants.mlsPredictModelCommand, (async () => { await modelManagementController.predictModel(); @@ -110,6 +110,9 @@ export default class MainController implements vscode.Disposable { this._apiWrapper.registerCommand(constants.mlsDependenciesCommand, (async () => { await packageManager.installDependencies(); })); + this._apiWrapper.registerCommand(constants.mlsEnableExternalScriptCommand, (async () => { + await packageManager.enableExternalScript(); + })); this._apiWrapper.registerTaskHandler(constants.mlManagePackagesCommand, async () => { await packageManager.managePackages(); }); diff --git a/extensions/machine-learning/src/modelManagement/deployedModelService.ts b/extensions/machine-learning/src/modelManagement/deployedModelService.ts index c70e3d41dd..f5533309ff 100644 --- a/extensions/machine-learning/src/modelManagement/deployedModelService.ts +++ b/extensions/machine-learning/src/modelManagement/deployedModelService.ts @@ -61,12 +61,19 @@ export class DeployedModelService { */ public async downloadModel(model: ImportedModel): Promise { let connection = await this.getCurrentConnection(); + let fileContent: string = ''; if (connection) { const query = queries.getModelContentQuery(model); let result = await this._queryRunner.safeRunQuery(connection, query); if (result && result.rows && result.rows.length > 0) { - const content = result.rows[0][0].displayValue; - return await utils.writeFileFromHex(content); + for (let index = 0; index < result.rows[0].length; index++) { + const column = result.rows[0][index]; + let content = column.displayValue; + content = content.startsWith('0x') || content.startsWith('0X') ? content.substr(2) : content; + fileContent = fileContent + content; + } + + return await utils.writeFileFromHex(fileContent); } else { throw Error(constants.invalidModelToSelectError); } @@ -170,6 +177,13 @@ export class DeployedModelService { } } + /** + * Installs the dependencies required for model management + */ + public async installDependencies(): Promise { + await this._modelClient.installDependencies(); + } + public async getRecentImportTable(): Promise { let connection = await this.getCurrentConnection(); let table: DatabaseTable | undefined; @@ -209,6 +223,7 @@ export class DeployedModelService { deploymentTime: row[7].displayValue, deployedBy: row[8].displayValue, runId: row[9].displayValue, + contentLength: +row[10].displayValue, table: table }; } diff --git a/extensions/machine-learning/src/modelManagement/interfaces.ts b/extensions/machine-learning/src/modelManagement/interfaces.ts index 8884401e33..a08d6c1aac 100644 --- a/extensions/machine-learning/src/modelManagement/interfaces.ts +++ b/extensions/machine-learning/src/modelManagement/interfaces.ts @@ -18,6 +18,7 @@ export interface ListWorkspaceModelsResult extends Array { */ export interface WorkspaceModel extends Resource { framework?: string; + description?: string; frameworkVersion?: string; createdBy?: string; createdTime?: string; @@ -52,12 +53,14 @@ export type WorkspacesModelsResponse = ListWorkspaceModelsResult & { export interface ImportedModel extends ImportedModelDetails { id: number; content?: string; + contentLength?: number; table: DatabaseTable; } export interface ModelParameter { name: string; type: string; + originalType?: string; } export interface ModelParameters { diff --git a/extensions/machine-learning/src/modelManagement/modelConfigRecent.ts b/extensions/machine-learning/src/modelManagement/modelConfigRecent.ts index 760bda5e28..9e5afeb56f 100644 --- a/extensions/machine-learning/src/modelManagement/modelConfigRecent.ts +++ b/extensions/machine-learning/src/modelManagement/modelConfigRecent.ts @@ -21,7 +21,12 @@ export class ModelConfigRecent { } public storeModelTable(connection: azdata.connection.ConnectionProfile, table: DatabaseTable): void { - this._memento.update(this.getKey(connection), table); + if (connection && table?.databaseName && table?.tableName && table?.schema) { + const current = this.getModelTable(connection); + if (!current || current.databaseName !== table.databaseName || current.tableName !== table.tableName || current.schema !== table.schema) { + this._memento.update(this.getKey(connection), table); + } + } } private getKey(connection: azdata.connection.ConnectionProfile): string { diff --git a/extensions/machine-learning/src/modelManagement/modelPythonClient.ts b/extensions/machine-learning/src/modelManagement/modelPythonClient.ts index 1efe1780cb..494070c1d8 100644 --- a/extensions/machine-learning/src/modelManagement/modelPythonClient.ts +++ b/extensions/machine-learning/src/modelManagement/modelPythonClient.ts @@ -39,7 +39,7 @@ export class ModelPythonClient { /** * Installs dependencies for python client */ - private async installDependencies(): Promise { + public async installDependencies(): Promise { await utils.executeTasks(this._apiWrapper, constants.installModelMngDependenciesMsgTaskName, [ this._packageManager.installRequiredPythonPackages(this._config.modelsRequiredPythonPackages)], true); } @@ -49,7 +49,6 @@ export class ModelPythonClient { * @param modelPath Loads model parameters */ public async loadModelParameters(modelPath: string): Promise { - await this.installDependencies(); return await this.executeModelParametersScripts(modelPath); } @@ -61,6 +60,9 @@ export class ModelPythonClient { 'import json', `onnx_model_path = '${modelFolderPath}'`, `onnx_model = onnx.load_model(onnx_model_path)`, + `type_list = ['undefined', + 'float', 'uint8', 'int8', 'uint16', 'int16', 'int32', 'int64', 'string', 'bool', 'double', + 'uint32', 'uint64', 'complex64', 'complex128', 'bfloat16']`, `type_map = { onnx.TensorProto.DataType.FLOAT: 'real', onnx.TensorProto.DataType.UINT8: 'tinyint', @@ -76,13 +78,14 @@ export class ModelPythonClient { `def addParameters(list, paramType): for id, p in enumerate(list): p_type = '' - - if p.type.tensor_type.elem_type in type_map: - p_type = type_map[p.type.tensor_type.elem_type] - + value = p.type.tensor_type.elem_type + if value in type_map: + p_type = type_map[value] + name = type_list[value] parameters[paramType].append({ 'name': p.name, - 'type': p_type + 'type': p_type, + 'originalType': name })`, 'addParameters(onnx_model.graph.input, "inputs")', diff --git a/extensions/machine-learning/src/modelManagement/queries.ts b/extensions/machine-learning/src/modelManagement/queries.ts index da19c7228f..3f90c05160 100644 --- a/extensions/machine-learning/src/modelManagement/queries.ts +++ b/extensions/machine-learning/src/modelManagement/queries.ts @@ -144,12 +144,37 @@ export function getInsertModelQuery(model: ImportedModel, table: DatabaseTable): `; } +/** + * Returns the query for loading model content from database + * @param model model information + */ export function getModelContentQuery(model: ImportedModel): string { const threePartTableName = utils.getRegisteredModelsThreePartsName(model.table.databaseName || '', model.table.tableName || '', model.table.schema || ''); + const len = model.contentLength !== undefined ? model.contentLength : 0; + const maxLength = 1000; + let numberOfColumns = len / maxLength; + // The query provider doesn't return the whole file bites if too big. so loading the bites it blocks + // and merge together to load the file + numberOfColumns = numberOfColumns <= 0 ? 1 : numberOfColumns; + let columns: string[] = []; + let fileIndex = 0; + for (let index = 0; index < numberOfColumns; index++) { + const length = fileIndex === 0 ? maxLength + 1 : maxLength; + columns.push(`substring(@str, ${fileIndex}, ${length}) as d${index}`); + fileIndex = fileIndex + length; + } + + if (fileIndex < len) { + columns.push(`substring(@str, ${fileIndex}, ${maxLength}) as d${columns.length}`); + } return ` - SELECT model + DECLARE @str varbinary(max) + + SELECT @str=model FROM ${threePartTableName} WHERE model_id = ${model.id}; + + select ${columns.join(',')} `; } @@ -190,6 +215,6 @@ export function getDeleteModelQuery(model: ImportedModel): string { `; } -export const selectQuery = 'SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id'; +export const selectQuery = 'SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id, len(model)'; diff --git a/extensions/machine-learning/src/packageManagement/packageManagementService.ts b/extensions/machine-learning/src/packageManagement/packageManagementService.ts index 093cbb817f..b817b354e8 100644 --- a/extensions/machine-learning/src/packageManagement/packageManagementService.ts +++ b/extensions/machine-learning/src/packageManagement/packageManagementService.ts @@ -59,6 +59,7 @@ export class PackageManagementService { let current = await this._queryRunner.isMachineLearningServiceEnabled(connection); if (current) { + this._apiWrapper.showInfoMessage(constants.mlsEnabledMessage); return current; } let confirmed = await utils.promptConfirm(constants.confirmEnableExternalScripts, this._apiWrapper); diff --git a/extensions/machine-learning/src/packageManagement/packageManager.ts b/extensions/machine-learning/src/packageManagement/packageManager.ts index b474b4deca..b949d2561b 100644 --- a/extensions/machine-learning/src/packageManagement/packageManager.ts +++ b/extensions/machine-learning/src/packageManagement/packageManager.ts @@ -158,7 +158,8 @@ export class PackageManager { }); if (fileContent) { - let confirmed = await utils.promptConfirm(constants.confirmInstallPythonPackages(fileContent), this._apiWrapper); + this._apiWrapper.showInfoMessage(constants.confirmInstallPythonPackagesDetails(fileContent)); + let confirmed = await utils.promptConfirm(constants.confirmInstallPythonPackages, this._apiWrapper); if (confirmed) { this._outputChannel.appendLine(constants.installDependenciesPackages); let result = await utils.execCommandOnTempFile(fileContent, async (tempFilePath) => { diff --git a/extensions/machine-learning/src/prediction/interfaces.ts b/extensions/machine-learning/src/prediction/interfaces.ts index 2274a0d8b1..f95995d5b5 100644 --- a/extensions/machine-learning/src/prediction/interfaces.ts +++ b/extensions/machine-learning/src/prediction/interfaces.ts @@ -10,6 +10,7 @@ export interface TableColumn { export interface PredictColumn extends TableColumn { paramName?: string; + paramType?: string; } export interface DatabaseTable { diff --git a/extensions/machine-learning/src/prediction/predictService.ts b/extensions/machine-learning/src/prediction/predictService.ts index 4bd3ddc9fe..db7682c5c5 100644 --- a/extensions/machine-learning/src/prediction/predictService.ts +++ b/extensions/machine-learning/src/prediction/predictService.ts @@ -35,6 +35,25 @@ export class PredictService { return []; } + /** + * Returns true if server supports ONNX + */ + public async serverSupportOnnxModel(): Promise { + try { + let connection = await this.getCurrentConnection(); + if (connection) { + const serverInfo = await this._apiWrapper.getServerInfo(connection.connectionId); + // Right now only Azure SQL Edge support Onnx + // + return serverInfo && serverInfo.engineEditionId === 9; + } + return false; + } catch (error) { + console.log(error); + return false; + } + } + /** * Generates prediction script given model info and predict parameters * @param predictParams predict parameters @@ -157,7 +176,7 @@ AS ( FROM [${utils.doubleEscapeSingleBrackets(sourceTable.databaseName)}].[${sourceTable.schema}].[${utils.doubleEscapeSingleBrackets(sourceTable.tableName)}] as pi ) SELECT -${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getInputColumnNames(outputColumns, 'p')} +${this.getPredictColumnNames(columns, 'predict_input')}, ${this.getPredictInputColumnNames(outputColumns, 'p')} FROM PREDICT(MODEL = @model, DATA = predict_input) WITH ( ${this.getOutputParameters(outputColumns)} @@ -186,7 +205,20 @@ WITH ( `; } + private getEscapedColumnName(tableName: string, columnName: string): string { + return `[${utils.doubleEscapeSingleBrackets(tableName)}].[${utils.doubleEscapeSingleBrackets(columnName)}]`; + } private getInputColumnNames(columns: PredictColumn[], tableName: string) { + + return columns.map(c => { + const column = this.getEscapedColumnName(tableName, c.columnName); + let columnName = c.dataType !== c.paramType ? `cast(${column} as ${c.paramType})` + : `${column}`; + return `${columnName} AS ${c.paramName}`; + }).join(',\n'); + } + + private getPredictInputColumnNames(columns: PredictColumn[], tableName: string) { return columns.map(c => { return this.getColumnName(tableName, c.paramName || '', c.columnName); }).join(',\n'); @@ -199,12 +231,15 @@ WITH ( } private getColumnName(tableName: string, columnName: string, displayName: string) { - return columnName && columnName !== displayName ? `${tableName}.${columnName} AS ${displayName}` : `${tableName}.${columnName}`; + const column = this.getEscapedColumnName(tableName, columnName); + return columnName && columnName !== displayName ? + `${column} AS [${utils.doubleEscapeSingleBrackets(displayName)}]` : column; } private getPredictColumnNames(columns: PredictColumn[], tableName: string) { return columns.map(c => { - return c.paramName ? `${tableName}.${c.paramName}` : `${tableName}.${c.columnName}`; + return c.paramName ? `${this.getEscapedColumnName(tableName, c.paramName)}` + : `${this.getEscapedColumnName(tableName, c.columnName)}`; }).join(',\n'); } diff --git a/extensions/machine-learning/src/test/modelManagement/deployedModelService.test.ts b/extensions/machine-learning/src/test/modelManagement/deployedModelService.test.ts index ef0761a3c3..5190522319 100644 --- a/extensions/machine-learning/src/test/modelManagement/deployedModelService.test.ts +++ b/extensions/machine-learning/src/test/modelManagement/deployedModelService.test.ts @@ -83,7 +83,8 @@ describe('DeployedModelService', () => { frameworkVersion: '1', deployedBy: '1', runId: 'run1', - table: testContext.importTable + table: testContext.importTable, + contentLength: 100 } ]; @@ -141,6 +142,11 @@ describe('DeployedModelService', () => { displayValue: 'run1', isNull: false, invariantCultureDisplayValue: '' + }, + { + displayValue: '100', + isNull: false, + invariantCultureDisplayValue: '' } ] ] @@ -304,7 +310,13 @@ describe('DeployedModelService', () => { displayValue: 'run1', isNull: false, invariantCultureDisplayValue: '' - } + }, + { + displayValue: '100', + isNull: false, + invariantCultureDisplayValue: '' + }, + ]; const result = { rowCount: 1, @@ -391,7 +403,7 @@ describe('DeployedModelService', () => { testContext.importTable.tableName = 'ta[b]le'; testContext.importTable.schema = 'dbo'; const expected = ` - SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id + SELECT model_id, model_name, model_description, model_version, model_creation_time, model_framework, model_framework_version, model_deployment_time, deployed_by, run_id, len(model) FROM [d[[]]b].[dbo].[ta[[b]]le] WHERE model_name not like 'MLmodel' and model_name not like 'conda.yaml' ORDER BY model_id @@ -443,9 +455,13 @@ describe('DeployedModelService', () => { databaseName: 'd[]b', tableName: 'ta[b]le', schema: 'dbo' }; const expected = ` - SELECT model + DECLARE @str varbinary(max) + + SELECT @str=model FROM [d[[]]b].[dbo].[ta[[b]]le] WHERE model_id = 1; + + select substring(@str, 0, 1001) as d0 `; const actual = queries.getModelContentQuery(model); should.deepEqual(actual, expected, `actual: ${actual} \n expected: ${expected}`); diff --git a/extensions/machine-learning/src/test/views/dashboardWidget.test.ts b/extensions/machine-learning/src/test/views/dashboardWidget.test.ts index e99badf84a..09390c68e0 100644 --- a/extensions/machine-learning/src/test/views/dashboardWidget.test.ts +++ b/extensions/machine-learning/src/test/views/dashboardWidget.test.ts @@ -9,11 +9,13 @@ import * as TypeMoq from 'typemoq'; import { ApiWrapper } from '../../common/apiWrapper'; import { createViewContext } from './utils'; import { DashboardWidget } from '../../views/widgets/dashboardWidget'; +import { PredictService } from '../../prediction/predictService'; interface TestContext { apiWrapper: TypeMoq.IMock; view: azdata.ModelView; onClick: vscode.EventEmitter; + predictService: TypeMoq.IMock; } @@ -24,15 +26,22 @@ function createContext(): TestContext { return { apiWrapper: viewTestContext.apiWrapper, view: viewTestContext.view, - onClick: viewTestContext.onClick + onClick: viewTestContext.onClick, + predictService: TypeMoq.Mock.ofType(PredictService) }; } describe('Dashboard widget', () => { it('Should create view components successfully ', async function (): Promise { let testContext = createContext(); - const dashboard = new DashboardWidget(testContext.apiWrapper.object, ''); - dashboard.register(); + + testContext.apiWrapper.setup(x => x.registerWidget(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(async ( handler) => { + await handler(testContext.view); + }); + + testContext.predictService.setup(x => x.serverSupportOnnxModel()).returns(() => Promise.resolve(true)); + const dashboard = new DashboardWidget(testContext.apiWrapper.object, '', testContext.predictService.object); + await dashboard.register(); testContext.onClick.fire(undefined); testContext.apiWrapper.verify(x => x.executeCommand(TypeMoq.It.isAny()), TypeMoq.Times.atLeastOnce()); }); diff --git a/extensions/machine-learning/src/test/views/models/ModelManagementController.test.ts b/extensions/machine-learning/src/test/views/models/ModelManagementController.test.ts index a54178217e..598b2c365a 100644 --- a/extensions/machine-learning/src/test/views/models/ModelManagementController.test.ts +++ b/extensions/machine-learning/src/test/views/models/ModelManagementController.test.ts @@ -115,7 +115,8 @@ const modelParameters: ModelParameters = { }; describe('Model Controller', () => { - it('Should open deploy model wizard successfully ', async function (): Promise { + + it('Should open import model wizard successfully ', async function (): Promise { let testContext = createContext(); @@ -125,16 +126,24 @@ describe('Model Controller', () => { tableName: 'table', schema: 'dbo' })); + testContext.deployModelService.setup(x => x.storeRecentImportTable(TypeMoq.It.isAny())).returns(() => Promise.resolve()); testContext.deployModelService.setup(x => x.getDeployedModels(TypeMoq.It.isAny())).returns(() => Promise.resolve(localModels)); + testContext.deployModelService.setup(x => x.verifyConfigTable(TypeMoq.It.isAny())).returns(() => Promise.resolve(true)); + testContext.deployModelService.setup(x => x.deployLocalModel(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve()); + testContext.deployModelService.setup(x => x.updateModel(TypeMoq.It.isAny())).returns(() => Promise.resolve()); + testContext.deployModelService.setup(x => x.deleteModel(TypeMoq.It.isAny())).returns(() => Promise.resolve()); + testContext.deployModelService.setup(x => x.downloadModel(TypeMoq.It.isAny())).returns(() => Promise.resolve('path')); testContext.predictService.setup(x => x.getDatabaseList()).returns(() => Promise.resolve(dbNames)); testContext.predictService.setup(x => x.getTableList(TypeMoq.It.isAny())).returns(() => Promise.resolve(tableNames)); testContext.azureModelService.setup(x => x.getAccounts()).returns(() => Promise.resolve(accounts)); + testContext.azureModelService.setup(x => x.signInToAzure()).returns(() => Promise.resolve()); testContext.azureModelService.setup(x => x.getSubscriptions(TypeMoq.It.isAny())).returns(() => Promise.resolve(subscriptions)); testContext.azureModelService.setup(x => x.getGroups(TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(groups)); testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces)); testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models)); + testContext.azureModelService.setup(x => x.downloadModel(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve('path')); - const view = await controller.registerModel(undefined); + const view = await controller.importModel(undefined); should.notEqual(view, undefined); }); @@ -161,7 +170,10 @@ describe('Model Controller', () => { testContext.azureModelService.setup(x => x.getWorkspaces(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(workspaces)); testContext.azureModelService.setup(x => x.getModels(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve(models)); testContext.predictService.setup(x => x.getTableColumnsList(TypeMoq.It.isAny())).returns(() => Promise.resolve(columnNames)); + testContext.predictService.setup(x => x.serverSupportOnnxModel()).returns(() => Promise.resolve(true)); testContext.deployModelService.setup(x => x.loadModelParameters(TypeMoq.It.isAny())).returns(() => Promise.resolve(modelParameters)); + testContext.deployModelService.setup(x => x.verifyConfigTable(TypeMoq.It.isAny())).returns(() => Promise.resolve(true)); + testContext.deployModelService.setup(x => x.installDependencies()).returns(() => Promise.resolve()); testContext.azureModelService.setup(x => x.downloadModel(TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny(), TypeMoq.It.isAny())).returns(() => Promise.resolve('file')); testContext.deployModelService.setup(x => x.downloadModel(TypeMoq.It.isAny())).returns(() => Promise.resolve('file')); @@ -169,6 +181,17 @@ describe('Model Controller', () => { should.notEqual(view, undefined); }); + + it('Should show error message if onnx is not supported ', async function (): Promise { + let testContext = createContext(); + let controller = new ModelManagementController(testContext.apiWrapper.object, '', testContext.azureModelService.object, testContext.deployModelService.object, testContext.predictService.object); + testContext.predictService.setup(x => x.serverSupportOnnxModel()).returns(() => Promise.resolve(false)); + testContext.apiWrapper.setup(x => x.showErrorMessage(TypeMoq.It.isAny())).returns(() => Promise.resolve('')); + const view = await controller.predictModel(); + should.equal(view, undefined); + testContext.apiWrapper.verify(x => x.showErrorMessage(TypeMoq.It.isAny()), TypeMoq.Times.once()); + }); + it('Should open edit model dialog successfully ', async function (): Promise { let testContext = createContext(); testContext.deployModelService.setup(x => x.updateModel(TypeMoq.It.isAny())).returns(() => Promise.resolve()); @@ -199,4 +222,5 @@ describe('Model Controller', () => { should.notEqual(view, undefined); }); + }); diff --git a/extensions/machine-learning/src/test/views/models/azureModelsComponent.test.ts b/extensions/machine-learning/src/test/views/models/azureModelsComponent.test.ts index e3cb1a78b2..73a4d22e68 100644 --- a/extensions/machine-learning/src/test/views/models/azureModelsComponent.test.ts +++ b/extensions/machine-learning/src/test/views/models/azureModelsComponent.test.ts @@ -71,21 +71,21 @@ describe('Azure Models Component', () => { name: 'model' } ]; - parent.on(ListAccountsEventName, () => { - parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts }); + parent.on(ListAccountsEventName, (args) => { + parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { inputArgs: args, data: accounts }); }); - parent.on(ListSubscriptionsEventName, () => { + parent.on(ListSubscriptionsEventName, (args) => { - parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions }); + parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { inputArgs: args, data: subscriptions }); }); - parent.on(ListGroupsEventName, () => { - parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups }); + parent.on(ListGroupsEventName, (args) => { + parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { inputArgs: args, data: groups }); }); - parent.on(ListWorkspacesEventName, () => { - parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces }); + parent.on(ListWorkspacesEventName, (args) => { + parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { inputArgs: args, data: workspaces }); }); - parent.on(ListAzureModelsEventName, () => { - parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models }); + parent.on(ListAzureModelsEventName, (args) => { + parent.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { inputArgs: args, data: models }); }); await view.refresh(); testContext.onClick.fire(true); diff --git a/extensions/machine-learning/src/test/views/models/predictWizard.test.ts b/extensions/machine-learning/src/test/views/models/predictWizard.test.ts index fe3897ea09..2bfcd14956 100644 --- a/extensions/machine-learning/src/test/views/models/predictWizard.test.ts +++ b/extensions/machine-learning/src/test/views/models/predictWizard.test.ts @@ -9,7 +9,7 @@ import 'mocha'; import { createContext } from './utils'; import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, - ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName, ModelSourceType + ListAzureModelsEventName, ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, LoadModelParametersEventName, DownloadAzureModelEventName, DownloadRegisteredModelEventName, ModelSourceType, VerifyImportTableEventName } from '../../../views/models/modelViewBase'; import { ImportedModel, ModelParameters } from '../../../modelManagement/interfaces'; @@ -136,42 +136,45 @@ describe('Predict Wizard', () => { ] }; - view.on(ListModelsEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels }); + view.on(ListModelsEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { inputArgs: args, data: localModels }); }); - view.on(ListAccountsEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts }); + view.on(ListAccountsEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { inputArgs: args, data: accounts }); }); - view.on(ListSubscriptionsEventName, () => { + view.on(ListSubscriptionsEventName, (args) => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions }); + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { inputArgs: args, data: subscriptions }); }); - view.on(ListGroupsEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups }); + view.on(ListGroupsEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { inputArgs: args, data: groups }); }); - view.on(ListWorkspacesEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces }); + view.on(ListWorkspacesEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { inputArgs: args, data: workspaces }); }); - view.on(ListAzureModelsEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models }); + view.on(ListAzureModelsEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { inputArgs: args, data: models }); }); - view.on(ListDatabaseNamesEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: dbNames }); + view.on(ListDatabaseNamesEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { inputArgs: args, data: dbNames }); }); - view.on(ListTableNamesEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: tableNames }); + view.on(ListTableNamesEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { inputArgs: args, data: tableNames }); }); - view.on(ListColumnNamesEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListColumnNamesEventName), { data: columnNames }); + view.on(ListColumnNamesEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListColumnNamesEventName), { inputArgs: args, data: columnNames }); }); - view.on(LoadModelParametersEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(LoadModelParametersEventName), { data: modelParameters }); + view.on(LoadModelParametersEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(LoadModelParametersEventName), { inputArgs: args, data: modelParameters }); }); - view.on(DownloadAzureModelEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadAzureModelEventName), { data: 'path' }); + view.on(DownloadAzureModelEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadAzureModelEventName), { inputArgs: args, data: 'path' }); }); - view.on(DownloadRegisteredModelEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadRegisteredModelEventName), { data: 'path' }); + view.on(DownloadRegisteredModelEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(DownloadRegisteredModelEventName), { inputArgs: args, data: 'path' }); + }); + view.on(VerifyImportTableEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(VerifyImportTableEventName), { inputArgs: args, data: view.importTable }); }); if (view.modelBrowsePage) { view.modelBrowsePage.modelSourceType = ModelSourceType.Azure; diff --git a/extensions/machine-learning/src/test/views/models/registerModelWizard.test.ts b/extensions/machine-learning/src/test/views/models/registerModelWizard.test.ts index fae94d875e..c4a4a5d0ff 100644 --- a/extensions/machine-learning/src/test/views/models/registerModelWizard.test.ts +++ b/extensions/machine-learning/src/test/views/models/registerModelWizard.test.ts @@ -7,24 +7,78 @@ import * as azdata from 'azdata'; import * as should from 'should'; import 'mocha'; import { createContext } from './utils'; -import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType, ListDatabaseNamesEventName, ListTableNamesEventName } from '../../../views/models/modelViewBase'; +import { ListModelsEventName, ListAccountsEventName, ListSubscriptionsEventName, ListGroupsEventName, ListWorkspacesEventName, ListAzureModelsEventName, ModelSourceType, ListDatabaseNamesEventName, ListTableNamesEventName, VerifyImportTableEventName } from '../../../views/models/modelViewBase'; import { ImportedModel } from '../../../modelManagement/interfaces'; import { azureResource } from '../../../typings/azure-resource'; import { Workspace } from '@azure/arm-machinelearningservices/esm/models'; import { ViewBase } from '../../../views/viewBase'; import { WorkspaceModel } from '../../../modelManagement/interfaces'; import { ImportModelWizard } from '../../../views/models/manageModels/importModelWizard'; +import { DatabaseTable } from '../../../prediction/interfaces'; +let accounts: azdata.Account[] = [ + { + key: { + accountId: '1', + providerId: '' + }, + displayInfo: { + displayName: 'account', + userId: '', + accountType: '', + contextualDisplayName: '' + }, + isStale: false, + properties: [] + } +]; +let subscriptions: azureResource.AzureResourceSubscription[] = [ + { + name: 'subscription', + id: '2' + } +]; +let groups: azureResource.AzureResourceResourceGroup[] = [ + { + name: 'group', + id: '3' + } +]; +let workspaces: Workspace[] = [ + { + name: 'workspace', + id: '4' + } +]; +let models: WorkspaceModel[] = [ + { + id: '5', + name: 'model' + } +]; +let localModels: ImportedModel[] = [ + { + id: 1, + modelName: 'model', + table: { + databaseName: 'db', + tableName: 'tb', + schema: 'dbo' + } + } +]; + +let importTable: DatabaseTable = { + databaseName: 'db', + tableName: 'tb', + schema: 'dbo' +}; describe('Register Model Wizard', () => { it('Should create view components successfully ', async function (): Promise { let testContext = createContext(); let view = new ImportModelWizard(testContext.apiWrapper.object, ''); - view.importTable = { - databaseName: 'db', - tableName: 'table', - schema: 'dbo' - }; + view.importTable = importTable; await view.open(); should.notEqual(view.wizardView, undefined); should.notEqual(view.modelSourcePage, undefined); @@ -34,98 +88,56 @@ describe('Register Model Wizard', () => { let testContext = createContext(); let view = new ImportModelWizard(testContext.apiWrapper.object, ''); - view.importTable = { - databaseName: 'db', - tableName: 'tb', - schema: 'dbo' - }; + view.importTable = importTable; await view.open(); - let accounts: azdata.Account[] = [ - { - key: { - accountId: '1', - providerId: '' - }, - displayInfo: { - displayName: 'account', - userId: '', - accountType: '', - contextualDisplayName: '' - }, - isStale: false, - properties: [] - } - ]; - let subscriptions: azureResource.AzureResourceSubscription[] = [ - { - name: 'subscription', - id: '2' - } - ]; - let groups: azureResource.AzureResourceResourceGroup[] = [ - { - name: 'group', - id: '3' - } - ]; - let workspaces: Workspace[] = [ - { - name: 'workspace', - id: '4' - } - ]; - let models: WorkspaceModel[] = [ - { - id: '5', - name: 'model' - } - ]; - let localModels: ImportedModel[] = [ - { - id: 1, - modelName: 'model', - table: { - databaseName: 'db', - tableName: 'tb', - schema: 'dbo' - } - } - ]; - view.on(ListModelsEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { data: localModels }); - }); - view.on(ListDatabaseNamesEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { data: [ - 'db', 'db1' - ] }); - }); - view.on(ListTableNamesEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { data: [ - 'tb', 'tb1' - ] }); - }); - view.on(ListAccountsEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { data: accounts }); - }); - view.on(ListSubscriptionsEventName, () => { - - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { data: subscriptions }); - }); - view.on(ListGroupsEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { data: groups }); - }); - view.on(ListWorkspacesEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { data: workspaces }); - }); - view.on(ListAzureModelsEventName, () => { - view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { data: models }); - }); + setEvents(view); + await view.refresh(); + should.notEqual(view.modelBrowsePage, undefined); if (view.modelBrowsePage) { view.modelBrowsePage.modelSourceType = ModelSourceType.Azure; + await view.modelBrowsePage.refresh(); + should.equal(view.modelBrowsePage.modelSourceType, ModelSourceType.Azure); } - await view.refresh(); - should.notEqual(view.azureModelsComponent?.data ,undefined); + should.notEqual(view.azureModelsComponent?.data, undefined); should.notEqual(view.localModelsComponent?.data, undefined); }); + + function setEvents(view: ImportModelWizard): void { + view.on(ListModelsEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListModelsEventName), { inputArgs: args, data: localModels }); + }); + view.on(ListDatabaseNamesEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListDatabaseNamesEventName), { + inputArgs: args, data: [ + 'db', 'db1' + ] + }); + }); + view.on(ListTableNamesEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListTableNamesEventName), { + inputArgs: args, data: [ + 'tb', 'tb1' + ] + }); + }); + view.on(ListAccountsEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAccountsEventName), { inputArgs: args, data: accounts }); + }); + view.on(ListSubscriptionsEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListSubscriptionsEventName), { inputArgs: args, data: subscriptions }); + }); + view.on(ListGroupsEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListGroupsEventName), { inputArgs: args, data: groups }); + }); + view.on(ListWorkspacesEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListWorkspacesEventName), { inputArgs: args, data: workspaces }); + }); + view.on(ListAzureModelsEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(ListAzureModelsEventName), { inputArgs: args, data: models }); + }); + view.on(VerifyImportTableEventName, (args) => { + view.sendCallbackRequest(ViewBase.getCallbackEventName(VerifyImportTableEventName), { inputArgs: args, data: view.importTable }); + }); + } }); diff --git a/extensions/machine-learning/src/test/views/utils.ts b/extensions/machine-learning/src/test/views/utils.ts index 25a1432c6d..2ac5441332 100644 --- a/extensions/machine-learning/src/test/views/utils.ts +++ b/extensions/machine-learning/src/test/views/utils.ts @@ -240,7 +240,7 @@ export function createViewContext(): ViewTestContext { try { await handler(view); } catch (err) { - throw err; + console.log(err); } }, onValidityChanged: undefined!, @@ -305,7 +305,7 @@ export function createViewContext(): ViewTestContext { try { await handler(view); } catch (err) { - throw err; + console.log(err); } }, modelView: undefined!, diff --git a/extensions/machine-learning/src/views/controllerBase.ts b/extensions/machine-learning/src/views/controllerBase.ts index 22babbbf74..b03788d100 100644 --- a/extensions/machine-learning/src/views/controllerBase.ts +++ b/extensions/machine-learning/src/views/controllerBase.ts @@ -22,14 +22,14 @@ export abstract class ControllerBase { /** * Executes an action and sends back callback event to the view */ - public async executeAction(dialog: T, eventName: string, func: (...args: any[]) => Promise, ...args: any[]): Promise { + public async executeAction(dialog: T, eventName: string, inputArgs: any, func: (...args: any[]) => Promise, ...args: any[]): Promise { const callbackEvent = ViewBase.getCallbackEventName(eventName); try { let result = await func(...args); - dialog.sendCallbackRequest(callbackEvent, { data: result }); + dialog.sendCallbackRequest(callbackEvent, { inputArgs: inputArgs, data: result }); } catch (error) { - dialog.sendCallbackRequest(callbackEvent, { error: error }); + dialog.sendCallbackRequest(callbackEvent, { inputArgs: inputArgs, error: error }); } } @@ -39,7 +39,7 @@ export abstract class ControllerBase { */ public registerEvents(view: ViewBase): void { view.on(LocalPathsEventName, async (args) => { - await this.executeAction(view, LocalPathsEventName, this.getLocalPaths, this._apiWrapper, args); + await this.executeAction(view, LocalPathsEventName, args, this.getLocalPaths, this._apiWrapper, args); }); } diff --git a/extensions/machine-learning/src/views/models/azureModelsTable.ts b/extensions/machine-learning/src/views/models/azureModelsTable.ts index eeeb581669..eb2e8f7ef7 100644 --- a/extensions/machine-learning/src/views/models/azureModelsTable.ts +++ b/extensions/machine-learning/src/views/models/azureModelsTable.ts @@ -65,9 +65,22 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent { return { displayName: a.displayInfo.displayName, name: a.key.accountId }; }); this._accounts.values = values; this._accounts.value = values[0]; + } else { + this._accounts.values = []; + this._accounts.value = undefined; } await this.onAccountSelected(); } @@ -161,6 +164,9 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData let values = this._azureSubscriptions.map(s => { return { displayName: s.name, name: s.id }; }); this._subscriptions.values = values; this._subscriptions.value = values[0]; + } else { + this._subscriptions.values = []; + this._subscriptions.value = undefined; } await this.onSubscriptionSelected(); } @@ -171,6 +177,9 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData let values = this._azureGroups.map(s => { return { displayName: s.name, name: s.id }; }); this._groups.values = values; this._groups.value = values[0]; + } else { + this._groups.values = []; + this._groups.value = undefined; } await this.onGroupSelected(); } @@ -181,6 +190,9 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData let values = this._azureWorkspaces.map(s => { return { displayName: s.name || '', name: s.id || '' }; }); this._workspaces.values = values; this._workspaces.value = values[0]; + } else { + this._workspaces.values = []; + this._workspaces.value = undefined; } this.onWorkspaceSelectedChanged(); } diff --git a/extensions/machine-learning/src/views/models/localModelsComponent.ts b/extensions/machine-learning/src/views/models/localModelsComponent.ts index b74f786282..1306d72938 100644 --- a/extensions/machine-learning/src/views/models/localModelsComponent.ts +++ b/extensions/machine-learning/src/views/models/localModelsComponent.ts @@ -39,10 +39,7 @@ export class LocalModelsComponent extends ModelViewBase implements IDataComponen }).component(); this._localBrowse = modelBuilder.button().withProperties({ label: constants.browseModels, - width: this.browseButtonMaxLength, - CSSStyles: { - 'text-align': 'end' - } + width: this.browseButtonMaxLength }).component(); this._localBrowse.onDidClick(async () => { @@ -65,7 +62,7 @@ export class LocalModelsComponent extends ModelViewBase implements IDataComponen .withLayout({ flexFlow: 'row', justifyContent: 'space-between', - width: this.componentMaxLength + width: this.componentMaxLength + 200 }).withItems([ this._localPath, this._localBrowse] ).component(); @@ -80,9 +77,9 @@ export class LocalModelsComponent extends ModelViewBase implements IDataComponen public addComponents(formBuilder: azdata.FormBuilder) { if (this._flex) { formBuilder.addFormItem({ - title: '', + title: constants.modelLocalSourceTitle, component: this._flex - }); + }, { info: constants.modelLocalSourceTooltip }); } } diff --git a/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts b/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts index 244136173f..edb751114b 100644 --- a/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts +++ b/extensions/machine-learning/src/views/models/manageModels/currentModelsComponent.ts @@ -36,8 +36,8 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView { * @param modelBuilder register the components */ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { - this._tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, false); - this._tableSelectionComponent.registerComponent(modelBuilder); + this._tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, { editable: false, preSelected: true }); + this._tableSelectionComponent.registerComponent(modelBuilder, constants.databaseName, constants.tableName); this._tableSelectionComponent.onSelectedChanged(async () => { await this.onTableSelected(); }); @@ -110,7 +110,9 @@ export class CurrentModelsComponent extends ModelViewBase implements IPageView { if (this._tableSelectionComponent?.data) { this.importTable = this._tableSelectionComponent?.data; await this.storeImportConfigTable(); - await this._dataTable?.refresh(); + if (this._dataTable) { + await this._dataTable.refresh(); + } } } diff --git a/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts b/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts index f9dfb22596..ab98d9ad62 100644 --- a/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts +++ b/extensions/machine-learning/src/views/models/manageModels/currentModelsTable.ts @@ -66,6 +66,32 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< ...constants.cssStyles.tableRow }, }, + { // Version + displayName: constants.modelVersion, + ariaLabel: constants.modelVersion, + valueType: azdata.DeclarativeDataType.string, + isReadOnly: true, + width: 150, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + }, + { // Format + displayName: constants.modelFramework, + ariaLabel: constants.modelFramework, + valueType: azdata.DeclarativeDataType.string, + isReadOnly: true, + width: 150, + headerCssStyles: { + ...constants.cssStyles.tableHeader + }, + rowCssStyles: { + ...constants.cssStyles.tableRow + }, + }, { // Action displayName: '', valueType: azdata.DeclarativeDataType.component, @@ -113,13 +139,13 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< public addComponents(formBuilder: azdata.FormBuilder) { if (this.component) { - formBuilder.addFormItem({ title: constants.modelSourcesTitle, component: this.component }); + formBuilder.addFormItem({ title: '', component: this.component }); } } public removeComponents(formBuilder: azdata.FormBuilder) { if (this.component) { - formBuilder.removeFormItem({ title: constants.modelSourcesTitle, component: this.component }); + formBuilder.removeFormItem({ title: '', component: this.component }); } } @@ -169,7 +195,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent< } private createTableRow(model: ImportedModel): any[] { - let row: any[] = [model.modelName, model.created]; + let row: any[] = [model.modelName, model.created, model.version, model.framework]; if (this._modelBuilder) { const selectButton = this.createSelectButton(model); if (selectButton) { diff --git a/extensions/machine-learning/src/views/models/manageModels/importModelWizard.ts b/extensions/machine-learning/src/views/models/manageModels/importModelWizard.ts index c388d9b356..1f5ef99f13 100644 --- a/extensions/machine-learning/src/views/models/manageModels/importModelWizard.ts +++ b/extensions/machine-learning/src/views/models/manageModels/importModelWizard.ts @@ -58,11 +58,9 @@ export class ImportModelWizard extends ModelViewBase { validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false; } if (validated && pageInfo.newPage === undefined) { - wizard.cancelButton.enabled = false; - wizard.backButton.enabled = false; + this.onLoading(); let result = await this.registerModel(); - wizard.cancelButton.enabled = true; - wizard.backButton.enabled = true; + this.onLoaded(); if (this._parentView) { this._parentView.importTable = this.importTable; await this._parentView.refresh(); @@ -76,6 +74,21 @@ export class ImportModelWizard extends ModelViewBase { await wizard.open(); } + private onLoading(): void { + this.refreshButtons(true); + } + + private onLoaded(): void { + this.refreshButtons(false); + } + + private refreshButtons(loading: boolean): void { + if (this.wizardView && this.wizardView.wizard) { + this.wizardView.wizard.cancelButton.enabled = !loading; + this.wizardView.wizard.backButton.enabled = !loading; + } + } + public get modelResources(): ModelSourcesComponent | undefined { return this.modelSourcePage?.modelResources; } diff --git a/extensions/machine-learning/src/views/models/manageModels/manageModelsDialog.ts b/extensions/machine-learning/src/views/models/manageModels/manageModelsDialog.ts index 7e16f55220..cd4d49b88e 100644 --- a/extensions/machine-learning/src/views/models/manageModels/manageModelsDialog.ts +++ b/extensions/machine-learning/src/views/models/manageModels/manageModelsDialog.ts @@ -40,6 +40,7 @@ export class ManageModelsDialog extends ModelViewBase { }); let dialog = this.dialogView.createDialog(constants.registerModelTitle, [this.currentLanguagesTab]); + dialog.isWide = true; dialog.customButtons = [registerModelButton]; this.mainViewPanel = dialog; dialog.okButton.hidden = true; diff --git a/extensions/machine-learning/src/views/models/manageModels/modelDetailsComponent.ts b/extensions/machine-learning/src/views/models/manageModels/modelDetailsComponent.ts index eccc01449d..4c132c78fa 100644 --- a/extensions/machine-learning/src/views/models/manageModels/modelDetailsComponent.ts +++ b/extensions/machine-learning/src/views/models/manageModels/modelDetailsComponent.ts @@ -78,7 +78,7 @@ export class ModelDetailsComponent extends ModelViewBase implements IDataCompone component: this._createdComponent }, { - title: constants.modelDeployed, + title: constants.modelImported, component: this._deployedComponent }, { title: constants.modelFramework, diff --git a/extensions/machine-learning/src/views/models/manageModels/modelImportLocationPage.ts b/extensions/machine-learning/src/views/models/manageModels/modelImportLocationPage.ts index 740d9af157..c319fc1f1f 100644 --- a/extensions/machine-learning/src/views/models/manageModels/modelImportLocationPage.ts +++ b/extensions/machine-learning/src/views/models/manageModels/modelImportLocationPage.ts @@ -19,6 +19,7 @@ export class ModelImportLocationPage extends ModelViewBase implements IPageView, private _form: azdata.FormContainer | undefined; private _formBuilder: azdata.FormBuilder | undefined; public tableSelectionComponent: TableSelectionComponent | undefined; + private _labelComponent: azdata.TextComponent | undefined; constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) { super(apiWrapper, parent.root, parent); @@ -31,12 +32,35 @@ export class ModelImportLocationPage extends ModelViewBase implements IPageView, public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { this._formBuilder = modelBuilder.formContainer(); - this.tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, true); + this.tableSelectionComponent = new TableSelectionComponent(this._apiWrapper, this, { editable: true, preSelected: true }); + this._labelComponent = modelBuilder.text().withProperties({ + width: 200 + }).component(); + const container = modelBuilder.flexContainer().withLayout({ + width: 800, + height: '400px', + justifyContent: 'center' + }).withItems([ + this._labelComponent + ], { + CSSStyles: { + 'align-items': 'center', + 'padding-top': '30px', + 'font-size': '16px' + } + }).component(); + + this.tableSelectionComponent.onSelectedChanged(async () => { await this.onTableSelected(); }); - this.tableSelectionComponent.registerComponent(modelBuilder); + this.tableSelectionComponent.registerComponent(modelBuilder, constants.databaseName, constants.tableName); this.tableSelectionComponent.addComponents(this._formBuilder); + + this._formBuilder.addFormItem({ + title: '', + component: container + }); this._form = this._formBuilder.component(); return this._form; } @@ -45,6 +69,15 @@ export class ModelImportLocationPage extends ModelViewBase implements IPageView, if (this.tableSelectionComponent?.data) { this.importTable = this.tableSelectionComponent?.data; } + + if (this.importTable && this._labelComponent) { + const validated = await this.verifyImportConfigTable(this.importTable); + if (validated) { + this._labelComponent.value = constants.modelSchemaIsAcceptedMessage; + } else { + this._labelComponent.value = constants.modelSchemaIsNotAcceptedMessage; + } + } } /** diff --git a/extensions/machine-learning/src/views/models/modelBrowsePage.ts b/extensions/machine-learning/src/views/models/modelBrowsePage.ts index 90bc74b0d7..af1f1c2a32 100644 --- a/extensions/machine-learning/src/views/models/modelBrowsePage.ts +++ b/extensions/machine-learning/src/views/models/modelBrowsePage.ts @@ -46,7 +46,6 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo editable: false }); this.registeredModelsComponent.registerComponent(modelBuilder); - this.refresh(); this._form = this._formBuilder.component(); return this._form; } @@ -173,6 +172,7 @@ export class ModelBrowsePage extends ModelViewBase implements IPageView, IDataCo fileName: x.model?.name, framework: x.model?.framework, frameworkVersion: x.model?.frameworkVersion, + description: x.model?.description, created: x.model?.createdTime }, targetImportTable: this.importTable diff --git a/extensions/machine-learning/src/views/models/modelManagementController.ts b/extensions/machine-learning/src/views/models/modelManagementController.ts index e0925477a0..ce3a72613e 100644 --- a/extensions/machine-learning/src/views/models/modelManagementController.ts +++ b/extensions/machine-learning/src/views/models/modelManagementController.ts @@ -40,19 +40,19 @@ export class ModelManagementController extends ControllerBase { apiWrapper: ApiWrapper, private _root: string, private _amlService: AzureModelRegistryService, - private _registeredModelService: DeployedModelService, + private _deployedModelService: DeployedModelService, private _predictService: PredictService) { super(apiWrapper); } /** - * Opens the dialog for model registration + * Opens the dialog for model import * @param parent parent if the view is opened from another view * @param controller controller * @param apiWrapper apiWrapper * @param root root folder path */ - public async registerModel(importTable: DatabaseTable | undefined, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise { + public async importModel(importTable: DatabaseTable | undefined, parent?: ModelViewBase, controller?: ModelManagementController, apiWrapper?: ApiWrapper, root?: string): Promise { controller = controller || this; apiWrapper = apiWrapper || this._apiWrapper; root = root || this._root; @@ -60,7 +60,7 @@ export class ModelManagementController extends ControllerBase { if (importTable) { view.importTable = importTable; } else { - view.importTable = await controller._registeredModelService.getRecentImportTable(); + view.importTable = await controller._deployedModelService.getRecentImportTable(); } controller.registerEvents(view); @@ -93,23 +93,31 @@ export class ModelManagementController extends ControllerBase { /** * Opens the wizard for prediction */ - public async predictModel(): Promise { + public async predictModel(): Promise { - let view = new PredictWizard(this._apiWrapper, this._root); - view.importTable = await this._registeredModelService.getRecentImportTable(); + const onnxSupported = await this._predictService.serverSupportOnnxModel(); + if (onnxSupported) { + await this._deployedModelService.installDependencies(); + let view = new PredictWizard(this._apiWrapper, this._root); + view.importTable = await this._deployedModelService.getRecentImportTable(); - this.registerEvents(view); - view.on(LoadModelParametersEventName, async () => { - const modelArtifact = await view.getModelFileName(); - await this.executeAction(view, LoadModelParametersEventName, this.loadModelParameters, this._registeredModelService, - modelArtifact?.filePath); - }); + this.registerEvents(view); - // Open view - // - await view.open(); - await view.refresh(); - return view; + view.on(LoadModelParametersEventName, async (args) => { + const modelArtifact = await view.getModelFileName(); + await this.executeAction(view, LoadModelParametersEventName, args, this.loadModelParameters, this._deployedModelService, + modelArtifact?.filePath); + }); + + // Open view + // + await view.open(); + await view.refresh(); + return view; + } else { + this._apiWrapper.showErrorMessage(constants.onnxNotSupportedError); + return undefined; + } } @@ -122,99 +130,99 @@ export class ModelManagementController extends ControllerBase { // Register events // super.registerEvents(view); - view.on(ListAccountsEventName, async () => { - await this.executeAction(view, ListAccountsEventName, this.getAzureAccounts, this._amlService); + view.on(ListAccountsEventName, async (args) => { + await this.executeAction(view, ListAccountsEventName, args, this.getAzureAccounts, this._amlService); }); - view.on(ListSubscriptionsEventName, async (arg) => { - let azureArgs = arg; - await this.executeAction(view, ListSubscriptionsEventName, this.getAzureSubscriptions, this._amlService, azureArgs.account); + view.on(ListSubscriptionsEventName, async (args) => { + let azureArgs = args; + await this.executeAction(view, ListSubscriptionsEventName, args, this.getAzureSubscriptions, this._amlService, azureArgs.account); }); - view.on(ListWorkspacesEventName, async (arg) => { - let azureArgs = arg; - await this.executeAction(view, ListWorkspacesEventName, this.getWorkspaces, this._amlService, azureArgs.account, azureArgs.subscription, azureArgs.group); + view.on(ListWorkspacesEventName, async (args) => { + let azureArgs = args; + await this.executeAction(view, ListWorkspacesEventName, args, this.getWorkspaces, this._amlService, azureArgs.account, azureArgs.subscription, azureArgs.group); }); - view.on(ListGroupsEventName, async (arg) => { - let azureArgs = arg; - await this.executeAction(view, ListGroupsEventName, this.getAzureGroups, this._amlService, azureArgs.account, azureArgs.subscription); + view.on(ListGroupsEventName, async (args) => { + let azureArgs = args; + await this.executeAction(view, ListGroupsEventName, args, this.getAzureGroups, this._amlService, azureArgs.account, azureArgs.subscription); }); - view.on(ListAzureModelsEventName, async (arg) => { - let azureArgs = arg; - await this.executeAction(view, ListAzureModelsEventName, this.getAzureModels, this._amlService + view.on(ListAzureModelsEventName, async (args) => { + let azureArgs = args; + await this.executeAction(view, ListAzureModelsEventName, args, this.getAzureModels, this._amlService , azureArgs.account, azureArgs.subscription, azureArgs.group, azureArgs.workspace); }); view.on(ListModelsEventName, async (args) => { const table = args; - await this.executeAction(view, ListModelsEventName, this.getRegisteredModels, this._registeredModelService, table); + await this.executeAction(view, ListModelsEventName, args, this.getRegisteredModels, this._deployedModelService, table); }); - view.on(RegisterLocalModelEventName, async (arg) => { - let models = arg; - await this.executeAction(view, RegisterLocalModelEventName, this.registerLocalModel, this._registeredModelService, models); + view.on(RegisterLocalModelEventName, async (args) => { + let models = args; + await this.executeAction(view, RegisterLocalModelEventName, args, this.registerLocalModel, this._deployedModelService, models); view.refresh(); }); view.on(RegisterModelEventName, async (args) => { const importTable = args; - await this.executeAction(view, RegisterModelEventName, this.registerModel, importTable, view, this, this._apiWrapper, this._root); + await this.executeAction(view, RegisterModelEventName, args, this.importModel, importTable, view, this, this._apiWrapper, this._root); }); view.on(EditModelEventName, async (args) => { const model = args; - await this.executeAction(view, EditModelEventName, this.editModel, model, view, this, this._apiWrapper, this._root); + await this.executeAction(view, EditModelEventName, args, this.editModel, model, view, this, this._apiWrapper, this._root); }); view.on(UpdateModelEventName, async (args) => { const model = args; - await this.executeAction(view, UpdateModelEventName, this.updateModel, this._registeredModelService, model); + await this.executeAction(view, UpdateModelEventName, args, this.updateModel, this._deployedModelService, model); }); view.on(DeleteModelEventName, async (args) => { const model = args; - await this.executeAction(view, DeleteModelEventName, this.deleteModel, this._registeredModelService, model); + await this.executeAction(view, DeleteModelEventName, args, this.deleteModel, this._deployedModelService, model); }); - view.on(RegisterAzureModelEventName, async (arg) => { - let models = arg; - await this.executeAction(view, RegisterAzureModelEventName, this.registerAzureModel, this._amlService, this._registeredModelService, + view.on(RegisterAzureModelEventName, async (args) => { + let models = args; + await this.executeAction(view, RegisterAzureModelEventName, args, this.registerAzureModel, this._amlService, this._deployedModelService, models); }); - view.on(DownloadAzureModelEventName, async (arg) => { - let registerArgs = arg; - await this.executeAction(view, DownloadAzureModelEventName, this.downloadAzureModel, this._amlService, + view.on(DownloadAzureModelEventName, async (args) => { + let registerArgs = args; + await this.executeAction(view, DownloadAzureModelEventName, args, this.downloadAzureModel, this._amlService, registerArgs.account, registerArgs.subscription, registerArgs.group, registerArgs.workspace, registerArgs.model); }); - view.on(ListDatabaseNamesEventName, async () => { - await this.executeAction(view, ListDatabaseNamesEventName, this.getDatabaseList, this._predictService); + view.on(ListDatabaseNamesEventName, async (args) => { + await this.executeAction(view, ListDatabaseNamesEventName, args, this.getDatabaseList, this._predictService); }); - view.on(ListTableNamesEventName, async (arg) => { - let dbName = arg; - await this.executeAction(view, ListTableNamesEventName, this.getTableList, this._predictService, dbName); + view.on(ListTableNamesEventName, async (args) => { + let dbName = args; + await this.executeAction(view, ListTableNamesEventName, args, this.getTableList, this._predictService, dbName); }); - view.on(ListColumnNamesEventName, async (arg) => { - let tableColumnsArgs = arg; - await this.executeAction(view, ListColumnNamesEventName, this.getTableColumnsList, this._predictService, + view.on(ListColumnNamesEventName, async (args) => { + let tableColumnsArgs = args; + await this.executeAction(view, ListColumnNamesEventName, args, this.getTableColumnsList, this._predictService, tableColumnsArgs); }); - view.on(PredictModelEventName, async (arg) => { - let predictArgs = arg; - await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService, + view.on(PredictModelEventName, async (args) => { + let predictArgs = args; + await this.executeAction(view, PredictModelEventName, args, this.generatePredictScript, this._predictService, predictArgs, predictArgs.model, predictArgs.filePath); }); - view.on(DownloadRegisteredModelEventName, async (arg) => { - let model = arg; - await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService, + view.on(DownloadRegisteredModelEventName, async (args) => { + let model = args; + await this.executeAction(view, DownloadRegisteredModelEventName, args, this.downloadRegisteredModel, this._deployedModelService, model); }); - view.on(StoreImportTableEventName, async (arg) => { - let importTable = arg; - await this.executeAction(view, StoreImportTableEventName, this.storeImportTable, this._registeredModelService, + view.on(StoreImportTableEventName, async (args) => { + let importTable = args; + await this.executeAction(view, StoreImportTableEventName, args, this.storeImportTable, this._deployedModelService, importTable); }); - view.on(VerifyImportTableEventName, async (arg) => { - let importTable = arg; - await this.executeAction(view, VerifyImportTableEventName, this.verifyImportTable, this._registeredModelService, + view.on(VerifyImportTableEventName, async (args) => { + let importTable = args; + await this.executeAction(view, VerifyImportTableEventName, args, this.verifyImportTable, this._deployedModelService, importTable); }); - view.on(SourceModelSelectedEventName, async (arg) => { - view.modelSourceType = arg; + view.on(SourceModelSelectedEventName, async (args) => { + view.modelSourceType = args; await view.refresh(); }); - view.on(SignInToAzureEventName, async () => { - await this.executeAction(view, SignInToAzureEventName, this.signInToAzure, this._amlService); + view.on(SignInToAzureEventName, async (args) => { + await this.executeAction(view, SignInToAzureEventName, args, this.signInToAzure, this._amlService); await view.refresh(); }); } @@ -228,7 +236,7 @@ export class ModelManagementController extends ControllerBase { if (importTable) { view.importTable = importTable; } else { - view.importTable = await this._registeredModelService.getRecentImportTable(); + view.importTable = await this._deployedModelService.getRecentImportTable(); } // Register events diff --git a/extensions/machine-learning/src/views/models/modelSourcesComponent.ts b/extensions/machine-learning/src/views/models/modelSourcesComponent.ts index aa75f6f153..9fa58db490 100644 --- a/extensions/machine-learning/src/views/models/modelSourcesComponent.ts +++ b/extensions/machine-learning/src/views/models/modelSourcesComponent.ts @@ -31,12 +31,14 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone */ public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { + this._sourceType = this._options && this._options.length > 0 ? this._options[0] : ModelSourceType.Local; + this.modelSourceType = this._sourceType; this._localModel = modelBuilder.card() .withProperties({ value: 'local', name: 'modelLocation', label: constants.localModelSource, - selected: this._options[0] === ModelSourceType.Local, + selected: this._sourceType === ModelSourceType.Local, cardType: azdata.CardType.VerticalButton, width: 50 }).component(); @@ -45,7 +47,7 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone value: 'aml', name: 'modelLocation', label: constants.azureModelSource, - selected: this._options[0] === ModelSourceType.Azure, + selected: this._sourceType === ModelSourceType.Azure, cardType: azdata.CardType.VerticalButton, width: 50 }).component(); @@ -55,7 +57,7 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone value: 'registered', name: 'modelLocation', label: constants.registeredModelsSource, - selected: this._options[0] === ModelSourceType.RegisteredModels, + selected: this._sourceType === ModelSourceType.RegisteredModels, cardType: azdata.CardType.VerticalButton, width: 50 }).component(); @@ -105,9 +107,6 @@ export class ModelSourcesComponent extends ModelViewBase implements IDataCompone break; } }); - this._sourceType = this._options[0]; - this.sendRequest(SourceModelSelectedEventName, this._sourceType); - this._flexContainer = modelBuilder.flexContainer() .withLayout({ flexFlow: 'row', diff --git a/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts b/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts index 6eb45172df..106934304f 100644 --- a/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts +++ b/extensions/machine-learning/src/views/models/prediction/columnsSelectionPage.ts @@ -74,6 +74,18 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID } } + public async validate(): Promise { + const data = this.data; + const validated = data !== undefined && data.databaseName !== undefined && data.inputColumns !== undefined && data.outputColumns !== undefined + && data.tableName !== undefined && data.databaseName !== constants.selectDatabaseTitle && data.tableName !== constants.selectTableTitle + && !data.inputColumns.find(x => x.columnName === constants.selectColumnTitle); + if (!validated) { + this.showErrorMessage(constants.invalidModelParametersError); + } + + return Promise.resolve(validated); + } + public async onEnter(): Promise { await this.inputColumnsComponent?.onLoading(); await this.outputColumnsComponent?.onLoading(); diff --git a/extensions/machine-learning/src/views/models/prediction/columnsTable.ts b/extensions/machine-learning/src/views/models/prediction/columnsTable.ts index 717cae893f..6600f2b032 100644 --- a/extensions/machine-learning/src/views/models/prediction/columnsTable.ts +++ b/extensions/machine-learning/src/views/models/prediction/columnsTable.ts @@ -173,7 +173,12 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent this.createInputTableRow(input, columns))); } @@ -212,6 +217,10 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent x === modelParameter.type); if (dataType) { nameInput.value = dataType; + } else { + // Output type not supported + // + modelParameter.type = dataTypes[0]; } this._parameters.push({ columnName: name, paramName: name, dataType: modelParameter.type }); @@ -234,7 +243,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent { return { name: c.columnName, displayName: `${c.columnName}(${c.dataType})` }; }); + + let values = columns.map(c => { return { name: c.columnName, displayName: `${c.columnName}(${c.dataType})` }; }); + if (columns.length > 0 && columns[0].columnName !== constants.selectColumnTitle) { + values = [{ displayName: constants.selectColumnTitle, name: '' }].concat(values); + } let nameInput = this._modelBuilder.dropDown().withProperties({ values: values, width: this.componentMaxLength @@ -250,11 +263,28 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent x.name === modelParameter.name); if (!column) { - column = values[0]; + column = values.length > 0 ? values[0] : undefined; } + const currentColumn = columns.find(x => x.columnName === column?.name); nameInput.value = column; - this._parameters.push({ columnName: column.name, paramName: name }); + if (column) { + this._parameters.push({ columnName: column.name, paramName: name, paramType: modelParameter.type }); + } + const inputContainer = this._modelBuilder.flexContainer().withLayout({ + flexFlow: 'row', + width: this.componentMaxLength + 20, + justifyContent: 'flex-start' + }).component(); + const warningButton = this.createWarningButton(); + warningButton.onDidClick(() => { + }); + + const css = { + 'padding-top': '5px', + 'padding-right': '5px', + 'margin': '0px' + }; nameInput.onValueChanged(() => { const selectedColumn = nameInput.value; @@ -264,12 +294,36 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent x.columnName === value); + if (currentColumn && modelParameter.type === currentColumn?.dataType) { + inputContainer.removeItem(warningButton); + } else { + inputContainer.addItem(warningButton, { + CSSStyles: css + }); + } }); + const label = this._modelBuilder.inputBox().withProperties({ - value: `${name}(${modelParameter.type ? modelParameter.type : constants.unsupportedModelParameterType})`, + value: `${name}(${modelParameter.originalType ? modelParameter.originalType : constants.unsupportedModelParameterType})`, enabled: false, width: this.componentMaxLength }).component(); + + + inputContainer.addItem(label, { + CSSStyles: { + 'padding': '0px', + 'padding-right': '5px', + 'margin': '0px' + } + }); + if (currentColumn && modelParameter.type !== currentColumn?.dataType) { + inputContainer.addItem(warningButton, { + CSSStyles: css + }); + } const image = this._modelBuilder.image().withProperties({ width: 50, height: 50, @@ -281,12 +335,28 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent { await this.onTableSelected(); }); diff --git a/extensions/machine-learning/src/views/models/prediction/predictWizard.ts b/extensions/machine-learning/src/views/models/prediction/predictWizard.ts index 82c0e8b9a1..dae96f9dfa 100644 --- a/extensions/machine-learning/src/views/models/prediction/predictWizard.ts +++ b/extensions/machine-learning/src/views/models/prediction/predictWizard.ts @@ -62,7 +62,7 @@ export class PredictWizard extends ModelViewBase { }); wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => { let validated: boolean = true; - if (pageInfo.newPage > pageInfo.lastPage) { + if (pageInfo.newPage === undefined || pageInfo.newPage > pageInfo.lastPage) { validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false; } if (validated) { @@ -94,7 +94,7 @@ export class PredictWizard extends ModelViewBase { private refreshButtons(loading: boolean): void { if (this.wizardView && this.wizardView.wizard) { this.wizardView.wizard.cancelButton.enabled = !loading; - this.wizardView.wizard.cancelButton.enabled = !loading; + this.wizardView.wizard.backButton.enabled = !loading; } } diff --git a/extensions/machine-learning/src/views/models/tableSelectionComponent.ts b/extensions/machine-learning/src/views/models/tableSelectionComponent.ts index 8259427bc9..06eb5614cd 100644 --- a/extensions/machine-learning/src/views/models/tableSelectionComponent.ts +++ b/extensions/machine-learning/src/views/models/tableSelectionComponent.ts @@ -7,10 +7,14 @@ import * as azdata from 'azdata'; import * as vscode from 'vscode'; import { ModelViewBase } from './modelViewBase'; import { ApiWrapper } from '../../common/apiWrapper'; -import * as constants from '../../common/constants'; import { IDataComponent } from '../interfaces'; import { DatabaseTable } from '../../prediction/interfaces'; +import * as constants from '../../common/constants'; +export interface ITableSelectionSettings { + editable: boolean, + preSelected: boolean +} /** * View to render filters to pick an azure resource */ @@ -30,7 +34,7 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo /** * Creates a new view */ - constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _editable: boolean) { + constructor(apiWrapper: ApiWrapper, parent: ModelViewBase, private _settings: ITableSelectionSettings) { super(apiWrapper, parent.root, parent); } @@ -38,16 +42,16 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo * Register components * @param modelBuilder model builder */ - public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component { + public registerComponent(modelBuilder: azdata.ModelBuilder, databaseTitle: string, tableTitle: string): azdata.Component { this._databases = modelBuilder.dropDown().withProperties({ width: this.componentMaxLength, - editable: this._editable, - fireOnTextChange: this._editable + editable: this._settings.editable, + fireOnTextChange: this._settings.editable }).component(); this._tables = modelBuilder.dropDown().withProperties({ width: this.componentMaxLength, - editable: this._editable, - fireOnTextChange: this._editable + editable: this._settings.editable, + fireOnTextChange: this._settings.editable }).component(); this._databases.onValueChanged(async () => { @@ -58,21 +62,21 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo // There's an issue with dropdown doesn't set the value in editable mode. this is the workaround if (this._tables && value) { - this._selectedTableName = this._editable ? value : value.selected; + this._selectedTableName = this._settings.editable ? value : value.selected; } await this.onTableSelected(); }); const databaseForm = modelBuilder.formContainer().withFormItems([{ - title: constants.columnDatabase, + title: databaseTitle, component: this._databases, - }]).withLayout({ + }], { info: databaseTitle }).withLayout({ padding: '0px' }).component(); const tableForm = modelBuilder.formContainer().withFormItems([{ - title: constants.columnTable, + title: tableTitle, component: this._tables - }]).withLayout({ + }], { info: tableTitle }).withLayout({ padding: '0px' }).component(); this._dbTableComponent = modelBuilder.flexContainer().withItems([ @@ -97,27 +101,21 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo } public addComponents(formBuilder: azdata.FormBuilder) { - if (this._databases && this._tables) { + if (this._dbTableComponent) { formBuilder.addFormItems([{ - title: constants.databaseName, - component: this._databases - }, { - title: constants.tableName, - component: this._tables + title: '', + component: this._dbTableComponent }]); } } public removeComponents(formBuilder: azdata.FormBuilder) { - if (this._databases && this._tables) { + if (this._dbTableComponent) { formBuilder.removeFormItem({ - title: constants.databaseName, - component: this._databases - }); - formBuilder.removeFormItem({ - title: constants.tableName, - component: this._tables + title: '', + component: this._dbTableComponent }); + } } @@ -140,13 +138,19 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo */ public async loadData(): Promise { this._dbNames = await this.listDatabaseNames(); - if (this._databases && this._dbNames && this._dbNames.length > 0) { - this._databases.values = this._dbNames; - if (this.importTable) { + let dbNames = this._dbNames; + if (!this._settings.preSelected && !this._dbNames.find(x => x === constants.selectDatabaseTitle)) { + dbNames = [constants.selectDatabaseTitle].concat(this._dbNames); + } + if (this._databases && dbNames && dbNames.length > 0) { + this._databases.values = dbNames; + + if (this.importTable && this._settings.preSelected) { this._databases.value = this.importTable.databaseName; } else { - this._databases.value = this._dbNames[0]; + this._databases.value = dbNames[0]; } + } await this.onDatabaseSelected(); } @@ -160,18 +164,25 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo private async onDatabaseSelected(): Promise { this._tableNames = await this.listTableNames(this.databaseName || ''); - if (this._tables && this._tableNames && this._tableNames.length > 0) { - this._tables.values = this._tableNames.map(t => this.getTableFullName(t)); + let tableNames = this._tableNames; + + if (this._tableNames && !this._settings.preSelected && !this._tableNames.find(x => x.tableName === constants.selectTableTitle)) { + const firstRow: DatabaseTable = { tableName: constants.selectTableTitle, databaseName: '', schema: '' }; + tableNames = [firstRow].concat(this._tableNames); + } + + if (this._tables && tableNames && tableNames.length > 0) { + this._tables.values = tableNames.map(t => this.getTableFullName(t)); if (this.importTable) { - const selectedTable = this._tableNames.find(t => t.tableName === this.importTable?.tableName && t.schema === this.importTable?.schema); + const selectedTable = tableNames.find(t => t.tableName === this.importTable?.tableName && t.schema === this.importTable?.schema); if (selectedTable) { this._selectedTableName = this.getTableFullName(selectedTable); this._tables.value = this.getTableFullName(selectedTable); } else { - this._selectedTableName = this._editable ? this.getTableFullName(this.importTable) : this.getTableFullName(this._tableNames[0]); + this._selectedTableName = this._settings.editable ? this.getTableFullName(this.importTable) : this.getTableFullName(tableNames[0]); } } else { - this._selectedTableName = this.getTableFullName(this._tableNames[0]); + this._selectedTableName = this.getTableFullName(tableNames[0]); } this._tables.value = this._selectedTableName; } else if (this._tables) { @@ -182,7 +193,7 @@ export class TableSelectionComponent extends ModelViewBase implements IDataCompo } private getTableFullName(table: DatabaseTable): string { - return `${table.schema}.${table.tableName}`; + return table.tableName === constants.selectTableTitle ? table.tableName : `${table.schema}.${table.tableName}`; } private async onTableSelected(): Promise { diff --git a/extensions/machine-learning/src/views/viewBase.ts b/extensions/machine-learning/src/views/viewBase.ts index 8cc36d9862..36bf18043f 100644 --- a/extensions/machine-learning/src/views/viewBase.ts +++ b/extensions/machine-learning/src/views/viewBase.ts @@ -12,12 +12,7 @@ import * as path from 'path'; import { EventEmitterCollection } from '../common/eventEmitter'; export interface CallbackEventArgs { - data?: any; - error?: (reason?: any) => void; -} - - -export interface CallbackEventArgs { + inputArgs?: any; data?: any; error?: (reason?: any) => void; } @@ -95,28 +90,38 @@ export abstract class ViewBase extends EventEmitterCollection { this.fire(requestType, arg); } - public sendDataRequest( + public async sendDataRequest( eventName: string, arg?: any, callbackEventName?: string): Promise { - return new Promise((resolve, reject) => { + let emitter: vscode.EventEmitter | undefined; + let promise = new Promise((resolve, reject) => { if (!callbackEventName) { callbackEventName = ViewBase.getCallbackEventName(eventName); } - this.on(callbackEventName, result => { + emitter = this.on(callbackEventName, result => { let callbackArgs = result; if (callbackArgs) { - if (callbackArgs.error) { - reject(callbackArgs.error); - } else { - resolve(callbackArgs.data); + if (callbackArgs.inputArgs === arg) { + if (callbackArgs.error) { + reject(callbackArgs.error); + } else { + resolve(callbackArgs.data); + } } } else { reject(constants.notSupportedEventArg); } }); + this.fire(eventName, arg); }); + const result = await promise; + if (emitter && callbackEventName) { + this.disposeEvent(callbackEventName, emitter); + } + + return result; } public async getLocalPaths(options: vscode.OpenDialogOptions): Promise { diff --git a/extensions/machine-learning/src/views/widgets/dashboardWidget.ts b/extensions/machine-learning/src/views/widgets/dashboardWidget.ts index 57343c7903..702bab5b70 100644 --- a/extensions/machine-learning/src/views/widgets/dashboardWidget.ts +++ b/extensions/machine-learning/src/views/widgets/dashboardWidget.ts @@ -9,6 +9,7 @@ import { ApiWrapper } from '../../common/apiWrapper'; import * as path from 'path'; import * as constants from '../../common/constants'; import * as utils from '../../common/utils'; +import { PredictService } from '../../prediction/predictService'; interface IActionMetadata { title?: string, @@ -25,53 +26,56 @@ export class DashboardWidget { /** * Creates new instance of dashboard */ - constructor(private _apiWrapper: ApiWrapper, private _root: string) { + constructor(private _apiWrapper: ApiWrapper, private _root: string, private _predictService: PredictService) { } - public register(): void { - this._apiWrapper.registerWidget('mls.dashboard', async (view) => { - const container = view.modelBuilder.flexContainer().withLayout({ - flexFlow: 'column', - width: '100%', - height: '100%' - }).component(); - const header = this.createHeader(view); - const tasksContainer = this.createTasks(view); - const footerContainer = this.createFooter(view); - container.addItem(header, { - CSSStyles: { - 'background-image': `url(${vscode.Uri.file(this.asAbsolutePath('images/background.svg'))})`, - 'background-repeat': 'no-repeat', - 'background-position': 'bottom', - 'width': `${maxWidth}px`, - 'height': '330px', - 'background-size': `${maxWidth}px ${headerMaxHeight}px`, - 'margin-bottom': '-60px' - } - }); - container.addItem(tasksContainer, { - CSSStyles: { - 'width': `${maxWidth}px`, - 'height': '150px', - } - }); - container.addItem(footerContainer, { - CSSStyles: { - 'width': `${maxWidth}px`, - 'height': '500px', - } - }); - const mainContainer = view.modelBuilder.flexContainer() - .withLayout({ + public register(): Promise { + return new Promise(resolve => { + this._apiWrapper.registerWidget('mls.dashboard', async (view) => { + const container = view.modelBuilder.flexContainer().withLayout({ flexFlow: 'column', width: '100%', - height: '100%', - position: 'absolute' + height: '100%' }).component(); - mainContainer.addItem(container, { - CSSStyles: { 'padding-top': '25px', 'padding-left': '5px' } + const header = this.createHeader(view); + const tasksContainer = await this.createTasks(view); + const footerContainer = this.createFooter(view); + container.addItem(header, { + CSSStyles: { + 'background-image': `url(${vscode.Uri.file(this.asAbsolutePath('images/background.svg'))})`, + 'background-repeat': 'no-repeat', + 'background-position': 'bottom', + 'width': `${maxWidth}px`, + 'height': '330px', + 'background-size': `${maxWidth}px ${headerMaxHeight}px`, + 'margin-bottom': '-60px' + } + }); + container.addItem(tasksContainer, { + CSSStyles: { + 'width': `${maxWidth}px`, + 'height': '150px', + } + }); + container.addItem(footerContainer, { + CSSStyles: { + 'width': `${maxWidth}px`, + 'height': '500px', + } + }); + const mainContainer = view.modelBuilder.flexContainer() + .withLayout({ + flexFlow: 'column', + width: '100%', + height: '100%', + position: 'absolute' + }).component(); + mainContainer.addItem(container, { + CSSStyles: { 'padding-top': '25px', 'padding-left': '5px' } + }); + await view.initializeModel(mainContainer); + resolve(); }); - await view.initializeModel(mainContainer); }); } @@ -445,7 +449,7 @@ export class DashboardWidget { return path.join(this._root || '', filePath); } - private createTasks(view: azdata.ModelView): azdata.Component { + private async createTasks(view: azdata.ModelView): Promise { const tasksContainer = view.modelBuilder.flexContainer().withLayout({ flexFlow: 'row', width: '100%', @@ -489,6 +493,7 @@ export class DashboardWidget { 'padding': '10px' } }); + predictionButton.enabled = await this._predictService.serverSupportOnnxModel(); return tasksContainer; } @@ -506,7 +511,7 @@ export class DashboardWidget { const iconContainer = view.modelBuilder.flexContainer().withLayout({ flexFlow: 'row', width: maxWidth, - height: maxHeight - 20, + height: maxHeight - 23, alignItems: 'flex-start' }).component(); const labelsContainer = view.modelBuilder.flexContainer().withLayout({ @@ -571,7 +576,7 @@ export class DashboardWidget { } }); mainContainer.onDidClick(async () => { - if (taskMetaData.command) { + if (mainContainer.enabled && taskMetaData.command) { await this._apiWrapper.executeCommand(taskMetaData.command); } });