Go 中使用 gorm 适配自定义数据库驱动
GORM 官方支持的数据库类型有:MySQL,PostgreSQL,SQLite,SQL Server 和 TiD。但是我们有的时候需要使用 gorm 接入一些其他自定义的数据库,例如:Oracle 或者 Yashan。
在本文中,我们将介绍如何在 Go 中使用 gorm 这个流行的 ORM 框架来连接 Yashan 数据库,并进行一些基本的增删改查操作。事实上这个方法适用于任何一个你想适配的数据库,如果其官方未适配 gorm 的话。
前提环境
在开始之前,我们需要准备以下内容:
- 一台安装了 Yashan 数据库的服务器,以及一个可以访问的数据库用户和密码。在本文中,我们假设服务器的 IP 地址是
192.168.1.100
,端口号是1688
,数据库名是yasdb
,用户名是sys
,密码是123456
。 - 一台安装了 Go 的开发环境,以及设置好了
GOPATH
和GOROOT
环境变量。在本文中,我们假设 Go 的版本是1.20
,并且使用了go mod
来管理依赖包。 - 一个可以编写和运行 Go 代码的编辑器或 IDE。
安装 gorm 和驱动
要使用 gorm 来链接 Yashan 数据库,我们需要安装 gorm 本身,以及一个适用于 Yashan 的驱动。由于当前 Yashan 官网上未给出 Go 语言的相关驱动,我们使用官网提供的 C 驱动然后利用 go 调用 C 实现。
go get -u gorm.io/gorm
执行上述命令后,我们在 Go 项目中成功安装了 gorm。
连接数据库
首先我们来看一下 gorm 内部支持的数据库是如何连接的。
main.go
pacakge main
import (
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
dsn := "host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai"
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
从 gorm 官方文档-Open 中可以看到,func Open
接受一个 Dialector 对象。而我们现在需要做的就是根据 Yashan 数据库的 C 驱动,实现 Dialector
对象相关的所有接口即可。
dialect.go
package yasdb
import (
"database/sql"
"fmt"
"github.com/thoas/go-funk"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
"regexp"
"strconv"
"strings"
)
const (
driverType = "yasdb"
)
type Dialector struct {
*Config
}
func (d Dialector) DummyTableName() string {
return "DUAL"
}
type Config struct {
DriverName string
DSN string
PreferSimpleProtocol bool
WithoutReturning bool
Conn gorm.ConnPool
DefaultStringSize uint
}
func Open(dsn string) gorm.Dialector {
return &Dialector{&Config{DSN: dsn, DefaultStringSize: 255}}
}
func (d Dialector) Name() string {
return driverType
}
func (d Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks
if !d.WithoutReturning {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
}
db.ConnPool, err = sql.Open(d.Name(), d.Config.DSN)
if err != nil {
return
}
//if err = db.Callback().Create().Replace("gorm:create", Create); err != nil {
// return
//}
//for k, v := range d.ClauseBuilders() {
// db.ClauseBuilders[k] = v
//}
return
}
var numericPlaceholder = regexp.MustCompile(`\$(\d+)`)
type Migrator struct {
migrator.Migrator
}
func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{
Migrator: migrator.Migrator{
Config: migrator.Config{
DB: db,
Dialector: d,
CreateIndexAfterCreateTable: true,
},
},
}
}
func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression {
return clause.Expr{SQL: "VALUES (DEFAULT)"}
}
func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
_, err := writer.WriteString(":")
if err != nil {
return
}
_, err2 := writer.WriteString(strconv.Itoa(len(stmt.Vars)))
if err2 != nil {
return
}
}
func (d Dialector) QuoteTo(writer clause.Writer, str string) {
var (
underQuoted, selfQuoted bool
continuousBacktick int8
shiftDelimiter int8
)
for _, v := range []byte(str) {
switch v {
case '"':
continuousBacktick++
if continuousBacktick == 2 {
_, err := writer.WriteString(`""`)
if err != nil {
return
}
continuousBacktick = 0
}
case '.':
if continuousBacktick > 0 || !selfQuoted {
shiftDelimiter = 0
underQuoted = false
continuousBacktick = 0
err := writer.WriteByte('"')
if err != nil {
return
}
}
err := writer.WriteByte(v)
if err != nil {
return
}
continue
default:
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
err := writer.WriteByte('"')
if err != nil {
return
}
underQuoted = true
if selfQuoted = continuousBacktick > 0; selfQuoted {
continuousBacktick -= 1
}
}
for ; continuousBacktick > 0; continuousBacktick -= 1 {
_, err := writer.WriteString(`""`)
if err != nil {
return
}
}
err := writer.WriteByte(v)
if err != nil {
return
}
}
shiftDelimiter++
}
if continuousBacktick > 0 && !selfQuoted {
_, err := writer.WriteString(`""`)
if err != nil {
return
}
}
err := writer.WriteByte('"')
if err != nil {
return
}
}
func (d Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v interface{}) interface{} {
switch v := v.(type) {
case bool:
if v {
return 1
}
return 0
default:
return v
}
}).([]interface{})...)
}
func (d Dialector) DataTypeOf(field *schema.Field) string {
if _, found := field.TagSettings["RESTRICT"]; found {
delete(field.TagSettings, "RESTRICT")
}
var sqlType string
switch field.DataType {
case schema.Bool, schema.Int, schema.Uint, schema.Float:
sqlType = "INTEGER"
switch {
case field.DataType == schema.Float:
sqlType = "FLOAT"
case field.Size <= 8:
sqlType = "SMALLINT"
case field.Size >= 64:
sqlType = "BIGINT(8)"
}
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
sqlType += " GENERATED BY DEFAULT AS IDENTITY"
}
case schema.String:
size := field.Size
defaultSize := d.DefaultStringSize
if size == 0 {
if defaultSize > 0 {
size = int(defaultSize)
} else {
hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != ""
// TEXT, GEOMETRY or JSON column can't have a default value
if field.PrimaryKey || field.HasDefaultValue || hasIndex {
size = 191 // utf8mb4
}
}
}
if size >= 2000 {
sqlType = "CLOB"
} else {
sqlType = fmt.Sprintf("VARCHAR2(%d)", size)
}
case schema.Time:
sqlType = "TIMESTAMP"
if field.NotNull || field.PrimaryKey {
sqlType += " NOT NULL"
}
case schema.Bytes:
sqlType = "BLOB"
default:
sqlType := string(field.DataType)
if strings.EqualFold(sqlType, "text") {
sqlType = "CLOB"
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for oracle", field.FieldType.Name(), field.FieldType.String()))
}
notNull, _ := field.TagSettings["NOT NULL"]
unique, _ := field.TagSettings["UNIQUE"]
additionalType := fmt.Sprintf("%s %s", notNull, unique)
if value, ok := field.TagSettings["DEFAULT"]; ok {
additionalType = fmt.Sprintf("%s %s %s%s", "DEFAULT", value, additionalType, func() string {
if value, ok := field.TagSettings["COMMENT"]; ok {
return " COMMENT " + value
}
return ""
}())
}
sqlType = fmt.Sprintf("%v %v", sqlType, additionalType)
}
return sqlType
}
现在,我们即可以使用以下代码进行数据库的连接了。
main.go
package main
import (
"fmt"
// 此处需要引入驱动外部库
"gorm.io/gorm"
"strings"
)
const (
connectFormat = `%s/%s@%s:%s`
)
func InitYashan() {
replacer := strings.NewReplacer("@", "\\@", "/", "\\/", "\\", "\\\\")
datasource := config.Conf.Datasource
dsn := fmt.Sprintf(connectFormat, replacer.Replace(datasource.Username), replacer.Replace(datasource.Password), datasource.Host, datasource.Port)
yasDB, err := gorm.Open(Open(dsn), &gorm.Config{})
// 检查是否有错误
if err != nil {
fmt.Println("连接数据库失败: ", err)
panic(err)
}
dataBaseModel, err := yasDB.DB()
if err != nil {
global.LOG.Error("连接数据库失败, error=" + err.Error())
panic(err)
}
dataBaseModel.SetMaxOpenConns(datasource.MaxOpenConns)
dataBaseModel.SetMaxIdleConns(datasource.MaxIdleConns)
}
贡献者
更新日志
2025/2/24 09:08
查看所有更新日志