diff --git a/pyproject.toml b/pyproject.toml index aca8155a03683431d20445e4aaa2acd29820670d..0ba4b4696bb50d5cdba0fa9bae84dcaedd902411 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = ["François Laurent"] [tool.poetry.dependencies] python = "^3.8" julia = "^0.5.7" -hdf5storage = "^0.1.18" +hdf5storage = ">0.1.18" h5py = "^3.1.0" numpy = "^1.19.3" diff --git a/src/taggingbackends/data/trxmat.py b/src/taggingbackends/data/trxmat.py index 489c69aee592f9691eb8fadb70f743b11c2f35d9..6e0ccad570d80e6b5b261aedc9095f269dddcc05 100644 --- a/src/taggingbackends/data/trxmat.py +++ b/src/taggingbackends/data/trxmat.py @@ -79,9 +79,25 @@ class TrxMat: records = records.split() if not lowlevel: records, memoized_records = self._parse_record_names(records) - trx = hdf5storage.loadmat(self.path, - variable_names=["trx/"+record for record in records]) + file = h5py.File(self.path, 'r') + try: + trx = {} + for record in records: + varname = 'trx/'+record + refs = file[varname][0,:] + # np arrays (and transpose) for backward compatibility + trx[varname] = numpy.empty(refs.shape, dtype=object) + if record == 'id': + for i, ref in enumerate(refs): + trx[varname][i] = numpy.array([file[ref][:].tobytes().decode('utf-16')]) + else: + for i, ref in enumerate(refs): + trx[varname][i] = numpy.transpose(file[ref][:]) + assert not numpy.isscalar(trx[varname]) + finally: + file.close() elif lowlevel: + # TODO: replace hdf5storage or deprecate this use case trx = hdf5storage.loadmat(self.path) else: # explicit record names required for memoization @@ -90,8 +106,8 @@ class TrxMat: for varname in trx: record = varname[4:] vardata = trx[varname] - assert len(vardata) == 1 - vardata = vardata[0] + if numpy.isscalar(vardata): # hdf5storage + vardata = vardata[0] try: hook = getattr(self, record+"_hook") except AttributeError: