Skip to content

connection.py

This module contains the Connection class that manages the connection to the database, and the conn function that provides access to a persistent connection in datajoint.

translate_query_error(client_error, query)

Take client error and original query and return the corresponding DataJoint exception.

Parameters:

Name Type Description Default
client_error

the exception raised by the client interface

required
query

sql query with placeholders

required

Returns:

Type Description

an instance of the corresponding subclass of datajoint.errors.DataJointError

Source code in datajoint/connection.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def translate_query_error(client_error, query):
    """
    Take client error and original query and return the corresponding DataJoint exception.

    :param client_error: the exception raised by the client interface
    :param query: sql query with placeholders
    :return: an instance of the corresponding subclass of datajoint.errors.DataJointError
    """
    logger.debug("type: {}, args: {}".format(type(client_error), client_error.args))

    err, *args = client_error.args

    # Loss of connection errors
    if err in (0, "(0, '')"):
        return errors.LostConnectionError(
            "Server connection lost due to an interface error.", *args
        )
    if err == 2006:
        return errors.LostConnectionError("Connection timed out", *args)
    if err == 2013:
        return errors.LostConnectionError("Server connection lost", *args)
    # Access errors
    if err in (1044, 1142):
        return errors.AccessError("Insufficient privileges.", args[0], query)
    # Integrity errors
    if err == 1062:
        return errors.DuplicateError(*args)
    if err == 1217:  # MySQL 8 error code
        return errors.IntegrityError(*args)
    if err == 1451:
        return errors.IntegrityError(*args)
    if err == 1452:
        return errors.IntegrityError(*args)
    # Syntax errors
    if err == 1064:
        return errors.QuerySyntaxError(args[0], query)
    # Existence errors
    if err == 1146:
        return errors.MissingTableError(args[0], query)
    if err == 1364:
        return errors.MissingAttributeError(*args)
    if err == 1054:
        return errors.UnknownAttributeError(*args)
    # all the other errors are re-raised in original form
    return client_error

conn(host=None, user=None, password=None, *, init_fun=None, reset=False, use_tls=None)

Returns a persistent connection object to be shared by multiple modules. If the connection is not yet established or reset=True, a new connection is set up. If connection information is not provided, it is taken from config which takes the information from dj_local_conf.json. If the password is not specified in that file datajoint prompts for the password.

Parameters:

Name Type Description Default
host

hostname

None
user

mysql user

None
password

mysql password

None
init_fun

initialization function

None
reset

whether the connection should be reset or not

False
use_tls

TLS encryption option. Valid options are: True (required), False (required no TLS), None (TLS preferred, default), dict (Manually specify values per https://dev.mysql.com/doc/refman/5.7/en/connection-options.html#encrypted-connection-options).

None
Source code in datajoint/connection.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def conn(
    host=None, user=None, password=None, *, init_fun=None, reset=False, use_tls=None
):
    """
    Returns a persistent connection object to be shared by multiple modules.
    If the connection is not yet established or reset=True, a new connection is set up.
    If connection information is not provided, it is taken from config which takes the
    information from dj_local_conf.json. If the password is not specified in that file
    datajoint prompts for the password.

    :param host: hostname
    :param user: mysql user
    :param password: mysql password
    :param init_fun: initialization function
    :param reset: whether the connection should be reset or not
    :param use_tls: TLS encryption option. Valid options are: True (required), False
        (required no TLS), None (TLS preferred, default), dict (Manually specify values per
        https://dev.mysql.com/doc/refman/5.7/en/connection-options.html#encrypted-connection-options).
    """
    if not hasattr(conn, "connection") or reset:
        host = host if host is not None else config["database.host"]
        user = user if user is not None else config["database.user"]
        password = password if password is not None else config["database.password"]
        if user is None:
            user = input("Please enter DataJoint username: ")
        if password is None:
            password = getpass(prompt="Please enter DataJoint password: ")
        init_fun = (
            init_fun if init_fun is not None else config["connection.init_function"]
        )
        use_tls = use_tls if use_tls is not None else config["database.use_tls"]
        conn.connection = Connection(host, user, password, None, init_fun, use_tls)
    return conn.connection

EmulatedCursor

acts like a cursor

Source code in datajoint/connection.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class EmulatedCursor:
    """acts like a cursor"""

    def __init__(self, data):
        self._data = data
        self._iter = iter(self._data)

    def __iter__(self):
        return self

    def __next__(self):
        return next(self._iter)

    def fetchall(self):
        return self._data

    def fetchone(self):
        return next(self._iter)

    @property
    def rowcount(self):
        return len(self._data)

Connection

A dj.Connection object manages a connection to a database server. It also catalogues modules, schemas, tables, and their dependencies (foreign keys).

Most of the parameters below should be set in the local configuration file.

Parameters:

Name Type Description Default
host

host name, may include port number as hostname:port, in which case it overrides the value in port

required
user

user name

required
password

password

required
port

port number

None
init_fun

connection initialization function (SQL)

None
use_tls

TLS encryption option

None
Source code in datajoint/connection.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
class Connection:
    """
    A dj.Connection object manages a connection to a database server.
    It also catalogues modules, schemas, tables, and their dependencies (foreign keys).

    Most of the parameters below should be set in the local configuration file.

    :param host: host name, may include port number as hostname:port, in which case it overrides the value in port
    :param user: user name
    :param password: password
    :param port: port number
    :param init_fun: connection initialization function (SQL)
    :param use_tls: TLS encryption option
    """

    def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None):
        host_input, host = (host, get_host_hook(host))
        if ":" in host:
            # the port in the hostname overrides the port argument
            host, port = host.split(":")
            port = int(port)
        elif port is None:
            port = config["database.port"]
        self.conn_info = dict(host=host, port=port, user=user, passwd=password)
        if use_tls is not False:
            self.conn_info["ssl"] = (
                use_tls if isinstance(use_tls, dict) else {"ssl": {}}
            )
        self.conn_info["ssl_input"] = use_tls
        self.conn_info["host_input"] = host_input
        self.init_fun = init_fun
        logger.info("Connecting {user}@{host}:{port}".format(**self.conn_info))
        self._conn = None
        self._query_cache = None
        connect_host_hook(self)
        if self.is_connected:
            logger.info("Connected {user}@{host}:{port}".format(**self.conn_info))
            self.connection_id = self.query("SELECT connection_id()").fetchone()[0]
        else:
            raise errors.LostConnectionError("Connection failed.")
        self._in_transaction = False
        self.schemas = dict()
        self.dependencies = Dependencies(self)

    def __eq__(self, other):
        return self.conn_info == other.conn_info

    def __repr__(self):
        connected = "connected" if self.is_connected else "disconnected"
        return "DataJoint connection ({connected}) {user}@{host}:{port}".format(
            connected=connected, **self.conn_info
        )

    def connect(self):
        """Connect to the database server."""
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", ".*deprecated.*")
            try:
                self._conn = client.connect(
                    init_command=self.init_fun,
                    sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
                    "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY",
                    charset=config["connection.charset"],
                    **{
                        k: v
                        for k, v in self.conn_info.items()
                        if k not in ["ssl_input", "host_input"]
                    },
                )
            except client.err.InternalError:
                self._conn = client.connect(
                    init_command=self.init_fun,
                    sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
                    "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY",
                    charset=config["connection.charset"],
                    **{
                        k: v
                        for k, v in self.conn_info.items()
                        if not (
                            k in ["ssl_input", "host_input"]
                            or k == "ssl"
                            and self.conn_info["ssl_input"] is None
                        )
                    },
                )
        self._conn.autocommit(True)

    def set_query_cache(self, query_cache=None):
        """
        When query_cache is not None, the connection switches into the query caching mode, which entails:
        1. Only SELECT queries are allowed.
        2. The results of queries are cached under the path indicated by dj.config['query_cache']
        3. query_cache is a string that differentiates different cache states.

        :param query_cache: a string to initialize the hash for query results
        """
        self._query_cache = query_cache

    def purge_query_cache(self):
        """Purges all query cache."""
        if (
            isinstance(config.get(cache_key), str)
            and pathlib.Path(config[cache_key]).is_dir()
        ):
            for path in pathlib.Path(config[cache_key]).iterdir():
                if not path.is_dir():
                    path.unlink()

    def close(self):
        self._conn.close()

    def register(self, schema):
        self.schemas[schema.database] = schema
        self.dependencies.clear()

    def ping(self):
        """Ping the connection or raises an exception if the connection is closed."""
        self._conn.ping(reconnect=False)

    @property
    def is_connected(self):
        """Return true if the object is connected to the database server."""
        try:
            self.ping()
        except:
            return False
        return True

    @staticmethod
    def _execute_query(cursor, query, args, suppress_warnings):
        try:
            with warnings.catch_warnings():
                if suppress_warnings:
                    # suppress all warnings arising from underlying SQL library
                    warnings.simplefilter("ignore")
                cursor.execute(query, args)
        except client.err.Error as err:
            raise translate_query_error(err, query)

    def query(
        self, query, args=(), *, as_dict=False, suppress_warnings=True, reconnect=None
    ):
        """
        Execute the specified query and return the tuple generator (cursor).

        :param query: SQL query
        :param args: additional arguments for the client.cursor
        :param as_dict: If as_dict is set to True, the returned cursor objects returns
                        query results as dictionary.
        :param suppress_warnings: If True, suppress all warnings arising from underlying query library
        :param reconnect: when None, get from config, when True, attempt to reconnect if disconnected
        """
        # check cache first:
        use_query_cache = bool(self._query_cache)
        if use_query_cache and not re.match(r"\s*(SELECT|SHOW)", query):
            raise errors.DataJointError(
                "Only SELECT queries are allowed when query caching is on."
            )
        if use_query_cache:
            if not config[cache_key]:
                raise errors.DataJointError(
                    f"Provide filepath dj.config['{cache_key}'] when using query caching."
                )
            hash_ = uuid_from_buffer(
                (str(self._query_cache) + re.sub(r"`\$\w+`", "", query)).encode()
                + pack(args)
            )
            cache_path = pathlib.Path(config[cache_key]) / str(hash_)
            try:
                buffer = cache_path.read_bytes()
            except FileNotFoundError:
                pass  # proceed to query the database
            else:
                return EmulatedCursor(unpack(buffer))

        if reconnect is None:
            reconnect = config["database.reconnect"]
        logger.debug("Executing SQL:" + query[:query_log_max_length])
        cursor_class = client.cursors.DictCursor if as_dict else client.cursors.Cursor
        cursor = self._conn.cursor(cursor=cursor_class)
        try:
            self._execute_query(cursor, query, args, suppress_warnings)
        except errors.LostConnectionError:
            if not reconnect:
                raise
            logger.warning("MySQL server has gone away. Reconnecting to the server.")
            connect_host_hook(self)
            if self._in_transaction:
                self.cancel_transaction()
                raise errors.LostConnectionError(
                    "Connection was lost during a transaction."
                )
            logger.debug("Re-executing")
            cursor = self._conn.cursor(cursor=cursor_class)
            self._execute_query(cursor, query, args, suppress_warnings)

        if use_query_cache:
            data = cursor.fetchall()
            cache_path.write_bytes(pack(data))
            return EmulatedCursor(data)

        return cursor

    def get_user(self):
        """
        :return: the user name and host name provided by the client to the server.
        """
        return self.query("SELECT user()").fetchone()[0]

    # ---------- transaction processing
    @property
    def in_transaction(self):
        """
        :return: True if there is an open transaction.
        """
        self._in_transaction = self._in_transaction and self.is_connected
        return self._in_transaction

    def start_transaction(self):
        """
        Starts a transaction error.
        """
        if self.in_transaction:
            raise errors.DataJointError("Nested connections are not supported.")
        self.query("START TRANSACTION WITH CONSISTENT SNAPSHOT")
        self._in_transaction = True
        logger.debug("Transaction started")

    def cancel_transaction(self):
        """
        Cancels the current transaction and rolls back all changes made during the transaction.
        """
        self.query("ROLLBACK")
        self._in_transaction = False
        logger.debug("Transaction cancelled. Rolling back ...")

    def commit_transaction(self):
        """
        Commit all changes made during the transaction and close it.

        """
        self.query("COMMIT")
        self._in_transaction = False
        logger.debug("Transaction committed and closed.")

    # -------- context manager for transactions
    @property
    @contextmanager
    def transaction(self):
        """
        Context manager for transactions. Opens an transaction and closes it after the with statement.
        If an error is caught during the transaction, the commits are automatically rolled back.
        All errors are raised again.

        Example:
        >>> import datajoint as dj
        >>> with dj.conn().transaction as conn:
        >>>     # transaction is open here
        """
        try:
            self.start_transaction()
            yield self
        except:
            self.cancel_transaction()
            raise
        else:
            self.commit_transaction()

connect()

Connect to the database server.

Source code in datajoint/connection.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def connect(self):
    """Connect to the database server."""
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", ".*deprecated.*")
        try:
            self._conn = client.connect(
                init_command=self.init_fun,
                sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
                "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY",
                charset=config["connection.charset"],
                **{
                    k: v
                    for k, v in self.conn_info.items()
                    if k not in ["ssl_input", "host_input"]
                },
            )
        except client.err.InternalError:
            self._conn = client.connect(
                init_command=self.init_fun,
                sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
                "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY",
                charset=config["connection.charset"],
                **{
                    k: v
                    for k, v in self.conn_info.items()
                    if not (
                        k in ["ssl_input", "host_input"]
                        or k == "ssl"
                        and self.conn_info["ssl_input"] is None
                    )
                },
            )
    self._conn.autocommit(True)

set_query_cache(query_cache=None)

When query_cache is not None, the connection switches into the query caching mode, which entails: 1. Only SELECT queries are allowed. 2. The results of queries are cached under the path indicated by dj.config['query_cache'] 3. query_cache is a string that differentiates different cache states.

Parameters:

Name Type Description Default
query_cache

a string to initialize the hash for query results

None
Source code in datajoint/connection.py
249
250
251
252
253
254
255
256
257
258
def set_query_cache(self, query_cache=None):
    """
    When query_cache is not None, the connection switches into the query caching mode, which entails:
    1. Only SELECT queries are allowed.
    2. The results of queries are cached under the path indicated by dj.config['query_cache']
    3. query_cache is a string that differentiates different cache states.

    :param query_cache: a string to initialize the hash for query results
    """
    self._query_cache = query_cache

purge_query_cache()

Purges all query cache.

Source code in datajoint/connection.py
260
261
262
263
264
265
266
267
268
def purge_query_cache(self):
    """Purges all query cache."""
    if (
        isinstance(config.get(cache_key), str)
        and pathlib.Path(config[cache_key]).is_dir()
    ):
        for path in pathlib.Path(config[cache_key]).iterdir():
            if not path.is_dir():
                path.unlink()

ping()

Ping the connection or raises an exception if the connection is closed.

Source code in datajoint/connection.py
277
278
279
def ping(self):
    """Ping the connection or raises an exception if the connection is closed."""
    self._conn.ping(reconnect=False)

is_connected property

Return true if the object is connected to the database server.

query(query, args=(), *, as_dict=False, suppress_warnings=True, reconnect=None)

Execute the specified query and return the tuple generator (cursor).

Parameters:

Name Type Description Default
query

SQL query

required
args

additional arguments for the client.cursor

()
as_dict

If as_dict is set to True, the returned cursor objects returns query results as dictionary.

False
suppress_warnings

If True, suppress all warnings arising from underlying query library

True
reconnect

when None, get from config, when True, attempt to reconnect if disconnected

None
Source code in datajoint/connection.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
def query(
    self, query, args=(), *, as_dict=False, suppress_warnings=True, reconnect=None
):
    """
    Execute the specified query and return the tuple generator (cursor).

    :param query: SQL query
    :param args: additional arguments for the client.cursor
    :param as_dict: If as_dict is set to True, the returned cursor objects returns
                    query results as dictionary.
    :param suppress_warnings: If True, suppress all warnings arising from underlying query library
    :param reconnect: when None, get from config, when True, attempt to reconnect if disconnected
    """
    # check cache first:
    use_query_cache = bool(self._query_cache)
    if use_query_cache and not re.match(r"\s*(SELECT|SHOW)", query):
        raise errors.DataJointError(
            "Only SELECT queries are allowed when query caching is on."
        )
    if use_query_cache:
        if not config[cache_key]:
            raise errors.DataJointError(
                f"Provide filepath dj.config['{cache_key}'] when using query caching."
            )
        hash_ = uuid_from_buffer(
            (str(self._query_cache) + re.sub(r"`\$\w+`", "", query)).encode()
            + pack(args)
        )
        cache_path = pathlib.Path(config[cache_key]) / str(hash_)
        try:
            buffer = cache_path.read_bytes()
        except FileNotFoundError:
            pass  # proceed to query the database
        else:
            return EmulatedCursor(unpack(buffer))

    if reconnect is None:
        reconnect = config["database.reconnect"]
    logger.debug("Executing SQL:" + query[:query_log_max_length])
    cursor_class = client.cursors.DictCursor if as_dict else client.cursors.Cursor
    cursor = self._conn.cursor(cursor=cursor_class)
    try:
        self._execute_query(cursor, query, args, suppress_warnings)
    except errors.LostConnectionError:
        if not reconnect:
            raise
        logger.warning("MySQL server has gone away. Reconnecting to the server.")
        connect_host_hook(self)
        if self._in_transaction:
            self.cancel_transaction()
            raise errors.LostConnectionError(
                "Connection was lost during a transaction."
            )
        logger.debug("Re-executing")
        cursor = self._conn.cursor(cursor=cursor_class)
        self._execute_query(cursor, query, args, suppress_warnings)

    if use_query_cache:
        data = cursor.fetchall()
        cache_path.write_bytes(pack(data))
        return EmulatedCursor(data)

    return cursor

get_user()

Returns:

Type Description

the user name and host name provided by the client to the server.

Source code in datajoint/connection.py
365
366
367
368
369
def get_user(self):
    """
    :return: the user name and host name provided by the client to the server.
    """
    return self.query("SELECT user()").fetchone()[0]

in_transaction property

Returns:

Type Description

True if there is an open transaction.

start_transaction()

Starts a transaction error.

Source code in datajoint/connection.py
380
381
382
383
384
385
386
387
388
def start_transaction(self):
    """
    Starts a transaction error.
    """
    if self.in_transaction:
        raise errors.DataJointError("Nested connections are not supported.")
    self.query("START TRANSACTION WITH CONSISTENT SNAPSHOT")
    self._in_transaction = True
    logger.debug("Transaction started")

cancel_transaction()

Cancels the current transaction and rolls back all changes made during the transaction.

Source code in datajoint/connection.py
390
391
392
393
394
395
396
def cancel_transaction(self):
    """
    Cancels the current transaction and rolls back all changes made during the transaction.
    """
    self.query("ROLLBACK")
    self._in_transaction = False
    logger.debug("Transaction cancelled. Rolling back ...")

commit_transaction()

Commit all changes made during the transaction and close it.

Source code in datajoint/connection.py
398
399
400
401
402
403
404
405
def commit_transaction(self):
    """
    Commit all changes made during the transaction and close it.

    """
    self.query("COMMIT")
    self._in_transaction = False
    logger.debug("Transaction committed and closed.")

transaction property

Context manager for transactions. Opens an transaction and closes it after the with statement. If an error is caught during the transaction, the commits are automatically rolled back. All errors are raised again.

Example:

import datajoint as dj with dj.conn().transaction as conn: # transaction is open here