Send server contextualization to Copilot extension (#24230)

* Send server contextualization to Copilot extension

* Keep context in editor input

* Remove unnecessary server context and extension service

* Send context when connecting from open editor

* Remove contextualization complete event

* Contextualize editor after connection success

* Minor clean up

* Remove nested then and use async/await

* Create helper function

* Remove unneeded async and add comment

* Encapsulate all context logic in service

* Use void operator to fix floating promise

* Correct return comment
This commit is contained in:
Lewis Sanchez
2023-09-01 09:26:29 -07:00
committed by GitHub
parent e3d0670609
commit 5152823306
12 changed files with 100 additions and 72 deletions

View File

@@ -1567,8 +1567,8 @@ export interface ServerContextualizationParams {
ownerUri: string; ownerUri: string;
} }
export namespace GenerateServerContextualizationNotification { export namespace GenerateServerContextualizationRequest {
export const type = new NotificationType<ServerContextualizationParams, void>('metadata/generateServerContext'); export const type = new RequestType<ServerContextualizationParams, azdata.contextualization.GenerateServerContextualizationResult, void, void>('metadata/generateServerContext');
} }
export namespace GetServerContextualizationRequest { export namespace GetServerContextualizationRequest {

View File

@@ -1310,7 +1310,7 @@ export class ExecutionPlanServiceFeature extends SqlOpsFeature<undefined> {
*/ */
export class ServerContextualizationServiceFeature extends SqlOpsFeature<undefined> { export class ServerContextualizationServiceFeature extends SqlOpsFeature<undefined> {
private static readonly messagesTypes: RPCMessageType[] = [ private static readonly messagesTypes: RPCMessageType[] = [
contracts.GenerateServerContextualizationNotification.type contracts.GenerateServerContextualizationRequest.type
]; ];
constructor(client: SqlOpsDataClient) { constructor(client: SqlOpsDataClient) {
@@ -1330,12 +1330,18 @@ export class ServerContextualizationServiceFeature extends SqlOpsFeature<undefin
protected registerProvider(options: undefined): Disposable { protected registerProvider(options: undefined): Disposable {
const client = this._client; const client = this._client;
const generateServerContextualization = (ownerUri: string): void => { const generateServerContextualization = (ownerUri: string): Thenable<azdata.contextualization.GenerateServerContextualizationResult> => {
const params: contracts.ServerContextualizationParams = { const params: contracts.ServerContextualizationParams = {
ownerUri: ownerUri ownerUri: ownerUri
}; };
return client.sendNotification(contracts.GenerateServerContextualizationNotification.type, params); return client.sendRequest(contracts.GenerateServerContextualizationRequest.type, params).then(
r => r,
e => {
client.logFailedRequest(contracts.GenerateServerContextualizationRequest.type, e);
return Promise.reject(e);
}
);
}; };
const getServerContextualization = (ownerUri: string): Thenable<azdata.contextualization.GetServerContextualizationResult> => { const getServerContextualization = (ownerUri: string): Thenable<azdata.contextualization.GetServerContextualizationResult> => {

View File

@@ -901,7 +901,7 @@ declare module 'azdata' {
* Copilot for improved suggestions. * Copilot for improved suggestions.
* @param provider The provider to register * @param provider The provider to register
*/ */
export function registerServerContextualizationProvider(provider: contextualization.ServerContextualizationProvider): vscode.Disposable export function registerServerContextualizationProvider(provider: contextualization.ServerContextualizationProvider): vscode.Disposable;
} }
export namespace designers { export namespace designers {
@@ -1783,11 +1783,18 @@ declare module 'azdata' {
} }
export namespace contextualization { export namespace contextualization {
export interface GenerateServerContextualizationResult {
/**
* The generated server context.
*/
context: string | undefined;
}
export interface GetServerContextualizationResult { export interface GetServerContextualizationResult {
/** /**
* An array containing the generated server context. * The retrieved server context.
*/ */
context: string[]; context: string | undefined;
} }
export interface ServerContextualizationProvider extends DataProvider { export interface ServerContextualizationProvider extends DataProvider {
@@ -1795,7 +1802,7 @@ declare module 'azdata' {
* Generates server context. * Generates server context.
* @param ownerUri The URI of the connection to generate context for. * @param ownerUri The URI of the connection to generate context for.
*/ */
generateServerContextualization(ownerUri: string): void; generateServerContextualization(ownerUri: string): Thenable<GenerateServerContextualizationResult>;
/** /**
* Gets server context, which can be in the form of create scripts but is left up each provider. * Gets server context, which can be in the form of create scripts but is left up each provider.

View File

@@ -972,8 +972,8 @@ export class ExtHostDataProtocol extends ExtHostDataProtocolShape {
// Database Server Contextualization API // Database Server Contextualization API
public override $generateServerContextualization(handle: number, ownerUri: string): void { public override $generateServerContextualization(handle: number, ownerUri: string): Thenable<azdata.contextualization.GenerateServerContextualizationResult> {
this._resolveProvider<azdata.contextualization.ServerContextualizationProvider>(handle).generateServerContextualization(ownerUri); return this._resolveProvider<azdata.contextualization.ServerContextualizationProvider>(handle).generateServerContextualization(ownerUri);
} }
public override $getServerContextualization(handle: number, ownerUri: string): Thenable<azdata.contextualization.GetServerContextualizationResult> { public override $getServerContextualization(handle: number, ownerUri: string): Thenable<azdata.contextualization.GetServerContextualizationResult> {

View File

@@ -598,7 +598,7 @@ export abstract class ExtHostDataProtocolShape {
/** /**
* Generates server context. * Generates server context.
*/ */
$generateServerContextualization(handle: number, ownerUri: string): void { throw ni(); } $generateServerContextualization(handle: number, ownerUri: string): Thenable<azdata.contextualization.GenerateServerContextualizationResult> { throw ni(); }
/** /**
* Gets server context. * Gets server context.
*/ */

View File

@@ -19,7 +19,6 @@ import { IInstantiationService } from 'vs/platform/instantiation/common/instanti
import { EditorInput } from 'vs/workbench/common/editor/editorInput'; import { EditorInput } from 'vs/workbench/common/editor/editorInput';
import { IResourceEditorInput } from 'vs/platform/editor/common/editor'; import { IResourceEditorInput } from 'vs/platform/editor/common/editor';
import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces'; import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces';
import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions';
export class FileQueryEditorInput extends QueryEditorInput { export class FileQueryEditorInput extends QueryEditorInput {
@@ -33,10 +32,9 @@ export class FileQueryEditorInput extends QueryEditorInput {
@IQueryModelService queryModelService: IQueryModelService, @IQueryModelService queryModelService: IQueryModelService,
@IConfigurationService configurationService: IConfigurationService, @IConfigurationService configurationService: IConfigurationService,
@IInstantiationService instantiationService: IInstantiationService, @IInstantiationService instantiationService: IInstantiationService,
@IServerContextualizationService serverContextualizationService: IServerContextualizationService, @IServerContextualizationService serverContextualizationService: IServerContextualizationService
@IExtensionService extensionService: IExtensionService
) { ) {
super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService, serverContextualizationService, extensionService); super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService, serverContextualizationService);
} }
public override resolve(): Promise<ITextFileEditorModel | BinaryEditorModel> { public override resolve(): Promise<ITextFileEditorModel | BinaryEditorModel> {

View File

@@ -23,7 +23,6 @@ import { IEditorResolverService } from 'vs/workbench/services/editor/common/edit
import { Uri } from 'vscode'; import { Uri } from 'vscode';
import { ILogService } from 'vs/platform/log/common/log'; import { ILogService } from 'vs/platform/log/common/log';
import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces'; import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces';
import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions';
export class UntitledQueryEditorInput extends QueryEditorInput implements IUntitledQueryEditorInput { export class UntitledQueryEditorInput extends QueryEditorInput implements IUntitledQueryEditorInput {
@@ -39,10 +38,9 @@ export class UntitledQueryEditorInput extends QueryEditorInput implements IUntit
@IInstantiationService instantiationService: IInstantiationService, @IInstantiationService instantiationService: IInstantiationService,
@ILogService private readonly logService: ILogService, @ILogService private readonly logService: ILogService,
@IEditorResolverService private readonly editorResolverService: IEditorResolverService, @IEditorResolverService private readonly editorResolverService: IEditorResolverService,
@IServerContextualizationService serverContextualizationService: IServerContextualizationService, @IServerContextualizationService serverContextualizationService: IServerContextualizationService
@IExtensionService extensionService: IExtensionService
) { ) {
super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService, serverContextualizationService, extensionService); super(description, text, results, connectionManagementService, queryModelService, configurationService, instantiationService, serverContextualizationService);
// Set the mode explicitely to stop the auto language detection service from changing the mode unexpectedly. // Set the mode explicitely to stop the auto language detection service from changing the mode unexpectedly.
// the auto language detection service won't do the language change only if the mode is explicitely set. // the auto language detection service won't do the language change only if the mode is explicitely set.
// if the mode (e.g. kusto, sql) do not exist for whatever reason, we will default it to sql. // if the mode (e.g. kusto, sql) do not exist for whatever reason, we will default it to sql.

View File

@@ -21,7 +21,6 @@ import { IQueryEditorConfiguration } from 'sql/platform/query/common/query';
import { EditorInput } from 'vs/workbench/common/editor/editorInput'; import { EditorInput } from 'vs/workbench/common/editor/editorInput';
import { IInstantiationService } from 'vs/platform/instantiation/common/instantiation'; import { IInstantiationService } from 'vs/platform/instantiation/common/instantiation';
import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces'; import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces';
import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions';
const MAX_SIZE = 13; const MAX_SIZE = 13;
@@ -144,8 +143,6 @@ export abstract class QueryEditorInput extends EditorInput implements IConnectab
private _state = this._register(new QueryEditorState()); private _state = this._register(new QueryEditorState());
public get state(): QueryEditorState { return this._state; } public get state(): QueryEditorState { return this._state; }
private _serverContext: string[];
constructor( constructor(
private _description: string | undefined, private _description: string | undefined,
protected _text: AbstractTextResourceEditorInput, protected _text: AbstractTextResourceEditorInput,
@@ -154,8 +151,7 @@ export abstract class QueryEditorInput extends EditorInput implements IConnectab
@IQueryModelService private readonly queryModelService: IQueryModelService, @IQueryModelService private readonly queryModelService: IQueryModelService,
@IConfigurationService private readonly configurationService: IConfigurationService, @IConfigurationService private readonly configurationService: IConfigurationService,
@IInstantiationService protected readonly instantiationService: IInstantiationService, @IInstantiationService protected readonly instantiationService: IInstantiationService,
@IServerContextualizationService private readonly serverContextualizationService: IServerContextualizationService, @IServerContextualizationService private readonly serverContextualizationService: IServerContextualizationService
@IExtensionService private readonly extensionService: IExtensionService
) { ) {
super(); super();
@@ -241,27 +237,6 @@ export abstract class QueryEditorInput extends EditorInput implements IConnectab
public override isDirty(): boolean { return this._text.isDirty(); } public override isDirty(): boolean { return this._text.isDirty(); }
public get resource(): URI { return this._text.resource; } public get resource(): URI { return this._text.resource; }
public async getServerContext(): Promise<string[]> {
const copilotExt = await this.extensionService.getExtension('github.copilot');
if (copilotExt && this.configurationService.getValue<IQueryEditorConfiguration>('queryEditor').githubCopilotContextualizationEnabled) {
if (!this._serverContext) {
const result = await this.serverContextualizationService.getServerContextualization(this.uri);
// TODO lewissanchez - Remove this from here once Copilot starts pulling context. That isn't implemented yet, so
// getting scripts this way for now.
this._serverContext = result.context;
return this._serverContext;
}
else {
return this._serverContext;
}
}
else {
return Promise.resolve([]);
}
}
public override getName(longForm?: boolean): string { public override getName(longForm?: boolean): string {
if (this.configurationService.getValue<IQueryEditorConfiguration>('queryEditor').showConnectionInfoInTitle) { if (this.configurationService.getValue<IQueryEditorConfiguration>('queryEditor').showConnectionInfoInTitle) {
let profile = this.connectionManagementService.getConnectionProfile(this.uri); let profile = this.connectionManagementService.getConnectionProfile(this.uri);
@@ -346,6 +321,9 @@ export abstract class QueryEditorInput extends EditorInput implements IConnectab
} }
} }
this._onDidChangeLabel.fire(); this._onDidChangeLabel.fire();
// Intentionally not awaiting, so that contextualization can happen in the background
void this.serverContextualizationService?.contextualizeUriForCopilot(this.uri);
} }
public onDisconnect(): void { public onDisconnect(): void {

View File

@@ -467,7 +467,7 @@ suite('commandLineService tests', () => {
let uri = URI.file(args._[0]); let uri = URI.file(args._[0]);
const workbenchinstantiationService = workbenchInstantiationService(); const workbenchinstantiationService = workbenchInstantiationService();
const editorInput = workbenchinstantiationService.createInstance(FileEditorInput, uri, undefined, undefined, undefined, undefined, undefined, undefined); const editorInput = workbenchinstantiationService.createInstance(FileEditorInput, uri, undefined, undefined, undefined, undefined, undefined, undefined);
const queryInput = new FileQueryEditorInput(undefined, editorInput, undefined, connectionManagementService.object, querymodelService.object, configurationService.object, workbenchinstantiationService, undefined, undefined); const queryInput = new FileQueryEditorInput(undefined, editorInput, undefined, connectionManagementService.object, querymodelService.object, configurationService.object, workbenchinstantiationService, undefined);
queryInput.state.connected = true; queryInput.state.connected = true;
const editorService: TypeMoq.Mock<IEditorService> = TypeMoq.Mock.ofType<IEditorService>(TestEditorService, TypeMoq.MockBehavior.Strict); const editorService: TypeMoq.Mock<IEditorService> = TypeMoq.Mock.ofType<IEditorService>(TestEditorService, TypeMoq.MockBehavior.Strict);
editorService.setup(e => e.editors).returns(() => [queryInput]); editorService.setup(e => e.editors).returns(() => [queryInput]);

View File

@@ -317,7 +317,6 @@ suite('SQL QueryEditor Tests', () => {
testinstantiationService, testinstantiationService,
undefined, undefined,
undefined, undefined,
undefined,
undefined undefined
); );
}); });

View File

@@ -30,14 +30,9 @@ export interface IServerContextualizationService {
getProvider(providerId: string): azdata.contextualization.ServerContextualizationProvider; getProvider(providerId: string): azdata.contextualization.ServerContextualizationProvider;
/** /**
* Generates server context * Contextualizes the provided URI for GitHub Copilot.
* @param ownerUri The URI of the connection to generate context for. * @param uri The URI to contextualize for Copilot.
* @returns Copilot will have the URI contextualized when the promise completes.
*/ */
generateServerContextualization(ownerUri: string): void; contextualizeUriForCopilot(uri: string): Promise<void>;
/**
* Gets all database context.
* @param ownerUri The URI of the connection to get context for.
*/
getServerContextualization(ownerUri: string): Promise<azdata.contextualization.GetServerContextualizationResult>;
} }

View File

@@ -5,10 +5,11 @@
import * as azdata from 'azdata'; import * as azdata from 'azdata';
import { invalidProvider } from 'sql/base/common/errors'; import { invalidProvider } from 'sql/base/common/errors';
import { IConnectionManagementService, IConnectionParams } from 'sql/platform/connection/common/connectionManagement'; import { IConnectionManagementService } from 'sql/platform/connection/common/connectionManagement';
import { IQueryEditorConfiguration } from 'sql/platform/query/common/query'; import { IQueryEditorConfiguration } from 'sql/platform/query/common/query';
import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces'; import { IServerContextualizationService } from 'sql/workbench/services/contextualization/common/interfaces';
import { Disposable } from 'vs/base/common/lifecycle'; import { Disposable } from 'vs/base/common/lifecycle';
import { ICommandService } from 'vs/platform/commands/common/commands';
import { IConfigurationService } from 'vs/platform/configuration/common/configuration'; import { IConfigurationService } from 'vs/platform/configuration/common/configuration';
import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions'; import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions';
@@ -19,18 +20,10 @@ export class ServerContextualizationService extends Disposable implements IServe
constructor( constructor(
@IConnectionManagementService private readonly _connectionManagementService: IConnectionManagementService, @IConnectionManagementService private readonly _connectionManagementService: IConnectionManagementService,
@IConfigurationService private readonly _configurationService: IConfigurationService, @IConfigurationService private readonly _configurationService: IConfigurationService,
@IExtensionService private readonly _extensionService: IExtensionService @IExtensionService private readonly _extensionService: IExtensionService,
@ICommandService private readonly _commandService: ICommandService
) { ) {
super(); super();
this._register(this._connectionManagementService.onConnect(async (e: IConnectionParams) => {
const copilotExt = await this._extensionService.getExtension('github.copilot');
if (copilotExt && this._configurationService.getValue<IQueryEditorConfiguration>('queryEditor').githubCopilotContextualizationEnabled) {
const ownerUri = e.connectionUri;
await this.generateServerContextualization(ownerUri);
}
}));
} }
/** /**
@@ -63,15 +56,44 @@ export class ServerContextualizationService extends Disposable implements IServe
throw invalidProvider(providerId); throw invalidProvider(providerId);
} }
/**
* Contextualizes the provided URI for GitHub Copilot.
* @param uri The URI to contextualize for Copilot.
* @returns Copilot will have the URI contextualized when the promise completes.
*/
public async contextualizeUriForCopilot(uri: string): Promise<void> {
// Don't need to take any actions if contextualization is not enabled and can return
const isContextualizationNeeded = await this.isContextualizationNeeded();
if (!isContextualizationNeeded) {
return;
}
const getServerContextualizationResult = await this.getServerContextualization(uri);
if (getServerContextualizationResult.context) {
await this.sendServerContextualizationToCopilot(getServerContextualizationResult.context);
}
else {
const generateServerContextualizationResult = await this.generateServerContextualization(uri);
if (generateServerContextualizationResult.context) {
await this.sendServerContextualizationToCopilot(generateServerContextualizationResult.context);
}
}
}
/** /**
* Generates server context * Generates server context
* @param ownerUri The URI of the connection to generate context for. * @param ownerUri The URI of the connection to generate context for.
*/ */
public generateServerContextualization(ownerUri: string): void { private async generateServerContextualization(ownerUri: string): Promise<azdata.contextualization.GenerateServerContextualizationResult> {
const providerName = this._connectionManagementService.getProviderIdFromUri(ownerUri); const providerName = this._connectionManagementService.getProviderIdFromUri(ownerUri);
const handler = this.getProvider(providerName); const handler = this.getProvider(providerName);
if (handler) { if (handler) {
handler.generateServerContextualization(ownerUri); return await handler.generateServerContextualization(ownerUri);
}
else {
return Promise.resolve({
context: undefined
});
} }
} }
@@ -79,7 +101,7 @@ export class ServerContextualizationService extends Disposable implements IServe
* Gets all database context. * Gets all database context.
* @param ownerUri The URI of the connection to get context for. * @param ownerUri The URI of the connection to get context for.
*/ */
public async getServerContextualization(ownerUri: string): Promise<azdata.contextualization.GetServerContextualizationResult> { private async getServerContextualization(ownerUri: string): Promise<azdata.contextualization.GetServerContextualizationResult> {
const providerName = this._connectionManagementService.getProviderIdFromUri(ownerUri); const providerName = this._connectionManagementService.getProviderIdFromUri(ownerUri);
const handler = this.getProvider(providerName); const handler = this.getProvider(providerName);
if (handler) { if (handler) {
@@ -87,8 +109,33 @@ export class ServerContextualizationService extends Disposable implements IServe
} }
else { else {
return Promise.resolve({ return Promise.resolve({
context: [] context: undefined
}); });
} }
} }
/**
* Sends the provided context over to copilot, so that it can be used to generate improved suggestions.
* @param serverContext The context to be sent over to Copilot
*/
private async sendServerContextualizationToCopilot(serverContext: string | undefined): Promise<void> {
if (serverContext) {
// LEWISSANCHEZ TODO: Find way to set context on untitled query editor files. Need to save first for Copilot status to say "Has Context"
await this._commandService.executeCommand('github.copilot.provideContext', '**/*.sql', {
value: serverContext
});
}
}
/**
* Checks if contextualization is needed. This is based on whether the Copilot extension is installed and the GitHub Copilot
* contextualization setting is enabled.
* @returns A promise that resolves to true if contextualization is needed, false otherwise.
*/
private async isContextualizationNeeded(): Promise<boolean> {
const copilotExt = await this._extensionService.getExtension('github.copilot');
const isContextualizationEnabled = this._configurationService.getValue<IQueryEditorConfiguration>('queryEditor').githubCopilotContextualizationEnabled
return (copilotExt && isContextualizationEnabled);
}
} }