Machine Learning Services - Model detection in predict wizard (#9609)

* Machine Learning Services - Model detection in predict wizard
This commit is contained in:
Leila Lali
2020-03-25 13:18:19 -07:00
committed by GitHub
parent 176edde2aa
commit ab82c04766
44 changed files with 2265 additions and 376 deletions

View File

@@ -107,14 +107,14 @@ export abstract class LanguageViewBase {
if (connection) {
return `${connection.serverName} ${connection.databaseName ? connection.databaseName : constants.extLangLocal}`;
}
return constants.packageManagerNoConnection;
return constants.noConnectionError;
}
public getServerTitle(): string {
if (this.connection) {
return this.connection.serverName;
}
return constants.packageManagerNoConnection;
return constants.noConnectionError;
}
private async getCurrentConnectionUrl(): Promise<string> {

View File

@@ -18,6 +18,7 @@ export interface IPageView {
onLeave?: () => Promise<void>;
validate?: () => Promise<boolean>;
refresh: () => Promise<void>;
disposePage?: () => Promise<void>;
viewPanel: azdata.window.ModelViewPanel | undefined;
title: string;
}

View File

@@ -37,6 +37,16 @@ export class MainViewBase {
}
}
public async disposePages(): Promise<void> {
if (this._pages) {
await Promise.all(this._pages.map(async (p) => {
if (p.disposePage) {
await p.disposePage();
}
}));
}
}
public async refresh(): Promise<void> {
if (this._pages) {
await Promise.all(this._pages.map(async (p) => await p.refresh()));

View File

@@ -9,6 +9,7 @@ import { ApiWrapper } from '../../common/apiWrapper';
import { AzureResourceFilterComponent } from './azureResourceFilterComponent';
import { AzureModelsTable } from './azureModelsTable';
import { IDataComponent, AzureModelResource } from '../interfaces';
import { ModelArtifact } from './prediction/modelArtifact';
export class AzureModelsComponent extends ModelViewBase implements IDataComponent<AzureModelResource> {
@@ -17,6 +18,7 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
private _loader: azdata.LoadingComponent | undefined;
private _form: azdata.FormContainer | undefined;
private _downloadedFile: ModelArtifact | undefined;
/**
* Component to render a view to pick an azure model
@@ -37,8 +39,14 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
.withProperties({
loading: true
}).component();
this.azureModelsTable.onModelSelectionChanged(async () => {
if (this._downloadedFile) {
await this._downloadedFile.close();
}
this._downloadedFile = undefined;
});
this.azureFilterComponent.onWorkspacesSelected(async () => {
this.azureFilterComponent.onWorkspacesSelectedChanged(async () => {
await this.onLoading();
await this.azureModelsTable?.loadData(this.azureFilterComponent?.data);
await this.onLoaded();
@@ -107,6 +115,22 @@ export class AzureModelsComponent extends ModelViewBase implements IDataComponen
});
}
public async getDownloadedModel(): Promise<ModelArtifact> {
if (!this._downloadedFile) {
this._downloadedFile = new ModelArtifact(await this.downloadAzureModel(this.data));
}
return this._downloadedFile;
}
/**
* disposes the view
*/
public async disposeComponent(): Promise<void> {
if (this._downloadedFile) {
await this._downloadedFile.close();
}
}
/**
* Refreshes the view
*/

View File

@@ -4,6 +4,7 @@
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as constants from '../../common/constants';
import { ModelViewBase } from './modelViewBase';
import { ApiWrapper } from '../../common/apiWrapper';
@@ -18,6 +19,8 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent<Wo
private _table: azdata.DeclarativeTableComponent;
private _selectedModelId: any;
private _models: WorkspaceModel[] | undefined;
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
public readonly onModelSelectionChanged: vscode.Event<void> = this._onModelSelectionChanged.event;
/**
* Creates a view to render azure models in a table
@@ -115,6 +118,7 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent<Wo
this._table.data = tableData;
}
this._onModelSelectionChanged.fire();
}
private createTableRow(model: WorkspaceModel): any[] {
@@ -128,6 +132,7 @@ export class AzureModelsTable extends ModelViewBase implements IDataComponent<Wo
}).component();
selectModelButton.onDidClick(() => {
this._selectedModelId = model.id;
this._onModelSelectionChanged.fire();
});
return [model.name, model.createdTime, model.frameworkVersion, selectModelButton];
}

View File

@@ -27,8 +27,8 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData
private _azureSubscriptions: azureResource.AzureResourceSubscription[] = [];
private _azureGroups: azureResource.AzureResource[] = [];
private _azureWorkspaces: Workspace[] = [];
private _onWorkspacesSelected: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
public readonly onWorkspacesSelected: vscode.Event<void> = this._onWorkspacesSelected.event;
private _onWorkspacesSelectedChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
public readonly onWorkspacesSelectedChanged: vscode.Event<void> = this._onWorkspacesSelectedChanged.event;
/**
* Creates a new view
@@ -59,7 +59,7 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData
await this.onGroupSelected();
});
this._workspaces.onValueChanged(async () => {
await this.onWorkspaceSelected();
await this.onWorkspaceSelectedChanged();
});
this._form = this._modelBuilder.formContainer().withFormItems([{
@@ -182,26 +182,26 @@ export class AzureResourceFilterComponent extends ModelViewBase implements IData
this._workspaces.values = values;
this._workspaces.value = values[0];
}
this.onWorkspaceSelected();
this.onWorkspaceSelectedChanged();
}
private onWorkspaceSelected(): void {
this._onWorkspacesSelected.fire();
private onWorkspaceSelectedChanged(): void {
this._onWorkspacesSelectedChanged.fire();
}
private get workspace(): Workspace | undefined {
return this._azureWorkspaces ? this._azureWorkspaces.find(a => a.id === (<azdata.CategoryValue>this._workspaces.value).name) : undefined;
return this._azureWorkspaces && this._workspaces.value ? this._azureWorkspaces.find(a => a.id === (<azdata.CategoryValue>this._workspaces.value).name) : undefined;
}
private get account(): azdata.Account | undefined {
return this._azureAccounts ? this._azureAccounts.find(a => a.key.accountId === (<azdata.CategoryValue>this._accounts.value).name) : undefined;
return this._azureAccounts && this._accounts.value ? this._azureAccounts.find(a => a.key.accountId === (<azdata.CategoryValue>this._accounts.value).name) : undefined;
}
private get group(): azureResource.AzureResource | undefined {
return this._azureGroups ? this._azureGroups.find(a => a.id === (<azdata.CategoryValue>this._groups.value).name) : undefined;
return this._azureGroups && this._groups.value ? this._azureGroups.find(a => a.id === (<azdata.CategoryValue>this._groups.value).name) : undefined;
}
private get subscription(): azureResource.AzureResourceSubscription | undefined {
return this._azureSubscriptions ? this._azureSubscriptions.find(a => a.id === (<azdata.CategoryValue>this._subscriptions.value).name) : undefined;
return this._azureSubscriptions && this._subscriptions.value ? this._azureSubscriptions.find(a => a.id === (<azdata.CategoryValue>this._subscriptions.value).name) : undefined;
}
}

View File

@@ -9,15 +9,15 @@ import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper';
import { AzureModelRegistryService } from '../../modelManagement/azureModelRegistryService';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable } from '../../prediction/interfaces';
import { RegisteredModelService } from '../../modelManagement/registeredModelService';
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails, ModelParameters } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
import { DeployedModelService } from '../../modelManagement/deployedModelService';
import { RegisteredModelsDialog } from './registerModels/registeredModelsDialog';
import {
AzureResourceEventArgs, ListAzureModelsEventName, ListSubscriptionsEventName, ListModelsEventName, ListWorkspacesEventName,
ListGroupsEventName, ListAccountsEventName, RegisterLocalModelEventName, RegisterLocalModelEventArgs, RegisterAzureModelEventName,
RegisterAzureModelEventArgs, ModelViewBase, SourceModelSelectedEventName, RegisterModelEventName, DownloadAzureModelEventName,
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs
ListDatabaseNamesEventName, ListTableNamesEventName, ListColumnNamesEventName, PredictModelEventName, PredictModelEventArgs, DownloadRegisteredModelEventName, LoadModelParametersEventName
} from './modelViewBase';
import { ControllerBase } from '../controllerBase';
import { RegisterModelWizard } from './registerModels/registerModelWizard';
@@ -39,7 +39,7 @@ export class ModelManagementController extends ControllerBase {
apiWrapper: ApiWrapper,
private _root: string,
private _amlService: AzureModelRegistryService,
private _registeredModelService: RegisteredModelService,
private _registeredModelService: DeployedModelService,
private _predictService: PredictService) {
super(apiWrapper);
}
@@ -61,7 +61,7 @@ export class ModelManagementController extends ControllerBase {
// Open view
//
view.open();
await view.open();
await view.refresh();
return view;
}
@@ -74,10 +74,15 @@ export class ModelManagementController extends ControllerBase {
let view = new PredictWizard(this._apiWrapper, this._root);
this.registerEvents(view);
view.on(LoadModelParametersEventName, async () => {
const modelArtifact = await view.getModelFileName();
await this.executeAction(view, LoadModelParametersEventName, this.loadModelParameters, this._registeredModelService,
modelArtifact?.filePath);
});
// Open view
//
view.open();
await view.open();
await view.refresh();
return view;
}
@@ -151,6 +156,11 @@ export class ModelManagementController extends ControllerBase {
await this.executeAction(view, PredictModelEventName, this.generatePredictScript, this._predictService,
predictArgs, predictArgs.model, predictArgs.filePath);
});
view.on(DownloadRegisteredModelEventName, async (arg) => {
let model = <RegisteredModel>arg;
await this.executeAction(view, DownloadRegisteredModelEventName, this.downloadRegisteredModel, this._registeredModelService,
model);
});
view.on(SourceModelSelectedEventName, () => {
view.refresh();
});
@@ -191,8 +201,8 @@ export class ModelManagementController extends ControllerBase {
return await service.getWorkspaces(account, subscription, group);
}
private async getRegisteredModels(registeredModelService: RegisteredModelService): Promise<RegisteredModel[]> {
return registeredModelService.getRegisteredModels();
private async getRegisteredModels(registeredModelService: DeployedModelService): Promise<RegisteredModel[]> {
return registeredModelService.getDeployedModels();
}
private async getAzureModels(
@@ -207,9 +217,9 @@ export class ModelManagementController extends ControllerBase {
return await service.getModels(account, subscription, resourceGroup, workspace) || [];
}
private async registerLocalModel(service: RegisteredModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise<void> {
private async registerLocalModel(service: DeployedModelService, filePath: string, details: RegisteredModelDetails | undefined): Promise<void> {
if (filePath) {
await service.registerLocalModel(filePath, details);
await service.deployLocalModel(filePath, details);
} else {
throw Error(constants.invalidModelToRegisterError);
@@ -218,7 +228,7 @@ export class ModelManagementController extends ControllerBase {
private async registerAzureModel(
azureService: AzureModelRegistryService,
service: RegisteredModelService,
service: DeployedModelService,
account: azdata.Account | undefined,
subscription: azureResource.AzureResourceSubscription | undefined,
resourceGroup: azureResource.AzureResource | undefined,
@@ -231,7 +241,7 @@ export class ModelManagementController extends ControllerBase {
const filePath = await azureService.downloadModel(account, subscription, resourceGroup, workspace, model);
if (filePath) {
await service.registerLocalModel(filePath, details);
await service.deployLocalModel(filePath, details);
await fs.promises.unlink(filePath);
} else {
throw Error(constants.invalidModelToRegisterError);
@@ -246,7 +256,7 @@ export class ModelManagementController extends ControllerBase {
return await predictService.getTableList(databaseName);
}
public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<string[]> {
public async getTableColumnsList(predictService: PredictService, databaseTable: DatabaseTable): Promise<TableColumn[]> {
return await predictService.getTableColumnsList(databaseTable);
}
@@ -263,6 +273,24 @@ export class ModelManagementController extends ControllerBase {
return result;
}
private async downloadRegisteredModel(
registeredModelService: DeployedModelService,
model: RegisteredModel | undefined): Promise<string> {
if (!model) {
throw Error(constants.invalidModelToPredictError);
}
return await registeredModelService.downloadModel(model);
}
private async loadModelParameters(
registeredModelService: DeployedModelService,
model: string | undefined): Promise<ModelParameters | undefined> {
if (!model) {
return undefined;
}
return await registeredModelService.loadModelParameters(model);
}
private async downloadAzureModel(
azureService: AzureModelRegistryService,
account: azdata.Account | undefined,

View File

@@ -120,4 +120,14 @@ export class ModelSourcePage extends ModelViewBase implements IPageView, IDataCo
}
return Promise.resolve(validated);
}
public async disposePage(): Promise<void> {
if (this.azureModelsComponent) {
await this.azureModelsComponent.disposeComponent();
}
if (this.registeredModelsComponent) {
await this.registeredModelsComponent.disposeComponent();
}
}
}

View File

@@ -8,8 +8,8 @@ import * as azdata from 'azdata';
import { azureResource } from '../../typings/azure-resource';
import { ApiWrapper } from '../../common/apiWrapper';
import { ViewBase } from '../viewBase';
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable } from '../../prediction/interfaces';
import { RegisteredModel, WorkspaceModel, RegisteredModelDetails, ModelParameters } from '../../modelManagement/interfaces';
import { PredictParameters, DatabaseTable, TableColumn } from '../../prediction/interfaces';
import { Workspace } from '@azure/arm-machinelearningservices/esm/models';
import { AzureWorkspaceResource, AzureModelResource } from '../interfaces';
@@ -47,9 +47,11 @@ export const ListWorkspacesEventName = 'listWorkspaces';
export const RegisterLocalModelEventName = 'registerLocalModel';
export const RegisterAzureModelEventName = 'registerAzureLocalModel';
export const DownloadAzureModelEventName = 'downloadAzureLocalModel';
export const DownloadRegisteredModelEventName = 'downloadRegisteredModel';
export const PredictModelEventName = 'predictModel';
export const RegisterModelEventName = 'registerModel';
export const SourceModelSelectedEventName = 'sourceModelSelected';
export const LoadModelParametersEventName = 'loadModelParameters';
/**
* Base class for all model management views
@@ -75,7 +77,9 @@ export abstract class ModelViewBase extends ViewBase {
ListTableNamesEventName,
ListColumnNamesEventName,
PredictModelEventName,
DownloadAzureModelEventName]);
DownloadAzureModelEventName,
DownloadRegisteredModelEventName,
LoadModelParametersEventName]);
}
/**
@@ -124,7 +128,7 @@ export abstract class ModelViewBase extends ViewBase {
/**
* lists column names
*/
public async listColumnNames(table: DatabaseTable): Promise<string[]> {
public async listColumnNames(table: DatabaseTable): Promise<TableColumn[]> {
return await this.sendDataRequest(ListColumnNamesEventName, table);
}
@@ -151,6 +155,14 @@ export abstract class ModelViewBase extends ViewBase {
return await this.sendDataRequest(RegisterLocalModelEventName, args);
}
/**
* downloads registered model
* @param model model to download
*/
public async downloadRegisteredModel(model: RegisteredModel | undefined): Promise<string> {
return await this.sendDataRequest(DownloadRegisteredModelEventName, model);
}
/**
* download azure model
* @param args azure resource
@@ -159,6 +171,13 @@ export abstract class ModelViewBase extends ViewBase {
return await this.sendDataRequest(DownloadAzureModelEventName, resource);
}
/**
* Loads model parameters
*/
public async loadModelParameters(): Promise<ModelParameters | undefined> {
return await this.sendDataRequest(LoadModelParametersEventName);
}
/**
* registers azure model
* @param args azure resource

View File

@@ -8,7 +8,7 @@ import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IPageView, IDataComponent } from '../../interfaces';
import { ColumnsFilterComponent } from './columnsFilterComponent';
import { InputColumnsComponent } from './inputColumnsComponent';
import { OutputColumnsComponent } from './outputColumnsComponent';
import { PredictParameters } from '../../../prediction/interfaces';
@@ -19,7 +19,7 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
private _form: azdata.FormContainer | undefined;
private _formBuilder: azdata.FormBuilder | undefined;
public columnsFilterComponent: ColumnsFilterComponent | undefined;
public inputColumnsComponent: InputColumnsComponent | undefined;
public outputColumnsComponent: OutputColumnsComponent | undefined;
constructor(apiWrapper: ApiWrapper, parent: ModelViewBase) {
@@ -32,15 +32,14 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._formBuilder = modelBuilder.formContainer();
this.columnsFilterComponent = new ColumnsFilterComponent(this._apiWrapper, this);
this.columnsFilterComponent.registerComponent(modelBuilder);
this.columnsFilterComponent.addComponents(this._formBuilder);
this.refresh();
this.inputColumnsComponent = new InputColumnsComponent(this._apiWrapper, this);
this.inputColumnsComponent.registerComponent(modelBuilder);
this.inputColumnsComponent.addComponents(this._formBuilder);
this.outputColumnsComponent = new OutputColumnsComponent(this._apiWrapper, this);
this.outputColumnsComponent.registerComponent(modelBuilder);
this.outputColumnsComponent.addComponents(this._formBuilder);
this.refresh();
this._form = this._formBuilder.component();
return this._form;
}
@@ -49,8 +48,8 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
* Returns selected data
*/
public get data(): PredictParameters | undefined {
return this.columnsFilterComponent?.data && this.outputColumnsComponent?.data ?
Object.assign({}, this.columnsFilterComponent.data, { outputColumns: this.outputColumnsComponent.data }) :
return this.inputColumnsComponent?.data && this.outputColumnsComponent?.data ?
Object.assign({}, this.inputColumnsComponent.data, { outputColumns: this.outputColumnsComponent.data }) :
undefined;
}
@@ -66,8 +65,8 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
*/
public async refresh(): Promise<void> {
if (this._formBuilder) {
if (this.columnsFilterComponent) {
await this.columnsFilterComponent.refresh();
if (this.inputColumnsComponent) {
await this.inputColumnsComponent.refresh();
}
if (this.outputColumnsComponent) {
await this.outputColumnsComponent.refresh();
@@ -75,6 +74,24 @@ export class ColumnsSelectionPage extends ModelViewBase implements IPageView, ID
}
}
public async onEnter(): Promise<void> {
await this.inputColumnsComponent?.onLoading();
await this.outputColumnsComponent?.onLoading();
try {
const modelParameters = await this.loadModelParameters();
if (modelParameters && this.inputColumnsComponent && this.outputColumnsComponent) {
this.inputColumnsComponent.modelParameters = modelParameters;
this.outputColumnsComponent.modelParameters = modelParameters;
await this.inputColumnsComponent.refresh();
await this.outputColumnsComponent.refresh();
}
} catch (error) {
this.showErrorMessage(constants.loadModelParameterFailedError, error);
}
await this.inputColumnsComponent?.onLoaded();
await this.outputColumnsComponent?.onLoaded();
}
/**
* Returns page title
*/

View File

@@ -8,133 +8,280 @@ import * as constants from '../../../common/constants';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { IDataComponent } from '../../interfaces';
import { PredictColumn, DatabaseTable } from '../../../prediction/interfaces';
import { PredictColumn, DatabaseTable, TableColumn } from '../../../prediction/interfaces';
import { ModelParameter, ModelParameters } from '../../../modelManagement/interfaces';
/**
* View to render azure models in a table
*/
export class ColumnsTable extends ModelViewBase implements IDataComponent<PredictColumn[]> {
private _table: azdata.DeclarativeTableComponent;
private _selectedColumns: PredictColumn[] = [];
private _columns: string[] | undefined;
private _table: azdata.DeclarativeTableComponent | undefined;
private _parameters: PredictColumn[] = [];
private _loader: azdata.LoadingComponent;
private _dataTypes: string[] = [
'bigint',
'int',
'smallint',
'real',
'float',
'varchar(MAX)',
'bit'
];
/**
* Creates a view to render azure models in a table
*/
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase) {
constructor(apiWrapper: ApiWrapper, private _modelBuilder: azdata.ModelBuilder, parent: ModelViewBase, private _forInput: boolean = true) {
super(apiWrapper, parent.root, parent);
this._table = this.registerComponent(this._modelBuilder);
this._loader = this.registerComponent(this._modelBuilder);
}
/**
* Register components
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.LoadingComponent {
let columnHeader: azdata.DeclarativeTableColumn[];
if (this._forInput) {
columnHeader = [
{ // Action
displayName: constants.columnName,
ariaLabel: constants.columnName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Name
displayName: '',
ariaLabel: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Name
displayName: constants.inputName,
ariaLabel: constants.inputName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 120,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
];
} else {
columnHeader = [
{ // Name
displayName: constants.outputName,
ariaLabel: constants.outputName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 200,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: constants.displayName,
ariaLabel: constants.displayName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: constants.dataTypeName,
ariaLabel: constants.dataTypeName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
];
}
this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
{
columns: [
{ // Name
displayName: constants.columnDatabase,
ariaLabel: constants.columnName,
valueType: azdata.DeclarativeDataType.string,
isReadOnly: true,
width: 120,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: constants.inputName,
ariaLabel: constants.inputName,
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
},
{ // Action
displayName: '',
valueType: azdata.DeclarativeDataType.component,
isReadOnly: true,
width: 50,
headerCssStyles: {
...constants.cssStyles.tableHeader
},
rowCssStyles: {
...constants.cssStyles.tableRow
},
}
],
columns: columnHeader,
data: [],
ariaLabel: constants.mlsConfigTitle
})
.component();
return this._table;
this._loader = modelBuilder.loadingComponent()
.withItem(this._table)
.withProperties({
loading: true
}).component();
return this._loader;
}
public get component(): azdata.DeclarativeTableComponent {
return this._table;
public async onLoading(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: true });
}
}
public async onLoaded(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: false });
}
}
public get component(): azdata.Component {
return this._loader;
}
/**
* Load data in the component
* @param workspaceResource Azure workspace
*/
public async loadData(table: DatabaseTable): Promise<void> {
this._selectedColumns = [];
if (this._table) {
this._columns = await this.listColumnNames(table);
let tableData: any[][] = [];
public async loadInputs(modelParameters: ModelParameters | undefined, table: DatabaseTable): Promise<void> {
await this.onLoading();
this._parameters = [];
let tableData: any[][] = [];
if (this._columns) {
tableData = tableData.concat(this._columns.map(model => this.createTableRow(model)));
if (this._table) {
if (this._forInput) {
const columns = await this.listColumnNames(table);
if (modelParameters?.inputs && columns) {
tableData = tableData.concat(modelParameters.inputs.map(input => this.createInputTableRow(input, columns)));
}
}
this._table.data = tableData;
}
await this.onLoaded();
}
private createTableRow(column: string): any[] {
if (this._modelBuilder) {
let selectRowButton = this._modelBuilder.checkBox().withProperties({
public async loadOutputs(modelParameters: ModelParameters | undefined): Promise<void> {
this.onLoading();
this._parameters = [];
let tableData: any[][] = [];
width: 15,
height: 15,
checked: true
if (this._table) {
if (!this._forInput) {
if (modelParameters?.outputs && this._dataTypes) {
tableData = tableData.concat(modelParameters.outputs.map(output => this.createOutputTableRow(output, this._dataTypes)));
}
}
this._table.data = tableData;
}
this.onLoaded();
}
private createOutputTableRow(modelParameter: ModelParameter, dataTypes: string[]): any[] {
if (this._modelBuilder) {
let nameInput = this._modelBuilder.dropDown().withProperties({
values: dataTypes,
width: this.componentMaxLength
}).component();
let nameInputBox = this._modelBuilder.inputBox().withProperties({
value: '',
width: 150
}).component();
this._selectedColumns.push({ name: column });
selectRowButton.onChanged(() => {
if (selectRowButton.checked) {
if (!this._selectedColumns.find(x => x.name === column)) {
this._selectedColumns.push({ name: column });
}
} else {
if (this._selectedColumns.find(x => x.name === column)) {
this._selectedColumns = this._selectedColumns.filter(x => x.name !== column);
const name = modelParameter.name;
const dataType = dataTypes.find(x => x === modelParameter.type);
if (dataType) {
nameInput.value = dataType;
}
this._parameters.push({ columnName: name, paramName: name, dataType: modelParameter.type });
nameInput.onValueChanged(() => {
const value = <string>nameInput.value;
if (value !== modelParameter.type) {
let selectedRow = this._parameters.find(x => x.paramName === name);
if (selectedRow) {
selectedRow.dataType = value;
}
}
});
nameInputBox.onTextChanged(() => {
let selectedRow = this._selectedColumns.find(x => x.name === column);
let displayNameInput = this._modelBuilder.inputBox().withProperties({
value: name,
width: 200
}).component();
displayNameInput.onTextChanged(() => {
let selectedRow = this._parameters.find(x => x.paramName === name);
if (selectedRow) {
selectedRow.displayName = nameInputBox.value;
selectedRow.columnName = displayNameInput.value || name;
}
});
return [column, nameInputBox, selectRowButton];
return [`${name}(${modelParameter.type ? modelParameter.type : constants.unsupportedModelParameterType})`, displayNameInput, nameInput];
}
return [];
}
private createInputTableRow(modelParameter: ModelParameter, columns: TableColumn[] | undefined): any[] {
if (this._modelBuilder && columns) {
const values = columns.map(c => { return { name: c.columnName, displayName: `${c.columnName}(${c.dataType})` }; });
let nameInput = this._modelBuilder.dropDown().withProperties({
values: values,
width: this.componentMaxLength
}).component();
const name = modelParameter.name;
let column = values.find(x => x.name === modelParameter.name);
if (!column) {
column = values[0];
}
nameInput.value = column;
this._parameters.push({ columnName: column.name, paramName: name });
nameInput.onValueChanged(() => {
const selectedColumn = nameInput.value;
const value = selectedColumn ? (<azdata.CategoryValue>selectedColumn).name : undefined;
let selectedRow = this._parameters.find(x => x.paramName === name);
if (selectedRow) {
selectedRow.columnName = value || '';
}
});
const label = this._modelBuilder.inputBox().withProperties({
value: `${name}(${modelParameter.type ? modelParameter.type : constants.unsupportedModelParameterType})`,
enabled: false,
width: this.componentMaxLength
}).component();
const image = this._modelBuilder.image().withProperties({
width: 50,
height: 50,
iconPath: {
dark: this.asAbsolutePath('images/arrow.svg'),
light: this.asAbsolutePath('images/arrow.svg')
},
iconWidth: 20,
iconHeight: 20,
title: 'maps'
}).component();
return [nameInput, image, label];
}
return [];
@@ -144,7 +291,7 @@ export class ColumnsTable extends ModelViewBase implements IDataComponent<Predic
* Returns selected data
*/
public get data(): PredictColumn[] | undefined {
return this._selectedColumns;
return this._parameters;
}
/**

View File

@@ -8,13 +8,14 @@ import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IDataComponent } from '../../interfaces';
import { ColumnsTable } from './columnsTable';
import { PredictColumn, PredictInputParameters, DatabaseTable } from '../../../prediction/interfaces';
import { ModelParameters } from '../../../modelManagement/interfaces';
import { ColumnsTable } from './columnsTable';
/**
* View to render filters to pick an azure resource
*/
export class ColumnsFilterComponent extends ModelViewBase implements IDataComponent<PredictInputParameters> {
export class InputColumnsComponent extends ModelViewBase implements IDataComponent<PredictInputParameters> {
private _form: azdata.FormContainer | undefined;
private _databases: azdata.DropDownComponent | undefined;
@@ -22,7 +23,9 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon
private _columns: ColumnsTable | undefined;
private _dbNames: string[] = [];
private _tableNames: DatabaseTable[] = [];
private _modelParameters: ModelParameters | undefined;
private _dbTableComponent: azdata.FlexContainer | undefined;
private tableMaxLength = this.componentMaxLength * 2 + 70;
/**
* Creates a new view
*/
@@ -52,27 +55,47 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon
});
this._form = modelBuilder.formContainer().withFormItems([{
title: constants.azureAccount,
const databaseForm = modelBuilder.formContainer().withFormItems([{
title: constants.columnDatabase,
component: this._databases
}, {
title: constants.azureSubscription,
}]).withLayout({
padding: '0px'
}).component();
const tableForm = modelBuilder.formContainer().withFormItems([{
title: constants.columnTable,
component: this._tables
}]).withLayout({
padding: '0px'
}).component();
this._dbTableComponent = modelBuilder.flexContainer().withItems([
databaseForm,
tableForm
], {
flex: '0 0 auto',
CSSStyles: {
'align-items': 'flex-start'
}
}).withLayout({
flexFlow: 'row',
justifyContent: 'space-between',
width: this.tableMaxLength
}).component();
this._form = modelBuilder.formContainer().withFormItems([{
title: '',
component: this._dbTableComponent
}, {
title: constants.azureGroup,
title: constants.inputColumns,
component: this._columns.component
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._databases && this._tables && this._columns) {
if (this._columns && this._dbTableComponent) {
formBuilder.addFormItems([{
title: constants.columnDatabase,
component: this._databases
}, {
title: constants.columnTable,
component: this._tables
title: '',
component: this._dbTableComponent
}, {
title: constants.inputColumns,
component: this._columns.component
@@ -81,17 +104,13 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._databases && this._tables && this._columns) {
if (this._columns && this._dbTableComponent) {
formBuilder.removeFormItem({
title: constants.azureAccount,
component: this._databases
title: '',
component: this._dbTableComponent
});
formBuilder.removeFormItem({
title: constants.azureSubscription,
component: this._tables
});
formBuilder.removeFormItem({
title: constants.azureGroup,
title: constants.inputColumns,
component: this._columns.component
});
}
@@ -125,6 +144,22 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon
await this.onDatabaseSelected();
}
public set modelParameters(value: ModelParameters) {
this._modelParameters = value;
}
public async onLoading(): Promise<void> {
if (this._columns) {
await this._columns.onLoading();
}
}
public async onLoaded(): Promise<void> {
if (this._columns) {
await this._columns.onLoaded();
}
}
/**
* refreshes the view
*/
@@ -146,7 +181,7 @@ export class ColumnsFilterComponent extends ModelViewBase implements IDataCompon
}
private async onTableSelected(): Promise<void> {
this._columns?.loadData(this.databaseTable);
this._columns?.loadInputs(this._modelParameters, this.databaseTable);
}
private get databaseName(): string | undefined {

View File

@@ -0,0 +1,35 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the Source EULA. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as utils from '../../../common/utils';
/**
* Wizard to register a model
*/
export class ModelArtifact {
/**
* Creates new model artifact
*/
constructor(private _filePath: string, private _deleteAtClose: boolean = true) {
}
public get filePath(): string {
return this._filePath;
}
/**
* Closes the artifact and disposes the resources
*/
public async close(): Promise<void> {
if (this._deleteAtClose) {
try {
await utils.deleteFile(this._filePath);
} catch {
}
}
}
}

View File

@@ -9,25 +9,18 @@ import { ApiWrapper } from '../../../common/apiWrapper';
import * as constants from '../../../common/constants';
import { IDataComponent } from '../../interfaces';
import { PredictColumn } from '../../../prediction/interfaces';
import { ColumnsTable } from './columnsTable';
import { ModelParameters } from '../../../modelManagement/interfaces';
/**
* View to render filters to pick an azure resource
*/
const componentWidth = 60;
export class OutputColumnsComponent extends ModelViewBase implements IDataComponent<PredictColumn[]> {
private _form: azdata.FormContainer | undefined;
private _flex: azdata.FlexContainer | undefined;
private _columnName: azdata.InputBoxComponent | undefined;
private _columnTypes: azdata.DropDownComponent | undefined;
private _dataTypes: string[] = [
'int',
'nvarchar(MAX)',
'varchar(MAX)',
'float',
'double',
'bit'
];
private _columns: ColumnsTable | undefined;
private _modelParameters: ModelParameters | undefined;
/**
* Creates a new view
@@ -41,49 +34,29 @@ export class OutputColumnsComponent extends ModelViewBase implements IDataCompon
* @param modelBuilder model builder
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._columnName = modelBuilder.inputBox().withProperties({
width: this.componentMaxLength - componentWidth - this.spaceBetweenComponentsLength
}).component();
this._columnTypes = modelBuilder.dropDown().withProperties({
width: componentWidth
}).component();
let flex = modelBuilder.flexContainer()
.withLayout({
width: this._columnName.width
}).withItems([
this._columnName]
).component();
this._flex = modelBuilder.flexContainer()
.withLayout({
flexFlow: 'row',
justifyContent: 'space-between',
width: this.componentMaxLength
}).withItems([
flex, this._columnTypes]
).component();
this._columns = new ColumnsTable(this._apiWrapper, modelBuilder, this, false);
this._form = modelBuilder.formContainer().withFormItems([{
title: constants.azureAccount,
component: this._flex
component: this._columns.component
}]).component();
return this._form;
}
public addComponents(formBuilder: azdata.FormBuilder) {
if (this._flex) {
if (this._columns) {
formBuilder.addFormItems([{
title: constants.outputColumns,
component: this._flex
component: this._columns.component
}]);
}
}
public removeComponents(formBuilder: azdata.FormBuilder) {
if (this._flex) {
if (this._columns) {
formBuilder.removeFormItem({
title: constants.outputColumns,
component: this._flex
component: this._columns.component
});
}
}
@@ -99,9 +72,24 @@ export class OutputColumnsComponent extends ModelViewBase implements IDataCompon
* loads data in the components
*/
public async loadData(): Promise<void> {
if (this._columnTypes) {
this._columnTypes.values = this._dataTypes;
this._columnTypes.value = this._dataTypes[0];
if (this._modelParameters) {
this._columns?.loadOutputs(this._modelParameters);
}
}
public set modelParameters(value: ModelParameters) {
this._modelParameters = value;
}
public async onLoading(): Promise<void> {
if (this._columns) {
await this._columns.onLoading();
}
}
public async onLoaded(): Promise<void> {
if (this._columns) {
await this._columns.onLoaded();
}
}
@@ -116,9 +104,6 @@ export class OutputColumnsComponent extends ModelViewBase implements IDataCompon
* Returns selected data
*/
public get data(): PredictColumn[] | undefined {
return this._columnName && this._columnTypes ? [{
name: this._columnName.value || '',
dataType: <string>this._columnTypes.value || ''
}] : undefined;
return this._columns?.data;
}
}

View File

@@ -14,6 +14,7 @@ import { WizardView } from '../../wizardView';
import { ModelSourcePage } from '../modelSourcePage';
import { ColumnsSelectionPage } from './columnsSelectionPage';
import { RegisteredModel } from '../../../modelManagement/interfaces';
import { ModelArtifact } from './modelArtifact';
/**
* Wizard to register a model
@@ -21,7 +22,6 @@ import { RegisteredModel } from '../../../modelManagement/interfaces';
export class PredictWizard extends ModelViewBase {
public modelSourcePage: ModelSourcePage | undefined;
//public modelDetailsPage: ModelDetailsPage | undefined;
public columnsSelectionPage: ColumnsSelectionPage | undefined;
public wizardView: WizardView | undefined;
private _parentView: ModelViewBase | undefined;
@@ -37,7 +37,7 @@ export class PredictWizard extends ModelViewBase {
/**
* Opens a dialog to manage packages used by notebooks.
*/
public open(): void {
public async open(): Promise<void> {
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this, [ModelSourceType.RegisteredModels, ModelSourceType.Local, ModelSourceType.Azure]);
this.columnsSelectionPage = new ColumnsSelectionPage(this._apiWrapper, this);
this.wizardView = new WizardView(this._apiWrapper);
@@ -50,16 +50,22 @@ export class PredictWizard extends ModelViewBase {
wizard.doneButton.label = constants.predictModel;
wizard.generateScriptButton.hidden = true;
wizard.displayPageTitles = true;
wizard.doneButton.onClick(async () => {
await this.onClose();
});
wizard.cancelButton.onClick(async () => {
await this.onClose();
});
wizard.registerNavigationValidator(async (pageInfo: azdata.window.WizardPageChangeInfo) => {
let validated = this.wizardView ? await this.wizardView.validate(pageInfo) : false;
if (validated && pageInfo.newPage === undefined) {
wizard.cancelButton.enabled = false;
wizard.backButton.enabled = false;
await this.predict();
wizard.cancelButton.enabled = true;
wizard.backButton.enabled = true;
if (this._parentView) {
this._parentView?.refresh();
if (validated) {
if (pageInfo.newPage === undefined) {
this.onLoading();
await this.predict();
this.onLoaded();
if (this._parentView) {
this._parentView?.refresh();
}
}
return true;
@@ -67,7 +73,22 @@ export class PredictWizard extends ModelViewBase {
return validated;
});
wizard.open();
await wizard.open();
}
private onLoading(): void {
this.refreshButtons(true);
}
private onLoaded(): void {
this.refreshButtons(false);
}
private refreshButtons(loading: boolean): void {
if (this.wizardView && this.wizardView.wizard) {
this.wizardView.wizard.cancelButton.enabled = !loading;
this.wizardView.wizard.cancelButton.enabled = !loading;
}
}
public get modelResources(): ModelSourcesComponent | undefined {
@@ -82,16 +103,26 @@ export class PredictWizard extends ModelViewBase {
return this.modelSourcePage?.azureModelsComponent;
}
public async getModelFileName(): Promise<ModelArtifact | undefined> {
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
return new ModelArtifact(this.localModelsComponent.data, false);
} else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) {
return await this.azureModelsComponent.getDownloadedModel();
} else if (this.modelSourcePage && this.modelSourcePage.registeredModelsComponent) {
return await this.modelSourcePage.registeredModelsComponent.getDownloadedModel();
}
return undefined;
}
private async predict(): Promise<boolean> {
try {
let modelFilePath: string = '';
let modelFilePath: string | undefined;
let registeredModel: RegisteredModel | undefined = undefined;
if (this.modelResources && this.localModelsComponent && this.modelResources.data === ModelSourceType.Local) {
modelFilePath = this.localModelsComponent.data;
} else if (this.modelResources && this.azureModelsComponent && this.modelResources.data === ModelSourceType.Azure) {
modelFilePath = await this.downloadAzureModel(this.azureModelsComponent?.data);
} else {
if (this.modelSourcePage && this.modelSourcePage.registeredModelsComponent) {
registeredModel = this.modelSourcePage?.registeredModelsComponent?.data;
} else {
const artifact = await this.getModelFileName();
modelFilePath = artifact?.filePath;
}
await this.generatePredictScript(registeredModel, modelFilePath, this.columnsSelectionPage?.data);
@@ -102,6 +133,14 @@ export class PredictWizard extends ModelViewBase {
}
}
private async onClose(): Promise<void> {
const artifact = await this.getModelFileName();
if (artifact) {
artifact.close();
}
await this.wizardView?.disposePages();
}
/**
* Refresh the pages
*/

View File

@@ -15,7 +15,7 @@ import { IPageView } from '../../interfaces';
* View to render current registered models
*/
export class CurrentModelsPage extends ModelViewBase implements IPageView {
private _tableComponent: azdata.DeclarativeTableComponent | undefined;
private _tableComponent: azdata.Component | undefined;
private _dataTable: CurrentModelsTable | undefined;
private _loader: azdata.LoadingComponent | undefined;

View File

@@ -4,11 +4,13 @@
*--------------------------------------------------------------------------------------------*/
import * as azdata from 'azdata';
import * as vscode from 'vscode';
import * as constants from '../../../common/constants';
import { ModelViewBase } from '../modelViewBase';
import { ApiWrapper } from '../../../common/apiWrapper';
import { RegisteredModel } from '../../../modelManagement/interfaces';
import { IDataComponent } from '../../interfaces';
import { ModelArtifact } from '../prediction/modelArtifact';
/**
* View to render registered models table
@@ -18,6 +20,10 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
private _table: azdata.DeclarativeTableComponent | undefined;
private _modelBuilder: azdata.ModelBuilder | undefined;
private _selectedModel: any;
private _loader: azdata.LoadingComponent | undefined;
private _downloadedFile: ModelArtifact | undefined;
private _onModelSelectionChanged: vscode.EventEmitter<void> = new vscode.EventEmitter<void>();
public readonly onModelSelectionChanged: vscode.Event<void> = this._onModelSelectionChanged.event;
/**
* Creates new view
@@ -30,7 +36,7 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
*
* @param modelBuilder register the components
*/
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.DeclarativeTableComponent {
public registerComponent(modelBuilder: azdata.ModelBuilder): azdata.Component {
this._modelBuilder = modelBuilder;
this._table = modelBuilder.declarativeTable()
.withProperties<azdata.DeclarativeTableProperties>(
@@ -92,7 +98,12 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
ariaLabel: constants.mlsConfigTitle
})
.component();
return this._table;
this._loader = modelBuilder.loadingComponent()
.withItem(this._table)
.withProperties({
loading: true
}).component();
return this._loader;
}
public addComponents(formBuilder: azdata.FormBuilder) {
@@ -111,14 +122,15 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
/**
* Returns the component
*/
public get component(): azdata.DeclarativeTableComponent | undefined {
return this._table;
public get component(): azdata.Component | undefined {
return this._loader;
}
/**
* Loads the data in the component
*/
public async loadData(): Promise<void> {
await this.onLoading();
if (this._table) {
let models: RegisteredModel[] | undefined;
@@ -131,6 +143,20 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
this._table.data = tableData;
}
this.onModelSelected();
await this.onLoaded();
}
public async onLoading(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: true });
}
}
public async onLoaded(): Promise<void> {
if (this._loader) {
await this._loader.updateProperties({ loading: false });
}
}
private createTableRow(model: RegisteredModel): any[] {
@@ -142,8 +168,9 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
height: 15,
checked: false
}).component();
selectModelButton.onDidClick(() => {
selectModelButton.onDidClick(async () => {
this._selectedModel = model;
await this.onModelSelected();
});
return [model.artifactName, model.title, model.created, selectModelButton];
}
@@ -151,6 +178,14 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
return [];
}
private async onModelSelected(): Promise<void> {
this._onModelSelectionChanged.fire();
if (this._downloadedFile) {
await this._downloadedFile.close();
}
this._downloadedFile = undefined;
}
/**
* Returns selected data
*/
@@ -158,6 +193,22 @@ export class CurrentModelsTable extends ModelViewBase implements IDataComponent<
return this._selectedModel;
}
public async getDownloadedModel(): Promise<ModelArtifact> {
if (!this._downloadedFile) {
this._downloadedFile = new ModelArtifact(await this.downloadRegisteredModel(this.data));
}
return this._downloadedFile;
}
/**
* disposes the view
*/
public async disposeComponent(): Promise<void> {
if (this._downloadedFile) {
await this._downloadedFile.close();
}
}
/**
* Refreshes the view
*/

View File

@@ -35,7 +35,7 @@ export class RegisterModelWizard extends ModelViewBase {
/**
* Opens a dialog to manage packages used by notebooks.
*/
public open(): void {
public async open(): Promise<void> {
this.modelSourcePage = new ModelSourcePage(this._apiWrapper, this);
this.modelDetailsPage = new ModelDetailsPage(this._apiWrapper, this);
this.wizardView = new WizardView(this._apiWrapper);
@@ -63,7 +63,7 @@ export class RegisterModelWizard extends ModelViewBase {
return validated;
});
wizard.open();
await wizard.open();
}
public get modelResources(): ModelSourcesComponent | undefined {

View File

@@ -128,14 +128,14 @@ export abstract class ViewBase extends EventEmitterCollection {
if (connection) {
return `${connection.serverName} ${connection.databaseName ? connection.databaseName : ''}`;
}
return constants.packageManagerNoConnection;
return constants.noConnectionError;
}
public getServerTitle(): string {
if (this.connection) {
return this.connection.serverName;
}
return constants.packageManagerNoConnection;
return constants.noConnectionError;
}
private async getCurrentConnectionUrl(): Promise<string> {

View File

@@ -68,7 +68,7 @@ export class WizardView extends MainViewBase {
this._pages = pages;
this._wizard.pages = pages.map(x => this.createWizardPage(x.title || '', x));
this._wizard.onPageChanged(async (info) => {
this.onWizardPageChanged(info);
await this.onWizardPageChanged(info);
});
return this._wizard;
@@ -85,17 +85,17 @@ export class WizardView extends MainViewBase {
return true;
}
private onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) {
private async onWizardPageChanged(pageInfo: azdata.window.WizardPageChangeInfo) {
let idxLast = pageInfo.lastPage;
let lastPage = this._pages[idxLast];
if (lastPage && lastPage.onLeave) {
lastPage.onLeave();
await lastPage.onLeave();
}
let idx = pageInfo.newPage;
let page = this._pages[idx];
if (page && page.onEnter) {
page.onEnter();
await page.onEnter();
}
}