Skip to content

Commit

Permalink
Merge pull request #295 from BrainJS/294-hidden-size-fix
Browse files Browse the repository at this point in the history
fix: Resolve issue with different size hidden layers for recurrent ne…
  • Loading branch information
robertleeplummerjr committed Nov 4, 2018
2 parents 71864ac + 772eb8c commit 0f98db9
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 27 deletions.
10 changes: 3 additions & 7 deletions browser.js
Expand Up @@ -6,7 +6,7 @@
* license: MIT (http://opensource.org/licenses/MIT)
* author: Heather Arthur <fayearthur@gmail.com>
* homepage: https://github.com/brainjs/brain.js#readme
* version: 1.4.3
* version: 1.4.4
*
* acorn:
* license: MIT (http://opensource.org/licenses/MIT)
Expand Down Expand Up @@ -4236,13 +4236,11 @@ var RNN = function () {
}, {
key: 'mapModel',
value: function mapModel() {
var _this = this;

var model = this.model;
var hiddenLayers = model.hiddenLayers;
var allMatrices = model.allMatrices;
this.initialLayerInputs = this.hiddenLayers.map(function (size) {
return new _matrix2.default(_this.hiddenLayers[0], 1);
return new _matrix2.default(size, 1);
});

this.createInputMatrix();
Expand Down Expand Up @@ -4572,8 +4570,6 @@ var RNN = function () {
}, {
key: 'fromJSON',
value: function fromJSON(json) {
var _this2 = this;

var defaults = this.constructor.defaults;
var options = json.options;
this.model = null;
Expand Down Expand Up @@ -4619,7 +4615,7 @@ var RNN = function () {
equationConnections: []
};
this.initialLayerInputs = this.hiddenLayers.map(function (size) {
return new _matrix2.default(_this2.hiddenLayers[0], 1);
return new _matrix2.default(size, 1);
});
this.bindEquation();
}
Expand Down
14 changes: 7 additions & 7 deletions browser.min.js

Large diffs are not rendered by default.

8 changes: 2 additions & 6 deletions dist/recurrent/rnn.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion dist/recurrent/rnn.js.map

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion package.json
@@ -1,7 +1,7 @@
{
"name": "brain.js",
"description": "Neural network library",
"version": "1.4.3",
"version": "1.4.4",
"author": "Heather Arthur <fayearthur@gmail.com>",
"repository": {
"type": "git",
Expand Down
4 changes: 2 additions & 2 deletions src/recurrent/rnn.js
Expand Up @@ -152,7 +152,7 @@ export default class RNN {
let model = this.model;
let hiddenLayers = model.hiddenLayers;
let allMatrices = model.allMatrices;
this.initialLayerInputs = this.hiddenLayers.map((size) => new Matrix(this.hiddenLayers[0], 1));
this.initialLayerInputs = this.hiddenLayers.map((size) => new Matrix(size, 1));

this.createInputMatrix();
if (!model.input) throw new Error('net.model.input not set');
Expand Down Expand Up @@ -506,7 +506,7 @@ export default class RNN {
equations: [],
equationConnections: [],
};
this.initialLayerInputs = this.hiddenLayers.map((size) => new Matrix(this.hiddenLayers[0], 1));
this.initialLayerInputs = this.hiddenLayers.map((size) => new Matrix(size, 1));
this.bindEquation();
}

Expand Down
37 changes: 34 additions & 3 deletions test/recurrent/rnn.js
Expand Up @@ -20,6 +20,28 @@ describe('rnn', () => {
net.initialize();
assert.notEqual(net.model, null);
});
it('can setup different size hiddenLayers', () => {
const inputSize = 2;
const hiddenLayers = [5,4,3];
const networkOptions = {
learningRate: 0.001,
decayRate: 0.75,
inputSize: inputSize,
hiddenLayers,
outputSize: inputSize
};

const net = new RNN(networkOptions);
net.initialize();
net.bindEquation();
assert.equal(net.model.hiddenLayers.length, 3);
assert.equal(net.model.hiddenLayers[0].weight.columns, inputSize);
assert.equal(net.model.hiddenLayers[0].weight.rows, hiddenLayers[0]);
assert.equal(net.model.hiddenLayers[1].weight.columns, hiddenLayers[0]);
assert.equal(net.model.hiddenLayers[1].weight.rows, hiddenLayers[1]);
assert.equal(net.model.hiddenLayers[2].weight.columns, hiddenLayers[1]);
assert.equal(net.model.hiddenLayers[2].weight.rows, hiddenLayers[2]);
});
});
describe('basic operations', () => {
it('starts with zeros in input.deltas', () => {
Expand Down Expand Up @@ -354,9 +376,12 @@ describe('rnn', () => {

describe('.fromJSON', () => {
it('can import model from json', () => {
let dataFormatter = new DataFormatter('abcdef'.split(''));
let jsonString = JSON.stringify(new RNN({
inputSize: 6, //<- length
const inputSize = 6;
const hiddenLayers = [10, 20];
const dataFormatter = new DataFormatter('abcdef'.split(''));
const jsonString = JSON.stringify(new RNN({
inputSize, //<- length
hiddenLayers,
inputRange: dataFormatter.characters.length,
outputSize: dataFormatter.characters.length //<- length
}).toJSON(), null, 2);
Expand All @@ -368,6 +393,12 @@ describe('rnn', () => {
assert.equal(clone.inputSize, 6);
assert.equal(clone.inputRange, dataFormatter.characters.length);
assert.equal(clone.outputSize, dataFormatter.characters.length);

assert.equal(clone.model.hiddenLayers.length, 2);
assert.equal(clone.model.hiddenLayers[0].weight.columns, inputSize);
assert.equal(clone.model.hiddenLayers[0].weight.rows, hiddenLayers[0]);
assert.equal(clone.model.hiddenLayers[1].weight.columns, hiddenLayers[0]);
assert.equal(clone.model.hiddenLayers[1].weight.rows, hiddenLayers[1]);
});

it('can import model from json using .fromJSON()', () => {
Expand Down

0 comments on commit 0f98db9

Please sign in to comment.