Skip to content

Golang中使用gorm适配自定义数据库驱动

1199字约4分钟

Go

2024-02-28

GORM 官方支持的数据库类型有: MySQL, PostgreSQL, SQLite, SQL Server 和 TiD. 但是我们有的时候需要使用gorm接入一些其他自定义的数据库, 例如: Oracle或者Yashan.

在本文中, 我们将介绍如何在 golang 中使用 gorm 这个流行的 ORM 框架来连接 Yashan 数据库, 并进行一些基本的增删改查操作.

先决条件

在开始之前, 我们需要准备以下内容:

  • 一台安装了 Yashan 数据库的服务器, 以及一个可以访问的数据库用户和密码. 在本文中, 我们假设服务器的 IP 地址是 192.168.1.100, 端口号是 1688, 数据库名是 yasdb, 用户名是 sys, 密码是 123456.
  • 一台安装了 golang 的开发环境, 以及设置好了 GOPATHGOROOT 环境变量. 在本文中, 我们假设 golang 的版本是 1.20, 并且使用了 go mod 来管理依赖包.
  • 一个可以编写和运行 golang 代码的编辑器或 IDE.

安装gorm和驱动

要使用 gorm 来链接 Yashan 数据库, 我们需要安装 gorm 本身, 以及一个适用于 Yashan 的驱动. 由于当前 Yashan 官网上未给出 golang 语言的相关驱动, 我们使用官网提供的 C 驱动然后利用 go 调用 C 实现.

go get -u gorm.io/gorm

执行上述命令后, 我们在 golang 项目中成功安装了 gorm

连接数据库

首先我们来看一下 gorm 内部支持的数据库是如何连接的.

import (
  "gorm.io/driver/postgres"
  "gorm.io/gorm"
)

dsn := "host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai"
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})

gorm 官方文档-Open中可以看到, func Open 接受一个 Dialector 对象. 而我们现在需要做的就是根据 Yashan 数据库的 C 驱动, 实现 Dialector 对象相关的所有接口即可.

package yasdb

import (
	"database/sql"
	"fmt"
	"github.com/thoas/go-funk"
	"gorm.io/gorm"
	"gorm.io/gorm/callbacks"
	"gorm.io/gorm/clause"
	"gorm.io/gorm/logger"
	"gorm.io/gorm/migrator"
	"gorm.io/gorm/schema"
	"gorm.io/gorm/utils"
	"regexp"
	"strconv"
	"strings"
)

const (
	driverType = "yasdb"
)

type Dialector struct {
	*Config
}

func (d Dialector) DummyTableName() string {
	return "DUAL"
}

type Config struct {
	DriverName           string
	DSN                  string
	PreferSimpleProtocol bool
	WithoutReturning     bool
	Conn                 gorm.ConnPool
	DefaultStringSize    uint
}

func Open(dsn string) gorm.Dialector {
	return &Dialector{&Config{DSN: dsn, DefaultStringSize: 255}}
}

func (d Dialector) Name() string {
	return driverType
}

func (d Dialector) Initialize(db *gorm.DB) (err error) {
	// register callbacks
	if !d.WithoutReturning {
		callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
	}
	db.ConnPool, err = sql.Open(d.Name(), d.Config.DSN)
	if err != nil {
		return
	}
	//if err = db.Callback().Create().Replace("gorm:create", Create); err != nil {
	//	return
	//}
	//for k, v := range d.ClauseBuilders() {
	//	db.ClauseBuilders[k] = v
	//}
	return
}

var numericPlaceholder = regexp.MustCompile(`\$(\d+)`)

type Migrator struct {
	migrator.Migrator
}

func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator {
	return Migrator{
		Migrator: migrator.Migrator{
			Config: migrator.Config{
				DB:                          db,
				Dialector:                   d,
				CreateIndexAfterCreateTable: true,
			},
		},
	}
}

func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression {
	return clause.Expr{SQL: "VALUES (DEFAULT)"}
}

func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
	_, err := writer.WriteString(":")
	if err != nil {
		return
	}
	_, err2 := writer.WriteString(strconv.Itoa(len(stmt.Vars)))
	if err2 != nil {
		return
	}
}

func (d Dialector) QuoteTo(writer clause.Writer, str string) {
	var (
		underQuoted, selfQuoted bool
		continuousBacktick      int8
		shiftDelimiter          int8
	)

	for _, v := range []byte(str) {
		switch v {
		case '"':
			continuousBacktick++
			if continuousBacktick == 2 {
				_, err := writer.WriteString(`""`)
				if err != nil {
					return
				}
				continuousBacktick = 0
			}
		case '.':
			if continuousBacktick > 0 || !selfQuoted {
				shiftDelimiter = 0
				underQuoted = false
				continuousBacktick = 0
				err := writer.WriteByte('"')
				if err != nil {
					return
				}
			}
			err := writer.WriteByte(v)
			if err != nil {
				return
			}
			continue
		default:
			if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
				err := writer.WriteByte('"')
				if err != nil {
					return
				}
				underQuoted = true
				if selfQuoted = continuousBacktick > 0; selfQuoted {
					continuousBacktick -= 1
				}
			}

			for ; continuousBacktick > 0; continuousBacktick -= 1 {
				_, err := writer.WriteString(`""`)
				if err != nil {
					return
				}
			}

			err := writer.WriteByte(v)
			if err != nil {
				return
			}
		}
		shiftDelimiter++
	}

	if continuousBacktick > 0 && !selfQuoted {
		_, err := writer.WriteString(`""`)
		if err != nil {
			return
		}
	}
	err := writer.WriteByte('"')
	if err != nil {
		return
	}
}

func (d Dialector) Explain(sql string, vars ...interface{}) string {
	return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v interface{}) interface{} {
		switch v := v.(type) {
		case bool:
			if v {
				return 1
			}
			return 0
		default:
			return v
		}
	}).([]interface{})...)
}

func (d Dialector) DataTypeOf(field *schema.Field) string {
	if _, found := field.TagSettings["RESTRICT"]; found {
		delete(field.TagSettings, "RESTRICT")
	}

	var sqlType string

	switch field.DataType {
	case schema.Bool, schema.Int, schema.Uint, schema.Float:
		sqlType = "INTEGER"
		switch {
		case field.DataType == schema.Float:
			sqlType = "FLOAT"
		case field.Size <= 8:
			sqlType = "SMALLINT"
		case field.Size >= 64:
			sqlType = "BIGINT(8)"
		}

		if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
			sqlType += " GENERATED BY DEFAULT AS IDENTITY"
		}
	case schema.String:
		size := field.Size
		defaultSize := d.DefaultStringSize
		if size == 0 {
			if defaultSize > 0 {
				size = int(defaultSize)
			} else {
				hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != ""
				// TEXT, GEOMETRY or JSON column can't have a default value
				if field.PrimaryKey || field.HasDefaultValue || hasIndex {
					size = 191 // utf8mb4
				}
			}
		}

		if size >= 2000 {
			sqlType = "CLOB"
		} else {
			sqlType = fmt.Sprintf("VARCHAR2(%d)", size)
		}

	case schema.Time:
		sqlType = "TIMESTAMP"
		if field.NotNull || field.PrimaryKey {
			sqlType += " NOT NULL"
		}
	case schema.Bytes:
		sqlType = "BLOB"
	default:
		sqlType := string(field.DataType)

		if strings.EqualFold(sqlType, "text") {
			sqlType = "CLOB"
		}

		if sqlType == "" {
			panic(fmt.Sprintf("invalid sql type %s (%s) for oracle", field.FieldType.Name(), field.FieldType.String()))
		}

		notNull, _ := field.TagSettings["NOT NULL"]
		unique, _ := field.TagSettings["UNIQUE"]
		additionalType := fmt.Sprintf("%s %s", notNull, unique)
		if value, ok := field.TagSettings["DEFAULT"]; ok {
			additionalType = fmt.Sprintf("%s %s %s%s", "DEFAULT", value, additionalType, func() string {
				if value, ok := field.TagSettings["COMMENT"]; ok {
					return " COMMENT " + value
				}
				return ""
			}())
		}
		sqlType = fmt.Sprintf("%v %v", sqlType, additionalType)
	}

	return sqlType
}

现在, 我们即可以使用以下代码进行数据库的连接了.


import (
	"fmt"
	// 此处需要引入驱动外部库
	"gorm.io/gorm"
	"strings"
)

const (
	connectFormat = `%s/%s@%s:%s`
)

func InitYashan() {
	replacer := strings.NewReplacer("@", "\\@", "/", "\\/", "\\", "\\\\")
	datasource := config.Conf.Datasource
	dsn := fmt.Sprintf(connectFormat, replacer.Replace(datasource.Username), replacer.Replace(datasource.Password), datasource.Host, datasource.Port)
	yasDB, err := gorm.Open(Open(dsn), &gorm.Config{})
	// 检查是否有错误
	if err != nil {
		fmt.Println("连接数据库失败: ", err)
		panic(err)
	}
	dataBaseModel, err := yasDB.DB()
	if err != nil {
		global.LOG.Error("连接数据库失败, error=" + err.Error())
		panic(err)
	}
	dataBaseModel.SetMaxOpenConns(datasource.MaxOpenConns)
	dataBaseModel.SetMaxIdleConns(datasource.MaxIdleConns)
}