feat(misc): embeddings script and list sources in ai response (#18455)

This commit is contained in:
Katerina Skroumpelou 2023-08-03 18:51:31 +03:00 committed by GitHub
parent 0c0e61e122
commit e9d50af945
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1228 additions and 362 deletions

View File

@ -0,0 +1,53 @@
name: Generate embeddings
on:
schedule:
- cron: "0 5 * * 0,4" # sunday, thursday 5AM
workflow_dispatch:
jobs:
cache-and-install:
if: github.repository == 'nrwl/nx'
runs-on: ubuntu-latest
strategy:
matrix:
node-version: [18]
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Install Node.js
uses: actions/setup-node@v3
with:
node-version: 18
- name: Install pnpm
uses: pnpm/action-setup@v2
id: pnpm-install
with:
version: 7
run_install: false
- name: Get pnpm store directory
id: pnpm-cache
shell: bash
run: |
echo "STORE_PATH=$(pnpm store path)" >> $GITHUB_OUTPUT
- name: Setup pnpm cache
uses: actions/cache@v3
with:
path: ${{ steps.pnpm-cache.outputs.STORE_PATH }}
key: ${{ runner.os }}-pnpm-store-${{ hashFiles('**/pnpm-lock.yaml') }}
restore-keys: |
${{ runner.os }}-pnpm-store-
- name: Install dependencies
run: pnpm install --no-frozen-lockfile
- name: Run embeddings script
run: pnpm exec nx run tools-documentation-create-embeddings:run-node
env:
NX_NEXT_PUBLIC_SUPABASE_URL: ${{ secrets.NX_NEXT_PUBLIC_SUPABASE_URL }}
NX_SUPABASE_SERVICE_ROLE_KEY: ${{ secrets.NX_SUPABASE_SERVICE_ROLE_KEY }}
NX_OPENAI_KEY: ${{ secrets.NX_OPENAI_KEY }}

View File

@ -11,7 +11,13 @@ import {
ChatCompletionRequestMessageRoleEnum,
CreateCompletionResponseUsage,
} from 'openai';
import { getMessageFromResponse, sanitizeLinksInResponse } from './utils';
import {
PageSection,
getListOfSources,
getMessageFromResponse,
sanitizeLinksInResponse,
toMarkdownList,
} from './utils';
const openAiKey = process.env['NX_OPENAI_KEY'];
const supabaseUrl = process.env['NX_NEXT_PUBLIC_SUPABASE_URL'];
@ -21,9 +27,12 @@ const config = new Configuration({
});
const openai = new OpenAIApi(config);
export async function nxDevDataAccessAi(
query: string
): Promise<{ textResponse: string; usage?: CreateCompletionResponseUsage }> {
export async function nxDevDataAccessAi(query: string): Promise<{
textResponse: string;
usage?: CreateCompletionResponseUsage;
sources: { heading: string; url: string }[];
sourcesMarkdown: string;
}> {
try {
if (!openAiKey) {
throw new ApplicationError('Missing environment variable NX_OPENAI_KEY');
@ -80,11 +89,11 @@ export async function nxDevDataAccessAi(
}: CreateEmbeddingResponse = embeddingResponse.data;
const { error: matchError, data: pageSections } = await supabaseClient.rpc(
'match_page_sections',
'match_page_sections_2',
{
embedding,
match_threshold: 0.78,
match_count: 10,
match_count: 15,
min_content_length: 50,
}
);
@ -97,13 +106,13 @@ export async function nxDevDataAccessAi(
let tokenCount = 0;
let contextText = '';
for (let i = 0; i < pageSections.length; i++) {
const pageSection = pageSections[i];
for (let i = 0; i < (pageSections as PageSection[]).length; i++) {
const pageSection: PageSection = pageSections[i];
const content = pageSection.content;
const encoded = tokenizer.encode(content);
tokenCount += encoded.text.length;
if (tokenCount >= 1500) {
if (tokenCount >= 2500) {
break;
}
@ -163,9 +172,13 @@ export async function nxDevDataAccessAi(
const responseWithoutBadLinks = await sanitizeLinksInResponse(message);
const sources = getListOfSources(pageSections);
return {
textResponse: responseWithoutBadLinks,
usage: response.data.usage,
sources,
sourcesMarkdown: toMarkdownList(sources),
};
} catch (err: unknown) {
if (err instanceof UserError) {

View File

@ -1,4 +1,13 @@
import { CreateChatCompletionResponse } from 'openai';
export interface PageSection {
id: number;
page_id: number;
content: string;
heading: string;
similarity: number;
slug: string;
url_partial: string | null;
}
export function getMessageFromResponse(
response: CreateChatCompletionResponse
@ -11,6 +20,34 @@ export function getMessageFromResponse(
return response.choices[0].message?.content ?? '';
}
export function getListOfSources(
pageSections: PageSection[]
): { heading: string; url: string }[] {
const uniqueUrlPartials = new Set<string | null>();
const result = pageSections
.filter((section) => {
if (section.url_partial && !uniqueUrlPartials.has(section.url_partial)) {
uniqueUrlPartials.add(section.url_partial);
return true;
}
return false;
})
.map((section) => ({
heading: section.heading,
url: `https://nx.dev${section.url_partial}`,
}));
return result;
}
export function toMarkdownList(
sections: { heading: string; url: string }[]
): string {
return sections
.map((section) => `- [${section.heading}](${section.url})`)
.join('\n');
}
export async function sanitizeLinksInResponse(
response: string
): Promise<string> {

View File

@ -11,6 +11,7 @@ export function FeatureAi(): JSX.Element {
const [query, setSearchTerm] = useState('');
const [loading, setLoading] = useState(false);
const [feedbackSent, setFeedbackSent] = useState<boolean>(false);
const [sources, setSources] = useState('');
const warning = `
{% callout type="warning" title="Always double check!" %}
@ -23,19 +24,33 @@ export function FeatureAi(): JSX.Element {
setLoading(true);
let completeText = '';
let usage;
let sourcesMarkdown = '';
try {
const aiResponse = await nxDevDataAccessAi(query);
completeText = aiResponse.textResponse;
usage = aiResponse.usage;
setSources(
JSON.stringify(aiResponse.sources?.map((source) => source.url))
);
sourcesMarkdown = aiResponse.sourcesMarkdown;
setLoading(false);
} catch (error) {
setError(error as any);
setLoading(false);
}
sendCustomEvent('ai_query', 'ai', 'query', undefined, { query, ...usage });
sendCustomEvent('ai_query', 'ai', 'query', undefined, {
query,
...usage,
});
setFeedbackSent(false);
const sourcesMd = `
{% callout type="info" title="Sources" %}
${sourcesMarkdown}
{% /callout %}`;
setFinalResult(
renderMarkdown(warning + completeText, { filePath: '' }).node
renderMarkdown(warning + completeText + sourcesMd, { filePath: '' }).node
);
};
@ -44,6 +59,7 @@ export function FeatureAi(): JSX.Element {
sendCustomEvent('ai_feedback', 'ai', type, undefined, {
query,
result: finalResult,
sources,
});
setFeedbackSent(true);
} catch (error) {

View File

@ -175,6 +175,7 @@
"flat": "^5.0.2",
"fork-ts-checker-webpack-plugin": "7.2.13",
"fs-extra": "^11.1.0",
"github-slugger": "^2.0.0",
"gpt3-tokenizer": "^1.1.5",
"html-webpack-plugin": "5.5.0",
"http-server": "14.1.0",
@ -191,6 +192,7 @@
"jest": "29.4.3",
"jest-config": "^29.4.1",
"jest-environment-jsdom": "29.4.3",
"jest-environment-node": "^29.4.1",
"jest-resolve": "^29.4.1",
"jest-util": "^29.4.1",
"js-tokens": "^4.0.0",
@ -206,6 +208,9 @@
"loader-utils": "2.0.3",
"magic-string": "~0.30.2",
"markdown-factory": "^0.0.6",
"mdast-util-from-markdown": "^1.3.1",
"mdast-util-to-markdown": "^1.5.0",
"mdast-util-to-string": "^3.2.0",
"memfs": "^3.0.1",
"metro-config": "0.76.7",
"metro-resolver": "0.76.7",
@ -267,6 +272,7 @@
"typedoc": "0.24.8",
"typedoc-plugin-markdown": "3.15.3",
"typescript": "~5.1.3",
"unist-builder": "^4.0.0",
"unzipper": "^0.10.11",
"url-loader": "^4.1.1",
"use-sync-external-store": "^1.2.0",
@ -359,4 +365,3 @@
}
}
}

831
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,18 @@
{
"extends": ["../../../.eslintrc.json"],
"ignorePatterns": ["!**/*"],
"overrides": [
{
"files": ["*.ts", "*.tsx", "*.js", "*.jsx"],
"rules": {}
},
{
"files": ["*.ts", "*.tsx"],
"rules": {}
},
{
"files": ["*.js", "*.jsx"],
"rules": {}
}
]
}

View File

@ -0,0 +1,18 @@
/* eslint-disable */
export default {
displayName: 'tools-documentation-create-embeddings',
preset: './jest.preset.js',
testEnvironment: 'node',
transform: {
'^.+\\.(ts|tsx|js|jsx|mts|mjs)$': [
'ts-jest',
{ tsconfig: '<rootDir>/tsconfig.spec.json' },
],
},
transformIgnorePatterns: [
// Ensure that Jest does not ignore github-slugger
'<rootDir>/node_modules/.pnpm/(?!(github-slugger)@)',
],
moduleFileExtensions: ['mts', 'ts', 'js', 'html'],
coverageDirectory: '../../../coverage/tools/documentation/create-embeddings',
};

View File

@ -0,0 +1,16 @@
const nxPreset = require('@nx/jest/preset').default;
module.exports = {
...nxPreset,
testTimeout: 30000,
testMatch: ['**/+(*.)+(spec|test).+(ts|js)?(x)'],
transform: {
'^.+\\.(ts|tsx|js|jsx|mts|mjs)$': 'ts-jest',
},
resolver: '../../../scripts/patched-jest-resolver.js',
// Fixes https://github.com/jestjs/jest/issues/11956
runtime: '@side/jest-runtime',
moduleFileExtensions: ['ts', 'js', 'mts', 'html'],
coverageReporters: ['html'],
maxWorkers: 1,
};

View File

@ -0,0 +1,78 @@
{
"name": "tools-documentation-create-embeddings",
"$schema": "../../../node_modules/nx/schemas/project-schema.json",
"sourceRoot": "tools/documentation/create-embeddings/src",
"projectType": "application",
"targets": {
"build": {
"executor": "@nx/esbuild:esbuild",
"outputs": ["{options.outputPath}"],
"defaultConfiguration": "production",
"options": {
"platform": "node",
"outputPath": "dist/tools/documentation/create-embeddings",
"format": ["esm"],
"bundle": false,
"main": "tools/documentation/create-embeddings/src/main.mts",
"tsConfig": "tools/documentation/create-embeddings/tsconfig.app.json",
"assets": ["tools/documentation/create-embeddings/src/assets"],
"generatePackageJson": true,
"esbuildOptions": {
"sourcemap": true,
"outExtension": {
".js": ".js"
}
}
},
"configurations": {
"development": {},
"production": {
"esbuildOptions": {
"sourcemap": false,
"outExtension": {
".js": ".js"
}
}
}
}
},
"run-node": {
"executor": "@nx/js:node",
"defaultConfiguration": "development",
"options": {
"buildTarget": "tools-documentation-create-embeddings:build",
"watch": false
},
"configurations": {
"development": {
"buildTarget": "tools-documentation-create-embeddings:build:development"
},
"production": {
"buildTarget": "tools-documentation-create-embeddings:build:production"
}
}
},
"lint": {
"executor": "@nx/linter:eslint",
"outputs": ["{options.outputFile}"],
"options": {
"lintFilePatterns": ["tools/documentation/create-embeddings/**/*.ts"]
}
},
"test": {
"executor": "@nx/jest:jest",
"outputs": ["{workspaceRoot}/coverage/{projectRoot}"],
"options": {
"jestConfig": "tools/documentation/create-embeddings/jest.config.ts",
"passWithNoTests": true
},
"configurations": {
"ci": {
"ci": true,
"codeCoverage": true
}
}
}
},
"tags": []
}

View File

@ -0,0 +1,434 @@
// based on:
// https://github.com/supabase-community/nextjs-openai-doc-search/blob/main/lib/generate-embeddings.ts
import { createClient } from '@supabase/supabase-js';
import * as dotenv from 'dotenv';
import { readFile } from 'fs/promises';
import 'openai';
import { Configuration, OpenAIApi } from 'openai';
import { inspect } from 'util';
import yargs from 'yargs';
import { createHash } from 'crypto';
import GithubSlugger from 'github-slugger';
import { fromMarkdown } from 'mdast-util-from-markdown';
import { toMarkdown } from 'mdast-util-to-markdown';
import { toString } from 'mdast-util-to-string';
import { u } from 'unist-builder';
import mapJson from '../../../../docs/map.json' assert { type: 'json' };
import manifestsCloud from '../../../../docs/generated/manifests/cloud.json' assert { type: 'json' };
import manifestsExtending from '../../../../docs/generated/manifests/extending-nx.json' assert { type: 'json' };
import manifestsNx from '../../../../docs/generated/manifests/nx.json' assert { type: 'json' };
import manifestsPackages from '../../../../docs/generated/manifests/packages.json' assert { type: 'json' };
import manifestsRecipes from '../../../../docs/generated/manifests/recipes.json' assert { type: 'json' };
import manifestsTags from '../../../../docs/generated/manifests/tags.json' assert { type: 'json' };
dotenv.config();
type ProcessedMdx = {
checksum: string;
sections: Section[];
};
type Section = {
content: string;
heading?: string;
slug?: string;
};
/**
* Splits a `mdast` tree into multiple trees based on
* a predicate function. Will include the splitting node
* at the beginning of each tree.
*
* Useful to split a markdown file into smaller sections.
*/
export function splitTreeBy(tree: any, predicate: (node: any) => boolean) {
return tree.children.reduce((trees: any, node: any) => {
const [lastTree] = trees.slice(-1);
if (!lastTree || predicate(node)) {
const tree = u('root', [node]);
return trees.concat(tree);
}
lastTree.children.push(node);
return trees;
}, []);
}
/**
* Processes MD content for search indexing.
* It extracts metadata and splits it into sub-sections based on criteria.
*/
export function processMdxForSearch(content: string): ProcessedMdx {
const checksum = createHash('sha256').update(content).digest('base64');
const mdTree = fromMarkdown(content, {});
if (!mdTree) {
return {
checksum,
sections: [],
};
}
const sectionTrees = splitTreeBy(mdTree, (node) => node.type === 'heading');
const slugger = new GithubSlugger();
const sections = sectionTrees.map((tree: any) => {
const [firstNode] = tree.children;
const heading =
firstNode.type === 'heading' ? toString(firstNode) : undefined;
const slug = heading ? slugger.slug(heading) : undefined;
return {
content: toMarkdown(tree),
heading,
slug,
};
});
return {
checksum,
sections,
};
}
type WalkEntry = {
path: string;
url_partial: string;
};
abstract class BaseEmbeddingSource {
checksum?: string;
sections?: Section[];
constructor(
public source: string,
public path: string,
public url_partial: string
) {}
abstract load(): Promise<{
checksum: string;
sections: Section[];
}>;
}
class MarkdownEmbeddingSource extends BaseEmbeddingSource {
type: 'markdown' = 'markdown';
constructor(
source: string,
public filePath: string,
public url_partial: string
) {
const path = filePath.replace(/^docs/, '').replace(/\.md?$/, '');
super(source, path, url_partial);
}
async load() {
const contents = await readFile(this.filePath, 'utf8');
const { checksum, sections } = processMdxForSearch(contents);
this.checksum = checksum;
this.sections = sections;
return {
checksum,
sections,
};
}
}
type EmbeddingSource = MarkdownEmbeddingSource;
async function generateEmbeddings() {
const argv = await yargs().option('refresh', {
alias: 'r',
description: 'Refresh data',
type: 'boolean',
}).argv;
const shouldRefresh = argv.refresh;
if (!process.env.NX_NEXT_PUBLIC_SUPABASE_URL) {
throw new Error(
'Environment variable NX_NEXT_PUBLIC_SUPABASE_URL is required: skipping embeddings generation'
);
}
if (!process.env.NX_SUPABASE_SERVICE_ROLE_KEY) {
throw new Error(
'Environment variable NX_SUPABASE_SERVICE_ROLE_KEY is required: skipping embeddings generation'
);
}
if (!process.env.NX_OPENAI_KEY) {
throw new Error(
'Environment variable NX_OPENAI_KEY is required: skipping embeddings generation'
);
}
const supabaseClient = createClient(
process.env.NX_NEXT_PUBLIC_SUPABASE_URL,
process.env.NX_SUPABASE_SERVICE_ROLE_KEY,
{
auth: {
persistSession: false,
autoRefreshToken: false,
},
}
);
const allFilesPaths = [
...getAllFilesFromMapJson(mapJson),
...getAllFilesWithItemList(manifestsCloud),
...getAllFilesWithItemList(manifestsExtending),
...getAllFilesWithItemList(manifestsNx),
...getAllFilesWithItemList(manifestsPackages),
...getAllFilesWithItemList(manifestsRecipes),
...getAllFilesWithItemList(manifestsTags),
].filter((entry) => !entry.path.includes('sitemap'));
const embeddingSources: EmbeddingSource[] = [
...allFilesPaths.map((entry) => {
return new MarkdownEmbeddingSource(
'guide',
entry.path,
entry.url_partial
);
}),
];
console.log(`Discovered ${embeddingSources.length} pages`);
if (!shouldRefresh) {
console.log('Checking which pages are new or have changed');
} else {
console.log('Refresh flag set, re-generating all pages');
}
for (const [index, embeddingSource] of embeddingSources.entries()) {
const { type, source, path, url_partial } = embeddingSource;
try {
const { checksum, sections } = await embeddingSource.load();
// Check for existing page in DB and compare checksums
const { error: fetchPageError, data: existingPage } = await supabaseClient
.from('nods_page')
.select('id, path, checksum')
.filter('path', 'eq', path)
.limit(1)
.maybeSingle();
if (fetchPageError) {
throw fetchPageError;
}
// We use checksum to determine if this page & its sections need to be regenerated
if (!shouldRefresh && existingPage?.checksum === checksum) {
continue;
}
if (existingPage) {
if (!shouldRefresh) {
console.log(
`#${index}: [${path}] Docs have changed, removing old page sections and their embeddings`
);
} else {
console.log(
`#${index}: [${path}] Refresh flag set, removing old page sections and their embeddings`
);
}
const { error: deletePageSectionError } = await supabaseClient
.from('nods_page_section')
.delete()
.filter('page_id', 'eq', existingPage.id);
if (deletePageSectionError) {
throw deletePageSectionError;
}
}
// Create/update page record. Intentionally clear checksum until we
// have successfully generated all page sections.
const { error: upsertPageError, data: page } = await supabaseClient
.from('nods_page')
.upsert(
{
checksum: null,
path,
url_partial,
type,
source,
},
{ onConflict: 'path' }
)
.select()
.limit(1)
.single();
if (upsertPageError) {
throw upsertPageError;
}
console.log(
`#${index}: [${path}] Adding ${sections.length} page sections (with embeddings)`
);
console.log(
`${embeddingSources.length - index - 1} pages remaining to process.`
);
for (const { slug, heading, content } of sections) {
// OpenAI recommends replacing newlines with spaces for best results (specific to embeddings)
const input = content.replace(/\n/g, ' ');
try {
const configuration = new Configuration({
apiKey: process.env.NX_OPENAI_KEY,
});
const openai = new OpenAIApi(configuration);
const embeddingResponse = await openai.createEmbedding({
model: 'text-embedding-ada-002',
input,
});
if (embeddingResponse.status !== 200) {
throw new Error(inspect(embeddingResponse.data, false, 2));
}
const [responseData] = embeddingResponse.data.data;
const { error: insertPageSectionError, data: pageSection } =
await supabaseClient
.from('nods_page_section')
.insert({
page_id: page.id,
slug,
heading,
content,
url_partial,
token_count: embeddingResponse.data.usage.total_tokens,
embedding: responseData.embedding,
})
.select()
.limit(1)
.single();
if (insertPageSectionError) {
throw insertPageSectionError;
}
// Add delay after each request
await delay(500); // delay of 0.5 second
} catch (err) {
// TODO: decide how to better handle failed embeddings
console.error(
`Failed to generate embeddings for '${path}' page section starting with '${input.slice(
0,
40
)}...'`
);
throw err;
}
}
// Set page checksum so that we know this page was stored successfully
const { error: updatePageError } = await supabaseClient
.from('nods_page')
.update({ checksum })
.filter('id', 'eq', page.id);
if (updatePageError) {
throw updatePageError;
}
} catch (err) {
console.error(
`Page '${path}' or one/multiple of its page sections failed to store properly. Page has been marked with null checksum to indicate that it needs to be re-generated.`
);
console.error(err);
}
}
console.log('Embedding generation complete');
}
function delay(ms: number) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
function getAllFilesFromMapJson(doc): WalkEntry[] {
const files: WalkEntry[] = [];
function traverse(itemList) {
for (const item of itemList) {
if (item.file && item.file.length > 0) {
// we can exclude some docs here, eg. the deprecated ones
// the path is the relative path to the file within the nx repo
// the url_partial is the relative path to the file within the docs site - under nx.dev
files.push({ path: `docs/${item.file}.md`, url_partial: item.path });
}
if (item.itemList) {
traverse(item.itemList);
}
}
}
for (const item of doc.content) {
traverse([item]);
}
return files;
}
function getAllFilesWithItemList(data): WalkEntry[] {
const files: WalkEntry[] = [];
function traverse(itemList) {
for (const item of itemList) {
if (item.file && item.file.length > 0) {
// the path is the relative path to the file within the nx repo
// the url_partial is the relative path to the file within the docs site - under nx.dev
files.push({ path: `docs/${item.file}.md`, url_partial: item.path });
}
if (item.itemList) {
traverse(item.itemList);
}
}
}
for (const key in data) {
if (data[key].itemList) {
traverse([data[key]]);
} else {
if (data[key].documents) {
files.push(...getAllFilesWithItemList(data[key].documents));
}
if (data[key].generators) {
files.push(...getAllFilesWithItemList(data[key].generators));
}
if (data[key].executors) {
files.push(...getAllFilesWithItemList(data[key].executors));
}
if (data[key]?.length > 0) {
traverse(data[key]);
}
}
}
return files;
}
async function main() {
await generateEmbeddings();
}
main().catch((err) => console.error(err));

View File

@ -0,0 +1,13 @@
{
"extends": "./tsconfig.json",
"compilerOptions": {
"outDir": "../../../dist/out-tsc",
"module": "ESNext",
"moduleResolution": "nodenext",
"types": ["node"],
"esModuleInterop": true,
"resolveJsonModule": true
},
"exclude": ["jest.config.ts", "src/**/*.spec.ts", "src/**/*.test.ts"],
"include": ["src/**/*.ts", "src/**/*.mts"]
}

View File

@ -0,0 +1,18 @@
{
"extends": "../../../tsconfig.base.json",
"files": [],
"include": [],
"references": [
{
"path": "./tsconfig.app.json"
},
{
"path": "./tsconfig.spec.json"
}
],
"compilerOptions": {
"esModuleInterop": true,
"module": "esnext",
"target": "esnext"
}
}

View File

@ -0,0 +1,16 @@
{
"extends": "./tsconfig.json",
"compilerOptions": {
"outDir": "../../../dist/out-tsc",
"module": "commonjs",
"types": ["jest", "node"],
"allowImportingTsExtensions": true,
"emitDeclarationOnly": true
},
"include": [
"jest.config.ts",
"src/**/*.test.ts",
"src/**/*.spec.ts",
"src/**/*.d.ts"
]
}