ML.NET案例详解:在.NET下使用机器学习API实现化学分子式数据格式的判定

By | 2019年2月22日

半年前写过一篇类似的文章,题目是:《在.NET中使用机器学习API(ML.NET)实现化学分子式数据格式的判定》,在该文中,我介绍了化学分子式数据格式的基本知识,同时给出了一个案例,展示了如何在.NET/.NET Core中,使用微软开源的ML.NET框架,通过机器学习,实现化学分子式数据格式的预测。

时隔半年,ML.NET有了很大的发展。在阅读我之前那篇文章的时候,或许还会对给出的案例代码有些疑问,ML.NET经过几个版本的更新之后,API的设计变得更为合理易用,所开放的接口也越来越多(比如,新版本的ML.NET中,对机器学习引擎的OutputSchema进行了完全开放,开发者可以根据自己的需要进行调用),因此,本文就再一次回到这个话题并进行更为详细的介绍,用新版本的ML.NET重新实现化学分子式数据格式的判定。

有关化学分子式的相关知识,在这里也就不多说了,直接看代码实现部分。

准备数据

我们的数据仍然是一个CSV文件,通过逗号分隔,文件包含两个字段:结构式数据(ChemicalStructure),以及该结构式数据的类型(Type),以下是这个文件的部分片段,注意,在这个文件中,我们没有定义CSV头,不过这不重要,只要记得在后面的代码实现中,将这个设置体现出来就可以了。

[O-]C(CCCCCCCCCCCCCCCCC)=O.[Na+],SMILES
O=C(C1)N(C2[C@@]3(CC4)[C@](N4C5)([H])C[C@@]6([H])C5=CCOC1[C@]62[H])C7=C3C=CC=C7.O[N+]([O-])=O,SMILES
O=C1CC2C(C3[C@]45C(C=CC=C6)=C6N31)C(CC4N(CC5)C7)C7=CCO2.OS(O)(=O)=O.O=C8CC9C(C%10[C@@]%11%12C(C=CC=C%13)=C%13N%108)C(CC%11N(CC%12)C%14)C%14=CCO9,SMILES
C=CC1=CC=CC=C1,SMILES
N=C(OC)CCCCCCC(OC)=N.Cl.Cl,SMILES
NC(CCC(N)=O)=O,SMILES
O=C(O)C1(N(CCOC)CCOC)CCC(C)CC1,SMILES
CN(C)C(C)CC(C1=CC=CC=C1)(C(CC)=O)C2=CC=CC=C2,SMILES
NCC1(CCC(CCC)CC1)N(C)CC2=COC=C2,SMILES
AAADceByOAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHgAAAAAACBThgAYCCAMABAAIAACQCAAAAAAAAAAAAAEIAAACABQAgAAHAAAFIAAQAAAkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==,BASE64_CDX
AAADceByOAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHgAACAAACBThgAYCCAMABgAIAACQCAAAAAAAAAAAAAEIAAACABQAgAAHQAAFIAAQAAAkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==,BASE64_CDX
AAADccBCIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHgAQCAAACBThgAYCAABAAgAAAAAAAAAAAAAAAAAAAIAAAAACEAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==,BASE64_CDX
AAADccBjgAAAAAAAAAAAAAAAAAAAAWAAAAAsAAAAAAAAAFgB+AAAHAAQAAAACAjBFwQH8L9MEACgAQZhZACAgC0REKABUCAoVBCASABASEAUBAgIAALAACAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==,BASE64_CDX
AAADceB7uAAAAAAAAAAAAAAAAAAAAAAAAAAwQIAAAAAAAACBAAAAHgAQCAAADCjBmAYxyIPAAgCoAiXS/ACCAAElAgAJiIGIZMiKYDLA1bGUYQhslgLYyce8rwCeCAAAAAAAAAAQAAAAAAAAAAAAAAAAAA==,BASE64_CDX
OC1=C(C2=C(C=C1)C[C@@]3([C@]45[H])[H])O[C@]([C@@]52CCN3C)([C@H](CC4)OC)[H],SMILES
OC1=C(O2)C([C@]([C@]2(C)C(CC3)=O)(CCN4C)[C@]3([H])[C@H]4C5)=C5C=C1,SMILES
........

注意:你不需要将这些数据复制下来,本文结尾会给出源代码,其中包含了这个完整的数据文件。

实现过程

可以基于.NET Framework 4.6.1或者.NET Core创建一个新的控制台应用程序,在这个控制台应用程序上,添加对ML.NET NuGet包的引用。实现的第一步就是定义我们的样本数据对象。根据上面的CSV文件结构,我们可以设计如下的类:

public class ChemicalData
{
    [Column("0")]
    public string ChemicalStructure;

    [Column("1")]
    public string Type;
}

这个类非常简单,仅仅是针对CSV文件两个列的映射。接下来,我们需要定义用于保存预测结果的数据对象,该对象不仅会用来保存预测结果值,而且还会提供基于不同分类的可信度得分(Score):

public class ChemicalDataPrediction
{
    [ColumnName("PredictedLabel")]
    public string Type;

    public float[] Score;
}

OK,到这里我们基本上已经清楚我们的机器学习应用场景了:我们在使用Multi-class Classification对化学结构式数据进行分类。在机器学习的应用过程中,了解应用场景是非常重要的。然后,回到Main函数,实现如下代码:

static void Main(string[] args)
{
    // 创建机器学习上下文实例
    var mlContext = new MLContext();

    // 从data.txt读入样本数据
    var dataView = mlContext.Data.ReadFromTextFile("data.txt", new TextLoader.Arguments
    {
        Separators = new char[] { ',' }, // 逗号分隔
        HasHeader = false, // 文件中不包含CSV头信息
        Column = new[] {
            new TextLoader.Column("ChemicalStructure", DataKind.Text, 0),  // 化学结构式数据字段
            new TextLoader.Column("Type", DataKind.Text, 1)  // 化学结构式数据类型字段
        }
    });

    // 创建机器学习管道,指定我们需要使用CSV文件中的Type字段进行标记并分类
    var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label", "Type")
        
        // 指定将由ChemicalStructure字段提供特征信息
        .Append(mlContext.Transforms.Text.FeaturizeText("Features", "ChemicalStructure"))

        // 选择机器学习算法
        .Append(mlContext.MulticlassClassification.Trainers.LogisticRegression())

        // 计算结果将输出到由PredictedLabel所标记的对象字段上
        .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));

    // 基于样本数据和所选择的管道选项,进行模型训练,并返回模型
    var model = pipeline.Fit(dataView);

    // 创建预测引擎
    var engine = model.CreatePredictionEngine<ChemicalData, ChemicalDataPrediction>(mlContext);

    // 对给定的测试数据进行预测,并输出测试结果
    var sample = new ChemicalData { ChemicalStructure = "NC(C(N)=O)=O" };
    var prediction = engine.Predict(sample);
    Console.WriteLine(prediction.Type);
}

代码非常简单,有几个点说明一下:

  • 新的ML.NET需要创建MLContext对象,所有的机器学习工作都会依赖于这个上下文
  • 通过MapValueToKey方法来指定读入数据的哪个字段是用来进行分类标记的,这个Label是ML.NET的一个保留字段名,在模型训练的时候会找到由Label所标记的字段进行计算
  • Features也是ML.NET的一个保留字段名,它指定了哪个(或哪些)字段将提供特征数据
  • PredictedLabel也是ML.NET的保留字段名,它指定了计算结果应该输出到哪个对象字段中

直接运行程序,可以看到,程序毫无悬念地输出了正确结果:

image

可信度得分的获取

在上面的代码中,如果我们将断点设置在最后一句Console.WriteLine方法上,然后调试程序,查看prediction的数值,会发现,各个分类的可信度已经在Score字段里了:

image

可问题是,我如何知道某个得分到底是属于哪个分类呢?在ML.NET 0.6之前的版本,在训练好的模型对象上,会有一个TryGetScoreLabelNames的扩展方法,它能够返回可信度得分的分类名称,顺序和Score数组的顺序一致。但从ML.NET 0.6开始,这个扩展方法已经没有了,但这并不是说ML.NET变得更弱了,相反,新版本中直接将OutputSchema对象暴露出来,开发者可以自己实现所需的方法。下面的代码展示了如何基于预测引擎的OutputSchema来获取各个分类的名称,以及所对应的可信度得分:

static void Main(string[] args)
{
    // ...
    // 接上文代码
    
    var outputSchema = engine.OutputSchema;
    TryGetScoreLabelNames(outputSchema, out var names);
    var confidences = new Dictionary<string, float>();
    for (var idx = 0; idx < names.Length; idx++)
    {
        confidences.Add(names[idx], prediction.Score[idx]);
    }

    Console.WriteLine(JsonConvert.SerializeObject(
        new
        {
            Label = prediction.Type,
            Confidences = confidences
        },
        Formatting.Indented));
}

static bool TryGetScoreLabelNames(Schema outputSchema, out string[] names, string scoreColumnName = DefaultColumnNames.Score)
{
    names = (string[])null;
    var scoreColumn = outputSchema.GetColumnOrNull(scoreColumnName);
    var slotNames = new VBuffer<ReadOnlyMemory<char>>();
    scoreColumn.Value.GetSlotNames(ref slotNames);
    names = new string[slotNames.Length];
    var num = 0;
    foreach (var denseValue in slotNames.DenseValues())
    {
        names[num++] = denseValue.ToString();
    }
    return true;
}

再次执行程序,可以看到,我们已经可以输出各个分类的可信度得分了:

image

预测失误

现在我们做个试验,将最后用于测试的数据从SMILES换成INCHI,比如:

var sample = new ChemicalData { ChemicalStructure = "InChI=1S/ClH/h1H/p-1" };

然后再次运行程序,结果发现,我们本想得到INCHI的输出,却仍然得到SMILES的结果,只不过SMILES的可信度降低了,InChi的可信度升高了:

image

这个问题主要是因为我们所提供的用于训练的样本数据还不够多,如果训练数据量大,并且干扰比较小的话,得到的预测结果就会更准确。因此,在实践机器学习的过程中,保证训练数据的纯净度和数据量是非常重要的,这也就是为什么目前机器学习的项目中,在数据清洗这一步中有着相当大的投入。回到我们的案例,让我们在样本CSV文件中多加一些InChi数据,来帮助机器学习得到更精确的结果:

"InChI=1/C2H6O/c1-2-3/h3H,2H2,1H3",InChi
"InChI=1/C6H8O6/c7-1-2(8)5-3(9)4(10)6(11)12-5/h2,5,7-10H,1H2/t2-,5+/m0/s1",InChi
"InChI=1S/C6H8O6/c7-1-2(8)5-3(9)4(10)6(11)12-5/h2,5,7-10H,1H2/t2-,5+/m0/s1",InChi
"InChI=1S/CH4/h1H4",InChi
"InChI=1S/C2H6/c1-2/h1-2H3",InChi
"InChI=1S/C2H6O/c1-2-3/h3H,2H2,1H3",InChi
"InChI=1S/C3H7NO2/c1-2(4)3(5)6/h2H,4H2,1H3,(H,5,6)/t2-/m0/s1",InChi
"InChI=1S/ClH/h1H/p-1",InChi
"InChI=1S/C6H7NO/c1-5-3-2-4-7-6(5)8/h2-4H,1H3,(H,7,8)",InChi
"InChI=1S/CH2N2/c1-3-2/h1H2",InChi
"InChI=1S/C7H5N3O/c11-7-5-3-1-2-4-6(5)8-10-9-7/h1-4H,(H,8,9,11)",InChi
"InChI=1S/C8H6N2O/c11-8-6-3-1-2-4-7(6)9-5-10-8/h1-5H,(H,9,10,11)",InChi
"InChI=1S/C2H6N2O/c1-4(2)3-5/h1-2H3",InChi
"InChI=1S/C9H8N2O/c1-6-10-8-5-3-2-4-7(8)9(12)11-6/h2-5H,1H3,(H,10,11,12)",InChi
"InChI=1S/C6H8O/c1-2-3-4-5-6-7/h2-6H,1H3/b3-2+,5-4+",InChi

再次运行程序,我们已经可以得到正确的输出了(虽然它仍然认为有31%的可能性是SMILES):

image

模型的保存与使用

我们可以用下面的代码将训练好的模型保存到本地ZIP文件中,以便今后直接在项目中使用:

using (var fs = new FileStream("ml_model.zip", FileMode.Create, FileAccess.Write, FileShare.Write))
{
    mlContext.Model.Save(model, fs);
}

然后使用下面的代码,读入保存的模型,并进行新的预测:

var mlContext2 = new MLContext();
ITransformer loadedModel;
using (var stream = new FileStream("ml_model.zip", FileMode.Open, FileAccess.Read, FileShare.Read))
{
    loadedModel = mlContext2.Model.Load(stream);
    var engine2 = loadedModel.CreatePredictionEngine<ChemicalData, ChemicalDataPrediction>(mlContext2);
    var pred = engine2.Predict(new ChemicalData { ChemicalStructure = "c1ccccc1" });
    Console.WriteLine(pred.Type);
}

总结

本文再一次介绍了如何使用微软开源的ML.NET框架,实现化学结构式数据格式的预测和判定。本文对使用ML.NET的整个流程进行了详细完整的介绍,但只演示了Multi-class Classification的应用场景。其它应用场景其实也大同小异,开发人员需要根据实际情况进行选择。通过ML.NET产生的训练模型是可以序列化到ZIP文件的,因此,模型可以方便地重用。ML.NET支持.NET Core,因此,基于docker和ASP.NET Core实现机器学习的RESTful API也是轻而易举的事情,本文就不继续深入了。

源代码下载

请【点击此处】下载本文案例的源代码。

(总访问量:509;当日访问量:1)

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据