Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion packages/client/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"make-synchronous": "^1.0.0",
"meow": "^13.1.0",
"np": "github:mmkal/np#6a58244afa28fd7d3f8c4dc4b6457c22d74e82de",
"pg-promise": "^11.5.4",
"pg-mem": "3.0.2",
"quicktype-core": "^23.0.81",
"slonik": "^37.2.0",
Expand All @@ -58,6 +59,14 @@
},
"dependencies": {
"pg": "~8.14.1",
"pg-promise": "^11.5.4"
"postgres": "^3.4.7"
},
"peerDependencies": {
"pg-promise": ">=11"
},
"peerDependenciesMeta": {
"pg-promise": {
"optional": true
}
}
}
86 changes: 45 additions & 41 deletions packages/client/src/client.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import * as crypto from 'node:crypto'
import TypeOverrides from 'pg/lib/type-overrides'
import pgPromise from 'pg-promise'
import {createPgPromiseDriver} from './drivers/pg-promise'
import {QueryError, errorFromUnknown} from './errors'
import {createSqlTag} from './sql'
import {applyRecommendedTypeParsers} from './type-parsers'
Expand All @@ -10,6 +8,8 @@ import {
SQLQueryRowType,
ClientOptions,
Connection,
DriverInfo,
DriverScope,
Transaction,
Result,
DriverQueryable,
Expand Down Expand Up @@ -111,6 +111,18 @@ export const createQueryFn = (pgpQueryable: DriverQueryable): Queryable['query']
}
}

const getCompatibilityProps = (info: DriverInfo) => {
const {name: _name, raw: _raw, ...compatibilityProps} = info
return compatibilityProps
}

const createTransactionFactory = (
scope: Pick<DriverScope, 'transaction'>,
createTransaction: (transactionScope: DriverScope) => Transaction,
): Connection['transaction'] => {
return async callback => scope.transaction(async transaction => callback(createTransaction(transaction)))
}

export const createClient = (connectionString: string, options: ClientOptions = {}): Client => {
if (typeof connectionString !== 'string') throw new Error(`Expected connectionString, got ${typeof connectionString}`)
if (!connectionString) throw new Error(`Expected a valid connectionString, got "${connectionString}"`)
Expand All @@ -121,62 +133,54 @@ export const createClient = (connectionString: string, options: ClientOptions =
...options,
}

const types = new TypeOverrides()
// note: this should be done "high up" in the app: https://stackoverflow.com/questions/34382796/where-should-i-initialize-pg-promise
const initializedPgPromise = pgPromise(options.pgpOptions?.initialize)

options.applyTypeParsers?.({
setTypeParser: (id, parseFn) => types.setTypeParser(id, parseFn as (input: unknown) => unknown),
builtins: initializedPgPromise.pg.types.builtins,
})
const driver = options.driver ?? createPgPromiseDriver(options.pgpOptions)

const createWrappedQueryFn: typeof createQueryFn = queryable => {
const queryFn = createQueryFn(queryable)
return options.wrapQueryFn ? options.wrapQueryFn(queryFn) : queryFn
}

const pgPromiseClient = initializedPgPromise({
const runtime = driver.create({
connectionString,
types,
...options.pgpOptions?.connect,
applyTypeParsers: options.applyTypeParsers,
})

const transactionFnFromTask =
<U>(task: pgPromise.ITask<U> | pgPromise.IDatabase<U>): Connection['transaction'] =>
async txCallback => {
return task.tx({tag: crypto.randomUUID()}, async tx => {
const pgSuiteTransaction: Transaction = {
...createQueryable(createWrappedQueryFn(tx)),
transactionInfo: {pgp: tx},
connectionInfo: {pgp: task},
transaction: transactionFnFromTask(tx),
}
return txCallback(pgSuiteTransaction)
})
const createTransaction = (transactionScope: DriverScope, connectionInfo: DriverInfo): Transaction => {
return {
...getCompatibilityProps(transactionScope.info),
...createQueryable(createWrappedQueryFn(transactionScope.queryable)),
transactionInfo: transactionScope.info,
connectionInfo,
transaction: createTransactionFactory(transactionScope, nested => createTransaction(nested, connectionInfo)),
}
}

const taskMethod: Client['task'] = async callback => {
return pgPromiseClient.task({tag: crypto.randomUUID()}, async task => {
const connectionInfo: Connection['connectionInfo'] = {pgp: task}
const pgSuiteConnection: Connection = {
connectionInfo,
transaction: transactionFnFromTask(task),
...createQueryable(createWrappedQueryFn(task)),
}

return callback(pgSuiteConnection)
})
const createConnection = (connectionScope: DriverScope): Connection => {
const connectionInfo = connectionScope.info
return {
...getCompatibilityProps(connectionInfo),
...createQueryable(createWrappedQueryFn(connectionScope.queryable)),
connectionInfo,
transaction: createTransactionFactory(connectionScope, transaction =>
createTransaction(transaction, connectionInfo),
),
}
}

const taskMethod: Client['task'] = async callback =>
runtime.connect(async connection => callback(createConnection(connection)))

return {
options,
pgp: pgPromiseClient,
...getCompatibilityProps(runtime.info),
driverInfo: runtime.info,
pgpOptions: options.pgpOptions || {},
...createQueryable(createWrappedQueryFn(pgPromiseClient)),
...createQueryable(createWrappedQueryFn(runtime.queryable)),
connectionString: () => connectionString,
end: async () => pgPromiseClient.$pool.end(),
end: async () => runtime.end(),
connect: taskMethod,
task: taskMethod,
transaction: transactionFnFromTask(pgPromiseClient),
transaction: async callback =>
runtime.transaction(async transaction => callback(createTransaction(transaction, runtime.info))),
}
}
84 changes: 84 additions & 0 deletions packages/client/src/drivers/pg-promise.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import * as crypto from 'node:crypto'
import TypeOverrides from 'pg/lib/type-overrides'
import {ClientDriver, DriverInfo, DriverQueryable, DriverScope, PGPOptions} from '../types'

type PgPromiseTask = {
result<T>(query: string, values?: unknown[]): Promise<T>
tx<T>(options: {tag: string}, callback: (task: PgPromiseTask) => Promise<T>): Promise<T>
}

type PgPromiseDatabase = PgPromiseTask & {
task<T>(options: {tag: string}, callback: (task: PgPromiseTask) => Promise<T>): Promise<T>
$pool: {end(): Promise<void>}
}

type PgPromiseConnector = ((connect: Record<string, unknown>) => PgPromiseDatabase) & {
pg: {types: {builtins: Record<string, number>}}
}

type PgPromiseInitializer = (options?: unknown) => PgPromiseConnector

export type PgPromiseDriverOptions = PGPOptions & {
pgPromise?: PgPromiseInitializer
}

const createInfo = (raw: unknown): DriverInfo => ({name: 'pg-promise', raw, pgp: raw})

const createQueryable = (queryable: PgPromiseTask): DriverQueryable => ({
result: async <T>(query: string, values?: unknown[]) => {
return queryable.result<T>(query, values && values.length > 0 ? values : undefined)
},
})

const createScope = (queryable: PgPromiseTask, connectionInfo: DriverInfo): DriverScope => ({
info: createInfo(queryable),
queryable: createQueryable(queryable),
transaction: async callback => {
return queryable.tx({tag: crypto.randomUUID()}, async transaction =>
callback(createScope(transaction, connectionInfo)),
)
},
})

const getPgPromise = (override?: PgPromiseInitializer) => {
if (override) return override
try {
return require('pg-promise') as PgPromiseInitializer
} catch (cause) {
throw new Error('The pg-promise driver requires `pg-promise` to be installed in your app.', {cause})
}
}

export const createPgPromiseDriver = (options: PgPromiseDriverOptions = {}): ClientDriver => ({
name: 'pg-promise',
create({connectionString, applyTypeParsers}) {
const pgPromise = getPgPromise(options.pgPromise)
const initializedPgPromise = pgPromise(options.initialize)
const types = new TypeOverrides()

applyTypeParsers?.({
setTypeParser: (id, parseFn) => types.setTypeParser(id, ((input: unknown) => parseFn(input as never)) as any),
builtins: initializedPgPromise.pg.types.builtins as any,
})

const database = initializedPgPromise({
connectionString,
types,
...options.connect,
})

return {
info: createInfo(database),
queryable: createQueryable(database),
end: async () => database.$pool.end(),
connect: async callback => {
return database.task({tag: crypto.randomUUID()}, async task => callback(createScope(task, createInfo(task))))
},
transaction: async callback => {
return database.tx({tag: crypto.randomUUID()}, async transaction => {
return callback(createScope(transaction, createInfo(database)))
})
},
}
},
})
84 changes: 84 additions & 0 deletions packages/client/src/drivers/pg.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import * as crypto from 'node:crypto'
import {Pool, PoolClient, PoolConfig, types as pgTypes} from 'pg'
import TypeOverrides from 'pg/lib/type-overrides'
import {ClientDriver, DriverInfo, DriverQueryable, DriverScope} from '../types'

const createInfo = (raw: unknown): DriverInfo => ({name: 'pg', raw})

const toResult = async <T>(
queryable: {query(query: string, values?: unknown[]): Promise<any>},
query: string,
values?: unknown[],
) => {
const result = await queryable.query(query, values)
return {
rows: result.rows as T[],
fields: result.fields,
command: result.command,
rowCount: result.rowCount,
}
}

const createQueryable = (queryable: {query(query: string, values?: unknown[]): Promise<any>}): DriverQueryable => ({
result: async <T>(query: string, values?: unknown[]) => toResult<T>(queryable, query, values),
})

const savepointName = () => `pgkit_${crypto.randomUUID().replaceAll('-', '_')}`

const createScope = (client: PoolClient, connectionInfo: DriverInfo, inTransaction: boolean): DriverScope => ({
info: createInfo(client),
queryable: createQueryable(client),
transaction: async callback => {
const savepoint = savepointName()
await client.query(inTransaction ? `savepoint ${savepoint}` : 'begin')
try {
const result = await callback(createScope(client, connectionInfo, true))
await client.query(inTransaction ? `release savepoint ${savepoint}` : 'commit')
return result
} catch (error) {
await client.query(inTransaction ? `rollback to savepoint ${savepoint}` : 'rollback')
throw error
}
},
})

export type PgDriverOptions = PoolConfig

export const createPgDriver = (options: PgDriverOptions = {}): ClientDriver => ({
name: 'pg',
create({connectionString, applyTypeParsers}) {
const types = new TypeOverrides()
applyTypeParsers?.({
setTypeParser: (id, parseFn) => types.setTypeParser(id, ((input: unknown) => parseFn(input as never)) as any),
builtins: pgTypes.builtins,
})

const pool = new Pool({
connectionString,
types,
...options,
})

return {
info: createInfo(pool),
queryable: createQueryable(pool),
end: async () => pool.end(),
connect: async callback => {
const client = await pool.connect()
try {
return await callback(createScope(client, createInfo(client), false))
} finally {
client.release()
}
},
transaction: async callback => {
const client = await pool.connect()
try {
return await createScope(client, createInfo(pool), false).transaction(callback)
} finally {
client.release()
}
},
}
},
})
Loading
Loading